From f69d330b25db33272fcccbe762668c1a7e8833ab Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 13 Aug 2024 19:17:45 -0500 Subject: [PATCH] PYTHON-4654 Clean up Async API to match Motor (#1789) --- doc/api/pymongo/asynchronous/mongo_client.rst | 2 +- pymongo/asynchronous/collection.py | 6 +- pymongo/asynchronous/encryption.py | 4 +- pymongo/asynchronous/mongo_client.py | 8 +- pymongo/synchronous/mongo_client.py | 4 + test/asynchronous/__init__.py | 6 +- test/asynchronous/test_client.py | 80 +++++++++---------- test/asynchronous/test_client_bulk_write.py | 32 ++++---- test/asynchronous/test_collection.py | 2 +- test/asynchronous/test_cursor.py | 40 +++++----- test/asynchronous/test_database.py | 2 +- test/asynchronous/test_session.py | 16 ++-- test/asynchronous/test_transactions.py | 23 +++--- test/asynchronous/utils_spec_runner.py | 2 +- test/test_cursor.py | 14 ++-- test/test_session.py | 2 +- test/test_transactions.py | 5 +- tools/synchro.py | 1 - 18 files changed, 128 insertions(+), 121 deletions(-) diff --git a/doc/api/pymongo/asynchronous/mongo_client.rst b/doc/api/pymongo/asynchronous/mongo_client.rst index 57aa33e3c..afbd802ff 100644 --- a/doc/api/pymongo/asynchronous/mongo_client.rst +++ b/doc/api/pymongo/asynchronous/mongo_client.rst @@ -6,7 +6,7 @@ .. autoclass:: pymongo.asynchronous.mongo_client.AsyncMongoClient(host='localhost', port=27017, document_class=dict, tz_aware=False, connect=True, **kwargs) - .. automethod:: aclose + .. automethod:: close .. describe:: c[db_name] || c.db_name diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index e634b449f..e5a54c090 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -1893,9 +1893,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): """ return AsyncCursor(self, *args, **kwargs) - async def find_raw_batches( - self, *args: Any, **kwargs: Any - ) -> AsyncRawBatchCursor[_DocumentType]: + def find_raw_batches(self, *args: Any, **kwargs: Any) -> AsyncRawBatchCursor[_DocumentType]: """Query the database and retrieve batches of raw BSON. Similar to the :meth:`find` method but returns a @@ -1907,7 +1905,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): :mod:`bson` module. >>> import bson - >>> cursor = await db.test.find_raw_batches() + >>> cursor = db.test.find_raw_batches() >>> async for batch in cursor: ... print(bson.decode_all(batch)) diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index f8c03b21c..8b63525f2 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -299,7 +299,7 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc] self.client_ref = None self.key_vault_coll = None if self.mongocryptd_client: - await self.mongocryptd_client.aclose() + await self.mongocryptd_client.close() self.mongocryptd_client = None @@ -439,7 +439,7 @@ class _Encrypter: self._closed = True await self._auto_encrypter.close() if self._internal_client: - await self._internal_client.aclose() + await self._internal_client.close() self._internal_client = None diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 90e40978a..fbbd9a4ee 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1378,7 +1378,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): return self async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - await self.aclose() + await self.close() # See PYTHON-3084. __iter__ = None @@ -1514,7 +1514,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): # command. pass - async def aclose(self) -> None: + async def close(self) -> None: """Cleanup client resources and disconnect from MongoDB. End all server sessions created by this client by sending one or more @@ -1541,6 +1541,10 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. await self._encrypter.close() + if not _IS_SYNC: + # Add support for contextlib.aclosing. + aclose = close + async def _get_topology(self) -> Topology: """Get the internal :class:`~pymongo.asynchronous.topology.Topology` object. diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 41b4db4f1..186316562 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1536,6 +1536,10 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. self._encrypter.close() + if not _IS_SYNC: + # Add support for contextlib.closing. + aclose = close + def _get_topology(self) -> Topology: """Get the internal :class:`~pymongo.topology.Topology` object. diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index a95a9e31b..900b260c2 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -193,7 +193,7 @@ class AsyncClientContext: self.connection_attempts.append(f"failed to connect client {client!r}: {exc}") return None finally: - await client.aclose() + await client.close() async def _init_client(self): self.client = await self._connect(host, port) @@ -409,7 +409,7 @@ class AsyncClientContext: else: raise finally: - await client.aclose() + await client.close() def _server_started_with_auth(self): # MongoDB >= 2.0 @@ -1089,7 +1089,7 @@ async def async_teardown(): await c.drop_database("pymongo_test2") await c.drop_database("pymongo_test_mike") await c.drop_database("pymongo_test_bernie") - await c.aclose() + await c.close() print_running_clients() diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index fa526b3ea..9489de156 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -137,7 +137,7 @@ class AsyncClientUnitTest(AsyncUnitTest): @classmethod async def _tearDown_class(cls): - await cls.client.aclose() + await cls.client.close() @pytest.fixture(autouse=True) def inject_fixtures(self, caplog): @@ -663,7 +663,7 @@ class TestClient(AsyncIntegrationTest): pass self.assertEqual(1, len(server._pool.conns)) self.assertTrue(conn in server._pool.conns) - await client.aclose() + await client.close() async def test_max_idle_time_reaper_removes_stale_minPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): @@ -679,7 +679,7 @@ class TestClient(AsyncIntegrationTest): self.assertGreaterEqual(len(server._pool.conns), 1) wait_until(lambda: conn not in server._pool.conns, "remove stale socket") wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket") - await client.aclose() + await client.close() async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): @@ -697,7 +697,7 @@ class TestClient(AsyncIntegrationTest): self.assertEqual(1, len(server._pool.conns)) wait_until(lambda: conn not in server._pool.conns, "remove stale socket") wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket") - await client.aclose() + await client.close() async def test_max_idle_time_reaper_removes_stale(self): with client_knobs(kill_cursor_frequency=0.1): @@ -717,7 +717,7 @@ class TestClient(AsyncIntegrationTest): lambda: len(server._pool.conns) == 0, "stale socket reaped and new one NOT added to the pool", ) - await client.aclose() + await client.close() async def test_min_pool_size(self): with client_knobs(kill_cursor_frequency=0.1): @@ -845,13 +845,13 @@ class TestClient(AsyncIntegrationTest): async def test_equality(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) c = await async_rs_or_single_client(seed, connect=False) - self.addAsyncCleanup(c.aclose) + self.addAsyncCleanup(c.close) self.assertEqual(async_client_context.client, c) # Explicitly test inequality self.assertFalse(async_client_context.client != c) c = await async_rs_or_single_client("invalid.com", connect=False) - self.addAsyncCleanup(c.aclose) + self.addAsyncCleanup(c.close) self.assertNotEqual(async_client_context.client, c) self.assertTrue(async_client_context.client != c) # Seeds differ: @@ -867,10 +867,10 @@ class TestClient(AsyncIntegrationTest): async def test_hashable(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) c = await async_rs_or_single_client(seed, connect=False) - self.addAsyncCleanup(c.aclose) + self.addAsyncCleanup(c.close) self.assertIn(c, {async_client_context.client}) c = await async_rs_or_single_client("invalid.com", connect=False) - self.addAsyncCleanup(c.aclose) + self.addAsyncCleanup(c.close) self.assertNotIn(c, {async_client_context.client}) async def test_host_w_port(self): @@ -940,7 +940,7 @@ class TestClient(AsyncIntegrationTest): self.assertIs(type(helper_doc), dict) self.assertEqual(helper_doc.keys(), cmd_doc.keys()) client = await async_rs_or_single_client(document_class=SON) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) async for doc in await client.list_databases(): self.assertIs(type(doc), dict) @@ -991,7 +991,7 @@ class TestClient(AsyncIntegrationTest): async def test_close(self): test_client = await async_rs_or_single_client() coll = test_client.pymongo_test.bar - await test_client.aclose() + await test_client.close() with self.assertRaises(InvalidOperation): await coll.count_documents({}) @@ -1024,7 +1024,7 @@ class TestClient(AsyncIntegrationTest): # Close the client and ensure the topology is closed. self.assertTrue(test_client._topology._opened) - await test_client.aclose() + await test_client.close() self.assertFalse(test_client._topology._opened) test_client = await async_rs_or_single_client() # The killCursors task should not need to re-open the topology. @@ -1037,7 +1037,7 @@ class TestClient(AsyncIntegrationTest): self.assertFalse(client._kill_cursors_executor._stopped) # Closing the client should stop the thread. - await client.aclose() + await client.close() self.assertTrue(client._kill_cursors_executor._stopped) # Reusing the closed client should raise an InvalidOperation error. @@ -1062,21 +1062,21 @@ class TestClient(AsyncIntegrationTest): self.assertTrue(kc_thread and kc_thread.is_alive()) # Tear down. - await client.aclose() + await client.close() async def test_close_does_not_open_servers(self): client = await async_rs_client(connect=False) topology = client._topology self.assertEqual(topology._servers, {}) - await client.aclose() + await client.close() self.assertEqual(topology._servers, {}) async def test_close_closes_sockets(self): client = await async_rs_client() - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) await client.test.test.find_one() topology = client._topology - await client.aclose() + await client.close() for server in topology._servers.values(): self.assertFalse(server._pool.conns) self.assertTrue(server._monitor._executor._stopped) @@ -1181,7 +1181,7 @@ class TestClient(AsyncIntegrationTest): uri = "mongodb://%s" % encoded_socket # Confirm we can do operations via the socket. client = await async_rs_or_single_client(uri) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) await client.pymongo_test.test.insert_one({"dummy": "object"}) dbs = await client.list_database_names() self.assertTrue("pymongo_test" in dbs) @@ -1206,7 +1206,7 @@ class TestClient(AsyncIntegrationTest): self.assertFalse(isinstance(await db.test.find_one(), SON)) c = await async_rs_or_single_client(document_class=SON) - self.addAsyncCleanup(c.aclose) + self.addAsyncCleanup(c.close) db = c.pymongo_test self.assertEqual(SON, c.codec_options.document_class) @@ -1248,7 +1248,7 @@ class TestClient(AsyncIntegrationTest): no_timeout = self.client timeout_sec = 1 timeout = await async_rs_or_single_client(socketTimeoutMS=1000 * timeout_sec) - self.addAsyncCleanup(timeout.aclose) + self.addAsyncCleanup(timeout.close) await no_timeout.pymongo_test.drop_collection("test") await no_timeout.pymongo_test.test.insert_one({"x": 1}) @@ -1310,7 +1310,7 @@ class TestClient(AsyncIntegrationTest): self.assertRaises(ValueError, AsyncMongoClient, tz_aware="foo") aware = await async_rs_or_single_client(tz_aware=True) - self.addAsyncCleanup(aware.aclose) + self.addAsyncCleanup(aware.close) naive = self.client await aware.pymongo_test.drop_collection("test") @@ -1340,7 +1340,7 @@ class TestClient(AsyncIntegrationTest): uri += "/?replicaSet=" + (async_client_context.replica_set_name or "") client = await async_rs_or_single_client_noauth(uri) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) await client.pymongo_test.test.insert_one({"dummy": "object"}) await client.pymongo_test_bernie.test.insert_one({"dummy": "object"}) @@ -1442,7 +1442,7 @@ class TestClient(AsyncIntegrationTest): # to avoid race conditions caused by replica set failover or idle # socket reaping. client = await async_single_client() - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) await client.pymongo_test.test.find_one() pool = await async_get_pool(client) socket_count = len(pool.conns) @@ -1467,7 +1467,7 @@ class TestClient(AsyncIntegrationTest): self.addAsyncCleanup(async_client_context.client.drop_database, "test_lazy_connect_w0") client = await async_rs_or_single_client(connect=False, w=0) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) await client.test_lazy_connect_w0.test.insert_one({}) async def predicate(): @@ -1476,7 +1476,7 @@ class TestClient(AsyncIntegrationTest): await async_wait_until(predicate, "find one document") client = await async_rs_or_single_client(connect=False, w=0) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) await client.test_lazy_connect_w0.test.update_one({}, {"$set": {"x": 1}}) async def predicate(): @@ -1485,7 +1485,7 @@ class TestClient(AsyncIntegrationTest): await async_wait_until(predicate, "update one document") client = await async_rs_or_single_client(connect=False, w=0) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) await client.test_lazy_connect_w0.test.delete_one({}) async def predicate(): @@ -1498,7 +1498,7 @@ class TestClient(AsyncIntegrationTest): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. client = await async_rs_or_single_client(maxPoolSize=1, retryReads=False) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) collection = client.pymongo_test.test pool = await async_get_pool(client) pool._check_interval_seconds = None # Never check. @@ -1611,7 +1611,7 @@ class TestClient(AsyncIntegrationTest): # closer to 0.5 sec with heartbeatFrequencyMS configured. self.assertAlmostEqual(heartbeat_times[1] - heartbeat_times[0], 0.5, delta=2) - await client.aclose() + await client.close() finally: ServerHeartbeatStartedEvent.__init__ = old_init # type: ignore @@ -1709,7 +1709,7 @@ class TestClient(AsyncIntegrationTest): async def test_reset_during_update_pool(self): client = await async_rs_or_single_client(minPoolSize=10) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) await client.admin.command("ping") pool = await async_get_pool(client) generation = pool.gen.get_overall() @@ -1758,7 +1758,7 @@ class TestClient(AsyncIntegrationTest): client = await async_rs_or_single_client( serverSelectionTimeoutMS=3000, minPoolSize=min_pool_size, connect=False ) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) # Create a single connection in the pool. await client.admin.command("ping") @@ -1793,7 +1793,7 @@ class TestClient(AsyncIntegrationTest): await client.admin.command("ping") self.assertEqual(len(client.nodes), 1) self.assertEqual(client._topology_settings.get_topology_type(), TOPOLOGY_TYPE.Single) - await client.aclose() + await client.close() # direct_connection=False should result in RS topology. client = await async_rs_or_single_client(directConnection=False) @@ -1803,7 +1803,7 @@ class TestClient(AsyncIntegrationTest): client._topology_settings.get_topology_type(), [TOPOLOGY_TYPE.ReplicaSetNoPrimary, TOPOLOGY_TYPE.ReplicaSetWithPrimary], ) - await client.aclose() + await client.close() # directConnection=True, should error with multiple hosts as a list. with self.assertRaises(ConfigurationError): @@ -1827,7 +1827,7 @@ class TestClient(AsyncIntegrationTest): "invalid:27017", heartbeatFrequencyMS=3, serverSelectionTimeoutMS=150 ) initial_count = server_description_count() - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) with self.assertRaises(ServerSelectionTimeoutError): await client.test.test.find_one() gc.collect() @@ -1841,7 +1841,7 @@ class TestClient(AsyncIntegrationTest): @async_client_context.require_failCommand_fail_point async def test_network_error_message(self): client = await async_single_client(retryReads=False) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) await client.admin.command("ping") # connect async with self.fail_point( {"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}} @@ -1860,7 +1860,7 @@ class TestClient(AsyncIntegrationTest): await cursor.next() c_id = cursor.cursor_id self.assertIsNotNone(c_id) - await client.aclose() + await client.close() # Add cursor to kill cursors queue del cursor wait_until( @@ -2315,7 +2315,7 @@ class TestMongoClientFailover(AsyncMockClientTest): replicaSet="rs", heartbeatFrequencyMS=500, ) - self.addAsyncCleanup(c.aclose) + self.addAsyncCleanup(c.close) wait_until(lambda: len(c.nodes) == 3, "connect") @@ -2342,7 +2342,7 @@ class TestMongoClientFailover(AsyncMockClientTest): retryReads=False, serverSelectionTimeoutMS=1000, ) - self.addAsyncCleanup(c.aclose) + self.addAsyncCleanup(c.close) wait_until(lambda: len(c.nodes) == 3, "connect") @@ -2377,7 +2377,7 @@ class TestMongoClientFailover(AsyncMockClientTest): retryReads=False, serverSelectionTimeoutMS=1000, ) - self.addAsyncCleanup(c.aclose) + self.addAsyncCleanup(c.close) # Set host-specific information so we can test whether it is reset. c.set_wire_version_range("a:1", 2, 6) @@ -2453,7 +2453,7 @@ class TestClientPool(AsyncMockClientTest): minPoolSize=1, # minPoolSize event_listeners=[listener], ) - self.addAsyncCleanup(c.aclose) + self.addAsyncCleanup(c.close) wait_until(lambda: len(c.nodes) == 3, "connect") self.assertEqual(await c.address, ("a", 1)) @@ -2483,7 +2483,7 @@ class TestClientPool(AsyncMockClientTest): minPoolSize=1, # minPoolSize event_listeners=[listener], ) - self.addAsyncCleanup(c.aclose) + self.addAsyncCleanup(c.close) wait_until(lambda: len(c.nodes) == 1, "connect") self.assertEqual(await c.address, ("c", 3)) diff --git a/test/asynchronous/test_client_bulk_write.py b/test/asynchronous/test_client_bulk_write.py index 4fe4fce2d..eea0b4e8e 100644 --- a/test/asynchronous/test_client_bulk_write.py +++ b/test/asynchronous/test_client_bulk_write.py @@ -43,7 +43,7 @@ class TestClientBulkWrite(AsyncIntegrationTest): @async_client_context.require_version_min(8, 0, 0, -24) async def test_returns_error_if_no_namespace_provided(self): client = await async_rs_or_single_client() - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) models = [InsertOne(document={"a": "b"})] with self.assertRaises(InvalidOperation) as context: @@ -65,7 +65,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): async def test_batch_splits_if_num_operations_too_large(self): listener = OvertCommandListener() client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) models = [] for _ in range(self.max_write_batch_size + 1): @@ -90,7 +90,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): async def test_batch_splits_if_ops_payload_too_large(self): listener = OvertCommandListener() client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) models = [] num_models = int(self.max_message_size_bytes / self.max_bson_object_size + 1) @@ -126,7 +126,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): event_listeners=[listener], retryWrites=False, ) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) fail_command = { "configureFailPoint": "failCommand", @@ -165,7 +165,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): async def test_collects_write_errors_across_batches_unordered(self): listener = OvertCommandListener() client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) collection = client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -195,7 +195,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): async def test_collects_write_errors_across_batches_ordered(self): listener = OvertCommandListener() client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) collection = client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -225,7 +225,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): async def test_handles_cursor_requiring_getMore(self): listener = OvertCommandListener() client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) collection = client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -266,7 +266,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): async def test_handles_cursor_requiring_getMore_within_transaction(self): listener = OvertCommandListener() client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) collection = client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -309,7 +309,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): async def test_handles_getMore_error(self): listener = OvertCommandListener() client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) collection = client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -363,7 +363,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): async def test_returns_error_if_unacknowledged_too_large_insert(self): listener = OvertCommandListener() client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) b_repeated = "b" * self.max_bson_object_size @@ -419,7 +419,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): async def test_no_batch_splits_if_new_namespace_is_not_too_large(self): listener = OvertCommandListener() client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) num_models, models = await self._setup_namespace_test_models() models.append( @@ -450,7 +450,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): async def test_batch_splits_if_new_namespace_is_too_large(self): listener = OvertCommandListener() client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) num_models, models = await self._setup_namespace_test_models() c_repeated = "c" * 200 @@ -487,7 +487,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): @async_client_context.require_version_min(8, 0, 0, -24) async def test_returns_error_if_no_writes_can_be_added_to_ops(self): client = await async_rs_or_single_client() - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) # Document too large. b_repeated = "b" * self.max_message_size_bytes @@ -512,7 +512,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): kms_providers={"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}}, ) client = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) models = [InsertOne(namespace="db.coll", document={"a": "b"})] with self.assertRaises(InvalidOperation) as context: @@ -535,7 +535,7 @@ class TestClientBulkWriteCSOT(AsyncIntegrationTest): _OVERHEAD = 500 internal_client = await async_rs_or_single_client(timeoutMS=None) - self.addAsyncCleanup(internal_client.aclose) + self.addAsyncCleanup(internal_client.close) collection = internal_client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -566,7 +566,7 @@ class TestClientBulkWriteCSOT(AsyncIntegrationTest): timeoutMS=2000, w="majority", ) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) await client.admin.command("ping") # Init the client first. with self.assertRaises(ClientBulkWriteException) as context: await client.bulk_write(models=models) diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index 90a53e8b8..10d64a525 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -1820,7 +1820,7 @@ class AsyncTestCollection(AsyncIntegrationTest): await self.db.test.insert_many([{"i": i} for i in range(150)]) client = await async_rs_or_single_client(maxPoolSize=1) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) pool = await async_get_pool(client) # Make sure the socket is returned after exhaustion. diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index d6d56244f..9dd67f2da 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -354,7 +354,7 @@ class TestCursor(AsyncIntegrationTest): # Do not add readConcern level to explain. listener = AllowListEventListener("explain") client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) coll = client.pymongo_test.test.with_options(read_concern=ReadConcern(level="local")) self.assertTrue(await coll.find().explain()) started = listener.started_events @@ -1262,7 +1262,7 @@ class TestCursor(AsyncIntegrationTest): listener = AllowListEventListener("killCursors") client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) coll = client[self.db.name].test_close_kills_cursors # Add some test data. @@ -1301,7 +1301,7 @@ class TestCursor(AsyncIntegrationTest): async def test_timeout_kills_cursor_asynchronously(self): listener = AllowListEventListener("killCursors") client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) coll = client[self.db.name].test_timeout_kills_cursor # Add some test data. @@ -1359,7 +1359,7 @@ class TestCursor(AsyncIntegrationTest): async def test_getMore_does_not_send_readPreference(self): listener = AllowListEventListener("find", "getMore") client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) # We never send primary read preference so override the default. coll = client[self.db.name].get_collection( "test", read_preference=ReadPreference.PRIMARY_PREFERRED @@ -1424,7 +1424,7 @@ class TestRawBatchCursor(AsyncIntegrationTest): await c.drop() docs = [{"_id": i, "x": 3.0 * i} for i in range(10)] await c.insert_many(docs) - batches = await (await c.find_raw_batches()).sort("_id").to_list() + batches = await c.find_raw_batches().sort("_id").to_list() self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0])) @@ -1440,7 +1440,7 @@ class TestRawBatchCursor(AsyncIntegrationTest): async with client.start_session() as session: async with await session.start_transaction(): batches = await ( - (await client[self.db.name].test.find_raw_batches(session=session)).sort("_id") + client[self.db.name].test.find_raw_batches(session=session).sort("_id") ).to_list() cmd = listener.started_events[0] self.assertEqual(cmd.command_name, "find") @@ -1470,9 +1470,7 @@ class TestRawBatchCursor(AsyncIntegrationTest): async with self.fail_point( {"mode": {"times": 1}, "data": {"failCommands": ["find"], "closeConnection": True}} ): - batches = ( - await (await client[self.db.name].test.find_raw_batches()).sort("_id").to_list() - ) + batches = await client[self.db.name].test.find_raw_batches().sort("_id").to_list() self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0])) @@ -1493,7 +1491,7 @@ class TestRawBatchCursor(AsyncIntegrationTest): db = client[self.db.name] async with client.start_session(snapshot=True) as session: await db.test.distinct("x", {}, session=session) - batches = await (await db.test.find_raw_batches(session=session)).sort("_id").to_list() + batches = await db.test.find_raw_batches(session=session).sort("_id").to_list() self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0])) @@ -1504,18 +1502,18 @@ class TestRawBatchCursor(AsyncIntegrationTest): async def test_explain(self): c = self.db.test await c.insert_one({}) - explanation = await (await c.find_raw_batches()).explain() + explanation = await c.find_raw_batches().explain() self.assertIsInstance(explanation, dict) async def test_empty(self): await self.db.test.drop() - cursor = await self.db.test.find_raw_batches() + cursor = self.db.test.find_raw_batches() with self.assertRaises(StopAsyncIteration): await anext(cursor) async def test_clone(self): await self.db.test.insert_one({}) - cursor = await self.db.test.find_raw_batches() + cursor = self.db.test.find_raw_batches() # Copy of a RawBatchCursor is also a RawBatchCursor, not a Cursor. self.assertIsInstance(await anext(cursor.clone()), bytes) self.assertIsInstance(await anext(copy.copy(cursor)), bytes) @@ -1525,24 +1523,22 @@ class TestRawBatchCursor(AsyncIntegrationTest): c = self.db.test await c.drop() await c.insert_many({"_id": i} for i in range(200)) - result = b"".join( - await (await c.find_raw_batches(cursor_type=CursorType.EXHAUST)).to_list() - ) + result = b"".join(await c.find_raw_batches(cursor_type=CursorType.EXHAUST).to_list()) self.assertEqual([{"_id": i} for i in range(200)], decode_all(result)) async def test_server_error(self): with self.assertRaises(OperationFailure) as exc: - await anext(await self.db.test.find_raw_batches({"x": {"$bad": 1}})) + await anext(self.db.test.find_raw_batches({"x": {"$bad": 1}})) # The server response was decoded, not left raw. self.assertIsInstance(exc.exception.details, dict) async def test_get_item(self): with self.assertRaises(InvalidOperation): - (await self.db.test.find_raw_batches())[0] + self.db.test.find_raw_batches()[0] async def test_collation(self): - await anext(await self.db.test.find_raw_batches(collation=Collation("en_US"))) + await anext(self.db.test.find_raw_batches(collation=Collation("en_US"))) @async_client_context.require_no_mmap # MMAPv1 does not support read concern async def test_read_concern(self): @@ -1550,7 +1546,7 @@ class TestRawBatchCursor(AsyncIntegrationTest): {} ) c = self.db.get_collection("test", read_concern=ReadConcern("majority")) - await anext(await c.find_raw_batches()) + await anext(c.find_raw_batches()) async def test_monitoring(self): listener = EventListener() @@ -1560,7 +1556,7 @@ class TestRawBatchCursor(AsyncIntegrationTest): await c.insert_many([{"_id": i} for i in range(10)]) listener.reset() - cursor = await c.find_raw_batches(batch_size=4) + cursor = c.find_raw_batches(batch_size=4) # First raw batch of 4 documents. await anext(cursor) @@ -1766,7 +1762,7 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest): async def test_exhaust_cursor_db_set(self): listener = OvertCommandListener() client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) c = client.pymongo_test.test await c.delete_many({}) await c.insert_many([{"_id": i} for i in range(3)]) diff --git a/test/asynchronous/test_database.py b/test/asynchronous/test_database.py index c20a74d3d..8f6886a2a 100644 --- a/test/asynchronous/test_database.py +++ b/test/asynchronous/test_database.py @@ -236,7 +236,7 @@ class TestDatabase(AsyncIntegrationTest): async def test_check_exists(self): listener = OvertCommandListener() client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) db = client[self.db.name] await db.drop_collection("unique") await db.create_collection("unique", check_exists=True) diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index b8387b7f6..1e1f5659b 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -99,7 +99,7 @@ class TestSession(AsyncIntegrationTest): @classmethod async def _tearDown_class(cls): monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands) - await cls.client2.aclose() + await cls.client2.close() await super()._tearDown_class() async def asyncSetUp(self): @@ -108,7 +108,7 @@ class TestSession(AsyncIntegrationTest): self.client = await async_rs_or_single_client( event_listeners=[self.listener, self.session_checker_listener] ) - self.addAsyncCleanup(self.client.aclose) + self.addAsyncCleanup(self.client.close) self.db = self.client.pymongo_test self.initial_lsids = {s["id"] for s in session_ids(self.client)} @@ -295,14 +295,14 @@ class TestSession(AsyncIntegrationTest): # Closing the client should end all sessions and clear the pool. self.assertEqual(len(client._topology._session_pool), _MAX_END_SESSIONS + 1) - await client.aclose() + await client.close() self.assertEqual(len(client._topology._session_pool), 0) end_sessions = [e for e in listener.started_events if e.command_name == "endSessions"] self.assertEqual(len(end_sessions), 2) # Closing again should not send any commands. listener.reset() - await client.aclose() + await client.close() self.assertEqual(len(listener.started_events), 0) async def test_client(self): @@ -790,7 +790,7 @@ class TestSession(AsyncIntegrationTest): # Ensure the collection exists. await self.client.pymongo_test.test_unacked_writes.insert_one({}) client = await async_rs_or_single_client(w=0, event_listeners=[self.listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) db = client.pymongo_test coll = db.test_unacked_writes ops: list = [ @@ -842,7 +842,7 @@ class TestCausalConsistency(AsyncUnitTest): @classmethod async def _tearDown_class(cls): - await cls.client.aclose() + await cls.client.close() @async_client_context.require_sessions async def asyncSetUp(self): @@ -927,7 +927,7 @@ class TestCausalConsistency(AsyncUnitTest): return await (await coll.aggregate_raw_batches([], session=session)).to_list() async def find_raw(coll, session): - return await (await coll.find_raw_batches({}, session=session)).to_list() + return await coll.find_raw_batches({}, session=session).to_list() await self._test_reads(aggregate) await self._test_reads(lambda coll, session: coll.find({}, session=session).to_list()) @@ -1156,7 +1156,7 @@ class TestClusterTime(AsyncIntegrationTest): client = await async_rs_or_single_client( event_listeners=[listener], heartbeatFrequencyMS=999999 ) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) collection = client.pymongo_test.collection # Prepare for tests of find() and aggregate(). await collection.insert_many([{} for _ in range(10)]) diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py index 373fda588..8fa1e70d0 100644 --- a/test/asynchronous/test_transactions.py +++ b/test/asynchronous/test_transactions.py @@ -74,7 +74,7 @@ class AsyncTransactionsBase(AsyncSpecRunner): @classmethod async def _tearDown_class(cls): for client in cls.mongos_clients: - await client.aclose() + await client.close() await super()._tearDown_class() def maybe_skip_scenario(self, test): @@ -121,7 +121,7 @@ class TestTransactions(AsyncTransactionsBase): async def test_transaction_write_concern_override(self): """Test txn overrides Client/Database/Collection write_concern.""" client = await async_rs_client(w=0) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) db = client.test coll = db.test await coll.insert_one({}) @@ -183,7 +183,7 @@ class TestTransactions(AsyncTransactionsBase): coll = client.test.test # Create the collection. await coll.insert_one({}) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) async with client.start_session() as s: # Session is pinned to Mongos. async with await s.start_transaction(): @@ -211,7 +211,7 @@ class TestTransactions(AsyncTransactionsBase): coll = client.test.test # Create the collection. await coll.insert_one({}) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) async with client.start_session() as s: # Session is pinned to Mongos. async with await s.start_transaction(): @@ -339,7 +339,7 @@ class TestTransactions(AsyncTransactionsBase): coll = client[self.db.name].test await coll.delete_many({}) listener.reset() - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) self.addAsyncCleanup(coll.drop) large_str = "\0" * (1 * 1024 * 1024) ops: List[InsertOne[RawBSONDocument]] = [ @@ -365,7 +365,7 @@ class TestTransactions(AsyncTransactionsBase): @async_client_context.require_transactions async def test_transaction_direct_connection(self): client = await async_single_client() - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) coll = client.pymongo_test.test # Make sure the collection exists. @@ -375,6 +375,9 @@ class TestTransactions(AsyncTransactionsBase): async def find(*args, **kwargs): return coll.find(*args, **kwargs) + async def find_raw_batches(*args, **kwargs): + return coll.find_raw_batches(*args, **kwargs) + ops = [ (coll.bulk_write, [[InsertOne[dict]({})]]), (coll.insert_one, [{}]), @@ -393,7 +396,7 @@ class TestTransactions(AsyncTransactionsBase): (coll.aggregate, [[]]), (find, [{}]), (coll.aggregate_raw_batches, [[]]), - (coll.find_raw_batches, [{}]), + (find_raw_batches, [{}]), (coll.database.command, ["find", coll.name]), ] for f, args in ops: @@ -452,7 +455,7 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase): async def test_callback_not_retried_after_timeout(self): listener = OvertCommandListener() client = await async_rs_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) coll = client[self.db.name].test async def callback(session): @@ -481,7 +484,7 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase): async def test_callback_not_retried_after_commit_timeout(self): listener = OvertCommandListener() client = await async_rs_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) coll = client[self.db.name].test async def callback(session): @@ -516,7 +519,7 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase): async def test_commit_not_retried_after_timeout(self): listener = OvertCommandListener() client = await async_rs_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) coll = client[self.db.name].test async def callback(session): diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index 2e256ec17..71044d153 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -535,7 +535,7 @@ class AsyncSpecRunner(AsyncIntegrationTest): self.pool_listener = pool_listener self.server_listener = server_listener # Close the client explicitly to avoid having too many threads open. - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) # Create session0 and session1. sessions = {} diff --git a/test/test_cursor.py b/test/test_cursor.py index 0d6186519..1cde4718f 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -1415,7 +1415,7 @@ class TestRawBatchCursor(IntegrationTest): c.drop() docs = [{"_id": i, "x": 3.0 * i} for i in range(10)] c.insert_many(docs) - batches = (c.find_raw_batches()).sort("_id").to_list() + batches = c.find_raw_batches().sort("_id").to_list() self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0])) @@ -1431,7 +1431,7 @@ class TestRawBatchCursor(IntegrationTest): with client.start_session() as session: with session.start_transaction(): batches = ( - (client[self.db.name].test.find_raw_batches(session=session)).sort("_id") + client[self.db.name].test.find_raw_batches(session=session).sort("_id") ).to_list() cmd = listener.started_events[0] self.assertEqual(cmd.command_name, "find") @@ -1461,7 +1461,7 @@ class TestRawBatchCursor(IntegrationTest): with self.fail_point( {"mode": {"times": 1}, "data": {"failCommands": ["find"], "closeConnection": True}} ): - batches = (client[self.db.name].test.find_raw_batches()).sort("_id").to_list() + batches = client[self.db.name].test.find_raw_batches().sort("_id").to_list() self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0])) @@ -1482,7 +1482,7 @@ class TestRawBatchCursor(IntegrationTest): db = client[self.db.name] with client.start_session(snapshot=True) as session: db.test.distinct("x", {}, session=session) - batches = (db.test.find_raw_batches(session=session)).sort("_id").to_list() + batches = db.test.find_raw_batches(session=session).sort("_id").to_list() self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0])) @@ -1493,7 +1493,7 @@ class TestRawBatchCursor(IntegrationTest): def test_explain(self): c = self.db.test c.insert_one({}) - explanation = (c.find_raw_batches()).explain() + explanation = c.find_raw_batches().explain() self.assertIsInstance(explanation, dict) def test_empty(self): @@ -1514,7 +1514,7 @@ class TestRawBatchCursor(IntegrationTest): c = self.db.test c.drop() c.insert_many({"_id": i} for i in range(200)) - result = b"".join((c.find_raw_batches(cursor_type=CursorType.EXHAUST)).to_list()) + result = b"".join(c.find_raw_batches(cursor_type=CursorType.EXHAUST).to_list()) self.assertEqual([{"_id": i} for i in range(200)], decode_all(result)) def test_server_error(self): @@ -1526,7 +1526,7 @@ class TestRawBatchCursor(IntegrationTest): def test_get_item(self): with self.assertRaises(InvalidOperation): - (self.db.test.find_raw_batches())[0] + self.db.test.find_raw_batches()[0] def test_collation(self): next(self.db.test.find_raw_batches(collation=Collation("en_US"))) diff --git a/test/test_session.py b/test/test_session.py index dfc894804..563b33c70 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -925,7 +925,7 @@ class TestCausalConsistency(UnitTest): return (coll.aggregate_raw_batches([], session=session)).to_list() def find_raw(coll, session): - return (coll.find_raw_batches({}, session=session)).to_list() + return coll.find_raw_batches({}, session=session).to_list() self._test_reads(aggregate) self._test_reads(lambda coll, session: coll.find({}, session=session).to_list()) diff --git a/test/test_transactions.py b/test/test_transactions.py index 4ee018647..b1869bec7 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -371,6 +371,9 @@ class TestTransactions(TransactionsBase): def find(*args, **kwargs): return coll.find(*args, **kwargs) + def find_raw_batches(*args, **kwargs): + return coll.find_raw_batches(*args, **kwargs) + ops = [ (coll.bulk_write, [[InsertOne[dict]({})]]), (coll.insert_one, [{}]), @@ -389,7 +392,7 @@ class TestTransactions(TransactionsBase): (coll.aggregate, [[]]), (find, [{}]), (coll.aggregate_raw_batches, [[]]), - (coll.find_raw_batches, [{}]), + (find_raw_batches, [{}]), (coll.database.command, ["find", coll.name]), ] for f, args in ops: diff --git a/tools/synchro.py b/tools/synchro.py index 5711e1f84..0c2aff130 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -92,7 +92,6 @@ replacements = { "asyncAssertRaisesExactly": "assertRaisesExactly", "get_async_mock_client": "get_mock_client", "aconnect": "_connect", - "aclose": "close", "async-transactions-ref": "transactions-ref", "async-snapshot-reads-ref": "snapshot-reads-ref", "default_async": "default",