PYTHON-5763 Improve async test coverage

This commit is contained in:
Jeffrey 'Alex' Clark 2026-03-18 17:24:55 -04:00
parent 13085ff679
commit 26f7a11253
10 changed files with 2550 additions and 0 deletions

View File

@ -781,6 +781,227 @@ class AsyncTestBulk(AsyncBulkTestBase):
self.assertEqual(6, result.inserted_count)
self.assertEqual(6, await self.coll.count_documents({}))
async def test_bulk_write_with_comment(self):
"""Test bulk write operations with comment parameter."""
requests = [
InsertOne({"x": 1}),
UpdateOne({"x": 1}, {"$set": {"y": 1}}),
DeleteOne({"x": 1}),
]
result = await self.coll.bulk_write(requests, comment="bulk_comment")
self.assertEqual(1, result.inserted_count)
self.assertEqual(1, result.modified_count)
self.assertEqual(1, result.deleted_count)
async def test_bulk_write_with_let(self):
"""Test bulk write operations with let parameter."""
if not async_client_context.version.at_least(5, 0):
self.skipTest("let parameter requires MongoDB 5.0+")
await self.coll.insert_one({"x": 1})
requests = [
UpdateOne({"$expr": {"$eq": ["$x", "$$targetVal"]}}, {"$set": {"updated": True}}),
]
result = await self.coll.bulk_write(requests, let={"targetVal": 1})
self.assertEqual(1, result.modified_count)
async def test_bulk_write_all_operation_types(self):
"""Test bulk write with all operation types combined."""
await self.coll.insert_many([{"x": i} for i in range(5)])
requests = [
InsertOne({"x": 100}),
UpdateOne({"x": 0}, {"$set": {"updated": True}}),
UpdateMany({"x": {"$lte": 2}}, {"$set": {"batch_updated": True}}),
ReplaceOne({"x": 3}, {"x": 3, "replaced": True}),
DeleteOne({"x": 4}),
DeleteMany({"x": {"$gt": 50}}),
]
result = await self.coll.bulk_write(requests)
self.assertEqual(1, result.inserted_count)
self.assertGreaterEqual(result.modified_count, 1)
self.assertGreaterEqual(result.deleted_count, 1)
async def test_bulk_write_unordered(self):
"""Test unordered bulk write continues after error."""
await self.coll.create_index([("x", 1)], unique=True)
self.addAsyncCleanup(self.coll.drop_index, [("x", 1)])
requests = [
InsertOne({"x": 1}),
InsertOne({"x": 1}), # Duplicate - will error
InsertOne({"x": 2}),
InsertOne({"x": 3}),
]
with self.assertRaises(BulkWriteError) as ctx:
await self.coll.bulk_write(requests, ordered=False)
# With unordered, should have inserted 3 documents
self.assertEqual(3, ctx.exception.details["nInserted"])
async def test_bulk_write_ordered(self):
"""Test ordered bulk write stops on first error."""
await self.coll.create_index([("x", 1)], unique=True)
self.addAsyncCleanup(self.coll.drop_index, [("x", 1)])
requests = [
InsertOne({"x": 1}),
InsertOne({"x": 1}), # Duplicate - will error
InsertOne({"x": 2}),
InsertOne({"x": 3}),
]
with self.assertRaises(BulkWriteError) as ctx:
await self.coll.bulk_write(requests, ordered=True)
# With ordered, should have inserted only 1 document
self.assertEqual(1, ctx.exception.details["nInserted"])
async def test_bulk_write_bypass_document_validation(self):
"""Test bulk write with bypass_document_validation."""
if not async_client_context.version.at_least(3, 2):
self.skipTest("bypass_document_validation requires MongoDB 3.2+")
# Create collection with validator
await self.coll.drop()
await self.db.create_collection(
self.coll.name, validator={"$jsonSchema": {"required": ["name"]}}
)
# Without bypass, should fail
with self.assertRaises(BulkWriteError):
await self.coll.bulk_write([InsertOne({"x": 1})])
# With bypass, should succeed
result = await self.coll.bulk_write([InsertOne({"x": 1})], bypass_document_validation=True)
self.assertEqual(1, result.inserted_count)
async def test_bulk_write_result_properties(self):
"""Test all BulkWriteResult properties."""
await self.coll.insert_one({"x": 1})
requests = [
InsertOne({"x": 2}),
UpdateOne({"x": 1}, {"$set": {"updated": True}}),
ReplaceOne({"x": 2}, {"x": 2, "replaced": True}, upsert=True),
DeleteOne({"x": 1}),
]
result = await self.coll.bulk_write(requests)
# Check all properties
self.assertTrue(result.acknowledged)
self.assertEqual(1, result.inserted_count)
self.assertGreaterEqual(result.matched_count, 0)
self.assertGreaterEqual(result.modified_count, 0)
self.assertEqual(1, result.deleted_count)
self.assertIsInstance(result.upserted_count, int)
self.assertIsInstance(result.upserted_ids, dict)
async def test_bulk_write_with_upsert(self):
"""Test bulk write upsert operations."""
requests = [
UpdateOne({"x": 1}, {"$set": {"y": 1}}, upsert=True),
UpdateOne({"x": 2}, {"$set": {"y": 2}}, upsert=True),
ReplaceOne({"x": 3}, {"x": 3, "y": 3}, upsert=True),
]
result = await self.coll.bulk_write(requests)
self.assertEqual(3, result.upserted_count)
self.assertEqual(3, len(result.upserted_ids))
async def test_update_one_with_hint(self):
"""Test UpdateOne with hint parameter."""
await self.coll.create_index([("x", 1)])
self.addAsyncCleanup(self.coll.drop_index, [("x", 1)])
await self.coll.insert_one({"x": 1})
requests = [UpdateOne({"x": 1}, {"$set": {"y": 1}}, hint=[("x", 1)])]
result = await self.coll.bulk_write(requests)
self.assertEqual(1, result.modified_count)
async def test_update_many_with_hint(self):
"""Test UpdateMany with hint parameter."""
await self.coll.create_index([("x", 1)])
self.addAsyncCleanup(self.coll.drop_index, [("x", 1)])
await self.coll.insert_many([{"x": 1}, {"x": 1}])
requests = [UpdateMany({"x": 1}, {"$set": {"y": 1}}, hint=[("x", 1)])]
result = await self.coll.bulk_write(requests)
self.assertEqual(2, result.modified_count)
async def test_delete_one_with_hint(self):
"""Test DeleteOne with hint parameter."""
await self.coll.create_index([("x", 1)])
self.addAsyncCleanup(self.coll.drop_index, [("x", 1)])
await self.coll.insert_one({"x": 1})
requests = [DeleteOne({"x": 1}, hint=[("x", 1)])]
result = await self.coll.bulk_write(requests)
self.assertEqual(1, result.deleted_count)
async def test_delete_many_with_hint(self):
"""Test DeleteMany with hint parameter."""
await self.coll.create_index([("x", 1)])
self.addAsyncCleanup(self.coll.drop_index, [("x", 1)])
await self.coll.insert_many([{"x": 1}, {"x": 1}])
requests = [DeleteMany({"x": 1}, hint=[("x", 1)])]
result = await self.coll.bulk_write(requests)
self.assertEqual(2, result.deleted_count)
async def test_update_one_with_array_filters(self):
"""Test UpdateOne with array_filters parameter."""
await self.coll.insert_one({"x": [{"y": 1}, {"y": 2}, {"y": 3}]})
requests = [
UpdateOne({}, {"$set": {"x.$[elem].z": 1}}, array_filters=[{"elem.y": {"$gt": 1}}])
]
result = await self.coll.bulk_write(requests)
self.assertEqual(1, result.modified_count)
doc = await self.coll.find_one()
# Elements with y > 1 should have z = 1
for elem in doc["x"]:
if elem["y"] > 1:
self.assertEqual(1, elem.get("z"))
async def test_replace_one_with_hint(self):
"""Test ReplaceOne with hint parameter."""
await self.coll.create_index([("x", 1)])
self.addAsyncCleanup(self.coll.drop_index, [("x", 1)])
await self.coll.insert_one({"x": 1})
requests = [ReplaceOne({"x": 1}, {"x": 1, "replaced": True}, hint=[("x", 1)])]
result = await self.coll.bulk_write(requests)
self.assertEqual(1, result.modified_count)
async def test_update_with_collation(self):
"""Test update operations with collation."""
await self.coll.insert_many(
[
{"name": "cafe"},
{"name": "Cafe"},
]
)
requests = [
UpdateMany(
{"name": "cafe"},
{"$set": {"updated": True}},
collation={"locale": "en", "strength": 2},
)
]
result = await self.coll.bulk_write(requests)
# With case-insensitive collation, both docs should match
self.assertEqual(2, result.modified_count)
class AsyncBulkAuthorizationTestBase(AsyncBulkTestBase):
@async_client_context.require_auth

View File

@ -1152,5 +1152,122 @@ globals().update(
)
class AsyncTestChangeStreamCoverage(TestAsyncCollectionAsyncChangeStream):
"""Additional tests to improve code coverage for AsyncChangeStream."""
async def test_change_stream_alive_property(self):
"""Test alive property state transitions."""
async with await self.change_stream() as cs:
self.assertTrue(cs.alive)
# After context exit, should be closed
self.assertFalse(cs.alive)
async def test_change_stream_idempotent_close(self):
"""Test that close() can be called multiple times safely."""
cs = await self.change_stream()
await cs.close()
# Second close should not raise
await cs.close()
self.assertFalse(cs.alive)
async def test_change_stream_resume_token_deepcopy(self):
"""Test that resume_token returns a deep copy."""
coll = self.watched_collection()
async with await self.change_stream() as cs:
await coll.insert_one({"x": 1})
await anext(cs) # Consume the change event
token1 = cs.resume_token
token2 = cs.resume_token
# Should be equal but different objects
self.assertEqual(token1, token2)
self.assertIsNot(token1, token2)
async def test_change_stream_with_comment(self):
"""Test change stream with comment parameter."""
client, listener = await self.client_with_listener("aggregate")
try:
async with await self.change_stream_with_client(client, comment="test_comment"):
pass
finally:
await client.close()
# Check that comment was in the aggregate command
self.assertGreater(len(listener.started_events), 0)
cmd = listener.started_events[0].command
self.assertEqual("test_comment", cmd.get("comment"))
async def test_change_stream_with_show_expanded_events(self):
"""Test change stream with show_expanded_events parameter."""
if not async_client_context.version.at_least(6, 0):
self.skipTest("show_expanded_events requires MongoDB 6.0+")
async with await self.change_stream(show_expanded_events=True) as cs:
# Just verify it doesn't error
self.assertTrue(cs.alive)
@async_client_context.require_version_min(6, 0)
async def test_change_stream_with_full_document_before_change(self):
"""Test change stream with full_document_before_change parameter."""
coll = self.watched_collection()
# Need to ensure collection exists with changeStreamPreAndPostImages enabled
await coll.drop()
await self.db.create_collection(coll.name, changeStreamPreAndPostImages={"enabled": True})
await coll.insert_one({"x": 1})
async with await self.change_stream(full_document_before_change="whenAvailable") as cs:
await coll.update_one({"x": 1}, {"$set": {"x": 2}})
change = await anext(cs)
self.assertEqual("update", change["operationType"])
# fullDocumentBeforeChange should be present
self.assertIn("fullDocumentBeforeChange", change)
async def test_change_stream_next_after_close(self):
"""Test that next() on closed stream raises StopAsyncIteration."""
cs = await self.change_stream()
await cs.close()
with self.assertRaises(StopAsyncIteration):
await anext(cs)
async def test_change_stream_try_next_after_close(self):
"""Test that try_next() on closed stream raises StopAsyncIteration."""
cs = await self.change_stream()
await cs.close()
with self.assertRaises(StopAsyncIteration):
await cs.try_next()
async def test_change_stream_pipeline_construction(self):
"""Test change stream pipeline is properly constructed."""
pipeline = [{"$match": {"operationType": "insert"}}]
client, listener = await self.client_with_listener("aggregate")
try:
async with await self.change_stream_with_client(client, pipeline=pipeline):
pass
finally:
await client.close()
cmd = listener.started_events[0].command
agg_pipeline = cmd["pipeline"]
# First stage should be $changeStream
self.assertIn("$changeStream", agg_pipeline[0])
# Second stage should be our match
self.assertEqual({"$match": {"operationType": "insert"}}, agg_pipeline[1])
async def test_change_stream_empty_pipeline(self):
"""Test change stream with empty pipeline."""
async with await self.change_stream(pipeline=[]) as cs:
self.assertTrue(cs.alive)
async def test_change_stream_context_manager_exception(self):
"""Test change stream context manager closes on exception."""
cs = None
try:
async with await self.change_stream() as cs:
raise ValueError("test exception")
except ValueError:
pass
# Stream should be closed
self.assertFalse(cs.alive)
if __name__ == "__main__":
unittest.main()

