diff --git a/database.py b/database.py index 2c2328c1b..35b621cfc 100644 --- a/database.py +++ b/database.py @@ -4,7 +4,7 @@ import types from son_manipulator import ObjectIdInjector from collection import Collection -from errors import InvalidName +from errors import InvalidName, CollectionInvalid, OperationFailure ASCENDING = 1 DESCENDING = -1 @@ -132,8 +132,33 @@ class Database(object): if not isinstance(name, types.StringTypes): raise TypeError("name_or_collection must be an instance of (Collection, str, unicode)") + if name not in self.collection_names(): + return + self[name].drop_indexes() # must manually drop indexes result = self._command({"drop": unicode(name)}) if result["ok"] != 1: - raise OperationFailure("failed to drop collection") + raise OperationFailure("failed to drop collection: %s" % result["errmsg"]) + + def validate_collection(self, name_or_collection): + """Validate a collection. + + Returns a string of validation info. Raises CollectionInvalid if + validation fails. + """ + name = name_or_collection + if isinstance(name, Collection): + name = name.name() + + if not isinstance(name, types.StringTypes): + raise TypeError("name_or_collection must be an instance of (Collection, str, unicode)") + + result = self._command({"validate": unicode(name)}) + if result["ok"] != 1: + raise OperationFailure("failed to validate collection: %s" % result["errmsg"]) + + info = result["result"] + if info.find("exception") != -1 or info.find("corrupt") != -1: + raise CollectionInvalid("%s invalid: %s" % (name, info)) + return info diff --git a/errors.py b/errors.py index 0beae2d30..b37855edf 100644 --- a/errors.py +++ b/errors.py @@ -12,6 +12,10 @@ class InvalidOperation(Exception): """Raised when a client attempts to perform an invalid operation. """ +class CollectionInvalid(Exception): + """Raised when collection validation fails. + """ + class InvalidName(ValueError): """Raised when an invalid name is used. """ diff --git a/test/test_database.py b/test/test_database.py index 1b007a64f..bc5c41f5d 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -4,7 +4,7 @@ import unittest import types import random -from errors import InvalidName, InvalidOperation +from errors import InvalidName, InvalidOperation, CollectionInvalid, OperationFailure from son import SON from objectid import ObjectId from database import Database, ASCENDING, DESCENDING @@ -68,6 +68,22 @@ class TestDatabase(unittest.TestCase): db.drop_collection(db.test) self.assertFalse("test" in db.collection_names()) + db.drop_collection(db.test.doesnotexist) + + def test_validate_collection(self): + db = self.connection.test + + self.assertRaises(TypeError, db.validate_collection, 5) + self.assertRaises(TypeError, db.validate_collection, None) + + db.test.save({"dummy": u"object"}) + + self.assertRaises(OperationFailure, db.validate_collection, "test.doesnotexist") + self.assertRaises(OperationFailure, db.validate_collection, db.test.doesnotexist) + + self.assertTrue(db.validate_collection("test")) + self.assertTrue(db.validate_collection(db.test)) + def test_save_find_one(self): db = Database(self.connection, "test") db.test.remove({})