Merge branch 'master' of github.com:mongodb/mongo-python-driver
This commit is contained in:
commit
ff55b8178a
@ -6,7 +6,7 @@
|
||||
|
||||
.. autoclass:: pymongo.asynchronous.mongo_client.AsyncMongoClient(host='localhost', port=27017, document_class=dict, tz_aware=False, connect=True, **kwargs)
|
||||
|
||||
.. automethod:: aclose
|
||||
.. automethod:: close
|
||||
|
||||
.. describe:: c[db_name] || c.db_name
|
||||
|
||||
|
||||
@ -1892,8 +1892,16 @@ class AsyncGridOutCursor(AsyncCursor):
|
||||
next_file = await super().next()
|
||||
return AsyncGridOut(self._root_collection, file_document=next_file, session=self.session)
|
||||
|
||||
async def to_list(self) -> list[AsyncGridOut]:
|
||||
return [x async for x in self] # noqa: C416,RUF100
|
||||
async def to_list(self, length: Optional[int] = None) -> list[AsyncGridOut]:
|
||||
"""Convert the cursor to a list."""
|
||||
if length is None:
|
||||
return [x async for x in self] # noqa: C416,RUF100
|
||||
if length < 1:
|
||||
raise ValueError("to_list() length must be greater than 0")
|
||||
ret = []
|
||||
for _ in range(length):
|
||||
ret.append(await self.next())
|
||||
return ret
|
||||
|
||||
__anext__ = next
|
||||
|
||||
|
||||
@ -1878,8 +1878,16 @@ class GridOutCursor(Cursor):
|
||||
next_file = super().next()
|
||||
return GridOut(self._root_collection, file_document=next_file, session=self.session)
|
||||
|
||||
def to_list(self) -> list[GridOut]:
|
||||
return [x for x in self] # noqa: C416,RUF100
|
||||
def to_list(self, length: Optional[int] = None) -> list[GridOut]:
|
||||
"""Convert the cursor to a list."""
|
||||
if length is None:
|
||||
return [x for x in self] # noqa: C416,RUF100
|
||||
if length < 1:
|
||||
raise ValueError("to_list() length must be greater than 0")
|
||||
ret = []
|
||||
for _ in range(length):
|
||||
ret.append(self.next())
|
||||
return ret
|
||||
|
||||
__next__ = next
|
||||
|
||||
|
||||
@ -1893,9 +1893,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
"""
|
||||
return AsyncCursor(self, *args, **kwargs)
|
||||
|
||||
async def find_raw_batches(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> AsyncRawBatchCursor[_DocumentType]:
|
||||
def find_raw_batches(self, *args: Any, **kwargs: Any) -> AsyncRawBatchCursor[_DocumentType]:
|
||||
"""Query the database and retrieve batches of raw BSON.
|
||||
|
||||
Similar to the :meth:`find` method but returns a
|
||||
@ -1907,7 +1905,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
|
||||
:mod:`bson` module.
|
||||
|
||||
>>> import bson
|
||||
>>> cursor = await db.test.find_raw_batches()
|
||||
>>> cursor = db.test.find_raw_batches()
|
||||
>>> async for batch in cursor:
|
||||
... print(bson.decode_all(batch))
|
||||
|
||||
|
||||
@ -346,13 +346,17 @@ class AsyncCommandCursor(Generic[_DocumentType]):
|
||||
else:
|
||||
return None
|
||||
|
||||
async def _next_batch(self, result: list) -> bool:
|
||||
"""Get all available documents from the cursor."""
|
||||
async def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
|
||||
"""Get all or some available documents from the cursor."""
|
||||
if not len(self._data) and not self._killed:
|
||||
await self._refresh()
|
||||
if len(self._data):
|
||||
result.extend(self._data)
|
||||
self._data.clear()
|
||||
if total is None:
|
||||
result.extend(self._data)
|
||||
self._data.clear()
|
||||
else:
|
||||
for _ in range(min(len(self._data), total)):
|
||||
result.append(self._data.popleft())
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
@ -381,21 +385,32 @@ class AsyncCommandCursor(Generic[_DocumentType]):
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
await self.close()
|
||||
|
||||
async def to_list(self) -> list[_DocumentType]:
|
||||
async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
|
||||
"""Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``.
|
||||
|
||||
To use::
|
||||
|
||||
>>> await cursor.to_list()
|
||||
|
||||
Or, so read at most n items from the cursor::
|
||||
|
||||
>>> await cursor.to_list(n)
|
||||
|
||||
If the cursor is empty or has no more results, an empty list will be returned.
|
||||
|
||||
.. versionadded:: 4.9
|
||||
"""
|
||||
res: list[_DocumentType] = []
|
||||
remaining = length
|
||||
if isinstance(length, int) and length < 1:
|
||||
raise ValueError("to_list() length must be greater than 0")
|
||||
while self.alive:
|
||||
if not await self._next_batch(res):
|
||||
if not await self._next_batch(res, remaining):
|
||||
break
|
||||
if length is not None:
|
||||
remaining = length - len(res)
|
||||
if remaining == 0:
|
||||
break
|
||||
return res
|
||||
|
||||
|
||||
|
||||
@ -1260,16 +1260,20 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
else:
|
||||
raise StopAsyncIteration
|
||||
|
||||
async def _next_batch(self, result: list) -> bool:
|
||||
"""Get all available documents from the cursor."""
|
||||
async def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
|
||||
"""Get all or some documents from the cursor."""
|
||||
if not self._exhaust_checked:
|
||||
self._exhaust_checked = True
|
||||
await self._supports_exhaust()
|
||||
if self._empty:
|
||||
return False
|
||||
if len(self._data) or await self._refresh():
|
||||
result.extend(self._data)
|
||||
self._data.clear()
|
||||
if total is None:
|
||||
result.extend(self._data)
|
||||
self._data.clear()
|
||||
else:
|
||||
for _ in range(min(len(self._data), total)):
|
||||
result.append(self._data.popleft())
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
@ -1286,21 +1290,32 @@ class AsyncCursor(Generic[_DocumentType]):
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
await self.close()
|
||||
|
||||
async def to_list(self) -> list[_DocumentType]:
|
||||
async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
|
||||
"""Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``.
|
||||
|
||||
To use::
|
||||
|
||||
>>> await cursor.to_list()
|
||||
|
||||
Or, so read at most n items from the cursor::
|
||||
|
||||
>>> await cursor.to_list(n)
|
||||
|
||||
If the cursor is empty or has no more results, an empty list will be returned.
|
||||
|
||||
.. versionadded:: 4.9
|
||||
"""
|
||||
res: list[_DocumentType] = []
|
||||
remaining = length
|
||||
if isinstance(length, int) and length < 1:
|
||||
raise ValueError("to_list() length must be greater than 0")
|
||||
while self.alive:
|
||||
if not await self._next_batch(res):
|
||||
if not await self._next_batch(res, remaining):
|
||||
break
|
||||
if length is not None:
|
||||
remaining = length - len(res)
|
||||
if remaining == 0:
|
||||
break
|
||||
return res
|
||||
|
||||
|
||||
|
||||
@ -299,7 +299,7 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
|
||||
self.client_ref = None
|
||||
self.key_vault_coll = None
|
||||
if self.mongocryptd_client:
|
||||
await self.mongocryptd_client.aclose()
|
||||
await self.mongocryptd_client.close()
|
||||
self.mongocryptd_client = None
|
||||
|
||||
|
||||
@ -439,7 +439,7 @@ class _Encrypter:
|
||||
self._closed = True
|
||||
await self._auto_encrypter.close()
|
||||
if self._internal_client:
|
||||
await self._internal_client.aclose()
|
||||
await self._internal_client.close()
|
||||
self._internal_client = None
|
||||
|
||||
|
||||
|
||||
@ -1378,7 +1378,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
await self.aclose()
|
||||
await self.close()
|
||||
|
||||
# See PYTHON-3084.
|
||||
__iter__ = None
|
||||
@ -1514,7 +1514,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
# command.
|
||||
pass
|
||||
|
||||
async def aclose(self) -> None:
|
||||
async def close(self) -> None:
|
||||
"""Cleanup client resources and disconnect from MongoDB.
|
||||
|
||||
End all server sessions created by this client by sending one or more
|
||||
@ -1541,6 +1541,10 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
# TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened.
|
||||
await self._encrypter.close()
|
||||
|
||||
if not _IS_SYNC:
|
||||
# Add support for contextlib.aclosing.
|
||||
aclose = close
|
||||
|
||||
async def _get_topology(self) -> Topology:
|
||||
"""Get the internal :class:`~pymongo.asynchronous.topology.Topology` object.
|
||||
|
||||
|
||||
@ -77,6 +77,7 @@ async def async_sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> Non
|
||||
async def _async_sendall_ssl(
|
||||
sock: Union[socket.socket, _sslConn], buf: bytes, loop: AbstractEventLoop
|
||||
) -> None:
|
||||
view = memoryview(buf)
|
||||
fd = sock.fileno()
|
||||
sent = 0
|
||||
|
||||
@ -89,7 +90,7 @@ async def _async_sendall_ssl(
|
||||
|
||||
while sent < len(buf):
|
||||
try:
|
||||
sent += sock.send(buf)
|
||||
sent += sock.send(view[sent:])
|
||||
except BLOCKING_IO_ERRORS as exc:
|
||||
fd = sock.fileno()
|
||||
# Check for closed socket.
|
||||
|
||||
@ -346,13 +346,17 @@ class CommandCursor(Generic[_DocumentType]):
|
||||
else:
|
||||
return None
|
||||
|
||||
def _next_batch(self, result: list) -> bool:
|
||||
"""Get all available documents from the cursor."""
|
||||
def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
|
||||
"""Get all or some available documents from the cursor."""
|
||||
if not len(self._data) and not self._killed:
|
||||
self._refresh()
|
||||
if len(self._data):
|
||||
result.extend(self._data)
|
||||
self._data.clear()
|
||||
if total is None:
|
||||
result.extend(self._data)
|
||||
self._data.clear()
|
||||
else:
|
||||
for _ in range(min(len(self._data), total)):
|
||||
result.append(self._data.popleft())
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
@ -381,21 +385,32 @@ class CommandCursor(Generic[_DocumentType]):
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def to_list(self) -> list[_DocumentType]:
|
||||
def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
|
||||
"""Converts the contents of this cursor to a list more efficiently than ``[doc for doc in cursor]``.
|
||||
|
||||
To use::
|
||||
|
||||
>>> cursor.to_list()
|
||||
|
||||
Or, so read at most n items from the cursor::
|
||||
|
||||
>>> cursor.to_list(n)
|
||||
|
||||
If the cursor is empty or has no more results, an empty list will be returned.
|
||||
|
||||
.. versionadded:: 4.9
|
||||
"""
|
||||
res: list[_DocumentType] = []
|
||||
remaining = length
|
||||
if isinstance(length, int) and length < 1:
|
||||
raise ValueError("to_list() length must be greater than 0")
|
||||
while self.alive:
|
||||
if not self._next_batch(res):
|
||||
if not self._next_batch(res, remaining):
|
||||
break
|
||||
if length is not None:
|
||||
remaining = length - len(res)
|
||||
if remaining == 0:
|
||||
break
|
||||
return res
|
||||
|
||||
|
||||
|
||||
@ -1258,16 +1258,20 @@ class Cursor(Generic[_DocumentType]):
|
||||
else:
|
||||
raise StopIteration
|
||||
|
||||
def _next_batch(self, result: list) -> bool:
|
||||
"""Get all available documents from the cursor."""
|
||||
def _next_batch(self, result: list, total: Optional[int] = None) -> bool:
|
||||
"""Get all or some documents from the cursor."""
|
||||
if not self._exhaust_checked:
|
||||
self._exhaust_checked = True
|
||||
self._supports_exhaust()
|
||||
if self._empty:
|
||||
return False
|
||||
if len(self._data) or self._refresh():
|
||||
result.extend(self._data)
|
||||
self._data.clear()
|
||||
if total is None:
|
||||
result.extend(self._data)
|
||||
self._data.clear()
|
||||
else:
|
||||
for _ in range(min(len(self._data), total)):
|
||||
result.append(self._data.popleft())
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
@ -1284,21 +1288,32 @@ class Cursor(Generic[_DocumentType]):
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self.close()
|
||||
|
||||
def to_list(self) -> list[_DocumentType]:
|
||||
def to_list(self, length: Optional[int] = None) -> list[_DocumentType]:
|
||||
"""Converts the contents of this cursor to a list more efficiently than ``[doc for doc in cursor]``.
|
||||
|
||||
To use::
|
||||
|
||||
>>> cursor.to_list()
|
||||
|
||||
Or, so read at most n items from the cursor::
|
||||
|
||||
>>> cursor.to_list(n)
|
||||
|
||||
If the cursor is empty or has no more results, an empty list will be returned.
|
||||
|
||||
.. versionadded:: 4.9
|
||||
"""
|
||||
res: list[_DocumentType] = []
|
||||
remaining = length
|
||||
if isinstance(length, int) and length < 1:
|
||||
raise ValueError("to_list() length must be greater than 0")
|
||||
while self.alive:
|
||||
if not self._next_batch(res):
|
||||
if not self._next_batch(res, remaining):
|
||||
break
|
||||
if length is not None:
|
||||
remaining = length - len(res)
|
||||
if remaining == 0:
|
||||
break
|
||||
return res
|
||||
|
||||
|
||||
|
||||
@ -1536,6 +1536,10 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
# TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened.
|
||||
self._encrypter.close()
|
||||
|
||||
if not _IS_SYNC:
|
||||
# Add support for contextlib.closing.
|
||||
aclose = close
|
||||
|
||||
def _get_topology(self) -> Topology:
|
||||
"""Get the internal :class:`~pymongo.topology.Topology` object.
|
||||
|
||||
|
||||
@ -193,7 +193,7 @@ class AsyncClientContext:
|
||||
self.connection_attempts.append(f"failed to connect client {client!r}: {exc}")
|
||||
return None
|
||||
finally:
|
||||
await client.aclose()
|
||||
await client.close()
|
||||
|
||||
async def _init_client(self):
|
||||
self.client = await self._connect(host, port)
|
||||
@ -409,7 +409,7 @@ class AsyncClientContext:
|
||||
else:
|
||||
raise
|
||||
finally:
|
||||
await client.aclose()
|
||||
await client.close()
|
||||
|
||||
def _server_started_with_auth(self):
|
||||
# MongoDB >= 2.0
|
||||
@ -1089,7 +1089,7 @@ async def async_teardown():
|
||||
await c.drop_database("pymongo_test2")
|
||||
await c.drop_database("pymongo_test_mike")
|
||||
await c.drop_database("pymongo_test_bernie")
|
||||
await c.aclose()
|
||||
await c.close()
|
||||
|
||||
print_running_clients()
|
||||
|
||||
|
||||
@ -137,7 +137,7 @@ class AsyncClientUnitTest(AsyncUnitTest):
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
await cls.client.aclose()
|
||||
await cls.client.close()
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, caplog):
|
||||
@ -663,7 +663,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
pass
|
||||
self.assertEqual(1, len(server._pool.conns))
|
||||
self.assertTrue(conn in server._pool.conns)
|
||||
await client.aclose()
|
||||
await client.close()
|
||||
|
||||
async def test_max_idle_time_reaper_removes_stale_minPoolSize(self):
|
||||
with client_knobs(kill_cursor_frequency=0.1):
|
||||
@ -679,7 +679,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
self.assertGreaterEqual(len(server._pool.conns), 1)
|
||||
wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
|
||||
wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket")
|
||||
await client.aclose()
|
||||
await client.close()
|
||||
|
||||
async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self):
|
||||
with client_knobs(kill_cursor_frequency=0.1):
|
||||
@ -697,7 +697,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
self.assertEqual(1, len(server._pool.conns))
|
||||
wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
|
||||
wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket")
|
||||
await client.aclose()
|
||||
await client.close()
|
||||
|
||||
async def test_max_idle_time_reaper_removes_stale(self):
|
||||
with client_knobs(kill_cursor_frequency=0.1):
|
||||
@ -717,7 +717,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
lambda: len(server._pool.conns) == 0,
|
||||
"stale socket reaped and new one NOT added to the pool",
|
||||
)
|
||||
await client.aclose()
|
||||
await client.close()
|
||||
|
||||
async def test_min_pool_size(self):
|
||||
with client_knobs(kill_cursor_frequency=0.1):
|
||||
@ -845,13 +845,13 @@ class TestClient(AsyncIntegrationTest):
|
||||
async def test_equality(self):
|
||||
seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0])
|
||||
c = await async_rs_or_single_client(seed, connect=False)
|
||||
self.addAsyncCleanup(c.aclose)
|
||||
self.addAsyncCleanup(c.close)
|
||||
self.assertEqual(async_client_context.client, c)
|
||||
# Explicitly test inequality
|
||||
self.assertFalse(async_client_context.client != c)
|
||||
|
||||
c = await async_rs_or_single_client("invalid.com", connect=False)
|
||||
self.addAsyncCleanup(c.aclose)
|
||||
self.addAsyncCleanup(c.close)
|
||||
self.assertNotEqual(async_client_context.client, c)
|
||||
self.assertTrue(async_client_context.client != c)
|
||||
# Seeds differ:
|
||||
@ -867,10 +867,10 @@ class TestClient(AsyncIntegrationTest):
|
||||
async def test_hashable(self):
|
||||
seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0])
|
||||
c = await async_rs_or_single_client(seed, connect=False)
|
||||
self.addAsyncCleanup(c.aclose)
|
||||
self.addAsyncCleanup(c.close)
|
||||
self.assertIn(c, {async_client_context.client})
|
||||
c = await async_rs_or_single_client("invalid.com", connect=False)
|
||||
self.addAsyncCleanup(c.aclose)
|
||||
self.addAsyncCleanup(c.close)
|
||||
self.assertNotIn(c, {async_client_context.client})
|
||||
|
||||
async def test_host_w_port(self):
|
||||
@ -940,7 +940,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
self.assertIs(type(helper_doc), dict)
|
||||
self.assertEqual(helper_doc.keys(), cmd_doc.keys())
|
||||
client = await async_rs_or_single_client(document_class=SON)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
async for doc in await client.list_databases():
|
||||
self.assertIs(type(doc), dict)
|
||||
|
||||
@ -991,7 +991,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
async def test_close(self):
|
||||
test_client = await async_rs_or_single_client()
|
||||
coll = test_client.pymongo_test.bar
|
||||
await test_client.aclose()
|
||||
await test_client.close()
|
||||
with self.assertRaises(InvalidOperation):
|
||||
await coll.count_documents({})
|
||||
|
||||
@ -1024,7 +1024,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
|
||||
# Close the client and ensure the topology is closed.
|
||||
self.assertTrue(test_client._topology._opened)
|
||||
await test_client.aclose()
|
||||
await test_client.close()
|
||||
self.assertFalse(test_client._topology._opened)
|
||||
test_client = await async_rs_or_single_client()
|
||||
# The killCursors task should not need to re-open the topology.
|
||||
@ -1037,7 +1037,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
self.assertFalse(client._kill_cursors_executor._stopped)
|
||||
|
||||
# Closing the client should stop the thread.
|
||||
await client.aclose()
|
||||
await client.close()
|
||||
self.assertTrue(client._kill_cursors_executor._stopped)
|
||||
|
||||
# Reusing the closed client should raise an InvalidOperation error.
|
||||
@ -1062,21 +1062,21 @@ class TestClient(AsyncIntegrationTest):
|
||||
self.assertTrue(kc_thread and kc_thread.is_alive())
|
||||
|
||||
# Tear down.
|
||||
await client.aclose()
|
||||
await client.close()
|
||||
|
||||
async def test_close_does_not_open_servers(self):
|
||||
client = await async_rs_client(connect=False)
|
||||
topology = client._topology
|
||||
self.assertEqual(topology._servers, {})
|
||||
await client.aclose()
|
||||
await client.close()
|
||||
self.assertEqual(topology._servers, {})
|
||||
|
||||
async def test_close_closes_sockets(self):
|
||||
client = await async_rs_client()
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
await client.test.test.find_one()
|
||||
topology = client._topology
|
||||
await client.aclose()
|
||||
await client.close()
|
||||
for server in topology._servers.values():
|
||||
self.assertFalse(server._pool.conns)
|
||||
self.assertTrue(server._monitor._executor._stopped)
|
||||
@ -1181,7 +1181,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
uri = "mongodb://%s" % encoded_socket
|
||||
# Confirm we can do operations via the socket.
|
||||
client = await async_rs_or_single_client(uri)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
await client.pymongo_test.test.insert_one({"dummy": "object"})
|
||||
dbs = await client.list_database_names()
|
||||
self.assertTrue("pymongo_test" in dbs)
|
||||
@ -1206,7 +1206,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
self.assertFalse(isinstance(await db.test.find_one(), SON))
|
||||
|
||||
c = await async_rs_or_single_client(document_class=SON)
|
||||
self.addAsyncCleanup(c.aclose)
|
||||
self.addAsyncCleanup(c.close)
|
||||
db = c.pymongo_test
|
||||
|
||||
self.assertEqual(SON, c.codec_options.document_class)
|
||||
@ -1248,7 +1248,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
no_timeout = self.client
|
||||
timeout_sec = 1
|
||||
timeout = await async_rs_or_single_client(socketTimeoutMS=1000 * timeout_sec)
|
||||
self.addAsyncCleanup(timeout.aclose)
|
||||
self.addAsyncCleanup(timeout.close)
|
||||
|
||||
await no_timeout.pymongo_test.drop_collection("test")
|
||||
await no_timeout.pymongo_test.test.insert_one({"x": 1})
|
||||
@ -1310,7 +1310,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
self.assertRaises(ValueError, AsyncMongoClient, tz_aware="foo")
|
||||
|
||||
aware = await async_rs_or_single_client(tz_aware=True)
|
||||
self.addAsyncCleanup(aware.aclose)
|
||||
self.addAsyncCleanup(aware.close)
|
||||
naive = self.client
|
||||
await aware.pymongo_test.drop_collection("test")
|
||||
|
||||
@ -1340,7 +1340,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
uri += "/?replicaSet=" + (async_client_context.replica_set_name or "")
|
||||
|
||||
client = await async_rs_or_single_client_noauth(uri)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
await client.pymongo_test.test.insert_one({"dummy": "object"})
|
||||
await client.pymongo_test_bernie.test.insert_one({"dummy": "object"})
|
||||
|
||||
@ -1442,7 +1442,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
# to avoid race conditions caused by replica set failover or idle
|
||||
# socket reaping.
|
||||
client = await async_single_client()
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
await client.pymongo_test.test.find_one()
|
||||
pool = await async_get_pool(client)
|
||||
socket_count = len(pool.conns)
|
||||
@ -1467,7 +1467,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
self.addAsyncCleanup(async_client_context.client.drop_database, "test_lazy_connect_w0")
|
||||
|
||||
client = await async_rs_or_single_client(connect=False, w=0)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
await client.test_lazy_connect_w0.test.insert_one({})
|
||||
|
||||
async def predicate():
|
||||
@ -1476,7 +1476,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
await async_wait_until(predicate, "find one document")
|
||||
|
||||
client = await async_rs_or_single_client(connect=False, w=0)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
await client.test_lazy_connect_w0.test.update_one({}, {"$set": {"x": 1}})
|
||||
|
||||
async def predicate():
|
||||
@ -1485,7 +1485,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
await async_wait_until(predicate, "update one document")
|
||||
|
||||
client = await async_rs_or_single_client(connect=False, w=0)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
await client.test_lazy_connect_w0.test.delete_one({})
|
||||
|
||||
async def predicate():
|
||||
@ -1498,7 +1498,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
# When doing an exhaust query, the socket stays checked out on success
|
||||
# but must be checked in on error to avoid semaphore leaks.
|
||||
client = await async_rs_or_single_client(maxPoolSize=1, retryReads=False)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
collection = client.pymongo_test.test
|
||||
pool = await async_get_pool(client)
|
||||
pool._check_interval_seconds = None # Never check.
|
||||
@ -1611,7 +1611,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
# closer to 0.5 sec with heartbeatFrequencyMS configured.
|
||||
self.assertAlmostEqual(heartbeat_times[1] - heartbeat_times[0], 0.5, delta=2)
|
||||
|
||||
await client.aclose()
|
||||
await client.close()
|
||||
finally:
|
||||
ServerHeartbeatStartedEvent.__init__ = old_init # type: ignore
|
||||
|
||||
@ -1709,7 +1709,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
|
||||
async def test_reset_during_update_pool(self):
|
||||
client = await async_rs_or_single_client(minPoolSize=10)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
await client.admin.command("ping")
|
||||
pool = await async_get_pool(client)
|
||||
generation = pool.gen.get_overall()
|
||||
@ -1758,7 +1758,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
client = await async_rs_or_single_client(
|
||||
serverSelectionTimeoutMS=3000, minPoolSize=min_pool_size, connect=False
|
||||
)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
|
||||
# Create a single connection in the pool.
|
||||
await client.admin.command("ping")
|
||||
@ -1793,7 +1793,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
await client.admin.command("ping")
|
||||
self.assertEqual(len(client.nodes), 1)
|
||||
self.assertEqual(client._topology_settings.get_topology_type(), TOPOLOGY_TYPE.Single)
|
||||
await client.aclose()
|
||||
await client.close()
|
||||
|
||||
# direct_connection=False should result in RS topology.
|
||||
client = await async_rs_or_single_client(directConnection=False)
|
||||
@ -1803,7 +1803,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
client._topology_settings.get_topology_type(),
|
||||
[TOPOLOGY_TYPE.ReplicaSetNoPrimary, TOPOLOGY_TYPE.ReplicaSetWithPrimary],
|
||||
)
|
||||
await client.aclose()
|
||||
await client.close()
|
||||
|
||||
# directConnection=True, should error with multiple hosts as a list.
|
||||
with self.assertRaises(ConfigurationError):
|
||||
@ -1827,7 +1827,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
"invalid:27017", heartbeatFrequencyMS=3, serverSelectionTimeoutMS=150
|
||||
)
|
||||
initial_count = server_description_count()
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
with self.assertRaises(ServerSelectionTimeoutError):
|
||||
await client.test.test.find_one()
|
||||
gc.collect()
|
||||
@ -1841,7 +1841,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
@async_client_context.require_failCommand_fail_point
|
||||
async def test_network_error_message(self):
|
||||
client = await async_single_client(retryReads=False)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
await client.admin.command("ping") # connect
|
||||
async with self.fail_point(
|
||||
{"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}}
|
||||
@ -1860,7 +1860,7 @@ class TestClient(AsyncIntegrationTest):
|
||||
await cursor.next()
|
||||
c_id = cursor.cursor_id
|
||||
self.assertIsNotNone(c_id)
|
||||
await client.aclose()
|
||||
await client.close()
|
||||
# Add cursor to kill cursors queue
|
||||
del cursor
|
||||
wait_until(
|
||||
@ -2315,7 +2315,7 @@ class TestMongoClientFailover(AsyncMockClientTest):
|
||||
replicaSet="rs",
|
||||
heartbeatFrequencyMS=500,
|
||||
)
|
||||
self.addAsyncCleanup(c.aclose)
|
||||
self.addAsyncCleanup(c.close)
|
||||
|
||||
wait_until(lambda: len(c.nodes) == 3, "connect")
|
||||
|
||||
@ -2342,7 +2342,7 @@ class TestMongoClientFailover(AsyncMockClientTest):
|
||||
retryReads=False,
|
||||
serverSelectionTimeoutMS=1000,
|
||||
)
|
||||
self.addAsyncCleanup(c.aclose)
|
||||
self.addAsyncCleanup(c.close)
|
||||
|
||||
wait_until(lambda: len(c.nodes) == 3, "connect")
|
||||
|
||||
@ -2377,7 +2377,7 @@ class TestMongoClientFailover(AsyncMockClientTest):
|
||||
retryReads=False,
|
||||
serverSelectionTimeoutMS=1000,
|
||||
)
|
||||
self.addAsyncCleanup(c.aclose)
|
||||
self.addAsyncCleanup(c.close)
|
||||
|
||||
# Set host-specific information so we can test whether it is reset.
|
||||
c.set_wire_version_range("a:1", 2, 6)
|
||||
@ -2453,7 +2453,7 @@ class TestClientPool(AsyncMockClientTest):
|
||||
minPoolSize=1, # minPoolSize
|
||||
event_listeners=[listener],
|
||||
)
|
||||
self.addAsyncCleanup(c.aclose)
|
||||
self.addAsyncCleanup(c.close)
|
||||
|
||||
wait_until(lambda: len(c.nodes) == 3, "connect")
|
||||
self.assertEqual(await c.address, ("a", 1))
|
||||
@ -2483,7 +2483,7 @@ class TestClientPool(AsyncMockClientTest):
|
||||
minPoolSize=1, # minPoolSize
|
||||
event_listeners=[listener],
|
||||
)
|
||||
self.addAsyncCleanup(c.aclose)
|
||||
self.addAsyncCleanup(c.close)
|
||||
|
||||
wait_until(lambda: len(c.nodes) == 1, "connect")
|
||||
self.assertEqual(await c.address, ("c", 3))
|
||||
|
||||
@ -43,7 +43,7 @@ class TestClientBulkWrite(AsyncIntegrationTest):
|
||||
@async_client_context.require_version_min(8, 0, 0, -24)
|
||||
async def test_returns_error_if_no_namespace_provided(self):
|
||||
client = await async_rs_or_single_client()
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
|
||||
models = [InsertOne(document={"a": "b"})]
|
||||
with self.assertRaises(InvalidOperation) as context:
|
||||
@ -65,7 +65,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
async def test_batch_splits_if_num_operations_too_large(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
|
||||
models = []
|
||||
for _ in range(self.max_write_batch_size + 1):
|
||||
@ -90,7 +90,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
async def test_batch_splits_if_ops_payload_too_large(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
|
||||
models = []
|
||||
num_models = int(self.max_message_size_bytes / self.max_bson_object_size + 1)
|
||||
@ -126,7 +126,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
event_listeners=[listener],
|
||||
retryWrites=False,
|
||||
)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
|
||||
fail_command = {
|
||||
"configureFailPoint": "failCommand",
|
||||
@ -165,7 +165,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
async def test_collects_write_errors_across_batches_unordered(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
|
||||
collection = client.db["coll"]
|
||||
self.addAsyncCleanup(collection.drop)
|
||||
@ -195,7 +195,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
async def test_collects_write_errors_across_batches_ordered(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
|
||||
collection = client.db["coll"]
|
||||
self.addAsyncCleanup(collection.drop)
|
||||
@ -225,7 +225,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
async def test_handles_cursor_requiring_getMore(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
|
||||
collection = client.db["coll"]
|
||||
self.addAsyncCleanup(collection.drop)
|
||||
@ -266,7 +266,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
async def test_handles_cursor_requiring_getMore_within_transaction(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
|
||||
collection = client.db["coll"]
|
||||
self.addAsyncCleanup(collection.drop)
|
||||
@ -309,7 +309,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
async def test_handles_getMore_error(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
|
||||
collection = client.db["coll"]
|
||||
self.addAsyncCleanup(collection.drop)
|
||||
@ -363,7 +363,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
async def test_returns_error_if_unacknowledged_too_large_insert(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
|
||||
b_repeated = "b" * self.max_bson_object_size
|
||||
|
||||
@ -419,7 +419,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
async def test_no_batch_splits_if_new_namespace_is_not_too_large(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
|
||||
num_models, models = await self._setup_namespace_test_models()
|
||||
models.append(
|
||||
@ -450,7 +450,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
async def test_batch_splits_if_new_namespace_is_too_large(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
|
||||
num_models, models = await self._setup_namespace_test_models()
|
||||
c_repeated = "c" * 200
|
||||
@ -487,7 +487,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
@async_client_context.require_version_min(8, 0, 0, -24)
|
||||
async def test_returns_error_if_no_writes_can_be_added_to_ops(self):
|
||||
client = await async_rs_or_single_client()
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
|
||||
# Document too large.
|
||||
b_repeated = "b" * self.max_message_size_bytes
|
||||
@ -512,7 +512,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
|
||||
kms_providers={"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}},
|
||||
)
|
||||
client = await async_rs_or_single_client(auto_encryption_opts=opts)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
|
||||
models = [InsertOne(namespace="db.coll", document={"a": "b"})]
|
||||
with self.assertRaises(InvalidOperation) as context:
|
||||
@ -535,7 +535,7 @@ class TestClientBulkWriteCSOT(AsyncIntegrationTest):
|
||||
_OVERHEAD = 500
|
||||
|
||||
internal_client = await async_rs_or_single_client(timeoutMS=None)
|
||||
self.addAsyncCleanup(internal_client.aclose)
|
||||
self.addAsyncCleanup(internal_client.close)
|
||||
|
||||
collection = internal_client.db["coll"]
|
||||
self.addAsyncCleanup(collection.drop)
|
||||
@ -566,7 +566,7 @@ class TestClientBulkWriteCSOT(AsyncIntegrationTest):
|
||||
timeoutMS=2000,
|
||||
w="majority",
|
||||
)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
await client.admin.command("ping") # Init the client first.
|
||||
with self.assertRaises(ClientBulkWriteException) as context:
|
||||
await client.bulk_write(models=models)
|
||||
|
||||
@ -1820,7 +1820,7 @@ class AsyncTestCollection(AsyncIntegrationTest):
|
||||
await self.db.test.insert_many([{"i": i} for i in range(150)])
|
||||
|
||||
client = await async_rs_or_single_client(maxPoolSize=1)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
pool = await async_get_pool(client)
|
||||
|
||||
# Make sure the socket is returned after exhaustion.
|
||||
|
||||
@ -354,7 +354,7 @@ class TestCursor(AsyncIntegrationTest):
|
||||
# Do not add readConcern level to explain.
|
||||
listener = AllowListEventListener("explain")
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
coll = client.pymongo_test.test.with_options(read_concern=ReadConcern(level="local"))
|
||||
self.assertTrue(await coll.find().explain())
|
||||
started = listener.started_events
|
||||
@ -1262,7 +1262,7 @@ class TestCursor(AsyncIntegrationTest):
|
||||
|
||||
listener = AllowListEventListener("killCursors")
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
coll = client[self.db.name].test_close_kills_cursors
|
||||
|
||||
# Add some test data.
|
||||
@ -1301,7 +1301,7 @@ class TestCursor(AsyncIntegrationTest):
|
||||
async def test_timeout_kills_cursor_asynchronously(self):
|
||||
listener = AllowListEventListener("killCursors")
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
coll = client[self.db.name].test_timeout_kills_cursor
|
||||
|
||||
# Add some test data.
|
||||
@ -1359,7 +1359,7 @@ class TestCursor(AsyncIntegrationTest):
|
||||
async def test_getMore_does_not_send_readPreference(self):
|
||||
listener = AllowListEventListener("find", "getMore")
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
# We never send primary read preference so override the default.
|
||||
coll = client[self.db.name].get_collection(
|
||||
"test", read_preference=ReadPreference.PRIMARY_PREFERRED
|
||||
@ -1401,6 +1401,20 @@ class TestCursor(AsyncIntegrationTest):
|
||||
docs = await c.to_list()
|
||||
self.assertEqual([], docs)
|
||||
|
||||
async def test_to_list_length(self):
|
||||
coll = self.db.test
|
||||
await coll.insert_many([{} for _ in range(5)])
|
||||
self.addCleanup(coll.drop)
|
||||
c = coll.find()
|
||||
docs = await c.to_list(3)
|
||||
self.assertEqual(len(docs), 3)
|
||||
|
||||
c = coll.find(batch_size=2)
|
||||
docs = await c.to_list(3)
|
||||
self.assertEqual(len(docs), 3)
|
||||
docs = await c.to_list(3)
|
||||
self.assertEqual(len(docs), 2)
|
||||
|
||||
@async_client_context.require_change_streams
|
||||
async def test_command_cursor_to_list(self):
|
||||
# Set maxAwaitTimeMS=1 to speed up the test.
|
||||
@ -1417,6 +1431,19 @@ class TestCursor(AsyncIntegrationTest):
|
||||
docs = await c.to_list()
|
||||
self.assertEqual([], docs)
|
||||
|
||||
@async_client_context.require_change_streams
|
||||
async def test_command_cursor_to_list_length(self):
|
||||
db = self.db
|
||||
await db.drop_collection("test")
|
||||
await db.test.insert_many([{"foo": 1}, {"foo": 2}])
|
||||
|
||||
pipeline = {"$project": {"_id": False, "foo": True}}
|
||||
result = await db.test.aggregate([pipeline])
|
||||
self.assertEqual(len(await result.to_list()), 2)
|
||||
|
||||
result = await db.test.aggregate([pipeline])
|
||||
self.assertEqual(len(await result.to_list(1)), 1)
|
||||
|
||||
|
||||
class TestRawBatchCursor(AsyncIntegrationTest):
|
||||
async def test_find_raw(self):
|
||||
@ -1424,7 +1451,7 @@ class TestRawBatchCursor(AsyncIntegrationTest):
|
||||
await c.drop()
|
||||
docs = [{"_id": i, "x": 3.0 * i} for i in range(10)]
|
||||
await c.insert_many(docs)
|
||||
batches = await (await c.find_raw_batches()).sort("_id").to_list()
|
||||
batches = await c.find_raw_batches().sort("_id").to_list()
|
||||
self.assertEqual(1, len(batches))
|
||||
self.assertEqual(docs, decode_all(batches[0]))
|
||||
|
||||
@ -1440,7 +1467,7 @@ class TestRawBatchCursor(AsyncIntegrationTest):
|
||||
async with client.start_session() as session:
|
||||
async with await session.start_transaction():
|
||||
batches = await (
|
||||
(await client[self.db.name].test.find_raw_batches(session=session)).sort("_id")
|
||||
client[self.db.name].test.find_raw_batches(session=session).sort("_id")
|
||||
).to_list()
|
||||
cmd = listener.started_events[0]
|
||||
self.assertEqual(cmd.command_name, "find")
|
||||
@ -1470,9 +1497,7 @@ class TestRawBatchCursor(AsyncIntegrationTest):
|
||||
async with self.fail_point(
|
||||
{"mode": {"times": 1}, "data": {"failCommands": ["find"], "closeConnection": True}}
|
||||
):
|
||||
batches = (
|
||||
await (await client[self.db.name].test.find_raw_batches()).sort("_id").to_list()
|
||||
)
|
||||
batches = await client[self.db.name].test.find_raw_batches().sort("_id").to_list()
|
||||
|
||||
self.assertEqual(1, len(batches))
|
||||
self.assertEqual(docs, decode_all(batches[0]))
|
||||
@ -1493,7 +1518,7 @@ class TestRawBatchCursor(AsyncIntegrationTest):
|
||||
db = client[self.db.name]
|
||||
async with client.start_session(snapshot=True) as session:
|
||||
await db.test.distinct("x", {}, session=session)
|
||||
batches = await (await db.test.find_raw_batches(session=session)).sort("_id").to_list()
|
||||
batches = await db.test.find_raw_batches(session=session).sort("_id").to_list()
|
||||
self.assertEqual(1, len(batches))
|
||||
self.assertEqual(docs, decode_all(batches[0]))
|
||||
|
||||
@ -1504,18 +1529,18 @@ class TestRawBatchCursor(AsyncIntegrationTest):
|
||||
async def test_explain(self):
|
||||
c = self.db.test
|
||||
await c.insert_one({})
|
||||
explanation = await (await c.find_raw_batches()).explain()
|
||||
explanation = await c.find_raw_batches().explain()
|
||||
self.assertIsInstance(explanation, dict)
|
||||
|
||||
async def test_empty(self):
|
||||
await self.db.test.drop()
|
||||
cursor = await self.db.test.find_raw_batches()
|
||||
cursor = self.db.test.find_raw_batches()
|
||||
with self.assertRaises(StopAsyncIteration):
|
||||
await anext(cursor)
|
||||
|
||||
async def test_clone(self):
|
||||
await self.db.test.insert_one({})
|
||||
cursor = await self.db.test.find_raw_batches()
|
||||
cursor = self.db.test.find_raw_batches()
|
||||
# Copy of a RawBatchCursor is also a RawBatchCursor, not a Cursor.
|
||||
self.assertIsInstance(await anext(cursor.clone()), bytes)
|
||||
self.assertIsInstance(await anext(copy.copy(cursor)), bytes)
|
||||
@ -1525,24 +1550,22 @@ class TestRawBatchCursor(AsyncIntegrationTest):
|
||||
c = self.db.test
|
||||
await c.drop()
|
||||
await c.insert_many({"_id": i} for i in range(200))
|
||||
result = b"".join(
|
||||
await (await c.find_raw_batches(cursor_type=CursorType.EXHAUST)).to_list()
|
||||
)
|
||||
result = b"".join(await c.find_raw_batches(cursor_type=CursorType.EXHAUST).to_list())
|
||||
self.assertEqual([{"_id": i} for i in range(200)], decode_all(result))
|
||||
|
||||
async def test_server_error(self):
|
||||
with self.assertRaises(OperationFailure) as exc:
|
||||
await anext(await self.db.test.find_raw_batches({"x": {"$bad": 1}}))
|
||||
await anext(self.db.test.find_raw_batches({"x": {"$bad": 1}}))
|
||||
|
||||
# The server response was decoded, not left raw.
|
||||
self.assertIsInstance(exc.exception.details, dict)
|
||||
|
||||
async def test_get_item(self):
|
||||
with self.assertRaises(InvalidOperation):
|
||||
(await self.db.test.find_raw_batches())[0]
|
||||
self.db.test.find_raw_batches()[0]
|
||||
|
||||
async def test_collation(self):
|
||||
await anext(await self.db.test.find_raw_batches(collation=Collation("en_US")))
|
||||
await anext(self.db.test.find_raw_batches(collation=Collation("en_US")))
|
||||
|
||||
@async_client_context.require_no_mmap # MMAPv1 does not support read concern
|
||||
async def test_read_concern(self):
|
||||
@ -1550,7 +1573,7 @@ class TestRawBatchCursor(AsyncIntegrationTest):
|
||||
{}
|
||||
)
|
||||
c = self.db.get_collection("test", read_concern=ReadConcern("majority"))
|
||||
await anext(await c.find_raw_batches())
|
||||
await anext(c.find_raw_batches())
|
||||
|
||||
async def test_monitoring(self):
|
||||
listener = EventListener()
|
||||
@ -1560,7 +1583,7 @@ class TestRawBatchCursor(AsyncIntegrationTest):
|
||||
await c.insert_many([{"_id": i} for i in range(10)])
|
||||
|
||||
listener.reset()
|
||||
cursor = await c.find_raw_batches(batch_size=4)
|
||||
cursor = c.find_raw_batches(batch_size=4)
|
||||
|
||||
# First raw batch of 4 documents.
|
||||
await anext(cursor)
|
||||
@ -1766,7 +1789,7 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest):
|
||||
async def test_exhaust_cursor_db_set(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
c = client.pymongo_test.test
|
||||
await c.delete_many({})
|
||||
await c.insert_many([{"_id": i} for i in range(3)])
|
||||
|
||||
@ -236,7 +236,7 @@ class TestDatabase(AsyncIntegrationTest):
|
||||
async def test_check_exists(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_or_single_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
db = client[self.db.name]
|
||||
await db.drop_collection("unique")
|
||||
await db.create_collection("unique", check_exists=True)
|
||||
|
||||
@ -99,7 +99,7 @@ class TestSession(AsyncIntegrationTest):
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands)
|
||||
await cls.client2.aclose()
|
||||
await cls.client2.close()
|
||||
await super()._tearDown_class()
|
||||
|
||||
async def asyncSetUp(self):
|
||||
@ -108,7 +108,7 @@ class TestSession(AsyncIntegrationTest):
|
||||
self.client = await async_rs_or_single_client(
|
||||
event_listeners=[self.listener, self.session_checker_listener]
|
||||
)
|
||||
self.addAsyncCleanup(self.client.aclose)
|
||||
self.addAsyncCleanup(self.client.close)
|
||||
self.db = self.client.pymongo_test
|
||||
self.initial_lsids = {s["id"] for s in session_ids(self.client)}
|
||||
|
||||
@ -295,14 +295,14 @@ class TestSession(AsyncIntegrationTest):
|
||||
|
||||
# Closing the client should end all sessions and clear the pool.
|
||||
self.assertEqual(len(client._topology._session_pool), _MAX_END_SESSIONS + 1)
|
||||
await client.aclose()
|
||||
await client.close()
|
||||
self.assertEqual(len(client._topology._session_pool), 0)
|
||||
end_sessions = [e for e in listener.started_events if e.command_name == "endSessions"]
|
||||
self.assertEqual(len(end_sessions), 2)
|
||||
|
||||
# Closing again should not send any commands.
|
||||
listener.reset()
|
||||
await client.aclose()
|
||||
await client.close()
|
||||
self.assertEqual(len(listener.started_events), 0)
|
||||
|
||||
async def test_client(self):
|
||||
@ -790,7 +790,7 @@ class TestSession(AsyncIntegrationTest):
|
||||
# Ensure the collection exists.
|
||||
await self.client.pymongo_test.test_unacked_writes.insert_one({})
|
||||
client = await async_rs_or_single_client(w=0, event_listeners=[self.listener])
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
db = client.pymongo_test
|
||||
coll = db.test_unacked_writes
|
||||
ops: list = [
|
||||
@ -842,7 +842,7 @@ class TestCausalConsistency(AsyncUnitTest):
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
await cls.client.aclose()
|
||||
await cls.client.close()
|
||||
|
||||
@async_client_context.require_sessions
|
||||
async def asyncSetUp(self):
|
||||
@ -927,7 +927,7 @@ class TestCausalConsistency(AsyncUnitTest):
|
||||
return await (await coll.aggregate_raw_batches([], session=session)).to_list()
|
||||
|
||||
async def find_raw(coll, session):
|
||||
return await (await coll.find_raw_batches({}, session=session)).to_list()
|
||||
return await coll.find_raw_batches({}, session=session).to_list()
|
||||
|
||||
await self._test_reads(aggregate)
|
||||
await self._test_reads(lambda coll, session: coll.find({}, session=session).to_list())
|
||||
@ -1156,7 +1156,7 @@ class TestClusterTime(AsyncIntegrationTest):
|
||||
client = await async_rs_or_single_client(
|
||||
event_listeners=[listener], heartbeatFrequencyMS=999999
|
||||
)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
collection = client.pymongo_test.collection
|
||||
# Prepare for tests of find() and aggregate().
|
||||
await collection.insert_many([{} for _ in range(10)])
|
||||
|
||||
@ -74,7 +74,7 @@ class AsyncTransactionsBase(AsyncSpecRunner):
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
for client in cls.mongos_clients:
|
||||
await client.aclose()
|
||||
await client.close()
|
||||
await super()._tearDown_class()
|
||||
|
||||
def maybe_skip_scenario(self, test):
|
||||
@ -121,7 +121,7 @@ class TestTransactions(AsyncTransactionsBase):
|
||||
async def test_transaction_write_concern_override(self):
|
||||
"""Test txn overrides Client/Database/Collection write_concern."""
|
||||
client = await async_rs_client(w=0)
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
db = client.test
|
||||
coll = db.test
|
||||
await coll.insert_one({})
|
||||
@ -183,7 +183,7 @@ class TestTransactions(AsyncTransactionsBase):
|
||||
coll = client.test.test
|
||||
# Create the collection.
|
||||
await coll.insert_one({})
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
async with client.start_session() as s:
|
||||
# Session is pinned to Mongos.
|
||||
async with await s.start_transaction():
|
||||
@ -211,7 +211,7 @@ class TestTransactions(AsyncTransactionsBase):
|
||||
coll = client.test.test
|
||||
# Create the collection.
|
||||
await coll.insert_one({})
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
async with client.start_session() as s:
|
||||
# Session is pinned to Mongos.
|
||||
async with await s.start_transaction():
|
||||
@ -339,7 +339,7 @@ class TestTransactions(AsyncTransactionsBase):
|
||||
coll = client[self.db.name].test
|
||||
await coll.delete_many({})
|
||||
listener.reset()
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
self.addAsyncCleanup(coll.drop)
|
||||
large_str = "\0" * (1 * 1024 * 1024)
|
||||
ops: List[InsertOne[RawBSONDocument]] = [
|
||||
@ -365,7 +365,7 @@ class TestTransactions(AsyncTransactionsBase):
|
||||
@async_client_context.require_transactions
|
||||
async def test_transaction_direct_connection(self):
|
||||
client = await async_single_client()
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
coll = client.pymongo_test.test
|
||||
|
||||
# Make sure the collection exists.
|
||||
@ -375,6 +375,9 @@ class TestTransactions(AsyncTransactionsBase):
|
||||
async def find(*args, **kwargs):
|
||||
return coll.find(*args, **kwargs)
|
||||
|
||||
async def find_raw_batches(*args, **kwargs):
|
||||
return coll.find_raw_batches(*args, **kwargs)
|
||||
|
||||
ops = [
|
||||
(coll.bulk_write, [[InsertOne[dict]({})]]),
|
||||
(coll.insert_one, [{}]),
|
||||
@ -393,7 +396,7 @@ class TestTransactions(AsyncTransactionsBase):
|
||||
(coll.aggregate, [[]]),
|
||||
(find, [{}]),
|
||||
(coll.aggregate_raw_batches, [[]]),
|
||||
(coll.find_raw_batches, [{}]),
|
||||
(find_raw_batches, [{}]),
|
||||
(coll.database.command, ["find", coll.name]),
|
||||
]
|
||||
for f, args in ops:
|
||||
@ -452,7 +455,7 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
|
||||
async def test_callback_not_retried_after_timeout(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
coll = client[self.db.name].test
|
||||
|
||||
async def callback(session):
|
||||
@ -481,7 +484,7 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
|
||||
async def test_callback_not_retried_after_commit_timeout(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
coll = client[self.db.name].test
|
||||
|
||||
async def callback(session):
|
||||
@ -516,7 +519,7 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
|
||||
async def test_commit_not_retried_after_timeout(self):
|
||||
listener = OvertCommandListener()
|
||||
client = await async_rs_client(event_listeners=[listener])
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
coll = client[self.db.name].test
|
||||
|
||||
async def callback(session):
|
||||
|
||||
@ -535,7 +535,7 @@ class AsyncSpecRunner(AsyncIntegrationTest):
|
||||
self.pool_listener = pool_listener
|
||||
self.server_listener = server_listener
|
||||
# Close the client explicitly to avoid having too many threads open.
|
||||
self.addAsyncCleanup(client.aclose)
|
||||
self.addAsyncCleanup(client.close)
|
||||
|
||||
# Create session0 and session1.
|
||||
sessions = {}
|
||||
|
||||
@ -1392,6 +1392,20 @@ class TestCursor(IntegrationTest):
|
||||
docs = c.to_list()
|
||||
self.assertEqual([], docs)
|
||||
|
||||
def test_to_list_length(self):
|
||||
coll = self.db.test
|
||||
coll.insert_many([{} for _ in range(5)])
|
||||
self.addCleanup(coll.drop)
|
||||
c = coll.find()
|
||||
docs = c.to_list(3)
|
||||
self.assertEqual(len(docs), 3)
|
||||
|
||||
c = coll.find(batch_size=2)
|
||||
docs = c.to_list(3)
|
||||
self.assertEqual(len(docs), 3)
|
||||
docs = c.to_list(3)
|
||||
self.assertEqual(len(docs), 2)
|
||||
|
||||
@client_context.require_change_streams
|
||||
def test_command_cursor_to_list(self):
|
||||
# Set maxAwaitTimeMS=1 to speed up the test.
|
||||
@ -1408,6 +1422,19 @@ class TestCursor(IntegrationTest):
|
||||
docs = c.to_list()
|
||||
self.assertEqual([], docs)
|
||||
|
||||
@client_context.require_change_streams
|
||||
def test_command_cursor_to_list_length(self):
|
||||
db = self.db
|
||||
db.drop_collection("test")
|
||||
db.test.insert_many([{"foo": 1}, {"foo": 2}])
|
||||
|
||||
pipeline = {"$project": {"_id": False, "foo": True}}
|
||||
result = db.test.aggregate([pipeline])
|
||||
self.assertEqual(len(result.to_list()), 2)
|
||||
|
||||
result = db.test.aggregate([pipeline])
|
||||
self.assertEqual(len(result.to_list(1)), 1)
|
||||
|
||||
|
||||
class TestRawBatchCursor(IntegrationTest):
|
||||
def test_find_raw(self):
|
||||
@ -1415,7 +1442,7 @@ class TestRawBatchCursor(IntegrationTest):
|
||||
c.drop()
|
||||
docs = [{"_id": i, "x": 3.0 * i} for i in range(10)]
|
||||
c.insert_many(docs)
|
||||
batches = (c.find_raw_batches()).sort("_id").to_list()
|
||||
batches = c.find_raw_batches().sort("_id").to_list()
|
||||
self.assertEqual(1, len(batches))
|
||||
self.assertEqual(docs, decode_all(batches[0]))
|
||||
|
||||
@ -1431,7 +1458,7 @@ class TestRawBatchCursor(IntegrationTest):
|
||||
with client.start_session() as session:
|
||||
with session.start_transaction():
|
||||
batches = (
|
||||
(client[self.db.name].test.find_raw_batches(session=session)).sort("_id")
|
||||
client[self.db.name].test.find_raw_batches(session=session).sort("_id")
|
||||
).to_list()
|
||||
cmd = listener.started_events[0]
|
||||
self.assertEqual(cmd.command_name, "find")
|
||||
@ -1461,7 +1488,7 @@ class TestRawBatchCursor(IntegrationTest):
|
||||
with self.fail_point(
|
||||
{"mode": {"times": 1}, "data": {"failCommands": ["find"], "closeConnection": True}}
|
||||
):
|
||||
batches = (client[self.db.name].test.find_raw_batches()).sort("_id").to_list()
|
||||
batches = client[self.db.name].test.find_raw_batches().sort("_id").to_list()
|
||||
|
||||
self.assertEqual(1, len(batches))
|
||||
self.assertEqual(docs, decode_all(batches[0]))
|
||||
@ -1482,7 +1509,7 @@ class TestRawBatchCursor(IntegrationTest):
|
||||
db = client[self.db.name]
|
||||
with client.start_session(snapshot=True) as session:
|
||||
db.test.distinct("x", {}, session=session)
|
||||
batches = (db.test.find_raw_batches(session=session)).sort("_id").to_list()
|
||||
batches = db.test.find_raw_batches(session=session).sort("_id").to_list()
|
||||
self.assertEqual(1, len(batches))
|
||||
self.assertEqual(docs, decode_all(batches[0]))
|
||||
|
||||
@ -1493,7 +1520,7 @@ class TestRawBatchCursor(IntegrationTest):
|
||||
def test_explain(self):
|
||||
c = self.db.test
|
||||
c.insert_one({})
|
||||
explanation = (c.find_raw_batches()).explain()
|
||||
explanation = c.find_raw_batches().explain()
|
||||
self.assertIsInstance(explanation, dict)
|
||||
|
||||
def test_empty(self):
|
||||
@ -1514,7 +1541,7 @@ class TestRawBatchCursor(IntegrationTest):
|
||||
c = self.db.test
|
||||
c.drop()
|
||||
c.insert_many({"_id": i} for i in range(200))
|
||||
result = b"".join((c.find_raw_batches(cursor_type=CursorType.EXHAUST)).to_list())
|
||||
result = b"".join(c.find_raw_batches(cursor_type=CursorType.EXHAUST).to_list())
|
||||
self.assertEqual([{"_id": i} for i in range(200)], decode_all(result))
|
||||
|
||||
def test_server_error(self):
|
||||
@ -1526,7 +1553,7 @@ class TestRawBatchCursor(IntegrationTest):
|
||||
|
||||
def test_get_item(self):
|
||||
with self.assertRaises(InvalidOperation):
|
||||
(self.db.test.find_raw_batches())[0]
|
||||
self.db.test.find_raw_batches()[0]
|
||||
|
||||
def test_collation(self):
|
||||
next(self.db.test.find_raw_batches(collation=Collation("en_US")))
|
||||
|
||||
@ -440,6 +440,12 @@ class TestGridfs(IntegrationTest):
|
||||
gout = next(cursor)
|
||||
self.assertEqual(b"test2+", gout.read())
|
||||
self.assertRaises(StopIteration, cursor.__next__)
|
||||
cursor.rewind()
|
||||
items = cursor.to_list()
|
||||
self.assertEqual(len(items), 2)
|
||||
cursor.rewind()
|
||||
items = cursor.to_list(1)
|
||||
self.assertEqual(len(items), 1)
|
||||
cursor.close()
|
||||
self.assertRaises(TypeError, self.fs.find, {}, {"_id": True})
|
||||
|
||||
|
||||
@ -925,7 +925,7 @@ class TestCausalConsistency(UnitTest):
|
||||
return (coll.aggregate_raw_batches([], session=session)).to_list()
|
||||
|
||||
def find_raw(coll, session):
|
||||
return (coll.find_raw_batches({}, session=session)).to_list()
|
||||
return coll.find_raw_batches({}, session=session).to_list()
|
||||
|
||||
self._test_reads(aggregate)
|
||||
self._test_reads(lambda coll, session: coll.find({}, session=session).to_list())
|
||||
|
||||
@ -371,6 +371,9 @@ class TestTransactions(TransactionsBase):
|
||||
def find(*args, **kwargs):
|
||||
return coll.find(*args, **kwargs)
|
||||
|
||||
def find_raw_batches(*args, **kwargs):
|
||||
return coll.find_raw_batches(*args, **kwargs)
|
||||
|
||||
ops = [
|
||||
(coll.bulk_write, [[InsertOne[dict]({})]]),
|
||||
(coll.insert_one, [{}]),
|
||||
@ -389,7 +392,7 @@ class TestTransactions(TransactionsBase):
|
||||
(coll.aggregate, [[]]),
|
||||
(find, [{}]),
|
||||
(coll.aggregate_raw_batches, [[]]),
|
||||
(coll.find_raw_batches, [{}]),
|
||||
(find_raw_batches, [{}]),
|
||||
(coll.database.command, ["find", coll.name]),
|
||||
]
|
||||
for f, args in ops:
|
||||
|
||||
@ -92,7 +92,6 @@ replacements = {
|
||||
"asyncAssertRaisesExactly": "assertRaisesExactly",
|
||||
"get_async_mock_client": "get_mock_client",
|
||||
"aconnect": "_connect",
|
||||
"aclose": "close",
|
||||
"async-transactions-ref": "transactions-ref",
|
||||
"async-snapshot-reads-ref": "snapshot-reads-ref",
|
||||
"default_async": "default",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user