From 11967eb160363894dbad41cf772a9d1c4dce4e46 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Thu, 21 Mar 2019 15:48:45 -0700 Subject: [PATCH] PYTHON-1784 Add filter support to list_collection_names Adhere to enumerate collection spec for setting nameOnly when filter is provided to allow filtering based on collection options. --- doc/changelog.rst | 3 ++- pymongo/database.py | 39 ++++++++++++++++++++++++++---- test/test_database.py | 56 ++++++++++++++++++++++++++++++++++--------- test/utils.py | 7 +++--- 4 files changed, 85 insertions(+), 20 deletions(-) diff --git a/doc/changelog.rst b/doc/changelog.rst index 29481f21f..245dae0e8 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -82,7 +82,8 @@ Changes in Version 3.8.0.dev0 :meth:`~pymongo.cursor.Cursor.comment` as the "comment" top-level command option instead of "$comment". Also, note that "comment" must be a string. - +- Add the ``filter`` parameter to + :meth:`~pymongo.database.Database.list_collection_names`. Issues Resolved ............... diff --git a/pymongo/database.py b/pymongo/database.py index 223941a15..8ad91c4e1 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -361,7 +361,8 @@ class Database(common.BaseObject): Removed deprecated argument: options """ with self.__client._tmp_session(session) as s: - if name in self.list_collection_names(session=s): + if name in self.list_collection_names( + filter={"name": name}, session=s): raise CollectionInvalid("collection %s already exists" % name) return Collection(self, name, True, codec_options, @@ -651,12 +652,14 @@ class Database(common.BaseObject): cursor = self._command(sock_info, cmd, slave_okay)["cursor"] return CommandCursor(coll, cursor, sock_info.address) - def list_collections(self, session=None, **kwargs): + def list_collections(self, session=None, filter=None, **kwargs): """Get a cursor over the collectons of this database. :Parameters: - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. + - `filter` (optional): A query document to filter the list of + collections returned from the listCollections command. - `**kwargs` (optional): Optional parameters of the `listCollections command `_ @@ -668,6 +671,8 @@ class Database(common.BaseObject): .. versionadded:: 3.6 """ + if filter is not None: + kwargs['filter'] = filter read_pref = ((session and session._txn_read_preference()) or ReadPreference.PRIMARY) with self.__client._socket_for_reads( @@ -676,18 +681,42 @@ class Database(common.BaseObject): sock_info, slave_okay, session, read_preference=read_pref, **kwargs) - def list_collection_names(self, session=None): + def list_collection_names(self, session=None, filter=None, **kwargs): """Get a list of all the collection names in this database. + For example, to list all non-system collections:: + + filter = {"name": {"$regex": r"^(?!system\.)"}} + db.list_collection_names(filter=filter) + :Parameters: - `session` (optional): a :class:`~pymongo.client_session.ClientSession`. + - `filter` (optional): A query document to filter the list of + collections returned from the listCollections command. + - `**kwargs` (optional): Optional parameters of the + `listCollections command + `_ + can be passed as keyword arguments to this method. The supported + options differ by server version. + + .. versionchanged:: 3.8 + Added the ``filter`` and ``**kwargs`` parameters. .. versionadded:: 3.6 """ + if filter is None: + kwargs["nameOnly"] = True + else: + # The enumerate collections spec states that "drivers MUST NOT set + # nameOnly if a filter specifies any keys other than name." + common.validate_is_mapping("filter", filter) + kwargs["filter"] = filter + if not filter or (len(filter) == 1 and "name" in filter): + kwargs["nameOnly"] = True + return [result["name"] - for result in self.list_collections(session=session, - nameOnly=True)] + for result in self.list_collections(session=session, **kwargs)] def collection_names(self, include_system_collections=True, session=None): diff --git a/test/test_database.py b/test/test_database.py index 52f480a19..b6f66cada 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -56,7 +56,8 @@ from test.utils import (ignore_deprecations, rs_or_single_client_noauth, rs_or_single_client, server_started_with_auth, - IMPOSSIBLE_WRITE_CONCERN) + IMPOSSIBLE_WRITE_CONCERN, + OvertCommandListener) if PY3: @@ -156,7 +157,7 @@ class TestDatabase(IntegrationTest): self.assertTrue(u"test.foo" in db.list_collection_names()) self.assertRaises(CollectionInvalid, db.create_collection, "test.foo") - def _test_collection_names(self, meth, test_no_system): + def _test_collection_names(self, meth, **no_system_kwargs): db = Database(self.client, "pymongo_test") db.test.insert_one({"dummy": u"object"}) db.test.mike.insert_one({"dummy": u"object"}) @@ -167,13 +168,11 @@ class TestDatabase(IntegrationTest): for coll in colls: self.assertTrue("$" not in coll) - if test_no_system: - db.systemcoll.test.insert_one({}) - no_system_collections = getattr( - db, meth)(include_system_collections=False) - for coll in no_system_collections: - self.assertTrue(not coll.startswith("system.")) - self.assertIn("systemcoll.test", no_system_collections) + db.systemcoll.test.insert_one({}) + no_system_collections = getattr(db, meth)(**no_system_kwargs) + for coll in no_system_collections: + self.assertTrue(not coll.startswith("system.")) + self.assertIn("systemcoll.test", no_system_collections) # Force more than one batch. db = self.client.many_collections @@ -186,10 +185,45 @@ class TestDatabase(IntegrationTest): self.client.drop_database("many_collections") def test_collection_names(self): - self._test_collection_names('collection_names', True) + self._test_collection_names( + 'collection_names', include_system_collections=False) def test_list_collection_names(self): - self._test_collection_names('list_collection_names', False) + self._test_collection_names( + 'list_collection_names', filter={ + "name": {"$regex": r"^(?!system\.)"}}) + + def test_list_collection_names_filter(self): + listener = OvertCommandListener() + results = listener.results + client = rs_or_single_client(event_listeners=[listener]) + db = client[self.db.name] + db.capped.drop() + db.create_collection("capped", capped=True, size=4096) + db.capped.insert_one({}) + db.non_capped.insert_one({}) + self.addCleanup(client.drop_database, db.name) + + # Should not send nameOnly. + for filter in ({'options.capped': True}, + {'options.capped': True, 'name': 'capped'}): + results.clear() + names = db.list_collection_names(filter=filter) + self.assertEqual(names, ["capped"]) + self.assertNotIn("nameOnly", results["started"][0].command) + + # Should send nameOnly (except on 2.6). + for filter in (None, {}, {'name': {'$in': ['capped', 'non_capped']}}): + results.clear() + names = db.list_collection_names(filter=filter) + self.assertIn("capped", names) + self.assertIn("non_capped", names) + command = results["started"][0].command + if client_context.version >= (3, 0): + self.assertIn("nameOnly", command) + self.assertTrue(command["nameOnly"]) + else: + self.assertNotIn("nameOnly", command) def test_list_collections(self): self.client.drop_database("pymongo_test") diff --git a/test/utils.py b/test/utils.py index af3d86167..0cb0052fc 100644 --- a/test/utils.py +++ b/test/utils.py @@ -408,9 +408,10 @@ def server_is_master_with_slave(client): def drop_collections(db): - for coll in db.list_collection_names(): - if not coll.startswith('system'): - db.drop_collection(coll) + # Drop all non-system collections in this database. + for coll in db.list_collection_names( + filter={"name": {"$regex": r"^(?!system\.)"}}): + db.drop_collection(coll) def remove_all_users(db):