seperate modules for collection, cursor, database
This commit is contained in:
parent
c41f7adb43
commit
3bc86c9bb4
259
collection.py
Normal file
259
collection.py
Normal file
@ -0,0 +1,259 @@
|
||||
"""Collection level utilities for Mongo."""
|
||||
|
||||
import types
|
||||
|
||||
import bson
|
||||
from objectid import ObjectId
|
||||
from cursor import Cursor
|
||||
from son import SON
|
||||
from errors import InvalidName, OperationFailure
|
||||
|
||||
_ZERO = "\x00\x00\x00\x00"
|
||||
_ONE = "\x01\x00\x00\x00"
|
||||
SYSTEM_INDEX_COLLECTION = "system.indexes"
|
||||
|
||||
class Collection(object):
|
||||
"""A Mongo collection.
|
||||
"""
|
||||
def __init__(self, database, name):
|
||||
"""Get / create a Mongo collection.
|
||||
|
||||
Raises TypeError if name is not an instance of (str, unicode). Raises
|
||||
InvalidName if name is not a valid collection name.
|
||||
|
||||
Arguments:
|
||||
- `database`: the database to get a collection from
|
||||
- `name`: the name of the collection to get
|
||||
"""
|
||||
if not isinstance(name, types.StringTypes):
|
||||
raise TypeError("name must be an instance of (str, unicode)")
|
||||
|
||||
if not name or ".." in name:
|
||||
raise InvalidName("collection names cannot be empty")
|
||||
if "$" in name and name not in ["$cmd"]:
|
||||
raise InvalidName("collection names must not contain '$'")
|
||||
if name[0] == "." or name[-1] == ".":
|
||||
raise InvalidName("collecion names must not start or end with '.'")
|
||||
|
||||
self.__database = database
|
||||
self.__collection_name = unicode(name)
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Get a sub-collection of this collection by name.
|
||||
|
||||
Raises InvalidName if an invalid collection name is used.
|
||||
|
||||
Arguments:
|
||||
- `name`: the name of the collection to get
|
||||
"""
|
||||
return Collection(self.__database, u"%s.%s" % (self.__collection_name, name))
|
||||
|
||||
def __getitem__(self, name):
|
||||
return self.__getattr__(name)
|
||||
|
||||
def __repr__(self):
|
||||
return "Collection(%r, %r)" % (self.__database, self.__collection_name)
|
||||
|
||||
def __cmp__(self, other):
|
||||
if isinstance(other, Collection):
|
||||
return cmp((self.__database, self.__collection_name),
|
||||
(other.__database, other.__collection_name))
|
||||
return NotImplemented
|
||||
|
||||
def _full_name(self):
|
||||
return u"%s.%s" % (self.__database.name(), self.__collection_name)
|
||||
|
||||
def _name(self):
|
||||
return self.__collection_name
|
||||
|
||||
def _send_message(self, operation, data):
|
||||
"""Wrap up a message and send it.
|
||||
"""
|
||||
# reserved int, full collection name, message data
|
||||
message = _ZERO
|
||||
message += bson._make_c_string(self._full_name())
|
||||
message += data
|
||||
return self.__database.connection().send_message(operation, message)
|
||||
|
||||
def database(self):
|
||||
return self.__database
|
||||
|
||||
def save(self, to_save, add_meta=True):
|
||||
"""Save a SON object in this collection.
|
||||
|
||||
Raises TypeError if to_save is not an instance of (dict, SON).
|
||||
|
||||
Arguments:
|
||||
- `to_save`: the SON object to be saved
|
||||
- `add_meta` (optional): add meta information (like _id) to the object
|
||||
if it's missing
|
||||
"""
|
||||
if not isinstance(to_save, (types.DictType, SON)):
|
||||
raise TypeError("cannot save object of type %s" % type(to_save))
|
||||
|
||||
to_save = self.__database._fix_incoming(to_save, self, add_meta)
|
||||
|
||||
if "_id" not in to_save:
|
||||
self._send_message(2002, bson.BSON.from_dict(to_save))
|
||||
elif to_save["_id"].is_new():
|
||||
to_save["_id"]._use()
|
||||
self._send_message(2002, bson.BSON.from_dict(to_save))
|
||||
else:
|
||||
self._update({"_id": to_save["_id"]}, to_save, True)
|
||||
|
||||
return to_save.get("_id", None)
|
||||
|
||||
def _update(self, spec, document, upsert=False):
|
||||
"""Update an object(s) in this collection.
|
||||
|
||||
Raises TypeError if either spec or document isn't an instance of
|
||||
(dict, SON) or upsert isn't an instance of bool.
|
||||
|
||||
- `spec`: a SON object specifying elements which must be present for a
|
||||
document to be updated
|
||||
- `document`: a SON object specifying the fields to be changed in the
|
||||
selected document(s), or (in the case of an upsert) the document to
|
||||
be inserted.
|
||||
- `upsert` (optional): perform an upsert operation
|
||||
"""
|
||||
if not isinstance(spec, (types.DictType, SON)):
|
||||
raise TypeError("spec must be an instance of (dict, SON)")
|
||||
if not isinstance(document, (types.DictType, SON)):
|
||||
raise TypeError("document must be an instance of (dict, SON)")
|
||||
if not isinstance(upsert, types.BooleanType):
|
||||
raise TypeError("upsert must be an instance of bool")
|
||||
|
||||
message = upsert and _ONE or _ZERO
|
||||
message += bson.BSON.from_dict(spec)
|
||||
message += bson.BSON.from_dict(document)
|
||||
|
||||
self._send_message(2001, message)
|
||||
|
||||
def remove(self, spec_or_object_id):
|
||||
"""Remove an object(s) from this collection.
|
||||
|
||||
Raises TypeEror if the argument is not an instance of
|
||||
(dict, SON, ObjectId).
|
||||
|
||||
Arguments:
|
||||
- `spec_or_object_id` (optional): a SON object specifying elements
|
||||
which must be present for a document to be removed OR an instance of
|
||||
ObjectId to be used as the value for an _id element
|
||||
"""
|
||||
spec = spec_or_object_id
|
||||
if isinstance(spec, ObjectId):
|
||||
spec = SON({"_id": spec})
|
||||
|
||||
if not isinstance(spec, (types.DictType, SON)):
|
||||
raise TypeError("spec must be an instance of (dict, SON)")
|
||||
|
||||
self._send_message(2006, _ZERO + bson.BSON.from_dict(spec))
|
||||
|
||||
def find_one(self, spec_or_object_id=SON()):
|
||||
"""Get a single object from the database.
|
||||
|
||||
Raises TypeError if the argument is of an improper type. Returns a
|
||||
single SON object, or None if no result is found.
|
||||
|
||||
Arguments:
|
||||
- `spec_or_object_id` (optional): a SON object specifying elements
|
||||
which must be present for a document to be included in the result
|
||||
set OR an instance of ObjectId to be used as the value for an _id
|
||||
query
|
||||
"""
|
||||
spec = spec_or_object_id
|
||||
if isinstance(spec, ObjectId):
|
||||
spec = SON({"_id": spec})
|
||||
|
||||
for result in self.find(spec, limit=1):
|
||||
return result
|
||||
return None
|
||||
|
||||
def find(self, spec=SON(), fields=[], skip=0, limit=0):
|
||||
"""Query the database.
|
||||
|
||||
Raises TypeError if any of the arguments are of improper type. Returns
|
||||
an instance of Cursor corresponding to this query.
|
||||
|
||||
Arguments:
|
||||
- `spec` (optional): a SON object specifying elements which must be
|
||||
present for a document to be included in the result set
|
||||
- `fields` (optional): a list of field names that should be returned
|
||||
in the result set
|
||||
- `skip` (optional): the number of documents to omit (from the start of
|
||||
the result set) when returning the results
|
||||
- `limit` (optional): the maximum number of results to return in the
|
||||
first reply message, or 0 for the default return size
|
||||
"""
|
||||
if not isinstance(spec, (types.DictType, SON)):
|
||||
raise TypeError("spec must be an instance of (dict, SON)")
|
||||
if not isinstance(fields, types.ListType):
|
||||
raise TypeError("fields must be an instance of list")
|
||||
if not isinstance(skip, types.IntType):
|
||||
raise TypeError("skip must be an instance of int")
|
||||
if not isinstance(limit, types.IntType):
|
||||
raise TypeError("limit must be an instance of int")
|
||||
|
||||
return_fields = len(fields) and SON() or None
|
||||
for field in fields:
|
||||
if not isinstance(field, types.StringTypes):
|
||||
raise TypeError("fields must be a list of key names as (string, unicode)")
|
||||
return_fields[field] = 1
|
||||
|
||||
return Cursor(self, spec, return_fields, skip, limit)
|
||||
|
||||
def _gen_index_name(self, keys):
|
||||
"""Generate an index name from the set of fields it is over.
|
||||
"""
|
||||
return u"_".join([u"%s_%s" % item for item in keys])
|
||||
|
||||
def create_index(self, key_or_list, direction=None):
|
||||
"""Creates an index on this collection.
|
||||
|
||||
Takes either a single key and a direction, or a list of (key, direction)
|
||||
pairs. The key(s) must be an instance of (str, unicode), and the
|
||||
direction(s) must be one of (Mongo.ASCENDING, Mongo.DESCENDING).
|
||||
|
||||
Arguments:
|
||||
- `key_or_list`: a single key or a list of (key, direction) pairs
|
||||
specifying the index to ensure
|
||||
- `direction` (optional): must be included if key_or_list is a single
|
||||
key, otherwise must be None
|
||||
"""
|
||||
if direction:
|
||||
keys = [(key_or_list, direction)]
|
||||
else:
|
||||
keys = key_or_list
|
||||
|
||||
if not isinstance(keys, types.ListType):
|
||||
raise TypeError("if no direction is specified, key_or_list must be an instance of list")
|
||||
if not len(keys):
|
||||
raise ValueError("key_or_list must not be the empty list")
|
||||
|
||||
to_save = SON()
|
||||
to_save["name"] = self._gen_index_name(keys)
|
||||
to_save["ns"] = self._full_name()
|
||||
|
||||
key_object = SON()
|
||||
for (key, value) in keys:
|
||||
if not isinstance(key, types.StringTypes):
|
||||
raise TypeError("first item in each key pair must be a string")
|
||||
if not isinstance(value, types.IntType):
|
||||
raise TypeError("second item in each key pair must be Mongo.ASCENDING or Mongo.DESCENDING")
|
||||
key_object[key] = value
|
||||
to_save["key"] = key_object
|
||||
|
||||
self.__database[SYSTEM_INDEX_COLLECTION].save(to_save, False)
|
||||
|
||||
def drop_indexes(self):
|
||||
"""Drops all indexes on this collection.
|
||||
|
||||
Can be used on non-existant collections or collections with no indexes.
|
||||
Raises OperationFailure on an error.
|
||||
"""
|
||||
response = self.__database.command(SON([("deleteIndexes", self.__collection_name),
|
||||
("index", u"*")]))
|
||||
if response["ok"] != 1:
|
||||
if response["errmsg"] == "ns not found":
|
||||
return
|
||||
raise OperationFailure("error ping indexes: %s" % response["errmsg"])
|
||||
215
cursor.py
Normal file
215
cursor.py
Normal file
@ -0,0 +1,215 @@
|
||||
"""Cursor class to iterate over Mongo query results."""
|
||||
|
||||
import types
|
||||
import struct
|
||||
|
||||
import bson
|
||||
from son import SON
|
||||
from errors import InvalidOperation, OperationFailure
|
||||
|
||||
class Cursor(object):
|
||||
"""A cursor / iterator over Mongo query results.
|
||||
"""
|
||||
def __init__(self, collection, spec, fields, skip, limit):
|
||||
"""Create a new cursor.
|
||||
|
||||
Should not be called directly by application developers.
|
||||
"""
|
||||
self.__collection = collection
|
||||
self.__spec = spec
|
||||
self.__fields = fields
|
||||
self.__skip = skip
|
||||
self.__limit = limit
|
||||
self.__ordering = None
|
||||
|
||||
self.__data = []
|
||||
self.__id = None
|
||||
self.__retrieved = 0
|
||||
self.__killed = False
|
||||
|
||||
def __del__(self):
|
||||
if self.__id and not self.__killed:
|
||||
self.__die()
|
||||
|
||||
def __die(self):
|
||||
"""Kills this cursor.
|
||||
"""
|
||||
self.__collection.database()._kill_cursor(self.__id)
|
||||
self.__killed = True
|
||||
|
||||
def __query_spec(self):
|
||||
"""Get the spec to use for a query.
|
||||
|
||||
Just `self.__spec`, unless this cursor needs special query fields, like
|
||||
orderby.
|
||||
"""
|
||||
if not self.__ordering:
|
||||
return self.__spec
|
||||
return SON([("query", self.__spec),
|
||||
("orderby", self.__ordering)])
|
||||
|
||||
def __check_okay_to_chain(self):
|
||||
"""Check if it is okay to chain more options onto this cursor.
|
||||
"""
|
||||
if self.__retrieved or self.__id is not None:
|
||||
raise InvalidOperation("cannot set options after executing query")
|
||||
|
||||
def limit(self, limit):
|
||||
"""Limits the number of results to be returned by this cursor.
|
||||
|
||||
Raises TypeError if limit is not an instance of int. Raises
|
||||
InvalidOperation if this cursor has already been used.
|
||||
|
||||
Arguments:
|
||||
- `limit`: the number of results to return
|
||||
"""
|
||||
if not isinstance(limit, types.IntType):
|
||||
raise TypeError("limit must be an int")
|
||||
self.__check_okay_to_chain()
|
||||
|
||||
self.__limit = limit
|
||||
return self
|
||||
|
||||
def skip(self, skip):
|
||||
"""Skips the first `skip` results of this cursor.
|
||||
|
||||
Raises TypeError if skip is not an instance of int. Raises
|
||||
InvalidOperation if this cursor has already been used.
|
||||
|
||||
Arguments:
|
||||
- `skip`: the number of results to skip
|
||||
"""
|
||||
if not isinstance(skip, types.IntType):
|
||||
raise TypeError("skip must be an int")
|
||||
self.__check_okay_to_chain()
|
||||
|
||||
self.__skip = skip
|
||||
return self
|
||||
|
||||
def sort(self, key_or_list, direction=None):
|
||||
"""Sorts this cursors results.
|
||||
|
||||
Takes either a single key and a direction, or a list of (key, direction)
|
||||
pairs. The key(s) must be an instance of (str, unicode), and the
|
||||
direction(s) must be one of (Mongo.ASCENDING, Mongo.DESCENDING). Raises
|
||||
InvalidOperation if this cursor has already been used.
|
||||
|
||||
Arguments:
|
||||
- `key_or_list`: a single key or a list of (key, direction) pairs
|
||||
specifying the keys to sort on
|
||||
- `direction` (optional): must be included if key_or_list is a single
|
||||
key, otherwise must be None
|
||||
"""
|
||||
self.__check_okay_to_chain()
|
||||
|
||||
# TODO a lot of this logic could be shared with create_index()
|
||||
if direction:
|
||||
keys = [(key_or_list, direction)]
|
||||
else:
|
||||
keys = key_or_list
|
||||
|
||||
if not isinstance(keys, types.ListType):
|
||||
raise TypeError("if no direction is specified, key_or_list must be an instance of list")
|
||||
if not len(keys):
|
||||
raise ValueError("key_or_list must not be the empty list")
|
||||
|
||||
orderby = SON()
|
||||
for (key, value) in keys:
|
||||
if not isinstance(key, types.StringTypes):
|
||||
raise TypeError("first item in each key pair must be a string")
|
||||
if not isinstance(value, types.IntType):
|
||||
raise TypeError("second item in each key pair must be Mongo.ASCENDING or Mongo.DESCENDING")
|
||||
orderby[key] = value
|
||||
|
||||
self.__ordering = orderby
|
||||
return self
|
||||
|
||||
def count(self):
|
||||
"""Get the size of the results set for this query.
|
||||
|
||||
Returns the number of objects in the results set for this query. Does
|
||||
not take limit and skip into account. Raises InvalidOperation if this
|
||||
cursor has already been used. Raises OperationFailure on a database
|
||||
error.
|
||||
"""
|
||||
self.__check_okay_to_chain()
|
||||
|
||||
command = SON([("count", self.__collection._name()),
|
||||
("query", self.__spec)])
|
||||
response = self.__collection.database().command(command)
|
||||
if response["ok"] != 1:
|
||||
if response["errmsg"] == "ns does not exist":
|
||||
return 0
|
||||
raise OperationFailure("error getting count: %s" % response["errmsg"])
|
||||
return int(response["n"])
|
||||
|
||||
def _refresh(self):
|
||||
"""Refreshes the cursor with more data from Mongo.
|
||||
|
||||
Returns the length of self.__data after refresh. Will exit early if
|
||||
self.__data is already non-empty. Raises OperationFailure when the
|
||||
cursor cannot be refreshed due to an error on the query.
|
||||
"""
|
||||
if len(self.__data) or self.__killed:
|
||||
return len(self.__data)
|
||||
|
||||
def send_message(operation, message):
|
||||
# TODO the send and receive here should be synchronized...
|
||||
request_id = self.__collection._send_message(operation, message)
|
||||
response = self.__collection.database().connection().receive_message(1, request_id)
|
||||
|
||||
response_flag = struct.unpack("<i", response[:4])[0]
|
||||
if response_flag == 1:
|
||||
raise OperationFailure("cursor id '%s' not valid at server" % self.__id)
|
||||
elif response_flag == 2:
|
||||
error_object = bson.BSON(response[20:]).to_dict()
|
||||
raise OperationFailure("database error: %s" % error_object["$err"])
|
||||
else:
|
||||
assert response_flag == 0
|
||||
|
||||
self.__id = struct.unpack("<q", response[4:12])[0]
|
||||
assert struct.unpack("<i", response[12:16])[0] == self.__retrieved
|
||||
|
||||
number_returned = struct.unpack("<i", response[16:20])[0]
|
||||
self.__retrieved += number_returned
|
||||
self.__data = bson.to_dicts(response[20:])
|
||||
assert len(self.__data) == number_returned
|
||||
|
||||
if self.__id is None:
|
||||
# Query
|
||||
message = struct.pack("<i", self.__skip)
|
||||
message += struct.pack("<i", self.__limit)
|
||||
message += bson.BSON.from_dict(self.__query_spec())
|
||||
if self.__fields:
|
||||
message += bson.BSON.from_dict(self.__fields)
|
||||
|
||||
send_message(2004, message)
|
||||
elif self.__id != 0:
|
||||
# Get More
|
||||
limit = 0
|
||||
if self.__limit:
|
||||
if self.__limit > self.__retrieved:
|
||||
limit = self.__limit - self.__retrieved
|
||||
else:
|
||||
self.__die()
|
||||
return 0
|
||||
|
||||
message = struct.pack("<i", limit)
|
||||
message += struct.pack("<q", self.__id)
|
||||
|
||||
send_message(2005, message)
|
||||
|
||||
length = len(self.__data)
|
||||
if not length:
|
||||
self.__die()
|
||||
return length
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def next(self):
|
||||
if len(self.__data):
|
||||
return self.__collection.database()._fix_outgoing(self.__data.pop(0))
|
||||
if self._refresh():
|
||||
return self.__collection.database()._fix_outgoing(self.__data.pop(0))
|
||||
raise StopIteration
|
||||
77
database.py
Normal file
77
database.py
Normal file
@ -0,0 +1,77 @@
|
||||
"""Database level operations."""
|
||||
|
||||
import types
|
||||
|
||||
from collection import Collection
|
||||
from errors import InvalidName
|
||||
|
||||
class Database(object):
|
||||
"""A Mongo database.
|
||||
"""
|
||||
def __init__(self, connection, name):
|
||||
"""Get a database by connection and name.
|
||||
|
||||
Raises TypeError if name is not an instance of (str, unicode). Raises
|
||||
InvalidName if name is not a valid database name.
|
||||
|
||||
Arguments:
|
||||
- `connection`: a connection to Mongo
|
||||
- `name`: database name
|
||||
"""
|
||||
if not isinstance(name, types.StringTypes):
|
||||
raise TypeError("name must be an instance of (str, unicode)")
|
||||
|
||||
self.__check_name(name)
|
||||
|
||||
self.__name = unicode(name)
|
||||
self.__connection = connection
|
||||
|
||||
def __check_name(self, name):
|
||||
for invalid_char in " .$/\\":
|
||||
if invalid_char in name:
|
||||
raise InvalidName("database names cannot contain the character %r" % name)
|
||||
if not name:
|
||||
raise InvalidName("database name cannot be the empty string")
|
||||
|
||||
def connection(self):
|
||||
"""Get the database connection.
|
||||
"""
|
||||
return self.__connection
|
||||
|
||||
def name(self):
|
||||
"""Get the database name.
|
||||
"""
|
||||
return self.__name
|
||||
|
||||
def __cmp__(self, other):
|
||||
if isinstance(other, Database):
|
||||
return cmp((self.__connection, self.__name), (other.__connection, other.__name))
|
||||
return NotImplemented
|
||||
|
||||
def __repr__(self):
|
||||
return "Database(%r, %r)" % (self.__connection, self.__name)
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""Get a collection of this database by name.
|
||||
|
||||
Raises InvalidName if an invalid collection name is used.
|
||||
|
||||
Arguments:
|
||||
- `name`: the name of the collection to get
|
||||
"""
|
||||
return Collection(self, name)
|
||||
|
||||
def __getitem__(self, name):
|
||||
"""Get a collection of this database by name.
|
||||
|
||||
Raises InvalidName if an invalid collection name is used.
|
||||
|
||||
Arguments:
|
||||
- `name`: the name of the collection to get
|
||||
"""
|
||||
return self.__getattr__(name)
|
||||
|
||||
def collection_names(self):
|
||||
"""Get a list of all the collection names in this database.
|
||||
"""
|
||||
raise Exception("unimplemented")
|
||||
36
test/test_collection.py
Normal file
36
test/test_collection.py
Normal file
@ -0,0 +1,36 @@
|
||||
"""Test the collection module."""
|
||||
import unittest
|
||||
|
||||
from collection import Collection
|
||||
from test_connection import get_connection
|
||||
from errors import InvalidName
|
||||
|
||||
class TestCollection(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.db = get_connection().test
|
||||
|
||||
def test_collection(self):
|
||||
self.assertRaises(TypeError, Collection, self.db, 5)
|
||||
|
||||
def make_col(base, name):
|
||||
return base[name]
|
||||
|
||||
self.assertRaises(InvalidName, make_col, self.db, "")
|
||||
self.assertRaises(InvalidName, make_col, self.db, "te$t")
|
||||
self.assertRaises(InvalidName, make_col, self.db, ".test")
|
||||
self.assertRaises(InvalidName, make_col, self.db, "test.")
|
||||
self.assertRaises(InvalidName, make_col, self.db, "tes..t")
|
||||
self.assertRaises(InvalidName, make_col, self.db.test, "")
|
||||
self.assertRaises(InvalidName, make_col, self.db.test, "te$t")
|
||||
self.assertRaises(InvalidName, make_col, self.db.test, ".test")
|
||||
self.assertRaises(InvalidName, make_col, self.db.test, "test.")
|
||||
self.assertRaises(InvalidName, make_col, self.db.test, "tes..t")
|
||||
|
||||
self.assertTrue(isinstance(self.db.test, Collection))
|
||||
self.assertEqual(self.db.test, self.db["test"])
|
||||
self.assertEqual(self.db.test, Collection(self.db, "test"))
|
||||
self.assertEqual(self.db.test.mike, self.db["test.mike"])
|
||||
self.assertEqual(self.db.test["mike"], self.db["test.mike"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
16
test/test_cursor.py
Normal file
16
test/test_cursor.py
Normal file
@ -0,0 +1,16 @@
|
||||
"""Test the cursor module."""
|
||||
import unittest
|
||||
|
||||
from cursor import Cursor
|
||||
from database import Database
|
||||
from test_connection import get_connection
|
||||
|
||||
class TestCursor(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.db = Database(get_connection(), "test")
|
||||
|
||||
# TODO are there any tests that belong here? a lot of cursor stuff is hard to
|
||||
# test seperately...
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
36
test/test_database.py
Normal file
36
test/test_database.py
Normal file
@ -0,0 +1,36 @@
|
||||
"""Test the database module."""
|
||||
|
||||
import unittest
|
||||
|
||||
from errors import InvalidName
|
||||
from database import Database
|
||||
from connection import Connection
|
||||
from collection import Collection
|
||||
from test_connection import get_connection
|
||||
|
||||
class TestDatabase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.connection = get_connection()
|
||||
|
||||
def test_name(self):
|
||||
self.assertRaises(TypeError, Database, self.connection, 4)
|
||||
self.assertRaises(InvalidName, Database, self.connection, "my db")
|
||||
self.assertEqual("name", Database(self.connection, "name").name())
|
||||
|
||||
def test_cmp(self):
|
||||
self.assertNotEqual(Database(self.connection, "test"), Database(self.connection, "mike"))
|
||||
self.assertEqual(Database(self.connection, "test"), Database(self.connection, "test"))
|
||||
|
||||
def test_repr(self):
|
||||
self.assertEqual(repr(Database(self.connection, "test")),
|
||||
"Database(%r, u'test')" % self.connection)
|
||||
|
||||
def test_get_coll(self):
|
||||
db = Database(self.connection, "test")
|
||||
self.assertEqual(db.test, db["test"])
|
||||
self.assertEqual(db.test, Collection(db, "test"))
|
||||
self.assertNotEqual(db.test, Collection(db, "mike"))
|
||||
self.assertEqual(db.test.mike, db["test.mike"])
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue
Block a user