PYTHON-4549 - Optimize Cursor.to_list (#1749)

This commit is contained in:
Noah Stapp 2024-07-30 15:45:30 -07:00 committed by GitHub
parent d79eee51ba
commit a3cd7045df
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 150 additions and 4 deletions

View File

@ -1892,6 +1892,9 @@ class AsyncGridOutCursor(AsyncCursor):
next_file = await super().next()
return AsyncGridOut(self._root_collection, file_document=next_file, session=self.session)
async def to_list(self) -> list[AsyncGridOut]:
return [x async for x in self] # noqa: C416,RUF100
__anext__ = next
def add_option(self, *args: Any, **kwargs: Any) -> NoReturn:

View File

@ -1878,6 +1878,9 @@ class GridOutCursor(Cursor):
next_file = super().next()
return GridOut(self._root_collection, file_document=next_file, session=self.session)
def to_list(self) -> list[GridOut]:
return [x for x in self] # noqa: C416,RUF100
__next__ = next
def add_option(self, *args: Any, **kwargs: Any) -> NoReturn:

View File

@ -346,6 +346,17 @@ class AsyncCommandCursor(Generic[_DocumentType]):
else:
return None
async def _next_batch(self, result: list) -> bool:
"""Get all available documents from the cursor."""
if not len(self._data) and not self._killed:
await self._refresh()
if len(self._data):
result.extend(self._data)
self._data.clear()
return True
else:
return False
async def try_next(self) -> Optional[_DocumentType]:
"""Advance the cursor without blocking indefinitely.
@ -371,7 +382,11 @@ class AsyncCommandCursor(Generic[_DocumentType]):
await self.close()
async def to_list(self) -> list[_DocumentType]:
return [x async for x in self] # noqa: C416,RUF100
res: list[_DocumentType] = []
while self.alive:
if not await self._next_batch(res):
break
return res
class AsyncRawBatchCommandCursor(AsyncCommandCursor[_DocumentType]):

View File

@ -1260,6 +1260,20 @@ class AsyncCursor(Generic[_DocumentType]):
else:
raise StopAsyncIteration
async def _next_batch(self, result: list) -> bool:
"""Get all available documents from the cursor."""
if not self._exhaust_checked:
self._exhaust_checked = True
await self._supports_exhaust()
if self._empty:
return False
if len(self._data) or await self._refresh():
result.extend(self._data)
self._data.clear()
return True
else:
return False
async def __anext__(self) -> _DocumentType:
return await self.next()
@ -1273,7 +1287,11 @@ class AsyncCursor(Generic[_DocumentType]):
await self.close()
async def to_list(self) -> list[_DocumentType]:
return [x async for x in self] # noqa: C416,RUF100
res: list[_DocumentType] = []
while self.alive:
if not await self._next_batch(res):
break
return res
class AsyncRawBatchCursor(AsyncCursor, Generic[_DocumentType]):

View File

@ -346,6 +346,17 @@ class CommandCursor(Generic[_DocumentType]):
else:
return None
def _next_batch(self, result: list) -> bool:
"""Get all available documents from the cursor."""
if not len(self._data) and not self._killed:
self._refresh()
if len(self._data):
result.extend(self._data)
self._data.clear()
return True
else:
return False
def try_next(self) -> Optional[_DocumentType]:
"""Advance the cursor without blocking indefinitely.
@ -371,7 +382,11 @@ class CommandCursor(Generic[_DocumentType]):
self.close()
def to_list(self) -> list[_DocumentType]:
return [x for x in self] # noqa: C416,RUF100
res: list[_DocumentType] = []
while self.alive:
if not self._next_batch(res):
break
return res
class RawBatchCommandCursor(CommandCursor[_DocumentType]):

View File

@ -1258,6 +1258,20 @@ class Cursor(Generic[_DocumentType]):
else:
raise StopIteration
def _next_batch(self, result: list) -> bool:
"""Get all available documents from the cursor."""
if not self._exhaust_checked:
self._exhaust_checked = True
self._supports_exhaust()
if self._empty:
return False
if len(self._data) or self._refresh():
result.extend(self._data)
self._data.clear()
return True
else:
return False
def __next__(self) -> _DocumentType:
return self.next()
@ -1271,7 +1285,11 @@ class Cursor(Generic[_DocumentType]):
self.close()
def to_list(self) -> list[_DocumentType]:
return [x for x in self] # noqa: C416,RUF100
res: list[_DocumentType] = []
while self.alive:
if not self._next_batch(res):
break
return res
class RawBatchCursor(Cursor, Generic[_DocumentType]):

View File

@ -1380,6 +1380,43 @@ class TestCursor(AsyncIntegrationTest):
self.assertEqual("getMore", started[1].command_name)
self.assertNotIn("$readPreference", started[1].command)
@async_client_context.require_replica_set
async def test_to_list_tailable(self):
oplog = self.client.local.oplog.rs
last = await oplog.find().sort("$natural", pymongo.DESCENDING).limit(-1).next()
ts = last["ts"]
c = oplog.find(
{"ts": {"$gte": ts}}, cursor_type=pymongo.CursorType.TAILABLE_AWAIT, oplog_replay=True
)
docs = await c.to_list()
self.assertGreaterEqual(len(docs), 1)
async def test_to_list_empty(self):
c = self.db.does_not_exist.find()
docs = await c.to_list()
self.assertEqual([], docs)
@async_client_context.require_replica_set
async def test_command_cursor_to_list(self):
c = await self.db.test.aggregate([{"$changeStream": {}}])
docs = await c.to_list()
self.assertGreaterEqual(len(docs), 0)
@async_client_context.require_replica_set
async def test_command_cursor_to_list_empty(self):
c = await self.db.does_not_exist.aggregate([{"$changeStream": {}}])
docs = await c.to_list()
self.assertEqual([], docs)
class TestRawBatchCursor(AsyncIntegrationTest):
async def test_find_raw(self):

View File

@ -1371,6 +1371,43 @@ class TestCursor(IntegrationTest):
self.assertEqual("getMore", started[1].command_name)
self.assertNotIn("$readPreference", started[1].command)
@client_context.require_replica_set
def test_to_list_tailable(self):
oplog = self.client.local.oplog.rs
last = oplog.find().sort("$natural", pymongo.DESCENDING).limit(-1).next()
ts = last["ts"]
c = oplog.find(
{"ts": {"$gte": ts}}, cursor_type=pymongo.CursorType.TAILABLE_AWAIT, oplog_replay=True
)
docs = c.to_list()
self.assertGreaterEqual(len(docs), 1)
def test_to_list_empty(self):
c = self.db.does_not_exist.find()
docs = c.to_list()
self.assertEqual([], docs)
@client_context.require_replica_set
def test_command_cursor_to_list(self):
c = self.db.test.aggregate([{"$changeStream": {}}])
docs = c.to_list()
self.assertGreaterEqual(len(docs), 0)
@client_context.require_replica_set
def test_command_cursor_to_list_empty(self):
c = self.db.does_not_exist.aggregate([{"$changeStream": {}}])
docs = c.to_list()
self.assertEqual([], docs)
class TestRawBatchCursor(IntegrationTest):
def test_find_raw(self):