make gridfs thread safe

This commit is contained in:
Mike Dirolf 2009-07-06 18:12:49 -04:00
parent 76eb09977e
commit 5e968f0cad
3 changed files with 104 additions and 25 deletions

View File

@ -17,6 +17,7 @@
import types
import datetime
import math
from threading import Condition
from pymongo.son import SON
from pymongo.database import Database
@ -26,11 +27,19 @@ from pymongo.binary import Binary
from errors import CorruptGridFile
from pymongo import ASCENDING
# TODO we should use per-file reader-writer locks here instead,
# for performance. Unfortunately they aren't in the Python standard library.
_files_lock = Condition()
_open_files = {}
class GridFile(object):
"""A "file" stored in GridFS.
"""
# TODO should be able to create a GridFile given a Collection object instead
# of a database and collection name?
# TODO this whole file_spec thing is over-engineered. ought to be just
# filename.
def __init__(self, file_spec, database, mode="r", collection="fs"):
"""Open a "file" in GridFS.
@ -79,19 +88,30 @@ class GridFile(object):
raise ValueError("mode must be one of ('r', 'w')")
self.__collection = database[collection]
# TODO change this to ensure_index
self.__collection.chunks.create_index([("files_id", ASCENDING), ("n", ASCENDING)])
_files_lock.acquire()
grid_file = self.__collection.files.find_one(file_spec)
if grid_file:
self.__id = grid_file["_id"]
else:
if mode == "r":
_files_lock.release()
raise IOError("No such file: %r" % file_spec)
file_spec["length"] = 0
file_spec["uploadDate"] = datetime.datetime.utcnow()
file_spec.setdefault("chunkSize", 256000)
self.__id = self.__collection.files.insert(file_spec)["_id"]
# we use repr(self.__id) here because we need it to be string and
# filename gets tricky with renaming. this is a hack.
while repr(self.__id) in _open_files:
_files_lock.wait()
_open_files[repr(self.__id)] = True
_files_lock.release()
self.__mode = mode
if mode == "w":
self.__erase()
@ -207,6 +227,12 @@ class GridFile(object):
self.flush()
self.__closed = True
_files_lock.acquire()
if repr(self.__id) in _open_files:
del _open_files[repr(self.__id)]
_files_lock.notifyAll()
_files_lock.release()
def __assert_open(self, mode=None):
if mode and self.mode != mode:
raise ValueError("file must be open in mode %r" % mode)

View File

@ -126,24 +126,24 @@ class TestGridFile(unittest.TestCase):
self.assertRaises(TypeError, GridFile, {}, "hello")
self.assertRaises(TypeError, GridFile, {}, None)
self.assertRaises(TypeError, GridFile, {}, 5)
self.assert_(GridFile({}, self.db))
GridFile({}, self.db).close()
self.assertRaises(TypeError, GridFile, {}, self.db, None)
self.assertRaises(TypeError, GridFile, {}, self.db, 5)
self.assertRaises(TypeError, GridFile, {}, self.db, [])
self.assertRaises(ValueError, GridFile, {}, self.db, "m")
self.assertRaises(ValueError, GridFile, {}, self.db, u"m")
self.assert_(GridFile({}, self.db, "r"))
self.assert_(GridFile({}, self.db, u"r"))
self.assert_(GridFile({}, self.db, "w"))
self.assert_(GridFile({}, self.db, u"w"))
GridFile({}, self.db, "r").close()
GridFile({}, self.db, u"r").close()
GridFile({}, self.db, "w").close()
GridFile({}, self.db, u"w").close()
self.assertRaises(TypeError, GridFile, {}, self.db, "r", None)
self.assertRaises(TypeError, GridFile, {}, self.db, "r", 5)
self.assertRaises(TypeError, GridFile, {}, self.db, "r", [])
self.assertRaises(IOError, GridFile, {"filename": "mike"}, self.db)
self.assert_(GridFile({"filename": "test"}, self.db))
GridFile({"filename": "test"}, self.db).close()
def test_properties(self):
self.db.fs.files.remove({})
@ -199,6 +199,8 @@ class TestGridFile(unittest.TestCase):
a.md5 = "what"
self.assertRaises(AttributeError, set_md5)
a.close()
def test_rename(self):
self.db.fs.files.remove({})
self.db.fs.chunks.remove({})
@ -211,9 +213,10 @@ class TestGridFile(unittest.TestCase):
a.rename("mike")
self.assertEqual("mike", a.name)
a.close()
self.assertRaises(IOError, GridFile, {"filename": "test"}, self.db)
a = GridFile({"filename": "mike"}, self.db)
GridFile({"filename": "mike"}, self.db).close()
def test_flush_close(self):
self.db.fs.files.remove({})
@ -224,7 +227,10 @@ class TestGridFile(unittest.TestCase):
file.close()
file.close()
self.assertRaises(ValueError, file.write, "test")
self.assertEqual(GridFile({}, self.db).read(), "")
file = GridFile({}, self.db)
self.assertEqual(file.read(), "")
file.close()
file = GridFile({"filename": "test"}, self.db, "w")
file.write("mike")
@ -237,7 +243,9 @@ class TestGridFile(unittest.TestCase):
file.close()
file.close()
self.assertRaises(ValueError, file.write, "test")
self.assertEqual(GridFile({}, self.db).read(), "miketesthuh")
file = GridFile({}, self.db)
self.assertEqual(file.read(), "miketesthuh")
file.close()
def test_overwrite(self):
self.db.fs.files.remove({})
@ -251,7 +259,9 @@ class TestGridFile(unittest.TestCase):
file.write("mike")
file.close()
self.assertEqual(GridFile({}, self.db).read(), "mike")
f = GridFile({}, self.db)
self.assertEqual(f.read(), "mike")
f.close()
def test_multi_chunk_file(self):
self.db.fs.files.remove({})
@ -266,7 +276,9 @@ class TestGridFile(unittest.TestCase):
self.assertEqual(self.db.fs.files.find().count(), 1)
self.assertEqual(self.db.fs.chunks.find().count(), 2)
self.assertEqual(GridFile({}, self.db).read(), random_string)
f = GridFile({}, self.db)
self.assertEqual(f.read(), random_string)
f.close()
def test_small_chunks(self):
self.db.fs.files.remove({})
@ -288,11 +300,13 @@ class TestGridFile(unittest.TestCase):
self.assertEqual(self.db.fs.files.find().count(), self.files)
self.assertEqual(self.db.fs.chunks.find().count(), self.chunks)
self.assertEqual(GridFile({"filename": filename}, self.db).read(),
data)
f = GridFile({"filename": filename}, self.db)
self.assertEqual(f.read(), data)
f.close()
f = GridFile({"filename": filename}, self.db)
self.assertEqual(f.read(10) + f.read(10), data)
f.close()
return True
qcheck.check_unittest(self, helper,
@ -332,6 +346,7 @@ class TestGridFile(unittest.TestCase):
self.assertEqual(file.read(2), "rl")
self.assertEqual(file.read(2), "d")
self.assertEqual(file.read(2), "")
file.close()
if __name__ == "__main__":
unittest.main()

View File

@ -31,12 +31,25 @@ class JustWrite(threading.Thread):
self.fs = fs
def run(self):
for _ in range(100):
for _ in range(10):
file = self.fs.open("test", "w")
file.write("hello")
file.close()
class JustRead(threading.Thread):
def __init__(self, fs):
threading.Thread.__init__(self)
self.fs = fs
def run(self):
for _ in range(10):
file = self.fs.open("test")
assert file.read() == "hello"
file.close()
class TestGridfs(unittest.TestCase):
def setUp(self):
@ -94,8 +107,12 @@ class TestGridfs(unittest.TestCase):
self.assertEqual(["mike", "hello world"], self.fs.list())
self.assertEqual(self.db.fs.files.find().count(), 2)
self.assertEqual(self.db.fs.chunks.find().count(), 2)
self.assertEqual(self.fs.open("mike").read(), "hi")
self.assertEqual(self.fs.open("hello world").read(), "fly")
f = self.fs.open("mike")
self.assertEqual(f.read(), "hi")
f.close()
f = self.fs.open("hello world")
self.assertEqual(f.read(), "fly")
f.close()
self.assertRaises(IOError, self.fs.open, "test")
self.fs.remove({})
@ -148,12 +165,12 @@ class TestGridfs(unittest.TestCase):
self.fs.remove("test", "pymongo_test")
self.assertEqual(["mike", "hello world"], self.fs.list("pymongo_test"))
self.assertEqual(self.fs.open("mike",
collection="pymongo_test").read(),
"hi")
self.assertEqual(self.fs.open("hello world",
collection="pymongo_test").read(),
"fly")
f = self.fs.open("mike", collection="pymongo_test")
self.assertEqual(f.read(), "hi")
f.close()
f = self.fs.open("hello world", collection="pymongo_test")
self.assertEqual(f.read(), "fly")
f.close()
self.fs.remove({}, "pymongo_test")
@ -161,10 +178,31 @@ class TestGridfs(unittest.TestCase):
self.assertEqual(self.db.pymongo_test.files.find().count(), 0)
self.assertEqual(self.db.pymongo_test.chunks.find().count(), 0)
def test_threading(self):
def test_threaded_reads(self):
f = self.fs.open("test", "w")
f.write("hello")
f.close()
threads = []
for i in range(10):
t = JustWrite(self.fs)
t.start()
threads.append(JustRead(self.fs))
threads[i].start()
for i in range(10):
threads[i].join()
def test_threaded_writes(self):
threads = []
for i in range(10):
threads.append(JustWrite(self.fs))
threads[i].start()
for i in range(10):
threads[i].join()
f = self.fs.open("test")
self.assertEqual(f.read(), "hello")
f.close()
if __name__ == "__main__":
unittest.main()