View File

@ -2260,5 +2260,264 @@ class AsyncTestCollection(AsyncIntegrationTest):
await helper(*args, let={}) # type: ignore
class AsyncTestCollectionCoverage(AsyncIntegrationTest):
"""Additional tests to improve code coverage for AsyncCollection."""
async def asyncSetUp(self):
await super().asyncSetUp()
await self.db.test.drop()
await self.db.test.insert_many([{"x": i, "y": i * 2} for i in range(10)])
async def test_collection_full_name(self):
"""Test full_name property."""
expected = f"{self.db.name}.test"
self.assertEqual(expected, self.db.test.full_name)
async def test_collection_name(self):
"""Test name property."""
self.assertEqual("test", self.db.test.name)
async def test_collection_database(self):
"""Test database property."""
self.assertEqual(self.db, self.db.test.database)
async def test_collection_equality(self):
"""Test collection equality."""
coll1 = self.db.test
coll2 = self.db.test
coll3 = self.db.other
self.assertEqual(coll1, coll2)
self.assertNotEqual(coll1, coll3)
async def test_collection_hash(self):
"""Test collection hashability."""
coll1 = self.db.test
coll2 = self.db.test
# Same collection should have same hash
self.assertEqual(hash(coll1), hash(coll2))
# Collections can be used in sets
s = {coll1, coll2}
self.assertEqual(1, len(s))
async def test_collection_repr(self):
"""Test collection repr."""
coll = self.db.test
repr_str = repr(coll)
self.assertIn("test", repr_str)
self.assertIn("AsyncCollection", repr_str)
async def test_collection_getattr(self):
"""Test sub-collection access via attribute."""
subcoll = self.db.test.subcollection
self.assertEqual("test.subcollection", subcoll.name)
async def test_collection_getitem(self):
"""Test sub-collection access via indexing."""
subcoll = self.db.test["subcollection"]
self.assertEqual("test.subcollection", subcoll.name)
async def test_collection_with_options(self):
"""Test with_options creates new collection with options."""
from pymongo.read_concern import ReadConcern
from pymongo.write_concern import WriteConcern
coll = self.db.test.with_options(
read_concern=ReadConcern("majority"), write_concern=WriteConcern(w=1)
)
self.assertEqual("majority", coll.read_concern.level)
self.assertEqual({"w": 1}, coll.write_concern.document)
# Original should be unchanged
self.assertNotEqual("majority", self.db.test.read_concern.level)
async def test_collection_drop(self):
"""Test collection drop."""
await self.db.test_drop.insert_one({"x": 1})
await self.db.test_drop.drop()
names = await self.db.list_collection_names()
self.assertNotIn("test_drop", names)
async def test_collection_drop_with_comment(self):
"""Test collection drop with comment."""
await self.db.test_drop_comment.insert_one({"x": 1})
await self.db.test_drop_comment.drop(comment="test_comment")
names = await self.db.list_collection_names()
self.assertNotIn("test_drop_comment", names)
async def test_find_raw_batches(self):
"""Test find_raw_batches returns raw BSON."""
from bson import decode_all
cursor = self.db.test.find_raw_batches(batch_size=5)
batch_count = 0
async for batch in cursor:
self.assertIsInstance(batch, bytes)
docs = decode_all(batch)
self.assertGreater(len(docs), 0)
batch_count += 1
self.assertGreater(batch_count, 0)
async def test_aggregate_raw_batches(self):
"""Test aggregate_raw_batches returns raw BSON."""
from bson import decode_all
cursor = await self.db.test.aggregate_raw_batches([{"$sort": {"x": 1}}], batchSize=5)
batch_count = 0
async for batch in cursor:
self.assertIsInstance(batch, bytes)
docs = decode_all(batch)
self.assertGreater(len(docs), 0)
batch_count += 1
self.assertGreater(batch_count, 0)
async def test_distinct_with_collation(self):
"""Test distinct with collation."""
await self.db.test.drop()
await self.db.test.insert_many(
[
{"name": "abc"},
{"name": "ABC"},
{"name": "def"},
]
)
# Case-insensitive distinct
values = await self.db.test.distinct("name", collation={"locale": "en_US", "strength": 2})
# abc and ABC should be considered the same
self.assertEqual(2, len(values))
async def test_count_documents_with_options(self):
"""Test count_documents with skip, limit, hint."""
await self.db.test.create_index([("x", 1)])
count = await self.db.test.count_documents(
{"x": {"$gte": 0}}, skip=2, limit=5, hint=[("x", 1)]
)
self.assertEqual(5, count)
async def test_estimated_document_count(self):
"""Test estimated_document_count."""
count = await self.db.test.estimated_document_count()
self.assertEqual(10, count)
async def test_estimated_document_count_with_options(self):
"""Test estimated_document_count with maxTimeMS and comment."""
count = await self.db.test.estimated_document_count(maxTimeMS=5000, comment="test_comment")
self.assertEqual(10, count)
async def test_find_one_and_delete_with_options(self):
"""Test find_one_and_delete with projection, sort."""
doc = await self.db.test.find_one_and_delete(
{"x": {"$gte": 0}}, projection={"x": 1}, sort=[("x", -1)]
)
self.assertEqual(9, doc["x"])
self.assertNotIn("y", doc)
async def test_find_one_and_replace_with_options(self):
"""Test find_one_and_replace with various options."""
from pymongo import ReturnDocument
doc = await self.db.test.find_one_and_replace(
{"x": 0},
{"x": 0, "replaced": True},
projection={"x": 1, "replaced": 1},
return_document=ReturnDocument.AFTER,
)
self.assertEqual(0, doc["x"])
self.assertTrue(doc.get("replaced"))
async def test_find_one_and_update_with_options(self):
"""Test find_one_and_update with various options."""
from pymongo import ReturnDocument
doc = await self.db.test.find_one_and_update(
{"x": 0},
{"$set": {"updated": True}},
projection={"x": 1, "updated": 1},
return_document=ReturnDocument.AFTER,
)
self.assertEqual(0, doc["x"])
self.assertTrue(doc.get("updated"))
async def test_update_one_with_array_filters(self):
"""Test update_one with array_filters."""
await self.db.test.drop()
await self.db.test.insert_one({"items": [{"v": 1}, {"v": 2}, {"v": 3}]})
result = await self.db.test.update_one(
{}, {"$set": {"items.$[elem].updated": True}}, array_filters=[{"elem.v": {"$gt": 1}}]
)
self.assertEqual(1, result.modified_count)
async def test_update_many_with_hint(self):
"""Test update_many with hint."""
await self.db.test.create_index([("x", 1)])
result = await self.db.test.update_many(
{"x": {"$gte": 0}}, {"$set": {"batch_updated": True}}, hint=[("x", 1)]
)
self.assertEqual(10, result.modified_count)
async def test_delete_one_with_hint(self):
"""Test delete_one with hint."""
await self.db.test.create_index([("x", 1)])
result = await self.db.test.delete_one({"x": 0}, hint=[("x", 1)])
self.assertEqual(1, result.deleted_count)
async def test_delete_many_with_hint(self):
"""Test delete_many with hint."""
await self.db.test.create_index([("x", 1)])
result = await self.db.test.delete_many({"x": {"$lt": 5}}, hint=[("x", 1)])
self.assertEqual(5, result.deleted_count)
async def test_aggregate_with_let(self):
"""Test aggregate with let parameter."""
if not async_client_context.version.at_least(5, 0):
self.skipTest("let parameter requires MongoDB 5.0+")
pipeline = [{"$match": {"$expr": {"$eq": ["$x", "$$targetVal"]}}}]
cursor = await self.db.test.aggregate(pipeline, let={"targetVal": 5})
docs = await cursor.to_list()
self.assertEqual(1, len(docs))
self.assertEqual(5, docs[0]["x"])
async def test_aggregate_with_batch_size(self):
"""Test aggregate with batchSize."""
cursor = await self.db.test.aggregate([{"$sort": {"x": 1}}], batchSize=2)
docs = await cursor.to_list()
self.assertEqual(10, len(docs))
async def test_list_indexes(self):
"""Test list_indexes returns cursor."""
await self.db.test.create_index([("x", 1)])
cursor = await self.db.test.list_indexes()
# Should get at least the _id index
indexes = await cursor.to_list()
self.assertGreaterEqual(len(indexes), 1)
index_names = [idx["name"] for idx in indexes]
self.assertIn("_id_", index_names)
async def test_index_information(self):
"""Test index_information returns dict."""
await self.db.test.create_index([("x", 1)], name="x_index")
info = await self.db.test.index_information()
self.assertIsInstance(info, dict)
self.assertIn("_id_", info)
self.assertIn("x_index", info)
async def test_options_method(self):
"""Test options() returns collection options."""
# Create a capped collection
await self.db.drop_collection("test_capped")
await self.db.create_collection("test_capped", capped=True, size=10000)
opts = await self.db.test_capped.options()
self.assertTrue(opts.get("capped"))
await self.db.drop_collection("test_capped")
if __name__ == "__main__":
unittest.main()

View File

