diff --git a/mongo.py b/mongo.py index b9a6ff865..d0a1b6b27 100644 --- a/mongo.py +++ b/mongo.py @@ -8,14 +8,19 @@ import types import traceback import os -import bson -import objectid -import dbref +from son import SON +from bson import BSON +from objectid import ObjectId +from dbref import DBRef class ConnectionException(IOError): """Raised when a connection to the database cannot be made or is lost. """ +class InvalidCollection(ValueError): + """Raised when an invalid collection name is used. + """ + class Mongo(object): """A connection to a Mongo database. """ @@ -50,9 +55,83 @@ class Mongo(object): raise ConnectionException("could not connect to %s:%s, got: %s" % (self.__host, self.__port, traceback.format_exc())) + def __cmp__(self, other): + if isinstance(other, Mongo): + return cmp((self.__host, self.__port), (other.__host, other.__port)) + return NotImplemented + def __repr__(self): return "Mongo(" + repr(self.__host) + ", " + repr(self.__port) + ")" + def __getattr__(self, name): + """Get a collection of this database by name. + + Raises InvalidCollection if an invalid collection name is used. + + Arguments: + - `name`: the name of the collection to get + """ + return Collection(self, name) + + def __getitem__(self, name): + return self.__getattr__(name) + +class Collection(object): + """A Mongo collection. + """ + def __init__(self, database, name): + """Get / create a Mongo collection. + + Raises TypeError if database is not an instance of Mongo or name is not + an instance of (str, unicode). Raises InvalidCollection if name is not a + valid collection name. + + Arguments: + - `database`: the database to get a collection from + - `name`: the name of the collection to get + """ + if not isinstance(database, Mongo): + raise TypeError("database must be an instance of Mongo") + if not isinstance(name, types.StringTypes): + raise TypeError("name must be an instance of (str, unicode)") + + if not name or ".." in name: + raise InvalidCollection("collection names cannot be empty") + if "$" in name: + raise InvalidCollection("collection names must not contain '$'") + if name[0] == "." or name[-1] == ".": + raise InvalidCollection("collecion names must not start or end with '.'") + + self.__database = database + self.__collection_name = name + + def __getattr__(self, name): + """Get a sub-collection of this collection by name. + + Raises InvalidCollection if an invalid collection name is used. + + Arguments: + - `name`: the name of the collection to get + """ + return Collection(self.__database, u"%s.%s" % (self.__collection_name, name)) + + def __getitem__(self, name): + return self.__getattr__(name) + + def __repr__(self): + return "Collection(%r, %r)" % (self.__database, self.__collection_name) + + def __cmp__(self, other): + if isinstance(other, Collection): + return cmp((self.__database, self.__collection_name), (other.__database, other.__collection_name)) + return NotImplemented + + def save(self, to_save): + pass + + def find_one(self, spec): + pass + class TestMongo(unittest.TestCase): def setUp(self): self.host = os.environ.get("db_ip", "localhost") @@ -76,6 +155,42 @@ class TestMongo(unittest.TestCase): def test_repr(self): self.assertEqual(repr(Mongo(self.host, self.port)), "Mongo('%s', %s)" % (self.host, self.port)) + self.assertEqual(repr(Mongo(self.host, self.port).test), + "Collection(Mongo('%s', %s), 'test')" % (self.host, self.port)) + + def test_collection(self): + db = Mongo(self.host, self.port) + + self.assertRaises(TypeError, Collection, db, 5) + self.assertRaises(TypeError, Collection, 5, "test") + + def make_col(base, name): + base[name] + + self.assertRaises(InvalidCollection, make_col, db, "") + self.assertRaises(InvalidCollection, make_col, db, "te$t") + self.assertRaises(InvalidCollection, make_col, db, ".test") + self.assertRaises(InvalidCollection, make_col, db, "test.") + self.assertRaises(InvalidCollection, make_col, db, "tes..t") + self.assertRaises(InvalidCollection, make_col, db.test, "") + self.assertRaises(InvalidCollection, make_col, db.test, "te$t") + self.assertRaises(InvalidCollection, make_col, db.test, ".test") + self.assertRaises(InvalidCollection, make_col, db.test, "test.") + self.assertRaises(InvalidCollection, make_col, db.test, "tes..t") + + self.assertTrue(isinstance(db.test, Collection)) + self.assertEqual(db.test, db["test"]) + self.assertEqual(db.test, Collection(db, "test")) + self.assertEqual(db.test.mike, db["test.mike"]) + self.assertEqual(db.test["mike"], db["test.mike"]) + + def test_save_find_one(self): + db = Mongo(self.host, self.port) + + a_doc = SON({"hello": u"world"}) + db.test.save(a_doc) +# self.assertTrue(isinstance(a_doc["_id"], ObjectId)) +# self.assertEqual(a_doc, db.test.find_one({"_id": a_doc["_id"]})) if __name__ == "__main__": unittest.main()