PYTHON-4654 Clean up Async API to match Motor (#1789)

This commit is contained in:
Steven Silvester 2024-08-13 19:17:45 -05:00 committed by GitHub
parent 47b2257028
commit f69d330b25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 128 additions and 121 deletions

View File

@ -6,7 +6,7 @@
.. autoclass:: pymongo.asynchronous.mongo_client.AsyncMongoClient(host='localhost', port=27017, document_class=dict, tz_aware=False, connect=True, **kwargs)
.. automethod:: aclose
.. automethod:: close
.. describe:: c[db_name] || c.db_name

View File

@ -1893,9 +1893,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
"""
return AsyncCursor(self, *args, **kwargs)
async def find_raw_batches(
self, *args: Any, **kwargs: Any
) -> AsyncRawBatchCursor[_DocumentType]:
def find_raw_batches(self, *args: Any, **kwargs: Any) -> AsyncRawBatchCursor[_DocumentType]:
"""Query the database and retrieve batches of raw BSON.
Similar to the :meth:`find` method but returns a
@ -1907,7 +1905,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]):
:mod:`bson` module.
>>> import bson
>>> cursor = await db.test.find_raw_batches()
>>> cursor = db.test.find_raw_batches()
>>> async for batch in cursor:
... print(bson.decode_all(batch))

View File

@ -299,7 +299,7 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
self.client_ref = None
self.key_vault_coll = None
if self.mongocryptd_client:
await self.mongocryptd_client.aclose()
await self.mongocryptd_client.close()
self.mongocryptd_client = None
@ -439,7 +439,7 @@ class _Encrypter:
self._closed = True
await self._auto_encrypter.close()
if self._internal_client:
await self._internal_client.aclose()
await self._internal_client.close()
self._internal_client = None

View File

@ -1378,7 +1378,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
return self
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
await self.aclose()
await self.close()
# See PYTHON-3084.
__iter__ = None
@ -1514,7 +1514,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
# command.
pass
async def aclose(self) -> None:
async def close(self) -> None:
"""Cleanup client resources and disconnect from MongoDB.
End all server sessions created by this client by sending one or more
@ -1541,6 +1541,10 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
# TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened.
await self._encrypter.close()
if not _IS_SYNC:
# Add support for contextlib.aclosing.
aclose = close
async def _get_topology(self) -> Topology:
"""Get the internal :class:`~pymongo.asynchronous.topology.Topology` object.

View File