@ -1864,5 +1864,404 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest):
self.assertEqual(cmd.command["$db"], "pymongo_test")
class AsyncTestCursorCoverage(AsyncIntegrationTest):
"""Additional tests to improve code coverage for AsyncCursor."""
async def asyncSetUp(self):
await super().asyncSetUp()
await self.db.test.drop()
await self.db.test.insert_many([{"x": i, "y": i * 2} for i in range(10)])
async def test_get_namespace(self):
"""Test _get_namespace() method."""
cursor = self.db.test.find()
expected_ns = f"{self.db.name}.test"
self.assertEqual(expected_ns, cursor._get_namespace())
async def test_cursor_alive_property_states(self):
"""Test cursor alive property in different states."""
cursor = self.db.test.find()
# Cursor is alive even before starting (has potential to return data)
self.assertTrue(cursor.alive)
# Start the cursor
await anext(cursor)
self.assertTrue(cursor.alive)
# Exhaust the cursor
await cursor.to_list()
self.assertFalse(cursor.alive)
async def test_cursor_closed_property(self):
"""Test cursor behavior after close."""
cursor = self.db.test.find()
await anext(cursor)
self.assertTrue(cursor.alive)
await cursor.close()
# After close, cursor is killed (check internal _killed flag)
self.assertTrue(cursor._killed)
async def test_retrieved_property(self):
"""Test the retrieved property tracking."""
cursor = self.db.test.find().batch_size(2)
self.assertEqual(0, cursor.retrieved)
await anext(cursor)
self.assertGreater(cursor.retrieved, 0)
async def test_cursor_with_let_parameter(self):
"""Test cursor with let parameter."""
# let parameter allows variables to be used in the filter
cursor = self.db.test.find(
{"$expr": {"$eq": ["$x", "$$targetValue"]}}, let={"targetValue": 5}
)
docs = await cursor.to_list()
self.assertEqual(1, len(docs))
self.assertEqual(5, docs[0]["x"])
async def test_cursor_with_invalid_let_parameter(self):
"""Test cursor raises error for invalid let parameter."""
with self.assertRaises(TypeError):
self.db.test.find(let="invalid") # type: ignore[arg-type]
async def test_cursor_with_show_record_id(self):
"""Test cursor with show_record_id option."""
cursor = self.db.test.find(show_record_id=True)
doc = await anext(cursor)
self.assertIn("$recordId", doc)
async def test_cursor_with_return_key(self):
"""Test cursor with return_key option."""
await self.db.test.create_index([("x", ASCENDING)])
cursor = self.db.test.find({"x": 5}, return_key=True).hint([("x", ASCENDING)])
doc = await anext(cursor)
# return_key returns only index keys
self.assertIn("x", doc)
self.assertNotIn("y", doc)
async def test_check_okay_to_chain_after_iteration(self):
"""Test that cursor configuration methods raise after iteration."""
cursor = self.db.test.find()
await anext(cursor) # Start iteration
# All these should raise InvalidOperation
with self.assertRaises(InvalidOperation):
cursor.limit(5)
with self.assertRaises(InvalidOperation):
cursor.skip(2)
with self.assertRaises(InvalidOperation):
cursor.sort("x")
with self.assertRaises(InvalidOperation):
cursor.hint([("x", ASCENDING)])
with self.assertRaises(InvalidOperation):
cursor.max([("x", 10)])
with self.assertRaises(InvalidOperation):
cursor.min([("x", 0)])
with self.assertRaises(InvalidOperation):
await cursor.add_option(2)
with self.assertRaises(InvalidOperation):
cursor.remove_option(2)
with self.assertRaises(InvalidOperation):
cursor.batch_size(10)
with self.assertRaises(InvalidOperation):
cursor.max_time_ms(1000)
with self.assertRaises(InvalidOperation):
cursor.collation(Collation("en_US"))
with self.assertRaises(InvalidOperation):
cursor.allow_disk_use(True)
with self.assertRaises(InvalidOperation):
cursor.where("this.x > 5")
with self.assertRaises(InvalidOperation):
cursor.comment("test")
async def test_cursor_context_manager(self):
"""Test cursor as async context manager."""
async with self.db.test.find() as cursor:
doc = await anext(cursor)
self.assertIsNotNone(doc)
# Cursor should be killed after context (check _killed flag)
self.assertTrue(cursor._killed)
async def test_cursor_context_manager_with_exception(self):
"""Test cursor context manager closes on exception."""
cursor = None
try:
async with self.db.test.find() as cursor:
await anext(cursor)
raise ValueError("test exception")
except ValueError:
pass
# Cursor should be killed after exception
self.assertTrue(cursor._killed)
async def test_cursor_collation(self):
"""Test cursor with collation."""
await self.db.test.drop()
await self.db.test.insert_many([{"name": "abc"}, {"name": "ABC"}, {"name": "def"}])
# Case-insensitive sort
cursor = (
self.db.test.find().collation(Collation("en_US", strength=2)).sort("name", ASCENDING)
)
docs = await cursor.to_list()
self.assertEqual(3, len(docs))
async def test_cursor_collation_type_error(self):
"""Test cursor raises error for invalid collation."""
with self.assertRaises(TypeError):
self.db.test.find().collation("invalid") # type: ignore[arg-type]
async def test_cursor_getitem_not_supported(self):
"""Test that AsyncCursor does not support indexing."""
cursor = self.db.test.find()
with self.assertRaises(IndexError) as ctx:
cursor[5]
self.assertIn("does not support indexing", str(ctx.exception))
async def test_cursor_next_after_close(self):
"""Test that next() raises StopAsyncIteration after close."""
cursor = self.db.test.find()
await cursor.close()
with self.assertRaises(StopAsyncIteration):
await anext(cursor)
async def test_cursor_rewind_resets_state(self):
"""Test that rewind properly resets cursor state."""
cursor = self.db.test.find().limit(3)
# Iterate fully
docs1 = await cursor.to_list()
self.assertEqual(3, len(docs1))
self.assertEqual(0, len(cursor._data))
# Rewind and iterate again
await cursor.rewind()
docs2 = await cursor.to_list()
self.assertEqual(3, len(docs2))
self.assertEqual(docs1, docs2)
async def test_cursor_clone_with_session(self):
"""Test that clone preserves explicit session."""
async with self.client.start_session() as session:
cursor = self.db.test.find(session=session)
cloned = cursor.clone()
# Clone should reference the same session
self.assertEqual(cursor.session, cloned.session)
async def test_cursor_clone_without_session(self):
"""Test that clone without session doesn't add one."""
cursor = self.db.test.find()
cloned = cursor.clone()
# Clone should have no session if original had none
self.assertIsNone(cloned.session)
async def test_cursor_distinct_with_collation(self):
"""Test distinct with collation."""
await self.db.test.drop()
await self.db.test.insert_many([{"name": "abc"}, {"name": "ABC"}, {"name": "def"}])
# Case-insensitive distinct
cursor = self.db.test.find().collation(Collation("en_US", strength=2))
# distinct() on cursor with collation
values = await cursor.distinct("name")
# Should have 2 distinct values (abc/ABC treated as same)
self.assertEqual(2, len(values))
async def test_cursor_explain_with_options(self):
"""Test explain with cursor options set."""
cursor = self.db.test.find({"x": {"$gt": 5}}).sort("x", ASCENDING).limit(5).skip(1)
explanation = await cursor.explain()
self.assertIn("queryPlanner", explanation)
async def test_cursor_max_time_ms_type_errors(self):
"""Test max_time_ms raises TypeError for invalid input."""
cursor = self.db.test.find()
with self.assertRaises(TypeError):
cursor.max_time_ms("invalid") # type: ignore[arg-type]
async def test_cursor_max_await_time_ms_type_errors(self):
"""Test max_await_time_ms raises TypeError for invalid input."""
cursor = self.db.test.find()
with self.assertRaises(TypeError):
cursor.max_await_time_ms("invalid") # type: ignore[arg-type]
async def test_cursor_comment_type(self):
"""Test cursor with comment of various types."""
# String comment
cursor1 = self.db.test.find().comment("test comment")
docs1 = await cursor1.to_list()
self.assertGreater(len(docs1), 0)
# Dict comment
cursor2 = self.db.test.find().comment({"key": "value"})
docs2 = await cursor2.to_list()
self.assertGreater(len(docs2), 0)
async def test_cursor_batch_size_validation(self):
"""Test batch_size validation."""
with self.assertRaises(TypeError):
self.db.test.find(batch_size="invalid") # type: ignore[arg-type]
with self.assertRaises(ValueError):
self.db.test.find(batch_size=-1)
async def test_cursor_skip_validation(self):
"""Test skip validation."""
with self.assertRaises(TypeError):
self.db.test.find(skip="invalid") # type: ignore[arg-type]
async def test_cursor_limit_validation(self):
"""Test limit validation."""
with self.assertRaises(TypeError):
self.db.test.find(limit="invalid") # type: ignore[arg-type]
async def test_cursor_filter_validation(self):
"""Test filter validation."""
with self.assertRaises(TypeError):
self.db.test.find(filter="invalid") # type: ignore[arg-type]
async def test_cursor_type_validation(self):
"""Test cursor_type validation."""
with self.assertRaises(ValueError):
self.db.test.find(cursor_type=999)
async def test_cursor_query_spec_with_modifiers(self):
"""Test _query_spec includes modifiers."""
cursor = (
self.db.test.find()
.sort("x", ASCENDING)
.hint([("x", ASCENDING)])
.max_time_ms(1000)
.comment("test")
)
spec = cursor._query_spec()
self.assertIsInstance(spec, dict)
async def test_cursor_copy(self):
"""Test cursor __copy__ returns clone."""
cursor = self.db.test.find().limit(5)
copied = copy.copy(cursor)
self.assertIsNot(cursor, copied)
self.assertEqual(cursor._limit, copied._limit)
async def test_cursor_deepcopy(self):
"""Test cursor __deepcopy__ returns deep clone."""
cursor = self.db.test.find({"x": {"$gt": 0}}).limit(5)
copied = copy.deepcopy(cursor)
self.assertIsNot(cursor, copied)
self.assertEqual(cursor._limit, copied._limit)
self.assertEqual(cursor._spec, copied._spec)
# Spec should be a different object
self.assertIsNot(cursor._spec, copied._spec)
async def test_cursor_iteration_protocol(self):
"""Test cursor async iteration protocol."""
cursor = self.db.test.find().limit(3)
# Test __aiter__ returns self
self.assertIs(cursor, cursor.__aiter__())
# Test __anext__ returns documents
doc1 = await cursor.__anext__()
self.assertIsNotNone(doc1)
async def test_cursor_to_list_with_limit(self):
"""Test to_list respects cursor limit."""
cursor = self.db.test.find().limit(3)
docs = await cursor.to_list()
self.assertEqual(3, len(docs))
async def test_cursor_to_list_with_length(self):
"""Test to_list with length parameter."""
cursor = self.db.test.find()
docs = await cursor.to_list(length=3)
self.assertEqual(3, len(docs))
async def test_min_max_require_hint(self):
"""Test that min/max require hint for proper execution."""
await self.db.test.create_index([("x", ASCENDING)])
# min without hint should work when index exists
cursor = self.db.test.find().min([("x", 5)]).hint([("x", ASCENDING)])
docs = await cursor.to_list()
self.assertTrue(all(doc["x"] >= 5 for doc in docs))
# max without hint should work when index exists
cursor = self.db.test.find().max([("x", 5)]).hint([("x", ASCENDING)])
docs = await cursor.to_list()
self.assertTrue(all(doc["x"] < 5 for doc in docs))
async def test_cursor_address_property(self):
"""Test cursor address is set after first batch."""
cursor = self.db.test.find()
self.assertIsNone(cursor.address)
await anext(cursor)
# Address should be set after query
self.assertIsNotNone(cursor.address)
async def test_cursor_session_property(self):
"""Test cursor session property."""
# Cursor without explicit session
cursor1 = self.db.test.find()
self.assertIsNone(cursor1.session)
# Cursor with explicit session
async with self.client.start_session() as session:
cursor2 = self.db.test.find(session=session)
self.assertEqual(session, cursor2.session)
async def test_cursor_allow_disk_use_type_error(self):
"""Test allow_disk_use raises TypeError for invalid input."""
with self.assertRaises(TypeError):
self.db.test.find().allow_disk_use("invalid") # type: ignore[arg-type]
class AsyncTestRawBatchCursorCoverage(AsyncIntegrationTest):
"""Additional tests for AsyncRawBatchCursor coverage."""
async def asyncSetUp(self):
await super().asyncSetUp()
await self.db.test.drop()
await self.db.test.insert_many([{"x": i} for i in range(20)])
async def test_raw_batch_cursor_iteration(self):
"""Test raw batch cursor returns raw BSON."""
cursor = self.db.test.find_raw_batches(batch_size=5)
batch_count = 0
async for batch in cursor:
self.assertIsInstance(batch, bytes)
# Decode the batch to verify it's valid BSON
docs = decode_all(batch)
self.assertGreater(len(docs), 0)
batch_count += 1
self.assertGreater(batch_count, 0)
async def test_raw_batch_cursor_explain(self):
"""Test raw batch cursor explain."""
cursor = self.db.test.find_raw_batches()
explanation = await cursor.explain()
self.assertIn("queryPlanner", explanation)
async def test_raw_batch_cursor_getitem_raises(self):
"""Test raw batch cursor __getitem__ raises InvalidOperation."""
cursor = self.db.test.find_raw_batches()
with self.assertRaises(InvalidOperation):
cursor[0]
async def test_raw_batch_cursor_with_sort(self):
"""Test raw batch cursor with sort."""
cursor = self.db.test.find_raw_batches(batch_size=5).sort("x", DESCENDING)
first_batch = await anext(cursor)
docs = decode_all(first_batch)
# First doc should have highest x value
self.assertEqual(19, docs[0]["x"])
async def test_raw_batch_cursor_with_limit(self):
"""Test raw batch cursor with limit."""
cursor = self.db.test.find_raw_batches(batch_size=5).limit(7)
all_docs = []
async for batch in cursor:
all_docs.extend(decode_all(batch))
self.assertEqual(7, len(all_docs))
if __name__ == "__main__":
unittest.main()

