diff --git a/database.py b/database.py index 35b621cfc..eea582e8b 100644 --- a/database.py +++ b/database.py @@ -9,6 +9,11 @@ from errors import InvalidName, CollectionInvalid, OperationFailure ASCENDING = 1 DESCENDING = -1 +# profiling levels +OFF = 0 +SLOW_ONLY = 1 +ALL = 2 + class Database(object): """A Mongo database. """ @@ -162,3 +167,35 @@ class Database(object): if info.find("exception") != -1 or info.find("corrupt") != -1: raise CollectionInvalid("%s invalid: %s" % (name, info)) return info + + def profiling_level(self): + """Get the database's current profiling level. + + Returns one of (OFF, SLOW_ONLY, ALL). + """ + result = self._command({"profile": -1}) + if result["ok"] != 1: + raise OperationFailure("failed to get profiling level: %s" % result["errmsg"]) + + assert result["was"] >= 0 and result["was"] <= 2 + return result["was"] + + def set_profiling_level(self, level): + """Set the database's profiling level. + + Raises ValueError if level is not one of (OFF, SLOW_ONLY, ALL). + + Arguments: + - `level`: the profiling level to use + """ + if not isinstance(level, types.IntType) or level < 0 or level > 2: + raise ValueError("level must be one of (OFF, SLOW_ONLY, ALL)") + + result = self._command({"profile": level}) + if result["ok"] != 1: + raise OperationFailure("failed to set profiling level: %s" % result["errmsg"]) + + def profiling_info(self): + """Returns a list containing current profiling information. + """ + return list(self.system.profile.find()) diff --git a/test/test_database.py b/test/test_database.py index bc5c41f5d..2646c2caa 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -3,11 +3,12 @@ import unittest import types import random +import datetime from errors import InvalidName, InvalidOperation, CollectionInvalid, OperationFailure from son import SON from objectid import ObjectId -from database import Database, ASCENDING, DESCENDING +from database import Database, ASCENDING, DESCENDING, OFF, SLOW_ONLY, ALL from connection import Connection from collection import Collection, SYSTEM_INDEX_COLLECTION from test_connection import get_connection @@ -84,6 +85,37 @@ class TestDatabase(unittest.TestCase): self.assertTrue(db.validate_collection("test")) self.assertTrue(db.validate_collection(db.test)) + def test_profiling_levels(self): + db = self.connection.test + self.assertEqual(db.profiling_level(), OFF) #default + + self.assertRaises(ValueError, db.set_profiling_level, 5.5) + self.assertRaises(ValueError, db.set_profiling_level, None) + self.assertRaises(ValueError, db.set_profiling_level, -1) + + db.set_profiling_level(SLOW_ONLY) + self.assertEqual(db.profiling_level(), SLOW_ONLY) + + db.set_profiling_level(ALL) + self.assertEqual(db.profiling_level(), ALL) + + db.set_profiling_level(OFF) + self.assertEqual(db.profiling_level(), OFF) + + def test_profiling_info(self): + db = self.connection.test + + db.set_profiling_level(ALL) + db.test.find() + db.set_profiling_level(OFF) + + info = db.profiling_info() + self.assertTrue(isinstance(info, types.ListType)) + self.assertTrue(len(info) >= 1) + self.assertTrue(isinstance(info[0]["info"], types.StringTypes)) + self.assertTrue(isinstance(info[0]["ts"], datetime.datetime)) + self.assertTrue(isinstance(info[0]["millis"], types.FloatType)) + def test_save_find_one(self): db = Database(self.connection, "test") db.test.remove({})