From 758d3800806bc407b289bfada64e507e08bb2d9a Mon Sep 17 00:00:00 2001 From: Mike Dirolf Date: Wed, 21 Jan 2009 12:59:13 -0500 Subject: [PATCH] create_collection --- collection.py | 25 +++++++++++++++++++++++-- database.py | 23 +++++++++++++++++++++++ test/test_database.py | 25 +++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 2 deletions(-) diff --git a/collection.py b/collection.py index c89b04202..42f4ee232 100644 --- a/collection.py +++ b/collection.py @@ -14,19 +14,27 @@ _ONE = "\x01\x00\x00\x00" class Collection(object): """A Mongo collection. """ - def __init__(self, database, name): + def __init__(self, database, name, options={}): """Get / create a Mongo collection. Raises TypeError if name is not an instance of (str, unicode). Raises - InvalidName if name is not a valid collection name. + InvalidName if name is not a valid collection name. Raises TypeError if + options is not an instance of dict. If options is non-empty a create + command will be sent to the database. Otherwise the collection will be + created implicitly on first use. Arguments: - `database`: the database to get a collection from - `name`: the name of the collection to get + - `options`: dictionary of collection options. + see `Database.create_collection` for details. """ if not isinstance(name, types.StringTypes): raise TypeError("name must be an instance of (str, unicode)") + if not isinstance(options, types.DictType): + raise TypeError("options must be an instance of dict") + if not name or ".." in name: raise InvalidName("collection names cannot be empty") if "$" in name and name not in ["$cmd"]: @@ -36,6 +44,19 @@ class Collection(object): self.__database = database self.__collection_name = unicode(name) + if options: + self.__create(options) + + def __create(self, options): + """Sends a create command with the given options. + """ + command = SON({"create": self.__collection_name}) + command.update(options) + + response = self.__database._command(command) + if response["ok"] not in [0, 1]: + raise OperationFailure("error creating collection: %s" % + response["errmsg"]) def __getattr__(self, name): """Get a sub-collection of this collection by name. diff --git a/database.py b/database.py index d93e07428..e32d3802d 100644 --- a/database.py +++ b/database.py @@ -87,6 +87,29 @@ class Database(object): """ return self.__getattr__(name) + def create_collection(self, name, options={}): + """Create a new collection in this database. + + Normally collection creation is automatic. This method should only if you + want to specify options on creation. CollectionInvalid is raised if the + collection already exists. + + Options should be a dictionary, with any of the following options: + - "size": desired initial size for the collection. must be less than or + equal to 10000000000. for capped collections this size is the max + size of the collection. + - "capped": if True, this is a capped collection + - "max": maximum number of objects if capped (optional) + + Arguments: + - `name`: the name of the collection to create + - `options` (optional): options to use on the new collection + """ + if name in self.collection_names(): + raise CollectionInvalid("collection %s already exists" % name) + + return Collection(self, name, options) + def _fix_incoming(self, son, collection): """Apply manipulators to an incoming SON object before it gets stored. diff --git a/test/test_database.py b/test/test_database.py index 865da0551..8c729d029 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -37,6 +37,31 @@ class TestDatabase(unittest.TestCase): self.assertNotEqual(db.test, Collection(db, "mike")) self.assertEqual(db.test.mike, db["test.mike"]) + def test_create_collection(self): + db = Database(self.connection, "test") + + db.test.insert({"hello": "world"}) + self.assertRaises(CollectionInvalid, db.create_collection, "test") + + db.drop_collection("test") + + self.assertRaises(TypeError, db.create_collection, 5) + self.assertRaises(TypeError, db.create_collection, None) + self.assertRaises(InvalidName, db.create_collection, "coll..ection") + self.assertRaises(TypeError, db.create_collection, "test", 5) + self.assertRaises(TypeError, db.create_collection, "test", None) + + test = db.create_collection("test") + test.save({"hello": u"world"}) + self.assertEqual(db.test.find_one()["hello"], "world") + self.assertTrue(u"test" in db.collection_names()) + + db.drop_collection("test.foo") + db.create_collection("test.foo") + self.assertFalse(u"test.foo" in db.collection_names()) + db.create_collection("test.foo", {"capped": True}) + self.assertTrue(u"test.foo" in db.collection_names()) + def test_collection_names(self): db = Database(self.connection, "test") db.test.save({"dummy": u"object"})