@ -1536,6 +1536,10 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
# TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened.
self._encrypter.close()
if not _IS_SYNC:
# Add support for contextlib.closing.
aclose = close
def _get_topology(self) -> Topology:
"""Get the internal :class:`~pymongo.topology.Topology` object.

View File

@ -193,7 +193,7 @@ class AsyncClientContext:
self.connection_attempts.append(f"failed to connect client {client!r}: {exc}")
return None
finally:
await client.aclose()
await client.close()
async def _init_client(self):
self.client = await self._connect(host, port)
@ -409,7 +409,7 @@ class AsyncClientContext:
else:
raise
finally:
await client.aclose()
await client.close()
def _server_started_with_auth(self):
# MongoDB >= 2.0
@ -1089,7 +1089,7 @@ async def async_teardown():
await c.drop_database("pymongo_test2")
await c.drop_database("pymongo_test_mike")
await c.drop_database("pymongo_test_bernie")
await c.aclose()
await c.close()
print_running_clients()

View File

@ -137,7 +137,7 @@ class AsyncClientUnitTest(AsyncUnitTest):
@classmethod
async def _tearDown_class(cls):
await cls.client.aclose()
await cls.client.close()
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
@ -663,7 +663,7 @@ class TestClient(AsyncIntegrationTest):
pass
self.assertEqual(1, len(server._pool.conns))
self.assertTrue(conn in server._pool.conns)
await client.aclose()
await client.close()
async def test_max_idle_time_reaper_removes_stale_minPoolSize(self):
with client_knobs(kill_cursor_frequency=0.1):
@ -679,7 +679,7 @@ class TestClient(AsyncIntegrationTest):
self.assertGreaterEqual(len(server._pool.conns), 1)
wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket")
await client.aclose()
await client.close()
async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self):
with client_knobs(kill_cursor_frequency=0.1):
@ -697,7 +697,7 @@ class TestClient(AsyncIntegrationTest):
self.assertEqual(1, len(server._pool.conns))
wait_until(lambda: conn not in server._pool.conns, "remove stale socket")
wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket")
await client.aclose()
await client.close()
async def test_max_idle_time_reaper_removes_stale(self):
with client_knobs(kill_cursor_frequency=0.1):
@ -717,7 +717,7 @@ class TestClient(AsyncIntegrationTest):
lambda: len(server._pool.conns) == 0,
"stale socket reaped and new one NOT added to the pool",
)
await client.aclose()
await client.close()
async def test_min_pool_size(self):
with client_knobs(kill_cursor_frequency=0.1):
@ -845,13 +845,13 @@ class TestClient(AsyncIntegrationTest):
async def test_equality(self):
seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0])
c = await async_rs_or_single_client(seed, connect=False)
self.addAsyncCleanup(c.aclose)
self.addAsyncCleanup(c.close)
self.assertEqual(async_client_context.client, c)
# Explicitly test inequality
self.assertFalse(async_client_context.client != c)
c = await async_rs_or_single_client("invalid.com", connect=False)
self.addAsyncCleanup(c.aclose)
self.addAsyncCleanup(c.close)
self.assertNotEqual(async_client_context.client, c)
self.assertTrue(async_client_context.client != c)
# Seeds differ:
@ -867,10 +867,10 @@ class TestClient(AsyncIntegrationTest):
async def test_hashable(self):
seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0])
c = await async_rs_or_single_client(seed, connect=False)
self.addAsyncCleanup(c.aclose)
self.addAsyncCleanup(c.close)
self.assertIn(c, {async_client_context.client})
c = await async_rs_or_single_client("invalid.com", connect=False)
self.addAsyncCleanup(c.aclose)
self.addAsyncCleanup(c.close)
self.assertNotIn(c, {async_client_context.client})
async def test_host_w_port(self):
@ -940,7 +940,7 @@ class TestClient(AsyncIntegrationTest):
self.assertIs(type(helper_doc), dict)
self.assertEqual(helper_doc.keys(), cmd_doc.keys())
client = await async_rs_or_single_client(document_class=SON)
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
async for doc in await client.list_databases():
self.assertIs(type(doc), dict)
@ -991,7 +991,7 @@ class TestClient(AsyncIntegrationTest):
async def test_close(self):
test_client = await async_rs_or_single_client()
coll = test_client.pymongo_test.bar
await test_client.aclose()
await test_client.close()
with self.assertRaises(InvalidOperation):
await coll.count_documents({})
@ -1024,7 +1024,7 @@ class TestClient(AsyncIntegrationTest):
# Close the client and ensure the topology is closed.
self.assertTrue(test_client._topology._opened)
await test_client.aclose()
await test_client.close()
self.assertFalse(test_client._topology._opened)
test_client = await async_rs_or_single_client()
# The killCursors task should not need to re-open the topology.
@ -1037,7 +1037,7 @@ class TestClient(AsyncIntegrationTest):
self.assertFalse(client._kill_cursors_executor._stopped)
# Closing the client should stop the thread.
await client.aclose()
await client.close()
self.assertTrue(client._kill_cursors_executor._stopped)
# Reusing the closed client should raise an InvalidOperation error.
@ -1062,21 +1062,21 @@ class TestClient(AsyncIntegrationTest):
self.assertTrue(kc_thread and kc_thread.is_alive())
# Tear down.
await client.aclose()
await client.close()
async def test_close_does_not_open_servers(self):
client = await async_rs_client(connect=False)
topology = client._topology
self.assertEqual(topology._servers, {})
await client.aclose()
await client.close()
self.assertEqual(topology._servers, {})
async def test_close_closes_sockets(self):
client = await async_rs_client()
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
await client.test.test.find_one()
topology = client._topology
await client.aclose()
await client.close()
for server in topology._servers.values():
self.assertFalse(server._pool.conns)
self.assertTrue(server._monitor._executor._stopped)
@ -1181,7 +1181,7 @@ class TestClient(AsyncIntegrationTest):
uri = "mongodb://%s" % encoded_socket
# Confirm we can do operations via the socket.
client = await async_rs_or_single_client(uri)
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
await client.pymongo_test.test.insert_one({"dummy": "object"})
dbs = await client.list_database_names()
self.assertTrue("pymongo_test" in dbs)
@ -1206,7 +1206,7 @@ class TestClient(AsyncIntegrationTest):
self.assertFalse(isinstance(await db.test.find_one(), SON))
c = await async_rs_or_single_client(document_class=SON)
self.addAsyncCleanup(c.aclose)
self.addAsyncCleanup(c.close)
db = c.pymongo_test
self.assertEqual(SON, c.codec_options.document_class)
@ -1248,7 +1248,7 @@ class TestClient(AsyncIntegrationTest):
no_timeout = self.client
timeout_sec = 1
timeout = await async_rs_or_single_client(socketTimeoutMS=1000 * timeout_sec)
self.addAsyncCleanup(timeout.aclose)
self.addAsyncCleanup(timeout.close)
await no_timeout.pymongo_test.drop_collection("test")
await no_timeout.pymongo_test.test.insert_one({"x": 1})
@ -1310,7 +1310,7 @@ class TestClient(AsyncIntegrationTest):
self.assertRaises(ValueError, AsyncMongoClient, tz_aware="foo")
aware = await async_rs_or_single_client(tz_aware=True)
self.addAsyncCleanup(aware.aclose)
self.addAsyncCleanup(aware.close)
naive = self.client
await aware.pymongo_test.drop_collection("test")
@ -1340,7 +1340,7 @@ class TestClient(AsyncIntegrationTest):
uri += "/?replicaSet=" + (async_client_context.replica_set_name or "")
client = await async_rs_or_single_client_noauth(uri)
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
await client.pymongo_test.test.insert_one({"dummy": "object"})
await client.pymongo_test_bernie.test.insert_one({"dummy": "object"})
@ -1442,7 +1442,7 @@ class TestClient(AsyncIntegrationTest):
# to avoid race conditions caused by replica set failover or idle
# socket reaping.
client = await async_single_client()
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
await client.pymongo_test.test.find_one()
pool = await async_get_pool(client)
socket_count = len(pool.conns)
@ -1467,7 +1467,7 @@ class TestClient(AsyncIntegrationTest):
self.addAsyncCleanup(async_client_context.client.drop_database, "test_lazy_connect_w0")
client = await async_rs_or_single_client(connect=False, w=0)
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
await client.test_lazy_connect_w0.test.insert_one({})
async def predicate():
@ -1476,7 +1476,7 @@ class TestClient(AsyncIntegrationTest):
await async_wait_until(predicate, "find one document")
client = await async_rs_or_single_client(connect=False, w=0)
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
await client.test_lazy_connect_w0.test.update_one({}, {"$set": {"x": 1}})
async def predicate():
@ -1485,7 +1485,7 @@ class TestClient(AsyncIntegrationTest):
await async_wait_until(predicate, "update one document")
client = await async_rs_or_single_client(connect=False, w=0)
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
await client.test_lazy_connect_w0.test.delete_one({})
async def predicate():
@ -1498,7 +1498,7 @@ class TestClient(AsyncIntegrationTest):
# When doing an exhaust query, the socket stays checked out on success
# but must be checked in on error to avoid semaphore leaks.
client = await async_rs_or_single_client(maxPoolSize=1, retryReads=False)
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
collection = client.pymongo_test.test
pool = await async_get_pool(client)
pool._check_interval_seconds = None # Never check.
@ -1611,7 +1611,7 @@ class TestClient(AsyncIntegrationTest):
# closer to 0.5 sec with heartbeatFrequencyMS configured.
self.assertAlmostEqual(heartbeat_times[1] - heartbeat_times[0], 0.5, delta=2)
await client.aclose()
await client.close()
finally:
ServerHeartbeatStartedEvent.__init__ = old_init # type: ignore
@ -1709,7 +1709,7 @@ class TestClient(AsyncIntegrationTest):
async def test_reset_during_update_pool(self):
client = await async_rs_or_single_client(minPoolSize=10)
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
await client.admin.command("ping")
pool = await async_get_pool(client)
generation = pool.gen.get_overall()
@ -1758,7 +1758,7 @@ class TestClient(AsyncIntegrationTest):
client = await async_rs_or_single_client(
serverSelectionTimeoutMS=3000, minPoolSize=min_pool_size, connect=False
)
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
# Create a single connection in the pool.
await client.admin.command("ping")
@ -1793,7 +1793,7 @@ class TestClient(AsyncIntegrationTest):
await client.admin.command("ping")
self.assertEqual(len(client.nodes), 1)
self.assertEqual(client._topology_settings.get_topology_type(), TOPOLOGY_TYPE.Single)
await client.aclose()
await client.close()
# direct_connection=False should result in RS topology.
client = await async_rs_or_single_client(directConnection=False)
@ -1803,7 +1803,7 @@ class TestClient(AsyncIntegrationTest):
client._topology_settings.get_topology_type(),
[TOPOLOGY_TYPE.ReplicaSetNoPrimary, TOPOLOGY_TYPE.ReplicaSetWithPrimary],
)
await client.aclose()
await client.close()
# directConnection=True, should error with multiple hosts as a list.
with self.assertRaises(ConfigurationError):
@ -1827,7 +1827,7 @@ class TestClient(AsyncIntegrationTest):
"invalid:27017", heartbeatFrequencyMS=3, serverSelectionTimeoutMS=150
)
initial_count = server_description_count()
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
with self.assertRaises(ServerSelectionTimeoutError):
await client.test.test.find_one()
gc.collect()
@ -1841,7 +1841,7 @@ class TestClient(AsyncIntegrationTest):
@async_client_context.require_failCommand_fail_point
async def test_network_error_message(self):
client = await async_single_client(retryReads=False)
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
await client.admin.command("ping") # connect
async with self.fail_point(
{"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}}
@ -1860,7 +1860,7 @@ class TestClient(AsyncIntegrationTest):
await cursor.next()
c_id = cursor.cursor_id
self.assertIsNotNone(c_id)
await client.aclose()
await client.close()
# Add cursor to kill cursors queue
del cursor
wait_until(
@ -2315,7 +2315,7 @@ class TestMongoClientFailover(AsyncMockClientTest):
replicaSet="rs",
heartbeatFrequencyMS=500,
)
self.addAsyncCleanup(c.aclose)
self.addAsyncCleanup(c.close)
wait_until(lambda: len(c.nodes) == 3, "connect")
@ -2342,7 +2342,7 @@ class TestMongoClientFailover(AsyncMockClientTest):
retryReads=False,
serverSelectionTimeoutMS=1000,
)
self.addAsyncCleanup(c.aclose)
self.addAsyncCleanup(c.close)
wait_until(lambda: len(c.nodes) == 3, "connect")
@ -2377,7 +2377,7 @@ class TestMongoClientFailover(AsyncMockClientTest):
retryReads=False,
serverSelectionTimeoutMS=1000,
)
self.addAsyncCleanup(c.aclose)
self.addAsyncCleanup(c.close)
# Set host-specific information so we can test whether it is reset.
c.set_wire_version_range("a:1", 2, 6)
@ -2453,7 +2453,7 @@ class TestClientPool(AsyncMockClientTest):
minPoolSize=1, # minPoolSize
event_listeners=[listener],
)
self.addAsyncCleanup(c.aclose)
self.addAsyncCleanup(c.close)
wait_until(lambda: len(c.nodes) == 3, "connect")
self.assertEqual(await c.address, ("a", 1))
@ -2483,7 +2483,7 @@ class TestClientPool(AsyncMockClientTest):
minPoolSize=1, # minPoolSize
event_listeners=[listener],
)
self.addAsyncCleanup(c.aclose)
self.addAsyncCleanup(c.close)
wait_until(lambda: len(c.nodes) == 1, "connect")
self.assertEqual(await c.address, ("c", 3))

View File

@ -43,7 +43,7 @@ class TestClientBulkWrite(AsyncIntegrationTest):
@async_client_context.require_version_min(8, 0, 0, -24)
async def test_returns_error_if_no_namespace_provided(self):
client = await async_rs_or_single_client()
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
models = [InsertOne(document={"a": "b"})]
with self.assertRaises(InvalidOperation) as context:
@ -65,7 +65,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
async def test_batch_splits_if_num_operations_too_large(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
models = []
for _ in range(self.max_write_batch_size + 1):
@ -90,7 +90,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
async def test_batch_splits_if_ops_payload_too_large(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
models = []
num_models = int(self.max_message_size_bytes / self.max_bson_object_size + 1)
@ -126,7 +126,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
event_listeners=[listener],
retryWrites=False,
)
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
fail_command = {
"configureFailPoint": "failCommand",
@ -165,7 +165,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
async def test_collects_write_errors_across_batches_unordered(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
collection = client.db["coll"]
self.addAsyncCleanup(collection.drop)
@ -195,7 +195,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
async def test_collects_write_errors_across_batches_ordered(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
collection = client.db["coll"]
self.addAsyncCleanup(collection.drop)
@ -225,7 +225,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
async def test_handles_cursor_requiring_getMore(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
collection = client.db["coll"]
self.addAsyncCleanup(collection.drop)
@ -266,7 +266,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
async def test_handles_cursor_requiring_getMore_within_transaction(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
collection = client.db["coll"]
self.addAsyncCleanup(collection.drop)
@ -309,7 +309,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
async def test_handles_getMore_error(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
collection = client.db["coll"]
self.addAsyncCleanup(collection.drop)
@ -363,7 +363,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
async def test_returns_error_if_unacknowledged_too_large_insert(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
b_repeated = "b" * self.max_bson_object_size
@ -419,7 +419,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
async def test_no_batch_splits_if_new_namespace_is_not_too_large(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
num_models, models = await self._setup_namespace_test_models()
models.append(
@ -450,7 +450,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
async def test_batch_splits_if_new_namespace_is_too_large(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
num_models, models = await self._setup_namespace_test_models()
c_repeated = "c" * 200
@ -487,7 +487,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
@async_client_context.require_version_min(8, 0, 0, -24)
async def test_returns_error_if_no_writes_can_be_added_to_ops(self):
client = await async_rs_or_single_client()
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
# Document too large.
b_repeated = "b" * self.max_message_size_bytes
@ -512,7 +512,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest):
kms_providers={"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}},
)
client = await async_rs_or_single_client(auto_encryption_opts=opts)
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
models = [InsertOne(namespace="db.coll", document={"a": "b"})]
with self.assertRaises(InvalidOperation) as context:
@ -535,7 +535,7 @@ class TestClientBulkWriteCSOT(AsyncIntegrationTest):
_OVERHEAD = 500
internal_client = await async_rs_or_single_client(timeoutMS=None)
self.addAsyncCleanup(internal_client.aclose)
self.addAsyncCleanup(internal_client.close)
collection = internal_client.db["coll"]
self.addAsyncCleanup(collection.drop)
@ -566,7 +566,7 @@ class TestClientBulkWriteCSOT(AsyncIntegrationTest):
timeoutMS=2000,
w="majority",
)
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
await client.admin.command("ping") # Init the client first.
with self.assertRaises(ClientBulkWriteException) as context:
await client.bulk_write(models=models)

View File

@ -1820,7 +1820,7 @@ class AsyncTestCollection(AsyncIntegrationTest):
await self.db.test.insert_many([{"i": i} for i in range(150)])
client = await async_rs_or_single_client(maxPoolSize=1)
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
pool = await async_get_pool(client)
# Make sure the socket is returned after exhaustion.

View File

@ -354,7 +354,7 @@ class TestCursor(AsyncIntegrationTest):
# Do not add readConcern level to explain.
listener = AllowListEventListener("explain")
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
coll = client.pymongo_test.test.with_options(read_concern=ReadConcern(level="local"))
self.assertTrue(await coll.find().explain())
started = listener.started_events
@ -1262,7 +1262,7 @@ class TestCursor(AsyncIntegrationTest):
listener = AllowListEventListener("killCursors")
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
coll = client[self.db.name].test_close_kills_cursors
# Add some test data.
@ -1301,7 +1301,7 @@ class TestCursor(AsyncIntegrationTest):
async def test_timeout_kills_cursor_asynchronously(self):
listener = AllowListEventListener("killCursors")
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
coll = client[self.db.name].test_timeout_kills_cursor
# Add some test data.
@ -1359,7 +1359,7 @@ class TestCursor(AsyncIntegrationTest):
async def test_getMore_does_not_send_readPreference(self):
listener = AllowListEventListener("find", "getMore")
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
# We never send primary read preference so override the default.
coll = client[self.db.name].get_collection(
"test", read_preference=ReadPreference.PRIMARY_PREFERRED
@ -1424,7 +1424,7 @@ class TestRawBatchCursor(AsyncIntegrationTest):
await c.drop()
docs = [{"_id": i, "x": 3.0 * i} for i in range(10)]
await c.insert_many(docs)
batches = await (await c.find_raw_batches()).sort("_id").to_list()
batches = await c.find_raw_batches().sort("_id").to_list()
self.assertEqual(1, len(batches))
self.assertEqual(docs, decode_all(batches[0]))
@ -1440,7 +1440,7 @@ class TestRawBatchCursor(AsyncIntegrationTest):
async with client.start_session() as session:
async with await session.start_transaction():
batches = await (
(await client[self.db.name].test.find_raw_batches(session=session)).sort("_id")
client[self.db.name].test.find_raw_batches(session=session).sort("_id")
).to_list()
cmd = listener.started_events[0]
self.assertEqual(cmd.command_name, "find")
@ -1470,9 +1470,7 @@ class TestRawBatchCursor(AsyncIntegrationTest):
async with self.fail_point(
{"mode": {"times": 1}, "data": {"failCommands": ["find"], "closeConnection": True}}
):
batches = (
await (await client[self.db.name].test.find_raw_batches()).sort("_id").to_list()
)
batches = await client[self.db.name].test.find_raw_batches().sort("_id").to_list()
self.assertEqual(1, len(batches))
self.assertEqual(docs, decode_all(batches[0]))
@ -1493,7 +1491,7 @@ class TestRawBatchCursor(AsyncIntegrationTest):
db = client[self.db.name]
async with client.start_session(snapshot=True) as session:
await db.test.distinct("x", {}, session=session)
batches = await (await db.test.find_raw_batches(session=session)).sort("_id").to_list()
batches = await db.test.find_raw_batches(session=session).sort("_id").to_list()
self.assertEqual(1, len(batches))
self.assertEqual(docs, decode_all(batches[0]))
@ -1504,18 +1502,18 @@ class TestRawBatchCursor(AsyncIntegrationTest):
async def test_explain(self):
c = self.db.test
await c.insert_one({})
explanation = await (await c.find_raw_batches()).explain()
explanation = await c.find_raw_batches().explain()
self.assertIsInstance(explanation, dict)
async def test_empty(self):
await self.db.test.drop()
cursor = await self.db.test.find_raw_batches()
cursor = self.db.test.find_raw_batches()
with self.assertRaises(StopAsyncIteration):
await anext(cursor)
async def test_clone(self):
await self.db.test.insert_one({})
cursor = await self.db.test.find_raw_batches()
cursor = self.db.test.find_raw_batches()
# Copy of a RawBatchCursor is also a RawBatchCursor, not a Cursor.
self.assertIsInstance(await anext(cursor.clone()), bytes)
self.assertIsInstance(await anext(copy.copy(cursor)), bytes)
@ -1525,24 +1523,22 @@ class TestRawBatchCursor(AsyncIntegrationTest):
c = self.db.test
await c.drop()
await c.insert_many({"_id": i} for i in range(200))
result = b"".join(
await (await c.find_raw_batches(cursor_type=CursorType.EXHAUST)).to_list()
)
result = b"".join(await c.find_raw_batches(cursor_type=CursorType.EXHAUST).to_list())
self.assertEqual([{"_id": i} for i in range(200)], decode_all(result))
async def test_server_error(self):
with self.assertRaises(OperationFailure) as exc:
await anext(await self.db.test.find_raw_batches({"x": {"$bad": 1}}))
await anext(self.db.test.find_raw_batches({"x": {"$bad": 1}}))
# The server response was decoded, not left raw.
self.assertIsInstance(exc.exception.details, dict)
async def test_get_item(self):
with self.assertRaises(InvalidOperation):
(await self.db.test.find_raw_batches())[0]
self.db.test.find_raw_batches()[0]
async def test_collation(self):
await anext(await self.db.test.find_raw_batches(collation=Collation("en_US")))
await anext(self.db.test.find_raw_batches(collation=Collation("en_US")))
@async_client_context.require_no_mmap # MMAPv1 does not support read concern
async def test_read_concern(self):
@ -1550,7 +1546,7 @@ class TestRawBatchCursor(AsyncIntegrationTest):
{}
)
c = self.db.get_collection("test", read_concern=ReadConcern("majority"))
await anext(await c.find_raw_batches())
await anext(c.find_raw_batches())
async def test_monitoring(self):
listener = EventListener()
@ -1560,7 +1556,7 @@ class TestRawBatchCursor(AsyncIntegrationTest):
await c.insert_many([{"_id": i} for i in range(10)])
listener.reset()
cursor = await c.find_raw_batches(batch_size=4)
cursor = c.find_raw_batches(batch_size=4)
# First raw batch of 4 documents.
await anext(cursor)
@ -1766,7 +1762,7 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest):
async def test_exhaust_cursor_db_set(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
c = client.pymongo_test.test
await c.delete_many({})
await c.insert_many([{"_id": i} for i in range(3)])

View File

@ -236,7 +236,7 @@ class TestDatabase(AsyncIntegrationTest):
async def test_check_exists(self):
listener = OvertCommandListener()
client = await async_rs_or_single_client(event_listeners=[listener])
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
db = client[self.db.name]
await db.drop_collection("unique")
await db.create_collection("unique", check_exists=True)

View File

@ -99,7 +99,7 @@ class TestSession(AsyncIntegrationTest):
@classmethod
async def _tearDown_class(cls):
monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands)
await cls.client2.aclose()
await cls.client2.close()
await super()._tearDown_class()
async def asyncSetUp(self):
@ -108,7 +108,7 @@ class TestSession(AsyncIntegrationTest):
self.client = await async_rs_or_single_client(
event_listeners=[self.listener, self.session_checker_listener]
)
self.addAsyncCleanup(self.client.aclose)
self.addAsyncCleanup(self.client.close)
self.db = self.client.pymongo_test
self.initial_lsids = {s["id"] for s in session_ids(self.client)}
@ -295,14 +295,14 @@ class TestSession(AsyncIntegrationTest):
# Closing the client should end all sessions and clear the pool.
self.assertEqual(len(client._topology._session_pool), _MAX_END_SESSIONS + 1)
await client.aclose()
await client.close()
self.assertEqual(len(client._topology._session_pool), 0)
end_sessions = [e for e in listener.started_events if e.command_name == "endSessions"]
self.assertEqual(len(end_sessions), 2)
# Closing again should not send any commands.
listener.reset()
await client.aclose()
await client.close()
self.assertEqual(len(listener.started_events), 0)
async def test_client(self):
@ -790,7 +790,7 @@ class TestSession(AsyncIntegrationTest):
# Ensure the collection exists.
await self.client.pymongo_test.test_unacked_writes.insert_one({})
client = await async_rs_or_single_client(w=0, event_listeners=[self.listener])
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
db = client.pymongo_test
coll = db.test_unacked_writes
ops: list = [
@ -842,7 +842,7 @@ class TestCausalConsistency(AsyncUnitTest):
@classmethod
async def _tearDown_class(cls):
await cls.client.aclose()
await cls.client.close()
@async_client_context.require_sessions
async def asyncSetUp(self):
@ -927,7 +927,7 @@ class TestCausalConsistency(AsyncUnitTest):
return await (await coll.aggregate_raw_batches([], session=session)).to_list()
async def find_raw(coll, session):
return await (await coll.find_raw_batches({}, session=session)).to_list()
return await coll.find_raw_batches({}, session=session).to_list()
await self._test_reads(aggregate)
await self._test_reads(lambda coll, session: coll.find({}, session=session).to_list())
@ -1156,7 +1156,7 @@ class TestClusterTime(AsyncIntegrationTest):
client = await async_rs_or_single_client(
event_listeners=[listener], heartbeatFrequencyMS=999999
)
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
collection = client.pymongo_test.collection
# Prepare for tests of find() and aggregate().
await collection.insert_many([{} for _ in range(10)])

View File

@ -74,7 +74,7 @@ class AsyncTransactionsBase(AsyncSpecRunner):
@classmethod
async def _tearDown_class(cls):
for client in cls.mongos_clients:
await client.aclose()
await client.close()
await super()._tearDown_class()
def maybe_skip_scenario(self, test):
@ -121,7 +121,7 @@ class TestTransactions(AsyncTransactionsBase):
async def test_transaction_write_concern_override(self):
"""Test txn overrides Client/Database/Collection write_concern."""
client = await async_rs_client(w=0)
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
db = client.test
coll = db.test
await coll.insert_one({})
@ -183,7 +183,7 @@ class TestTransactions(AsyncTransactionsBase):
coll = client.test.test
# Create the collection.
await coll.insert_one({})
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
async with client.start_session() as s:
# Session is pinned to Mongos.
async with await s.start_transaction():
@ -211,7 +211,7 @@ class TestTransactions(AsyncTransactionsBase):
coll = client.test.test
# Create the collection.
await coll.insert_one({})
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
async with client.start_session() as s:
# Session is pinned to Mongos.
async with await s.start_transaction():
@ -339,7 +339,7 @@ class TestTransactions(AsyncTransactionsBase):
coll = client[self.db.name].test
await coll.delete_many({})
listener.reset()
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
self.addAsyncCleanup(coll.drop)
large_str = "\0" * (1 * 1024 * 1024)
ops: List[InsertOne[RawBSONDocument]] = [
@ -365,7 +365,7 @@ class TestTransactions(AsyncTransactionsBase):
@async_client_context.require_transactions
async def test_transaction_direct_connection(self):
client = await async_single_client()
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
coll = client.pymongo_test.test
# Make sure the collection exists.
@ -375,6 +375,9 @@ class TestTransactions(AsyncTransactionsBase):
async def find(*args, **kwargs):
return coll.find(*args, **kwargs)
async def find_raw_batches(*args, **kwargs):
return coll.find_raw_batches(*args, **kwargs)
ops = [
(coll.bulk_write, [[InsertOne[dict]({})]]),
(coll.insert_one, [{}]),
@ -393,7 +396,7 @@ class TestTransactions(AsyncTransactionsBase):
(coll.aggregate, [[]]),
(find, [{}]),
(coll.aggregate_raw_batches, [[]]),
(coll.find_raw_batches, [{}]),
(find_raw_batches, [{}]),
(coll.database.command, ["find", coll.name]),
]
for f, args in ops:
@ -452,7 +455,7 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
async def test_callback_not_retried_after_timeout(self):
listener = OvertCommandListener()
client = await async_rs_client(event_listeners=[listener])
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
coll = client[self.db.name].test
async def callback(session):
@ -481,7 +484,7 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
async def test_callback_not_retried_after_commit_timeout(self):
listener = OvertCommandListener()
client = await async_rs_client(event_listeners=[listener])
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
coll = client[self.db.name].test
async def callback(session):
@ -516,7 +519,7 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase):
async def test_commit_not_retried_after_timeout(self):
listener = OvertCommandListener()
client = await async_rs_client(event_listeners=[listener])
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
coll = client[self.db.name].test
async def callback(session):

View File

@ -535,7 +535,7 @@ class AsyncSpecRunner(AsyncIntegrationTest):
self.pool_listener = pool_listener
self.server_listener = server_listener
# Close the client explicitly to avoid having too many threads open.
self.addAsyncCleanup(client.aclose)
self.addAsyncCleanup(client.close)
# Create session0 and session1.
sessions = {}

View File

@ -1415,7 +1415,7 @@ class TestRawBatchCursor(IntegrationTest):
c.drop()
docs = [{"_id": i, "x": 3.0 * i} for i in range(10)]
c.insert_many(docs)
batches = (c.find_raw_batches()).sort("_id").to_list()
batches = c.find_raw_batches().sort("_id").to_list()
self.assertEqual(1, len(batches))
self.assertEqual(docs, decode_all(batches[0]))
@ -1431,7 +1431,7 @@ class TestRawBatchCursor(IntegrationTest):
with client.start_session() as session:
with session.start_transaction():
batches = (
(client[self.db.name].test.find_raw_batches(session=session)).sort("_id")
client[self.db.name].test.find_raw_batches(session=session).sort("_id")
).to_list()
cmd = listener.started_events[0]
self.assertEqual(cmd.command_name, "find")
@ -1461,7 +1461,7 @@ class TestRawBatchCursor(IntegrationTest):
with self.fail_point(
{"mode": {"times": 1}, "data": {"failCommands": ["find"], "closeConnection": True}}
):
batches = (client[self.db.name].test.find_raw_batches()).sort("_id").to_list()
batches = client[self.db.name].test.find_raw_batches().sort("_id").to_list()
self.assertEqual(1, len(batches))
self.assertEqual(docs, decode_all(batches[0]))
@ -1482,7 +1482,7 @@ class TestRawBatchCursor(IntegrationTest):
db = client[self.db.name]
with client.start_session(snapshot=True) as session:
db.test.distinct("x", {}, session=session)
batches = (db.test.find_raw_batches(session=session)).sort("_id").to_list()
batches = db.test.find_raw_batches(session=session).sort("_id").to_list()
self.assertEqual(1, len(batches))
self.assertEqual(docs, decode_all(batches[0]))
@ -1493,7 +1493,7 @@ class TestRawBatchCursor(IntegrationTest):
def test_explain(self):
c = self.db.test
c.insert_one({})
explanation = (c.find_raw_batches()).explain()
explanation = c.find_raw_batches().explain()
self.assertIsInstance(explanation, dict)
def test_empty(self):
@ -1514,7 +1514,7 @@ class TestRawBatchCursor(IntegrationTest):
c = self.db.test
c.drop()
c.insert_many({"_id": i} for i in range(200))
result = b"".join((c.find_raw_batches(cursor_type=CursorType.EXHAUST)).to_list())
result = b"".join(c.find_raw_batches(cursor_type=CursorType.EXHAUST).to_list())
self.assertEqual([{"_id": i} for i in range(200)], decode_all(result))
def test_server_error(self):
@ -1526,7 +1526,7 @@ class TestRawBatchCursor(IntegrationTest):
def test_get_item(self):
with self.assertRaises(InvalidOperation):
(self.db.test.find_raw_batches())[0]
self.db.test.find_raw_batches()[0]
def test_collation(self):
next(self.db.test.find_raw_batches(collation=Collation("en_US")))

View File

@ -925,7 +925,7 @@ class TestCausalConsistency(UnitTest):
return (coll.aggregate_raw_batches([], session=session)).to_list()
def find_raw(coll, session):
return (coll.find_raw_batches({}, session=session)).to_list()
return coll.find_raw_batches({}, session=session).to_list()
self._test_reads(aggregate)
self._test_reads(lambda coll, session: coll.find({}, session=session).to_list())

View File

@ -371,6 +371,9 @@ class TestTransactions(TransactionsBase):
def find(*args, **kwargs):
return coll.find(*args, **kwargs)
def find_raw_batches(*args, **kwargs):
return coll.find_raw_batches(*args, **kwargs)
ops = [
(coll.bulk_write, [[InsertOne[dict]({})]]),
(coll.insert_one, [{}]),
@ -389,7 +392,7 @@ class TestTransactions(TransactionsBase):
(coll.aggregate, [[]]),
(find, [{}]),
(coll.aggregate_raw_batches, [[]]),
(coll.find_raw_batches, [{}]),
(find_raw_batches, [{}]),
(coll.database.command, ["find", coll.name]),
]
for f, args in ops:

View File

@ -92,7 +92,6 @@ replacements = {
"asyncAssertRaisesExactly": "assertRaisesExactly",
"get_async_mock_client": "get_mock_client",
"aconnect": "_connect",
"aclose": "close",
"async-transactions-ref": "transactions-ref",
"async-snapshot-reads-ref": "snapshot-reads-ref",
"default_async": "default",