From f69d330b25db33272fcccbe762668c1a7e8833ab Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Tue, 13 Aug 2024 19:17:45 -0500 Subject: [PATCH 1/3] PYTHON-4654 Clean up Async API to match Motor (#1789) --- doc/api/pymongo/asynchronous/mongo_client.rst | 2 +- pymongo/asynchronous/collection.py | 6 +- pymongo/asynchronous/encryption.py | 4 +- pymongo/asynchronous/mongo_client.py | 8 +- pymongo/synchronous/mongo_client.py | 4 + test/asynchronous/__init__.py | 6 +- test/asynchronous/test_client.py | 80 +++++++++---------- test/asynchronous/test_client_bulk_write.py | 32 ++++---- test/asynchronous/test_collection.py | 2 +- test/asynchronous/test_cursor.py | 40 +++++----- test/asynchronous/test_database.py | 2 +- test/asynchronous/test_session.py | 16 ++-- test/asynchronous/test_transactions.py | 23 +++--- test/asynchronous/utils_spec_runner.py | 2 +- test/test_cursor.py | 14 ++-- test/test_session.py | 2 +- test/test_transactions.py | 5 +- tools/synchro.py | 1 - 18 files changed, 128 insertions(+), 121 deletions(-) diff --git a/doc/api/pymongo/asynchronous/mongo_client.rst b/doc/api/pymongo/asynchronous/mongo_client.rst index 57aa33e3c..afbd802ff 100644 --- a/doc/api/pymongo/asynchronous/mongo_client.rst +++ b/doc/api/pymongo/asynchronous/mongo_client.rst @@ -6,7 +6,7 @@ .. autoclass:: pymongo.asynchronous.mongo_client.AsyncMongoClient(host='localhost', port=27017, document_class=dict, tz_aware=False, connect=True, **kwargs) - .. automethod:: aclose + .. automethod:: close .. describe:: c[db_name] || c.db_name diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index e634b449f..e5a54c090 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -1893,9 +1893,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): """ return AsyncCursor(self, *args, **kwargs) - async def find_raw_batches( - self, *args: Any, **kwargs: Any - ) -> AsyncRawBatchCursor[_DocumentType]: + def find_raw_batches(self, *args: Any, **kwargs: Any) -> AsyncRawBatchCursor[_DocumentType]: """Query the database and retrieve batches of raw BSON. Similar to the :meth:`find` method but returns a @@ -1907,7 +1905,7 @@ class AsyncCollection(common.BaseObject, Generic[_DocumentType]): :mod:`bson` module. >>> import bson - >>> cursor = await db.test.find_raw_batches() + >>> cursor = db.test.find_raw_batches() >>> async for batch in cursor: ... print(bson.decode_all(batch)) diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index f8c03b21c..8b63525f2 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -299,7 +299,7 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc] self.client_ref = None self.key_vault_coll = None if self.mongocryptd_client: - await self.mongocryptd_client.aclose() + await self.mongocryptd_client.close() self.mongocryptd_client = None @@ -439,7 +439,7 @@ class _Encrypter: self._closed = True await self._auto_encrypter.close() if self._internal_client: - await self._internal_client.aclose() + await self._internal_client.close() self._internal_client = None diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 90e40978a..fbbd9a4ee 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1378,7 +1378,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): return self async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - await self.aclose() + await self.close() # See PYTHON-3084. __iter__ = None @@ -1514,7 +1514,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): # command. pass - async def aclose(self) -> None: + async def close(self) -> None: """Cleanup client resources and disconnect from MongoDB. End all server sessions created by this client by sending one or more @@ -1541,6 +1541,10 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. await self._encrypter.close() + if not _IS_SYNC: + # Add support for contextlib.aclosing. + aclose = close + async def _get_topology(self) -> Topology: """Get the internal :class:`~pymongo.asynchronous.topology.Topology` object. diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 41b4db4f1..186316562 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1536,6 +1536,10 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. self._encrypter.close() + if not _IS_SYNC: + # Add support for contextlib.closing. + aclose = close + def _get_topology(self) -> Topology: """Get the internal :class:`~pymongo.topology.Topology` object. diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index a95a9e31b..900b260c2 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -193,7 +193,7 @@ class AsyncClientContext: self.connection_attempts.append(f"failed to connect client {client!r}: {exc}") return None finally: - await client.aclose() + await client.close() async def _init_client(self): self.client = await self._connect(host, port) @@ -409,7 +409,7 @@ class AsyncClientContext: else: raise finally: - await client.aclose() + await client.close() def _server_started_with_auth(self): # MongoDB >= 2.0 @@ -1089,7 +1089,7 @@ async def async_teardown(): await c.drop_database("pymongo_test2") await c.drop_database("pymongo_test_mike") await c.drop_database("pymongo_test_bernie") - await c.aclose() + await c.close() print_running_clients() diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index fa526b3ea..9489de156 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -137,7 +137,7 @@ class AsyncClientUnitTest(AsyncUnitTest): @classmethod async def _tearDown_class(cls): - await cls.client.aclose() + await cls.client.close() @pytest.fixture(autouse=True) def inject_fixtures(self, caplog): @@ -663,7 +663,7 @@ class TestClient(AsyncIntegrationTest): pass self.assertEqual(1, len(server._pool.conns)) self.assertTrue(conn in server._pool.conns) - await client.aclose() + await client.close() async def test_max_idle_time_reaper_removes_stale_minPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): @@ -679,7 +679,7 @@ class TestClient(AsyncIntegrationTest): self.assertGreaterEqual(len(server._pool.conns), 1) wait_until(lambda: conn not in server._pool.conns, "remove stale socket") wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket") - await client.aclose() + await client.close() async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): @@ -697,7 +697,7 @@ class TestClient(AsyncIntegrationTest): self.assertEqual(1, len(server._pool.conns)) wait_until(lambda: conn not in server._pool.conns, "remove stale socket") wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket") - await client.aclose() + await client.close() async def test_max_idle_time_reaper_removes_stale(self): with client_knobs(kill_cursor_frequency=0.1): @@ -717,7 +717,7 @@ class TestClient(AsyncIntegrationTest): lambda: len(server._pool.conns) == 0, "stale socket reaped and new one NOT added to the pool", ) - await client.aclose() + await client.close() async def test_min_pool_size(self): with client_knobs(kill_cursor_frequency=0.1): @@ -845,13 +845,13 @@ class TestClient(AsyncIntegrationTest): async def test_equality(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) c = await async_rs_or_single_client(seed, connect=False) - self.addAsyncCleanup(c.aclose) + self.addAsyncCleanup(c.close) self.assertEqual(async_client_context.client, c) # Explicitly test inequality self.assertFalse(async_client_context.client != c) c = await async_rs_or_single_client("invalid.com", connect=False) - self.addAsyncCleanup(c.aclose) + self.addAsyncCleanup(c.close) self.assertNotEqual(async_client_context.client, c) self.assertTrue(async_client_context.client != c) # Seeds differ: @@ -867,10 +867,10 @@ class TestClient(AsyncIntegrationTest): async def test_hashable(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) c = await async_rs_or_single_client(seed, connect=False) - self.addAsyncCleanup(c.aclose) + self.addAsyncCleanup(c.close) self.assertIn(c, {async_client_context.client}) c = await async_rs_or_single_client("invalid.com", connect=False) - self.addAsyncCleanup(c.aclose) + self.addAsyncCleanup(c.close) self.assertNotIn(c, {async_client_context.client}) async def test_host_w_port(self): @@ -940,7 +940,7 @@ class TestClient(AsyncIntegrationTest): self.assertIs(type(helper_doc), dict) self.assertEqual(helper_doc.keys(), cmd_doc.keys()) client = await async_rs_or_single_client(document_class=SON) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) async for doc in await client.list_databases(): self.assertIs(type(doc), dict) @@ -991,7 +991,7 @@ class TestClient(AsyncIntegrationTest): async def test_close(self): test_client = await async_rs_or_single_client() coll = test_client.pymongo_test.bar - await test_client.aclose() + await test_client.close() with self.assertRaises(InvalidOperation): await coll.count_documents({}) @@ -1024,7 +1024,7 @@ class TestClient(AsyncIntegrationTest): # Close the client and ensure the topology is closed. self.assertTrue(test_client._topology._opened) - await test_client.aclose() + await test_client.close() self.assertFalse(test_client._topology._opened) test_client = await async_rs_or_single_client() # The killCursors task should not need to re-open the topology. @@ -1037,7 +1037,7 @@ class TestClient(AsyncIntegrationTest): self.assertFalse(client._kill_cursors_executor._stopped) # Closing the client should stop the thread. - await client.aclose() + await client.close() self.assertTrue(client._kill_cursors_executor._stopped) # Reusing the closed client should raise an InvalidOperation error. @@ -1062,21 +1062,21 @@ class TestClient(AsyncIntegrationTest): self.assertTrue(kc_thread and kc_thread.is_alive()) # Tear down. - await client.aclose() + await client.close() async def test_close_does_not_open_servers(self): client = await async_rs_client(connect=False) topology = client._topology self.assertEqual(topology._servers, {}) - await client.aclose() + await client.close() self.assertEqual(topology._servers, {}) async def test_close_closes_sockets(self): client = await async_rs_client() - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) await client.test.test.find_one() topology = client._topology - await client.aclose() + await client.close() for server in topology._servers.values(): self.assertFalse(server._pool.conns) self.assertTrue(server._monitor._executor._stopped) @@ -1181,7 +1181,7 @@ class TestClient(AsyncIntegrationTest): uri = "mongodb://%s" % encoded_socket # Confirm we can do operations via the socket. client = await async_rs_or_single_client(uri) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) await client.pymongo_test.test.insert_one({"dummy": "object"}) dbs = await client.list_database_names() self.assertTrue("pymongo_test" in dbs) @@ -1206,7 +1206,7 @@ class TestClient(AsyncIntegrationTest): self.assertFalse(isinstance(await db.test.find_one(), SON)) c = await async_rs_or_single_client(document_class=SON) - self.addAsyncCleanup(c.aclose) + self.addAsyncCleanup(c.close) db = c.pymongo_test self.assertEqual(SON, c.codec_options.document_class) @@ -1248,7 +1248,7 @@ class TestClient(AsyncIntegrationTest): no_timeout = self.client timeout_sec = 1 timeout = await async_rs_or_single_client(socketTimeoutMS=1000 * timeout_sec) - self.addAsyncCleanup(timeout.aclose) + self.addAsyncCleanup(timeout.close) await no_timeout.pymongo_test.drop_collection("test") await no_timeout.pymongo_test.test.insert_one({"x": 1}) @@ -1310,7 +1310,7 @@ class TestClient(AsyncIntegrationTest): self.assertRaises(ValueError, AsyncMongoClient, tz_aware="foo") aware = await async_rs_or_single_client(tz_aware=True) - self.addAsyncCleanup(aware.aclose) + self.addAsyncCleanup(aware.close) naive = self.client await aware.pymongo_test.drop_collection("test") @@ -1340,7 +1340,7 @@ class TestClient(AsyncIntegrationTest): uri += "/?replicaSet=" + (async_client_context.replica_set_name or "") client = await async_rs_or_single_client_noauth(uri) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) await client.pymongo_test.test.insert_one({"dummy": "object"}) await client.pymongo_test_bernie.test.insert_one({"dummy": "object"}) @@ -1442,7 +1442,7 @@ class TestClient(AsyncIntegrationTest): # to avoid race conditions caused by replica set failover or idle # socket reaping. client = await async_single_client() - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) await client.pymongo_test.test.find_one() pool = await async_get_pool(client) socket_count = len(pool.conns) @@ -1467,7 +1467,7 @@ class TestClient(AsyncIntegrationTest): self.addAsyncCleanup(async_client_context.client.drop_database, "test_lazy_connect_w0") client = await async_rs_or_single_client(connect=False, w=0) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) await client.test_lazy_connect_w0.test.insert_one({}) async def predicate(): @@ -1476,7 +1476,7 @@ class TestClient(AsyncIntegrationTest): await async_wait_until(predicate, "find one document") client = await async_rs_or_single_client(connect=False, w=0) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) await client.test_lazy_connect_w0.test.update_one({}, {"$set": {"x": 1}}) async def predicate(): @@ -1485,7 +1485,7 @@ class TestClient(AsyncIntegrationTest): await async_wait_until(predicate, "update one document") client = await async_rs_or_single_client(connect=False, w=0) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) await client.test_lazy_connect_w0.test.delete_one({}) async def predicate(): @@ -1498,7 +1498,7 @@ class TestClient(AsyncIntegrationTest): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. client = await async_rs_or_single_client(maxPoolSize=1, retryReads=False) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) collection = client.pymongo_test.test pool = await async_get_pool(client) pool._check_interval_seconds = None # Never check. @@ -1611,7 +1611,7 @@ class TestClient(AsyncIntegrationTest): # closer to 0.5 sec with heartbeatFrequencyMS configured. self.assertAlmostEqual(heartbeat_times[1] - heartbeat_times[0], 0.5, delta=2) - await client.aclose() + await client.close() finally: ServerHeartbeatStartedEvent.__init__ = old_init # type: ignore @@ -1709,7 +1709,7 @@ class TestClient(AsyncIntegrationTest): async def test_reset_during_update_pool(self): client = await async_rs_or_single_client(minPoolSize=10) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) await client.admin.command("ping") pool = await async_get_pool(client) generation = pool.gen.get_overall() @@ -1758,7 +1758,7 @@ class TestClient(AsyncIntegrationTest): client = await async_rs_or_single_client( serverSelectionTimeoutMS=3000, minPoolSize=min_pool_size, connect=False ) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) # Create a single connection in the pool. await client.admin.command("ping") @@ -1793,7 +1793,7 @@ class TestClient(AsyncIntegrationTest): await client.admin.command("ping") self.assertEqual(len(client.nodes), 1) self.assertEqual(client._topology_settings.get_topology_type(), TOPOLOGY_TYPE.Single) - await client.aclose() + await client.close() # direct_connection=False should result in RS topology. client = await async_rs_or_single_client(directConnection=False) @@ -1803,7 +1803,7 @@ class TestClient(AsyncIntegrationTest): client._topology_settings.get_topology_type(), [TOPOLOGY_TYPE.ReplicaSetNoPrimary, TOPOLOGY_TYPE.ReplicaSetWithPrimary], ) - await client.aclose() + await client.close() # directConnection=True, should error with multiple hosts as a list. with self.assertRaises(ConfigurationError): @@ -1827,7 +1827,7 @@ class TestClient(AsyncIntegrationTest): "invalid:27017", heartbeatFrequencyMS=3, serverSelectionTimeoutMS=150 ) initial_count = server_description_count() - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) with self.assertRaises(ServerSelectionTimeoutError): await client.test.test.find_one() gc.collect() @@ -1841,7 +1841,7 @@ class TestClient(AsyncIntegrationTest): @async_client_context.require_failCommand_fail_point async def test_network_error_message(self): client = await async_single_client(retryReads=False) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) await client.admin.command("ping") # connect async with self.fail_point( {"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}} @@ -1860,7 +1860,7 @@ class TestClient(AsyncIntegrationTest): await cursor.next() c_id = cursor.cursor_id self.assertIsNotNone(c_id) - await client.aclose() + await client.close() # Add cursor to kill cursors queue del cursor wait_until( @@ -2315,7 +2315,7 @@ class TestMongoClientFailover(AsyncMockClientTest): replicaSet="rs", heartbeatFrequencyMS=500, ) - self.addAsyncCleanup(c.aclose) + self.addAsyncCleanup(c.close) wait_until(lambda: len(c.nodes) == 3, "connect") @@ -2342,7 +2342,7 @@ class TestMongoClientFailover(AsyncMockClientTest): retryReads=False, serverSelectionTimeoutMS=1000, ) - self.addAsyncCleanup(c.aclose) + self.addAsyncCleanup(c.close) wait_until(lambda: len(c.nodes) == 3, "connect") @@ -2377,7 +2377,7 @@ class TestMongoClientFailover(AsyncMockClientTest): retryReads=False, serverSelectionTimeoutMS=1000, ) - self.addAsyncCleanup(c.aclose) + self.addAsyncCleanup(c.close) # Set host-specific information so we can test whether it is reset. c.set_wire_version_range("a:1", 2, 6) @@ -2453,7 +2453,7 @@ class TestClientPool(AsyncMockClientTest): minPoolSize=1, # minPoolSize event_listeners=[listener], ) - self.addAsyncCleanup(c.aclose) + self.addAsyncCleanup(c.close) wait_until(lambda: len(c.nodes) == 3, "connect") self.assertEqual(await c.address, ("a", 1)) @@ -2483,7 +2483,7 @@ class TestClientPool(AsyncMockClientTest): minPoolSize=1, # minPoolSize event_listeners=[listener], ) - self.addAsyncCleanup(c.aclose) + self.addAsyncCleanup(c.close) wait_until(lambda: len(c.nodes) == 1, "connect") self.assertEqual(await c.address, ("c", 3)) diff --git a/test/asynchronous/test_client_bulk_write.py b/test/asynchronous/test_client_bulk_write.py index 4fe4fce2d..eea0b4e8e 100644 --- a/test/asynchronous/test_client_bulk_write.py +++ b/test/asynchronous/test_client_bulk_write.py @@ -43,7 +43,7 @@ class TestClientBulkWrite(AsyncIntegrationTest): @async_client_context.require_version_min(8, 0, 0, -24) async def test_returns_error_if_no_namespace_provided(self): client = await async_rs_or_single_client() - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) models = [InsertOne(document={"a": "b"})] with self.assertRaises(InvalidOperation) as context: @@ -65,7 +65,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): async def test_batch_splits_if_num_operations_too_large(self): listener = OvertCommandListener() client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) models = [] for _ in range(self.max_write_batch_size + 1): @@ -90,7 +90,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): async def test_batch_splits_if_ops_payload_too_large(self): listener = OvertCommandListener() client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) models = [] num_models = int(self.max_message_size_bytes / self.max_bson_object_size + 1) @@ -126,7 +126,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): event_listeners=[listener], retryWrites=False, ) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) fail_command = { "configureFailPoint": "failCommand", @@ -165,7 +165,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): async def test_collects_write_errors_across_batches_unordered(self): listener = OvertCommandListener() client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) collection = client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -195,7 +195,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): async def test_collects_write_errors_across_batches_ordered(self): listener = OvertCommandListener() client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) collection = client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -225,7 +225,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): async def test_handles_cursor_requiring_getMore(self): listener = OvertCommandListener() client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) collection = client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -266,7 +266,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): async def test_handles_cursor_requiring_getMore_within_transaction(self): listener = OvertCommandListener() client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) collection = client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -309,7 +309,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): async def test_handles_getMore_error(self): listener = OvertCommandListener() client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) collection = client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -363,7 +363,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): async def test_returns_error_if_unacknowledged_too_large_insert(self): listener = OvertCommandListener() client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) b_repeated = "b" * self.max_bson_object_size @@ -419,7 +419,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): async def test_no_batch_splits_if_new_namespace_is_not_too_large(self): listener = OvertCommandListener() client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) num_models, models = await self._setup_namespace_test_models() models.append( @@ -450,7 +450,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): async def test_batch_splits_if_new_namespace_is_too_large(self): listener = OvertCommandListener() client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) num_models, models = await self._setup_namespace_test_models() c_repeated = "c" * 200 @@ -487,7 +487,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): @async_client_context.require_version_min(8, 0, 0, -24) async def test_returns_error_if_no_writes_can_be_added_to_ops(self): client = await async_rs_or_single_client() - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) # Document too large. b_repeated = "b" * self.max_message_size_bytes @@ -512,7 +512,7 @@ class TestClientBulkWriteCRUD(AsyncIntegrationTest): kms_providers={"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}}, ) client = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) models = [InsertOne(namespace="db.coll", document={"a": "b"})] with self.assertRaises(InvalidOperation) as context: @@ -535,7 +535,7 @@ class TestClientBulkWriteCSOT(AsyncIntegrationTest): _OVERHEAD = 500 internal_client = await async_rs_or_single_client(timeoutMS=None) - self.addAsyncCleanup(internal_client.aclose) + self.addAsyncCleanup(internal_client.close) collection = internal_client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -566,7 +566,7 @@ class TestClientBulkWriteCSOT(AsyncIntegrationTest): timeoutMS=2000, w="majority", ) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) await client.admin.command("ping") # Init the client first. with self.assertRaises(ClientBulkWriteException) as context: await client.bulk_write(models=models) diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index 90a53e8b8..10d64a525 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -1820,7 +1820,7 @@ class AsyncTestCollection(AsyncIntegrationTest): await self.db.test.insert_many([{"i": i} for i in range(150)]) client = await async_rs_or_single_client(maxPoolSize=1) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) pool = await async_get_pool(client) # Make sure the socket is returned after exhaustion. diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index d6d56244f..9dd67f2da 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -354,7 +354,7 @@ class TestCursor(AsyncIntegrationTest): # Do not add readConcern level to explain. listener = AllowListEventListener("explain") client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) coll = client.pymongo_test.test.with_options(read_concern=ReadConcern(level="local")) self.assertTrue(await coll.find().explain()) started = listener.started_events @@ -1262,7 +1262,7 @@ class TestCursor(AsyncIntegrationTest): listener = AllowListEventListener("killCursors") client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) coll = client[self.db.name].test_close_kills_cursors # Add some test data. @@ -1301,7 +1301,7 @@ class TestCursor(AsyncIntegrationTest): async def test_timeout_kills_cursor_asynchronously(self): listener = AllowListEventListener("killCursors") client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) coll = client[self.db.name].test_timeout_kills_cursor # Add some test data. @@ -1359,7 +1359,7 @@ class TestCursor(AsyncIntegrationTest): async def test_getMore_does_not_send_readPreference(self): listener = AllowListEventListener("find", "getMore") client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) # We never send primary read preference so override the default. coll = client[self.db.name].get_collection( "test", read_preference=ReadPreference.PRIMARY_PREFERRED @@ -1424,7 +1424,7 @@ class TestRawBatchCursor(AsyncIntegrationTest): await c.drop() docs = [{"_id": i, "x": 3.0 * i} for i in range(10)] await c.insert_many(docs) - batches = await (await c.find_raw_batches()).sort("_id").to_list() + batches = await c.find_raw_batches().sort("_id").to_list() self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0])) @@ -1440,7 +1440,7 @@ class TestRawBatchCursor(AsyncIntegrationTest): async with client.start_session() as session: async with await session.start_transaction(): batches = await ( - (await client[self.db.name].test.find_raw_batches(session=session)).sort("_id") + client[self.db.name].test.find_raw_batches(session=session).sort("_id") ).to_list() cmd = listener.started_events[0] self.assertEqual(cmd.command_name, "find") @@ -1470,9 +1470,7 @@ class TestRawBatchCursor(AsyncIntegrationTest): async with self.fail_point( {"mode": {"times": 1}, "data": {"failCommands": ["find"], "closeConnection": True}} ): - batches = ( - await (await client[self.db.name].test.find_raw_batches()).sort("_id").to_list() - ) + batches = await client[self.db.name].test.find_raw_batches().sort("_id").to_list() self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0])) @@ -1493,7 +1491,7 @@ class TestRawBatchCursor(AsyncIntegrationTest): db = client[self.db.name] async with client.start_session(snapshot=True) as session: await db.test.distinct("x", {}, session=session) - batches = await (await db.test.find_raw_batches(session=session)).sort("_id").to_list() + batches = await db.test.find_raw_batches(session=session).sort("_id").to_list() self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0])) @@ -1504,18 +1502,18 @@ class TestRawBatchCursor(AsyncIntegrationTest): async def test_explain(self): c = self.db.test await c.insert_one({}) - explanation = await (await c.find_raw_batches()).explain() + explanation = await c.find_raw_batches().explain() self.assertIsInstance(explanation, dict) async def test_empty(self): await self.db.test.drop() - cursor = await self.db.test.find_raw_batches() + cursor = self.db.test.find_raw_batches() with self.assertRaises(StopAsyncIteration): await anext(cursor) async def test_clone(self): await self.db.test.insert_one({}) - cursor = await self.db.test.find_raw_batches() + cursor = self.db.test.find_raw_batches() # Copy of a RawBatchCursor is also a RawBatchCursor, not a Cursor. self.assertIsInstance(await anext(cursor.clone()), bytes) self.assertIsInstance(await anext(copy.copy(cursor)), bytes) @@ -1525,24 +1523,22 @@ class TestRawBatchCursor(AsyncIntegrationTest): c = self.db.test await c.drop() await c.insert_many({"_id": i} for i in range(200)) - result = b"".join( - await (await c.find_raw_batches(cursor_type=CursorType.EXHAUST)).to_list() - ) + result = b"".join(await c.find_raw_batches(cursor_type=CursorType.EXHAUST).to_list()) self.assertEqual([{"_id": i} for i in range(200)], decode_all(result)) async def test_server_error(self): with self.assertRaises(OperationFailure) as exc: - await anext(await self.db.test.find_raw_batches({"x": {"$bad": 1}})) + await anext(self.db.test.find_raw_batches({"x": {"$bad": 1}})) # The server response was decoded, not left raw. self.assertIsInstance(exc.exception.details, dict) async def test_get_item(self): with self.assertRaises(InvalidOperation): - (await self.db.test.find_raw_batches())[0] + self.db.test.find_raw_batches()[0] async def test_collation(self): - await anext(await self.db.test.find_raw_batches(collation=Collation("en_US"))) + await anext(self.db.test.find_raw_batches(collation=Collation("en_US"))) @async_client_context.require_no_mmap # MMAPv1 does not support read concern async def test_read_concern(self): @@ -1550,7 +1546,7 @@ class TestRawBatchCursor(AsyncIntegrationTest): {} ) c = self.db.get_collection("test", read_concern=ReadConcern("majority")) - await anext(await c.find_raw_batches()) + await anext(c.find_raw_batches()) async def test_monitoring(self): listener = EventListener() @@ -1560,7 +1556,7 @@ class TestRawBatchCursor(AsyncIntegrationTest): await c.insert_many([{"_id": i} for i in range(10)]) listener.reset() - cursor = await c.find_raw_batches(batch_size=4) + cursor = c.find_raw_batches(batch_size=4) # First raw batch of 4 documents. await anext(cursor) @@ -1766,7 +1762,7 @@ class TestRawBatchCommandCursor(AsyncIntegrationTest): async def test_exhaust_cursor_db_set(self): listener = OvertCommandListener() client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) c = client.pymongo_test.test await c.delete_many({}) await c.insert_many([{"_id": i} for i in range(3)]) diff --git a/test/asynchronous/test_database.py b/test/asynchronous/test_database.py index c20a74d3d..8f6886a2a 100644 --- a/test/asynchronous/test_database.py +++ b/test/asynchronous/test_database.py @@ -236,7 +236,7 @@ class TestDatabase(AsyncIntegrationTest): async def test_check_exists(self): listener = OvertCommandListener() client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) db = client[self.db.name] await db.drop_collection("unique") await db.create_collection("unique", check_exists=True) diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index b8387b7f6..1e1f5659b 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -99,7 +99,7 @@ class TestSession(AsyncIntegrationTest): @classmethod async def _tearDown_class(cls): monitoring._SENSITIVE_COMMANDS.update(cls.sensitive_commands) - await cls.client2.aclose() + await cls.client2.close() await super()._tearDown_class() async def asyncSetUp(self): @@ -108,7 +108,7 @@ class TestSession(AsyncIntegrationTest): self.client = await async_rs_or_single_client( event_listeners=[self.listener, self.session_checker_listener] ) - self.addAsyncCleanup(self.client.aclose) + self.addAsyncCleanup(self.client.close) self.db = self.client.pymongo_test self.initial_lsids = {s["id"] for s in session_ids(self.client)} @@ -295,14 +295,14 @@ class TestSession(AsyncIntegrationTest): # Closing the client should end all sessions and clear the pool. self.assertEqual(len(client._topology._session_pool), _MAX_END_SESSIONS + 1) - await client.aclose() + await client.close() self.assertEqual(len(client._topology._session_pool), 0) end_sessions = [e for e in listener.started_events if e.command_name == "endSessions"] self.assertEqual(len(end_sessions), 2) # Closing again should not send any commands. listener.reset() - await client.aclose() + await client.close() self.assertEqual(len(listener.started_events), 0) async def test_client(self): @@ -790,7 +790,7 @@ class TestSession(AsyncIntegrationTest): # Ensure the collection exists. await self.client.pymongo_test.test_unacked_writes.insert_one({}) client = await async_rs_or_single_client(w=0, event_listeners=[self.listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) db = client.pymongo_test coll = db.test_unacked_writes ops: list = [ @@ -842,7 +842,7 @@ class TestCausalConsistency(AsyncUnitTest): @classmethod async def _tearDown_class(cls): - await cls.client.aclose() + await cls.client.close() @async_client_context.require_sessions async def asyncSetUp(self): @@ -927,7 +927,7 @@ class TestCausalConsistency(AsyncUnitTest): return await (await coll.aggregate_raw_batches([], session=session)).to_list() async def find_raw(coll, session): - return await (await coll.find_raw_batches({}, session=session)).to_list() + return await coll.find_raw_batches({}, session=session).to_list() await self._test_reads(aggregate) await self._test_reads(lambda coll, session: coll.find({}, session=session).to_list()) @@ -1156,7 +1156,7 @@ class TestClusterTime(AsyncIntegrationTest): client = await async_rs_or_single_client( event_listeners=[listener], heartbeatFrequencyMS=999999 ) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) collection = client.pymongo_test.collection # Prepare for tests of find() and aggregate(). await collection.insert_many([{} for _ in range(10)]) diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py index 373fda588..8fa1e70d0 100644 --- a/test/asynchronous/test_transactions.py +++ b/test/asynchronous/test_transactions.py @@ -74,7 +74,7 @@ class AsyncTransactionsBase(AsyncSpecRunner): @classmethod async def _tearDown_class(cls): for client in cls.mongos_clients: - await client.aclose() + await client.close() await super()._tearDown_class() def maybe_skip_scenario(self, test): @@ -121,7 +121,7 @@ class TestTransactions(AsyncTransactionsBase): async def test_transaction_write_concern_override(self): """Test txn overrides Client/Database/Collection write_concern.""" client = await async_rs_client(w=0) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) db = client.test coll = db.test await coll.insert_one({}) @@ -183,7 +183,7 @@ class TestTransactions(AsyncTransactionsBase): coll = client.test.test # Create the collection. await coll.insert_one({}) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) async with client.start_session() as s: # Session is pinned to Mongos. async with await s.start_transaction(): @@ -211,7 +211,7 @@ class TestTransactions(AsyncTransactionsBase): coll = client.test.test # Create the collection. await coll.insert_one({}) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) async with client.start_session() as s: # Session is pinned to Mongos. async with await s.start_transaction(): @@ -339,7 +339,7 @@ class TestTransactions(AsyncTransactionsBase): coll = client[self.db.name].test await coll.delete_many({}) listener.reset() - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) self.addAsyncCleanup(coll.drop) large_str = "\0" * (1 * 1024 * 1024) ops: List[InsertOne[RawBSONDocument]] = [ @@ -365,7 +365,7 @@ class TestTransactions(AsyncTransactionsBase): @async_client_context.require_transactions async def test_transaction_direct_connection(self): client = await async_single_client() - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) coll = client.pymongo_test.test # Make sure the collection exists. @@ -375,6 +375,9 @@ class TestTransactions(AsyncTransactionsBase): async def find(*args, **kwargs): return coll.find(*args, **kwargs) + async def find_raw_batches(*args, **kwargs): + return coll.find_raw_batches(*args, **kwargs) + ops = [ (coll.bulk_write, [[InsertOne[dict]({})]]), (coll.insert_one, [{}]), @@ -393,7 +396,7 @@ class TestTransactions(AsyncTransactionsBase): (coll.aggregate, [[]]), (find, [{}]), (coll.aggregate_raw_batches, [[]]), - (coll.find_raw_batches, [{}]), + (find_raw_batches, [{}]), (coll.database.command, ["find", coll.name]), ] for f, args in ops: @@ -452,7 +455,7 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase): async def test_callback_not_retried_after_timeout(self): listener = OvertCommandListener() client = await async_rs_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) coll = client[self.db.name].test async def callback(session): @@ -481,7 +484,7 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase): async def test_callback_not_retried_after_commit_timeout(self): listener = OvertCommandListener() client = await async_rs_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) coll = client[self.db.name].test async def callback(session): @@ -516,7 +519,7 @@ class TestTransactionsConvenientAPI(AsyncTransactionsBase): async def test_commit_not_retried_after_timeout(self): listener = OvertCommandListener() client = await async_rs_client(event_listeners=[listener]) - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) coll = client[self.db.name].test async def callback(session): diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index 2e256ec17..71044d153 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -535,7 +535,7 @@ class AsyncSpecRunner(AsyncIntegrationTest): self.pool_listener = pool_listener self.server_listener = server_listener # Close the client explicitly to avoid having too many threads open. - self.addAsyncCleanup(client.aclose) + self.addAsyncCleanup(client.close) # Create session0 and session1. sessions = {} diff --git a/test/test_cursor.py b/test/test_cursor.py index 0d6186519..1cde4718f 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -1415,7 +1415,7 @@ class TestRawBatchCursor(IntegrationTest): c.drop() docs = [{"_id": i, "x": 3.0 * i} for i in range(10)] c.insert_many(docs) - batches = (c.find_raw_batches()).sort("_id").to_list() + batches = c.find_raw_batches().sort("_id").to_list() self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0])) @@ -1431,7 +1431,7 @@ class TestRawBatchCursor(IntegrationTest): with client.start_session() as session: with session.start_transaction(): batches = ( - (client[self.db.name].test.find_raw_batches(session=session)).sort("_id") + client[self.db.name].test.find_raw_batches(session=session).sort("_id") ).to_list() cmd = listener.started_events[0] self.assertEqual(cmd.command_name, "find") @@ -1461,7 +1461,7 @@ class TestRawBatchCursor(IntegrationTest): with self.fail_point( {"mode": {"times": 1}, "data": {"failCommands": ["find"], "closeConnection": True}} ): - batches = (client[self.db.name].test.find_raw_batches()).sort("_id").to_list() + batches = client[self.db.name].test.find_raw_batches().sort("_id").to_list() self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0])) @@ -1482,7 +1482,7 @@ class TestRawBatchCursor(IntegrationTest): db = client[self.db.name] with client.start_session(snapshot=True) as session: db.test.distinct("x", {}, session=session) - batches = (db.test.find_raw_batches(session=session)).sort("_id").to_list() + batches = db.test.find_raw_batches(session=session).sort("_id").to_list() self.assertEqual(1, len(batches)) self.assertEqual(docs, decode_all(batches[0])) @@ -1493,7 +1493,7 @@ class TestRawBatchCursor(IntegrationTest): def test_explain(self): c = self.db.test c.insert_one({}) - explanation = (c.find_raw_batches()).explain() + explanation = c.find_raw_batches().explain() self.assertIsInstance(explanation, dict) def test_empty(self): @@ -1514,7 +1514,7 @@ class TestRawBatchCursor(IntegrationTest): c = self.db.test c.drop() c.insert_many({"_id": i} for i in range(200)) - result = b"".join((c.find_raw_batches(cursor_type=CursorType.EXHAUST)).to_list()) + result = b"".join(c.find_raw_batches(cursor_type=CursorType.EXHAUST).to_list()) self.assertEqual([{"_id": i} for i in range(200)], decode_all(result)) def test_server_error(self): @@ -1526,7 +1526,7 @@ class TestRawBatchCursor(IntegrationTest): def test_get_item(self): with self.assertRaises(InvalidOperation): - (self.db.test.find_raw_batches())[0] + self.db.test.find_raw_batches()[0] def test_collation(self): next(self.db.test.find_raw_batches(collation=Collation("en_US"))) diff --git a/test/test_session.py b/test/test_session.py index dfc894804..563b33c70 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -925,7 +925,7 @@ class TestCausalConsistency(UnitTest): return (coll.aggregate_raw_batches([], session=session)).to_list() def find_raw(coll, session): - return (coll.find_raw_batches({}, session=session)).to_list() + return coll.find_raw_batches({}, session=session).to_list() self._test_reads(aggregate) self._test_reads(lambda coll, session: coll.find({}, session=session).to_list()) diff --git a/test/test_transactions.py b/test/test_transactions.py index 4ee018647..b1869bec7 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -371,6 +371,9 @@ class TestTransactions(TransactionsBase): def find(*args, **kwargs): return coll.find(*args, **kwargs) + def find_raw_batches(*args, **kwargs): + return coll.find_raw_batches(*args, **kwargs) + ops = [ (coll.bulk_write, [[InsertOne[dict]({})]]), (coll.insert_one, [{}]), @@ -389,7 +392,7 @@ class TestTransactions(TransactionsBase): (coll.aggregate, [[]]), (find, [{}]), (coll.aggregate_raw_batches, [[]]), - (coll.find_raw_batches, [{}]), + (find_raw_batches, [{}]), (coll.database.command, ["find", coll.name]), ] for f, args in ops: diff --git a/tools/synchro.py b/tools/synchro.py index 5711e1f84..0c2aff130 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -92,7 +92,6 @@ replacements = { "asyncAssertRaisesExactly": "assertRaisesExactly", "get_async_mock_client": "get_mock_client", "aconnect": "_connect", - "aclose": "close", "async-transactions-ref": "transactions-ref", "async-snapshot-reads-ref": "snapshot-reads-ref", "default_async": "default", From f2f75fc1c8716622b7e04fcbbc327dc000afec02 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Tue, 13 Aug 2024 18:32:48 -0700 Subject: [PATCH 2/3] PYTHON-4659 Fix async with TLS (#1793) --- pymongo/network_layer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index f1c378b9b..d99b4fee4 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -77,6 +77,7 @@ async def async_sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> Non async def _async_sendall_ssl( sock: Union[socket.socket, _sslConn], buf: bytes, loop: AbstractEventLoop ) -> None: + view = memoryview(buf) fd = sock.fileno() sent = 0 @@ -89,7 +90,7 @@ async def _async_sendall_ssl( while sent < len(buf): try: - sent += sock.send(buf) + sent += sock.send(view[sent:]) except BLOCKING_IO_ERRORS as exc: fd = sock.fileno() # Check for closed socket. From adf8817df8f3fd47d82cf4c5e3bd4e93cfdfb602 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Wed, 14 Aug 2024 13:13:56 -0500 Subject: [PATCH 3/3] PYTHON-4584 Add length option to Cursor.to_list for motor compat (#1791) --- gridfs/asynchronous/grid_file.py | 12 ++++++++++-- gridfs/synchronous/grid_file.py | 12 ++++++++++-- pymongo/asynchronous/command_cursor.py | 27 ++++++++++++++++++++------ pymongo/asynchronous/cursor.py | 27 ++++++++++++++++++++------ pymongo/synchronous/command_cursor.py | 27 ++++++++++++++++++++------ pymongo/synchronous/cursor.py | 27 ++++++++++++++++++++------ test/asynchronous/test_cursor.py | 27 ++++++++++++++++++++++++++ test/test_cursor.py | 27 ++++++++++++++++++++++++++ test/test_gridfs.py | 6 ++++++ 9 files changed, 164 insertions(+), 28 deletions(-) diff --git a/gridfs/asynchronous/grid_file.py b/gridfs/asynchronous/grid_file.py index 303abe705..4d6140750 100644 --- a/gridfs/asynchronous/grid_file.py +++ b/gridfs/asynchronous/grid_file.py @@ -1892,8 +1892,16 @@ class AsyncGridOutCursor(AsyncCursor): next_file = await super().next() return AsyncGridOut(self._root_collection, file_document=next_file, session=self.session) - async def to_list(self) -> list[AsyncGridOut]: - return [x async for x in self] # noqa: C416,RUF100 + async def to_list(self, length: Optional[int] = None) -> list[AsyncGridOut]: + """Convert the cursor to a list.""" + if length is None: + return [x async for x in self] # noqa: C416,RUF100 + if length < 1: + raise ValueError("to_list() length must be greater than 0") + ret = [] + for _ in range(length): + ret.append(await self.next()) + return ret __anext__ = next diff --git a/gridfs/synchronous/grid_file.py b/gridfs/synchronous/grid_file.py index 1e3d265d4..bc2e29a61 100644 --- a/gridfs/synchronous/grid_file.py +++ b/gridfs/synchronous/grid_file.py @@ -1878,8 +1878,16 @@ class GridOutCursor(Cursor): next_file = super().next() return GridOut(self._root_collection, file_document=next_file, session=self.session) - def to_list(self) -> list[GridOut]: - return [x for x in self] # noqa: C416,RUF100 + def to_list(self, length: Optional[int] = None) -> list[GridOut]: + """Convert the cursor to a list.""" + if length is None: + return [x for x in self] # noqa: C416,RUF100 + if length < 1: + raise ValueError("to_list() length must be greater than 0") + ret = [] + for _ in range(length): + ret.append(self.next()) + return ret __next__ = next diff --git a/pymongo/asynchronous/command_cursor.py b/pymongo/asynchronous/command_cursor.py index b28f983b1..b2cd345f6 100644 --- a/pymongo/asynchronous/command_cursor.py +++ b/pymongo/asynchronous/command_cursor.py @@ -346,13 +346,17 @@ class AsyncCommandCursor(Generic[_DocumentType]): else: return None - async def _next_batch(self, result: list) -> bool: - """Get all available documents from the cursor.""" + async def _next_batch(self, result: list, total: Optional[int] = None) -> bool: + """Get all or some available documents from the cursor.""" if not len(self._data) and not self._killed: await self._refresh() if len(self._data): - result.extend(self._data) - self._data.clear() + if total is None: + result.extend(self._data) + self._data.clear() + else: + for _ in range(min(len(self._data), total)): + result.append(self._data.popleft()) return True else: return False @@ -381,21 +385,32 @@ class AsyncCommandCursor(Generic[_DocumentType]): async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: await self.close() - async def to_list(self) -> list[_DocumentType]: + async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]: """Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``. To use:: >>> await cursor.to_list() + Or, so read at most n items from the cursor:: + + >>> await cursor.to_list(n) + If the cursor is empty or has no more results, an empty list will be returned. .. versionadded:: 4.9 """ res: list[_DocumentType] = [] + remaining = length + if isinstance(length, int) and length < 1: + raise ValueError("to_list() length must be greater than 0") while self.alive: - if not await self._next_batch(res): + if not await self._next_batch(res, remaining): break + if length is not None: + remaining = length - len(res) + if remaining == 0: + break return res diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index 8421667be..bae77bb30 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -1260,16 +1260,20 @@ class AsyncCursor(Generic[_DocumentType]): else: raise StopAsyncIteration - async def _next_batch(self, result: list) -> bool: - """Get all available documents from the cursor.""" + async def _next_batch(self, result: list, total: Optional[int] = None) -> bool: + """Get all or some documents from the cursor.""" if not self._exhaust_checked: self._exhaust_checked = True await self._supports_exhaust() if self._empty: return False if len(self._data) or await self._refresh(): - result.extend(self._data) - self._data.clear() + if total is None: + result.extend(self._data) + self._data.clear() + else: + for _ in range(min(len(self._data), total)): + result.append(self._data.popleft()) return True else: return False @@ -1286,21 +1290,32 @@ class AsyncCursor(Generic[_DocumentType]): async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: await self.close() - async def to_list(self) -> list[_DocumentType]: + async def to_list(self, length: Optional[int] = None) -> list[_DocumentType]: """Converts the contents of this cursor to a list more efficiently than ``[doc async for doc in cursor]``. To use:: >>> await cursor.to_list() + Or, so read at most n items from the cursor:: + + >>> await cursor.to_list(n) + If the cursor is empty or has no more results, an empty list will be returned. .. versionadded:: 4.9 """ res: list[_DocumentType] = [] + remaining = length + if isinstance(length, int) and length < 1: + raise ValueError("to_list() length must be greater than 0") while self.alive: - if not await self._next_batch(res): + if not await self._next_batch(res, remaining): break + if length is not None: + remaining = length - len(res) + if remaining == 0: + break return res diff --git a/pymongo/synchronous/command_cursor.py b/pymongo/synchronous/command_cursor.py index 86fa69dcb..da05bf1a3 100644 --- a/pymongo/synchronous/command_cursor.py +++ b/pymongo/synchronous/command_cursor.py @@ -346,13 +346,17 @@ class CommandCursor(Generic[_DocumentType]): else: return None - def _next_batch(self, result: list) -> bool: - """Get all available documents from the cursor.""" + def _next_batch(self, result: list, total: Optional[int] = None) -> bool: + """Get all or some available documents from the cursor.""" if not len(self._data) and not self._killed: self._refresh() if len(self._data): - result.extend(self._data) - self._data.clear() + if total is None: + result.extend(self._data) + self._data.clear() + else: + for _ in range(min(len(self._data), total)): + result.append(self._data.popleft()) return True else: return False @@ -381,21 +385,32 @@ class CommandCursor(Generic[_DocumentType]): def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.close() - def to_list(self) -> list[_DocumentType]: + def to_list(self, length: Optional[int] = None) -> list[_DocumentType]: """Converts the contents of this cursor to a list more efficiently than ``[doc for doc in cursor]``. To use:: >>> cursor.to_list() + Or, so read at most n items from the cursor:: + + >>> cursor.to_list(n) + If the cursor is empty or has no more results, an empty list will be returned. .. versionadded:: 4.9 """ res: list[_DocumentType] = [] + remaining = length + if isinstance(length, int) and length < 1: + raise ValueError("to_list() length must be greater than 0") while self.alive: - if not self._next_batch(res): + if not self._next_batch(res, remaining): break + if length is not None: + remaining = length - len(res) + if remaining == 0: + break return res diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index 1595ce40b..c352b6409 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -1258,16 +1258,20 @@ class Cursor(Generic[_DocumentType]): else: raise StopIteration - def _next_batch(self, result: list) -> bool: - """Get all available documents from the cursor.""" + def _next_batch(self, result: list, total: Optional[int] = None) -> bool: + """Get all or some documents from the cursor.""" if not self._exhaust_checked: self._exhaust_checked = True self._supports_exhaust() if self._empty: return False if len(self._data) or self._refresh(): - result.extend(self._data) - self._data.clear() + if total is None: + result.extend(self._data) + self._data.clear() + else: + for _ in range(min(len(self._data), total)): + result.append(self._data.popleft()) return True else: return False @@ -1284,21 +1288,32 @@ class Cursor(Generic[_DocumentType]): def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.close() - def to_list(self) -> list[_DocumentType]: + def to_list(self, length: Optional[int] = None) -> list[_DocumentType]: """Converts the contents of this cursor to a list more efficiently than ``[doc for doc in cursor]``. To use:: >>> cursor.to_list() + Or, so read at most n items from the cursor:: + + >>> cursor.to_list(n) + If the cursor is empty or has no more results, an empty list will be returned. .. versionadded:: 4.9 """ res: list[_DocumentType] = [] + remaining = length + if isinstance(length, int) and length < 1: + raise ValueError("to_list() length must be greater than 0") while self.alive: - if not self._next_batch(res): + if not self._next_batch(res, remaining): break + if length is not None: + remaining = length - len(res) + if remaining == 0: + break return res diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index 9dd67f2da..6967205fe 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -1401,6 +1401,20 @@ class TestCursor(AsyncIntegrationTest): docs = await c.to_list() self.assertEqual([], docs) + async def test_to_list_length(self): + coll = self.db.test + await coll.insert_many([{} for _ in range(5)]) + self.addCleanup(coll.drop) + c = coll.find() + docs = await c.to_list(3) + self.assertEqual(len(docs), 3) + + c = coll.find(batch_size=2) + docs = await c.to_list(3) + self.assertEqual(len(docs), 3) + docs = await c.to_list(3) + self.assertEqual(len(docs), 2) + @async_client_context.require_change_streams async def test_command_cursor_to_list(self): # Set maxAwaitTimeMS=1 to speed up the test. @@ -1417,6 +1431,19 @@ class TestCursor(AsyncIntegrationTest): docs = await c.to_list() self.assertEqual([], docs) + @async_client_context.require_change_streams + async def test_command_cursor_to_list_length(self): + db = self.db + await db.drop_collection("test") + await db.test.insert_many([{"foo": 1}, {"foo": 2}]) + + pipeline = {"$project": {"_id": False, "foo": True}} + result = await db.test.aggregate([pipeline]) + self.assertEqual(len(await result.to_list()), 2) + + result = await db.test.aggregate([pipeline]) + self.assertEqual(len(await result.to_list(1)), 1) + class TestRawBatchCursor(AsyncIntegrationTest): async def test_find_raw(self): diff --git a/test/test_cursor.py b/test/test_cursor.py index 1cde4718f..8e6fade1e 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -1392,6 +1392,20 @@ class TestCursor(IntegrationTest): docs = c.to_list() self.assertEqual([], docs) + def test_to_list_length(self): + coll = self.db.test + coll.insert_many([{} for _ in range(5)]) + self.addCleanup(coll.drop) + c = coll.find() + docs = c.to_list(3) + self.assertEqual(len(docs), 3) + + c = coll.find(batch_size=2) + docs = c.to_list(3) + self.assertEqual(len(docs), 3) + docs = c.to_list(3) + self.assertEqual(len(docs), 2) + @client_context.require_change_streams def test_command_cursor_to_list(self): # Set maxAwaitTimeMS=1 to speed up the test. @@ -1408,6 +1422,19 @@ class TestCursor(IntegrationTest): docs = c.to_list() self.assertEqual([], docs) + @client_context.require_change_streams + def test_command_cursor_to_list_length(self): + db = self.db + db.drop_collection("test") + db.test.insert_many([{"foo": 1}, {"foo": 2}]) + + pipeline = {"$project": {"_id": False, "foo": True}} + result = db.test.aggregate([pipeline]) + self.assertEqual(len(result.to_list()), 2) + + result = db.test.aggregate([pipeline]) + self.assertEqual(len(result.to_list(1)), 1) + class TestRawBatchCursor(IntegrationTest): def test_find_raw(self): diff --git a/test/test_gridfs.py b/test/test_gridfs.py index 27b38dc0b..19ec152bd 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -440,6 +440,12 @@ class TestGridfs(IntegrationTest): gout = next(cursor) self.assertEqual(b"test2+", gout.read()) self.assertRaises(StopIteration, cursor.__next__) + cursor.rewind() + items = cursor.to_list() + self.assertEqual(len(items), 2) + cursor.rewind() + items = cursor.to_list(1) + self.assertEqual(len(items), 1) cursor.close() self.assertRaises(TypeError, self.fs.find, {}, {"_id": True})