PYTHON-5763 Improve async test coverage
This commit is contained in:
parent
13085ff679
commit
26f7a11253
@ -781,6 +781,227 @@ class AsyncTestBulk(AsyncBulkTestBase):
|
||||
self.assertEqual(6, result.inserted_count)
|
||||
self.assertEqual(6, await self.coll.count_documents({}))
|
||||
|
||||
async def test_bulk_write_with_comment(self):
|
||||
"""Test bulk write operations with comment parameter."""
|
||||
requests = [
|
||||
InsertOne({"x": 1}),
|
||||
UpdateOne({"x": 1}, {"$set": {"y": 1}}),
|
||||
DeleteOne({"x": 1}),
|
||||
]
|
||||
result = await self.coll.bulk_write(requests, comment="bulk_comment")
|
||||
self.assertEqual(1, result.inserted_count)
|
||||
self.assertEqual(1, result.modified_count)
|
||||
self.assertEqual(1, result.deleted_count)
|
||||
|
||||
async def test_bulk_write_with_let(self):
|
||||
"""Test bulk write operations with let parameter."""
|
||||
if not async_client_context.version.at_least(5, 0):
|
||||
self.skipTest("let parameter requires MongoDB 5.0+")
|
||||
|
||||
await self.coll.insert_one({"x": 1})
|
||||
requests = [
|
||||
UpdateOne({"$expr": {"$eq": ["$x", "$$targetVal"]}}, {"$set": {"updated": True}}),
|
||||
]
|
||||
result = await self.coll.bulk_write(requests, let={"targetVal": 1})
|
||||
self.assertEqual(1, result.modified_count)
|
||||
|
||||
async def test_bulk_write_all_operation_types(self):
|
||||
"""Test bulk write with all operation types combined."""
|
||||
await self.coll.insert_many([{"x": i} for i in range(5)])
|
||||
|
||||
requests = [
|
||||
InsertOne({"x": 100}),
|
||||
UpdateOne({"x": 0}, {"$set": {"updated": True}}),
|
||||
UpdateMany({"x": {"$lte": 2}}, {"$set": {"batch_updated": True}}),
|
||||
ReplaceOne({"x": 3}, {"x": 3, "replaced": True}),
|
||||
DeleteOne({"x": 4}),
|
||||
DeleteMany({"x": {"$gt": 50}}),
|
||||
]
|
||||
result = await self.coll.bulk_write(requests)
|
||||
|
||||
self.assertEqual(1, result.inserted_count)
|
||||
self.assertGreaterEqual(result.modified_count, 1)
|
||||
self.assertGreaterEqual(result.deleted_count, 1)
|
||||
|
||||
async def test_bulk_write_unordered(self):
|
||||
"""Test unordered bulk write continues after error."""
|
||||
await self.coll.create_index([("x", 1)], unique=True)
|
||||
self.addAsyncCleanup(self.coll.drop_index, [("x", 1)])
|
||||
|
||||
requests = [
|
||||
InsertOne({"x": 1}),
|
||||
InsertOne({"x": 1}), # Duplicate - will error
|
||||
InsertOne({"x": 2}),
|
||||
InsertOne({"x": 3}),
|
||||
]
|
||||
|
||||
with self.assertRaises(BulkWriteError) as ctx:
|
||||
await self.coll.bulk_write(requests, ordered=False)
|
||||
|
||||
# With unordered, should have inserted 3 documents
|
||||
self.assertEqual(3, ctx.exception.details["nInserted"])
|
||||
|
||||
async def test_bulk_write_ordered(self):
|
||||
"""Test ordered bulk write stops on first error."""
|
||||
await self.coll.create_index([("x", 1)], unique=True)
|
||||
self.addAsyncCleanup(self.coll.drop_index, [("x", 1)])
|
||||
|
||||
requests = [
|
||||
InsertOne({"x": 1}),
|
||||
InsertOne({"x": 1}), # Duplicate - will error
|
||||
InsertOne({"x": 2}),
|
||||
InsertOne({"x": 3}),
|
||||
]
|
||||
|
||||
with self.assertRaises(BulkWriteError) as ctx:
|
||||
await self.coll.bulk_write(requests, ordered=True)
|
||||
|
||||
# With ordered, should have inserted only 1 document
|
||||
self.assertEqual(1, ctx.exception.details["nInserted"])
|
||||
|
||||
async def test_bulk_write_bypass_document_validation(self):
|
||||
"""Test bulk write with bypass_document_validation."""
|
||||
if not async_client_context.version.at_least(3, 2):
|
||||
self.skipTest("bypass_document_validation requires MongoDB 3.2+")
|
||||
|
||||
# Create collection with validator
|
||||
await self.coll.drop()
|
||||
await self.db.create_collection(
|
||||
self.coll.name, validator={"$jsonSchema": {"required": ["name"]}}
|
||||
)
|
||||
|
||||
# Without bypass, should fail
|
||||
with self.assertRaises(BulkWriteError):
|
||||
await self.coll.bulk_write([InsertOne({"x": 1})])
|
||||
|
||||
# With bypass, should succeed
|
||||
result = await self.coll.bulk_write([InsertOne({"x": 1})], bypass_document_validation=True)
|
||||
self.assertEqual(1, result.inserted_count)
|
||||
|
||||
async def test_bulk_write_result_properties(self):
|
||||
"""Test all BulkWriteResult properties."""
|
||||
await self.coll.insert_one({"x": 1})
|
||||
|
||||
requests = [
|
||||
InsertOne({"x": 2}),
|
||||
UpdateOne({"x": 1}, {"$set": {"updated": True}}),
|
||||
ReplaceOne({"x": 2}, {"x": 2, "replaced": True}, upsert=True),
|
||||
DeleteOne({"x": 1}),
|
||||
]
|
||||
result = await self.coll.bulk_write(requests)
|
||||
|
||||
# Check all properties
|
||||
self.assertTrue(result.acknowledged)
|
||||
self.assertEqual(1, result.inserted_count)
|
||||
self.assertGreaterEqual(result.matched_count, 0)
|
||||
self.assertGreaterEqual(result.modified_count, 0)
|
||||
self.assertEqual(1, result.deleted_count)
|
||||
self.assertIsInstance(result.upserted_count, int)
|
||||
self.assertIsInstance(result.upserted_ids, dict)
|
||||
|
||||
async def test_bulk_write_with_upsert(self):
|
||||
"""Test bulk write upsert operations."""
|
||||
requests = [
|
||||
UpdateOne({"x": 1}, {"$set": {"y": 1}}, upsert=True),
|
||||
UpdateOne({"x": 2}, {"$set": {"y": 2}}, upsert=True),
|
||||
ReplaceOne({"x": 3}, {"x": 3, "y": 3}, upsert=True),
|
||||
]
|
||||
result = await self.coll.bulk_write(requests)
|
||||
|
||||
self.assertEqual(3, result.upserted_count)
|
||||
self.assertEqual(3, len(result.upserted_ids))
|
||||
|
||||
async def test_update_one_with_hint(self):
|
||||
"""Test UpdateOne with hint parameter."""
|
||||
await self.coll.create_index([("x", 1)])
|
||||
self.addAsyncCleanup(self.coll.drop_index, [("x", 1)])
|
||||
|
||||
await self.coll.insert_one({"x": 1})
|
||||
|
||||
requests = [UpdateOne({"x": 1}, {"$set": {"y": 1}}, hint=[("x", 1)])]
|
||||
result = await self.coll.bulk_write(requests)
|
||||
self.assertEqual(1, result.modified_count)
|
||||
|
||||
async def test_update_many_with_hint(self):
|
||||
"""Test UpdateMany with hint parameter."""
|
||||
await self.coll.create_index([("x", 1)])
|
||||
self.addAsyncCleanup(self.coll.drop_index, [("x", 1)])
|
||||
|
||||
await self.coll.insert_many([{"x": 1}, {"x": 1}])
|
||||
|
||||
requests = [UpdateMany({"x": 1}, {"$set": {"y": 1}}, hint=[("x", 1)])]
|
||||
result = await self.coll.bulk_write(requests)
|
||||
self.assertEqual(2, result.modified_count)
|
||||
|
||||
async def test_delete_one_with_hint(self):
|
||||
"""Test DeleteOne with hint parameter."""
|
||||
await self.coll.create_index([("x", 1)])
|
||||
self.addAsyncCleanup(self.coll.drop_index, [("x", 1)])
|
||||
|
||||
await self.coll.insert_one({"x": 1})
|
||||
|
||||
requests = [DeleteOne({"x": 1}, hint=[("x", 1)])]
|
||||
result = await self.coll.bulk_write(requests)
|
||||
self.assertEqual(1, result.deleted_count)
|
||||
|
||||
async def test_delete_many_with_hint(self):
|
||||
"""Test DeleteMany with hint parameter."""
|
||||
await self.coll.create_index([("x", 1)])
|
||||
self.addAsyncCleanup(self.coll.drop_index, [("x", 1)])
|
||||
|
||||
await self.coll.insert_many([{"x": 1}, {"x": 1}])
|
||||
|
||||
requests = [DeleteMany({"x": 1}, hint=[("x", 1)])]
|
||||
result = await self.coll.bulk_write(requests)
|
||||
self.assertEqual(2, result.deleted_count)
|
||||
|
||||
async def test_update_one_with_array_filters(self):
|
||||
"""Test UpdateOne with array_filters parameter."""
|
||||
await self.coll.insert_one({"x": [{"y": 1}, {"y": 2}, {"y": 3}]})
|
||||
|
||||
requests = [
|
||||
UpdateOne({}, {"$set": {"x.$[elem].z": 1}}, array_filters=[{"elem.y": {"$gt": 1}}])
|
||||
]
|
||||
result = await self.coll.bulk_write(requests)
|
||||
self.assertEqual(1, result.modified_count)
|
||||
|
||||
doc = await self.coll.find_one()
|
||||
# Elements with y > 1 should have z = 1
|
||||
for elem in doc["x"]:
|
||||
if elem["y"] > 1:
|
||||
self.assertEqual(1, elem.get("z"))
|
||||
|
||||
async def test_replace_one_with_hint(self):
|
||||
"""Test ReplaceOne with hint parameter."""
|
||||
await self.coll.create_index([("x", 1)])
|
||||
self.addAsyncCleanup(self.coll.drop_index, [("x", 1)])
|
||||
|
||||
await self.coll.insert_one({"x": 1})
|
||||
|
||||
requests = [ReplaceOne({"x": 1}, {"x": 1, "replaced": True}, hint=[("x", 1)])]
|
||||
result = await self.coll.bulk_write(requests)
|
||||
self.assertEqual(1, result.modified_count)
|
||||
|
||||
async def test_update_with_collation(self):
|
||||
"""Test update operations with collation."""
|
||||
await self.coll.insert_many(
|
||||
[
|
||||
{"name": "cafe"},
|
||||
{"name": "Cafe"},
|
||||
]
|
||||
)
|
||||
|
||||
requests = [
|
||||
UpdateMany(
|
||||
{"name": "cafe"},
|
||||
{"$set": {"updated": True}},
|
||||
collation={"locale": "en", "strength": 2},
|
||||
)
|
||||
]
|
||||
result = await self.coll.bulk_write(requests)
|
||||
# With case-insensitive collation, both docs should match
|
||||
self.assertEqual(2, result.modified_count)
|
||||
|
||||
|
||||
class AsyncBulkAuthorizationTestBase(AsyncBulkTestBase):
|
||||
@async_client_context.require_auth
|
||||
|
||||
@ -1152,5 +1152,122 @@ globals().update(
|
||||
)
|
||||
|
||||
|
||||
class AsyncTestChangeStreamCoverage(TestAsyncCollectionAsyncChangeStream):
|
||||
"""Additional tests to improve code coverage for AsyncChangeStream."""
|
||||
|
||||
async def test_change_stream_alive_property(self):
|
||||
"""Test alive property state transitions."""
|
||||
async with await self.change_stream() as cs:
|
||||
self.assertTrue(cs.alive)
|
||||
# After context exit, should be closed
|
||||
self.assertFalse(cs.alive)
|
||||
|
||||
async def test_change_stream_idempotent_close(self):
|
||||
"""Test that close() can be called multiple times safely."""
|
||||
cs = await self.change_stream()
|
||||
await cs.close()
|
||||
# Second close should not raise
|
||||
await cs.close()
|
||||
self.assertFalse(cs.alive)
|
||||
|
||||
async def test_change_stream_resume_token_deepcopy(self):
|
||||
"""Test that resume_token returns a deep copy."""
|
||||
coll = self.watched_collection()
|
||||
async with await self.change_stream() as cs:
|
||||
await coll.insert_one({"x": 1})
|
||||
await anext(cs) # Consume the change event
|
||||
token1 = cs.resume_token
|
||||
token2 = cs.resume_token
|
||||
# Should be equal but different objects
|
||||
self.assertEqual(token1, token2)
|
||||
self.assertIsNot(token1, token2)
|
||||
|
||||
async def test_change_stream_with_comment(self):
|
||||
"""Test change stream with comment parameter."""
|
||||
client, listener = await self.client_with_listener("aggregate")
|
||||
try:
|
||||
async with await self.change_stream_with_client(client, comment="test_comment"):
|
||||
pass
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
# Check that comment was in the aggregate command
|
||||
self.assertGreater(len(listener.started_events), 0)
|
||||
cmd = listener.started_events[0].command
|
||||
self.assertEqual("test_comment", cmd.get("comment"))
|
||||
|
||||
async def test_change_stream_with_show_expanded_events(self):
|
||||
"""Test change stream with show_expanded_events parameter."""
|
||||
if not async_client_context.version.at_least(6, 0):
|
||||
self.skipTest("show_expanded_events requires MongoDB 6.0+")
|
||||
|
||||
async with await self.change_stream(show_expanded_events=True) as cs:
|
||||
# Just verify it doesn't error
|
||||
self.assertTrue(cs.alive)
|
||||
|
||||
@async_client_context.require_version_min(6, 0)
|
||||
async def test_change_stream_with_full_document_before_change(self):
|
||||
"""Test change stream with full_document_before_change parameter."""
|
||||
coll = self.watched_collection()
|
||||
# Need to ensure collection exists with changeStreamPreAndPostImages enabled
|
||||
await coll.drop()
|
||||
await self.db.create_collection(coll.name, changeStreamPreAndPostImages={"enabled": True})
|
||||
await coll.insert_one({"x": 1})
|
||||
|
||||
async with await self.change_stream(full_document_before_change="whenAvailable") as cs:
|
||||
await coll.update_one({"x": 1}, {"$set": {"x": 2}})
|
||||
change = await anext(cs)
|
||||
self.assertEqual("update", change["operationType"])
|
||||
# fullDocumentBeforeChange should be present
|
||||
self.assertIn("fullDocumentBeforeChange", change)
|
||||
|
||||
async def test_change_stream_next_after_close(self):
|
||||
"""Test that next() on closed stream raises StopAsyncIteration."""
|
||||
cs = await self.change_stream()
|
||||
await cs.close()
|
||||
with self.assertRaises(StopAsyncIteration):
|
||||
await anext(cs)
|
||||
|
||||
async def test_change_stream_try_next_after_close(self):
|
||||
"""Test that try_next() on closed stream raises StopAsyncIteration."""
|
||||
cs = await self.change_stream()
|
||||
await cs.close()
|
||||
with self.assertRaises(StopAsyncIteration):
|
||||
await cs.try_next()
|
||||
|
||||
async def test_change_stream_pipeline_construction(self):
|
||||
"""Test change stream pipeline is properly constructed."""
|
||||
pipeline = [{"$match": {"operationType": "insert"}}]
|
||||
client, listener = await self.client_with_listener("aggregate")
|
||||
try:
|
||||
async with await self.change_stream_with_client(client, pipeline=pipeline):
|
||||
pass
|
||||
finally:
|
||||
await client.close()
|
||||
|
||||
cmd = listener.started_events[0].command
|
||||
agg_pipeline = cmd["pipeline"]
|
||||
# First stage should be $changeStream
|
||||
self.assertIn("$changeStream", agg_pipeline[0])
|
||||
# Second stage should be our match
|
||||
self.assertEqual({"$match": {"operationType": "insert"}}, agg_pipeline[1])
|
||||
|
||||
async def test_change_stream_empty_pipeline(self):
|
||||
"""Test change stream with empty pipeline."""
|
||||
async with await self.change_stream(pipeline=[]) as cs:
|
||||
self.assertTrue(cs.alive)
|
||||
|
||||
async def test_change_stream_context_manager_exception(self):
|
||||
"""Test change stream context manager closes on exception."""
|
||||
cs = None
|
||||
try:
|
||||
async with await self.change_stream() as cs:
|
||||
raise ValueError("test exception")
|
||||
except ValueError:
|
||||
pass
|
||||
# Stream should be closed
|
||||
self.assertFalse(cs.alive)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -2260,5 +2260,264 @@ class AsyncTestCollection(AsyncIntegrationTest):
|
||||
await helper(*args, let={}) # type: ignore
|
||||
|
||||
|
||||
class AsyncTestCollectionCoverage(AsyncIntegrationTest):
|
||||
"""Additional tests to improve code coverage for AsyncCollection."""
|
||||
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
await self.db.test.drop()
|
||||
await self.db.test.insert_many([{"x": i, "y": i * 2} for i in range(10)])
|
||||
|
||||
async def test_collection_full_name(self):
|
||||
"""Test full_name property."""
|
||||
expected = f"{self.db.name}.test"
|
||||
self.assertEqual(expected, self.db.test.full_name)
|
||||
|
||||
async def test_collection_name(self):
|
||||
"""Test name property."""
|
||||
self.assertEqual("test", self.db.test.name)
|
||||
|
||||
async def test_collection_database(self):
|
||||
"""Test database property."""
|
||||
self.assertEqual(self.db, self.db.test.database)
|
||||
|
||||
async def test_collection_equality(self):
|
||||
"""Test collection equality."""
|
||||
coll1 = self.db.test
|
||||
coll2 = self.db.test
|
||||
coll3 = self.db.other
|
||||
self.assertEqual(coll1, coll2)
|
||||
self.assertNotEqual(coll1, coll3)
|
||||
|
||||
async def test_collection_hash(self):
|
||||
"""Test collection hashability."""
|
||||
coll1 = self.db.test
|
||||
coll2 = self.db.test
|
||||
# Same collection should have same hash
|
||||
self.assertEqual(hash(coll1), hash(coll2))
|
||||
# Collections can be used in sets
|
||||
s = {coll1, coll2}
|
||||
self.assertEqual(1, len(s))
|
||||
|
||||
async def test_collection_repr(self):
|
||||
"""Test collection repr."""
|
||||
coll = self.db.test
|
||||
repr_str = repr(coll)
|
||||
self.assertIn("test", repr_str)
|
||||
self.assertIn("AsyncCollection", repr_str)
|
||||
|
||||
async def test_collection_getattr(self):
|
||||
"""Test sub-collection access via attribute."""
|
||||
subcoll = self.db.test.subcollection
|
||||
self.assertEqual("test.subcollection", subcoll.name)
|
||||
|
||||
async def test_collection_getitem(self):
|
||||
"""Test sub-collection access via indexing."""
|
||||
subcoll = self.db.test["subcollection"]
|
||||
self.assertEqual("test.subcollection", subcoll.name)
|
||||
|
||||
async def test_collection_with_options(self):
|
||||
"""Test with_options creates new collection with options."""
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
coll = self.db.test.with_options(
|
||||
read_concern=ReadConcern("majority"), write_concern=WriteConcern(w=1)
|
||||
)
|
||||
self.assertEqual("majority", coll.read_concern.level)
|
||||
self.assertEqual({"w": 1}, coll.write_concern.document)
|
||||
# Original should be unchanged
|
||||
self.assertNotEqual("majority", self.db.test.read_concern.level)
|
||||
|
||||
async def test_collection_drop(self):
|
||||
"""Test collection drop."""
|
||||
await self.db.test_drop.insert_one({"x": 1})
|
||||
await self.db.test_drop.drop()
|
||||
names = await self.db.list_collection_names()
|
||||
self.assertNotIn("test_drop", names)
|
||||
|
||||
async def test_collection_drop_with_comment(self):
|
||||
"""Test collection drop with comment."""
|
||||
await self.db.test_drop_comment.insert_one({"x": 1})
|
||||
await self.db.test_drop_comment.drop(comment="test_comment")
|
||||
names = await self.db.list_collection_names()
|
||||
self.assertNotIn("test_drop_comment", names)
|
||||
|
||||
async def test_find_raw_batches(self):
|
||||
"""Test find_raw_batches returns raw BSON."""
|
||||
from bson import decode_all
|
||||
|
||||
cursor = self.db.test.find_raw_batches(batch_size=5)
|
||||
batch_count = 0
|
||||
async for batch in cursor:
|
||||
self.assertIsInstance(batch, bytes)
|
||||
docs = decode_all(batch)
|
||||
self.assertGreater(len(docs), 0)
|
||||
batch_count += 1
|
||||
self.assertGreater(batch_count, 0)
|
||||
|
||||
async def test_aggregate_raw_batches(self):
|
||||
"""Test aggregate_raw_batches returns raw BSON."""
|
||||
from bson import decode_all
|
||||
|
||||
cursor = await self.db.test.aggregate_raw_batches([{"$sort": {"x": 1}}], batchSize=5)
|
||||
batch_count = 0
|
||||
async for batch in cursor:
|
||||
self.assertIsInstance(batch, bytes)
|
||||
docs = decode_all(batch)
|
||||
self.assertGreater(len(docs), 0)
|
||||
batch_count += 1
|
||||
self.assertGreater(batch_count, 0)
|
||||
|
||||
async def test_distinct_with_collation(self):
|
||||
"""Test distinct with collation."""
|
||||
await self.db.test.drop()
|
||||
await self.db.test.insert_many(
|
||||
[
|
||||
{"name": "abc"},
|
||||
{"name": "ABC"},
|
||||
{"name": "def"},
|
||||
]
|
||||
)
|
||||
# Case-insensitive distinct
|
||||
values = await self.db.test.distinct("name", collation={"locale": "en_US", "strength": 2})
|
||||
# abc and ABC should be considered the same
|
||||
self.assertEqual(2, len(values))
|
||||
|
||||
async def test_count_documents_with_options(self):
|
||||
"""Test count_documents with skip, limit, hint."""
|
||||
await self.db.test.create_index([("x", 1)])
|
||||
|
||||
count = await self.db.test.count_documents(
|
||||
{"x": {"$gte": 0}}, skip=2, limit=5, hint=[("x", 1)]
|
||||
)
|
||||
self.assertEqual(5, count)
|
||||
|
||||
async def test_estimated_document_count(self):
|
||||
"""Test estimated_document_count."""
|
||||
count = await self.db.test.estimated_document_count()
|
||||
self.assertEqual(10, count)
|
||||
|
||||
async def test_estimated_document_count_with_options(self):
|
||||
"""Test estimated_document_count with maxTimeMS and comment."""
|
||||
count = await self.db.test.estimated_document_count(maxTimeMS=5000, comment="test_comment")
|
||||
self.assertEqual(10, count)
|
||||
|
||||
async def test_find_one_and_delete_with_options(self):
|
||||
"""Test find_one_and_delete with projection, sort."""
|
||||
doc = await self.db.test.find_one_and_delete(
|
||||
{"x": {"$gte": 0}}, projection={"x": 1}, sort=[("x", -1)]
|
||||
)
|
||||
self.assertEqual(9, doc["x"])
|
||||
self.assertNotIn("y", doc)
|
||||
|
||||
async def test_find_one_and_replace_with_options(self):
|
||||
"""Test find_one_and_replace with various options."""
|
||||
from pymongo import ReturnDocument
|
||||
|
||||
doc = await self.db.test.find_one_and_replace(
|
||||
{"x": 0},
|
||||
{"x": 0, "replaced": True},
|
||||
projection={"x": 1, "replaced": 1},
|
||||
return_document=ReturnDocument.AFTER,
|
||||
)
|
||||
self.assertEqual(0, doc["x"])
|
||||
self.assertTrue(doc.get("replaced"))
|
||||
|
||||
async def test_find_one_and_update_with_options(self):
|
||||
"""Test find_one_and_update with various options."""
|
||||
from pymongo import ReturnDocument
|
||||
|
||||
doc = await self.db.test.find_one_and_update(
|
||||
{"x": 0},
|
||||
{"$set": {"updated": True}},
|
||||
projection={"x": 1, "updated": 1},
|
||||
return_document=ReturnDocument.AFTER,
|
||||
)
|
||||
self.assertEqual(0, doc["x"])
|
||||
self.assertTrue(doc.get("updated"))
|
||||
|
||||
async def test_update_one_with_array_filters(self):
|
||||
"""Test update_one with array_filters."""
|
||||
await self.db.test.drop()
|
||||
await self.db.test.insert_one({"items": [{"v": 1}, {"v": 2}, {"v": 3}]})
|
||||
|
||||
result = await self.db.test.update_one(
|
||||
{}, {"$set": {"items.$[elem].updated": True}}, array_filters=[{"elem.v": {"$gt": 1}}]
|
||||
)
|
||||
self.assertEqual(1, result.modified_count)
|
||||
|
||||
async def test_update_many_with_hint(self):
|
||||
"""Test update_many with hint."""
|
||||
await self.db.test.create_index([("x", 1)])
|
||||
|
||||
result = await self.db.test.update_many(
|
||||
{"x": {"$gte": 0}}, {"$set": {"batch_updated": True}}, hint=[("x", 1)]
|
||||
)
|
||||
self.assertEqual(10, result.modified_count)
|
||||
|
||||
async def test_delete_one_with_hint(self):
|
||||
"""Test delete_one with hint."""
|
||||
await self.db.test.create_index([("x", 1)])
|
||||
|
||||
result = await self.db.test.delete_one({"x": 0}, hint=[("x", 1)])
|
||||
self.assertEqual(1, result.deleted_count)
|
||||
|
||||
async def test_delete_many_with_hint(self):
|
||||
"""Test delete_many with hint."""
|
||||
await self.db.test.create_index([("x", 1)])
|
||||
|
||||
result = await self.db.test.delete_many({"x": {"$lt": 5}}, hint=[("x", 1)])
|
||||
self.assertEqual(5, result.deleted_count)
|
||||
|
||||
async def test_aggregate_with_let(self):
|
||||
"""Test aggregate with let parameter."""
|
||||
if not async_client_context.version.at_least(5, 0):
|
||||
self.skipTest("let parameter requires MongoDB 5.0+")
|
||||
|
||||
pipeline = [{"$match": {"$expr": {"$eq": ["$x", "$$targetVal"]}}}]
|
||||
cursor = await self.db.test.aggregate(pipeline, let={"targetVal": 5})
|
||||
docs = await cursor.to_list()
|
||||
self.assertEqual(1, len(docs))
|
||||
self.assertEqual(5, docs[0]["x"])
|
||||
|
||||
async def test_aggregate_with_batch_size(self):
|
||||
"""Test aggregate with batchSize."""
|
||||
cursor = await self.db.test.aggregate([{"$sort": {"x": 1}}], batchSize=2)
|
||||
docs = await cursor.to_list()
|
||||
self.assertEqual(10, len(docs))
|
||||
|
||||
async def test_list_indexes(self):
|
||||
"""Test list_indexes returns cursor."""
|
||||
await self.db.test.create_index([("x", 1)])
|
||||
cursor = await self.db.test.list_indexes()
|
||||
|
||||
# Should get at least the _id index
|
||||
indexes = await cursor.to_list()
|
||||
self.assertGreaterEqual(len(indexes), 1)
|
||||
index_names = [idx["name"] for idx in indexes]
|
||||
self.assertIn("_id_", index_names)
|
||||
|
||||
async def test_index_information(self):
|
||||
"""Test index_information returns dict."""
|
||||
await self.db.test.create_index([("x", 1)], name="x_index")
|
||||
info = await self.db.test.index_information()
|
||||
|
||||
self.assertIsInstance(info, dict)
|
||||
self.assertIn("_id_", info)
|
||||
self.assertIn("x_index", info)
|
||||
|
||||
async def test_options_method(self):
|
||||
"""Test options() returns collection options."""
|
||||
# Create a capped collection
|
||||
await self.db.drop_collection("test_capped")
|
||||
await self.db.create_collection("test_capped", capped=True, size=10000)
|
||||
|
||||
opts = await self.db.test_capped.options()
|
||||
self.assertTrue(opts.get("capped"))
|
||||
|
||||
await self.db.drop_collection("test_capped")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -1864,5 +1864,404 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest):
|
||||
self.assertEqual(cmd.command["$db"], "pymongo_test")
|
||||
|
||||
|
||||
class AsyncTestCursorCoverage(AsyncIntegrationTest):
|
||||
"""Additional tests to improve code coverage for AsyncCursor."""
|
||||
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
await self.db.test.drop()
|
||||
await self.db.test.insert_many([{"x": i, "y": i * 2} for i in range(10)])
|
||||
|
||||
async def test_get_namespace(self):
|
||||
"""Test _get_namespace() method."""
|
||||
cursor = self.db.test.find()
|
||||
expected_ns = f"{self.db.name}.test"
|
||||
self.assertEqual(expected_ns, cursor._get_namespace())
|
||||
|
||||
async def test_cursor_alive_property_states(self):
|
||||
"""Test cursor alive property in different states."""
|
||||
cursor = self.db.test.find()
|
||||
# Cursor is alive even before starting (has potential to return data)
|
||||
self.assertTrue(cursor.alive)
|
||||
|
||||
# Start the cursor
|
||||
await anext(cursor)
|
||||
self.assertTrue(cursor.alive)
|
||||
|
||||
# Exhaust the cursor
|
||||
await cursor.to_list()
|
||||
self.assertFalse(cursor.alive)
|
||||
|
||||
async def test_cursor_closed_property(self):
|
||||
"""Test cursor behavior after close."""
|
||||
cursor = self.db.test.find()
|
||||
await anext(cursor)
|
||||
self.assertTrue(cursor.alive)
|
||||
|
||||
await cursor.close()
|
||||
# After close, cursor is killed (check internal _killed flag)
|
||||
self.assertTrue(cursor._killed)
|
||||
|
||||
async def test_retrieved_property(self):
|
||||
"""Test the retrieved property tracking."""
|
||||
cursor = self.db.test.find().batch_size(2)
|
||||
self.assertEqual(0, cursor.retrieved)
|
||||
|
||||
await anext(cursor)
|
||||
self.assertGreater(cursor.retrieved, 0)
|
||||
|
||||
async def test_cursor_with_let_parameter(self):
|
||||
"""Test cursor with let parameter."""
|
||||
# let parameter allows variables to be used in the filter
|
||||
cursor = self.db.test.find(
|
||||
{"$expr": {"$eq": ["$x", "$$targetValue"]}}, let={"targetValue": 5}
|
||||
)
|
||||
docs = await cursor.to_list()
|
||||
self.assertEqual(1, len(docs))
|
||||
self.assertEqual(5, docs[0]["x"])
|
||||
|
||||
async def test_cursor_with_invalid_let_parameter(self):
|
||||
"""Test cursor raises error for invalid let parameter."""
|
||||
with self.assertRaises(TypeError):
|
||||
self.db.test.find(let="invalid") # type: ignore[arg-type]
|
||||
|
||||
async def test_cursor_with_show_record_id(self):
|
||||
"""Test cursor with show_record_id option."""
|
||||
cursor = self.db.test.find(show_record_id=True)
|
||||
doc = await anext(cursor)
|
||||
self.assertIn("$recordId", doc)
|
||||
|
||||
async def test_cursor_with_return_key(self):
|
||||
"""Test cursor with return_key option."""
|
||||
await self.db.test.create_index([("x", ASCENDING)])
|
||||
cursor = self.db.test.find({"x": 5}, return_key=True).hint([("x", ASCENDING)])
|
||||
doc = await anext(cursor)
|
||||
# return_key returns only index keys
|
||||
self.assertIn("x", doc)
|
||||
self.assertNotIn("y", doc)
|
||||
|
||||
async def test_check_okay_to_chain_after_iteration(self):
|
||||
"""Test that cursor configuration methods raise after iteration."""
|
||||
cursor = self.db.test.find()
|
||||
await anext(cursor) # Start iteration
|
||||
|
||||
# All these should raise InvalidOperation
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.limit(5)
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.skip(2)
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.sort("x")
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.hint([("x", ASCENDING)])
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.max([("x", 10)])
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.min([("x", 0)])
|
||||
with self.assertRaises(InvalidOperation):
|
||||
await cursor.add_option(2)
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.remove_option(2)
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.batch_size(10)
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.max_time_ms(1000)
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.collation(Collation("en_US"))
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.allow_disk_use(True)
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.where("this.x > 5")
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.comment("test")
|
||||
|
||||
async def test_cursor_context_manager(self):
|
||||
"""Test cursor as async context manager."""
|
||||
async with self.db.test.find() as cursor:
|
||||
doc = await anext(cursor)
|
||||
self.assertIsNotNone(doc)
|
||||
# Cursor should be killed after context (check _killed flag)
|
||||
self.assertTrue(cursor._killed)
|
||||
|
||||
async def test_cursor_context_manager_with_exception(self):
|
||||
"""Test cursor context manager closes on exception."""
|
||||
cursor = None
|
||||
try:
|
||||
async with self.db.test.find() as cursor:
|
||||
await anext(cursor)
|
||||
raise ValueError("test exception")
|
||||
except ValueError:
|
||||
pass
|
||||
# Cursor should be killed after exception
|
||||
self.assertTrue(cursor._killed)
|
||||
|
||||
async def test_cursor_collation(self):
|
||||
"""Test cursor with collation."""
|
||||
await self.db.test.drop()
|
||||
await self.db.test.insert_many([{"name": "abc"}, {"name": "ABC"}, {"name": "def"}])
|
||||
# Case-insensitive sort
|
||||
cursor = (
|
||||
self.db.test.find().collation(Collation("en_US", strength=2)).sort("name", ASCENDING)
|
||||
)
|
||||
docs = await cursor.to_list()
|
||||
self.assertEqual(3, len(docs))
|
||||
|
||||
async def test_cursor_collation_type_error(self):
|
||||
"""Test cursor raises error for invalid collation."""
|
||||
with self.assertRaises(TypeError):
|
||||
self.db.test.find().collation("invalid") # type: ignore[arg-type]
|
||||
|
||||
async def test_cursor_getitem_not_supported(self):
|
||||
"""Test that AsyncCursor does not support indexing."""
|
||||
cursor = self.db.test.find()
|
||||
with self.assertRaises(IndexError) as ctx:
|
||||
cursor[5]
|
||||
self.assertIn("does not support indexing", str(ctx.exception))
|
||||
|
||||
async def test_cursor_next_after_close(self):
|
||||
"""Test that next() raises StopAsyncIteration after close."""
|
||||
cursor = self.db.test.find()
|
||||
await cursor.close()
|
||||
with self.assertRaises(StopAsyncIteration):
|
||||
await anext(cursor)
|
||||
|
||||
async def test_cursor_rewind_resets_state(self):
|
||||
"""Test that rewind properly resets cursor state."""
|
||||
cursor = self.db.test.find().limit(3)
|
||||
|
||||
# Iterate fully
|
||||
docs1 = await cursor.to_list()
|
||||
self.assertEqual(3, len(docs1))
|
||||
self.assertEqual(0, len(cursor._data))
|
||||
|
||||
# Rewind and iterate again
|
||||
await cursor.rewind()
|
||||
docs2 = await cursor.to_list()
|
||||
self.assertEqual(3, len(docs2))
|
||||
self.assertEqual(docs1, docs2)
|
||||
|
||||
async def test_cursor_clone_with_session(self):
|
||||
"""Test that clone preserves explicit session."""
|
||||
async with self.client.start_session() as session:
|
||||
cursor = self.db.test.find(session=session)
|
||||
cloned = cursor.clone()
|
||||
# Clone should reference the same session
|
||||
self.assertEqual(cursor.session, cloned.session)
|
||||
|
||||
async def test_cursor_clone_without_session(self):
|
||||
"""Test that clone without session doesn't add one."""
|
||||
cursor = self.db.test.find()
|
||||
cloned = cursor.clone()
|
||||
# Clone should have no session if original had none
|
||||
self.assertIsNone(cloned.session)
|
||||
|
||||
async def test_cursor_distinct_with_collation(self):
|
||||
"""Test distinct with collation."""
|
||||
await self.db.test.drop()
|
||||
await self.db.test.insert_many([{"name": "abc"}, {"name": "ABC"}, {"name": "def"}])
|
||||
# Case-insensitive distinct
|
||||
cursor = self.db.test.find().collation(Collation("en_US", strength=2))
|
||||
# distinct() on cursor with collation
|
||||
values = await cursor.distinct("name")
|
||||
# Should have 2 distinct values (abc/ABC treated as same)
|
||||
self.assertEqual(2, len(values))
|
||||
|
||||
async def test_cursor_explain_with_options(self):
|
||||
"""Test explain with cursor options set."""
|
||||
cursor = self.db.test.find({"x": {"$gt": 5}}).sort("x", ASCENDING).limit(5).skip(1)
|
||||
explanation = await cursor.explain()
|
||||
self.assertIn("queryPlanner", explanation)
|
||||
|
||||
async def test_cursor_max_time_ms_type_errors(self):
|
||||
"""Test max_time_ms raises TypeError for invalid input."""
|
||||
cursor = self.db.test.find()
|
||||
with self.assertRaises(TypeError):
|
||||
cursor.max_time_ms("invalid") # type: ignore[arg-type]
|
||||
|
||||
async def test_cursor_max_await_time_ms_type_errors(self):
|
||||
"""Test max_await_time_ms raises TypeError for invalid input."""
|
||||
cursor = self.db.test.find()
|
||||
with self.assertRaises(TypeError):
|
||||
cursor.max_await_time_ms("invalid") # type: ignore[arg-type]
|
||||
|
||||
async def test_cursor_comment_type(self):
|
||||
"""Test cursor with comment of various types."""
|
||||
# String comment
|
||||
cursor1 = self.db.test.find().comment("test comment")
|
||||
docs1 = await cursor1.to_list()
|
||||
self.assertGreater(len(docs1), 0)
|
||||
|
||||
# Dict comment
|
||||
cursor2 = self.db.test.find().comment({"key": "value"})
|
||||
docs2 = await cursor2.to_list()
|
||||
self.assertGreater(len(docs2), 0)
|
||||
|
||||
async def test_cursor_batch_size_validation(self):
|
||||
"""Test batch_size validation."""
|
||||
with self.assertRaises(TypeError):
|
||||
self.db.test.find(batch_size="invalid") # type: ignore[arg-type]
|
||||
with self.assertRaises(ValueError):
|
||||
self.db.test.find(batch_size=-1)
|
||||
|
||||
async def test_cursor_skip_validation(self):
|
||||
"""Test skip validation."""
|
||||
with self.assertRaises(TypeError):
|
||||
self.db.test.find(skip="invalid") # type: ignore[arg-type]
|
||||
|
||||
async def test_cursor_limit_validation(self):
|
||||
"""Test limit validation."""
|
||||
with self.assertRaises(TypeError):
|
||||
self.db.test.find(limit="invalid") # type: ignore[arg-type]
|
||||
|
||||
async def test_cursor_filter_validation(self):
|
||||
"""Test filter validation."""
|
||||
with self.assertRaises(TypeError):
|
||||
self.db.test.find(filter="invalid") # type: ignore[arg-type]
|
||||
|
||||
async def test_cursor_type_validation(self):
|
||||
"""Test cursor_type validation."""
|
||||
with self.assertRaises(ValueError):
|
||||
self.db.test.find(cursor_type=999)
|
||||
|
||||
async def test_cursor_query_spec_with_modifiers(self):
|
||||
"""Test _query_spec includes modifiers."""
|
||||
cursor = (
|
||||
self.db.test.find()
|
||||
.sort("x", ASCENDING)
|
||||
.hint([("x", ASCENDING)])
|
||||
.max_time_ms(1000)
|
||||
.comment("test")
|
||||
)
|
||||
spec = cursor._query_spec()
|
||||
self.assertIsInstance(spec, dict)
|
||||
|
||||
async def test_cursor_copy(self):
|
||||
"""Test cursor __copy__ returns clone."""
|
||||
cursor = self.db.test.find().limit(5)
|
||||
copied = copy.copy(cursor)
|
||||
self.assertIsNot(cursor, copied)
|
||||
self.assertEqual(cursor._limit, copied._limit)
|
||||
|
||||
async def test_cursor_deepcopy(self):
|
||||
"""Test cursor __deepcopy__ returns deep clone."""
|
||||
cursor = self.db.test.find({"x": {"$gt": 0}}).limit(5)
|
||||
copied = copy.deepcopy(cursor)
|
||||
self.assertIsNot(cursor, copied)
|
||||
self.assertEqual(cursor._limit, copied._limit)
|
||||
self.assertEqual(cursor._spec, copied._spec)
|
||||
# Spec should be a different object
|
||||
self.assertIsNot(cursor._spec, copied._spec)
|
||||
|
||||
async def test_cursor_iteration_protocol(self):
|
||||
"""Test cursor async iteration protocol."""
|
||||
cursor = self.db.test.find().limit(3)
|
||||
|
||||
# Test __aiter__ returns self
|
||||
self.assertIs(cursor, cursor.__aiter__())
|
||||
|
||||
# Test __anext__ returns documents
|
||||
doc1 = await cursor.__anext__()
|
||||
self.assertIsNotNone(doc1)
|
||||
|
||||
async def test_cursor_to_list_with_limit(self):
|
||||
"""Test to_list respects cursor limit."""
|
||||
cursor = self.db.test.find().limit(3)
|
||||
docs = await cursor.to_list()
|
||||
self.assertEqual(3, len(docs))
|
||||
|
||||
async def test_cursor_to_list_with_length(self):
|
||||
"""Test to_list with length parameter."""
|
||||
cursor = self.db.test.find()
|
||||
docs = await cursor.to_list(length=3)
|
||||
self.assertEqual(3, len(docs))
|
||||
|
||||
async def test_min_max_require_hint(self):
|
||||
"""Test that min/max require hint for proper execution."""
|
||||
await self.db.test.create_index([("x", ASCENDING)])
|
||||
|
||||
# min without hint should work when index exists
|
||||
cursor = self.db.test.find().min([("x", 5)]).hint([("x", ASCENDING)])
|
||||
docs = await cursor.to_list()
|
||||
self.assertTrue(all(doc["x"] >= 5 for doc in docs))
|
||||
|
||||
# max without hint should work when index exists
|
||||
cursor = self.db.test.find().max([("x", 5)]).hint([("x", ASCENDING)])
|
||||
docs = await cursor.to_list()
|
||||
self.assertTrue(all(doc["x"] < 5 for doc in docs))
|
||||
|
||||
async def test_cursor_address_property(self):
|
||||
"""Test cursor address is set after first batch."""
|
||||
cursor = self.db.test.find()
|
||||
self.assertIsNone(cursor.address)
|
||||
await anext(cursor)
|
||||
# Address should be set after query
|
||||
self.assertIsNotNone(cursor.address)
|
||||
|
||||
async def test_cursor_session_property(self):
|
||||
"""Test cursor session property."""
|
||||
# Cursor without explicit session
|
||||
cursor1 = self.db.test.find()
|
||||
self.assertIsNone(cursor1.session)
|
||||
|
||||
# Cursor with explicit session
|
||||
async with self.client.start_session() as session:
|
||||
cursor2 = self.db.test.find(session=session)
|
||||
self.assertEqual(session, cursor2.session)
|
||||
|
||||
async def test_cursor_allow_disk_use_type_error(self):
|
||||
"""Test allow_disk_use raises TypeError for invalid input."""
|
||||
with self.assertRaises(TypeError):
|
||||
self.db.test.find().allow_disk_use("invalid") # type: ignore[arg-type]
|
||||
|
||||
|
||||
class AsyncTestRawBatchCursorCoverage(AsyncIntegrationTest):
|
||||
"""Additional tests for AsyncRawBatchCursor coverage."""
|
||||
|
||||
async def asyncSetUp(self):
|
||||
await super().asyncSetUp()
|
||||
await self.db.test.drop()
|
||||
await self.db.test.insert_many([{"x": i} for i in range(20)])
|
||||
|
||||
async def test_raw_batch_cursor_iteration(self):
|
||||
"""Test raw batch cursor returns raw BSON."""
|
||||
cursor = self.db.test.find_raw_batches(batch_size=5)
|
||||
batch_count = 0
|
||||
async for batch in cursor:
|
||||
self.assertIsInstance(batch, bytes)
|
||||
# Decode the batch to verify it's valid BSON
|
||||
docs = decode_all(batch)
|
||||
self.assertGreater(len(docs), 0)
|
||||
batch_count += 1
|
||||
self.assertGreater(batch_count, 0)
|
||||
|
||||
async def test_raw_batch_cursor_explain(self):
|
||||
"""Test raw batch cursor explain."""
|
||||
cursor = self.db.test.find_raw_batches()
|
||||
explanation = await cursor.explain()
|
||||
self.assertIn("queryPlanner", explanation)
|
||||
|
||||
async def test_raw_batch_cursor_getitem_raises(self):
|
||||
"""Test raw batch cursor __getitem__ raises InvalidOperation."""
|
||||
cursor = self.db.test.find_raw_batches()
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor[0]
|
||||
|
||||
async def test_raw_batch_cursor_with_sort(self):
|
||||
"""Test raw batch cursor with sort."""
|
||||
cursor = self.db.test.find_raw_batches(batch_size=5).sort("x", DESCENDING)
|
||||
first_batch = await anext(cursor)
|
||||
docs = decode_all(first_batch)
|
||||
# First doc should have highest x value
|
||||
self.assertEqual(19, docs[0]["x"])
|
||||
|
||||
async def test_raw_batch_cursor_with_limit(self):
|
||||
"""Test raw batch cursor with limit."""
|
||||
cursor = self.db.test.find_raw_batches(batch_size=5).limit(7)
|
||||
all_docs = []
|
||||
async for batch in cursor:
|
||||
all_docs.extend(decode_all(batch))
|
||||
self.assertEqual(7, len(all_docs))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -53,6 +53,8 @@ from pymongo.common import _MAX_END_SESSIONS
|
||||
from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure
|
||||
from pymongo.operations import IndexModel, InsertOne, UpdateOne
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
@ -1361,5 +1363,284 @@ class TestClusterTime(AsyncIntegrationTest):
|
||||
self.assertEqual(started.command["$clusterTime"], cluster_time)
|
||||
|
||||
|
||||
class AsyncTestClientSessionCoverage(AsyncIntegrationTest):
|
||||
"""Additional tests to improve code coverage for AsyncClientSession."""
|
||||
|
||||
@async_client_context.require_sessions
|
||||
async def test_session_has_ended_property(self):
|
||||
"""Test has_ended property state transitions."""
|
||||
session = self.client.start_session()
|
||||
self.assertFalse(session.has_ended)
|
||||
await session.end_session()
|
||||
self.assertTrue(session.has_ended)
|
||||
|
||||
@async_client_context.require_sessions
|
||||
async def test_session_session_id_property(self):
|
||||
"""Test session_id property returns correct value."""
|
||||
async with self.client.start_session() as session:
|
||||
session_id = session.session_id
|
||||
self.assertIsInstance(session_id, dict)
|
||||
self.assertIn("id", session_id)
|
||||
|
||||
@async_client_context.require_sessions
|
||||
async def test_session_cluster_time_operations(self):
|
||||
"""Test cluster time advance operations."""
|
||||
async with self.client.start_session() as session:
|
||||
# Initially None
|
||||
self.assertIsNone(session.cluster_time)
|
||||
|
||||
# Perform operation to get cluster time
|
||||
await self.db.test.find_one({}, session=session)
|
||||
|
||||
# Cluster time should be set after operation
|
||||
# (may still be None on some server versions)
|
||||
|
||||
@async_client_context.require_sessions
|
||||
async def test_session_operation_time_operations(self):
|
||||
"""Test operation time advance operations."""
|
||||
async with self.client.start_session() as session:
|
||||
# Initially None
|
||||
self.assertIsNone(session.operation_time)
|
||||
|
||||
# Perform operation to get operation time
|
||||
await self.db.test.find_one({}, session=session)
|
||||
|
||||
@async_client_context.require_sessions
|
||||
async def test_session_options_property(self):
|
||||
"""Test session options property."""
|
||||
async with self.client.start_session(causal_consistency=True) as session:
|
||||
self.assertTrue(session.options.causal_consistency)
|
||||
|
||||
@async_client_context.require_sessions
|
||||
async def test_session_client_property(self):
|
||||
"""Test session client property."""
|
||||
async with self.client.start_session() as session:
|
||||
self.assertEqual(self.client, session.client)
|
||||
|
||||
@async_client_context.require_sessions
|
||||
async def test_session_in_transaction_property(self):
|
||||
"""Test in_transaction property."""
|
||||
if async_client_context.is_rs or async_client_context.is_mongos:
|
||||
async with self.client.start_session() as session:
|
||||
self.assertFalse(session.in_transaction)
|
||||
session.start_transaction()
|
||||
self.assertTrue(session.in_transaction)
|
||||
await session.abort_transaction()
|
||||
self.assertFalse(session.in_transaction)
|
||||
|
||||
@async_client_context.require_sessions
|
||||
async def test_session_context_manager(self):
|
||||
"""Test session async context manager."""
|
||||
async with self.client.start_session() as session:
|
||||
self.assertFalse(session.has_ended)
|
||||
await self.db.test.find_one({}, session=session)
|
||||
self.assertTrue(session.has_ended)
|
||||
|
||||
@async_client_context.require_sessions
|
||||
async def test_session_context_manager_exception(self):
|
||||
"""Test session context manager closes on exception."""
|
||||
session = None
|
||||
try:
|
||||
async with self.client.start_session() as session:
|
||||
raise ValueError("test exception")
|
||||
except ValueError:
|
||||
pass
|
||||
self.assertTrue(session.has_ended)
|
||||
|
||||
@async_client_context.require_sessions
|
||||
async def test_session_operations_after_end(self):
|
||||
"""Test operations on ended session raise InvalidOperation."""
|
||||
session = self.client.start_session()
|
||||
await session.end_session()
|
||||
|
||||
with self.assertRaises(InvalidOperation):
|
||||
await self.db.test.find_one({}, session=session)
|
||||
|
||||
@async_client_context.require_sessions
|
||||
async def test_session_end_session_idempotent(self):
|
||||
"""Test that end_session can be called multiple times."""
|
||||
session = self.client.start_session()
|
||||
await session.end_session()
|
||||
# Second call should not raise
|
||||
await session.end_session()
|
||||
self.assertTrue(session.has_ended)
|
||||
|
||||
@async_client_context.require_transactions
|
||||
async def test_transaction_start_without_prior_transaction(self):
|
||||
"""Test start_transaction on fresh session."""
|
||||
async with self.client.start_session() as session:
|
||||
session.start_transaction()
|
||||
self.assertTrue(session.in_transaction)
|
||||
await session.abort_transaction()
|
||||
|
||||
@async_client_context.require_transactions
|
||||
async def test_transaction_start_twice_raises(self):
|
||||
"""Test starting transaction twice raises error."""
|
||||
async with self.client.start_session() as session:
|
||||
session.start_transaction()
|
||||
with self.assertRaises(InvalidOperation):
|
||||
session.start_transaction()
|
||||
await session.abort_transaction()
|
||||
|
||||
@async_client_context.require_transactions
|
||||
async def test_transaction_abort_without_transaction_raises(self):
|
||||
"""Test aborting without transaction raises error."""
|
||||
async with self.client.start_session() as session:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
await session.abort_transaction()
|
||||
|
||||
@async_client_context.require_transactions
|
||||
async def test_transaction_commit_without_transaction_raises(self):
|
||||
"""Test committing without transaction raises error."""
|
||||
async with self.client.start_session() as session:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
await session.commit_transaction()
|
||||
|
||||
@async_client_context.require_sessions
|
||||
async def test_session_advance_cluster_time_validation(self):
|
||||
"""Test advance_cluster_time with invalid input."""
|
||||
async with self.client.start_session() as session:
|
||||
with self.assertRaises(TypeError):
|
||||
session.advance_cluster_time("invalid") # type: ignore
|
||||
with self.assertRaises(ValueError):
|
||||
session.advance_cluster_time({})
|
||||
|
||||
@async_client_context.require_sessions
|
||||
async def test_session_advance_operation_time_validation(self):
|
||||
"""Test advance_operation_time with invalid input."""
|
||||
from bson import Timestamp
|
||||
|
||||
async with self.client.start_session() as session:
|
||||
with self.assertRaises(TypeError):
|
||||
session.advance_operation_time("invalid") # type: ignore
|
||||
# Valid Timestamp should work
|
||||
session.advance_operation_time(Timestamp(1, 1))
|
||||
|
||||
@async_client_context.require_transactions
|
||||
async def test_with_transaction_callback_success(self):
|
||||
"""Test with_transaction with successful callback."""
|
||||
async with self.client.start_session() as session:
|
||||
|
||||
async def callback(session):
|
||||
await self.db.test.insert_one({"x": 1}, session=session)
|
||||
return "success"
|
||||
|
||||
result = await session.with_transaction(callback)
|
||||
self.assertEqual("success", result)
|
||||
|
||||
@async_client_context.require_transactions
|
||||
async def test_with_transaction_callback_exception(self):
|
||||
"""Test with_transaction with callback exception."""
|
||||
async with self.client.start_session() as session:
|
||||
|
||||
async def callback(session):
|
||||
await self.db.test.insert_one({"x": 1}, session=session)
|
||||
raise ValueError("callback error")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
await session.with_transaction(callback)
|
||||
# Transaction should be aborted
|
||||
self.assertFalse(session.in_transaction)
|
||||
|
||||
|
||||
class AsyncTestSessionOptionsCoverage(AsyncUnitTest):
|
||||
"""Tests for SessionOptions coverage."""
|
||||
|
||||
def test_session_options_defaults(self):
|
||||
"""Test SessionOptions default values."""
|
||||
from pymongo.asynchronous.client_session import SessionOptions
|
||||
|
||||
options = SessionOptions()
|
||||
self.assertTrue(options.causal_consistency)
|
||||
self.assertIsNone(options.default_transaction_options)
|
||||
self.assertFalse(options.snapshot)
|
||||
|
||||
def test_session_options_snapshot_disables_causal_consistency(self):
|
||||
"""Test snapshot=True forces causal_consistency=False."""
|
||||
from pymongo.asynchronous.client_session import SessionOptions
|
||||
|
||||
options = SessionOptions(snapshot=True)
|
||||
self.assertFalse(options.causal_consistency)
|
||||
self.assertTrue(options.snapshot)
|
||||
|
||||
def test_session_options_snapshot_with_causal_raises(self):
|
||||
"""Test snapshot=True with causal_consistency=True raises error."""
|
||||
from pymongo.asynchronous.client_session import SessionOptions
|
||||
|
||||
with self.assertRaises(ConfigurationError):
|
||||
SessionOptions(snapshot=True, causal_consistency=True)
|
||||
|
||||
def test_session_options_invalid_transaction_options(self):
|
||||
"""Test SessionOptions with invalid transaction options type."""
|
||||
from pymongo.asynchronous.client_session import SessionOptions
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
SessionOptions(default_transaction_options="invalid") # type: ignore
|
||||
|
||||
|
||||
class AsyncTestTransactionOptionsCoverage(AsyncUnitTest):
|
||||
"""Tests for TransactionOptions coverage."""
|
||||
|
||||
def test_transaction_options_defaults(self):
|
||||
"""Test TransactionOptions default values."""
|
||||
from pymongo.asynchronous.client_session import TransactionOptions
|
||||
|
||||
options = TransactionOptions()
|
||||
self.assertIsNone(options.read_concern)
|
||||
self.assertIsNone(options.write_concern)
|
||||
self.assertIsNone(options.read_preference)
|
||||
self.assertIsNone(options.max_commit_time_ms)
|
||||
|
||||
def test_transaction_options_with_values(self):
|
||||
"""Test TransactionOptions with all values set."""
|
||||
from pymongo.asynchronous.client_session import TransactionOptions
|
||||
|
||||
options = TransactionOptions(
|
||||
read_concern=ReadConcern("majority"),
|
||||
write_concern=WriteConcern(w="majority"),
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
max_commit_time_ms=5000,
|
||||
)
|
||||
self.assertEqual("majority", options.read_concern.level)
|
||||
self.assertEqual("majority", options.write_concern.document.get("w"))
|
||||
self.assertEqual(ReadPreference.PRIMARY, options.read_preference)
|
||||
self.assertEqual(5000, options.max_commit_time_ms)
|
||||
|
||||
def test_transaction_options_invalid_read_concern(self):
|
||||
"""Test TransactionOptions with invalid read_concern type."""
|
||||
from pymongo.asynchronous.client_session import TransactionOptions
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
TransactionOptions(read_concern="invalid") # type: ignore
|
||||
|
||||
def test_transaction_options_invalid_write_concern(self):
|
||||
"""Test TransactionOptions with invalid write_concern type."""
|
||||
from pymongo.asynchronous.client_session import TransactionOptions
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
TransactionOptions(write_concern="invalid") # type: ignore
|
||||
|
||||
def test_transaction_options_invalid_read_preference(self):
|
||||
"""Test TransactionOptions with invalid read_preference type."""
|
||||
from pymongo.asynchronous.client_session import TransactionOptions
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
TransactionOptions(read_preference="invalid") # type: ignore
|
||||
|
||||
def test_transaction_options_invalid_max_commit_time(self):
|
||||
"""Test TransactionOptions with invalid max_commit_time_ms type."""
|
||||
from pymongo.asynchronous.client_session import TransactionOptions
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
TransactionOptions(max_commit_time_ms="invalid") # type: ignore
|
||||
|
||||
def test_transaction_options_unacknowledged_write_concern(self):
|
||||
"""Test TransactionOptions rejects unacknowledged write concern."""
|
||||
from pymongo.asynchronous.client_session import TransactionOptions
|
||||
|
||||
with self.assertRaises(ConfigurationError):
|
||||
TransactionOptions(write_concern=WriteConcern(w=0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -779,6 +779,225 @@ class TestBulk(BulkTestBase):
|
||||
self.assertEqual(6, result.inserted_count)
|
||||
self.assertEqual(6, self.coll.count_documents({}))
|
||||
|
||||
def test_bulk_write_with_comment(self):
|
||||
"""Test bulk write operations with comment parameter."""
|
||||
requests = [
|
||||
InsertOne({"x": 1}),
|
||||
UpdateOne({"x": 1}, {"$set": {"y": 1}}),
|
||||
DeleteOne({"x": 1}),
|
||||
]
|
||||
result = self.coll.bulk_write(requests, comment="bulk_comment")
|
||||
self.assertEqual(1, result.inserted_count)
|
||||
self.assertEqual(1, result.modified_count)
|
||||
self.assertEqual(1, result.deleted_count)
|
||||
|
||||
def test_bulk_write_with_let(self):
|
||||
"""Test bulk write operations with let parameter."""
|
||||
if not client_context.version.at_least(5, 0):
|
||||
self.skipTest("let parameter requires MongoDB 5.0+")
|
||||
|
||||
self.coll.insert_one({"x": 1})
|
||||
requests = [
|
||||
UpdateOne({"$expr": {"$eq": ["$x", "$$targetVal"]}}, {"$set": {"updated": True}}),
|
||||
]
|
||||
result = self.coll.bulk_write(requests, let={"targetVal": 1})
|
||||
self.assertEqual(1, result.modified_count)
|
||||
|
||||
def test_bulk_write_all_operation_types(self):
|
||||
"""Test bulk write with all operation types combined."""
|
||||
self.coll.insert_many([{"x": i} for i in range(5)])
|
||||
|
||||
requests = [
|
||||
InsertOne({"x": 100}),
|
||||
UpdateOne({"x": 0}, {"$set": {"updated": True}}),
|
||||
UpdateMany({"x": {"$lte": 2}}, {"$set": {"batch_updated": True}}),
|
||||
ReplaceOne({"x": 3}, {"x": 3, "replaced": True}),
|
||||
DeleteOne({"x": 4}),
|
||||
DeleteMany({"x": {"$gt": 50}}),
|
||||
]
|
||||
result = self.coll.bulk_write(requests)
|
||||
|
||||
self.assertEqual(1, result.inserted_count)
|
||||
self.assertGreaterEqual(result.modified_count, 1)
|
||||
self.assertGreaterEqual(result.deleted_count, 1)
|
||||
|
||||
def test_bulk_write_unordered(self):
|
||||
"""Test unordered bulk write continues after error."""
|
||||
self.coll.create_index([("x", 1)], unique=True)
|
||||
self.addCleanup(self.coll.drop_index, [("x", 1)])
|
||||
|
||||
requests = [
|
||||
InsertOne({"x": 1}),
|
||||
InsertOne({"x": 1}), # Duplicate - will error
|
||||
InsertOne({"x": 2}),
|
||||
InsertOne({"x": 3}),
|
||||
]
|
||||
|
||||
with self.assertRaises(BulkWriteError) as ctx:
|
||||
self.coll.bulk_write(requests, ordered=False)
|
||||
|
||||
# With unordered, should have inserted 3 documents
|
||||
self.assertEqual(3, ctx.exception.details["nInserted"])
|
||||
|
||||
def test_bulk_write_ordered(self):
|
||||
"""Test ordered bulk write stops on first error."""
|
||||
self.coll.create_index([("x", 1)], unique=True)
|
||||
self.addCleanup(self.coll.drop_index, [("x", 1)])
|
||||
|
||||
requests = [
|
||||
InsertOne({"x": 1}),
|
||||
InsertOne({"x": 1}), # Duplicate - will error
|
||||
InsertOne({"x": 2}),
|
||||
InsertOne({"x": 3}),
|
||||
]
|
||||
|
||||
with self.assertRaises(BulkWriteError) as ctx:
|
||||
self.coll.bulk_write(requests, ordered=True)
|
||||
|
||||
# With ordered, should have inserted only 1 document
|
||||
self.assertEqual(1, ctx.exception.details["nInserted"])
|
||||
|
||||
def test_bulk_write_bypass_document_validation(self):
|
||||
"""Test bulk write with bypass_document_validation."""
|
||||
if not client_context.version.at_least(3, 2):
|
||||
self.skipTest("bypass_document_validation requires MongoDB 3.2+")
|
||||
|
||||
# Create collection with validator
|
||||
self.coll.drop()
|
||||
self.db.create_collection(self.coll.name, validator={"$jsonSchema": {"required": ["name"]}})
|
||||
|
||||
# Without bypass, should fail
|
||||
with self.assertRaises(BulkWriteError):
|
||||
self.coll.bulk_write([InsertOne({"x": 1})])
|
||||
|
||||
# With bypass, should succeed
|
||||
result = self.coll.bulk_write([InsertOne({"x": 1})], bypass_document_validation=True)
|
||||
self.assertEqual(1, result.inserted_count)
|
||||
|
||||
def test_bulk_write_result_properties(self):
|
||||
"""Test all BulkWriteResult properties."""
|
||||
self.coll.insert_one({"x": 1})
|
||||
|
||||
requests = [
|
||||
InsertOne({"x": 2}),
|
||||
UpdateOne({"x": 1}, {"$set": {"updated": True}}),
|
||||
ReplaceOne({"x": 2}, {"x": 2, "replaced": True}, upsert=True),
|
||||
DeleteOne({"x": 1}),
|
||||
]
|
||||
result = self.coll.bulk_write(requests)
|
||||
|
||||
# Check all properties
|
||||
self.assertTrue(result.acknowledged)
|
||||
self.assertEqual(1, result.inserted_count)
|
||||
self.assertGreaterEqual(result.matched_count, 0)
|
||||
self.assertGreaterEqual(result.modified_count, 0)
|
||||
self.assertEqual(1, result.deleted_count)
|
||||
self.assertIsInstance(result.upserted_count, int)
|
||||
self.assertIsInstance(result.upserted_ids, dict)
|
||||
|
||||
def test_bulk_write_with_upsert(self):
|
||||
"""Test bulk write upsert operations."""
|
||||
requests = [
|
||||
UpdateOne({"x": 1}, {"$set": {"y": 1}}, upsert=True),
|
||||
UpdateOne({"x": 2}, {"$set": {"y": 2}}, upsert=True),
|
||||
ReplaceOne({"x": 3}, {"x": 3, "y": 3}, upsert=True),
|
||||
]
|
||||
result = self.coll.bulk_write(requests)
|
||||
|
||||
self.assertEqual(3, result.upserted_count)
|
||||
self.assertEqual(3, len(result.upserted_ids))
|
||||
|
||||
def test_update_one_with_hint(self):
|
||||
"""Test UpdateOne with hint parameter."""
|
||||
self.coll.create_index([("x", 1)])
|
||||
self.addCleanup(self.coll.drop_index, [("x", 1)])
|
||||
|
||||
self.coll.insert_one({"x": 1})
|
||||
|
||||
requests = [UpdateOne({"x": 1}, {"$set": {"y": 1}}, hint=[("x", 1)])]
|
||||
result = self.coll.bulk_write(requests)
|
||||
self.assertEqual(1, result.modified_count)
|
||||
|
||||
def test_update_many_with_hint(self):
|
||||
"""Test UpdateMany with hint parameter."""
|
||||
self.coll.create_index([("x", 1)])
|
||||
self.addCleanup(self.coll.drop_index, [("x", 1)])
|
||||
|
||||
self.coll.insert_many([{"x": 1}, {"x": 1}])
|
||||
|
||||
requests = [UpdateMany({"x": 1}, {"$set": {"y": 1}}, hint=[("x", 1)])]
|
||||
result = self.coll.bulk_write(requests)
|
||||
self.assertEqual(2, result.modified_count)
|
||||
|
||||
def test_delete_one_with_hint(self):
|
||||
"""Test DeleteOne with hint parameter."""
|
||||
self.coll.create_index([("x", 1)])
|
||||
self.addCleanup(self.coll.drop_index, [("x", 1)])
|
||||
|
||||
self.coll.insert_one({"x": 1})
|
||||
|
||||
requests = [DeleteOne({"x": 1}, hint=[("x", 1)])]
|
||||
result = self.coll.bulk_write(requests)
|
||||
self.assertEqual(1, result.deleted_count)
|
||||
|
||||
def test_delete_many_with_hint(self):
|
||||
"""Test DeleteMany with hint parameter."""
|
||||
self.coll.create_index([("x", 1)])
|
||||
self.addCleanup(self.coll.drop_index, [("x", 1)])
|
||||
|
||||
self.coll.insert_many([{"x": 1}, {"x": 1}])
|
||||
|
||||
requests = [DeleteMany({"x": 1}, hint=[("x", 1)])]
|
||||
result = self.coll.bulk_write(requests)
|
||||
self.assertEqual(2, result.deleted_count)
|
||||
|
||||
def test_update_one_with_array_filters(self):
|
||||
"""Test UpdateOne with array_filters parameter."""
|
||||
self.coll.insert_one({"x": [{"y": 1}, {"y": 2}, {"y": 3}]})
|
||||
|
||||
requests = [
|
||||
UpdateOne({}, {"$set": {"x.$[elem].z": 1}}, array_filters=[{"elem.y": {"$gt": 1}}])
|
||||
]
|
||||
result = self.coll.bulk_write(requests)
|
||||
self.assertEqual(1, result.modified_count)
|
||||
|
||||
doc = self.coll.find_one()
|
||||
# Elements with y > 1 should have z = 1
|
||||
for elem in doc["x"]:
|
||||
if elem["y"] > 1:
|
||||
self.assertEqual(1, elem.get("z"))
|
||||
|
||||
def test_replace_one_with_hint(self):
|
||||
"""Test ReplaceOne with hint parameter."""
|
||||
self.coll.create_index([("x", 1)])
|
||||
self.addCleanup(self.coll.drop_index, [("x", 1)])
|
||||
|
||||
self.coll.insert_one({"x": 1})
|
||||
|
||||
requests = [ReplaceOne({"x": 1}, {"x": 1, "replaced": True}, hint=[("x", 1)])]
|
||||
result = self.coll.bulk_write(requests)
|
||||
self.assertEqual(1, result.modified_count)
|
||||
|
||||
def test_update_with_collation(self):
|
||||
"""Test update operations with collation."""
|
||||
self.coll.insert_many(
|
||||
[
|
||||
{"name": "cafe"},
|
||||
{"name": "Cafe"},
|
||||
]
|
||||
)
|
||||
|
||||
requests = [
|
||||
UpdateMany(
|
||||
{"name": "cafe"},
|
||||
{"$set": {"updated": True}},
|
||||
collation={"locale": "en", "strength": 2},
|
||||
)
|
||||
]
|
||||
result = self.coll.bulk_write(requests)
|
||||
# With case-insensitive collation, both docs should match
|
||||
self.assertEqual(2, result.modified_count)
|
||||
|
||||
|
||||
class BulkAuthorizationTestBase(BulkTestBase):
|
||||
@client_context.require_auth
|
||||
|
||||
@ -1132,5 +1132,122 @@ globals().update(
|
||||
)
|
||||
|
||||
|
||||
class TestChangeStreamCoverage(TestCollectionChangeStream):
|
||||
"""Additional tests to improve code coverage for ChangeStream."""
|
||||
|
||||
def test_change_stream_alive_property(self):
|
||||
"""Test alive property state transitions."""
|
||||
with self.change_stream() as cs:
|
||||
self.assertTrue(cs.alive)
|
||||
# After context exit, should be closed
|
||||
self.assertFalse(cs.alive)
|
||||
|
||||
def test_change_stream_idempotent_close(self):
|
||||
"""Test that close() can be called multiple times safely."""
|
||||
cs = self.change_stream()
|
||||
cs.close()
|
||||
# Second close should not raise
|
||||
cs.close()
|
||||
self.assertFalse(cs.alive)
|
||||
|
||||
def test_change_stream_resume_token_deepcopy(self):
|
||||
"""Test that resume_token returns a deep copy."""
|
||||
coll = self.watched_collection()
|
||||
with self.change_stream() as cs:
|
||||
coll.insert_one({"x": 1})
|
||||
next(cs) # Consume the change event
|
||||
token1 = cs.resume_token
|
||||
token2 = cs.resume_token
|
||||
# Should be equal but different objects
|
||||
self.assertEqual(token1, token2)
|
||||
self.assertIsNot(token1, token2)
|
||||
|
||||
def test_change_stream_with_comment(self):
|
||||
"""Test change stream with comment parameter."""
|
||||
client, listener = self.client_with_listener("aggregate")
|
||||
try:
|
||||
with self.change_stream_with_client(client, comment="test_comment"):
|
||||
pass
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
# Check that comment was in the aggregate command
|
||||
self.assertGreater(len(listener.started_events), 0)
|
||||
cmd = listener.started_events[0].command
|
||||
self.assertEqual("test_comment", cmd.get("comment"))
|
||||
|
||||
def test_change_stream_with_show_expanded_events(self):
|
||||
"""Test change stream with show_expanded_events parameter."""
|
||||
if not client_context.version.at_least(6, 0):
|
||||
self.skipTest("show_expanded_events requires MongoDB 6.0+")
|
||||
|
||||
with self.change_stream(show_expanded_events=True) as cs:
|
||||
# Just verify it doesn't error
|
||||
self.assertTrue(cs.alive)
|
||||
|
||||
@client_context.require_version_min(6, 0)
|
||||
def test_change_stream_with_full_document_before_change(self):
|
||||
"""Test change stream with full_document_before_change parameter."""
|
||||
coll = self.watched_collection()
|
||||
# Need to ensure collection exists with changeStreamPreAndPostImages enabled
|
||||
coll.drop()
|
||||
self.db.create_collection(coll.name, changeStreamPreAndPostImages={"enabled": True})
|
||||
coll.insert_one({"x": 1})
|
||||
|
||||
with self.change_stream(full_document_before_change="whenAvailable") as cs:
|
||||
coll.update_one({"x": 1}, {"$set": {"x": 2}})
|
||||
change = next(cs)
|
||||
self.assertEqual("update", change["operationType"])
|
||||
# fullDocumentBeforeChange should be present
|
||||
self.assertIn("fullDocumentBeforeChange", change)
|
||||
|
||||
def test_change_stream_next_after_close(self):
|
||||
"""Test that next() on closed stream raises StopIteration."""
|
||||
cs = self.change_stream()
|
||||
cs.close()
|
||||
with self.assertRaises(StopIteration):
|
||||
next(cs)
|
||||
|
||||
def test_change_stream_try_next_after_close(self):
|
||||
"""Test that try_next() on closed stream raises StopIteration."""
|
||||
cs = self.change_stream()
|
||||
cs.close()
|
||||
with self.assertRaises(StopIteration):
|
||||
cs.try_next()
|
||||
|
||||
def test_change_stream_pipeline_construction(self):
|
||||
"""Test change stream pipeline is properly constructed."""
|
||||
pipeline = [{"$match": {"operationType": "insert"}}]
|
||||
client, listener = self.client_with_listener("aggregate")
|
||||
try:
|
||||
with self.change_stream_with_client(client, pipeline=pipeline):
|
||||
pass
|
||||
finally:
|
||||
client.close()
|
||||
|
||||
cmd = listener.started_events[0].command
|
||||
agg_pipeline = cmd["pipeline"]
|
||||
# First stage should be $changeStream
|
||||
self.assertIn("$changeStream", agg_pipeline[0])
|
||||
# Second stage should be our match
|
||||
self.assertEqual({"$match": {"operationType": "insert"}}, agg_pipeline[1])
|
||||
|
||||
def test_change_stream_empty_pipeline(self):
|
||||
"""Test change stream with empty pipeline."""
|
||||
with self.change_stream(pipeline=[]) as cs:
|
||||
self.assertTrue(cs.alive)
|
||||
|
||||
def test_change_stream_context_manager_exception(self):
|
||||
"""Test change stream context manager closes on exception."""
|
||||
cs = None
|
||||
try:
|
||||
with self.change_stream() as cs:
|
||||
raise ValueError("test exception")
|
||||
except ValueError:
|
||||
pass
|
||||
# Stream should be closed
|
||||
self.assertFalse(cs.alive)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -2238,5 +2238,262 @@ class TestCollection(IntegrationTest):
|
||||
helper(*args, let={}) # type: ignore
|
||||
|
||||
|
||||
class TestCollectionCoverage(IntegrationTest):
|
||||
"""Additional tests to improve code coverage for Collection."""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.db.test.drop()
|
||||
self.db.test.insert_many([{"x": i, "y": i * 2} for i in range(10)])
|
||||
|
||||
def test_collection_full_name(self):
|
||||
"""Test full_name property."""
|
||||
expected = f"{self.db.name}.test"
|
||||
self.assertEqual(expected, self.db.test.full_name)
|
||||
|
||||
def test_collection_name(self):
|
||||
"""Test name property."""
|
||||
self.assertEqual("test", self.db.test.name)
|
||||
|
||||
def test_collection_database(self):
|
||||
"""Test database property."""
|
||||
self.assertEqual(self.db, self.db.test.database)
|
||||
|
||||
def test_collection_equality(self):
|
||||
"""Test collection equality."""
|
||||
coll1 = self.db.test
|
||||
coll2 = self.db.test
|
||||
coll3 = self.db.other
|
||||
self.assertEqual(coll1, coll2)
|
||||
self.assertNotEqual(coll1, coll3)
|
||||
|
||||
def test_collection_hash(self):
|
||||
"""Test collection hashability."""
|
||||
coll1 = self.db.test
|
||||
coll2 = self.db.test
|
||||
# Same collection should have same hash
|
||||
self.assertEqual(hash(coll1), hash(coll2))
|
||||
# Collections can be used in sets
|
||||
s = {coll1, coll2}
|
||||
self.assertEqual(1, len(s))
|
||||
|
||||
def test_collection_repr(self):
|
||||
"""Test collection repr."""
|
||||
coll = self.db.test
|
||||
repr_str = repr(coll)
|
||||
self.assertIn("test", repr_str)
|
||||
self.assertIn("Collection", repr_str)
|
||||
|
||||
def test_collection_getattr(self):
|
||||
"""Test sub-collection access via attribute."""
|
||||
subcoll = self.db.test.subcollection
|
||||
self.assertEqual("test.subcollection", subcoll.name)
|
||||
|
||||
def test_collection_getitem(self):
|
||||
"""Test sub-collection access via indexing."""
|
||||
subcoll = self.db.test["subcollection"]
|
||||
self.assertEqual("test.subcollection", subcoll.name)
|
||||
|
||||
def test_collection_with_options(self):
|
||||
"""Test with_options creates new collection with options."""
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
coll = self.db.test.with_options(
|
||||
read_concern=ReadConcern("majority"), write_concern=WriteConcern(w=1)
|
||||
)
|
||||
self.assertEqual("majority", coll.read_concern.level)
|
||||
self.assertEqual({"w": 1}, coll.write_concern.document)
|
||||
# Original should be unchanged
|
||||
self.assertNotEqual("majority", self.db.test.read_concern.level)
|
||||
|
||||
def test_collection_drop(self):
|
||||
"""Test collection drop."""
|
||||
self.db.test_drop.insert_one({"x": 1})
|
||||
self.db.test_drop.drop()
|
||||
names = self.db.list_collection_names()
|
||||
self.assertNotIn("test_drop", names)
|
||||
|
||||
def test_collection_drop_with_comment(self):
|
||||
"""Test collection drop with comment."""
|
||||
self.db.test_drop_comment.insert_one({"x": 1})
|
||||
self.db.test_drop_comment.drop(comment="test_comment")
|
||||
names = self.db.list_collection_names()
|
||||
self.assertNotIn("test_drop_comment", names)
|
||||
|
||||
def test_find_raw_batches(self):
|
||||
"""Test find_raw_batches returns raw BSON."""
|
||||
from bson import decode_all
|
||||
|
||||
cursor = self.db.test.find_raw_batches(batch_size=5)
|
||||
batch_count = 0
|
||||
for batch in cursor:
|
||||
self.assertIsInstance(batch, bytes)
|
||||
docs = decode_all(batch)
|
||||
self.assertGreater(len(docs), 0)
|
||||
batch_count += 1
|
||||
self.assertGreater(batch_count, 0)
|
||||
|
||||
def test_aggregate_raw_batches(self):
|
||||
"""Test aggregate_raw_batches returns raw BSON."""
|
||||
from bson import decode_all
|
||||
|
||||
cursor = self.db.test.aggregate_raw_batches([{"$sort": {"x": 1}}], batchSize=5)
|
||||
batch_count = 0
|
||||
for batch in cursor:
|
||||
self.assertIsInstance(batch, bytes)
|
||||
docs = decode_all(batch)
|
||||
self.assertGreater(len(docs), 0)
|
||||
batch_count += 1
|
||||
self.assertGreater(batch_count, 0)
|
||||
|
||||
def test_distinct_with_collation(self):
|
||||
"""Test distinct with collation."""
|
||||
self.db.test.drop()
|
||||
self.db.test.insert_many(
|
||||
[
|
||||
{"name": "abc"},
|
||||
{"name": "ABC"},
|
||||
{"name": "def"},
|
||||
]
|
||||
)
|
||||
# Case-insensitive distinct
|
||||
values = self.db.test.distinct("name", collation={"locale": "en_US", "strength": 2})
|
||||
# abc and ABC should be considered the same
|
||||
self.assertEqual(2, len(values))
|
||||
|
||||
def test_count_documents_with_options(self):
|
||||
"""Test count_documents with skip, limit, hint."""
|
||||
self.db.test.create_index([("x", 1)])
|
||||
|
||||
count = self.db.test.count_documents({"x": {"$gte": 0}}, skip=2, limit=5, hint=[("x", 1)])
|
||||
self.assertEqual(5, count)
|
||||
|
||||
def test_estimated_document_count(self):
|
||||
"""Test estimated_document_count."""
|
||||
count = self.db.test.estimated_document_count()
|
||||
self.assertEqual(10, count)
|
||||
|
||||
def test_estimated_document_count_with_options(self):
|
||||
"""Test estimated_document_count with maxTimeMS and comment."""
|
||||
count = self.db.test.estimated_document_count(maxTimeMS=5000, comment="test_comment")
|
||||
self.assertEqual(10, count)
|
||||
|
||||
def test_find_one_and_delete_with_options(self):
|
||||
"""Test find_one_and_delete with projection, sort."""
|
||||
doc = self.db.test.find_one_and_delete(
|
||||
{"x": {"$gte": 0}}, projection={"x": 1}, sort=[("x", -1)]
|
||||
)
|
||||
self.assertEqual(9, doc["x"])
|
||||
self.assertNotIn("y", doc)
|
||||
|
||||
def test_find_one_and_replace_with_options(self):
|
||||
"""Test find_one_and_replace with various options."""
|
||||
from pymongo import ReturnDocument
|
||||
|
||||
doc = self.db.test.find_one_and_replace(
|
||||
{"x": 0},
|
||||
{"x": 0, "replaced": True},
|
||||
projection={"x": 1, "replaced": 1},
|
||||
return_document=ReturnDocument.AFTER,
|
||||
)
|
||||
self.assertEqual(0, doc["x"])
|
||||
self.assertTrue(doc.get("replaced"))
|
||||
|
||||
def test_find_one_and_update_with_options(self):
|
||||
"""Test find_one_and_update with various options."""
|
||||
from pymongo import ReturnDocument
|
||||
|
||||
doc = self.db.test.find_one_and_update(
|
||||
{"x": 0},
|
||||
{"$set": {"updated": True}},
|
||||
projection={"x": 1, "updated": 1},
|
||||
return_document=ReturnDocument.AFTER,
|
||||
)
|
||||
self.assertEqual(0, doc["x"])
|
||||
self.assertTrue(doc.get("updated"))
|
||||
|
||||
def test_update_one_with_array_filters(self):
|
||||
"""Test update_one with array_filters."""
|
||||
self.db.test.drop()
|
||||
self.db.test.insert_one({"items": [{"v": 1}, {"v": 2}, {"v": 3}]})
|
||||
|
||||
result = self.db.test.update_one(
|
||||
{}, {"$set": {"items.$[elem].updated": True}}, array_filters=[{"elem.v": {"$gt": 1}}]
|
||||
)
|
||||
self.assertEqual(1, result.modified_count)
|
||||
|
||||
def test_update_many_with_hint(self):
|
||||
"""Test update_many with hint."""
|
||||
self.db.test.create_index([("x", 1)])
|
||||
|
||||
result = self.db.test.update_many(
|
||||
{"x": {"$gte": 0}}, {"$set": {"batch_updated": True}}, hint=[("x", 1)]
|
||||
)
|
||||
self.assertEqual(10, result.modified_count)
|
||||
|
||||
def test_delete_one_with_hint(self):
|
||||
"""Test delete_one with hint."""
|
||||
self.db.test.create_index([("x", 1)])
|
||||
|
||||
result = self.db.test.delete_one({"x": 0}, hint=[("x", 1)])
|
||||
self.assertEqual(1, result.deleted_count)
|
||||
|
||||
def test_delete_many_with_hint(self):
|
||||
"""Test delete_many with hint."""
|
||||
self.db.test.create_index([("x", 1)])
|
||||
|
||||
result = self.db.test.delete_many({"x": {"$lt": 5}}, hint=[("x", 1)])
|
||||
self.assertEqual(5, result.deleted_count)
|
||||
|
||||
def test_aggregate_with_let(self):
|
||||
"""Test aggregate with let parameter."""
|
||||
if not client_context.version.at_least(5, 0):
|
||||
self.skipTest("let parameter requires MongoDB 5.0+")
|
||||
|
||||
pipeline = [{"$match": {"$expr": {"$eq": ["$x", "$$targetVal"]}}}]
|
||||
cursor = self.db.test.aggregate(pipeline, let={"targetVal": 5})
|
||||
docs = cursor.to_list()
|
||||
self.assertEqual(1, len(docs))
|
||||
self.assertEqual(5, docs[0]["x"])
|
||||
|
||||
def test_aggregate_with_batch_size(self):
|
||||
"""Test aggregate with batchSize."""
|
||||
cursor = self.db.test.aggregate([{"$sort": {"x": 1}}], batchSize=2)
|
||||
docs = cursor.to_list()
|
||||
self.assertEqual(10, len(docs))
|
||||
|
||||
def test_list_indexes(self):
|
||||
"""Test list_indexes returns cursor."""
|
||||
self.db.test.create_index([("x", 1)])
|
||||
cursor = self.db.test.list_indexes()
|
||||
|
||||
# Should get at least the _id index
|
||||
indexes = cursor.to_list()
|
||||
self.assertGreaterEqual(len(indexes), 1)
|
||||
index_names = [idx["name"] for idx in indexes]
|
||||
self.assertIn("_id_", index_names)
|
||||
|
||||
def test_index_information(self):
|
||||
"""Test index_information returns dict."""
|
||||
self.db.test.create_index([("x", 1)], name="x_index")
|
||||
info = self.db.test.index_information()
|
||||
|
||||
self.assertIsInstance(info, dict)
|
||||
self.assertIn("_id_", info)
|
||||
self.assertIn("x_index", info)
|
||||
|
||||
def test_options_method(self):
|
||||
"""Test options() returns collection options."""
|
||||
# Create a capped collection
|
||||
self.db.drop_collection("test_capped")
|
||||
self.db.create_collection("test_capped", capped=True, size=10000)
|
||||
|
||||
opts = self.db.test_capped.options()
|
||||
self.assertTrue(opts.get("capped"))
|
||||
|
||||
self.db.drop_collection("test_capped")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -1853,5 +1853,404 @@ class TestRawBatchCommandCursor(IntegrationTest):
|
||||
self.assertEqual(cmd.command["$db"], "pymongo_test")
|
||||
|
||||
|
||||
class TestCursorCoverage(IntegrationTest):
|
||||
"""Additional tests to improve code coverage for Cursor."""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.db.test.drop()
|
||||
self.db.test.insert_many([{"x": i, "y": i * 2} for i in range(10)])
|
||||
|
||||
def test_get_namespace(self):
|
||||
"""Test _get_namespace() method."""
|
||||
cursor = self.db.test.find()
|
||||
expected_ns = f"{self.db.name}.test"
|
||||
self.assertEqual(expected_ns, cursor._get_namespace())
|
||||
|
||||
def test_cursor_alive_property_states(self):
|
||||
"""Test cursor alive property in different states."""
|
||||
cursor = self.db.test.find()
|
||||
# Cursor is alive even before starting (has potential to return data)
|
||||
self.assertTrue(cursor.alive)
|
||||
|
||||
# Start the cursor
|
||||
next(cursor)
|
||||
self.assertTrue(cursor.alive)
|
||||
|
||||
# Exhaust the cursor
|
||||
cursor.to_list()
|
||||
self.assertFalse(cursor.alive)
|
||||
|
||||
def test_cursor_closed_property(self):
|
||||
"""Test cursor behavior after close."""
|
||||
cursor = self.db.test.find()
|
||||
next(cursor)
|
||||
self.assertTrue(cursor.alive)
|
||||
|
||||
cursor.close()
|
||||
# After close, cursor is killed (check internal _killed flag)
|
||||
self.assertTrue(cursor._killed)
|
||||
|
||||
def test_retrieved_property(self):
|
||||
"""Test the retrieved property tracking."""
|
||||
cursor = self.db.test.find().batch_size(2)
|
||||
self.assertEqual(0, cursor.retrieved)
|
||||
|
||||
next(cursor)
|
||||
self.assertGreater(cursor.retrieved, 0)
|
||||
|
||||
def test_cursor_with_let_parameter(self):
|
||||
"""Test cursor with let parameter."""
|
||||
# let parameter allows variables to be used in the filter
|
||||
cursor = self.db.test.find(
|
||||
{"$expr": {"$eq": ["$x", "$$targetValue"]}}, let={"targetValue": 5}
|
||||
)
|
||||
docs = cursor.to_list()
|
||||
self.assertEqual(1, len(docs))
|
||||
self.assertEqual(5, docs[0]["x"])
|
||||
|
||||
def test_cursor_with_invalid_let_parameter(self):
|
||||
"""Test cursor raises error for invalid let parameter."""
|
||||
with self.assertRaises(TypeError):
|
||||
self.db.test.find(let="invalid") # type: ignore[arg-type]
|
||||
|
||||
def test_cursor_with_show_record_id(self):
|
||||
"""Test cursor with show_record_id option."""
|
||||
cursor = self.db.test.find(show_record_id=True)
|
||||
doc = next(cursor)
|
||||
self.assertIn("$recordId", doc)
|
||||
|
||||
def test_cursor_with_return_key(self):
|
||||
"""Test cursor with return_key option."""
|
||||
self.db.test.create_index([("x", ASCENDING)])
|
||||
cursor = self.db.test.find({"x": 5}, return_key=True).hint([("x", ASCENDING)])
|
||||
doc = next(cursor)
|
||||
# return_key returns only index keys
|
||||
self.assertIn("x", doc)
|
||||
self.assertNotIn("y", doc)
|
||||
|
||||
def test_check_okay_to_chain_after_iteration(self):
|
||||
"""Test that cursor configuration methods raise after iteration."""
|
||||
cursor = self.db.test.find()
|
||||
next(cursor) # Start iteration
|
||||
|
||||
# All these should raise InvalidOperation
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.limit(5)
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.skip(2)
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.sort("x")
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.hint([("x", ASCENDING)])
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.max([("x", 10)])
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.min([("x", 0)])
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.add_option(2)
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.remove_option(2)
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.batch_size(10)
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.max_time_ms(1000)
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.collation(Collation("en_US"))
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.allow_disk_use(True)
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.where("this.x > 5")
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor.comment("test")
|
||||
|
||||
def test_cursor_context_manager(self):
|
||||
"""Test cursor as async context manager."""
|
||||
with self.db.test.find() as cursor:
|
||||
doc = next(cursor)
|
||||
self.assertIsNotNone(doc)
|
||||
# Cursor should be killed after context (check _killed flag)
|
||||
self.assertTrue(cursor._killed)
|
||||
|
||||
def test_cursor_context_manager_with_exception(self):
|
||||
"""Test cursor context manager closes on exception."""
|
||||
cursor = None
|
||||
try:
|
||||
with self.db.test.find() as cursor:
|
||||
next(cursor)
|
||||
raise ValueError("test exception")
|
||||
except ValueError:
|
||||
pass
|
||||
# Cursor should be killed after exception
|
||||
self.assertTrue(cursor._killed)
|
||||
|
||||
def test_cursor_collation(self):
|
||||
"""Test cursor with collation."""
|
||||
self.db.test.drop()
|
||||
self.db.test.insert_many([{"name": "abc"}, {"name": "ABC"}, {"name": "def"}])
|
||||
# Case-insensitive sort
|
||||
cursor = (
|
||||
self.db.test.find().collation(Collation("en_US", strength=2)).sort("name", ASCENDING)
|
||||
)
|
||||
docs = cursor.to_list()
|
||||
self.assertEqual(3, len(docs))
|
||||
|
||||
def test_cursor_collation_type_error(self):
|
||||
"""Test cursor raises error for invalid collation."""
|
||||
with self.assertRaises(TypeError):
|
||||
self.db.test.find().collation("invalid") # type: ignore[arg-type]
|
||||
|
||||
def test_cursor_getitem_not_supported(self):
|
||||
"""Test that Cursor does not support indexing."""
|
||||
cursor = self.db.test.find()
|
||||
with self.assertRaises(IndexError) as ctx:
|
||||
cursor[5]
|
||||
self.assertIn("does not support indexing", str(ctx.exception))
|
||||
|
||||
def test_cursor_next_after_close(self):
|
||||
"""Test that next() raises StopIteration after close."""
|
||||
cursor = self.db.test.find()
|
||||
cursor.close()
|
||||
with self.assertRaises(StopIteration):
|
||||
next(cursor)
|
||||
|
||||
def test_cursor_rewind_resets_state(self):
|
||||
"""Test that rewind properly resets cursor state."""
|
||||
cursor = self.db.test.find().limit(3)
|
||||
|
||||
# Iterate fully
|
||||
docs1 = cursor.to_list()
|
||||
self.assertEqual(3, len(docs1))
|
||||
self.assertEqual(0, len(cursor._data))
|
||||
|
||||
# Rewind and iterate again
|
||||
cursor.rewind()
|
||||
docs2 = cursor.to_list()
|
||||
self.assertEqual(3, len(docs2))
|
||||
self.assertEqual(docs1, docs2)
|
||||
|
||||
def test_cursor_clone_with_session(self):
|
||||
"""Test that clone preserves explicit session."""
|
||||
with self.client.start_session() as session:
|
||||
cursor = self.db.test.find(session=session)
|
||||
cloned = cursor.clone()
|
||||
# Clone should reference the same session
|
||||
self.assertEqual(cursor.session, cloned.session)
|
||||
|
||||
def test_cursor_clone_without_session(self):
|
||||
"""Test that clone without session doesn't add one."""
|
||||
cursor = self.db.test.find()
|
||||
cloned = cursor.clone()
|
||||
# Clone should have no session if original had none
|
||||
self.assertIsNone(cloned.session)
|
||||
|
||||
def test_cursor_distinct_with_collation(self):
|
||||
"""Test distinct with collation."""
|
||||
self.db.test.drop()
|
||||
self.db.test.insert_many([{"name": "abc"}, {"name": "ABC"}, {"name": "def"}])
|
||||
# Case-insensitive distinct
|
||||
cursor = self.db.test.find().collation(Collation("en_US", strength=2))
|
||||
# distinct() on cursor with collation
|
||||
values = cursor.distinct("name")
|
||||
# Should have 2 distinct values (abc/ABC treated as same)
|
||||
self.assertEqual(2, len(values))
|
||||
|
||||
def test_cursor_explain_with_options(self):
|
||||
"""Test explain with cursor options set."""
|
||||
cursor = self.db.test.find({"x": {"$gt": 5}}).sort("x", ASCENDING).limit(5).skip(1)
|
||||
explanation = cursor.explain()
|
||||
self.assertIn("queryPlanner", explanation)
|
||||
|
||||
def test_cursor_max_time_ms_type_errors(self):
|
||||
"""Test max_time_ms raises TypeError for invalid input."""
|
||||
cursor = self.db.test.find()
|
||||
with self.assertRaises(TypeError):
|
||||
cursor.max_time_ms("invalid") # type: ignore[arg-type]
|
||||
|
||||
def test_cursor_max_await_time_ms_type_errors(self):
|
||||
"""Test max_await_time_ms raises TypeError for invalid input."""
|
||||
cursor = self.db.test.find()
|
||||
with self.assertRaises(TypeError):
|
||||
cursor.max_await_time_ms("invalid") # type: ignore[arg-type]
|
||||
|
||||
def test_cursor_comment_type(self):
|
||||
"""Test cursor with comment of various types."""
|
||||
# String comment
|
||||
cursor1 = self.db.test.find().comment("test comment")
|
||||
docs1 = cursor1.to_list()
|
||||
self.assertGreater(len(docs1), 0)
|
||||
|
||||
# Dict comment
|
||||
cursor2 = self.db.test.find().comment({"key": "value"})
|
||||
docs2 = cursor2.to_list()
|
||||
self.assertGreater(len(docs2), 0)
|
||||
|
||||
def test_cursor_batch_size_validation(self):
|
||||
"""Test batch_size validation."""
|
||||
with self.assertRaises(TypeError):
|
||||
self.db.test.find(batch_size="invalid") # type: ignore[arg-type]
|
||||
with self.assertRaises(ValueError):
|
||||
self.db.test.find(batch_size=-1)
|
||||
|
||||
def test_cursor_skip_validation(self):
|
||||
"""Test skip validation."""
|
||||
with self.assertRaises(TypeError):
|
||||
self.db.test.find(skip="invalid") # type: ignore[arg-type]
|
||||
|
||||
def test_cursor_limit_validation(self):
|
||||
"""Test limit validation."""
|
||||
with self.assertRaises(TypeError):
|
||||
self.db.test.find(limit="invalid") # type: ignore[arg-type]
|
||||
|
||||
def test_cursor_filter_validation(self):
|
||||
"""Test filter validation."""
|
||||
with self.assertRaises(TypeError):
|
||||
self.db.test.find(filter="invalid") # type: ignore[arg-type]
|
||||
|
||||
def test_cursor_type_validation(self):
|
||||
"""Test cursor_type validation."""
|
||||
with self.assertRaises(ValueError):
|
||||
self.db.test.find(cursor_type=999)
|
||||
|
||||
def test_cursor_query_spec_with_modifiers(self):
|
||||
"""Test _query_spec includes modifiers."""
|
||||
cursor = (
|
||||
self.db.test.find()
|
||||
.sort("x", ASCENDING)
|
||||
.hint([("x", ASCENDING)])
|
||||
.max_time_ms(1000)
|
||||
.comment("test")
|
||||
)
|
||||
spec = cursor._query_spec()
|
||||
self.assertIsInstance(spec, dict)
|
||||
|
||||
def test_cursor_copy(self):
|
||||
"""Test cursor __copy__ returns clone."""
|
||||
cursor = self.db.test.find().limit(5)
|
||||
copied = copy.copy(cursor)
|
||||
self.assertIsNot(cursor, copied)
|
||||
self.assertEqual(cursor._limit, copied._limit)
|
||||
|
||||
def test_cursor_deepcopy(self):
|
||||
"""Test cursor __deepcopy__ returns deep clone."""
|
||||
cursor = self.db.test.find({"x": {"$gt": 0}}).limit(5)
|
||||
copied = copy.deepcopy(cursor)
|
||||
self.assertIsNot(cursor, copied)
|
||||
self.assertEqual(cursor._limit, copied._limit)
|
||||
self.assertEqual(cursor._spec, copied._spec)
|
||||
# Spec should be a different object
|
||||
self.assertIsNot(cursor._spec, copied._spec)
|
||||
|
||||
def test_cursor_iteration_protocol(self):
|
||||
"""Test cursor async iteration protocol."""
|
||||
cursor = self.db.test.find().limit(3)
|
||||
|
||||
# Test __iter__ returns self
|
||||
self.assertIs(cursor, cursor.__iter__())
|
||||
|
||||
# Test __next__ returns documents
|
||||
doc1 = cursor.__next__()
|
||||
self.assertIsNotNone(doc1)
|
||||
|
||||
def test_cursor_to_list_with_limit(self):
|
||||
"""Test to_list respects cursor limit."""
|
||||
cursor = self.db.test.find().limit(3)
|
||||
docs = cursor.to_list()
|
||||
self.assertEqual(3, len(docs))
|
||||
|
||||
def test_cursor_to_list_with_length(self):
|
||||
"""Test to_list with length parameter."""
|
||||
cursor = self.db.test.find()
|
||||
docs = cursor.to_list(length=3)
|
||||
self.assertEqual(3, len(docs))
|
||||
|
||||
def test_min_max_require_hint(self):
|
||||
"""Test that min/max require hint for proper execution."""
|
||||
self.db.test.create_index([("x", ASCENDING)])
|
||||
|
||||
# min without hint should work when index exists
|
||||
cursor = self.db.test.find().min([("x", 5)]).hint([("x", ASCENDING)])
|
||||
docs = cursor.to_list()
|
||||
self.assertTrue(all(doc["x"] >= 5 for doc in docs))
|
||||
|
||||
# max without hint should work when index exists
|
||||
cursor = self.db.test.find().max([("x", 5)]).hint([("x", ASCENDING)])
|
||||
docs = cursor.to_list()
|
||||
self.assertTrue(all(doc["x"] < 5 for doc in docs))
|
||||
|
||||
def test_cursor_address_property(self):
|
||||
"""Test cursor address is set after first batch."""
|
||||
cursor = self.db.test.find()
|
||||
self.assertIsNone(cursor.address)
|
||||
next(cursor)
|
||||
# Address should be set after query
|
||||
self.assertIsNotNone(cursor.address)
|
||||
|
||||
def test_cursor_session_property(self):
|
||||
"""Test cursor session property."""
|
||||
# Cursor without explicit session
|
||||
cursor1 = self.db.test.find()
|
||||
self.assertIsNone(cursor1.session)
|
||||
|
||||
# Cursor with explicit session
|
||||
with self.client.start_session() as session:
|
||||
cursor2 = self.db.test.find(session=session)
|
||||
self.assertEqual(session, cursor2.session)
|
||||
|
||||
def test_cursor_allow_disk_use_type_error(self):
|
||||
"""Test allow_disk_use raises TypeError for invalid input."""
|
||||
with self.assertRaises(TypeError):
|
||||
self.db.test.find().allow_disk_use("invalid") # type: ignore[arg-type]
|
||||
|
||||
|
||||
class TestRawBatchCursorCoverage(IntegrationTest):
|
||||
"""Additional tests for RawBatchCursor coverage."""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.db.test.drop()
|
||||
self.db.test.insert_many([{"x": i} for i in range(20)])
|
||||
|
||||
def test_raw_batch_cursor_iteration(self):
|
||||
"""Test raw batch cursor returns raw BSON."""
|
||||
cursor = self.db.test.find_raw_batches(batch_size=5)
|
||||
batch_count = 0
|
||||
for batch in cursor:
|
||||
self.assertIsInstance(batch, bytes)
|
||||
# Decode the batch to verify it's valid BSON
|
||||
docs = decode_all(batch)
|
||||
self.assertGreater(len(docs), 0)
|
||||
batch_count += 1
|
||||
self.assertGreater(batch_count, 0)
|
||||
|
||||
def test_raw_batch_cursor_explain(self):
|
||||
"""Test raw batch cursor explain."""
|
||||
cursor = self.db.test.find_raw_batches()
|
||||
explanation = cursor.explain()
|
||||
self.assertIn("queryPlanner", explanation)
|
||||
|
||||
def test_raw_batch_cursor_getitem_raises(self):
|
||||
"""Test raw batch cursor __getitem__ raises InvalidOperation."""
|
||||
cursor = self.db.test.find_raw_batches()
|
||||
with self.assertRaises(InvalidOperation):
|
||||
cursor[0]
|
||||
|
||||
def test_raw_batch_cursor_with_sort(self):
|
||||
"""Test raw batch cursor with sort."""
|
||||
cursor = self.db.test.find_raw_batches(batch_size=5).sort("x", DESCENDING)
|
||||
first_batch = next(cursor)
|
||||
docs = decode_all(first_batch)
|
||||
# First doc should have highest x value
|
||||
self.assertEqual(19, docs[0]["x"])
|
||||
|
||||
def test_raw_batch_cursor_with_limit(self):
|
||||
"""Test raw batch cursor with limit."""
|
||||
cursor = self.db.test.find_raw_batches(batch_size=5).limit(7)
|
||||
all_docs = []
|
||||
for batch in cursor:
|
||||
all_docs.extend(decode_all(batch))
|
||||
self.assertEqual(7, len(all_docs))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -50,9 +50,11 @@ from pymongo.common import _MAX_END_SESSIONS
|
||||
from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure
|
||||
from pymongo.operations import IndexModel, InsertOne, UpdateOne
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.synchronous.command_cursor import CommandCursor
|
||||
from pymongo.synchronous.cursor import Cursor
|
||||
from pymongo.synchronous.helpers import next
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
@ -1345,5 +1347,284 @@ class TestClusterTime(IntegrationTest):
|
||||
self.assertEqual(started.command["$clusterTime"], cluster_time)
|
||||
|
||||
|
||||
class TestClientSessionCoverage(IntegrationTest):
|
||||
"""Additional tests to improve code coverage for ClientSession."""
|
||||
|
||||
@client_context.require_sessions
|
||||
def test_session_has_ended_property(self):
|
||||
"""Test has_ended property state transitions."""
|
||||
session = self.client.start_session()
|
||||
self.assertFalse(session.has_ended)
|
||||
session.end_session()
|
||||
self.assertTrue(session.has_ended)
|
||||
|
||||
@client_context.require_sessions
|
||||
def test_session_session_id_property(self):
|
||||
"""Test session_id property returns correct value."""
|
||||
with self.client.start_session() as session:
|
||||
session_id = session.session_id
|
||||
self.assertIsInstance(session_id, dict)
|
||||
self.assertIn("id", session_id)
|
||||
|
||||
@client_context.require_sessions
|
||||
def test_session_cluster_time_operations(self):
|
||||
"""Test cluster time advance operations."""
|
||||
with self.client.start_session() as session:
|
||||
# Initially None
|
||||
self.assertIsNone(session.cluster_time)
|
||||
|
||||
# Perform operation to get cluster time
|
||||
self.db.test.find_one({}, session=session)
|
||||
|
||||
# Cluster time should be set after operation
|
||||
# (may still be None on some server versions)
|
||||
|
||||
@client_context.require_sessions
|
||||
def test_session_operation_time_operations(self):
|
||||
"""Test operation time advance operations."""
|
||||
with self.client.start_session() as session:
|
||||
# Initially None
|
||||
self.assertIsNone(session.operation_time)
|
||||
|
||||
# Perform operation to get operation time
|
||||
self.db.test.find_one({}, session=session)
|
||||
|
||||
@client_context.require_sessions
|
||||
def test_session_options_property(self):
|
||||
"""Test session options property."""
|
||||
with self.client.start_session(causal_consistency=True) as session:
|
||||
self.assertTrue(session.options.causal_consistency)
|
||||
|
||||
@client_context.require_sessions
|
||||
def test_session_client_property(self):
|
||||
"""Test session client property."""
|
||||
with self.client.start_session() as session:
|
||||
self.assertEqual(self.client, session.client)
|
||||
|
||||
@client_context.require_sessions
|
||||
def test_session_in_transaction_property(self):
|
||||
"""Test in_transaction property."""
|
||||
if client_context.is_rs or client_context.is_mongos:
|
||||
with self.client.start_session() as session:
|
||||
self.assertFalse(session.in_transaction)
|
||||
session.start_transaction()
|
||||
self.assertTrue(session.in_transaction)
|
||||
session.abort_transaction()
|
||||
self.assertFalse(session.in_transaction)
|
||||
|
||||
@client_context.require_sessions
|
||||
def test_session_context_manager(self):
|
||||
"""Test session async context manager."""
|
||||
with self.client.start_session() as session:
|
||||
self.assertFalse(session.has_ended)
|
||||
self.db.test.find_one({}, session=session)
|
||||
self.assertTrue(session.has_ended)
|
||||
|
||||
@client_context.require_sessions
|
||||
def test_session_context_manager_exception(self):
|
||||
"""Test session context manager closes on exception."""
|
||||
session = None
|
||||
try:
|
||||
with self.client.start_session() as session:
|
||||
raise ValueError("test exception")
|
||||
except ValueError:
|
||||
pass
|
||||
self.assertTrue(session.has_ended)
|
||||
|
||||
@client_context.require_sessions
|
||||
def test_session_operations_after_end(self):
|
||||
"""Test operations on ended session raise InvalidOperation."""
|
||||
session = self.client.start_session()
|
||||
session.end_session()
|
||||
|
||||
with self.assertRaises(InvalidOperation):
|
||||
self.db.test.find_one({}, session=session)
|
||||
|
||||
@client_context.require_sessions
|
||||
def test_session_end_session_idempotent(self):
|
||||
"""Test that end_session can be called multiple times."""
|
||||
session = self.client.start_session()
|
||||
session.end_session()
|
||||
# Second call should not raise
|
||||
session.end_session()
|
||||
self.assertTrue(session.has_ended)
|
||||
|
||||
@client_context.require_transactions
|
||||
def test_transaction_start_without_prior_transaction(self):
|
||||
"""Test start_transaction on fresh session."""
|
||||
with self.client.start_session() as session:
|
||||
session.start_transaction()
|
||||
self.assertTrue(session.in_transaction)
|
||||
session.abort_transaction()
|
||||
|
||||
@client_context.require_transactions
|
||||
def test_transaction_start_twice_raises(self):
|
||||
"""Test starting transaction twice raises error."""
|
||||
with self.client.start_session() as session:
|
||||
session.start_transaction()
|
||||
with self.assertRaises(InvalidOperation):
|
||||
session.start_transaction()
|
||||
session.abort_transaction()
|
||||
|
||||
@client_context.require_transactions
|
||||
def test_transaction_abort_without_transaction_raises(self):
|
||||
"""Test aborting without transaction raises error."""
|
||||
with self.client.start_session() as session:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
session.abort_transaction()
|
||||
|
||||
@client_context.require_transactions
|
||||
def test_transaction_commit_without_transaction_raises(self):
|
||||
"""Test committing without transaction raises error."""
|
||||
with self.client.start_session() as session:
|
||||
with self.assertRaises(InvalidOperation):
|
||||
session.commit_transaction()
|
||||
|
||||
@client_context.require_sessions
|
||||
def test_session_advance_cluster_time_validation(self):
|
||||
"""Test advance_cluster_time with invalid input."""
|
||||
with self.client.start_session() as session:
|
||||
with self.assertRaises(TypeError):
|
||||
session.advance_cluster_time("invalid") # type: ignore
|
||||
with self.assertRaises(ValueError):
|
||||
session.advance_cluster_time({})
|
||||
|
||||
@client_context.require_sessions
|
||||
def test_session_advance_operation_time_validation(self):
|
||||
"""Test advance_operation_time with invalid input."""
|
||||
from bson import Timestamp
|
||||
|
||||
with self.client.start_session() as session:
|
||||
with self.assertRaises(TypeError):
|
||||
session.advance_operation_time("invalid") # type: ignore
|
||||
# Valid Timestamp should work
|
||||
session.advance_operation_time(Timestamp(1, 1))
|
||||
|
||||
@client_context.require_transactions
|
||||
def test_with_transaction_callback_success(self):
|
||||
"""Test with_transaction with successful callback."""
|
||||
with self.client.start_session() as session:
|
||||
|
||||
def callback(session):
|
||||
self.db.test.insert_one({"x": 1}, session=session)
|
||||
return "success"
|
||||
|
||||
result = session.with_transaction(callback)
|
||||
self.assertEqual("success", result)
|
||||
|
||||
@client_context.require_transactions
|
||||
def test_with_transaction_callback_exception(self):
|
||||
"""Test with_transaction with callback exception."""
|
||||
with self.client.start_session() as session:
|
||||
|
||||
def callback(session):
|
||||
self.db.test.insert_one({"x": 1}, session=session)
|
||||
raise ValueError("callback error")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
session.with_transaction(callback)
|
||||
# Transaction should be aborted
|
||||
self.assertFalse(session.in_transaction)
|
||||
|
||||
|
||||
class TestSessionOptionsCoverage(UnitTest):
|
||||
"""Tests for SessionOptions coverage."""
|
||||
|
||||
def test_session_options_defaults(self):
|
||||
"""Test SessionOptions default values."""
|
||||
from pymongo.synchronous.client_session import SessionOptions
|
||||
|
||||
options = SessionOptions()
|
||||
self.assertTrue(options.causal_consistency)
|
||||
self.assertIsNone(options.default_transaction_options)
|
||||
self.assertFalse(options.snapshot)
|
||||
|
||||
def test_session_options_snapshot_disables_causal_consistency(self):
|
||||
"""Test snapshot=True forces causal_consistency=False."""
|
||||
from pymongo.synchronous.client_session import SessionOptions
|
||||
|
||||
options = SessionOptions(snapshot=True)
|
||||
self.assertFalse(options.causal_consistency)
|
||||
self.assertTrue(options.snapshot)
|
||||
|
||||
def test_session_options_snapshot_with_causal_raises(self):
|
||||
"""Test snapshot=True with causal_consistency=True raises error."""
|
||||
from pymongo.synchronous.client_session import SessionOptions
|
||||
|
||||
with self.assertRaises(ConfigurationError):
|
||||
SessionOptions(snapshot=True, causal_consistency=True)
|
||||
|
||||
def test_session_options_invalid_transaction_options(self):
|
||||
"""Test SessionOptions with invalid transaction options type."""
|
||||
from pymongo.synchronous.client_session import SessionOptions
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
SessionOptions(default_transaction_options="invalid") # type: ignore
|
||||
|
||||
|
||||
class TestTransactionOptionsCoverage(UnitTest):
|
||||
"""Tests for TransactionOptions coverage."""
|
||||
|
||||
def test_transaction_options_defaults(self):
|
||||
"""Test TransactionOptions default values."""
|
||||
from pymongo.synchronous.client_session import TransactionOptions
|
||||
|
||||
options = TransactionOptions()
|
||||
self.assertIsNone(options.read_concern)
|
||||
self.assertIsNone(options.write_concern)
|
||||
self.assertIsNone(options.read_preference)
|
||||
self.assertIsNone(options.max_commit_time_ms)
|
||||
|
||||
def test_transaction_options_with_values(self):
|
||||
"""Test TransactionOptions with all values set."""
|
||||
from pymongo.synchronous.client_session import TransactionOptions
|
||||
|
||||
options = TransactionOptions(
|
||||
read_concern=ReadConcern("majority"),
|
||||
write_concern=WriteConcern(w="majority"),
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
max_commit_time_ms=5000,
|
||||
)
|
||||
self.assertEqual("majority", options.read_concern.level)
|
||||
self.assertEqual("majority", options.write_concern.document.get("w"))
|
||||
self.assertEqual(ReadPreference.PRIMARY, options.read_preference)
|
||||
self.assertEqual(5000, options.max_commit_time_ms)
|
||||
|
||||
def test_transaction_options_invalid_read_concern(self):
|
||||
"""Test TransactionOptions with invalid read_concern type."""
|
||||
from pymongo.synchronous.client_session import TransactionOptions
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
TransactionOptions(read_concern="invalid") # type: ignore
|
||||
|
||||
def test_transaction_options_invalid_write_concern(self):
|
||||
"""Test TransactionOptions with invalid write_concern type."""
|
||||
from pymongo.synchronous.client_session import TransactionOptions
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
TransactionOptions(write_concern="invalid") # type: ignore
|
||||
|
||||
def test_transaction_options_invalid_read_preference(self):
|
||||
"""Test TransactionOptions with invalid read_preference type."""
|
||||
from pymongo.synchronous.client_session import TransactionOptions
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
TransactionOptions(read_preference="invalid") # type: ignore
|
||||
|
||||
def test_transaction_options_invalid_max_commit_time(self):
|
||||
"""Test TransactionOptions with invalid max_commit_time_ms type."""
|
||||
from pymongo.synchronous.client_session import TransactionOptions
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
TransactionOptions(max_commit_time_ms="invalid") # type: ignore
|
||||
|
||||
def test_transaction_options_unacknowledged_write_concern(self):
|
||||
"""Test TransactionOptions rejects unacknowledged write concern."""
|
||||
from pymongo.synchronous.client_session import TransactionOptions
|
||||
|
||||
with self.assertRaises(ConfigurationError):
|
||||
TransactionOptions(write_concern=WriteConcern(w=0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user