basic support for find and find_one. Cursor. database connections must specify database name.

This commit is contained in:
Mike Dirolf 2009-01-12 15:45:29 -05:00
parent 5da5102100
commit 7bdc7646c6

213
mongo.py
View File

@ -25,29 +25,36 @@ class InvalidCollection(ValueError):
class Mongo(object):
"""A connection to a Mongo database.
"""
def __init__(self, host="localhost", port=27017):
def __init__(self, name, host="localhost", port=27017):
"""Open a new connection to the database at host:port.
Raises TypeError if host is not an instance of string or port is
Raises TypeError if name or host is not an instance of string or port is
not an instance of int. Raises ConnectionException if the connection
cannot be made.
Arguments:
- `name`: the name of the database to connect to
- `host` (optional): the hostname or IPv4 address of the database to
connect to
connect to
- `port` (optional): the port number on which to connect
"""
if not isinstance(name, types.StringTypes):
raise TypeError("name must be an instance of (str, unicode)")
if not isinstance(host, types.StringType):
raise TypeError("host must be an instance of str")
if not isinstance(port, types.IntType):
raise TypeError("port must be an instance of int")
self.__name = name
self.__host = host
self.__port = port
self.__id = 1
self.__connect()
def name(self):
return self.__name
def __connect(self):
"""(Re-)connect to the database."""
try:
@ -58,7 +65,7 @@ class Mongo(object):
(self.__host, self.__port, traceback.format_exc()))
def _send_message(self, operation, data):
"""Say something to the database. Return the response.
"""Say something to the database.
Arguments:
- `operation`: the opcode of the message
@ -82,13 +89,41 @@ class Mongo(object):
raise ConnectionException("connection closed")
total_sent += sent
return self.__id - 1
def _receive_message(self, operation, request_id):
"""Receive a message from the database.
Returns the message body. Asserts that the message uses the given opcode
and request id.
Arguments:
- `operation`: the opcode of the message
- `request_id`: the request id that the message should be in response to
"""
def receive(length):
message = ""
while len(message) < length:
chunk = self.__connection.recv(length - len(message))
if chunk == "":
raise ConnectionException("connection closed")
message += chunk
return message
header = receive(16)
length = struct.unpack("<i", header[:4])[0]
assert request_id == struct.unpack("<i", header[8:12])[0]
assert operation == struct.unpack("<i", header[12:])[0]
return receive(length - 16)
def __cmp__(self, other):
if isinstance(other, Mongo):
return cmp((self.__host, self.__port), (other.__host, other.__port))
return NotImplemented
def __repr__(self):
return "Mongo(" + repr(self.__host) + ", " + repr(self.__port) + ")"
return "Mongo(%r, %r, %r)" % (self.__name, self.__host, self.__port)
def __getattr__(self, name):
"""Get a collection of this database by name.
@ -153,6 +188,19 @@ class Collection(object):
return cmp((self.__database, self.__collection_name), (other.__database, other.__collection_name))
return NotImplemented
def _full_name(self):
return "%s.%s" % (self.__database.name(), self.__collection_name)
def _send_message(self, operation, data):
"""Wrap up a message and send it.
"""
# reserved int, full collection name, message data
message = "\x00\x00\x00\x00%s%s" % (bson._make_c_string(self._full_name()), data)
return self.__database._send_message(operation, message)
def database(self):
return self.__database
def save(self, to_save):
"""Save a SON object in this collection.
@ -171,13 +219,128 @@ class Collection(object):
# TODO possibly add _ns?
# TODO support for named databases instead of just prepending test.
message_data = "\x00\x00\x00\x00%s%s" % (bson._make_c_string("test." + self.__collection_name), bson.BSON.from_dict(to_save))
self._send_message(2002, bson.BSON.from_dict(to_save))
self.__database._send_message(2002, message_data)
return to_save["_id"]
def find_one(self, spec):
pass
def find_one(self, spec_or_object_id=SON()):
"""Get a single object from the database.
Raises TypeError if the argument is of an improper type. Returns a
single SON object, or None if no result is found.
Arguments:
- `spec_or_object_id` (optional): a SON object specifying elements
which must be present for a document to be included in the result
set OR an instance of ObjectId to be used as the value for an _id
query
"""
spec = spec_or_object_id
if isinstance(spec, ObjectId):
spec = SON({"_id": spec})
for result in self.find(spec, limit=1):
return result
return None
def find(self, spec=SON(), fields=[], skip=0, limit=0):
"""Query the database.
Raises TypeError if any of the arguments are of improper type. Returns
an instance of Cursor corresponding to this query.
Arguments:
- `spec` (optional): a SON object specifying elements which must be
present for a document to be included in the result set
- `fields` (optional): a list of field names that should be returned
in the result set
- `skip` (optional): the number of documents to omit (from the start of
the result set) when returning the results
- `limit` (optional): the maximum number of results to return in the
first reply message, or 0 for the default return size
"""
if not isinstance(spec, (types.DictType, SON)):
raise TypeError("spec must be an instance of (dict, SON)")
if not isinstance(fields, types.ListType):
raise TypeError("fields must be an instance of list")
if not isinstance(skip, types.IntType):
raise TypeError("skip must be an instance of int")
if not isinstance(limit, types.IntType):
raise TypeError("limit must be an instance of int")
return_fields = len(fields) and SON() or None
for field in fields:
if not isinstance(field, types.StringTypes):
raise TypeError("fields must be a list of key names as (string, unicode)")
return_fields[field] = 1
return Cursor(self, spec, return_fields, skip, limit)
class Cursor(object):
"""A cursor / iterator over Mongo query results.
"""
def __init__(self, collection, spec, fields, skip, limit):
"""Create a new cursor.
Generally not needed to be used by application developers.
"""
self.__collection = collection
self.__spec = spec
self.__fields = fields
self.__skip = skip
self.__limit = limit
self.__data = []
self.__id = None
def _refresh(self):
"""Refreshes the cursor with more data from Mongo.
Returns the length of self.__data after refresh. Will exit early if
self.__data is already non-empty.
"""
if len(self.__data):
return len(self.__data)
if self.__id is None:
# Initial query
message = ""
message += struct.pack("<i", self.__skip)
message += struct.pack("<i", self.__limit)
message += bson.BSON.from_dict(self.__spec)
if self.__fields:
message += bson.BSON.from_dict(self.__fields)
# TODO the send and receive here should be synchronized...
request_id = self.__collection._send_message(2004, message)
response = self.__collection.database()._receive_message(1, request_id)
# TODO handle non-zero response flags
assert struct.unpack("<i", response[:4])[0] == 0
self.__id = struct.unpack("<q", response[4:12])[0]
# starting from
assert struct.unpack("<i", response[12:16])[0] == self.__skip
number_returned = struct.unpack("<i", response[16:20])[0]
self.__data = bson.to_dicts(response[20:])
assert len(self.__data) == number_returned
else:
raise Exception("unimplemented...")
return len(self.__data)
def __iter__(self):
return self
def next(self):
if len(self.__data):
return self.__data.pop(0)
if self._refresh():
return self.__data.pop(0)
raise StopIteration
class TestMongo(unittest.TestCase):
def setUp(self):
@ -189,24 +352,28 @@ class TestMongo(unittest.TestCase):
self.assertRaises(TypeError, Mongo, 1.14)
self.assertRaises(TypeError, Mongo, None)
self.assertRaises(TypeError, Mongo, [])
self.assertRaises(TypeError, Mongo, "localhost", "27017")
self.assertRaises(TypeError, Mongo, "localhost", 1.14)
self.assertRaises(TypeError, Mongo, "localhost", None)
self.assertRaises(TypeError, Mongo, "localhost", [])
self.assertRaises(TypeError, Mongo, "test", 1)
self.assertRaises(TypeError, Mongo, "test", 1.14)
self.assertRaises(TypeError, Mongo, "test", None)
self.assertRaises(TypeError, Mongo, "test", [])
self.assertRaises(TypeError, Mongo, "test", "localhost", "27017")
self.assertRaises(TypeError, Mongo, "test", "localhost", 1.14)
self.assertRaises(TypeError, Mongo, "test", "localhost", None)
self.assertRaises(TypeError, Mongo, "test", "localhost", [])
self.assertRaises(ConnectionException, Mongo, "somedomainthatdoesntexist.org")
self.assertRaises(ConnectionException, Mongo, self.host, 123456789)
self.assertRaises(ConnectionException, Mongo, "test", "somedomainthatdoesntexist.org")
self.assertRaises(ConnectionException, Mongo, "test", self.host, 123456789)
self.assertTrue(Mongo(self.host, self.port))
self.assertTrue(Mongo("test", self.host, self.port))
def test_repr(self):
self.assertEqual(repr(Mongo(self.host, self.port)),
"Mongo('%s', %s)" % (self.host, self.port))
self.assertEqual(repr(Mongo(self.host, self.port).test),
"Collection(Mongo('%s', %s), 'test')" % (self.host, self.port))
self.assertEqual(repr(Mongo("test", self.host, self.port)),
"Mongo('test', '%s', %s)" % (self.host, self.port))
self.assertEqual(repr(Mongo("test", self.host, self.port).test),
"Collection(Mongo('test', '%s', %s), 'test')" % (self.host, self.port))
def test_collection(self):
db = Mongo(self.host, self.port)
db = Mongo(u"test", self.host, self.port)
self.assertRaises(TypeError, Collection, db, 5)
self.assertRaises(TypeError, Collection, 5, "test")
@ -232,7 +399,7 @@ class TestMongo(unittest.TestCase):
self.assertEqual(db.test["mike"], db["test.mike"])
def test_save_find_one(self):
db = Mongo(self.host, self.port)
db = Mongo("test", self.host, self.port)
a_doc = SON({"hello": u"world"})
db.test.save(a_doc)