diff --git a/database.py b/database.py index 2290797d3..09f839d06 100644 --- a/database.py +++ b/database.py @@ -111,4 +111,9 @@ class Database(object): def collection_names(self): """Get a list of all the collection names in this database. """ - raise Exception("unimplemented") + results = self.system.namespaces.find() + names = [r["name"] for r in results] + names = [n[len(self.__name) + 1:] for n in names + if n.startswith(self.__name + ".")] + names = [n for n in names if "$" not in n] + return names diff --git a/test/test_database.py b/test/test_database.py index aff462329..306709e6e 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -36,6 +36,17 @@ class TestDatabase(unittest.TestCase): self.assertNotEqual(db.test, Collection(db, "mike")) self.assertEqual(db.test.mike, db["test.mike"]) + def test_collection_names(self): + db = Database(self.connection, "test") + db.test.save({"dummy": u"object"}) + db.test.mike.save({"dummy": u"object"}) + + colls = db.collection_names() + self.assertTrue("test" in colls) + self.assertTrue("test.mike" in colls) + for coll in colls: + self.assertTrue("$" not in coll) + def test_save_find_one(self): db = Database(self.connection, "test") db.test.remove({})