732 lines
25 KiB
Python
732 lines
25 KiB
Python
# Copyright 2014 MongoDB, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Test AsyncIOMotorCursor."""
|
|
|
|
import asyncio
|
|
import sys
|
|
import traceback
|
|
import unittest
|
|
import warnings
|
|
from functools import partial
|
|
from test.asyncio_tests import (
|
|
AsyncIOMockServerTestCase,
|
|
AsyncIOTestCase,
|
|
asyncio_test,
|
|
get_command_line,
|
|
server_is_mongos,
|
|
)
|
|
from test.test_environment import env
|
|
from test.utils import (
|
|
AUTO_ISMASTER,
|
|
FailPoint,
|
|
TestListener,
|
|
get_async_test_timeout,
|
|
get_primary_pool,
|
|
one,
|
|
safe_get,
|
|
wait_until,
|
|
)
|
|
from unittest import SkipTest
|
|
|
|
import bson
|
|
from pymongo import CursorType
|
|
from pymongo.errors import ExecutionTimeout, InvalidOperation, OperationFailure
|
|
|
|
from motor import motor_asyncio
|
|
|
|
|
|
class TestAsyncIOCursor(AsyncIOMockServerTestCase):
|
|
def test_cursor(self):
|
|
cursor = self.collection.find()
|
|
self.assertTrue(isinstance(cursor, motor_asyncio.AsyncIOMotorCursor))
|
|
self.assertFalse(cursor.started, "Cursor shouldn't start immediately")
|
|
|
|
@asyncio_test
|
|
async def test_count(self):
|
|
await self.make_test_data()
|
|
coll = self.collection
|
|
self.assertEqual(100, (await coll.count_documents({"_id": {"$gt": 99}})))
|
|
|
|
@asyncio_test
|
|
async def test_fetch_next(self):
|
|
await self.make_test_data()
|
|
coll = self.collection
|
|
# 200 results, only including _id field, sorted by _id.
|
|
cursor = coll.find({}, {"_id": 1}).sort("_id").batch_size(75)
|
|
|
|
self.assertEqual(None, cursor.cursor_id)
|
|
self.assertEqual(None, cursor.next_object()) # Haven't fetched yet.
|
|
i = 0
|
|
while await cursor.fetch_next:
|
|
self.assertEqual({"_id": i}, cursor.next_object())
|
|
i += 1
|
|
# With batch_size 75 and 200 results, cursor should be exhausted on
|
|
# the server by third fetch.
|
|
if i <= 150:
|
|
self.assertNotEqual(0, cursor.cursor_id)
|
|
else:
|
|
self.assertEqual(0, cursor.cursor_id)
|
|
|
|
self.assertEqual(False, (await cursor.fetch_next))
|
|
self.assertEqual(None, cursor.next_object())
|
|
self.assertEqual(0, cursor.cursor_id)
|
|
self.assertEqual(200, i)
|
|
|
|
@unittest.skipIf("PyPy" in sys.version, "PyPy")
|
|
@asyncio_test
|
|
async def test_fetch_next_delete(self):
|
|
client, server = self.client_server(auto_ismaster=AUTO_ISMASTER)
|
|
|
|
cursor = client.test.coll.find()
|
|
self.fetch_next(cursor)
|
|
request = await self.run_thread(server.receives, "find", "coll")
|
|
request.replies({"cursor": {"id": 123, "ns": "db.coll", "firstBatch": [{"_id": 1}]}})
|
|
|
|
# Decref the cursor and clear from the event loop.
|
|
del cursor
|
|
request = await self.run_thread(server.receives, "killCursors", "coll")
|
|
|
|
request.ok()
|
|
|
|
@asyncio_test
|
|
async def test_fetch_next_without_results(self):
|
|
coll = self.collection
|
|
# Nothing matches this query.
|
|
cursor = coll.find({"foo": "bar"})
|
|
self.assertEqual(None, cursor.next_object())
|
|
self.assertEqual(False, (await cursor.fetch_next))
|
|
self.assertEqual(None, cursor.next_object())
|
|
# Now cursor knows it's exhausted.
|
|
self.assertEqual(0, cursor.cursor_id)
|
|
|
|
@asyncio_test
|
|
async def test_fetch_next_is_idempotent(self):
|
|
# Subsequent calls to fetch_next don't do anything
|
|
await self.make_test_data()
|
|
coll = self.collection
|
|
cursor = coll.find()
|
|
self.assertEqual(None, cursor.cursor_id)
|
|
await cursor.fetch_next
|
|
self.assertTrue(cursor.cursor_id)
|
|
self.assertEqual(101, cursor._buffer_size())
|
|
await cursor.fetch_next # Does nothing
|
|
self.assertEqual(101, cursor._buffer_size())
|
|
await cursor.close()
|
|
|
|
@asyncio_test
|
|
async def test_fetch_next_exception(self):
|
|
coll = self.collection
|
|
await coll.insert_many([{} for _ in range(10)])
|
|
cursor = coll.find(batch_size=2)
|
|
await cursor.fetch_next
|
|
self.assertTrue(cursor.next_object())
|
|
|
|
# Not valid on server, causes CursorNotFound.
|
|
cursor.delegate._id = bson.int64.Int64(1234)
|
|
|
|
with self.assertRaises(OperationFailure):
|
|
await cursor.fetch_next
|
|
self.assertTrue(cursor.next_object())
|
|
await cursor.fetch_next
|
|
self.assertTrue(cursor.next_object())
|
|
|
|
@asyncio_test(timeout=30)
|
|
async def test_each(self):
|
|
await self.make_test_data()
|
|
cursor = self.collection.find({}, {"_id": 1}).sort("_id")
|
|
future = self.loop.create_future()
|
|
results = []
|
|
|
|
def callback(result, error):
|
|
if error:
|
|
raise error
|
|
|
|
if result is not None:
|
|
results.append(result)
|
|
else:
|
|
# Done iterating.
|
|
future.set_result(True)
|
|
|
|
cursor.each(callback)
|
|
await future
|
|
expected = [{"_id": i} for i in range(200)]
|
|
self.assertEqual(expected, results)
|
|
|
|
@asyncio_test
|
|
async def test_to_list_argument_checking(self):
|
|
# We need more than 10 documents so the cursor stays alive.
|
|
await self.make_test_data()
|
|
coll = self.collection
|
|
cursor = coll.find()
|
|
with self.assertRaises(ValueError):
|
|
await cursor.to_list(-1)
|
|
|
|
with self.assertRaises(TypeError):
|
|
await cursor.to_list("foo")
|
|
|
|
@asyncio_test
|
|
async def test_to_list_with_length(self):
|
|
await self.make_test_data()
|
|
coll = self.collection
|
|
cursor = coll.find().sort("_id")
|
|
|
|
def expected(start, stop):
|
|
return [{"_id": i} for i in range(start, stop)]
|
|
|
|
self.assertEqual(expected(0, 10), (await cursor.to_list(10)))
|
|
self.assertEqual(expected(10, 100), (await cursor.to_list(90)))
|
|
|
|
# Test particularly rigorously around the 101-doc mark, since this is
|
|
# where the first batch ends
|
|
self.assertEqual(expected(100, 101), (await cursor.to_list(1)))
|
|
self.assertEqual(expected(101, 102), (await cursor.to_list(1)))
|
|
self.assertEqual(expected(102, 103), (await cursor.to_list(1)))
|
|
self.assertEqual(expected(103, 105), (await cursor.to_list(2)))
|
|
|
|
# Only 95 docs left, make sure length=100 doesn't error or hang
|
|
self.assertEqual(expected(105, 200), (await cursor.to_list(100)))
|
|
self.assertEqual(0, cursor.cursor_id)
|
|
|
|
# Nothing left.
|
|
self.assertEqual([], (await cursor.to_list(100)))
|
|
|
|
await cursor.close()
|
|
|
|
@asyncio_test
|
|
async def test_to_list_multiple_getMores(self):
|
|
await self.make_test_data()
|
|
coll = self.collection
|
|
cursor = coll.find(batch_size=5).sort("_id")
|
|
|
|
def expected(start, stop):
|
|
return [{"_id": i} for i in range(start, stop)]
|
|
|
|
# 2 batches (find+getMore):
|
|
self.assertEqual(expected(0, 10), (await cursor.to_list(10)))
|
|
# 5 batches, stop in the middle of a batch:
|
|
self.assertEqual(expected(10, 33), (await cursor.to_list(23)))
|
|
# 33 batches:
|
|
self.assertEqual(expected(33, 200), (await cursor.to_list(167)))
|
|
# Nothing left.
|
|
self.assertEqual([], (await cursor.to_list(100)))
|
|
|
|
await cursor.close()
|
|
|
|
@asyncio_test
|
|
async def test_to_list_exc_info(self):
|
|
await self.make_test_data()
|
|
coll = self.collection
|
|
cursor = coll.find()
|
|
await cursor.to_list(length=10)
|
|
await self.collection.drop()
|
|
try:
|
|
await cursor.to_list(length=None)
|
|
except OperationFailure:
|
|
_, _, tb = sys.exc_info()
|
|
|
|
# The call tree should include PyMongo code we ran on a thread.
|
|
formatted = "\n".join(traceback.format_tb(tb))
|
|
self.assertTrue(
|
|
"_unpack_response" in formatted or "_check_command_response" in formatted
|
|
)
|
|
|
|
async def _test_cancelled_error(self, coro):
|
|
await self.make_test_data()
|
|
# Cause an error on a getMore after the cursor.to_list task is
|
|
# cancelled.
|
|
fp = {
|
|
"configureFailPoint": "failCommand",
|
|
"data": {"failCommands": ["getMore"], "errorCode": 96},
|
|
"mode": {"times": 1},
|
|
}
|
|
async with FailPoint(self.cx, fp):
|
|
cleanup, task = coro(self.collection)
|
|
task.cancel()
|
|
with self.assertRaises(asyncio.CancelledError):
|
|
await task
|
|
await cleanup()
|
|
# Yield for some time to allow pending Cursor callbacks to run.
|
|
await asyncio.sleep(0.5)
|
|
|
|
@env.require_version_min(4, 2) # failCommand
|
|
@asyncio_test
|
|
async def test_cancelled_error_to_list(self):
|
|
# Note: We intentionally don't use "async def" here to avoid wrapping
|
|
# the returned to_list Future in a coroutine.
|
|
def to_list(collection):
|
|
cursor = collection.find(batch_size=2)
|
|
return cursor.close, cursor.to_list(None)
|
|
|
|
await self._test_cancelled_error(to_list)
|
|
|
|
@env.require_version_min(4, 2) # failCommand
|
|
@asyncio_test
|
|
async def test_cancelled_error_fetch_next(self):
|
|
def fetch_next(collection):
|
|
cursor = collection.find(batch_size=2)
|
|
return cursor.close, cursor.fetch_next
|
|
|
|
await self._test_cancelled_error(fetch_next)
|
|
|
|
@env.require_version_min(4, 2) # failCommand
|
|
@asyncio_test
|
|
async def test_cancelled_error_fetch_next_aggregate(self):
|
|
def fetch_next(collection):
|
|
cursor = collection.aggregate([], batchSize=2)
|
|
return cursor.close, cursor.fetch_next
|
|
|
|
await self._test_cancelled_error(fetch_next)
|
|
|
|
@asyncio_test
|
|
async def test_to_list_with_length_of_none(self):
|
|
await self.make_test_data()
|
|
collection = self.collection
|
|
cursor = collection.find()
|
|
docs = await cursor.to_list(None) # Unlimited.
|
|
count = await collection.count_documents({})
|
|
self.assertEqual(count, len(docs))
|
|
|
|
@asyncio_test
|
|
async def test_to_list_tailable(self):
|
|
coll = self.collection
|
|
cursor = coll.find(cursor_type=CursorType.TAILABLE)
|
|
|
|
# Can't call to_list on tailable cursor.
|
|
with self.assertRaises(InvalidOperation):
|
|
await cursor.to_list(10)
|
|
|
|
@asyncio_test
|
|
async def test_cursor_explicit_close(self):
|
|
client, server = self.client_server(auto_ismaster=AUTO_ISMASTER)
|
|
collection = client.test.coll
|
|
cursor = collection.find()
|
|
|
|
future = self.fetch_next(cursor)
|
|
self.assertTrue(cursor.alive)
|
|
request = await self.run_thread(server.receives, "find", "coll")
|
|
request.replies({"cursor": {"id": 123, "ns": "db.coll", "firstBatch": [{"_id": 1}]}})
|
|
|
|
self.assertTrue(await future)
|
|
self.assertEqual(123, cursor.cursor_id)
|
|
|
|
future = asyncio.ensure_future(cursor.close())
|
|
|
|
# No reply to OP_KILLCURSORS.
|
|
request = await self.run_thread(server.receives, "killCursors", "coll")
|
|
|
|
request.ok()
|
|
await future
|
|
|
|
# Cursor reports it's alive because it has buffered data, even though
|
|
# it's killed on the server.
|
|
self.assertTrue(cursor.alive)
|
|
self.assertEqual({"_id": 1}, cursor.next_object())
|
|
self.assertFalse(await cursor.fetch_next)
|
|
self.assertFalse(cursor.alive)
|
|
|
|
@asyncio_test
|
|
async def test_each_cancel(self):
|
|
await self.make_test_data()
|
|
loop = self.loop
|
|
collection = self.collection
|
|
results = []
|
|
future = self.loop.create_future()
|
|
|
|
def cancel(result, error):
|
|
if error:
|
|
future.set_exception(error)
|
|
|
|
else:
|
|
results.append(result)
|
|
loop.call_soon(canceled)
|
|
return False # Cancel iteration.
|
|
|
|
def canceled():
|
|
try:
|
|
self.assertFalse(cursor.delegate._killed)
|
|
self.assertTrue(cursor.alive)
|
|
|
|
# Resume iteration
|
|
cursor.each(each)
|
|
except Exception as e:
|
|
future.set_exception(e)
|
|
|
|
def each(result, error):
|
|
if error:
|
|
future.set_exception(error)
|
|
elif result:
|
|
results.append(result)
|
|
else:
|
|
# Complete
|
|
future.set_result(None)
|
|
|
|
cursor = collection.find()
|
|
cursor.each(cancel)
|
|
await future
|
|
self.assertEqual((await collection.count_documents({})), len(results))
|
|
|
|
@asyncio_test
|
|
async def test_rewind(self):
|
|
await self.collection.insert_many([{}, {}, {}])
|
|
cursor = self.collection.find().limit(2)
|
|
|
|
count = 0
|
|
while await cursor.fetch_next:
|
|
cursor.next_object()
|
|
count += 1
|
|
self.assertEqual(2, count)
|
|
|
|
cursor.rewind()
|
|
count = 0
|
|
while await cursor.fetch_next:
|
|
cursor.next_object()
|
|
count += 1
|
|
self.assertEqual(2, count)
|
|
|
|
cursor.rewind()
|
|
count = 0
|
|
while await cursor.fetch_next:
|
|
cursor.next_object()
|
|
break
|
|
|
|
cursor.rewind()
|
|
while await cursor.fetch_next:
|
|
cursor.next_object()
|
|
count += 1
|
|
|
|
self.assertEqual(2, count)
|
|
self.assertEqual(cursor, cursor.rewind())
|
|
|
|
@unittest.skipIf("PyPy" in sys.version, "PyPy")
|
|
@asyncio_test
|
|
async def test_cursor_del(self):
|
|
client, server = self.client_server(auto_ismaster=AUTO_ISMASTER)
|
|
cursor = client.test.coll.find()
|
|
|
|
future = self.fetch_next(cursor)
|
|
request = await self.run_thread(server.receives, "find", "coll")
|
|
request.replies({"cursor": {"id": 123, "ns": "db.coll", "firstBatch": [{"_id": 1}]}})
|
|
await future # Complete the first fetch.
|
|
|
|
# Dereference the cursor.
|
|
del cursor
|
|
|
|
# Let the event loop iterate once more to clear its references to
|
|
# callbacks, allowing the cursor to be freed.
|
|
await asyncio.sleep(0)
|
|
request = await self.run_thread(server.receives, "killCursors", "coll")
|
|
|
|
request.ok()
|
|
|
|
@asyncio_test
|
|
async def test_exhaust(self):
|
|
if await server_is_mongos(self.cx):
|
|
self.assertRaises(InvalidOperation, self.db.test.find, cursor_type=CursorType.EXHAUST)
|
|
return
|
|
|
|
self.assertRaises(ValueError, self.db.test.find, cursor_type=5)
|
|
|
|
cur = self.db.test.find(cursor_type=CursorType.EXHAUST)
|
|
self.assertRaises(InvalidOperation, cur.limit, 5)
|
|
cur = self.db.test.find(limit=5)
|
|
self.assertRaises(InvalidOperation, cur.add_option, 64)
|
|
cur = self.db.test.find()
|
|
cur.add_option(64)
|
|
self.assertRaises(InvalidOperation, cur.limit, 5)
|
|
|
|
await self.db.drop_collection("test")
|
|
|
|
# Insert enough documents to require more than one batch.
|
|
await self.db.test.insert_many([{} for _ in range(150)])
|
|
|
|
client = self.asyncio_client(maxPoolSize=1)
|
|
# Ensure a pool.
|
|
await client.db.collection.find_one()
|
|
|
|
pool = get_primary_pool(client)
|
|
conns = pool.conns
|
|
|
|
# Make sure the socket is returned after exhaustion.
|
|
cur = client[self.db.name].test.find(cursor_type=CursorType.EXHAUST)
|
|
has_next = await cur.fetch_next
|
|
self.assertTrue(has_next)
|
|
self.assertEqual(0, len(conns))
|
|
|
|
while await cur.fetch_next:
|
|
cur.next_object()
|
|
|
|
self.assertEqual(1, len(conns))
|
|
|
|
# Same as previous but with to_list instead of next_object.
|
|
docs = await client[self.db.name].test.find(cursor_type=CursorType.EXHAUST).to_list(None)
|
|
self.assertEqual(1, len(conns))
|
|
self.assertEqual((await self.db.test.count_documents({})), len(docs))
|
|
|
|
# If the Cursor instance is discarded before being
|
|
# completely iterated we have to close and
|
|
# discard the socket.
|
|
conn = one(conns)
|
|
cur = client[self.db.name].test.find(cursor_type=CursorType.EXHAUST).batch_size(1)
|
|
await cur.fetch_next
|
|
self.assertTrue(cur.next_object())
|
|
# Run at least one getMore to initiate the OP_MSG exhaust protocol.
|
|
if env.version.at_least(4, 2):
|
|
await cur.fetch_next
|
|
self.assertTrue(cur.next_object())
|
|
self.assertEqual(0, len(conns))
|
|
if "PyPy" in sys.version:
|
|
# Don't wait for GC or use gc.collect(), it's unreliable.
|
|
await cur.close()
|
|
|
|
del cur
|
|
|
|
async def conn_closed():
|
|
return conn not in conns and conn.closed
|
|
|
|
await wait_until(
|
|
conn_closed, "close exhaust cursor socket", timeout=get_async_test_timeout()
|
|
)
|
|
|
|
# The exhaust cursor's socket was discarded, although another may
|
|
# already have been opened to send OP_KILLCURSORS.
|
|
self.assertNotIn(conn, conns)
|
|
self.assertTrue(conn.closed)
|
|
|
|
@asyncio_test
|
|
async def test_close_with_docs_in_batch(self):
|
|
# MOTOR-67 Killed cursor with docs batched is "alive", don't kill again.
|
|
await self.make_test_data() # Ensure multiple batches.
|
|
cursor = self.collection.find()
|
|
await cursor.fetch_next
|
|
await cursor.close() # Killed but still "alive": has a batch.
|
|
self.cx.close()
|
|
|
|
with warnings.catch_warnings(record=True) as w:
|
|
del cursor # No-op, no error.
|
|
|
|
self.assertEqual(0, len(w))
|
|
|
|
@asyncio_test
|
|
async def test_aggregate_batch_size(self):
|
|
listener = TestListener()
|
|
cx = self.asyncio_client(event_listeners=[listener])
|
|
c = cx.motor_test.collection
|
|
await c.delete_many({})
|
|
await c.insert_many({"_id": i} for i in range(3))
|
|
|
|
# Two ways of setting batchSize.
|
|
cursor0 = c.aggregate([{"$sort": {"_id": 1}}]).batch_size(2)
|
|
cursor1 = c.aggregate([{"$sort": {"_id": 1}}], batchSize=2)
|
|
for cursor in cursor0, cursor1:
|
|
lst = []
|
|
while await cursor.fetch_next:
|
|
lst.append(cursor.next_object())
|
|
|
|
self.assertEqual(lst, [{"_id": 0}, {"_id": 1}, {"_id": 2}])
|
|
aggregate = listener.first_command_started("aggregate")
|
|
self.assertEqual(aggregate.command["cursor"]["batchSize"], 2)
|
|
getMore = listener.first_command_started("getMore")
|
|
self.assertEqual(getMore.command["batchSize"], 2)
|
|
|
|
@asyncio_test
|
|
async def test_raw_batches(self):
|
|
c = self.collection
|
|
await c.delete_many({})
|
|
await c.insert_many({"_id": i} for i in range(4))
|
|
|
|
find = partial(c.find_raw_batches, {})
|
|
agg = partial(c.aggregate_raw_batches, [{"$sort": {"_id": 1}}])
|
|
|
|
for method in find, agg:
|
|
cursor = method().batch_size(2)
|
|
await cursor.fetch_next
|
|
batch = cursor.next_object()
|
|
self.assertEqual([{"_id": 0}, {"_id": 1}], bson.decode_all(batch))
|
|
|
|
lst = await method().batch_size(2).to_list(length=1)
|
|
self.assertEqual([{"_id": 0}, {"_id": 1}], bson.decode_all(lst[0]))
|
|
|
|
@asyncio_test
|
|
async def test_context_manager(self):
|
|
coll = self.collection
|
|
await coll.insert_many({"_id": i} for i in range(10))
|
|
|
|
find = partial(coll.find, {})
|
|
agg = partial(coll.aggregate, [{"$sort": {"_id": 1}}])
|
|
find_raw_batches = partial(coll.find_raw_batches, {})
|
|
agg_raw_batches = partial(coll.aggregate_raw_batches, [{"$sort": {"_id": 1}}])
|
|
for method in find, agg, find_raw_batches, agg_raw_batches:
|
|
contrast_cursor = method().batch_size(2)
|
|
async with method().batch_size(2) as cursor:
|
|
self.assertFalse(cursor.started, "Cursor shouldn't start immediately")
|
|
with self.assertWarns(DeprecationWarning):
|
|
await cursor.fetch_next
|
|
record = cursor.next_object()
|
|
self.assertEqual(
|
|
{"_id": 0}, bson.decode_all(record)[0] if type(record) is bytes else record
|
|
)
|
|
self.assertTrue(cursor.started)
|
|
self.assertFalse(cursor.closed)
|
|
self.assertFalse(contrast_cursor.closed)
|
|
self.assertTrue(cursor.closed)
|
|
await contrast_cursor.close()
|
|
self.assertTrue(contrast_cursor.closed)
|
|
|
|
@asyncio_test
|
|
async def test_generate_keys(self):
|
|
c = self.cx
|
|
KMS_PROVIDERS = {"local": {"key": b"\x00" * 96}}
|
|
|
|
async with motor_asyncio.AsyncIOMotorClientEncryption(
|
|
KMS_PROVIDERS, "keyvault.datakeys", c, bson.codec_options.CodecOptions()
|
|
) as client_encryption:
|
|
self.assertIsInstance(
|
|
await client_encryption.get_keys(), motor_asyncio.AsyncIOMotorCursor
|
|
)
|
|
|
|
|
|
class TestAsyncIOCursorMaxTimeMS(AsyncIOTestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.loop.run_until_complete(self.maybe_skip())
|
|
|
|
def tearDown(self):
|
|
self.loop.run_until_complete(self.disable_timeout())
|
|
super().tearDown()
|
|
|
|
async def maybe_skip(self):
|
|
if await server_is_mongos(self.cx):
|
|
raise SkipTest("mongos has no maxTimeAlwaysTimeOut fail point")
|
|
|
|
cmdline = await get_command_line(self.cx)
|
|
if "1" != safe_get(cmdline, "parsed.setParameter.enableTestCommands"):
|
|
if "enableTestCommands=1" not in cmdline["argv"]:
|
|
raise SkipTest("testing maxTimeMS requires failpoints")
|
|
|
|
async def enable_timeout(self):
|
|
await self.cx.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="alwaysOn")
|
|
|
|
async def disable_timeout(self):
|
|
await self.cx.admin.command("configureFailPoint", "maxTimeAlwaysTimeOut", mode="off")
|
|
|
|
@asyncio_test
|
|
async def test_max_time_ms_query(self):
|
|
# Cursor parses server timeout error in response to initial query.
|
|
await self.enable_timeout()
|
|
cursor = self.collection.find().max_time_ms(100000)
|
|
with self.assertRaises(ExecutionTimeout):
|
|
await cursor.fetch_next
|
|
|
|
cursor = self.collection.find().max_time_ms(100000)
|
|
with self.assertRaises(ExecutionTimeout):
|
|
await cursor.to_list(10)
|
|
|
|
with self.assertRaises(ExecutionTimeout):
|
|
await self.collection.find_one(max_time_ms=100000)
|
|
|
|
@asyncio_test(timeout=60)
|
|
async def test_max_time_ms_getmore(self):
|
|
# Cursor handles server timeout during getmore, also.
|
|
await self.collection.insert_many({} for _ in range(200))
|
|
try:
|
|
# Send initial query.
|
|
cursor = self.collection.find().max_time_ms(100000)
|
|
await cursor.fetch_next
|
|
cursor.next_object()
|
|
|
|
# Test getmore timeout.
|
|
await self.enable_timeout()
|
|
with self.assertRaises(ExecutionTimeout):
|
|
while await cursor.fetch_next:
|
|
cursor.next_object()
|
|
|
|
await cursor.close()
|
|
|
|
# Send another initial query.
|
|
await self.disable_timeout()
|
|
cursor = self.collection.find().max_time_ms(100000)
|
|
await cursor.fetch_next
|
|
cursor.next_object()
|
|
|
|
# Test getmore timeout.
|
|
await self.enable_timeout()
|
|
with self.assertRaises(ExecutionTimeout):
|
|
await cursor.to_list(None)
|
|
|
|
# Avoid 'IOLoop is closing' warning.
|
|
await cursor.close()
|
|
finally:
|
|
# Cleanup.
|
|
await self.disable_timeout()
|
|
await self.collection.delete_many({})
|
|
|
|
@asyncio_test
|
|
async def test_max_time_ms_each_query(self):
|
|
# Cursor.each() handles server timeout during initial query.
|
|
await self.enable_timeout()
|
|
cursor = self.collection.find().max_time_ms(100000)
|
|
future = self.loop.create_future()
|
|
|
|
def callback(result, error):
|
|
if error:
|
|
future.set_exception(error)
|
|
elif not result:
|
|
# Done.
|
|
future.set_result(None)
|
|
|
|
with self.assertRaises(ExecutionTimeout):
|
|
cursor.each(callback)
|
|
await future
|
|
|
|
@asyncio_test(timeout=30)
|
|
async def test_max_time_ms_each_getmore(self):
|
|
# Cursor.each() handles server timeout during getmore.
|
|
await self.collection.insert_many({} for _ in range(200))
|
|
try:
|
|
# Send initial query.
|
|
cursor = self.collection.find().max_time_ms(100000)
|
|
await cursor.fetch_next
|
|
cursor.next_object()
|
|
|
|
future = self.loop.create_future()
|
|
|
|
def callback(result, error):
|
|
if error:
|
|
future.set_exception(error)
|
|
elif not result:
|
|
# Done.
|
|
future.set_result(None)
|
|
|
|
await self.enable_timeout()
|
|
with self.assertRaises(ExecutionTimeout):
|
|
cursor.each(callback)
|
|
await future
|
|
|
|
await cursor.close()
|
|
finally:
|
|
# Cleanup.
|
|
await self.disable_timeout()
|
|
await self.collection.delete_many({})
|
|
|
|
def test_iter(self):
|
|
# Iteration should be prohibited.
|
|
with self.assertRaises(TypeError):
|
|
for _ in self.db.test.find():
|
|
pass
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|