From a3cd7045df240acb4f5747701b238dedb9bc49ac Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 30 Jul 2024 15:45:30 -0700 Subject: [PATCH] PYTHON-4549 - Optimize Cursor.to_list (#1749) --- gridfs/asynchronous/grid_file.py | 3 +++ gridfs/synchronous/grid_file.py | 3 +++ pymongo/asynchronous/command_cursor.py | 17 +++++++++++- pymongo/asynchronous/cursor.py | 20 +++++++++++++- pymongo/synchronous/command_cursor.py | 17 +++++++++++- pymongo/synchronous/cursor.py | 20 +++++++++++++- test/asynchronous/test_cursor.py | 37 ++++++++++++++++++++++++++ test/test_cursor.py | 37 ++++++++++++++++++++++++++ 8 files changed, 150 insertions(+), 4 deletions(-) diff --git a/gridfs/asynchronous/grid_file.py b/gridfs/asynchronous/grid_file.py index 0b68c3f4e..303abe705 100644 --- a/gridfs/asynchronous/grid_file.py +++ b/gridfs/asynchronous/grid_file.py @@ -1892,6 +1892,9 @@ class AsyncGridOutCursor(AsyncCursor): next_file = await super().next() return AsyncGridOut(self._root_collection, file_document=next_file, session=self.session) + async def to_list(self) -> list[AsyncGridOut]: + return [x async for x in self] # noqa: C416,RUF100 + __anext__ = next def add_option(self, *args: Any, **kwargs: Any) -> NoReturn: diff --git a/gridfs/synchronous/grid_file.py b/gridfs/synchronous/grid_file.py index 98374cc8c..1e3d265d4 100644 --- a/gridfs/synchronous/grid_file.py +++ b/gridfs/synchronous/grid_file.py @@ -1878,6 +1878,9 @@ class GridOutCursor(Cursor): next_file = super().next() return GridOut(self._root_collection, file_document=next_file, session=self.session) + def to_list(self) -> list[GridOut]: + return [x for x in self] # noqa: C416,RUF100 + __next__ = next def add_option(self, *args: Any, **kwargs: Any) -> NoReturn: diff --git a/pymongo/asynchronous/command_cursor.py b/pymongo/asynchronous/command_cursor.py index ee9b7ddcb..dac9a27a2 100644 --- a/pymongo/asynchronous/command_cursor.py +++ b/pymongo/asynchronous/command_cursor.py @@ -346,6 +346,17 @@ class AsyncCommandCursor(Generic[_DocumentType]): else: return None + async def _next_batch(self, result: list) -> bool: + """Get all available documents from the cursor.""" + if not len(self._data) and not self._killed: + await self._refresh() + if len(self._data): + result.extend(self._data) + self._data.clear() + return True + else: + return False + async def try_next(self) -> Optional[_DocumentType]: """Advance the cursor without blocking indefinitely. @@ -371,7 +382,11 @@ class AsyncCommandCursor(Generic[_DocumentType]): await self.close() async def to_list(self) -> list[_DocumentType]: - return [x async for x in self] # noqa: C416,RUF100 + res: list[_DocumentType] = [] + while self.alive: + if not await self._next_batch(res): + break + return res class AsyncRawBatchCommandCursor(AsyncCommandCursor[_DocumentType]): diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index 879556380..fd288c710 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -1260,6 +1260,20 @@ class AsyncCursor(Generic[_DocumentType]): else: raise StopAsyncIteration + async def _next_batch(self, result: list) -> bool: + """Get all available documents from the cursor.""" + if not self._exhaust_checked: + self._exhaust_checked = True + await self._supports_exhaust() + if self._empty: + return False + if len(self._data) or await self._refresh(): + result.extend(self._data) + self._data.clear() + return True + else: + return False + async def __anext__(self) -> _DocumentType: return await self.next() @@ -1273,7 +1287,11 @@ class AsyncCursor(Generic[_DocumentType]): await self.close() async def to_list(self) -> list[_DocumentType]: - return [x async for x in self] # noqa: C416,RUF100 + res: list[_DocumentType] = [] + while self.alive: + if not await self._next_batch(res): + break + return res class AsyncRawBatchCursor(AsyncCursor, Generic[_DocumentType]): diff --git a/pymongo/synchronous/command_cursor.py b/pymongo/synchronous/command_cursor.py index 6a8ff9eba..d7a19c36b 100644 --- a/pymongo/synchronous/command_cursor.py +++ b/pymongo/synchronous/command_cursor.py @@ -346,6 +346,17 @@ class CommandCursor(Generic[_DocumentType]): else: return None + def _next_batch(self, result: list) -> bool: + """Get all available documents from the cursor.""" + if not len(self._data) and not self._killed: + self._refresh() + if len(self._data): + result.extend(self._data) + self._data.clear() + return True + else: + return False + def try_next(self) -> Optional[_DocumentType]: """Advance the cursor without blocking indefinitely. @@ -371,7 +382,11 @@ class CommandCursor(Generic[_DocumentType]): self.close() def to_list(self) -> list[_DocumentType]: - return [x for x in self] # noqa: C416,RUF100 + res: list[_DocumentType] = [] + while self.alive: + if not self._next_batch(res): + break + return res class RawBatchCommandCursor(CommandCursor[_DocumentType]): diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index f6bc5131c..e00c33d90 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -1258,6 +1258,20 @@ class Cursor(Generic[_DocumentType]): else: raise StopIteration + def _next_batch(self, result: list) -> bool: + """Get all available documents from the cursor.""" + if not self._exhaust_checked: + self._exhaust_checked = True + self._supports_exhaust() + if self._empty: + return False + if len(self._data) or self._refresh(): + result.extend(self._data) + self._data.clear() + return True + else: + return False + def __next__(self) -> _DocumentType: return self.next() @@ -1271,7 +1285,11 @@ class Cursor(Generic[_DocumentType]): self.close() def to_list(self) -> list[_DocumentType]: - return [x for x in self] # noqa: C416,RUF100 + res: list[_DocumentType] = [] + while self.alive: + if not self._next_batch(res): + break + return res class RawBatchCursor(Cursor, Generic[_DocumentType]): diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index 5b3175a61..925584b89 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -1380,6 +1380,43 @@ class TestCursor(AsyncIntegrationTest): self.assertEqual("getMore", started[1].command_name) self.assertNotIn("$readPreference", started[1].command) + @async_client_context.require_replica_set + async def test_to_list_tailable(self): + oplog = self.client.local.oplog.rs + last = await oplog.find().sort("$natural", pymongo.DESCENDING).limit(-1).next() + ts = last["ts"] + + c = oplog.find( + {"ts": {"$gte": ts}}, cursor_type=pymongo.CursorType.TAILABLE_AWAIT, oplog_replay=True + ) + + docs = await c.to_list() + + self.assertGreaterEqual(len(docs), 1) + + async def test_to_list_empty(self): + c = self.db.does_not_exist.find() + + docs = await c.to_list() + + self.assertEqual([], docs) + + @async_client_context.require_replica_set + async def test_command_cursor_to_list(self): + c = await self.db.test.aggregate([{"$changeStream": {}}]) + + docs = await c.to_list() + + self.assertGreaterEqual(len(docs), 0) + + @async_client_context.require_replica_set + async def test_command_cursor_to_list_empty(self): + c = await self.db.does_not_exist.aggregate([{"$changeStream": {}}]) + + docs = await c.to_list() + + self.assertEqual([], docs) + class TestRawBatchCursor(AsyncIntegrationTest): async def test_find_raw(self): diff --git a/test/test_cursor.py b/test/test_cursor.py index 03824a5e7..12cb0cd57 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -1371,6 +1371,43 @@ class TestCursor(IntegrationTest): self.assertEqual("getMore", started[1].command_name) self.assertNotIn("$readPreference", started[1].command) + @client_context.require_replica_set + def test_to_list_tailable(self): + oplog = self.client.local.oplog.rs + last = oplog.find().sort("$natural", pymongo.DESCENDING).limit(-1).next() + ts = last["ts"] + + c = oplog.find( + {"ts": {"$gte": ts}}, cursor_type=pymongo.CursorType.TAILABLE_AWAIT, oplog_replay=True + ) + + docs = c.to_list() + + self.assertGreaterEqual(len(docs), 1) + + def test_to_list_empty(self): + c = self.db.does_not_exist.find() + + docs = c.to_list() + + self.assertEqual([], docs) + + @client_context.require_replica_set + def test_command_cursor_to_list(self): + c = self.db.test.aggregate([{"$changeStream": {}}]) + + docs = c.to_list() + + self.assertGreaterEqual(len(docs), 0) + + @client_context.require_replica_set + def test_command_cursor_to_list_empty(self): + c = self.db.does_not_exist.aggregate([{"$changeStream": {}}]) + + docs = c.to_list() + + self.assertEqual([], docs) + class TestRawBatchCursor(IntegrationTest): def test_find_raw(self):