View File

@ -53,6 +53,8 @@ from pymongo.common import _MAX_END_SESSIONS
from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure
from pymongo.operations import IndexModel, InsertOne, UpdateOne
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
from pymongo.write_concern import WriteConcern
_IS_SYNC = False
@ -1361,5 +1363,284 @@ class TestClusterTime(AsyncIntegrationTest):
self.assertEqual(started.command["$clusterTime"], cluster_time)
class AsyncTestClientSessionCoverage(AsyncIntegrationTest):
"""Additional tests to improve code coverage for AsyncClientSession."""
@async_client_context.require_sessions
async def test_session_has_ended_property(self):
"""Test has_ended property state transitions."""
session = self.client.start_session()
self.assertFalse(session.has_ended)
await session.end_session()
self.assertTrue(session.has_ended)
@async_client_context.require_sessions
async def test_session_session_id_property(self):
"""Test session_id property returns correct value."""
async with self.client.start_session() as session:
session_id = session.session_id
self.assertIsInstance(session_id, dict)
self.assertIn("id", session_id)
@async_client_context.require_sessions
async def test_session_cluster_time_operations(self):
"""Test cluster time advance operations."""
async with self.client.start_session() as session:
# Initially None
self.assertIsNone(session.cluster_time)
# Perform operation to get cluster time
await self.db.test.find_one({}, session=session)
# Cluster time should be set after operation
# (may still be None on some server versions)
@async_client_context.require_sessions
async def test_session_operation_time_operations(self):
"""Test operation time advance operations."""
async with self.client.start_session() as session:
# Initially None
self.assertIsNone(session.operation_time)
# Perform operation to get operation time
await self.db.test.find_one({}, session=session)
@async_client_context.require_sessions
async def test_session_options_property(self):
"""Test session options property."""
async with self.client.start_session(causal_consistency=True) as session:
self.assertTrue(session.options.causal_consistency)
@async_client_context.require_sessions
async def test_session_client_property(self):
"""Test session client property."""
async with self.client.start_session() as session:
self.assertEqual(self.client, session.client)
@async_client_context.require_sessions
async def test_session_in_transaction_property(self):
"""Test in_transaction property."""
if async_client_context.is_rs or async_client_context.is_mongos:
async with self.client.start_session() as session:
self.assertFalse(session.in_transaction)
session.start_transaction()
self.assertTrue(session.in_transaction)
await session.abort_transaction()
self.assertFalse(session.in_transaction)
@async_client_context.require_sessions
async def test_session_context_manager(self):
"""Test session async context manager."""
async with self.client.start_session() as session:
self.assertFalse(session.has_ended)
await self.db.test.find_one({}, session=session)
self.assertTrue(session.has_ended)
@async_client_context.require_sessions
async def test_session_context_manager_exception(self):
"""Test session context manager closes on exception."""
session = None
try:
async with self.client.start_session() as session:
raise ValueError("test exception")
except ValueError:
pass
self.assertTrue(session.has_ended)
@async_client_context.require_sessions
async def test_session_operations_after_end(self):
"""Test operations on ended session raise InvalidOperation."""
session = self.client.start_session()
await session.end_session()
with self.assertRaises(InvalidOperation):
await self.db.test.find_one({}, session=session)
@async_client_context.require_sessions
async def test_session_end_session_idempotent(self):
"""Test that end_session can be called multiple times."""
session = self.client.start_session()
await session.end_session()
# Second call should not raise
await session.end_session()
self.assertTrue(session.has_ended)
@async_client_context.require_transactions
async def test_transaction_start_without_prior_transaction(self):
"""Test start_transaction on fresh session."""
async with self.client.start_session() as session:
session.start_transaction()
self.assertTrue(session.in_transaction)
await session.abort_transaction()
@async_client_context.require_transactions
async def test_transaction_start_twice_raises(self):
"""Test starting transaction twice raises error."""
async with self.client.start_session() as session:
session.start_transaction()
with self.assertRaises(InvalidOperation):
session.start_transaction()
await session.abort_transaction()
@async_client_context.require_transactions
async def test_transaction_abort_without_transaction_raises(self):
"""Test aborting without transaction raises error."""
async with self.client.start_session() as session:
with self.assertRaises(InvalidOperation):
await session.abort_transaction()
@async_client_context.require_transactions
async def test_transaction_commit_without_transaction_raises(self):
"""Test committing without transaction raises error."""
async with self.client.start_session() as session:
with self.assertRaises(InvalidOperation):
await session.commit_transaction()
@async_client_context.require_sessions
async def test_session_advance_cluster_time_validation(self):
"""Test advance_cluster_time with invalid input."""
async with self.client.start_session() as session:
with self.assertRaises(TypeError):
session.advance_cluster_time("invalid") # type: ignore
with self.assertRaises(ValueError):
session.advance_cluster_time({})
@async_client_context.require_sessions
async def test_session_advance_operation_time_validation(self):
"""Test advance_operation_time with invalid input."""
from bson import Timestamp
async with self.client.start_session() as session:
with self.assertRaises(TypeError):
session.advance_operation_time("invalid") # type: ignore
# Valid Timestamp should work
session.advance_operation_time(Timestamp(1, 1))
@async_client_context.require_transactions
async def test_with_transaction_callback_success(self):
"""Test with_transaction with successful callback."""
async with self.client.start_session() as session:
async def callback(session):
await self.db.test.insert_one({"x": 1}, session=session)
return "success"
result = await session.with_transaction(callback)
self.assertEqual("success", result)
@async_client_context.require_transactions
async def test_with_transaction_callback_exception(self):
"""Test with_transaction with callback exception."""
async with self.client.start_session() as session:
async def callback(session):
await self.db.test.insert_one({"x": 1}, session=session)
raise ValueError("callback error")
with self.assertRaises(ValueError):
await session.with_transaction(callback)
# Transaction should be aborted
self.assertFalse(session.in_transaction)
class AsyncTestSessionOptionsCoverage(AsyncUnitTest):
"""Tests for SessionOptions coverage."""
def test_session_options_defaults(self):
"""Test SessionOptions default values."""
from pymongo.asynchronous.client_session import SessionOptions
options = SessionOptions()
self.assertTrue(options.causal_consistency)
self.assertIsNone(options.default_transaction_options)
self.assertFalse(options.snapshot)
def test_session_options_snapshot_disables_causal_consistency(self):
"""Test snapshot=True forces causal_consistency=False."""
from pymongo.asynchronous.client_session import SessionOptions
options = SessionOptions(snapshot=True)
self.assertFalse(options.causal_consistency)
self.assertTrue(options.snapshot)
def test_session_options_snapshot_with_causal_raises(self):
"""Test snapshot=True with causal_consistency=True raises error."""
from pymongo.asynchronous.client_session import SessionOptions
with self.assertRaises(ConfigurationError):
SessionOptions(snapshot=True, causal_consistency=True)
def test_session_options_invalid_transaction_options(self):
"""Test SessionOptions with invalid transaction options type."""
from pymongo.asynchronous.client_session import SessionOptions
with self.assertRaises(TypeError):
SessionOptions(default_transaction_options="invalid") # type: ignore
class AsyncTestTransactionOptionsCoverage(AsyncUnitTest):
"""Tests for TransactionOptions coverage."""
def test_transaction_options_defaults(self):
"""Test TransactionOptions default values."""
from pymongo.asynchronous.client_session import TransactionOptions
options = TransactionOptions()
self.assertIsNone(options.read_concern)
self.assertIsNone(options.write_concern)
self.assertIsNone(options.read_preference)
self.assertIsNone(options.max_commit_time_ms)
def test_transaction_options_with_values(self):
"""Test TransactionOptions with all values set."""
from pymongo.asynchronous.client_session import TransactionOptions
options = TransactionOptions(
read_concern=ReadConcern("majority"),
write_concern=WriteConcern(w="majority"),
read_preference=ReadPreference.PRIMARY,
max_commit_time_ms=5000,
)
self.assertEqual("majority", options.read_concern.level)
self.assertEqual("majority", options.write_concern.document.get("w"))
self.assertEqual(ReadPreference.PRIMARY, options.read_preference)
self.assertEqual(5000, options.max_commit_time_ms)
def test_transaction_options_invalid_read_concern(self):
"""Test TransactionOptions with invalid read_concern type."""
from pymongo.asynchronous.client_session import TransactionOptions
with self.assertRaises(TypeError):
TransactionOptions(read_concern="invalid") # type: ignore
def test_transaction_options_invalid_write_concern(self):
"""Test TransactionOptions with invalid write_concern type."""
from pymongo.asynchronous.client_session import TransactionOptions
with self.assertRaises(TypeError):
TransactionOptions(write_concern="invalid") # type: ignore
def test_transaction_options_invalid_read_preference(self):
"""Test TransactionOptions with invalid read_preference type."""
from pymongo.asynchronous.client_session import TransactionOptions
with self.assertRaises(TypeError):
TransactionOptions(read_preference="invalid") # type: ignore
def test_transaction_options_invalid_max_commit_time(self):
"""Test TransactionOptions with invalid max_commit_time_ms type."""
from pymongo.asynchronous.client_session import TransactionOptions
with self.assertRaises(TypeError):
TransactionOptions(max_commit_time_ms="invalid") # type: ignore
def test_transaction_options_unacknowledged_write_concern(self):
"""Test TransactionOptions rejects unacknowledged write concern."""
from pymongo.asynchronous.client_session import TransactionOptions
with self.assertRaises(ConfigurationError):
TransactionOptions(write_concern=WriteConcern(w=0))
if __name__ == "__main__":
unittest.main()

