This commit is contained in:
Mike Dirolf 2009-01-14 15:51:07 -05:00
parent 5ae22d7292
commit fa8e276509

View File

@ -18,6 +18,10 @@ class DatabaseException(Exception):
"""Raised when a database operation fails.
"""
class InvalidOperation(Exception):
"""Raised when a client attempts to perform an invalid operation.
"""
class ConnectionException(IOError):
"""Raised when a connection to the database cannot be made or is lost.
"""
@ -446,6 +450,28 @@ class Cursor(object):
self.__collection.database()._kill_cursor(self.__id)
self.__killed = True
def __check_okay_to_chain(self):
"""Check if it is okay to chain more options onto this cursor.
"""
if self.__retrieved or self.__id is not None:
raise InvalidOperation("cannot set options after executing query")
def limit(self, limit):
"""Limits the number of results to be returned by this cursor.
Raises TypeError if limit is not an instance of int. Raises
InvalidOperation if this cursor has already been used.
Arguments:
- `limit`: the number of results to return
"""
if not isinstance(limit, types.IntType):
raise TypeError("limit must be an int")
self.__check_okay_to_chain()
self.__limit = limit
return self
def _refresh(self):
"""Refreshes the cursor with more data from Mongo.
@ -488,6 +514,7 @@ class Cursor(object):
limit = self.__limit - self.__retrieved
else:
self._die()
return 0
message = struct.pack("<i", limit)
message += struct.pack("<q", self.__id)
@ -687,6 +714,55 @@ class TestMongo(unittest.TestCase):
(u"key", SON([(u"hello", -1),
(u"world", 1)]))]))
def test_limit(self):
db = Mongo("test", self.host, self.port)
self.assertRaises(TypeError, db.test.find().limit, None)
self.assertRaises(TypeError, db.test.find().limit, "hello")
self.assertRaises(TypeError, db.test.find().limit, 5.5)
db.test.remove({})
for i in range(100):
db.test.save({"x": i})
count = 0
for _ in db.test.find():
count += 1
self.assertEqual(count, 100)
count = 0
for _ in db.test.find().limit(20):
count += 1
self.assertEqual(count, 20)
count = 0
for _ in db.test.find().limit(99):
count += 1
self.assertEqual(count, 99)
count = 0
for _ in db.test.find().limit(1):
count += 1
self.assertEqual(count, 1)
count = 0
for _ in db.test.find().limit(0):
count += 1
self.assertEqual(count, 100)
count = 0
for _ in db.test.find().limit(0).limit(50).limit(10):
count += 1
self.assertEqual(count, 10)
a = db.test.find()
a.limit(10)
for _ in a:
break
self.assertRaises(InvalidOperation, a.limit, 5)
if __name__ == "__main__":
unittest.main()