collections
This commit is contained in:
parent
7181b4cfcd
commit
99d45eadf9
121
mongo.py
121
mongo.py
@ -8,14 +8,19 @@ import types
|
||||
import traceback
|
||||
import os
|
||||
|
||||
import bson
|
||||
import objectid
|
||||
import dbref
|
||||
from son import SON
|
||||
from bson import BSON
|
||||
from objectid import ObjectId
|
||||
from dbref import DBRef
|
||||
|
||||
class ConnectionException(IOError):
|
||||
"""Raised when a connection to the database cannot be made or is lost.
|
||||
"""
|
||||
|
||||
class InvalidCollection(ValueError):
|
||||
"""Raised when an invalid collection name is used.
|
||||
"""
|
||||
|
||||
class Mongo(object):
|
||||
"""A connection to a Mongo database.
|
||||
"""
|
||||
@ -50,9 +55,83 @@ class Mongo(object):
|
||||
raise ConnectionException("could not connect to %s:%s, got: %s" %
|
||||
(self.__host, self.__port, traceback.format_exc()))
|
||||
|
||||
def __cmp__(self, other):
|
||||
if isinstance(other, Mongo):
|
||||
return cmp((self.__host, self.__port), (other.__host, other.__port))
|
||||
return NotImplemented
|
||||
|
||||
def __repr__(self):
|
||||
return "Mongo(" + repr(self.__host) + ", " + repr(self.__port) + ")"
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Get a collection of this database by name.
|
||||
|
||||
Raises InvalidCollection if an invalid collection name is used.
|
||||
|
||||
Arguments:
|
||||
- `name`: the name of the collection to get
|
||||
"""
|
||||
return Collection(self, name)
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self.__getattr__(name)
|
||||
|
||||
class Collection(object):
|
||||
"""A Mongo collection.
|
||||
"""
|
||||
def __init__(self, database, name):
|
||||
"""Get / create a Mongo collection.
|
||||
|
||||
Raises TypeError if database is not an instance of Mongo or name is not
|
||||
an instance of (str, unicode). Raises InvalidCollection if name is not a
|
||||
valid collection name.
|
||||
|
||||
Arguments:
|
||||
- `database`: the database to get a collection from
|
||||
- `name`: the name of the collection to get
|
||||
"""
|
||||
if not isinstance(database, Mongo):
|
||||
raise TypeError("database must be an instance of Mongo")
|
||||
if not isinstance(name, types.StringTypes):
|
||||
raise TypeError("name must be an instance of (str, unicode)")
|
||||
|
||||
if not name or ".." in name:
|
||||
raise InvalidCollection("collection names cannot be empty")
|
||||
if "$" in name:
|
||||
raise InvalidCollection("collection names must not contain '$'")
|
||||
if name[0] == "." or name[-1] == ".":
|
||||
raise InvalidCollection("collecion names must not start or end with '.'")
|
||||
|
||||
self.__database = database
|
||||
self.__collection_name = name
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Get a sub-collection of this collection by name.
|
||||
|
||||
Raises InvalidCollection if an invalid collection name is used.
|
||||
|
||||
Arguments:
|
||||
- `name`: the name of the collection to get
|
||||
"""
|
||||
return Collection(self.__database, u"%s.%s" % (self.__collection_name, name))
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self.__getattr__(name)
|
||||
|
||||
def __repr__(self):
|
||||
return "Collection(%r, %r)" % (self.__database, self.__collection_name)
|
||||
|
||||
def __cmp__(self, other):
|
||||
if isinstance(other, Collection):
|
||||
return cmp((self.__database, self.__collection_name), (other.__database, other.__collection_name))
|
||||
return NotImplemented
|
||||
|
||||
def save(self, to_save):
|
||||
pass
|
||||
|
||||
def find_one(self, spec):
|
||||
pass
|
||||
|
||||
class TestMongo(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.host = os.environ.get("db_ip", "localhost")
|
||||
@ -76,6 +155,42 @@ class TestMongo(unittest.TestCase):
|
||||
def test_repr(self):
|
||||
self.assertEqual(repr(Mongo(self.host, self.port)),
|
||||
"Mongo('%s', %s)" % (self.host, self.port))
|
||||
self.assertEqual(repr(Mongo(self.host, self.port).test),
|
||||
"Collection(Mongo('%s', %s), 'test')" % (self.host, self.port))
|
||||
|
||||
def test_collection(self):
|
||||
db = Mongo(self.host, self.port)
|
||||
|
||||
self.assertRaises(TypeError, Collection, db, 5)
|
||||
self.assertRaises(TypeError, Collection, 5, "test")
|
||||
|
||||
def make_col(base, name):
|
||||
base[name]
|
||||
|
||||
self.assertRaises(InvalidCollection, make_col, db, "")
|
||||
self.assertRaises(InvalidCollection, make_col, db, "te$t")
|
||||
self.assertRaises(InvalidCollection, make_col, db, ".test")
|
||||
self.assertRaises(InvalidCollection, make_col, db, "test.")
|
||||
self.assertRaises(InvalidCollection, make_col, db, "tes..t")
|
||||
self.assertRaises(InvalidCollection, make_col, db.test, "")
|
||||
self.assertRaises(InvalidCollection, make_col, db.test, "te$t")
|
||||
self.assertRaises(InvalidCollection, make_col, db.test, ".test")
|
||||
self.assertRaises(InvalidCollection, make_col, db.test, "test.")
|
||||
self.assertRaises(InvalidCollection, make_col, db.test, "tes..t")
|
||||
|
||||
self.assertTrue(isinstance(db.test, Collection))
|
||||
self.assertEqual(db.test, db["test"])
|
||||
self.assertEqual(db.test, Collection(db, "test"))
|
||||
self.assertEqual(db.test.mike, db["test.mike"])
|
||||
self.assertEqual(db.test["mike"], db["test.mike"])
|
||||
|
||||
def test_save_find_one(self):
|
||||
db = Mongo(self.host, self.port)
|
||||
|
||||
a_doc = SON({"hello": u"world"})
|
||||
db.test.save(a_doc)
|
||||
# self.assertTrue(isinstance(a_doc["_id"], ObjectId))
|
||||
# self.assertEqual(a_doc, db.test.find_one({"_id": a_doc["_id"]}))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user