View File

@ -779,6 +779,225 @@ class TestBulk(BulkTestBase):
self.assertEqual(6, result.inserted_count)
self.assertEqual(6, self.coll.count_documents({}))
def test_bulk_write_with_comment(self):
"""Test bulk write operations with comment parameter."""
requests = [
InsertOne({"x": 1}),
UpdateOne({"x": 1}, {"$set": {"y": 1}}),
DeleteOne({"x": 1}),
]
result = self.coll.bulk_write(requests, comment="bulk_comment")
self.assertEqual(1, result.inserted_count)
self.assertEqual(1, result.modified_count)
self.assertEqual(1, result.deleted_count)
def test_bulk_write_with_let(self):
"""Test bulk write operations with let parameter."""
if not client_context.version.at_least(5, 0):
self.skipTest("let parameter requires MongoDB 5.0+")
self.coll.insert_one({"x": 1})
requests = [
UpdateOne({"$expr": {"$eq": ["$x", "$$targetVal"]}}, {"$set": {"updated": True}}),
]
result = self.coll.bulk_write(requests, let={"targetVal": 1})
self.assertEqual(1, result.modified_count)
def test_bulk_write_all_operation_types(self):
"""Test bulk write with all operation types combined."""
self.coll.insert_many([{"x": i} for i in range(5)])
requests = [
InsertOne({"x": 100}),
UpdateOne({"x": 0}, {"$set": {"updated": True}}),
UpdateMany({"x": {"$lte": 2}}, {"$set": {"batch_updated": True}}),
ReplaceOne({"x": 3}, {"x": 3, "replaced": True}),
DeleteOne({"x": 4}),
DeleteMany({"x": {"$gt": 50}}),
]
result = self.coll.bulk_write(requests)
self.assertEqual(1, result.inserted_count)
self.assertGreaterEqual(result.modified_count, 1)
self.assertGreaterEqual(result.deleted_count, 1)
def test_bulk_write_unordered(self):
"""Test unordered bulk write continues after error."""
self.coll.create_index([("x", 1)], unique=True)
self.addCleanup(self.coll.drop_index, [("x", 1)])
requests = [
InsertOne({"x": 1}),
InsertOne({"x": 1}), # Duplicate - will error
InsertOne({"x": 2}),
InsertOne({"x": 3}),
]
with self.assertRaises(BulkWriteError) as ctx:
self.coll.bulk_write(requests, ordered=False)
# With unordered, should have inserted 3 documents
self.assertEqual(3, ctx.exception.details["nInserted"])
def test_bulk_write_ordered(self):
"""Test ordered bulk write stops on first error."""
self.coll.create_index([("x", 1)], unique=True)
self.addCleanup(self.coll.drop_index, [("x", 1)])
requests = [
InsertOne({"x": 1}),
InsertOne({"x": 1}), # Duplicate - will error
InsertOne({"x": 2}),
InsertOne({"x": 3}),
]
with self.assertRaises(BulkWriteError) as ctx:
self.coll.bulk_write(requests, ordered=True)
# With ordered, should have inserted only 1 document
self.assertEqual(1, ctx.exception.details["nInserted"])
def test_bulk_write_bypass_document_validation(self):
"""Test bulk write with bypass_document_validation."""
if not client_context.version.at_least(3, 2):
self.skipTest("bypass_document_validation requires MongoDB 3.2+")
# Create collection with validator
self.coll.drop()
self.db.create_collection(self.coll.name, validator={"$jsonSchema": {"required": ["name"]}})
# Without bypass, should fail
with self.assertRaises(BulkWriteError):
self.coll.bulk_write([InsertOne({"x": 1})])
# With bypass, should succeed
result = self.coll.bulk_write([InsertOne({"x": 1})], bypass_document_validation=True)
self.assertEqual(1, result.inserted_count)
def test_bulk_write_result_properties(self):
"""Test all BulkWriteResult properties."""
self.coll.insert_one({"x": 1})
requests = [
InsertOne({"x": 2}),
UpdateOne({"x": 1}, {"$set": {"updated": True}}),
ReplaceOne({"x": 2}, {"x": 2, "replaced": True}, upsert=True),
DeleteOne({"x": 1}),
]
result = self.coll.bulk_write(requests)
# Check all properties
self.assertTrue(result.acknowledged)
self.assertEqual(1, result.inserted_count)
self.assertGreaterEqual(result.matched_count, 0)
self.assertGreaterEqual(result.modified_count, 0)
self.assertEqual(1, result.deleted_count)
self.assertIsInstance(result.upserted_count, int)
self.assertIsInstance(result.upserted_ids, dict)
def test_bulk_write_with_upsert(self):
"""Test bulk write upsert operations."""
requests = [
UpdateOne({"x": 1}, {"$set": {"y": 1}}, upsert=True),
UpdateOne({"x": 2}, {"$set": {"y": 2}}, upsert=True),
ReplaceOne({"x": 3}, {"x": 3, "y": 3}, upsert=True),
]
result = self.coll.bulk_write(requests)
self.assertEqual(3, result.upserted_count)
self.assertEqual(3, len(result.upserted_ids))
def test_update_one_with_hint(self):
"""Test UpdateOne with hint parameter."""
self.coll.create_index([("x", 1)])
self.addCleanup(self.coll.drop_index, [("x", 1)])
self.coll.insert_one({"x": 1})
requests = [UpdateOne({"x": 1}, {"$set": {"y": 1}}, hint=[("x", 1)])]
result = self.coll.bulk_write(requests)
self.assertEqual(1, result.modified_count)
def test_update_many_with_hint(self):
"""Test UpdateMany with hint parameter."""
self.coll.create_index([("x", 1)])
self.addCleanup(self.coll.drop_index, [("x", 1)])
self.coll.insert_many([{"x": 1}, {"x": 1}])
requests = [UpdateMany({"x": 1}, {"$set": {"y": 1}}, hint=[("x", 1)])]
result = self.coll.bulk_write(requests)
self.assertEqual(2, result.modified_count)
def test_delete_one_with_hint(self):
"""Test DeleteOne with hint parameter."""
self.coll.create_index([("x", 1)])
self.addCleanup(self.coll.drop_index, [("x", 1)])
self.coll.insert_one({"x": 1})
requests = [DeleteOne({"x": 1}, hint=[("x", 1)])]
result = self.coll.bulk_write(requests)
self.assertEqual(1, result.deleted_count)
def test_delete_many_with_hint(self):
"""Test DeleteMany with hint parameter."""
self.coll.create_index([("x", 1)])
self.addCleanup(self.coll.drop_index, [("x", 1)])
self.coll.insert_many([{"x": 1}, {"x": 1}])
requests = [DeleteMany({"x": 1}, hint=[("x", 1)])]
result = self.coll.bulk_write(requests)
self.assertEqual(2, result.deleted_count)
def test_update_one_with_array_filters(self):
"""Test UpdateOne with array_filters parameter."""
self.coll.insert_one({"x": [{"y": 1}, {"y": 2}, {"y": 3}]})
requests = [
UpdateOne({}, {"$set": {"x.$[elem].z": 1}}, array_filters=[{"elem.y": {"$gt": 1}}])
]
result = self.coll.bulk_write(requests)
self.assertEqual(1, result.modified_count)
doc = self.coll.find_one()
# Elements with y > 1 should have z = 1
for elem in doc["x"]:
if elem["y"] > 1:
self.assertEqual(1, elem.get("z"))
def test_replace_one_with_hint(self):
"""Test ReplaceOne with hint parameter."""
self.coll.create_index([("x", 1)])
self.addCleanup(self.coll.drop_index, [("x", 1)])
self.coll.insert_one({"x": 1})
requests = [ReplaceOne({"x": 1}, {"x": 1, "replaced": True}, hint=[("x", 1)])]
result = self.coll.bulk_write(requests)
self.assertEqual(1, result.modified_count)
def test_update_with_collation(self):
"""Test update operations with collation."""
self.coll.insert_many(
[
{"name": "cafe"},
{"name": "Cafe"},
]
)
requests = [
UpdateMany(
{"name": "cafe"},
{"$set": {"updated": True}},
collation={"locale": "en", "strength": 2},
)
]
result = self.coll.bulk_write(requests)
# With case-insensitive collation, both docs should match
self.assertEqual(2, result.modified_count)
class BulkAuthorizationTestBase(BulkTestBase):
@client_context.require_auth

View File

@ -1132,5 +1132,122 @@ globals().update(
)
class TestChangeStreamCoverage(TestCollectionChangeStream):
"""Additional tests to improve code coverage for ChangeStream."""
def test_change_stream_alive_property(self):
"""Test alive property state transitions."""
with self.change_stream() as cs:
self.assertTrue(cs.alive)
# After context exit, should be closed
self.assertFalse(cs.alive)
def test_change_stream_idempotent_close(self):
"""Test that close() can be called multiple times safely."""
cs = self.change_stream()
cs.close()
# Second close should not raise
cs.close()
self.assertFalse(cs.alive)
def test_change_stream_resume_token_deepcopy(self):
"""Test that resume_token returns a deep copy."""
coll = self.watched_collection()
with self.change_stream() as cs:
coll.insert_one({"x": 1})
next(cs) # Consume the change event
token1 = cs.resume_token
token2 = cs.resume_token
# Should be equal but different objects
self.assertEqual(token1, token2)
self.assertIsNot(token1, token2)
def test_change_stream_with_comment(self):
"""Test change stream with comment parameter."""
client, listener = self.client_with_listener("aggregate")
try:
with self.change_stream_with_client(client, comment="test_comment"):
pass
finally:
client.close()
# Check that comment was in the aggregate command
self.assertGreater(len(listener.started_events), 0)
cmd = listener.started_events[0].command
self.assertEqual("test_comment", cmd.get("comment"))
def test_change_stream_with_show_expanded_events(self):
"""Test change stream with show_expanded_events parameter."""
if not client_context.version.at_least(6, 0):
self.skipTest("show_expanded_events requires MongoDB 6.0+")
with self.change_stream(show_expanded_events=True) as cs:
# Just verify it doesn't error
self.assertTrue(cs.alive)
@client_context.require_version_min(6, 0)
def test_change_stream_with_full_document_before_change(self):
"""Test change stream with full_document_before_change parameter."""
coll = self.watched_collection()
# Need to ensure collection exists with changeStreamPreAndPostImages enabled
coll.drop()
self.db.create_collection(coll.name, changeStreamPreAndPostImages={"enabled": True})
coll.insert_one({"x": 1})
with self.change_stream(full_document_before_change="whenAvailable") as cs:
coll.update_one({"x": 1}, {"$set": {"x": 2}})
change = next(cs)
self.assertEqual("update", change["operationType"])
# fullDocumentBeforeChange should be present
self.assertIn("fullDocumentBeforeChange", change)
def test_change_stream_next_after_close(self):
"""Test that next() on closed stream raises StopIteration."""
cs = self.change_stream()
cs.close()
with self.assertRaises(StopIteration):
next(cs)
def test_change_stream_try_next_after_close(self):
"""Test that try_next() on closed stream raises StopIteration."""
cs = self.change_stream()
cs.close()
with self.assertRaises(StopIteration):
cs.try_next()
def test_change_stream_pipeline_construction(self):
"""Test change stream pipeline is properly constructed."""
pipeline = [{"$match": {"operationType": "insert"}}]
client, listener = self.client_with_listener("aggregate")
try:
with self.change_stream_with_client(client, pipeline=pipeline):
pass
finally:
client.close()
cmd = listener.started_events[0].command
agg_pipeline = cmd["pipeline"]
# First stage should be $changeStream
self.assertIn("$changeStream", agg_pipeline[0])
# Second stage should be our match
self.assertEqual({"$match": {"operationType": "insert"}}, agg_pipeline[1])
def test_change_stream_empty_pipeline(self):
"""Test change stream with empty pipeline."""
with self.change_stream(pipeline=[]) as cs:
self.assertTrue(cs.alive)
def test_change_stream_context_manager_exception(self):
"""Test change stream context manager closes on exception."""
cs = None
try:
with self.change_stream() as cs:
raise ValueError("test exception")
except ValueError:
pass
# Stream should be closed
self.assertFalse(cs.alive)
if __name__ == "__main__":
unittest.main()

