motor/test/asyncio_tests/test_asyncio_cursor.py

732 lines
25 KiB
Python

# Copyright 2014 MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test AsyncIOMotorCursor."""
import asyncio
import sys
import traceback
import unittest
import warnings
from functools import partial
from test.asyncio_tests import (
AsyncIOMockServerTestCase,
AsyncIOTestCase,
asyncio_test,
get_command_line,
server_is_mongos,
)
from test.test_environment import env
from test.utils import (
AUTO_ISMASTER,
FailPoint,
TestListener,
get_async_test_timeout,
get_primary_pool,
one,
safe_get,
wait_until,
)
from unittest import SkipTest
import bson
from pymongo import CursorType
from pymongo.errors import ExecutionTimeout, InvalidOperation, OperationFailure
from motor import motor_asyncio
class TestAsyncIOCursor(AsyncIOMockServerTestCase):
def test_cursor(self):
cursor = self.collection.find()
self.assertTrue(isinstance(cursor, motor_asyncio.AsyncIOMotorCursor))
self.assertFalse(cursor.started, "Cursor shouldn't start immediately")
@asyncio_test
async def test_count(self):
await self.make_test_data()
coll = self.collection
self.assertEqual(100, (await coll.count_documents({"_id": {"$gt": 99}})))
@asyncio_test
async def test_fetch_next(self):
await self.make_test_data()
coll = self.collection
# 200 results, only including _id field, sorted by _id.
cursor = coll.find({}, {"_id": 1}).sort("_id").batch_size(75)
self.assertEqual(None, cursor.cursor_id)
self.assertEqual(None, cursor.next_object()) # Haven't fetched yet.
i = 0
while await cursor.fetch_next:
self.assertEqual({"_id": i}, cursor.next_object())
i += 1
# With batch_size 75 and 200 results, cursor should be exhausted on
# the server by third fetch.
if i <= 150:
self.assertNotEqual(0, cursor.cursor_id)
else:
self.assertEqual(0, cursor.cursor_id)
self.assertEqual(False, (await cursor.fetch_next))
self.assertEqual(None, cursor.next_object())
self.assertEqual(0, cursor.cursor_id)
self.assertEqual(200, i)
@unittest.skipIf("PyPy" in sys.version, "PyPy")
@asyncio_test
async def test_fetch_next_delete(self):
client, server = self.client_server(auto_ismaster=AUTO_ISMASTER)
cursor = client.test.coll.find()
self.fetch_next(cursor)
request = await self.run_thread(server.receives, "find", "coll")
request.replies({"cursor": {"id": 123, "ns": "db.coll", "firstBatch": [{"_id": 1}]}})
# Decref the cursor and clear from the event loop.
del cursor
request = await self.run_thread(server.receives, "killCursors", "coll")
request.ok()
@asyncio_test
async def test_fetch_next_without_results(self):
coll = self.collection
# Nothing matches this query.
cursor = coll.find({"foo": "bar"})
self.assertEqual(None, cursor.next_object())
self.assertEqual(False, (await cursor.fetch_next))
self.assertEqual(None, cursor.next_object())
# Now cursor knows it's exhausted.
self.assertEqual(0, cursor.cursor_id)
@asyncio_test
async def test_fetch_next_is_idempotent(self):
# Subsequent calls to fetch_next don't do anything
await self.make_test_data()
coll = self.collection
cursor = coll.find()
self.assertEqual(None, cursor.cursor_id)
await cursor.fetch_next
self.assertTrue(cursor.cursor_id)
self.assertEqual(101, cursor._buffer_size())
await cursor.fetch_next # Does nothing
self.assertEqual(101, cursor._buffer_size())
await cursor.close()
@asyncio_test
async def test_fetch_next_exception(self):
coll = self.collection
await coll.insert_many([{} for _ in range(10)])
cursor = coll.find(batch_size=2)
await cursor.fetch_next
self.assertTrue(cursor.next_object())
# Not valid on server, causes CursorNotFound.
cursor.delegate._id = bson.int64.Int64(1234)
with self.assertRaises(OperationFailure):
await cursor.fetch_next
self.assertTrue(cursor.next_object())
await cursor.fetch_next
self.assertTrue(cursor.next_object())
@asyncio_test(timeout=30)
async def test_each(self):
await self.make_test_data()
cursor = self.collection.find({}, {"_id": 1}).sort("_id")
future = self.loop.create_future()
results = []
def callback(result, error):
if error:
raise error
if result is not None:
results.append(result)
else:
# Done iterating.
future.set_result(True)
cursor.each(callback)
await future
expected = [{"_id": i} for i in range(200)]
self.assertEqual(expected, results)
@asyncio_test
async def test_to_list_argument_checking(self):
# We need more than 10 documents so the cursor stays alive.
await self.make_test_data()
coll = self.collection
cursor = coll.find()
with self.assertRaises(ValueError):
await cursor.to_list(-1)
with self.assertRaises(TypeError):
await cursor.to_list("foo")
@asyncio_test
async def test_to_list_with_length(self):
await self.make_test_data()
coll = self.collection
cursor = coll.find().sort("_id")
def expected(start, stop):
return [{"_id": i} for i in range(start, stop)]
self.assertEqual(expected(0, 10), (await cursor.to_list(10)))
self.assertEqual(expected(10, 100), (await cursor.to_list(90)))
# Test particularly rigorously around the 101-doc mark, since this is
# where the first batch ends
self.assertEqual(expected(100, 101), (await cursor.to_list(1)))
self.assertEqual(expected(101, 102), (await cursor.to_list(1)))
self.assertEqual(expected(102, 103), (await cursor.to_list(1)))
self.assertEqual(expected(103, 105), (await cursor.to_list(2)))
# Only 95 docs left, make sure length=100 doesn't error or hang
self.assertEqual(expected(105, 200), (await cursor.to_list(100)))
self.assertEqual(0, cursor.cursor_id)
# Nothing left.
self.assertEqual([], (await cursor.to_list(100)))
await cursor.close()
@asyncio_test
async def test_to_list_multiple_getMores(self):
await self.make_test_data()
coll = self.collection
cursor = coll.find(batch_size=5).sort("_id")
def expected(start, stop):
return [{"_id": i} for i in range(start, stop)]
# 2 batches (find+getMore):
self.assertEqual(expected(0, 10), (await cursor.to_list(10)))
# 5 batches, stop in the middle of a batch:
self.assertEqual(expected(10, 33), (await cursor.to_list(23)))
# 33 batches:
self.assertEqual(expected(33, 200), (await cursor.to_list(167)))
# Nothing left.
self.assertEqual([], (await cursor.to_list(100)))
await cursor.close()
@asyncio_test
async def test_to_list_exc_info(self):
await self.make_test_data()
coll = self.collection
cursor = coll.find()
await cursor.to_list(length=10)
await self.collection.drop()
try:
await cursor.to_list(length=None)
except OperationFailure:
_, _, tb = sys.exc_info()
# The call tree should include PyMongo code we ran on a thread.
formatted = "\n".join(traceback.format_tb(tb))
self.assertTrue(
"_unpack_response" in formatted or "_check_command_response" in formatted
)
async def _test_cancelled_error(self, coro):
await self.make_test_data()
# Cause an error on a getMore after the cursor.to_list task is
# cancelled.
fp = {
"configureFailPoint": "failCommand",
"data": {"failCommands": ["getMore"], "errorCode": 96},
"mode": {"times": 1},
}
async with FailPoint(self.cx, fp):
cleanup, task = coro(self.collection)
task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task
await cleanup()
# Yield for some time to allow pending Cursor callbacks to run.
await asyncio.sleep(0.5)
@env.require_version_min(4, 2) # failCommand
@asyncio_test
async def test_cancelled_error_to_list(self):
# Note: We intentionally don't use "async def" here to avoid wrapping
# the returned to_list Future in a coroutine.
def to_list(collection):
cursor = collection.find(batch_size=2)
return cursor.close, cursor.to_list(None)
await self._test_cancelled_error(to_list)
@env.require_version_min(4, 2) # failCommand
@asyncio_test
async def test_cancelled_error_fetch_next(self):
def fetch_next(collection):
cursor = collection.find(batch_size=2)
return cursor.close, cursor.fetch_next
await self._test_cancelled_error(fetch_next)
@env.require_version_min(4, 2) # failCommand
@asyncio_test
async def test_cancelled_error_fetch_next_aggregate(self):
def fetch_next(collection):
cursor = collection.aggregate([], batchSize=2)
return cursor.close, cursor.fetch_next
await self._test_cancelled_error(fetch_next)
@asyncio_test
async def test_to_list_with_length_of_none(self):
await self.make_test_data()
collection = self.collection
cursor = collection.find()
docs = await cursor.to_list(None) # Unlimited.
count = await collection.count_documents({})
self.assertEqual(count, len(docs))
@asyncio_test
async def test_to_list_tailable(self):
coll = self.collection
cursor = coll.find(cursor_type=CursorType.TAILABLE)
# Can't call to_list on tailable cursor.
with self.assertRaises(InvalidOperation):
await cursor.to_list(10)
@asyncio_test
async def test_cursor_explicit_close(self):
client, server = self.client_server(auto_ismaster=AUTO_ISMASTER)
collection = client.test.coll
cursor = collection.find()
future = self.fetch_next(cursor)
self.assertTrue(cursor.alive)
request = await self.run_thread(server.receives, "find", "coll")
request.replies({"cursor": {"id": 123, "ns": "db.coll", "firstBatch": [{"_id": 1}]}})
self.assertTrue(await future)
self.assertEqual(123, cursor.cursor_id)
future = asyncio.ensure_future(cursor.close())
# No reply to OP_KILLCURSORS.
request = await self.run_thread(server.receives, "killCursors", "coll")
request.ok()
await future
# Cursor reports it's alive because it has buffered data, even though
# it's killed on the server.
self.assertTrue(cursor.alive)
self.assertEqual({"_id": 1}, cursor.next_object())
self.assertFalse(await cursor.fetch_next)
self.assertFalse(cursor.alive)
@asyncio_test
async def test_each_cancel(self):
await self.make_test_data()
loop = self.loop
collection = self.collection
results = []
future = self.loop.create_future()
def cancel(result, error):
if error:
future.set_exception(error)
else:
results.append(result)
loop.call_soon(canceled)
return False # Cancel iteration.
def canceled():
try:
self.assertFalse(cursor.delegate._killed)
self.assertTrue(cursor.alive)
# Resume iteration
cursor.each(each)
except Exception as e:
future.set_exception(e)
def each(result, error):
if error:
future.set_exception(error)
elif result:
results.append(result)
else:
# Complete
future.set_result(None)
cursor = collection.find()
cursor.each(cancel)
await future
self.assertEqual((await collection.count_documents({})), len(results))
@asyncio_test
async def test_rewind(self):
await self.collection.insert_many([{}, {}, {}])
cursor = self.collection.find().limit(2)
count = 0
while await cursor.fetch_next:
cursor.next_object()
count += 1
self.assertEqual(2, count)
cursor.rewind()
count = 0
while await cursor.fetch_next:
cursor.next_object()
count += 1
self.assertEqual(2, count)
cursor.rewind()
count = 0
while await cursor.fetch_next:
cursor.next_object()
break
cursor.rewind()
while await cursor.fetch_next:
cursor.next_object()
count += 1
self.assertEqual(2, count)
self.assertEqual(cursor, cursor.rewind())
@unittest.skipIf("PyPy" in sys.version, "PyPy")
@asyncio_test
async def test_cursor_del(self):
client, server = self.client_server(auto_ismaster=AUTO_ISMASTER)
cursor = client.test.coll.find()
future = self.fetch_next(cursor)
request = await self.run_thread(server.receives, "find", "coll")
request.replies({"cursor": {"id": 123, "ns": "db.coll", "firstBatch": [{"_id": 1}]}})
await future # Complete the first fetch.
# Dereference the cursor.
del cursor
# Let the event loop iterate once more to clear its references to
# callbacks, allowing the cursor to be freed.
await asyncio.sleep(0)
request = await self.run_thread(server.receives, "killCursors", "coll")
request.ok()
@asyncio_test
async def test_exhaust(self):
if await server_is_mongos(self.cx):
self.assertRaises(InvalidOperation, self.db.test.find, cursor_type=CursorType.EXHAUST)
return
self.assertRaises(ValueError, self.db.test.find, cursor_type=5)
cur = self.db.test.find(cursor_type=CursorType.EXHAUST)
self.assertRaises(InvalidOperation, cur.limit, 5)
cur = self.db.test.find(limit=5)
self.assertRaises(InvalidOperation, cur.add_option, 64)
cur = self.db.test.find()
cur.add_option(64)
self.assertRaises(InvalidOperation, cur.limit, 5)
await self.db.drop_collection("test")
# Insert enough documents to require more than one batch.
await self.db.test.insert_many([{} for _ in range(150)])
client = self.asyncio_client(maxPoolSize=1)
# Ensure a pool.
await client.db.collection.find_one()
pool = get_primary_pool(client)
conns = pool.conns
# Make sure the socket is returned after exhaustion.
cur = client[self.db.name].test.find(cursor_type=CursorType.EXHAUST)
has_next = await cur.fetch_next
self.assertTrue(has_next)
self.assertEqual(0, len(conns))
while await cur.fetch_next:
cur.next_object()
self.assertEqual(1, len(conns))
# Same as previous but with to_list instead of next_object.
docs = await client[self.db.name].test.find(cursor_type=CursorType.EXHAUST).to_list(None)
self.assertEqual(1, len(conns))
self.assertEqual((await self.db.test.count_documents({})), len(docs))
# If the Cursor instance is discarded before being
# completely iterated we have to close and
# discard the socket.
conn = one(conns)
cur = client[self.db.name].test.find(cursor_type=CursorType.EXHAUST).batch_size(1)
await cur.fetch_next
self.assertTrue(cur.next_object())
# Run at least one getMore to initiate the OP_MSG exhaust protocol.
if env.version.at_least(4, 2):
await cur.fetch_next
self.assertTrue(cur.next_object())
self.assertEqual(0, len(conns))
if "PyPy" in sys.version:
# Don't wait for GC or use gc.collect(), it's unreliable.
await cur.close()
del cur
async def conn_closed():
return conn not in conns and conn.closed
await wait_until(
conn_closed, "close exhaust cursor socket", timeout=get_async_test_timeout()
)
# The exhaust cursor's socket was discarded, although another may
# already have been opened to send OP_KILLCURSORS.
self.assertNotIn(conn, conns)
self.assertTrue(conn.closed)
@asyncio_test
async def test_close_with_docs_in_batch(self):
# MOTOR-67 Killed cursor with docs batched is "alive", don't kill again.
await self.make_test_data() # Ensure multiple batches.
cursor = self.collection.find()
await cursor.fetch_next
await cursor.close() # Killed but still "alive": has a batch.
self.cx.close()
with warnings.catch_warnings(record=True) as w:
del cursor # No-op, no error.
self.assertEqual(0, len(w))
@asyncio_test
async def test_aggregate_batch_size(self):
listener = TestListener()
cx = self.asyncio_client(event_listeners=[listener])
c = cx.motor_test.collection
await c.delete_many({})
await c.insert_many({"_id": i} for i in range(3))
# Two ways of setting batchSize.
cursor0 = c.aggregate([{"$sort": {"_id": 1}}]).batch_size(2)
cursor1 = c.aggregate([{"$sort": {"_id": 1}}], batchSize=2)
for cursor in cursor0, cursor1:
lst = []
while await cursor.fetch_next:
lst.append(cursor.next_object())
self.assertEqual(lst, [{"_id": 0}, {"_id": 1}, {"_id": 2}])
aggregate = listener.first_command_started("aggregate")
self.assertEqual(aggregate.command["cursor"]["batchSize"], 2)
getMore = listener.first_command_started("getMore")
self.assertEqual(getMore.command["batchSize"], 2)
@asyncio_test
async def test_raw_batches(self):
c = self.collection
await c.delete_many({})
await c.insert_many({"_id": i} for i in range(4))
find = partial(c.find_raw_batches, {})
agg = partial(c.aggregate_raw_batches, [{"$sort": {"_id": 1}}])
for method in find, agg:
cursor = method().batch_size(2)
await cursor.fetch_next
batch = cursor.next_object()
self.assertEqual([{"_id": 0}, {"_id": 1}], bson.decode_all(batch))
lst = await method().batch_size(2).to_list(length=1)
self.assertEqual([{"_id": 0}, {"_id": 1}], bson.decode_all(lst[0]))
@asyncio_test
async def test_context_manager(self):
coll = self.collection
await coll.insert_many({"_id": i} for i in range(10))
find = partial(coll.find, {})
agg = partial(coll.aggregate, [{"$sort": {"_id": 1}}])
find_raw_batches = partial(coll.find_raw_batches, {})
agg_raw_batches = partial(coll.aggregate_raw_batches, [{"$sort": {"_id": 1}}])
for method in find, agg, find_raw_batches, agg_raw_batches:
contrast_cursor = method().batch_size(2)
async with method().batch_size(2) as cursor:
self.assertFalse(cursor.started, "Cursor shouldn't start immediately")
with self.assertWarns(DeprecationWarning):
await cursor.fetch_next
record = cursor.next_object()
self.assertEqual(
{"_id": 0}, bson.decode_all(record)[0] if type(record) is bytes else record
)
self.assertTrue(cursor.started)
self.assertFalse(cursor.closed)
self.assertFalse(contrast_cursor.closed)
self.assertTrue(cursor.closed)
await contrast_cursor.close()
self.assertTrue(contrast_cursor.closed)
@asyncio_test
async def test_generate_keys(self):
c = self.cx
KMS_PROVIDERS = {"local": {"key": b"\x00" * 96}}
async with motor_asyncio.AsyncIOMotorClientEncryption(
KMS_PROVIDERS, "keyvault.datakeys", c, bson.codec_options.CodecOptions()
) as client_encryption:
self.assertIsInstance(
await client_encryption.get_keys(), motor_asyncio.AsyncIOMotorCursor
)
class TestAsyncIOCursorMaxTimeMS(AsyncIOTestCase):
def setUp(self):
super().setUp()
self.loop.run_until_complete(self.maybe_skip())
def tearDown(self):
self.loop.run_until_complete(self.disable_timeout())
super().tearDown()
async def maybe_skip(self):
if await server_is_mongos(self.cx):
raise SkipTest("mongos has no maxTimeAlwaysTimeOut fail point")
cmdline = await get_command_line(self.cx)
if "1" != safe_get(cmdline, "parsed.setParameter.enableTestCommands"):
if "enableTestCommands=1" not in cmdline["argv"]:
raise SkipTest("testing maxTimeMS requires failpoints")
async def enable_timeout(self):
await self.cx.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="alwaysOn")
async def disable_timeout(self):
await self.cx.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="off")
@asyncio_test
async def test_max_time_ms_query(self):
# Cursor parses server timeout error in response to initial query.
await self.enable_timeout()
cursor = self.collection.find().max_time_ms(100000)
with self.assertRaises(ExecutionTimeout):
await cursor.fetch_next
cursor = self.collection.find().max_time_ms(100000)
with self.assertRaises(ExecutionTimeout):
await cursor.to_list(10)
with self.assertRaises(ExecutionTimeout):
await self.collection.find_one(max_time_ms=100000)
@asyncio_test(timeout=60)
async def test_max_time_ms_getmore(self):
# Cursor handles server timeout during getmore, also.
await self.collection.insert_many({} for _ in range(200))
try:
# Send initial query.
cursor = self.collection.find().max_time_ms(100000)
await cursor.fetch_next
cursor.next_object()
# Test getmore timeout.
await self.enable_timeout()
with self.assertRaises(ExecutionTimeout):
while await cursor.fetch_next:
cursor.next_object()
await cursor.close()
# Send another initial query.
await self.disable_timeout()
cursor = self.collection.find().max_time_ms(100000)
await cursor.fetch_next
cursor.next_object()
# Test getmore timeout.
await self.enable_timeout()
with self.assertRaises(ExecutionTimeout):
await cursor.to_list(None)
# Avoid 'IOLoop is closing' warning.
await cursor.close()
finally:
# Cleanup.
await self.disable_timeout()
await self.collection.delete_many({})
@asyncio_test
async def test_max_time_ms_each_query(self):
# Cursor.each() handles server timeout during initial query.
await self.enable_timeout()
cursor = self.collection.find().max_time_ms(100000)
future = self.loop.create_future()
def callback(result, error):
if error:
future.set_exception(error)
elif not result:
# Done.
future.set_result(None)
with self.assertRaises(ExecutionTimeout):
cursor.each(callback)
await future
@asyncio_test(timeout=30)
async def test_max_time_ms_each_getmore(self):
# Cursor.each() handles server timeout during getmore.
await self.collection.insert_many({} for _ in range(200))
try:
# Send initial query.
cursor = self.collection.find().max_time_ms(100000)
await cursor.fetch_next
cursor.next_object()
future = self.loop.create_future()
def callback(result, error):
if error:
future.set_exception(error)
elif not result:
# Done.
future.set_result(None)
await self.enable_timeout()
with self.assertRaises(ExecutionTimeout):
cursor.each(callback)
await future
await cursor.close()
finally:
# Cleanup.
await self.disable_timeout()
await self.collection.delete_many({})
def test_iter(self):
# Iteration should be prohibited.
with self.assertRaises(TypeError):
for _ in self.db.test.find():
pass
if __name__ == "__main__":
unittest.main()