721 lines
30 KiB
Python
721 lines
30 KiB
Python
# Copyright 2024-present MongoDB, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Unit tests for the Atlas Stream Processing driver module.
|
|
|
|
All tests run offline — no live workspace is required. The wire layer is
|
|
stubbed out by replacing ``_command`` on the client with a lightweight spy
|
|
that records calls and returns pre-configured responses.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import inspect
|
|
import sys
|
|
import unittest
|
|
from datetime import datetime, timezone
|
|
from typing import Any, Mapping, Optional
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
sys.path[0:0] = [""]
|
|
|
|
_IS_SYNC = False
|
|
|
|
import pymongo.asynchronous.stream_processing
|
|
from bson import Timestamp
|
|
from pymongo import (
|
|
AsyncSampleCursor,
|
|
AsyncStreamProcessingClient,
|
|
AsyncStreamProcessor,
|
|
AsyncStreamProcessors,
|
|
CreateStreamProcessorOptions,
|
|
GetStreamProcessorSamplesOptions,
|
|
GetStreamProcessorSamplesResult,
|
|
GetStreamProcessorStatsOptions,
|
|
StartStreamProcessorOptions,
|
|
StreamProcessorInfo,
|
|
)
|
|
from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Spy helper
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _spy_client(
|
|
responses: Optional[list[Mapping[str, Any]]] = None,
|
|
raises: Optional[Exception] = None,
|
|
) -> tuple[AsyncStreamProcessingClient, list[dict[str, Any]]]:
|
|
"""Return *(client, calls)* where ``client._command`` records every call.
|
|
|
|
*responses* is consumed in order; when exhausted, ``{}`` is returned.
|
|
If *raises* is set, the spy raises that exception on every call instead
|
|
of returning a response.
|
|
"""
|
|
calls: list[dict[str, Any]] = []
|
|
resp_list: list[Mapping[str, Any]] = list(responses) if responses is not None else [{}]
|
|
|
|
async def _command(
|
|
cmd: dict[str, Any],
|
|
*,
|
|
retryable_read: bool = False,
|
|
session: Any = None,
|
|
) -> Mapping[str, Any]:
|
|
calls.append({"cmd": dict(cmd), "retryable_read": retryable_read, "session": session})
|
|
if raises is not None:
|
|
raise raises
|
|
return resp_list.pop(0) if resp_list else {}
|
|
|
|
client = AsyncStreamProcessingClient.__new__(AsyncStreamProcessingClient)
|
|
client._client = MagicMock()
|
|
client._command = _command
|
|
return client, calls
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Minimal info response fixture
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_INFO_RESPONSE: dict[str, Any] = {
|
|
"ok": 1,
|
|
"id": "abc",
|
|
"name": "demo",
|
|
"state": "CREATED",
|
|
"pipeline": [],
|
|
"pipelineVersion": 1,
|
|
"enableAutoScaling": False,
|
|
"failoverEnabled": False,
|
|
"activeRegion": "us-east-1",
|
|
"hasStarted": False,
|
|
}
|
|
|
|
|
|
# ===========================================================================
|
|
# Test class 1 — client constructor validation
|
|
# ===========================================================================
|
|
|
|
|
|
class TestStreamProcessingClientConfig(unittest.TestCase):
|
|
"""Constructor-level validation for AsyncStreamProcessingClient."""
|
|
|
|
def test_rejects_srv_uri(self) -> None:
|
|
with self.assertRaises(ConfigurationError) as cm:
|
|
AsyncStreamProcessingClient("mongodb+srv://example.com/")
|
|
self.assertIn("mongodb+srv", str(cm.exception).lower())
|
|
|
|
def test_rejects_tls_false_kwarg(self) -> None:
|
|
with self.assertRaises(ConfigurationError):
|
|
AsyncStreamProcessingClient("mongodb://example.com/", tls=False)
|
|
|
|
def test_rejects_ssl_false_kwarg(self) -> None:
|
|
with self.assertRaises(ConfigurationError):
|
|
AsyncStreamProcessingClient("mongodb://example.com/", ssl=False)
|
|
|
|
def test_rejects_tls_false_in_uri(self) -> None:
|
|
with self.assertRaises(ConfigurationError):
|
|
AsyncStreamProcessingClient("mongodb://example.com/?tls=false")
|
|
|
|
@patch("pymongo.asynchronous.stream_processing.AsyncMongoClient")
|
|
def test_forces_tls_true(self, mock_client: MagicMock) -> None:
|
|
mock_client.return_value = MagicMock()
|
|
AsyncStreamProcessingClient("mongodb://example.com/")
|
|
kwargs = mock_client.call_args.kwargs
|
|
self.assertTrue(kwargs.get("tls"))
|
|
|
|
@patch("pymongo.asynchronous.stream_processing.AsyncMongoClient")
|
|
def test_defaults_authsource_to_admin(self, mock_client: MagicMock) -> None:
|
|
mock_client.return_value = MagicMock()
|
|
AsyncStreamProcessingClient("mongodb://example.com/")
|
|
kwargs = mock_client.call_args.kwargs
|
|
self.assertEqual(kwargs.get("authSource"), "admin")
|
|
|
|
@patch("pymongo.asynchronous.stream_processing.AsyncMongoClient")
|
|
def test_preserves_explicit_authsource_in_kwarg(self, mock_client: MagicMock) -> None:
|
|
mock_client.return_value = MagicMock()
|
|
AsyncStreamProcessingClient("mongodb://example.com/", authSource="other")
|
|
kwargs = mock_client.call_args.kwargs
|
|
# Our default injection must not overwrite a user-supplied value.
|
|
self.assertEqual(kwargs.get("authSource"), "other")
|
|
|
|
@patch("pymongo.asynchronous.stream_processing.AsyncMongoClient")
|
|
def test_preserves_explicit_authsource_in_uri(self, mock_client: MagicMock) -> None:
|
|
mock_client.return_value = MagicMock()
|
|
AsyncStreamProcessingClient("mongodb://example.com/?authSource=other")
|
|
kwargs = mock_client.call_args.kwargs
|
|
# authSource is already in the URI; we must not inject a duplicate kwarg.
|
|
self.assertNotIn("authSource", kwargs)
|
|
|
|
@patch("pymongo.asynchronous.stream_processing.AsyncMongoClient")
|
|
def test_drops_ssl_kwarg_to_avoid_duplicate(self, mock_client: MagicMock) -> None:
|
|
mock_client.return_value = MagicMock()
|
|
AsyncStreamProcessingClient("mongodb://example.com/", ssl=True)
|
|
kwargs = mock_client.call_args.kwargs
|
|
self.assertNotIn("ssl", kwargs)
|
|
self.assertTrue(kwargs.get("tls"))
|
|
|
|
def test_workspace_endpoint_detection(self) -> None:
|
|
from pymongo.asynchronous.stream_processing import _is_workspace_endpoint
|
|
|
|
self.assertTrue(_is_workspace_endpoint("atlas-stream-foo.virginia-usa.a.query.mongodb.net"))
|
|
self.assertTrue(_is_workspace_endpoint("something.a.query.mongodb.net"))
|
|
self.assertFalse(_is_workspace_endpoint("cluster0.mongodb.net"))
|
|
self.assertFalse(_is_workspace_endpoint("localhost"))
|
|
|
|
|
|
# ===========================================================================
|
|
# Test class 2 — options dataclass validation
|
|
# ===========================================================================
|
|
|
|
|
|
class TestStreamProcessingOptions(unittest.TestCase):
|
|
"""Dataclass validation in stream_processing_options."""
|
|
|
|
def test_start_options_mutex_start_after_and_at_op_time(self) -> None:
|
|
with self.assertRaises(InvalidOperation):
|
|
StartStreamProcessorOptions(
|
|
start_after={"_id": 1},
|
|
start_at_operation_time=Timestamp(0, 0),
|
|
)
|
|
|
|
def test_start_options_invalid_tier(self) -> None:
|
|
with self.assertRaises(InvalidOperation):
|
|
StartStreamProcessorOptions(tier="SP1")
|
|
for valid_tier in ("SP2", "SP5", "SP10", "SP30", "SP50"):
|
|
StartStreamProcessorOptions(tier=valid_tier) # must not raise
|
|
|
|
def test_start_options_invalid_workers(self) -> None:
|
|
with self.assertRaises(InvalidOperation):
|
|
StartStreamProcessorOptions(workers=0)
|
|
with self.assertRaises(InvalidOperation):
|
|
StartStreamProcessorOptions(workers=-1)
|
|
StartStreamProcessorOptions(workers=1) # must not raise
|
|
|
|
def test_stats_options_invalid_scale(self) -> None:
|
|
with self.assertRaises(InvalidOperation):
|
|
GetStreamProcessorStatsOptions(scale=0)
|
|
with self.assertRaises(InvalidOperation):
|
|
GetStreamProcessorStatsOptions(scale=-1)
|
|
GetStreamProcessorStatsOptions(scale=1)
|
|
GetStreamProcessorStatsOptions(scale=1024)
|
|
|
|
def test_samples_options_invalid_cursor_id(self) -> None:
|
|
with self.assertRaises(InvalidOperation):
|
|
GetStreamProcessorSamplesOptions(cursor_id=-1)
|
|
# cursor_id=0 is valid at construction; the method-level check rejects it.
|
|
GetStreamProcessorSamplesOptions(cursor_id=0)
|
|
GetStreamProcessorSamplesOptions(cursor_id=42)
|
|
|
|
def test_samples_options_invalid_limit(self) -> None:
|
|
with self.assertRaises(InvalidOperation):
|
|
GetStreamProcessorSamplesOptions(limit=-1)
|
|
GetStreamProcessorSamplesOptions(limit=0)
|
|
GetStreamProcessorSamplesOptions(limit=10)
|
|
|
|
def test_samples_options_invalid_batch_size(self) -> None:
|
|
with self.assertRaises(InvalidOperation):
|
|
GetStreamProcessorSamplesOptions(batch_size=-1)
|
|
GetStreamProcessorSamplesOptions(batch_size=0)
|
|
GetStreamProcessorSamplesOptions(batch_size=5)
|
|
|
|
def test_create_options_all_optional(self) -> None:
|
|
opts = CreateStreamProcessorOptions()
|
|
self.assertIsNone(opts.dlq)
|
|
self.assertIsNone(opts.stream_meta_field_name)
|
|
self.assertIsNone(opts.tier)
|
|
self.assertIsNone(opts.failover)
|
|
|
|
def test_stream_processor_info_from_response_full(self) -> None:
|
|
now = datetime.now(tz=timezone.utc)
|
|
doc: dict[str, Any] = {
|
|
"ok": 1,
|
|
"id": "abc123",
|
|
"name": "demo",
|
|
"state": "STARTED",
|
|
"pipeline": [{"$source": {}}],
|
|
"pipelineVersion": 3,
|
|
"tier": "SP10",
|
|
"dlq": {"connectionName": "c1"},
|
|
"streamMetaFieldName": "ts",
|
|
"enableAutoScaling": True,
|
|
"failoverEnabled": True,
|
|
"activeRegion": "us-east-1",
|
|
"lastModifiedAt": now,
|
|
"modifiedBy": "user@example.com",
|
|
"lastStateChange": now,
|
|
"lastHeartbeat": now,
|
|
"hasStarted": True,
|
|
"stats": {"inputMessageCount": 42},
|
|
"errorMsg": None,
|
|
"errorCode": None,
|
|
"errorRetryable": None,
|
|
}
|
|
info = StreamProcessorInfo.from_response(doc)
|
|
self.assertEqual(info.id, "abc123")
|
|
self.assertEqual(info.name, "demo")
|
|
self.assertEqual(info.state, "STARTED")
|
|
self.assertEqual(info.pipeline, [{"$source": {}}])
|
|
self.assertEqual(info.pipeline_version, 3)
|
|
self.assertEqual(info.tier, "SP10")
|
|
self.assertEqual(info.dlq, {"connectionName": "c1"})
|
|
self.assertEqual(info.stream_meta_field_name, "ts")
|
|
self.assertTrue(info.enable_auto_scaling)
|
|
self.assertTrue(info.failover_enabled)
|
|
self.assertEqual(info.active_region, "us-east-1")
|
|
self.assertTrue(info.has_started)
|
|
self.assertEqual(info.stats, {"inputMessageCount": 42})
|
|
self.assertEqual(info.raw, doc)
|
|
|
|
def test_stream_processor_info_from_response_minimal(self) -> None:
|
|
doc: dict[str, Any] = {
|
|
"id": "x",
|
|
"name": "p",
|
|
"state": "CREATED",
|
|
"pipeline": [],
|
|
"pipelineVersion": 1,
|
|
}
|
|
info = StreamProcessorInfo.from_response(doc)
|
|
self.assertEqual(info.id, "x")
|
|
self.assertIsNone(info.tier)
|
|
self.assertIsNone(info.dlq)
|
|
self.assertFalse(info.enable_auto_scaling)
|
|
self.assertFalse(info.has_started)
|
|
self.assertEqual(info.raw, doc)
|
|
|
|
def test_stream_processor_info_preserves_unknown_fields(self) -> None:
|
|
doc: dict[str, Any] = {
|
|
"id": "x",
|
|
"name": "p",
|
|
"state": "CREATED",
|
|
"pipeline": [],
|
|
"pipelineVersion": 1,
|
|
"futureField": "someValue",
|
|
}
|
|
info = StreamProcessorInfo.from_response(doc)
|
|
self.assertEqual(info.raw["futureField"], "someValue")
|
|
|
|
def test_stream_processor_info_state_is_string(self) -> None:
|
|
doc: dict[str, Any] = {
|
|
"id": "x",
|
|
"name": "p",
|
|
"state": "FUTURE_UNKNOWN_STATE",
|
|
"pipeline": [],
|
|
"pipelineVersion": 1,
|
|
}
|
|
info = StreamProcessorInfo.from_response(doc)
|
|
self.assertEqual(info.state, "FUTURE_UNKNOWN_STATE")
|
|
|
|
|
|
# ===========================================================================
|
|
# Test class 3 — lifecycle commands
|
|
# ===========================================================================
|
|
|
|
|
|
class AsyncTestStreamProcessorsCommands(unittest.IsolatedAsyncioTestCase):
|
|
"""Lifecycle commands on AsyncStreamProcessors and AsyncStreamProcessor."""
|
|
|
|
async def asyncSetUp(self) -> None:
|
|
self.client, self.calls = _spy_client(responses=[{"ok": 1}] * 30)
|
|
self.sps = AsyncStreamProcessors(self.client)
|
|
self.proc = AsyncStreamProcessor(client=self.client, name="demo")
|
|
|
|
async def test_create_sends_correct_command(self) -> None:
|
|
await self.sps.create("demo", pipeline=[{"$source": {}}])
|
|
entry = self.calls[-1]
|
|
self.assertEqual(entry["cmd"].get("createStreamProcessor"), "demo")
|
|
self.assertEqual(entry["cmd"].get("pipeline"), [{"$source": {}}])
|
|
self.assertNotIn("options", entry["cmd"])
|
|
self.assertFalse(entry["retryable_read"])
|
|
|
|
async def test_create_with_options_includes_them(self) -> None:
|
|
opts = CreateStreamProcessorOptions(dlq={"connectionName": "c"}, tier="SP10")
|
|
await self.sps.create("demo", pipeline=[{"$source": {}}], options=opts)
|
|
cmd = self.calls[-1]["cmd"]
|
|
self.assertEqual(cmd["options"]["dlq"], {"connectionName": "c"})
|
|
self.assertEqual(cmd["options"]["tier"], "SP10")
|
|
|
|
async def test_create_rejects_empty_name(self) -> None:
|
|
with self.assertRaises(InvalidOperation):
|
|
await self.sps.create("", pipeline=[{"$x": 1}])
|
|
self.assertEqual(len(self.calls), 0)
|
|
|
|
async def test_create_rejects_whitespace_name(self) -> None:
|
|
with self.assertRaises(InvalidOperation):
|
|
await self.sps.create(" ", pipeline=[{"$x": 1}])
|
|
self.assertEqual(len(self.calls), 0)
|
|
|
|
async def test_create_rejects_empty_pipeline(self) -> None:
|
|
with self.assertRaises(InvalidOperation):
|
|
await self.sps.create("demo", pipeline=[])
|
|
self.assertEqual(len(self.calls), 0)
|
|
|
|
async def test_get_returns_handle_without_calling_server(self) -> None:
|
|
proc = self.sps.get("demo")
|
|
self.assertEqual(len(self.calls), 0)
|
|
self.assertEqual(proc.name, "demo")
|
|
self.assertIsInstance(proc, AsyncStreamProcessor)
|
|
|
|
async def test_get_rejects_empty_name(self) -> None:
|
|
with self.assertRaises(InvalidOperation):
|
|
self.sps.get("")
|
|
|
|
async def test_get_info_sends_correct_command_and_decodes(self) -> None:
|
|
client, calls = _spy_client(responses=[dict(_INFO_RESPONSE)])
|
|
info = await AsyncStreamProcessors(client).get_info("demo")
|
|
self.assertEqual(calls[0]["cmd"].get("getStreamProcessor"), "demo")
|
|
self.assertTrue(calls[0]["retryable_read"])
|
|
self.assertIsInstance(info, StreamProcessorInfo)
|
|
self.assertEqual(info.id, "abc")
|
|
self.assertEqual(info.state, "CREATED")
|
|
|
|
async def test_get_info_preserves_unknown_response_fields(self) -> None:
|
|
resp = dict(_INFO_RESPONSE, futureField="value")
|
|
client, _ = _spy_client(responses=[resp])
|
|
info = await AsyncStreamProcessors(client).get_info("demo")
|
|
self.assertEqual(info.raw["futureField"], "value")
|
|
|
|
async def test_start_sends_correct_command(self) -> None:
|
|
await self.proc.start()
|
|
entry = self.calls[-1]
|
|
self.assertEqual(entry["cmd"].get("startStreamProcessor"), "demo")
|
|
self.assertNotIn("workers", entry["cmd"])
|
|
self.assertNotIn("options", entry["cmd"])
|
|
self.assertFalse(entry["retryable_read"])
|
|
|
|
async def test_start_with_options_includes_them(self) -> None:
|
|
opts = StartStreamProcessorOptions(workers=2, clear_checkpoints=True, tier="SP10")
|
|
await self.proc.start(options=opts)
|
|
cmd = self.calls[-1]["cmd"]
|
|
self.assertEqual(cmd.get("workers"), 2)
|
|
self.assertTrue(cmd["options"]["clearCheckpoints"])
|
|
self.assertEqual(cmd["options"]["tier"], "SP10")
|
|
|
|
async def test_stop_sends_correct_command(self) -> None:
|
|
await self.proc.stop()
|
|
entry = self.calls[-1]
|
|
self.assertEqual(entry["cmd"].get("stopStreamProcessor"), "demo")
|
|
self.assertFalse(entry["retryable_read"])
|
|
|
|
async def test_drop_sends_correct_command(self) -> None:
|
|
await self.proc.drop()
|
|
entry = self.calls[-1]
|
|
self.assertEqual(entry["cmd"].get("dropStreamProcessor"), "demo")
|
|
self.assertFalse(entry["retryable_read"])
|
|
|
|
async def test_stats_sends_correct_command_and_returns_raw_dict(self) -> None:
|
|
raw_resp: dict[str, Any] = {"ok": 1, "stats": {"inputMessageCount": 5}, "futureField": "x"}
|
|
client, calls = _spy_client(responses=[raw_resp])
|
|
result = await AsyncStreamProcessor(client=client, name="demo").stats()
|
|
self.assertEqual(calls[0]["cmd"].get("getStreamProcessorStats"), "demo")
|
|
self.assertTrue(calls[0]["retryable_read"])
|
|
self.assertEqual(result, raw_resp)
|
|
|
|
async def test_stats_with_options(self) -> None:
|
|
client, calls = _spy_client(responses=[{"ok": 1}])
|
|
opts = GetStreamProcessorStatsOptions(scale=1024, verbose=True)
|
|
await AsyncStreamProcessor(client=client, name="demo").stats(options=opts)
|
|
self.assertEqual(calls[0]["cmd"]["options"]["scale"], 1024)
|
|
self.assertTrue(calls[0]["cmd"]["options"]["verbose"])
|
|
|
|
async def test_session_propagates(self) -> None:
|
|
fake_session = MagicMock()
|
|
await self.proc.start(session=fake_session)
|
|
self.assertIs(self.calls[-1]["session"], fake_session)
|
|
|
|
|
|
# ===========================================================================
|
|
# Test class 4 — sample cursor
|
|
# ===========================================================================
|
|
|
|
|
|
class AsyncTestSampleCursor(unittest.IsolatedAsyncioTestCase):
|
|
"""Two-phase sample cursor protocol."""
|
|
|
|
async def test_get_samples_initial_call_sends_start_command(self) -> None:
|
|
client, calls = _spy_client(responses=[{"cursorId": 42, "firstBatch": [{"x": 1}]}])
|
|
proc = AsyncStreamProcessor(client=client, name="demo")
|
|
result = await proc.get_stream_processor_samples()
|
|
cmd = calls[0]["cmd"]
|
|
self.assertIn("startSampleStreamProcessor", cmd)
|
|
self.assertNotIn("cursorId", cmd)
|
|
self.assertNotIn("batchSize", cmd)
|
|
self.assertFalse(calls[0]["retryable_read"])
|
|
self.assertEqual(result.cursor_id, 42)
|
|
self.assertEqual(result.documents, [{"x": 1}])
|
|
|
|
async def test_get_samples_initial_call_includes_limit(self) -> None:
|
|
client, calls = _spy_client(responses=[{"cursorId": 1, "firstBatch": []}])
|
|
proc = AsyncStreamProcessor(client=client, name="demo")
|
|
await proc.get_stream_processor_samples(GetStreamProcessorSamplesOptions(limit=100))
|
|
self.assertEqual(calls[0]["cmd"].get("limit"), 100)
|
|
self.assertNotIn("batchSize", calls[0]["cmd"])
|
|
|
|
async def test_get_samples_continuation_sends_get_more(self) -> None:
|
|
client, calls = _spy_client(responses=[{"cursorId": 0, "nextBatch": [{"y": 2}]}])
|
|
proc = AsyncStreamProcessor(client=client, name="demo")
|
|
result = await proc.get_stream_processor_samples(
|
|
GetStreamProcessorSamplesOptions(cursor_id=42)
|
|
)
|
|
cmd = calls[0]["cmd"]
|
|
self.assertIn("getMoreSampleStreamProcessor", cmd)
|
|
self.assertEqual(cmd.get("cursorId"), 42)
|
|
self.assertNotIn("limit", cmd)
|
|
self.assertEqual(result.cursor_id, 0)
|
|
self.assertEqual(result.documents, [{"y": 2}])
|
|
|
|
async def test_get_samples_continuation_includes_batch_size(self) -> None:
|
|
client, calls = _spy_client(responses=[{"cursorId": 0, "nextBatch": []}])
|
|
proc = AsyncStreamProcessor(client=client, name="demo")
|
|
await proc.get_stream_processor_samples(
|
|
GetStreamProcessorSamplesOptions(cursor_id=42, batch_size=5)
|
|
)
|
|
self.assertEqual(calls[0]["cmd"].get("batchSize"), 5)
|
|
|
|
async def test_get_samples_rejects_cursor_id_zero(self) -> None:
|
|
client, calls = _spy_client()
|
|
proc = AsyncStreamProcessor(client=client, name="demo")
|
|
with self.assertRaises(InvalidOperation):
|
|
await proc.get_stream_processor_samples(GetStreamProcessorSamplesOptions(cursor_id=0))
|
|
self.assertEqual(len(calls), 0)
|
|
|
|
async def test_sample_iterator_drains_single_batch(self) -> None:
|
|
client, calls = _spy_client(
|
|
responses=[{"cursorId": 0, "firstBatch": [{"a": 1}, {"a": 2}, {"a": 3}]}]
|
|
)
|
|
proc = AsyncStreamProcessor(client=client, name="demo")
|
|
docs = []
|
|
async for doc in proc.sample():
|
|
docs.append(doc)
|
|
self.assertEqual(docs, [{"a": 1}, {"a": 2}, {"a": 3}])
|
|
self.assertEqual(len(calls), 1)
|
|
|
|
async def test_sample_iterator_drains_multiple_batches(self) -> None:
|
|
client, calls = _spy_client(
|
|
responses=[
|
|
{"cursorId": 11, "firstBatch": [{"a": 1}]},
|
|
{"cursorId": 22, "nextBatch": [{"a": 2}]},
|
|
{"cursorId": 0, "nextBatch": [{"a": 3}]},
|
|
]
|
|
)
|
|
proc = AsyncStreamProcessor(client=client, name="demo")
|
|
docs = []
|
|
async for doc in proc.sample():
|
|
docs.append(doc)
|
|
self.assertEqual(docs, [{"a": 1}, {"a": 2}, {"a": 3}])
|
|
self.assertEqual(len(calls), 3)
|
|
|
|
async def test_sample_iterator_handles_empty_batch_with_nonzero_cursor(self) -> None:
|
|
client, calls = _spy_client(
|
|
responses=[
|
|
{"cursorId": 11, "firstBatch": []},
|
|
{"cursorId": 22, "nextBatch": []},
|
|
{"cursorId": 0, "nextBatch": [{"a": 1}]},
|
|
]
|
|
)
|
|
proc = AsyncStreamProcessor(client=client, name="demo")
|
|
docs = []
|
|
async for doc in proc.sample():
|
|
docs.append(doc)
|
|
self.assertEqual(docs, [{"a": 1}])
|
|
self.assertEqual(len(calls), 3)
|
|
|
|
async def test_sample_iterator_stops_after_cursor_id_zero(self) -> None:
|
|
client, calls = _spy_client(responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}])
|
|
proc = AsyncStreamProcessor(client=client, name="demo")
|
|
cursor = proc.sample()
|
|
first = await cursor.__anext__()
|
|
self.assertEqual(first, {"a": 1})
|
|
with self.assertRaises(StopAsyncIteration):
|
|
await cursor.__anext__()
|
|
# Only one wire call — the server returned cursorId:0 so we stopped.
|
|
self.assertEqual(len(calls), 1)
|
|
|
|
async def test_sample_cursor_close_stops_iteration(self) -> None:
|
|
client, calls = _spy_client()
|
|
proc = AsyncStreamProcessor(client=client, name="demo")
|
|
cursor = proc.sample()
|
|
await cursor.close()
|
|
with self.assertRaises(StopAsyncIteration):
|
|
await cursor.__anext__()
|
|
self.assertEqual(len(calls), 0)
|
|
|
|
async def test_sample_cursor_alive_property(self) -> None:
|
|
client, _ = _spy_client(responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}])
|
|
proc = AsyncStreamProcessor(client=client, name="demo")
|
|
cursor = proc.sample()
|
|
self.assertTrue(cursor.alive)
|
|
doc = await cursor.__anext__()
|
|
self.assertEqual(doc, {"a": 1})
|
|
# After exhaustion (cursorId:0 returned during refill), alive is False.
|
|
self.assertFalse(cursor.alive)
|
|
|
|
async def test_sample_cursor_id_property(self) -> None:
|
|
client, _ = _spy_client(responses=[{"cursorId": 42, "firstBatch": [{"a": 1}]}])
|
|
proc = AsyncStreamProcessor(client=client, name="demo")
|
|
cursor = proc.sample()
|
|
self.assertIsNone(cursor.cursor_id)
|
|
await cursor.__anext__()
|
|
self.assertEqual(cursor.cursor_id, 42)
|
|
|
|
async def test_sample_cursor_async_context_manager(self) -> None:
|
|
client, _ = _spy_client()
|
|
proc = AsyncStreamProcessor(client=client, name="demo")
|
|
async with proc.sample() as cur:
|
|
self.assertIsInstance(cur, AsyncSampleCursor)
|
|
self.assertTrue(cur._closed)
|
|
|
|
async def test_sample_cursor_session_propagates_to_refill(self) -> None:
|
|
client, calls = _spy_client(responses=[{"cursorId": 0, "firstBatch": [{"a": 1}]}])
|
|
proc = AsyncStreamProcessor(client=client, name="demo")
|
|
fake_session = MagicMock()
|
|
async for _ in proc.sample(session=fake_session):
|
|
pass
|
|
self.assertIs(calls[0]["session"], fake_session)
|
|
|
|
def test_sample_cursor_does_not_inherit_from_cursor_base(self) -> None:
|
|
self.assertEqual(AsyncSampleCursor.__bases__, (object,))
|
|
|
|
|
|
# ===========================================================================
|
|
# Test class 5 — error handling
|
|
# ===========================================================================
|
|
|
|
|
|
class AsyncTestErrorHandling(unittest.IsolatedAsyncioTestCase):
|
|
"""Server errors MUST propagate unchanged — no rewrapping, no filtering."""
|
|
|
|
def _failure(self, code: int = 9) -> OperationFailure:
|
|
return OperationFailure("test error", code=code, details={"errmsg": "test", "code": code})
|
|
|
|
async def _assert_propagates(self, fn: Any, expected_code: int) -> None:
|
|
with self.assertRaises(OperationFailure) as cm:
|
|
await fn()
|
|
self.assertEqual(cm.exception.code, expected_code)
|
|
|
|
async def test_create_propagates_operation_failure(self) -> None:
|
|
client, _ = _spy_client(raises=self._failure(9))
|
|
await self._assert_propagates(
|
|
lambda: AsyncStreamProcessors(client).create("p", pipeline=[{"$x": 1}]), 9
|
|
)
|
|
|
|
async def test_start_propagates_operation_failure(self) -> None:
|
|
client, _ = _spy_client(raises=self._failure(72))
|
|
await self._assert_propagates(
|
|
lambda: AsyncStreamProcessor(client=client, name="p").start(), 72
|
|
)
|
|
|
|
async def test_stop_propagates_operation_failure(self) -> None:
|
|
client, _ = _spy_client(raises=self._failure(125))
|
|
await self._assert_propagates(
|
|
lambda: AsyncStreamProcessor(client=client, name="p").stop(), 125
|
|
)
|
|
|
|
async def test_drop_propagates_operation_failure(self) -> None:
|
|
client, _ = _spy_client(raises=self._failure(1))
|
|
await self._assert_propagates(
|
|
lambda: AsyncStreamProcessor(client=client, name="p").drop(), 1
|
|
)
|
|
|
|
async def test_get_info_propagates_operation_failure(self) -> None:
|
|
client, _ = _spy_client(raises=self._failure(9))
|
|
await self._assert_propagates(lambda: AsyncStreamProcessors(client).get_info("p"), 9)
|
|
|
|
async def test_stats_propagates_operation_failure(self) -> None:
|
|
client, _ = _spy_client(raises=self._failure(72))
|
|
await self._assert_propagates(
|
|
lambda: AsyncStreamProcessor(client=client, name="p").stats(), 72
|
|
)
|
|
|
|
async def test_get_stream_processor_samples_propagates_operation_failure(self) -> None:
|
|
client, _ = _spy_client(raises=self._failure(9))
|
|
await self._assert_propagates(
|
|
lambda: AsyncStreamProcessor(client=client, name="p").get_stream_processor_samples(), 9
|
|
)
|
|
|
|
async def test_unknown_error_code_propagates(self) -> None:
|
|
# Code 999 is not in the documented list; it must still propagate.
|
|
client, _ = _spy_client(raises=self._failure(999))
|
|
await self._assert_propagates(
|
|
lambda: AsyncStreamProcessors(client).create("p", pipeline=[{"$x": 1}]), 999
|
|
)
|
|
|
|
def test_no_operation_failure_caught_in_module(self) -> None:
|
|
"""Structural: verify no try/except hides server errors in either module."""
|
|
for mod in (pymongo.asynchronous.stream_processing,):
|
|
src = inspect.getsource(mod)
|
|
self.assertNotIn("except OperationFailure", src, mod.__name__)
|
|
self.assertNotIn("except PyMongoError", src, mod.__name__)
|
|
self.assertNotIn("exc.code in", src, mod.__name__)
|
|
|
|
|
|
# ===========================================================================
|
|
# Test class 6 — spec compliance
|
|
# ===========================================================================
|
|
|
|
|
|
class AsyncTestSpecCompliance(unittest.IsolatedAsyncioTestCase):
|
|
"""Spec MUST-level compliance checks."""
|
|
|
|
async def test_retryability_classification(self) -> None:
|
|
"""Each command must route as retryable or non-retryable per the spec."""
|
|
cases: list[tuple[str, Any, bool]] = [
|
|
# (label, coroutine-factory, expected_retryable_read)
|
|
(
|
|
"create",
|
|
lambda c: AsyncStreamProcessors(c).create("demo", pipeline=[{"$x": 1}]),
|
|
False,
|
|
),
|
|
("start", lambda c: AsyncStreamProcessor(client=c, name="demo").start(), False),
|
|
("stop", lambda c: AsyncStreamProcessor(client=c, name="demo").stop(), False),
|
|
("drop", lambda c: AsyncStreamProcessor(client=c, name="demo").drop(), False),
|
|
("get_info", lambda c: AsyncStreamProcessors(c).get_info("demo"), True),
|
|
("stats", lambda c: AsyncStreamProcessor(client=c, name="demo").stats(), True),
|
|
(
|
|
"get_samples_initial",
|
|
lambda c: AsyncStreamProcessor(
|
|
client=c, name="demo"
|
|
).get_stream_processor_samples(),
|
|
False,
|
|
),
|
|
(
|
|
"get_samples_continuation",
|
|
lambda c: AsyncStreamProcessor(client=c, name="demo").get_stream_processor_samples(
|
|
GetStreamProcessorSamplesOptions(cursor_id=42)
|
|
),
|
|
False,
|
|
),
|
|
]
|
|
|
|
responses_for: dict[str, list[Mapping[str, Any]]] = {
|
|
"get_info": [dict(_INFO_RESPONSE)],
|
|
"get_samples_initial": [{"cursorId": 0, "firstBatch": []}],
|
|
"get_samples_continuation": [{"cursorId": 0, "nextBatch": []}],
|
|
}
|
|
|
|
for label, make_coro, expected in cases:
|
|
resp = responses_for.get(label, [{"ok": 1}])
|
|
client, calls = _spy_client(responses=resp)
|
|
await make_coro(client)
|
|
self.assertEqual(
|
|
calls[-1]["retryable_read"],
|
|
expected,
|
|
f"retryability mismatch for '{label}'",
|
|
)
|
|
|
|
def test_async_sample_cursor_class_hierarchy(self) -> None:
|
|
self.assertEqual(AsyncSampleCursor.__bases__, (object,))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|