View File

@ -2238,5 +2238,262 @@ class TestCollection(IntegrationTest):
helper(*args, let={}) # type: ignore
class TestCollectionCoverage(IntegrationTest):
"""Additional tests to improve code coverage for Collection."""
def setUp(self):
super().setUp()
self.db.test.drop()
self.db.test.insert_many([{"x": i, "y": i * 2} for i in range(10)])
def test_collection_full_name(self):
"""Test full_name property."""
expected = f"{self.db.name}.test"
self.assertEqual(expected, self.db.test.full_name)
def test_collection_name(self):
"""Test name property."""
self.assertEqual("test", self.db.test.name)
def test_collection_database(self):
"""Test database property."""
self.assertEqual(self.db, self.db.test.database)
def test_collection_equality(self):
"""Test collection equality."""
coll1 = self.db.test
coll2 = self.db.test
coll3 = self.db.other
self.assertEqual(coll1, coll2)
self.assertNotEqual(coll1, coll3)
def test_collection_hash(self):
"""Test collection hashability."""
coll1 = self.db.test
coll2 = self.db.test
# Same collection should have same hash
self.assertEqual(hash(coll1), hash(coll2))
# Collections can be used in sets
s = {coll1, coll2}
self.assertEqual(1, len(s))
def test_collection_repr(self):
"""Test collection repr."""
coll = self.db.test
repr_str = repr(coll)
self.assertIn("test", repr_str)
self.assertIn("Collection", repr_str)
def test_collection_getattr(self):
"""Test sub-collection access via attribute."""
subcoll = self.db.test.subcollection
self.assertEqual("test.subcollection", subcoll.name)
def test_collection_getitem(self):
"""Test sub-collection access via indexing."""
subcoll = self.db.test["subcollection"]
self.assertEqual("test.subcollection", subcoll.name)
def test_collection_with_options(self):
"""Test with_options creates new collection with options."""
from pymongo.read_concern import ReadConcern
from pymongo.write_concern import WriteConcern
coll = self.db.test.with_options(
read_concern=ReadConcern("majority"), write_concern=WriteConcern(w=1)
)
self.assertEqual("majority", coll.read_concern.level)
self.assertEqual({"w": 1}, coll.write_concern.document)
# Original should be unchanged
self.assertNotEqual("majority", self.db.test.read_concern.level)
def test_collection_drop(self):
"""Test collection drop."""
self.db.test_drop.insert_one({"x": 1})
self.db.test_drop.drop()
names = self.db.list_collection_names()
self.assertNotIn("test_drop", names)
def test_collection_drop_with_comment(self):
"""Test collection drop with comment."""
self.db.test_drop_comment.insert_one({"x": 1})
self.db.test_drop_comment.drop(comment="test_comment")
names = self.db.list_collection_names()
self.assertNotIn("test_drop_comment", names)
def test_find_raw_batches(self):
"""Test find_raw_batches returns raw BSON."""
from bson import decode_all
cursor = self.db.test.find_raw_batches(batch_size=5)
batch_count = 0
for batch in cursor:
self.assertIsInstance(batch, bytes)
docs = decode_all(batch)
self.assertGreater(len(docs), 0)
batch_count += 1
self.assertGreater(batch_count, 0)
def test_aggregate_raw_batches(self):
"""Test aggregate_raw_batches returns raw BSON."""
from bson import decode_all
cursor = self.db.test.aggregate_raw_batches([{"$sort": {"x": 1}}], batchSize=5)
batch_count = 0
for batch in cursor:
self.assertIsInstance(batch, bytes)
docs = decode_all(batch)
self.assertGreater(len(docs), 0)
batch_count += 1
self.assertGreater(batch_count, 0)
def test_distinct_with_collation(self):
"""Test distinct with collation."""
self.db.test.drop()
self.db.test.insert_many(
[
{"name": "abc"},
{"name": "ABC"},
{"name": "def"},
]
)
# Case-insensitive distinct
values = self.db.test.distinct("name", collation={"locale": "en_US", "strength": 2})
# abc and ABC should be considered the same
self.assertEqual(2, len(values))
def test_count_documents_with_options(self):
"""Test count_documents with skip, limit, hint."""
self.db.test.create_index([("x", 1)])
count = self.db.test.count_documents({"x": {"$gte": 0}}, skip=2, limit=5, hint=[("x", 1)])
self.assertEqual(5, count)
def test_estimated_document_count(self):
"""Test estimated_document_count."""
count = self.db.test.estimated_document_count()
self.assertEqual(10, count)
def test_estimated_document_count_with_options(self):
"""Test estimated_document_count with maxTimeMS and comment."""
count = self.db.test.estimated_document_count(maxTimeMS=5000, comment="test_comment")
self.assertEqual(10, count)
def test_find_one_and_delete_with_options(self):
"""Test find_one_and_delete with projection, sort."""
doc = self.db.test.find_one_and_delete(
{"x": {"$gte": 0}}, projection={"x": 1}, sort=[("x", -1)]
)
self.assertEqual(9, doc["x"])
self.assertNotIn("y", doc)
def test_find_one_and_replace_with_options(self):
"""Test find_one_and_replace with various options."""
from pymongo import ReturnDocument
doc = self.db.test.find_one_and_replace(
{"x": 0},
{"x": 0, "replaced": True},
projection={"x": 1, "replaced": 1},
return_document=ReturnDocument.AFTER,
)
self.assertEqual(0, doc["x"])
self.assertTrue(doc.get("replaced"))
def test_find_one_and_update_with_options(self):
"""Test find_one_and_update with various options."""
from pymongo import ReturnDocument
doc = self.db.test.find_one_and_update(
{"x": 0},
{"$set": {"updated": True}},
projection={"x": 1, "updated": 1},
return_document=ReturnDocument.AFTER,
)
self.assertEqual(0, doc["x"])
self.assertTrue(doc.get("updated"))
def test_update_one_with_array_filters(self):
"""Test update_one with array_filters."""
self.db.test.drop()
self.db.test.insert_one({"items": [{"v": 1}, {"v": 2}, {"v": 3}]})
result = self.db.test.update_one(
{}, {"$set": {"items.$[elem].updated": True}}, array_filters=[{"elem.v": {"$gt": 1}}]
)
self.assertEqual(1, result.modified_count)
def test_update_many_with_hint(self):
"""Test update_many with hint."""
self.db.test.create_index([("x", 1)])
result = self.db.test.update_many(
{"x": {"$gte": 0}}, {"$set": {"batch_updated": True}}, hint=[("x", 1)]
)
self.assertEqual(10, result.modified_count)
def test_delete_one_with_hint(self):
"""Test delete_one with hint."""
self.db.test.create_index([("x", 1)])
result = self.db.test.delete_one({"x": 0}, hint=[("x", 1)])
self.assertEqual(1, result.deleted_count)
def test_delete_many_with_hint(self):
"""Test delete_many with hint."""
self.db.test.create_index([("x", 1)])
result = self.db.test.delete_many({"x": {"$lt": 5}}, hint=[("x", 1)])
self.assertEqual(5, result.deleted_count)
def test_aggregate_with_let(self):
"""Test aggregate with let parameter."""
if not client_context.version.at_least(5, 0):
self.skipTest("let parameter requires MongoDB 5.0+")
pipeline = [{"$match": {"$expr": {"$eq": ["$x", "$$targetVal"]}}}]
cursor = self.db.test.aggregate(pipeline, let={"targetVal": 5})
docs = cursor.to_list()
self.assertEqual(1, len(docs))
self.assertEqual(5, docs[0]["x"])
def test_aggregate_with_batch_size(self):
"""Test aggregate with batchSize."""
cursor = self.db.test.aggregate([{"$sort": {"x": 1}}], batchSize=2)
docs = cursor.to_list()
self.assertEqual(10, len(docs))
def test_list_indexes(self):
"""Test list_indexes returns cursor."""
self.db.test.create_index([("x", 1)])
cursor = self.db.test.list_indexes()
# Should get at least the _id index
indexes = cursor.to_list()
self.assertGreaterEqual(len(indexes), 1)
index_names = [idx["name"] for idx in indexes]
self.assertIn("_id_", index_names)
def test_index_information(self):
"""Test index_information returns dict."""
self.db.test.create_index([("x", 1)], name="x_index")
info = self.db.test.index_information()
self.assertIsInstance(info, dict)
self.assertIn("_id_", info)
self.assertIn("x_index", info)
def test_options_method(self):
"""Test options() returns collection options."""
# Create a capped collection
self.db.drop_collection("test_capped")
self.db.create_collection("test_capped", capped=True, size=10000)
opts = self.db.test_capped.options()
self.assertTrue(opts.get("capped"))
self.db.drop_collection("test_capped")
if __name__ == "__main__":
unittest.main()

View File

