PYTHON-4584 Add length option to Cursor.to_list for motor compat (#1791)

This commit is contained in:
Steven Silvester 2024-08-14 13:13:56 -05:00 committed by GitHub
parent f2f75fc1c8
commit adf8817df8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 164 additions and 28 deletions

View File

@ -1892,8 +1892,16 @@ class AsyncGridOutCursor(AsyncCursor):
next_file = await super().next()
return AsyncGridOut(self._root_collection, file_document=next_file, session=self.session)
async def to_list(self) -> list[AsyncGridOut]:
return [x async for x in self] # noqa: C416,RUF100
async def to_list(self, length: Optional[int] = None) -> list[AsyncGridOut]:
"""Convert the cursor to a list."""
if length is None:
return [x async for x in self] # noqa: C416,RUF100
if length < 1:
raise ValueError("to_list() length must be greater than 0")
ret = []
for _ in range(length):
ret.append(await self.next())
return ret
__anext__ = next

View File

@ -1878,8 +1878,16 @@ class GridOutCursor(Cursor):
next_file = super().next()
return GridOut(self._root_collection, file_document=next_file, session=self.session)
def to_list(self) -> list[GridOut]:
return [x for x in self] # noqa: C416,RUF100
def to_list(self, length: Optional[int] = None) -> list[GridOut]:
"""Convert the cursor to a list."""
if length is None:
return [x for x in self] # noqa: C416,RUF100
if length < 1:
raise ValueError("to_list() length must be greater than 0")
ret = []
for _ in range(length):
ret.append(self.next())
return ret
__next__ = next

View File

@ -346,13 +346,17 @@ class AsyncCommandCursor(Generic[_DocumentType]):
else:
return None
async def _next_batch(self, result: list) -> bool:
"""Get all available documents from the cursor."""
async def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
"""Get all or some available documents from the cursor."""
if not len(self._data) and not self._killed:
await self._refresh()
if len(self._data):
result.extend(self._data)
self._data.clear()
if total is None:
result.extend(self._data)
self._data.clear()
else:
for _ in range(min(len(self._data), total)):
result.append(self._data.popleft())
return True
else:
return False
@ -381,21 +385,32 @@ class AsyncCommandCursor(Generic[_DocumentType]):
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.close()
async def to_list(self) -> list[_DocumentType]:
async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
"""Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``.
To use::
>>> await cursor.to_list()
Or, so read at most n items from the cursor::
>>> await cursor.to_list(n)
If the cursor is empty or has no more results, an empty list will be returned.
.. versionadded:: 4.9
"""
res: list[_DocumentType] = []
remaining = length
if isinstance(length, int) and length < 1:
raise ValueError("to_list() length must be greater than 0")
while self.alive:
if not await self._next_batch(res):
if not await self._next_batch(res, remaining):
break
if length is not None:
remaining = length - len(res)
if remaining == 0:
break
return res

View File

@ -1260,16 +1260,20 @@ class AsyncCursor(Generic[_DocumentType]):
else:
raise StopAsyncIteration
async def _next_batch(self, result: list) -> bool:
"""Get all available documents from the cursor."""
async def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
"""Get all or some documents from the cursor."""
if not self._exhaust_checked:
self._exhaust_checked = True
await self._supports_exhaust()
if self._empty:
return False
if len(self._data) or await self._refresh():
result.extend(self._data)
self._data.clear()
if total is None:
result.extend(self._data)
self._data.clear()
else:
for _ in range(min(len(self._data), total)):
result.append(self._data.popleft())
return True
else:
return False
@ -1286,21 +1290,32 @@ class AsyncCursor(Generic[_DocumentType]):
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.close()
async def to_list(self) -> list[_DocumentType]:
async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
"""Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``.
To use::
>>> await cursor.to_list()
Or, so read at most n items from the cursor::
>>> await cursor.to_list(n)
If the cursor is empty or has no more results, an empty list will be returned.
.. versionadded:: 4.9
"""
res: list[_DocumentType] = []
remaining = length
if isinstance(length, int) and length < 1:
raise ValueError("to_list() length must be greater than 0")
while self.alive:
if not await self._next_batch(res):
if not await self._next_batch(res, remaining):
break
if length is not None:
remaining = length - len(res)
if remaining == 0:
break
return res

View File

@ -346,13 +346,17 @@ class CommandCursor(Generic[_DocumentType]):
else:
return None
def _next_batch(self, result: list) -> bool:
"""Get all available documents from the cursor."""
def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
"""Get all or some available documents from the cursor."""
if not len(self._data) and not self._killed:
self._refresh()
if len(self._data):
result.extend(self._data)
self._data.clear()
if total is None:
result.extend(self._data)
self._data.clear()
else:
for _ in range(min(len(self._data), total)):
result.append(self._data.popleft())
return True
else:
return False
@ -381,21 +385,32 @@ class CommandCursor(Generic[_DocumentType]):
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close()
def to_list(self) -> list[_DocumentType]:
def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
"""Converts the contents of this cursor to a list more efficiently than ``[doc for doc in cursor]``.
To use::
>>> cursor.to_list()
Or, so read at most n items from the cursor::
>>> cursor.to_list(n)
If the cursor is empty or has no more results, an empty list will be returned.
.. versionadded:: 4.9
"""
res: list[_DocumentType] = []
remaining = length
if isinstance(length, int) and length < 1:
raise ValueError("to_list() length must be greater than 0")
while self.alive:
if not self._next_batch(res):
if not self._next_batch(res, remaining):
break
if length is not None:
remaining = length - len(res)
if remaining == 0:
break
return res

View File

@ -1258,16 +1258,20 @@ class Cursor(Generic[_DocumentType]):
else:
raise StopIteration
def _next_batch(self, result: list) -> bool:
"""Get all available documents from the cursor."""
def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
"""Get all or some documents from the cursor."""
if not self._exhaust_checked:
self._exhaust_checked = True
self._supports_exhaust()
if self._empty:
return False
if len(self._data) or self._refresh():
result.extend(self._data)
self._data.clear()
if total is None:
result.extend(self._data)
self._data.clear()
else:
for _ in range(min(len(self._data), total)):
result.append(self._data.popleft())
return True
else:
return False
@ -1284,21 +1288,32 @@ class Cursor(Generic[_DocumentType]):
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.close()
def to_list(self) -> list[_DocumentType]:
def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
"""Converts the contents of this cursor to a list more efficiently than ``[doc for doc in cursor]``.
To use::
>>> cursor.to_list()
Or, so read at most n items from the cursor::
>>> cursor.to_list(n)
If the cursor is empty or has no more results, an empty list will be returned.
.. versionadded:: 4.9
"""
res: list[_DocumentType] = []
remaining = length
if isinstance(length, int) and length < 1:
raise ValueError("to_list() length must be greater than 0")
while self.alive:
if not self._next_batch(res):
if not self._next_batch(res, remaining):
break
if length is not None:
remaining = length - len(res)
if remaining == 0:
break
return res

View File

@ -1401,6 +1401,20 @@ class TestCursor(AsyncIntegrationTest):
docs = await c.to_list()
self.assertEqual([], docs)
async def test_to_list_length(self):
coll = self.db.test
await coll.insert_many([{} for _ in range(5)])
self.addCleanup(coll.drop)
c = coll.find()
docs = await c.to_list(3)
self.assertEqual(len(docs), 3)
c = coll.find(batch_size=2)
docs = await c.to_list(3)
self.assertEqual(len(docs), 3)
docs = await c.to_list(3)
self.assertEqual(len(docs), 2)
@async_client_context.require_change_streams
async def test_command_cursor_to_list(self):
# Set maxAwaitTimeMS=1 to speed up the test.
@ -1417,6 +1431,19 @@ class TestCursor(AsyncIntegrationTest):
docs = await c.to_list()
self.assertEqual([], docs)
@async_client_context.require_change_streams
async def test_command_cursor_to_list_length(self):
db = self.db
await db.drop_collection("test")
await db.test.insert_many([{"foo": 1}, {"foo": 2}])
pipeline = {"$project": {"_id": False, "foo": True}}
result = await db.test.aggregate([pipeline])
self.assertEqual(len(await result.to_list()), 2)
result = await db.test.aggregate([pipeline])
self.assertEqual(len(await result.to_list(1)), 1)
class TestRawBatchCursor(AsyncIntegrationTest):
async def test_find_raw(self):

View File

@ -1392,6 +1392,20 @@ class TestCursor(IntegrationTest):
docs = c.to_list()
self.assertEqual([], docs)
def test_to_list_length(self):
coll = self.db.test
coll.insert_many([{} for _ in range(5)])
self.addCleanup(coll.drop)
c = coll.find()
docs = c.to_list(3)
self.assertEqual(len(docs), 3)
c = coll.find(batch_size=2)
docs = c.to_list(3)
self.assertEqual(len(docs), 3)
docs = c.to_list(3)
self.assertEqual(len(docs), 2)
@client_context.require_change_streams
def test_command_cursor_to_list(self):
# Set maxAwaitTimeMS=1 to speed up the test.
@ -1408,6 +1422,19 @@ class TestCursor(IntegrationTest):
docs = c.to_list()
self.assertEqual([], docs)
@client_context.require_change_streams
def test_command_cursor_to_list_length(self):
db = self.db
db.drop_collection("test")
db.test.insert_many([{"foo": 1}, {"foo": 2}])
pipeline = {"$project": {"_id": False, "foo": True}}
result = db.test.aggregate([pipeline])
self.assertEqual(len(result.to_list()), 2)
result = db.test.aggregate([pipeline])
self.assertEqual(len(result.to_list(1)), 1)
class TestRawBatchCursor(IntegrationTest):
def test_find_raw(self):

View File

@ -440,6 +440,12 @@ class TestGridfs(IntegrationTest):
gout = next(cursor)
self.assertEqual(b"test2+", gout.read())
self.assertRaises(StopIteration, cursor.__next__)
cursor.rewind()
items = cursor.to_list()
self.assertEqual(len(items), 2)
cursor.rewind()
items = cursor.to_list(1)
self.assertEqual(len(items), 1)
cursor.close()
self.assertRaises(TypeError, self.fs.find, {}, {"_id": True})