diff --git a/gridfs/asynchronous/grid_file.py b/gridfs/asynchronous/grid_file.py index 303abe705..4d6140750 100644 --- a/gridfs/asynchronous/grid_file.py +++ b/gridfs/asynchronous/grid_file.py @@ -1892,8 +1892,16 @@ class AsyncGridOutCursor(AsyncCursor): next_file = await super().next() return AsyncGridOut(self._root_collection, file_document=next_file, session=self.session) - async def to_list(self) -> list[AsyncGridOut]: - return [x async for x in self] # noqa: C416,RUF100 + async def to_list(self, length: Optional[int] = None) -> list[AsyncGridOut]: + """Convert the cursor to a list.""" + if length is None: + return [x async for x in self] # noqa: C416,RUF100 + if length < 1: + raise ValueError("to_list() length must be greater than 0") + ret = [] + for _ in range(length): + ret.append(await self.next()) + return ret __anext__ = next diff --git a/gridfs/synchronous/grid_file.py b/gridfs/synchronous/grid_file.py index 1e3d265d4..bc2e29a61 100644 --- a/gridfs/synchronous/grid_file.py +++ b/gridfs/synchronous/grid_file.py @@ -1878,8 +1878,16 @@ class GridOutCursor(Cursor): next_file = super().next() return GridOut(self._root_collection, file_document=next_file, session=self.session) - def to_list(self) -> list[GridOut]: - return [x for x in self] # noqa: C416,RUF100 + def to_list(self, length: Optional[int] = None) -> list[GridOut]: + """Convert the cursor to a list.""" + if length is None: + return [x for x in self] # noqa: C416,RUF100 + if length < 1: + raise ValueError("to_list() length must be greater than 0") + ret = [] + for _ in range(length): + ret.append(self.next()) + return ret __next__ = next diff --git a/pymongo/asynchronous/command_cursor.py b/pymongo/asynchronous/command_cursor.py index b28f983b1..b2cd345f6 100644 --- a/pymongo/asynchronous/command_cursor.py +++ b/pymongo/asynchronous/command_cursor.py @@ -346,13 +346,17 @@ class AsyncCommandCursor(Generic[_DocumentType]): else: return None - async def _next_batch(self, result: list) -> bool: - """Get all available documents from the cursor.""" + async def _next_batch(self, result: list, total: Optional[int] = None) -> bool: + """Get all or some available documents from the cursor.""" if not len(self._data) and not self._killed: await self._refresh() if len(self._data): - result.extend(self._data) - self._data.clear() + if total is None: + result.extend(self._data) + self._data.clear() + else: + for _ in range(min(len(self._data), total)): + result.append(self._data.popleft()) return True else: return False @@ -381,21 +385,32 @@ class AsyncCommandCursor(Generic[_DocumentType]): async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: await self.close() - async def to_list(self) -> list[_DocumentType]: + async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]: """Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``. To use:: >>> await cursor.to_list() + Or, so read at most n items from the cursor:: + + >>> await cursor.to_list(n) + If the cursor is empty or has no more results, an empty list will be returned. .. versionadded:: 4.9 """ res: list[_DocumentType] = [] + remaining = length + if isinstance(length, int) and length < 1: + raise ValueError("to_list() length must be greater than 0") while self.alive: - if not await self._next_batch(res): + if not await self._next_batch(res, remaining): break + if length is not None: + remaining = length - len(res) + if remaining == 0: + break return res diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index 8421667be..bae77bb30 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -1260,16 +1260,20 @@ class AsyncCursor(Generic[_DocumentType]): else: raise StopAsyncIteration - async def _next_batch(self, result: list) -> bool: - """Get all available documents from the cursor.""" + async def _next_batch(self, result: list, total: Optional[int] = None) -> bool: + """Get all or some documents from the cursor.""" if not self._exhaust_checked: self._exhaust_checked = True await self._supports_exhaust() if self._empty: return False if len(self._data) or await self._refresh(): - result.extend(self._data) - self._data.clear() + if total is None: + result.extend(self._data) + self._data.clear() + else: + for _ in range(min(len(self._data), total)): + result.append(self._data.popleft()) return True else: return False @@ -1286,21 +1290,32 @@ class AsyncCursor(Generic[_DocumentType]): async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: await self.close() - async def to_list(self) -> list[_DocumentType]: + async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]: """Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``. To use:: >>> await cursor.to_list() + Or, so read at most n items from the cursor:: + + >>> await cursor.to_list(n) + If the cursor is empty or has no more results, an empty list will be returned. .. versionadded:: 4.9 """ res: list[_DocumentType] = [] + remaining = length + if isinstance(length, int) and length < 1: + raise ValueError("to_list() length must be greater than 0") while self.alive: - if not await self._next_batch(res): + if not await self._next_batch(res, remaining): break + if length is not None: + remaining = length - len(res) + if remaining == 0: + break return res diff --git a/pymongo/synchronous/command_cursor.py b/pymongo/synchronous/command_cursor.py index 86fa69dcb..da05bf1a3 100644 --- a/pymongo/synchronous/command_cursor.py +++ b/pymongo/synchronous/command_cursor.py @@ -346,13 +346,17 @@ class CommandCursor(Generic[_DocumentType]): else: return None - def _next_batch(self, result: list) -> bool: - """Get all available documents from the cursor.""" + def _next_batch(self, result: list, total: Optional[int] = None) -> bool: + """Get all or some available documents from the cursor.""" if not len(self._data) and not self._killed: self._refresh() if len(self._data): - result.extend(self._data) - self._data.clear() + if total is None: + result.extend(self._data) + self._data.clear() + else: + for _ in range(min(len(self._data), total)): + result.append(self._data.popleft()) return True else: return False @@ -381,21 +385,32 @@ class CommandCursor(Generic[_DocumentType]): def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.close() - def to_list(self) -> list[_DocumentType]: + def to_list(self, length: Optional[int] = None) -> list[_DocumentType]: """Converts the contents of this cursor to a list more efficiently than ``[doc for doc in cursor]``. To use:: >>> cursor.to_list() + Or, so read at most n items from the cursor:: + + >>> cursor.to_list(n) + If the cursor is empty or has no more results, an empty list will be returned. .. versionadded:: 4.9 """ res: list[_DocumentType] = [] + remaining = length + if isinstance(length, int) and length < 1: + raise ValueError("to_list() length must be greater than 0") while self.alive: - if not self._next_batch(res): + if not self._next_batch(res, remaining): break + if length is not None: + remaining = length - len(res) + if remaining == 0: + break return res diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index 1595ce40b..c352b6409 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -1258,16 +1258,20 @@ class Cursor(Generic[_DocumentType]): else: raise StopIteration - def _next_batch(self, result: list) -> bool: - """Get all available documents from the cursor.""" + def _next_batch(self, result: list, total: Optional[int] = None) -> bool: + """Get all or some documents from the cursor.""" if not self._exhaust_checked: self._exhaust_checked = True self._supports_exhaust() if self._empty: return False if len(self._data) or self._refresh(): - result.extend(self._data) - self._data.clear() + if total is None: + result.extend(self._data) + self._data.clear() + else: + for _ in range(min(len(self._data), total)): + result.append(self._data.popleft()) return True else: return False @@ -1284,21 +1288,32 @@ class Cursor(Generic[_DocumentType]): def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.close() - def to_list(self) -> list[_DocumentType]: + def to_list(self, length: Optional[int] = None) -> list[_DocumentType]: """Converts the contents of this cursor to a list more efficiently than ``[doc for doc in cursor]``. To use:: >>> cursor.to_list() + Or, so read at most n items from the cursor:: + + >>> cursor.to_list(n) + If the cursor is empty or has no more results, an empty list will be returned. .. versionadded:: 4.9 """ res: list[_DocumentType] = [] + remaining = length + if isinstance(length, int) and length < 1: + raise ValueError("to_list() length must be greater than 0") while self.alive: - if not self._next_batch(res): + if not self._next_batch(res, remaining): break + if length is not None: + remaining = length - len(res) + if remaining == 0: + break return res diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index 9dd67f2da..6967205fe 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -1401,6 +1401,20 @@ class TestCursor(AsyncIntegrationTest): docs = await c.to_list() self.assertEqual([], docs) + async def test_to_list_length(self): + coll = self.db.test + await coll.insert_many([{} for _ in range(5)]) + self.addCleanup(coll.drop) + c = coll.find() + docs = await c.to_list(3) + self.assertEqual(len(docs), 3) + + c = coll.find(batch_size=2) + docs = await c.to_list(3) + self.assertEqual(len(docs), 3) + docs = await c.to_list(3) + self.assertEqual(len(docs), 2) + @async_client_context.require_change_streams async def test_command_cursor_to_list(self): # Set maxAwaitTimeMS=1 to speed up the test. @@ -1417,6 +1431,19 @@ class TestCursor(AsyncIntegrationTest): docs = await c.to_list() self.assertEqual([], docs) + @async_client_context.require_change_streams + async def test_command_cursor_to_list_length(self): + db = self.db + await db.drop_collection("test") + await db.test.insert_many([{"foo": 1}, {"foo": 2}]) + + pipeline = {"$project": {"_id": False, "foo": True}} + result = await db.test.aggregate([pipeline]) + self.assertEqual(len(await result.to_list()), 2) + + result = await db.test.aggregate([pipeline]) + self.assertEqual(len(await result.to_list(1)), 1) + class TestRawBatchCursor(AsyncIntegrationTest): async def test_find_raw(self): diff --git a/test/test_cursor.py b/test/test_cursor.py index 1cde4718f..8e6fade1e 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -1392,6 +1392,20 @@ class TestCursor(IntegrationTest): docs = c.to_list() self.assertEqual([], docs) + def test_to_list_length(self): + coll = self.db.test + coll.insert_many([{} for _ in range(5)]) + self.addCleanup(coll.drop) + c = coll.find() + docs = c.to_list(3) + self.assertEqual(len(docs), 3) + + c = coll.find(batch_size=2) + docs = c.to_list(3) + self.assertEqual(len(docs), 3) + docs = c.to_list(3) + self.assertEqual(len(docs), 2) + @client_context.require_change_streams def test_command_cursor_to_list(self): # Set maxAwaitTimeMS=1 to speed up the test. @@ -1408,6 +1422,19 @@ class TestCursor(IntegrationTest): docs = c.to_list() self.assertEqual([], docs) + @client_context.require_change_streams + def test_command_cursor_to_list_length(self): + db = self.db + db.drop_collection("test") + db.test.insert_many([{"foo": 1}, {"foo": 2}]) + + pipeline = {"$project": {"_id": False, "foo": True}} + result = db.test.aggregate([pipeline]) + self.assertEqual(len(result.to_list()), 2) + + result = db.test.aggregate([pipeline]) + self.assertEqual(len(result.to_list(1)), 1) + class TestRawBatchCursor(IntegrationTest): def test_find_raw(self): diff --git a/test/test_gridfs.py b/test/test_gridfs.py index 27b38dc0b..19ec152bd 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -440,6 +440,12 @@ class TestGridfs(IntegrationTest): gout = next(cursor) self.assertEqual(b"test2+", gout.read()) self.assertRaises(StopIteration, cursor.__next__) + cursor.rewind() + items = cursor.to_list() + self.assertEqual(len(items), 2) + cursor.rewind() + items = cursor.to_list(1) + self.assertEqual(len(items), 1) cursor.close() self.assertRaises(TypeError, self.fs.find, {}, {"_id": True})