@ -1853,5 +1853,404 @@ class TestRawBatchCommandCursor(IntegrationTest):
self.assertEqual(cmd.command["$db"], "pymongo_test")
class TestCursorCoverage(IntegrationTest):
"""Additional tests to improve code coverage for Cursor."""
def setUp(self):
super().setUp()
self.db.test.drop()
self.db.test.insert_many([{"x": i, "y": i * 2} for i in range(10)])
def test_get_namespace(self):
"""Test _get_namespace() method."""
cursor = self.db.test.find()
expected_ns = f"{self.db.name}.test"
self.assertEqual(expected_ns, cursor._get_namespace())
def test_cursor_alive_property_states(self):
"""Test cursor alive property in different states."""
cursor = self.db.test.find()
# Cursor is alive even before starting (has potential to return data)
self.assertTrue(cursor.alive)
# Start the cursor
next(cursor)
self.assertTrue(cursor.alive)
# Exhaust the cursor
cursor.to_list()
self.assertFalse(cursor.alive)
def test_cursor_closed_property(self):
"""Test cursor behavior after close."""
cursor = self.db.test.find()
next(cursor)
self.assertTrue(cursor.alive)
cursor.close()
# After close, cursor is killed (check internal _killed flag)
self.assertTrue(cursor._killed)
def test_retrieved_property(self):
"""Test the retrieved property tracking."""
cursor = self.db.test.find().batch_size(2)
self.assertEqual(0, cursor.retrieved)
next(cursor)
self.assertGreater(cursor.retrieved, 0)
def test_cursor_with_let_parameter(self):
"""Test cursor with let parameter."""
# let parameter allows variables to be used in the filter
cursor = self.db.test.find(
{"$expr": {"$eq": ["$x", "$$targetValue"]}}, let={"targetValue": 5}
)
docs = cursor.to_list()
self.assertEqual(1, len(docs))
self.assertEqual(5, docs[0]["x"])
def test_cursor_with_invalid_let_parameter(self):
"""Test cursor raises error for invalid let parameter."""
with self.assertRaises(TypeError):
self.db.test.find(let="invalid") # type: ignore[arg-type]
def test_cursor_with_show_record_id(self):
"""Test cursor with show_record_id option."""
cursor = self.db.test.find(show_record_id=True)
doc = next(cursor)
self.assertIn("$recordId", doc)
def test_cursor_with_return_key(self):
"""Test cursor with return_key option."""
self.db.test.create_index([("x", ASCENDING)])
cursor = self.db.test.find({"x": 5}, return_key=True).hint([("x", ASCENDING)])
doc = next(cursor)
# return_key returns only index keys
self.assertIn("x", doc)
self.assertNotIn("y", doc)
def test_check_okay_to_chain_after_iteration(self):
"""Test that cursor configuration methods raise after iteration."""
cursor = self.db.test.find()
next(cursor) # Start iteration
# All these should raise InvalidOperation
with self.assertRaises(InvalidOperation):
cursor.limit(5)
with self.assertRaises(InvalidOperation):
cursor.skip(2)
with self.assertRaises(InvalidOperation):
cursor.sort("x")
with self.assertRaises(InvalidOperation):
cursor.hint([("x", ASCENDING)])
with self.assertRaises(InvalidOperation):
cursor.max([("x", 10)])
with self.assertRaises(InvalidOperation):
cursor.min([("x", 0)])
with self.assertRaises(InvalidOperation):
cursor.add_option(2)
with self.assertRaises(InvalidOperation):
cursor.remove_option(2)
with self.assertRaises(InvalidOperation):
cursor.batch_size(10)
with self.assertRaises(InvalidOperation):
cursor.max_time_ms(1000)
with self.assertRaises(InvalidOperation):
cursor.collation(Collation("en_US"))
with self.assertRaises(InvalidOperation):
cursor.allow_disk_use(True)
with self.assertRaises(InvalidOperation):
cursor.where("this.x > 5")
with self.assertRaises(InvalidOperation):
cursor.comment("test")
def test_cursor_context_manager(self):
"""Test cursor as async context manager."""
with self.db.test.find() as cursor:
doc = next(cursor)
self.assertIsNotNone(doc)
# Cursor should be killed after context (check _killed flag)
self.assertTrue(cursor._killed)
def test_cursor_context_manager_with_exception(self):
"""Test cursor context manager closes on exception."""
cursor = None
try:
with self.db.test.find() as cursor:
next(cursor)
raise ValueError("test exception")
except ValueError:
pass
# Cursor should be killed after exception
self.assertTrue(cursor._killed)
def test_cursor_collation(self):
"""Test cursor with collation."""
self.db.test.drop()
self.db.test.insert_many([{"name": "abc"}, {"name": "ABC"}, {"name": "def"}])
# Case-insensitive sort
cursor = (
self.db.test.find().collation(Collation("en_US", strength=2)).sort("name", ASCENDING)
)
docs = cursor.to_list()
self.assertEqual(3, len(docs))
def test_cursor_collation_type_error(self):
"""Test cursor raises error for invalid collation."""
with self.assertRaises(TypeError):
self.db.test.find().collation("invalid") # type: ignore[arg-type]
def test_cursor_getitem_not_supported(self):
"""Test that Cursor does not support indexing."""
cursor = self.db.test.find()
with self.assertRaises(IndexError) as ctx:
cursor[5]
self.assertIn("does not support indexing", str(ctx.exception))
def test_cursor_next_after_close(self):
"""Test that next() raises StopIteration after close."""
cursor = self.db.test.find()
cursor.close()
with self.assertRaises(StopIteration):
next(cursor)
def test_cursor_rewind_resets_state(self):
"""Test that rewind properly resets cursor state."""
cursor = self.db.test.find().limit(3)
# Iterate fully
docs1 = cursor.to_list()
self.assertEqual(3, len(docs1))
self.assertEqual(0, len(cursor._data))
# Rewind and iterate again
cursor.rewind()
docs2 = cursor.to_list()
self.assertEqual(3, len(docs2))
self.assertEqual(docs1, docs2)
def test_cursor_clone_with_session(self):
"""Test that clone preserves explicit session."""
with self.client.start_session() as session:
cursor = self.db.test.find(session=session)
cloned = cursor.clone()
# Clone should reference the same session
self.assertEqual(cursor.session, cloned.session)
def test_cursor_clone_without_session(self):
"""Test that clone without session doesn't add one."""
cursor = self.db.test.find()
cloned = cursor.clone()
# Clone should have no session if original had none
self.assertIsNone(cloned.session)
def test_cursor_distinct_with_collation(self):
"""Test distinct with collation."""
self.db.test.drop()
self.db.test.insert_many([{"name": "abc"}, {"name": "ABC"}, {"name": "def"}])
# Case-insensitive distinct
cursor = self.db.test.find().collation(Collation("en_US", strength=2))
# distinct() on cursor with collation
values = cursor.distinct("name")
# Should have 2 distinct values (abc/ABC treated as same)
self.assertEqual(2, len(values))
def test_cursor_explain_with_options(self):
"""Test explain with cursor options set."""
cursor = self.db.test.find({"x": {"$gt": 5}}).sort("x", ASCENDING).limit(5).skip(1)
explanation = cursor.explain()
self.assertIn("queryPlanner", explanation)
def test_cursor_max_time_ms_type_errors(self):
"""Test max_time_ms raises TypeError for invalid input."""
cursor = self.db.test.find()
with self.assertRaises(TypeError):
cursor.max_time_ms("invalid") # type: ignore[arg-type]
def test_cursor_max_await_time_ms_type_errors(self):
"""Test max_await_time_ms raises TypeError for invalid input."""
cursor = self.db.test.find()
with self.assertRaises(TypeError):
cursor.max_await_time_ms("invalid") # type: ignore[arg-type]
def test_cursor_comment_type(self):
"""Test cursor with comment of various types."""
# String comment
cursor1 = self.db.test.find().comment("test comment")
docs1 = cursor1.to_list()
self.assertGreater(len(docs1), 0)
# Dict comment
cursor2 = self.db.test.find().comment({"key": "value"})
docs2 = cursor2.to_list()
self.assertGreater(len(docs2), 0)
def test_cursor_batch_size_validation(self):
"""Test batch_size validation."""
with self.assertRaises(TypeError):
self.db.test.find(batch_size="invalid") # type: ignore[arg-type]
with self.assertRaises(ValueError):
self.db.test.find(batch_size=-1)
def test_cursor_skip_validation(self):
"""Test skip validation."""
with self.assertRaises(TypeError):
self.db.test.find(skip="invalid") # type: ignore[arg-type]
def test_cursor_limit_validation(self):
"""Test limit validation."""
with self.assertRaises(TypeError):
self.db.test.find(limit="invalid") # type: ignore[arg-type]
def test_cursor_filter_validation(self):
"""Test filter validation."""
with self.assertRaises(TypeError):
self.db.test.find(filter="invalid") # type: ignore[arg-type]
def test_cursor_type_validation(self):
"""Test cursor_type validation."""
with self.assertRaises(ValueError):
self.db.test.find(cursor_type=999)
def test_cursor_query_spec_with_modifiers(self):
"""Test _query_spec includes modifiers."""
cursor = (
self.db.test.find()
.sort("x", ASCENDING)
.hint([("x", ASCENDING)])
.max_time_ms(1000)
.comment("test")
)
spec = cursor._query_spec()
self.assertIsInstance(spec, dict)
def test_cursor_copy(self):
"""Test cursor __copy__ returns clone."""
cursor = self.db.test.find().limit(5)
copied = copy.copy(cursor)
self.assertIsNot(cursor, copied)
self.assertEqual(cursor._limit, copied._limit)
def test_cursor_deepcopy(self):
"""Test cursor __deepcopy__ returns deep clone."""
cursor = self.db.test.find({"x": {"$gt": 0}}).limit(5)
copied = copy.deepcopy(cursor)
self.assertIsNot(cursor, copied)
self.assertEqual(cursor._limit, copied._limit)
self.assertEqual(cursor._spec, copied._spec)
# Spec should be a different object
self.assertIsNot(cursor._spec, copied._spec)
def test_cursor_iteration_protocol(self):
"""Test cursor async iteration protocol."""
cursor = self.db.test.find().limit(3)
# Test __iter__ returns self
self.assertIs(cursor, cursor.__iter__())
# Test __next__ returns documents
doc1 = cursor.__next__()
self.assertIsNotNone(doc1)
def test_cursor_to_list_with_limit(self):
"""Test to_list respects cursor limit."""
cursor = self.db.test.find().limit(3)
docs = cursor.to_list()
self.assertEqual(3, len(docs))
def test_cursor_to_list_with_length(self):
"""Test to_list with length parameter."""
cursor = self.db.test.find()
docs = cursor.to_list(length=3)
self.assertEqual(3, len(docs))
def test_min_max_require_hint(self):
"""Test that min/max require hint for proper execution."""
self.db.test.create_index([("x", ASCENDING)])
# min without hint should work when index exists
cursor = self.db.test.find().min([("x", 5)]).hint([("x", ASCENDING)])
docs = cursor.to_list()
self.assertTrue(all(doc["x"] >= 5 for doc in docs))
# max without hint should work when index exists
cursor = self.db.test.find().max([("x", 5)]).hint([("x", ASCENDING)])
docs = cursor.to_list()
self.assertTrue(all(doc["x"] < 5 for doc in docs))
def test_cursor_address_property(self):
"""Test cursor address is set after first batch."""
cursor = self.db.test.find()
self.assertIsNone(cursor.address)
next(cursor)
# Address should be set after query
self.assertIsNotNone(cursor.address)
def test_cursor_session_property(self):
"""Test cursor session property."""
# Cursor without explicit session
cursor1 = self.db.test.find()
self.assertIsNone(cursor1.session)
# Cursor with explicit session
with self.client.start_session() as session:
cursor2 = self.db.test.find(session=session)
self.assertEqual(session, cursor2.session)
def test_cursor_allow_disk_use_type_error(self):
"""Test allow_disk_use raises TypeError for invalid input."""
with self.assertRaises(TypeError):
self.db.test.find().allow_disk_use("invalid") # type: ignore[arg-type]
class TestRawBatchCursorCoverage(IntegrationTest):
"""Additional tests for RawBatchCursor coverage."""
def setUp(self):
super().setUp()
self.db.test.drop()
self.db.test.insert_many([{"x": i} for i in range(20)])
def test_raw_batch_cursor_iteration(self):
"""Test raw batch cursor returns raw BSON."""
cursor = self.db.test.find_raw_batches(batch_size=5)
batch_count = 0
for batch in cursor:
self.assertIsInstance(batch, bytes)
# Decode the batch to verify it's valid BSON
docs = decode_all(batch)
self.assertGreater(len(docs), 0)
batch_count += 1
self.assertGreater(batch_count, 0)
def test_raw_batch_cursor_explain(self):
"""Test raw batch cursor explain."""
cursor = self.db.test.find_raw_batches()
explanation = cursor.explain()
self.assertIn("queryPlanner", explanation)
def test_raw_batch_cursor_getitem_raises(self):
"""Test raw batch cursor __getitem__ raises InvalidOperation."""
cursor = self.db.test.find_raw_batches()
with self.assertRaises(InvalidOperation):
cursor[0]
def test_raw_batch_cursor_with_sort(self):
"""Test raw batch cursor with sort."""
cursor = self.db.test.find_raw_batches(batch_size=5).sort("x", DESCENDING)
first_batch = next(cursor)
docs = decode_all(first_batch)
# First doc should have highest x value
self.assertEqual(19, docs[0]["x"])
def test_raw_batch_cursor_with_limit(self):
"""Test raw batch cursor with limit."""
cursor = self.db.test.find_raw_batches(batch_size=5).limit(7)
all_docs = []
for batch in cursor:
all_docs.extend(decode_all(batch))
self.assertEqual(7, len(all_docs))
if __name__ == "__main__":
unittest.main()

