collections

This commit is contained in:
Mike Dirolf 2009-01-09 16:28:57 -05:00
parent 7181b4cfcd
commit 99d45eadf9

121
mongo.py
View File

@ -8,14 +8,19 @@ import types
import traceback
import os
import bson
import objectid
import dbref
from son import SON
from bson import BSON
from objectid import ObjectId
from dbref import DBRef
class ConnectionException(IOError):
"""Raised when a connection to the database cannot be made or is lost.
"""
class InvalidCollection(ValueError):
"""Raised when an invalid collection name is used.
"""
class Mongo(object):
"""A connection to a Mongo database.
"""
@ -50,9 +55,83 @@ class Mongo(object):
raise ConnectionException("could not connect to %s:%s, got: %s" %
(self.__host, self.__port, traceback.format_exc()))
def __cmp__(self, other):
if isinstance(other, Mongo):
return cmp((self.__host, self.__port), (other.__host, other.__port))
return NotImplemented
def __repr__(self):
return "Mongo(" + repr(self.__host) + ", " + repr(self.__port) + ")"
def __getattr__(self, name):
"""Get a collection of this database by name.
Raises InvalidCollection if an invalid collection name is used.
Arguments:
- `name`: the name of the collection to get
"""
return Collection(self, name)
def __getitem__(self, name):
return self.__getattr__(name)
class Collection(object):
"""A Mongo collection.
"""
def __init__(self, database, name):
"""Get / create a Mongo collection.
Raises TypeError if database is not an instance of Mongo or name is not
an instance of (str, unicode). Raises InvalidCollection if name is not a
valid collection name.
Arguments:
- `database`: the database to get a collection from
- `name`: the name of the collection to get
"""
if not isinstance(database, Mongo):
raise TypeError("database must be an instance of Mongo")
if not isinstance(name, types.StringTypes):
raise TypeError("name must be an instance of (str, unicode)")
if not name or ".." in name:
raise InvalidCollection("collection names cannot be empty")
if "$" in name:
raise InvalidCollection("collection names must not contain '$'")
if name[0] == "." or name[-1] == ".":
raise InvalidCollection("collecion names must not start or end with '.'")
self.__database = database
self.__collection_name = name
def __getattr__(self, name):
"""Get a sub-collection of this collection by name.
Raises InvalidCollection if an invalid collection name is used.
Arguments:
- `name`: the name of the collection to get
"""
return Collection(self.__database, u"%s.%s" % (self.__collection_name, name))
def __getitem__(self, name):
return self.__getattr__(name)
def __repr__(self):
return "Collection(%r, %r)" % (self.__database, self.__collection_name)
def __cmp__(self, other):
if isinstance(other, Collection):
return cmp((self.__database, self.__collection_name), (other.__database, other.__collection_name))
return NotImplemented
def save(self, to_save):
pass
def find_one(self, spec):
pass
class TestMongo(unittest.TestCase):
def setUp(self):
self.host = os.environ.get("db_ip", "localhost")
@ -76,6 +155,42 @@ class TestMongo(unittest.TestCase):
def test_repr(self):
self.assertEqual(repr(Mongo(self.host, self.port)),
"Mongo('%s', %s)" % (self.host, self.port))
self.assertEqual(repr(Mongo(self.host, self.port).test),
"Collection(Mongo('%s', %s), 'test')" % (self.host, self.port))
def test_collection(self):
db = Mongo(self.host, self.port)
self.assertRaises(TypeError, Collection, db, 5)
self.assertRaises(TypeError, Collection, 5, "test")
def make_col(base, name):
base[name]
self.assertRaises(InvalidCollection, make_col, db, "")
self.assertRaises(InvalidCollection, make_col, db, "te$t")
self.assertRaises(InvalidCollection, make_col, db, ".test")
self.assertRaises(InvalidCollection, make_col, db, "test.")
self.assertRaises(InvalidCollection, make_col, db, "tes..t")
self.assertRaises(InvalidCollection, make_col, db.test, "")
self.assertRaises(InvalidCollection, make_col, db.test, "te$t")
self.assertRaises(InvalidCollection, make_col, db.test, ".test")
self.assertRaises(InvalidCollection, make_col, db.test, "test.")
self.assertRaises(InvalidCollection, make_col, db.test, "tes..t")
self.assertTrue(isinstance(db.test, Collection))
self.assertEqual(db.test, db["test"])
self.assertEqual(db.test, Collection(db, "test"))
self.assertEqual(db.test.mike, db["test.mike"])
self.assertEqual(db.test["mike"], db["test.mike"])
def test_save_find_one(self):
db = Mongo(self.host, self.port)
a_doc = SON({"hello": u"world"})
db.test.save(a_doc)
# self.assertTrue(isinstance(a_doc["_id"], ObjectId))
# self.assertEqual(a_doc, db.test.find_one({"_id": a_doc["_id"]}))
if __name__ == "__main__":
unittest.main()