diff --git a/gridfs/grid_file.py b/gridfs/grid_file.py index ad25b9bb3..0948fca72 100644 --- a/gridfs/grid_file.py +++ b/gridfs/grid_file.py @@ -17,6 +17,7 @@ import types import datetime import math +from threading import Condition from pymongo.son import SON from pymongo.database import Database @@ -26,11 +27,19 @@ from pymongo.binary import Binary from errors import CorruptGridFile from pymongo import ASCENDING +# TODO we should use per-file reader-writer locks here instead, +# for performance. Unfortunately they aren't in the Python standard library. +_files_lock = Condition() +_open_files = {} + + class GridFile(object): """A "file" stored in GridFS. """ # TODO should be able to create a GridFile given a Collection object instead # of a database and collection name? + # TODO this whole file_spec thing is over-engineered. ought to be just + # filename. def __init__(self, file_spec, database, mode="r", collection="fs"): """Open a "file" in GridFS. @@ -79,19 +88,30 @@ class GridFile(object): raise ValueError("mode must be one of ('r', 'w')") self.__collection = database[collection] + # TODO change this to ensure_index self.__collection.chunks.create_index([("files_id", ASCENDING), ("n", ASCENDING)]) + _files_lock.acquire() + grid_file = self.__collection.files.find_one(file_spec) if grid_file: self.__id = grid_file["_id"] else: if mode == "r": + _files_lock.release() raise IOError("No such file: %r" % file_spec) file_spec["length"] = 0 file_spec["uploadDate"] = datetime.datetime.utcnow() file_spec.setdefault("chunkSize", 256000) self.__id = self.__collection.files.insert(file_spec)["_id"] + # we use repr(self.__id) here because we need it to be string and + # filename gets tricky with renaming. this is a hack. + while repr(self.__id) in _open_files: + _files_lock.wait() + _open_files[repr(self.__id)] = True + _files_lock.release() + self.__mode = mode if mode == "w": self.__erase() @@ -207,6 +227,12 @@ class GridFile(object): self.flush() self.__closed = True + _files_lock.acquire() + if repr(self.__id) in _open_files: + del _open_files[repr(self.__id)] + _files_lock.notifyAll() + _files_lock.release() + def __assert_open(self, mode=None): if mode and self.mode != mode: raise ValueError("file must be open in mode %r" % mode) diff --git a/test/test_grid_file.py b/test/test_grid_file.py index 0b38d7424..4824bf581 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -126,24 +126,24 @@ class TestGridFile(unittest.TestCase): self.assertRaises(TypeError, GridFile, {}, "hello") self.assertRaises(TypeError, GridFile, {}, None) self.assertRaises(TypeError, GridFile, {}, 5) - self.assert_(GridFile({}, self.db)) + GridFile({}, self.db).close() self.assertRaises(TypeError, GridFile, {}, self.db, None) self.assertRaises(TypeError, GridFile, {}, self.db, 5) self.assertRaises(TypeError, GridFile, {}, self.db, []) self.assertRaises(ValueError, GridFile, {}, self.db, "m") self.assertRaises(ValueError, GridFile, {}, self.db, u"m") - self.assert_(GridFile({}, self.db, "r")) - self.assert_(GridFile({}, self.db, u"r")) - self.assert_(GridFile({}, self.db, "w")) - self.assert_(GridFile({}, self.db, u"w")) + GridFile({}, self.db, "r").close() + GridFile({}, self.db, u"r").close() + GridFile({}, self.db, "w").close() + GridFile({}, self.db, u"w").close() self.assertRaises(TypeError, GridFile, {}, self.db, "r", None) self.assertRaises(TypeError, GridFile, {}, self.db, "r", 5) self.assertRaises(TypeError, GridFile, {}, self.db, "r", []) self.assertRaises(IOError, GridFile, {"filename": "mike"}, self.db) - self.assert_(GridFile({"filename": "test"}, self.db)) + GridFile({"filename": "test"}, self.db).close() def test_properties(self): self.db.fs.files.remove({}) @@ -199,6 +199,8 @@ class TestGridFile(unittest.TestCase): a.md5 = "what" self.assertRaises(AttributeError, set_md5) + a.close() + def test_rename(self): self.db.fs.files.remove({}) self.db.fs.chunks.remove({}) @@ -211,9 +213,10 @@ class TestGridFile(unittest.TestCase): a.rename("mike") self.assertEqual("mike", a.name) + a.close() self.assertRaises(IOError, GridFile, {"filename": "test"}, self.db) - a = GridFile({"filename": "mike"}, self.db) + GridFile({"filename": "mike"}, self.db).close() def test_flush_close(self): self.db.fs.files.remove({}) @@ -224,7 +227,10 @@ class TestGridFile(unittest.TestCase): file.close() file.close() self.assertRaises(ValueError, file.write, "test") - self.assertEqual(GridFile({}, self.db).read(), "") + + file = GridFile({}, self.db) + self.assertEqual(file.read(), "") + file.close() file = GridFile({"filename": "test"}, self.db, "w") file.write("mike") @@ -237,7 +243,9 @@ class TestGridFile(unittest.TestCase): file.close() file.close() self.assertRaises(ValueError, file.write, "test") - self.assertEqual(GridFile({}, self.db).read(), "miketesthuh") + file = GridFile({}, self.db) + self.assertEqual(file.read(), "miketesthuh") + file.close() def test_overwrite(self): self.db.fs.files.remove({}) @@ -251,7 +259,9 @@ class TestGridFile(unittest.TestCase): file.write("mike") file.close() - self.assertEqual(GridFile({}, self.db).read(), "mike") + f = GridFile({}, self.db) + self.assertEqual(f.read(), "mike") + f.close() def test_multi_chunk_file(self): self.db.fs.files.remove({}) @@ -266,7 +276,9 @@ class TestGridFile(unittest.TestCase): self.assertEqual(self.db.fs.files.find().count(), 1) self.assertEqual(self.db.fs.chunks.find().count(), 2) - self.assertEqual(GridFile({}, self.db).read(), random_string) + f = GridFile({}, self.db) + self.assertEqual(f.read(), random_string) + f.close() def test_small_chunks(self): self.db.fs.files.remove({}) @@ -288,11 +300,13 @@ class TestGridFile(unittest.TestCase): self.assertEqual(self.db.fs.files.find().count(), self.files) self.assertEqual(self.db.fs.chunks.find().count(), self.chunks) - self.assertEqual(GridFile({"filename": filename}, self.db).read(), - data) + f = GridFile({"filename": filename}, self.db) + self.assertEqual(f.read(), data) + f.close() f = GridFile({"filename": filename}, self.db) self.assertEqual(f.read(10) + f.read(10), data) + f.close() return True qcheck.check_unittest(self, helper, @@ -332,6 +346,7 @@ class TestGridFile(unittest.TestCase): self.assertEqual(file.read(2), "rl") self.assertEqual(file.read(2), "d") self.assertEqual(file.read(2), "") + file.close() if __name__ == "__main__": unittest.main() diff --git a/test/test_gridfs.py b/test/test_gridfs.py index 796cb59b6..c0f1e482e 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -31,12 +31,25 @@ class JustWrite(threading.Thread): self.fs = fs def run(self): - for _ in range(100): + for _ in range(10): file = self.fs.open("test", "w") file.write("hello") file.close() +class JustRead(threading.Thread): + + def __init__(self, fs): + threading.Thread.__init__(self) + self.fs = fs + + def run(self): + for _ in range(10): + file = self.fs.open("test") + assert file.read() == "hello" + file.close() + + class TestGridfs(unittest.TestCase): def setUp(self): @@ -94,8 +107,12 @@ class TestGridfs(unittest.TestCase): self.assertEqual(["mike", "hello world"], self.fs.list()) self.assertEqual(self.db.fs.files.find().count(), 2) self.assertEqual(self.db.fs.chunks.find().count(), 2) - self.assertEqual(self.fs.open("mike").read(), "hi") - self.assertEqual(self.fs.open("hello world").read(), "fly") + f = self.fs.open("mike") + self.assertEqual(f.read(), "hi") + f.close() + f = self.fs.open("hello world") + self.assertEqual(f.read(), "fly") + f.close() self.assertRaises(IOError, self.fs.open, "test") self.fs.remove({}) @@ -148,12 +165,12 @@ class TestGridfs(unittest.TestCase): self.fs.remove("test", "pymongo_test") self.assertEqual(["mike", "hello world"], self.fs.list("pymongo_test")) - self.assertEqual(self.fs.open("mike", - collection="pymongo_test").read(), - "hi") - self.assertEqual(self.fs.open("hello world", - collection="pymongo_test").read(), - "fly") + f = self.fs.open("mike", collection="pymongo_test") + self.assertEqual(f.read(), "hi") + f.close() + f = self.fs.open("hello world", collection="pymongo_test") + self.assertEqual(f.read(), "fly") + f.close() self.fs.remove({}, "pymongo_test") @@ -161,10 +178,31 @@ class TestGridfs(unittest.TestCase): self.assertEqual(self.db.pymongo_test.files.find().count(), 0) self.assertEqual(self.db.pymongo_test.chunks.find().count(), 0) - def test_threading(self): + def test_threaded_reads(self): + f = self.fs.open("test", "w") + f.write("hello") + f.close() + + threads = [] for i in range(10): - t = JustWrite(self.fs) - t.start() + threads.append(JustRead(self.fs)) + threads[i].start() + + for i in range(10): + threads[i].join() + + def test_threaded_writes(self): + threads = [] + for i in range(10): + threads.append(JustWrite(self.fs)) + threads[i].start() + + for i in range(10): + threads[i].join() + + f = self.fs.open("test") + self.assertEqual(f.read(), "hello") + f.close() if __name__ == "__main__": unittest.main()