View File

@ -50,9 +50,11 @@ from pymongo.common import _MAX_END_SESSIONS
from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure
from pymongo.operations import IndexModel, InsertOne, UpdateOne
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.cursor import Cursor
from pymongo.synchronous.helpers import next
from pymongo.write_concern import WriteConcern
_IS_SYNC = True
@ -1345,5 +1347,284 @@ class TestClusterTime(IntegrationTest):
self.assertEqual(started.command["$clusterTime"], cluster_time)
class TestClientSessionCoverage(IntegrationTest):
"""Additional tests to improve code coverage for ClientSession."""
@client_context.require_sessions
def test_session_has_ended_property(self):
"""Test has_ended property state transitions."""
session = self.client.start_session()
self.assertFalse(session.has_ended)
session.end_session()
self.assertTrue(session.has_ended)
@client_context.require_sessions
def test_session_session_id_property(self):
"""Test session_id property returns correct value."""
with self.client.start_session() as session:
session_id = session.session_id
self.assertIsInstance(session_id, dict)
self.assertIn("id", session_id)
@client_context.require_sessions
def test_session_cluster_time_operations(self):
"""Test cluster time advance operations."""
with self.client.start_session() as session:
# Initially None
self.assertIsNone(session.cluster_time)
# Perform operation to get cluster time
self.db.test.find_one({}, session=session)
# Cluster time should be set after operation
# (may still be None on some server versions)
@client_context.require_sessions
def test_session_operation_time_operations(self):
"""Test operation time advance operations."""
with self.client.start_session() as session:
# Initially None
self.assertIsNone(session.operation_time)
# Perform operation to get operation time
self.db.test.find_one({}, session=session)
@client_context.require_sessions
def test_session_options_property(self):
"""Test session options property."""
with self.client.start_session(causal_consistency=True) as session:
self.assertTrue(session.options.causal_consistency)
@client_context.require_sessions
def test_session_client_property(self):
"""Test session client property."""
with self.client.start_session() as session:
self.assertEqual(self.client, session.client)
@client_context.require_sessions
def test_session_in_transaction_property(self):
"""Test in_transaction property."""
if client_context.is_rs or client_context.is_mongos:
with self.client.start_session() as session:
self.assertFalse(session.in_transaction)
session.start_transaction()
self.assertTrue(session.in_transaction)
session.abort_transaction()
self.assertFalse(session.in_transaction)
@client_context.require_sessions
def test_session_context_manager(self):
"""Test session async context manager."""
with self.client.start_session() as session:
self.assertFalse(session.has_ended)
self.db.test.find_one({}, session=session)
self.assertTrue(session.has_ended)
@client_context.require_sessions
def test_session_context_manager_exception(self):
"""Test session context manager closes on exception."""
session = None
try:
with self.client.start_session() as session:
raise ValueError("test exception")
except ValueError:
pass
self.assertTrue(session.has_ended)
@client_context.require_sessions
def test_session_operations_after_end(self):
"""Test operations on ended session raise InvalidOperation."""
session = self.client.start_session()
session.end_session()
with self.assertRaises(InvalidOperation):
self.db.test.find_one({}, session=session)
@client_context.require_sessions
def test_session_end_session_idempotent(self):
"""Test that end_session can be called multiple times."""
session = self.client.start_session()
session.end_session()
# Second call should not raise
session.end_session()
self.assertTrue(session.has_ended)
@client_context.require_transactions
def test_transaction_start_without_prior_transaction(self):
"""Test start_transaction on fresh session."""
with self.client.start_session() as session:
session.start_transaction()
self.assertTrue(session.in_transaction)
session.abort_transaction()
@client_context.require_transactions
def test_transaction_start_twice_raises(self):
"""Test starting transaction twice raises error."""
with self.client.start_session() as session:
session.start_transaction()
with self.assertRaises(InvalidOperation):
session.start_transaction()
session.abort_transaction()
@client_context.require_transactions
def test_transaction_abort_without_transaction_raises(self):
"""Test aborting without transaction raises error."""
with self.client.start_session() as session:
with self.assertRaises(InvalidOperation):
session.abort_transaction()
@client_context.require_transactions
def test_transaction_commit_without_transaction_raises(self):
"""Test committing without transaction raises error."""
with self.client.start_session() as session:
with self.assertRaises(InvalidOperation):
session.commit_transaction()
@client_context.require_sessions
def test_session_advance_cluster_time_validation(self):
"""Test advance_cluster_time with invalid input."""
with self.client.start_session() as session:
with self.assertRaises(TypeError):
session.advance_cluster_time("invalid") # type: ignore
with self.assertRaises(ValueError):
session.advance_cluster_time({})
@client_context.require_sessions
def test_session_advance_operation_time_validation(self):
"""Test advance_operation_time with invalid input."""
from bson import Timestamp
with self.client.start_session() as session:
with self.assertRaises(TypeError):
session.advance_operation_time("invalid") # type: ignore
# Valid Timestamp should work
session.advance_operation_time(Timestamp(1, 1))
@client_context.require_transactions
def test_with_transaction_callback_success(self):
"""Test with_transaction with successful callback."""
with self.client.start_session() as session:
def callback(session):
self.db.test.insert_one({"x": 1}, session=session)
return "success"
result = session.with_transaction(callback)
self.assertEqual("success", result)
@client_context.require_transactions
def test_with_transaction_callback_exception(self):
"""Test with_transaction with callback exception."""
with self.client.start_session() as session:
def callback(session):
self.db.test.insert_one({"x": 1}, session=session)
raise ValueError("callback error")
with self.assertRaises(ValueError):
session.with_transaction(callback)
# Transaction should be aborted
self.assertFalse(session.in_transaction)
class TestSessionOptionsCoverage(UnitTest):
"""Tests for SessionOptions coverage."""
def test_session_options_defaults(self):
"""Test SessionOptions default values."""
from pymongo.synchronous.client_session import SessionOptions
options = SessionOptions()
self.assertTrue(options.causal_consistency)
self.assertIsNone(options.default_transaction_options)
self.assertFalse(options.snapshot)
def test_session_options_snapshot_disables_causal_consistency(self):
"""Test snapshot=True forces causal_consistency=False."""
from pymongo.synchronous.client_session import SessionOptions
options = SessionOptions(snapshot=True)
self.assertFalse(options.causal_consistency)
self.assertTrue(options.snapshot)
def test_session_options_snapshot_with_causal_raises(self):
"""Test snapshot=True with causal_consistency=True raises error."""
from pymongo.synchronous.client_session import SessionOptions
with self.assertRaises(ConfigurationError):
SessionOptions(snapshot=True, causal_consistency=True)
def test_session_options_invalid_transaction_options(self):
"""Test SessionOptions with invalid transaction options type."""
from pymongo.synchronous.client_session import SessionOptions
with self.assertRaises(TypeError):
SessionOptions(default_transaction_options="invalid") # type: ignore
class TestTransactionOptionsCoverage(UnitTest):
"""Tests for TransactionOptions coverage."""
def test_transaction_options_defaults(self):
"""Test TransactionOptions default values."""
from pymongo.synchronous.client_session import TransactionOptions
options = TransactionOptions()
self.assertIsNone(options.read_concern)
self.assertIsNone(options.write_concern)
self.assertIsNone(options.read_preference)
self.assertIsNone(options.max_commit_time_ms)
def test_transaction_options_with_values(self):
"""Test TransactionOptions with all values set."""
from pymongo.synchronous.client_session import TransactionOptions
options = TransactionOptions(
read_concern=ReadConcern("majority"),
write_concern=WriteConcern(w="majority"),
read_preference=ReadPreference.PRIMARY,
max_commit_time_ms=5000,
)
self.assertEqual("majority", options.read_concern.level)
self.assertEqual("majority", options.write_concern.document.get("w"))
self.assertEqual(ReadPreference.PRIMARY, options.read_preference)
self.assertEqual(5000, options.max_commit_time_ms)
def test_transaction_options_invalid_read_concern(self):
"""Test TransactionOptions with invalid read_concern type."""
from pymongo.synchronous.client_session import TransactionOptions
with self.assertRaises(TypeError):
TransactionOptions(read_concern="invalid") # type: ignore
def test_transaction_options_invalid_write_concern(self):
"""Test TransactionOptions with invalid write_concern type."""
from pymongo.synchronous.client_session import TransactionOptions
with self.assertRaises(TypeError):
TransactionOptions(write_concern="invalid") # type: ignore
def test_transaction_options_invalid_read_preference(self):
"""Test TransactionOptions with invalid read_preference type."""
from pymongo.synchronous.client_session import TransactionOptions
with self.assertRaises(TypeError):
TransactionOptions(read_preference="invalid") # type: ignore
def test_transaction_options_invalid_max_commit_time(self):
"""Test TransactionOptions with invalid max_commit_time_ms type."""
from pymongo.synchronous.client_session import TransactionOptions
with self.assertRaises(TypeError):
TransactionOptions(max_commit_time_ms="invalid") # type: ignore
def test_transaction_options_unacknowledged_write_concern(self):
"""Test TransactionOptions rejects unacknowledged write concern."""
from pymongo.synchronous.client_session import TransactionOptions
with self.assertRaises(ConfigurationError):
TransactionOptions(write_concern=WriteConcern(w=0))
if __name__ == "__main__":
unittest.main()