diff --git a/gridfs/grid_file.py b/gridfs/grid_file.py index 0f6eaacac..6156d18b7 100644 --- a/gridfs/grid_file.py +++ b/gridfs/grid_file.py @@ -16,6 +16,7 @@ import types import datetime +import math from pymongo.son import SON from pymongo.database import Database @@ -91,6 +92,10 @@ class GridFile(object): if mode == "w": self.__erase() self.__current_chunk = None + elif mode == "r": + self.__position = 0 + self.__current_chunk = file["next"] + self.__read_buffer = "" self.__closed = False def __erase(self): @@ -199,18 +204,21 @@ class GridFile(object): """ self.__assert_open("r") - if size < 0 or size > self.length: - size = self.length + if size == 0: + return "" - bytes = "" - next = self.next - chunk_number = 0 + remainder = self.length - self.__position + if size < 0 or size > remainder: + size = remainder + + bytes = self.__read_buffer + chunk_number = math.floor(self.__position / self.chunk_size) while len(bytes) < size: - if not next: + if not self.__current_chunk: raise CorruptGridFile("incorrect length for file: %r" % self) - chunk = self.__collection.database().dereference(next) + chunk = self.__collection.database().dereference(self.__current_chunk) if not chunk: - raise CorruptGridFile("could not dereference: %r" % next) + raise CorruptGridFile("could not dereference: %r" % self.__current_chunk) if chunk["cn"] != chunk_number: raise CorruptGridFile("incorrect chunk number: %r, should be: %r" % (chunk["cn"], chunk_number)) @@ -218,10 +226,12 @@ class GridFile(object): bytes += chunk["data"] chunk_number += 1 - next = chunk["next"] + self.__current_chunk = chunk["next"] - bytes = bytes[:size] - return bytes + self.__position += size + to_return = bytes[:size] + self.__read_buffer = bytes[size:] + return to_return # TODO should support writing unicode to a file. this means that files will # need to have an encoding attribute. diff --git a/test/test_grid_file.py b/test/test_grid_file.py index d95b8632a..03217b32a 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -233,6 +233,9 @@ class TestGridFile(unittest.TestCase): self.assertEqual(self.db._chunks.find().count(), self.chunks) self.assertEqual(GridFile({"filename": filename}, self.db).read(), data) + + f = GridFile({"filename": filename}, self.db) + self.assertEqual(f.read(10) + f.read(10), data) return True qcheck.check_unittest(self, helper, qcheck.gen_string(qcheck.gen_range(0, 20))) @@ -255,5 +258,22 @@ class TestGridFile(unittest.TestCase): self.assertRaises(ValueError, file.read) self.assertRaises(ValueError, file.write, "hello") + def test_multiple_reads(self): + self.db._files.remove({}) + self.db._chunks.remove({}) + + file = GridFile({"filename": "test"}, self.db, "w") + file.write("hello world") + file.close() + + file = GridFile({"filename": "test"}, self.db, "r") + self.assertEqual(file.read(2), "he") + self.assertEqual(file.read(2), "ll") + self.assertEqual(file.read(2), "o ") + self.assertEqual(file.read(2), "wo") + self.assertEqual(file.read(2), "rl") + self.assertEqual(file.read(2), "d") + self.assertEqual(file.read(2), "") + if __name__ == "__main__": unittest.main()