PYTHON-4584 Add length option to Cursor.to_list for motor compat (#1791)
This commit is contained in:
parent
f2f75fc1c8
commit
adf8817df8
@ -1892,8 +1892,16 @@ class AsyncGridOutCursor(AsyncCursor):
|
||||
next_file = await super().next()
|
||||
return AsyncGridOut(self._root_collection, file_document=next_file, session=self.session)
|
||||
|
||||
async def to_list(self) -> list[AsyncGridOut]:
|
||||
return [x async for x in self] # noqa: C416,RUF100
|
||||
async def to_list(self, length: Optional[int] = None) -> list[AsyncGridOut]:
|
||||
"""Convert the cursor to a list."""
|
||||
if length is None:
|
||||
return [x async for x in self] # noqa: C416,RUF100
|
||||
if length < 1:
|
||||
raise ValueError("to_list() length must be greater than 0")
|
||||
ret = []
|
||||
for _ in range(length):
|
||||
ret.append(await self.next())
|
||||
return ret
|
||||
|
||||
__anext__ = next
|
||||
|
||||
|
||||
@ -1878,8 +1878,16 @@ class GridOutCursor(Cursor):
|
||||
next_file = super().next()
|
||||
return GridOut(self._root_collection, file_document=next_file, session=self.session)
|
||||
|
||||
def to_list(self) -> list[GridOut]:
|
||||
return [x for x in self] # noqa: C416,RUF100
|
||||
def to_list(self, length: Optional[int] = None) -> list[GridOut]:
|
||||
"""Convert the cursor to a list."""
|
||||
if length is None:
|
||||
return [x for x in self] # noqa: C416,RUF100
|
||||
if length < 1:
|
||||
raise ValueError("to_list() length must be greater than 0")
|
||||
ret = []
|
||||
for _ in range(length):
|
||||
ret.append(self.next())
|
||||
return ret
|
||||
|
||||
__next__ = next
|
||||
|
||||
|
||||
@ -346,13 +346,17 @@ class AsyncCommandCursor(Generic[_DocumentType]):
|
||||
else:
|
||||
return None
|
||||
|
||||
async def _next_batch(self, result: list) -> bool:
|
||||
"""Get all available documents from the cursor."""
|
||||
async def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
|
||||
"""Get all or some available documents from the cursor."""
|
||||
if not len(self._data) and not self._killed:
|
||||
await self._refresh()
|
||||
if len(self._data):
|
||||
result.extend(self._data)
|
||||
self._data.clear()
|
||||
if total is None:
|
||||
result.extend(self._data)
|
||||
self._data.clear()
|
||||
else:
|
||||
for _ in range(min(len(self._data), total)):
|
||||
result.append(self._data.popleft())
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
@ -381,21 +385,32 @@ class AsyncCommandCursor(Generic[_DocumentType]):
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
await self.close()
|
||||
|
||||
async def to_list(self) -> list[_DocumentType]:
|
||||
async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
|
||||
"""Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``.
|
||||
|
||||
To use::
|
||||
|
||||
>>> await cursor.to_list()
|
||||
|
||||
Or, so read at most n items from the cursor::
|
||||
|
||||
>>> await cursor.to_list(n)
|
||||
|
||||
If the cursor is empty or has no more results, an empty list will be returned.
|
||||
|
||||
.. versionadded:: 4.9
|
||||
"""
|
||||
res: list[_DocumentType] = []
|
||||
remaining = length
|
||||
if isinstance(length, int) and length < 1:
|
||||
raise ValueError("to_list() length must be greater than 0")
|
||||
while self.alive:
|
||||
if not await self._next_batch(res):
|
||||
if not await self._next_batch(res, remaining):
|
||||
break
|
||||
if length is not None:
|
||||
remaining = length - len(res)
|
||||
if remaining == 0:
|
||||
break
|
||||
return res
|
||||
|
||||
|
||||
|
||||
@ -1260,16 +1260,20 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
else:
|
||||
raise StopAsyncIteration
|
||||
|
||||
async def _next_batch(self, result: list) -> bool:
|
||||
"""Get all available documents from the cursor."""
|
||||
async def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
|
||||
"""Get all or some documents from the cursor."""
|
||||
if not self._exhaust_checked:
|
||||
self._exhaust_checked = True
|
||||
await self._supports_exhaust()
|
||||
if self._empty:
|
||||
return False
|
||||
if len(self._data) or await self._refresh():
|
||||
result.extend(self._data)
|
||||
self._data.clear()
|
||||
if total is None:
|
||||
result.extend(self._data)
|
||||
self._data.clear()
|
||||
else:
|
||||
for _ in range(min(len(self._data), total)):
|
||||
result.append(self._data.popleft())
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
@ -1286,21 +1290,32 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
await self.close()
|
||||
|
||||
async def to_list(self) -> list[_DocumentType]:
|
||||
async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
|
||||
"""Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``.
|
||||
|
||||
To use::
|
||||
|
||||
>>> await cursor.to_list()
|
||||
|
||||
Or, so read at most n items from the cursor::
|
||||
|
||||
>>> await cursor.to_list(n)
|
||||
|
||||
If the cursor is empty or has no more results, an empty list will be returned.
|
||||
|
||||
.. versionadded:: 4.9
|
||||
"""
|
||||
res: list[_DocumentType] = []
|
||||
remaining = length
|
||||
if isinstance(length, int) and length < 1:
|
||||
raise ValueError("to_list() length must be greater than 0")
|
||||
while self.alive:
|
||||
if not await self._next_batch(res):
|
||||
if not await self._next_batch(res, remaining):
|
||||
break
|
||||
if length is not None:
|
||||
remaining = length - len(res)
|
||||
if remaining == 0:
|
||||
break
|
||||
return res
|
||||
|
||||
|
||||
|
||||
@ -346,13 +346,17 @@ class CommandCursor(Generic[_DocumentType]):
|
||||
else:
|
||||
return None
|
||||
|
||||
def _next_batch(self, result: list) -> bool:
|
||||
"""Get all available documents from the cursor."""
|
||||
def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
|
||||
"""Get all or some available documents from the cursor."""
|
||||
if not len(self._data) and not self._killed:
|
||||
self._refresh()
|
||||
if len(self._data):
|
||||
result.extend(self._data)
|
||||
self._data.clear()
|
||||
if total is None:
|
||||
result.extend(self._data)
|
||||
self._data.clear()
|
||||
else:
|
||||
for _ in range(min(len(self._data), total)):
|
||||
result.append(self._data.popleft())
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
@ -381,21 +385,32 @@ class CommandCursor(Generic[_DocumentType]):
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def to_list(self) -> list[_DocumentType]:
|
||||
def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
|
||||
"""Converts the contents of this cursor to a list more efficiently than ``[doc for doc in cursor]``.
|
||||
|
||||
To use::
|
||||
|
||||
>>> cursor.to_list()
|
||||
|
||||
Or, so read at most n items from the cursor::
|
||||
|
||||
>>> cursor.to_list(n)
|
||||
|
||||
If the cursor is empty or has no more results, an empty list will be returned.
|
||||
|
||||
.. versionadded:: 4.9
|
||||
"""
|
||||
res: list[_DocumentType] = []
|
||||
remaining = length
|
||||
if isinstance(length, int) and length < 1:
|
||||
raise ValueError("to_list() length must be greater than 0")
|
||||
while self.alive:
|
||||
if not self._next_batch(res):
|
||||
if not self._next_batch(res, remaining):
|
||||
break
|
||||
if length is not None:
|
||||
remaining = length - len(res)
|
||||
if remaining == 0:
|
||||
break
|
||||
return res
|
||||
|
||||
|
||||
|
||||
@ -1258,16 +1258,20 @@ class Cursor(Generic[_DocumentType]):
|
||||
else:
|
||||
raise StopIteration
|
||||
|
||||
def _next_batch(self, result: list) -> bool:
|
||||
"""Get all available documents from the cursor."""
|
||||
def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
|
||||
"""Get all or some documents from the cursor."""
|
||||
if not self._exhaust_checked:
|
||||
self._exhaust_checked = True
|
||||
self._supports_exhaust()
|
||||
if self._empty:
|
||||
return False
|
||||
if len(self._data) or self._refresh():
|
||||
result.extend(self._data)
|
||||
self._data.clear()
|
||||
if total is None:
|
||||
result.extend(self._data)
|
||||
self._data.clear()
|
||||
else:
|
||||
for _ in range(min(len(self._data), total)):
|
||||
result.append(self._data.popleft())
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
@ -1284,21 +1288,32 @@ class Cursor(Generic[_DocumentType]):
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def to_list(self) -> list[_DocumentType]:
|
||||
def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
|
||||
"""Converts the contents of this cursor to a list more efficiently than ``[doc for doc in cursor]``.
|
||||
|
||||
To use::
|
||||
|
||||
>>> cursor.to_list()
|
||||
|
||||
Or, so read at most n items from the cursor::
|
||||
|
||||
>>> cursor.to_list(n)
|
||||
|
||||
If the cursor is empty or has no more results, an empty list will be returned.
|
||||
|
||||
.. versionadded:: 4.9
|
||||
"""
|
||||
res: list[_DocumentType] = []
|
||||
remaining = length
|
||||
if isinstance(length, int) and length < 1:
|
||||
raise ValueError("to_list() length must be greater than 0")
|
||||
while self.alive:
|
||||
if not self._next_batch(res):
|
||||
if not self._next_batch(res, remaining):
|
||||
break
|
||||
if length is not None:
|
||||
remaining = length - len(res)
|
||||
if remaining == 0:
|
||||
break
|
||||
return res
|
||||
|
||||
|
||||
|
||||
@ -1401,6 +1401,20 @@ class TestCursor(AsyncIntegrationTest):
|
||||
docs = await c.to_list()
|
||||
self.assertEqual([], docs)
|
||||
|
||||
async def test_to_list_length(self):
|
||||
coll = self.db.test
|
||||
await coll.insert_many([{} for _ in range(5)])
|
||||
self.addCleanup(coll.drop)
|
||||
c = coll.find()
|
||||
docs = await c.to_list(3)
|
||||
self.assertEqual(len(docs), 3)
|
||||
|
||||
c = coll.find(batch_size=2)
|
||||
docs = await c.to_list(3)
|
||||
self.assertEqual(len(docs), 3)
|
||||
docs = await c.to_list(3)
|
||||
self.assertEqual(len(docs), 2)
|
||||
|
||||
@async_client_context.require_change_streams
|
||||
async def test_command_cursor_to_list(self):
|
||||
# Set maxAwaitTimeMS=1 to speed up the test.
|
||||
@ -1417,6 +1431,19 @@ class TestCursor(AsyncIntegrationTest):
|
||||
docs = await c.to_list()
|
||||
self.assertEqual([], docs)
|
||||
|
||||
@async_client_context.require_change_streams
|
||||
async def test_command_cursor_to_list_length(self):
|
||||
db = self.db
|
||||
await db.drop_collection("test")
|
||||
await db.test.insert_many([{"foo": 1}, {"foo": 2}])
|
||||
|
||||
pipeline = {"$project": {"_id": False, "foo": True}}
|
||||
result = await db.test.aggregate([pipeline])
|
||||
self.assertEqual(len(await result.to_list()), 2)
|
||||
|
||||
result = await db.test.aggregate([pipeline])
|
||||
self.assertEqual(len(await result.to_list(1)), 1)
|
||||
|
||||
|
||||
class TestRawBatchCursor(AsyncIntegrationTest):
|
||||
async def test_find_raw(self):
|
||||
|
||||
@ -1392,6 +1392,20 @@ class TestCursor(IntegrationTest):
|
||||
docs = c.to_list()
|
||||
self.assertEqual([], docs)
|
||||
|
||||
def test_to_list_length(self):
|
||||
coll = self.db.test
|
||||
coll.insert_many([{} for _ in range(5)])
|
||||
self.addCleanup(coll.drop)
|
||||
c = coll.find()
|
||||
docs = c.to_list(3)
|
||||
self.assertEqual(len(docs), 3)
|
||||
|
||||
c = coll.find(batch_size=2)
|
||||
docs = c.to_list(3)
|
||||
self.assertEqual(len(docs), 3)
|
||||
docs = c.to_list(3)
|
||||
self.assertEqual(len(docs), 2)
|
||||
|
||||
@client_context.require_change_streams
|
||||
def test_command_cursor_to_list(self):
|
||||
# Set maxAwaitTimeMS=1 to speed up the test.
|
||||
@ -1408,6 +1422,19 @@ class TestCursor(IntegrationTest):
|
||||
docs = c.to_list()
|
||||
self.assertEqual([], docs)
|
||||
|
||||
@client_context.require_change_streams
|
||||
def test_command_cursor_to_list_length(self):
|
||||
db = self.db
|
||||
db.drop_collection("test")
|
||||
db.test.insert_many([{"foo": 1}, {"foo": 2}])
|
||||
|
||||
pipeline = {"$project": {"_id": False, "foo": True}}
|
||||
result = db.test.aggregate([pipeline])
|
||||
self.assertEqual(len(result.to_list()), 2)
|
||||
|
||||
result = db.test.aggregate([pipeline])
|
||||
self.assertEqual(len(result.to_list(1)), 1)
|
||||
|
||||
|
||||
class TestRawBatchCursor(IntegrationTest):
|
||||
def test_find_raw(self):
|
||||
|
||||
@ -440,6 +440,12 @@ class TestGridfs(IntegrationTest):
|
||||
gout = next(cursor)
|
||||
self.assertEqual(b"test2+", gout.read())
|
||||
self.assertRaises(StopIteration, cursor.__next__)
|
||||
cursor.rewind()
|
||||
items = cursor.to_list()
|
||||
self.assertEqual(len(items), 2)
|
||||
cursor.rewind()
|
||||
items = cursor.to_list(1)
|
||||
self.assertEqual(len(items), 1)
|
||||
cursor.close()
|
||||
self.assertRaises(TypeError, self.fs.find, {}, {"_id": True})
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user