From 88e744d5065bc256a086ebf13506fe1c65179dc1 Mon Sep 17 00:00:00 2001 From: Julius Park Date: Thu, 16 Sep 2021 15:52:14 -0700 Subject: [PATCH] PYTHON-808 Prevent use of Database and Collection in boolean expressions (#728) --- doc/changelog.rst | 7 ++++++- doc/migrate-to-pymongo4.rst | 26 ++++++++++++++++++++++++++ pymongo/collection.py | 5 +++++ pymongo/database.py | 5 +++++ test/test_collection.py | 4 ++++ test/test_database.py | 4 ++++ 6 files changed, 50 insertions(+), 1 deletion(-) diff --git a/doc/changelog.rst b/doc/changelog.rst index ea8cd19ba..8adb9ae7f 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -134,8 +134,13 @@ decodes datetime as naive by default. - The ``hint`` option is now required when using ``min`` or ``max`` queries with :meth:`~pymongo.collection.Collection.find`. - ``name`` is now a required argument for the :class:`pymongo.driver_info.DriverInfo` class. +- :class:`~pymongo.collection.Collection` and :class:`~pymongo.database.Database` + now raises an error upon evaluating as a Boolean, please use the + syntax ``if collection is not None:`` or ``if database is not None:`` as + opposed to + the previous syntax which was simply ``if collection:`` or ``if database:``. + You must now explicitly compare with None. - d Notable improvements .................... diff --git a/doc/migrate-to-pymongo4.rst b/doc/migrate-to-pymongo4.rst index 8300a0991..d17c13e75 100644 --- a/doc/migrate-to-pymongo4.rst +++ b/doc/migrate-to-pymongo4.rst @@ -321,6 +321,19 @@ Can be changed to this:: .. _'system.profile' collection: https://docs.mongodb.com/manual/reference/database-profiler/ +Database.__bool__ raises NotImplementedError +............................................ +:class:`~pymongo.database.Database` now raises an error upon evaluating as a +Boolean. Code like this:: + + if database: + +Can be changed to this:: + + if database is not None: + +You must now explicitly compare with None. + Collection ---------- @@ -621,6 +634,19 @@ can be changed to this:: cursor = coll.find({}, min={'x', min_value}, hint=[('x', ASCENDING)]) +Collection.__bool__ raises NotImplementedError +.............................................. +:class:`~pymongo.collection.Collection` now raises an error upon evaluating +as a Boolean. Code like this:: + + if collection: + +Can be changed to this:: + + if collection is not None: + +You must now explicitly compare with None. + SONManipulator is removed ------------------------- diff --git a/pymongo/collection.py b/pymongo/collection.py index 7994c4d8a..a5d2fa5c3 100644 --- a/pymongo/collection.py +++ b/pymongo/collection.py @@ -289,6 +289,11 @@ class Collection(common.BaseObject): def __hash__(self): return hash((self.__database, self.__name)) + def __bool__(self): + raise NotImplementedError("Collection objects do not implement truth " + "value testing or bool(). Please compare " + "with None instead: collection is not None") + @property def full_name(self): """The full name of this :class:`Collection`. diff --git a/pymongo/database.py b/pymongo/database.py index df8d730fb..35a878322 100644 --- a/pymongo/database.py +++ b/pymongo/database.py @@ -840,6 +840,11 @@ class Database(common.BaseObject): next = __next__ + def __bool__(self): + raise NotImplementedError("Database objects do not implement truth " + "value testing or bool(). Please compare " + "with None instead: database is not None") + def dereference(self, dbref, session=None, **kwargs): """Dereference a :class:`~bson.dbref.DBRef`, getting the document it points to. diff --git a/test/test_collection.py b/test/test_collection.py index 2795b4742..26c616a4d 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -2265,6 +2265,10 @@ class TestCollection(IntegrationTest): ('$dumb', 2), ('filter', {'foo': 1})]).to_dict()) + def test_bool(self): + with self.assertRaises(NotImplementedError): + bool(Collection(self.db, 'test')) + if __name__ == "__main__": unittest.main() diff --git a/test/test_database.py b/test/test_database.py index 37bda18b7..4813f8d10 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -687,6 +687,10 @@ class TestDatabaseAggregation(IntegrationTest): with self.admin.aggregate(self.pipeline) as _: pass + def test_bool(self): + with self.assertRaises(NotImplementedError): + bool(Database(self.client, "test")) + if __name__ == "__main__": unittest.main()