make gridfs thread safe
This commit is contained in:
parent
76eb09977e
commit
5e968f0cad
@ -17,6 +17,7 @@
|
||||
import types
|
||||
import datetime
|
||||
import math
|
||||
from threading import Condition
|
||||
|
||||
from pymongo.son import SON
|
||||
from pymongo.database import Database
|
||||
@ -26,11 +27,19 @@ from pymongo.binary import Binary
|
||||
from errors import CorruptGridFile
|
||||
from pymongo import ASCENDING
|
||||
|
||||
# TODO we should use per-file reader-writer locks here instead,
|
||||
# for performance. Unfortunately they aren't in the Python standard library.
|
||||
_files_lock = Condition()
|
||||
_open_files = {}
|
||||
|
||||
|
||||
class GridFile(object):
|
||||
"""A "file" stored in GridFS.
|
||||
"""
|
||||
# TODO should be able to create a GridFile given a Collection object instead
|
||||
# of a database and collection name?
|
||||
# TODO this whole file_spec thing is over-engineered. ought to be just
|
||||
# filename.
|
||||
def __init__(self, file_spec, database, mode="r", collection="fs"):
|
||||
"""Open a "file" in GridFS.
|
||||
|
||||
@ -79,19 +88,30 @@ class GridFile(object):
|
||||
raise ValueError("mode must be one of ('r', 'w')")
|
||||
|
||||
self.__collection = database[collection]
|
||||
# TODO change this to ensure_index
|
||||
self.__collection.chunks.create_index([("files_id", ASCENDING), ("n", ASCENDING)])
|
||||
|
||||
_files_lock.acquire()
|
||||
|
||||
grid_file = self.__collection.files.find_one(file_spec)
|
||||
if grid_file:
|
||||
self.__id = grid_file["_id"]
|
||||
else:
|
||||
if mode == "r":
|
||||
_files_lock.release()
|
||||
raise IOError("No such file: %r" % file_spec)
|
||||
file_spec["length"] = 0
|
||||
file_spec["uploadDate"] = datetime.datetime.utcnow()
|
||||
file_spec.setdefault("chunkSize", 256000)
|
||||
self.__id = self.__collection.files.insert(file_spec)["_id"]
|
||||
|
||||
# we use repr(self.__id) here because we need it to be string and
|
||||
# filename gets tricky with renaming. this is a hack.
|
||||
while repr(self.__id) in _open_files:
|
||||
_files_lock.wait()
|
||||
_open_files[repr(self.__id)] = True
|
||||
_files_lock.release()
|
||||
|
||||
self.__mode = mode
|
||||
if mode == "w":
|
||||
self.__erase()
|
||||
@ -207,6 +227,12 @@ class GridFile(object):
|
||||
self.flush()
|
||||
self.__closed = True
|
||||
|
||||
_files_lock.acquire()
|
||||
if repr(self.__id) in _open_files:
|
||||
del _open_files[repr(self.__id)]
|
||||
_files_lock.notifyAll()
|
||||
_files_lock.release()
|
||||
|
||||
def __assert_open(self, mode=None):
|
||||
if mode and self.mode != mode:
|
||||
raise ValueError("file must be open in mode %r" % mode)
|
||||
|
||||
@ -126,24 +126,24 @@ class TestGridFile(unittest.TestCase):
|
||||
self.assertRaises(TypeError, GridFile, {}, "hello")
|
||||
self.assertRaises(TypeError, GridFile, {}, None)
|
||||
self.assertRaises(TypeError, GridFile, {}, 5)
|
||||
self.assert_(GridFile({}, self.db))
|
||||
GridFile({}, self.db).close()
|
||||
|
||||
self.assertRaises(TypeError, GridFile, {}, self.db, None)
|
||||
self.assertRaises(TypeError, GridFile, {}, self.db, 5)
|
||||
self.assertRaises(TypeError, GridFile, {}, self.db, [])
|
||||
self.assertRaises(ValueError, GridFile, {}, self.db, "m")
|
||||
self.assertRaises(ValueError, GridFile, {}, self.db, u"m")
|
||||
self.assert_(GridFile({}, self.db, "r"))
|
||||
self.assert_(GridFile({}, self.db, u"r"))
|
||||
self.assert_(GridFile({}, self.db, "w"))
|
||||
self.assert_(GridFile({}, self.db, u"w"))
|
||||
GridFile({}, self.db, "r").close()
|
||||
GridFile({}, self.db, u"r").close()
|
||||
GridFile({}, self.db, "w").close()
|
||||
GridFile({}, self.db, u"w").close()
|
||||
|
||||
self.assertRaises(TypeError, GridFile, {}, self.db, "r", None)
|
||||
self.assertRaises(TypeError, GridFile, {}, self.db, "r", 5)
|
||||
self.assertRaises(TypeError, GridFile, {}, self.db, "r", [])
|
||||
|
||||
self.assertRaises(IOError, GridFile, {"filename": "mike"}, self.db)
|
||||
self.assert_(GridFile({"filename": "test"}, self.db))
|
||||
GridFile({"filename": "test"}, self.db).close()
|
||||
|
||||
def test_properties(self):
|
||||
self.db.fs.files.remove({})
|
||||
@ -199,6 +199,8 @@ class TestGridFile(unittest.TestCase):
|
||||
a.md5 = "what"
|
||||
self.assertRaises(AttributeError, set_md5)
|
||||
|
||||
a.close()
|
||||
|
||||
def test_rename(self):
|
||||
self.db.fs.files.remove({})
|
||||
self.db.fs.chunks.remove({})
|
||||
@ -211,9 +213,10 @@ class TestGridFile(unittest.TestCase):
|
||||
|
||||
a.rename("mike")
|
||||
self.assertEqual("mike", a.name)
|
||||
a.close()
|
||||
|
||||
self.assertRaises(IOError, GridFile, {"filename": "test"}, self.db)
|
||||
a = GridFile({"filename": "mike"}, self.db)
|
||||
GridFile({"filename": "mike"}, self.db).close()
|
||||
|
||||
def test_flush_close(self):
|
||||
self.db.fs.files.remove({})
|
||||
@ -224,7 +227,10 @@ class TestGridFile(unittest.TestCase):
|
||||
file.close()
|
||||
file.close()
|
||||
self.assertRaises(ValueError, file.write, "test")
|
||||
self.assertEqual(GridFile({}, self.db).read(), "")
|
||||
|
||||
file = GridFile({}, self.db)
|
||||
self.assertEqual(file.read(), "")
|
||||
file.close()
|
||||
|
||||
file = GridFile({"filename": "test"}, self.db, "w")
|
||||
file.write("mike")
|
||||
@ -237,7 +243,9 @@ class TestGridFile(unittest.TestCase):
|
||||
file.close()
|
||||
file.close()
|
||||
self.assertRaises(ValueError, file.write, "test")
|
||||
self.assertEqual(GridFile({}, self.db).read(), "miketesthuh")
|
||||
file = GridFile({}, self.db)
|
||||
self.assertEqual(file.read(), "miketesthuh")
|
||||
file.close()
|
||||
|
||||
def test_overwrite(self):
|
||||
self.db.fs.files.remove({})
|
||||
@ -251,7 +259,9 @@ class TestGridFile(unittest.TestCase):
|
||||
file.write("mike")
|
||||
file.close()
|
||||
|
||||
self.assertEqual(GridFile({}, self.db).read(), "mike")
|
||||
f = GridFile({}, self.db)
|
||||
self.assertEqual(f.read(), "mike")
|
||||
f.close()
|
||||
|
||||
def test_multi_chunk_file(self):
|
||||
self.db.fs.files.remove({})
|
||||
@ -266,7 +276,9 @@ class TestGridFile(unittest.TestCase):
|
||||
self.assertEqual(self.db.fs.files.find().count(), 1)
|
||||
self.assertEqual(self.db.fs.chunks.find().count(), 2)
|
||||
|
||||
self.assertEqual(GridFile({}, self.db).read(), random_string)
|
||||
f = GridFile({}, self.db)
|
||||
self.assertEqual(f.read(), random_string)
|
||||
f.close()
|
||||
|
||||
def test_small_chunks(self):
|
||||
self.db.fs.files.remove({})
|
||||
@ -288,11 +300,13 @@ class TestGridFile(unittest.TestCase):
|
||||
self.assertEqual(self.db.fs.files.find().count(), self.files)
|
||||
self.assertEqual(self.db.fs.chunks.find().count(), self.chunks)
|
||||
|
||||
self.assertEqual(GridFile({"filename": filename}, self.db).read(),
|
||||
data)
|
||||
f = GridFile({"filename": filename}, self.db)
|
||||
self.assertEqual(f.read(), data)
|
||||
f.close()
|
||||
|
||||
f = GridFile({"filename": filename}, self.db)
|
||||
self.assertEqual(f.read(10) + f.read(10), data)
|
||||
f.close()
|
||||
return True
|
||||
|
||||
qcheck.check_unittest(self, helper,
|
||||
@ -332,6 +346,7 @@ class TestGridFile(unittest.TestCase):
|
||||
self.assertEqual(file.read(2), "rl")
|
||||
self.assertEqual(file.read(2), "d")
|
||||
self.assertEqual(file.read(2), "")
|
||||
file.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -31,12 +31,25 @@ class JustWrite(threading.Thread):
|
||||
self.fs = fs
|
||||
|
||||
def run(self):
|
||||
for _ in range(100):
|
||||
for _ in range(10):
|
||||
file = self.fs.open("test", "w")
|
||||
file.write("hello")
|
||||
file.close()
|
||||
|
||||
|
||||
class JustRead(threading.Thread):
|
||||
|
||||
def __init__(self, fs):
|
||||
threading.Thread.__init__(self)
|
||||
self.fs = fs
|
||||
|
||||
def run(self):
|
||||
for _ in range(10):
|
||||
file = self.fs.open("test")
|
||||
assert file.read() == "hello"
|
||||
file.close()
|
||||
|
||||
|
||||
class TestGridfs(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
@ -94,8 +107,12 @@ class TestGridfs(unittest.TestCase):
|
||||
self.assertEqual(["mike", "hello world"], self.fs.list())
|
||||
self.assertEqual(self.db.fs.files.find().count(), 2)
|
||||
self.assertEqual(self.db.fs.chunks.find().count(), 2)
|
||||
self.assertEqual(self.fs.open("mike").read(), "hi")
|
||||
self.assertEqual(self.fs.open("hello world").read(), "fly")
|
||||
f = self.fs.open("mike")
|
||||
self.assertEqual(f.read(), "hi")
|
||||
f.close()
|
||||
f = self.fs.open("hello world")
|
||||
self.assertEqual(f.read(), "fly")
|
||||
f.close()
|
||||
self.assertRaises(IOError, self.fs.open, "test")
|
||||
|
||||
self.fs.remove({})
|
||||
@ -148,12 +165,12 @@ class TestGridfs(unittest.TestCase):
|
||||
self.fs.remove("test", "pymongo_test")
|
||||
self.assertEqual(["mike", "hello world"], self.fs.list("pymongo_test"))
|
||||
|
||||
self.assertEqual(self.fs.open("mike",
|
||||
collection="pymongo_test").read(),
|
||||
"hi")
|
||||
self.assertEqual(self.fs.open("hello world",
|
||||
collection="pymongo_test").read(),
|
||||
"fly")
|
||||
f = self.fs.open("mike", collection="pymongo_test")
|
||||
self.assertEqual(f.read(), "hi")
|
||||
f.close()
|
||||
f = self.fs.open("hello world", collection="pymongo_test")
|
||||
self.assertEqual(f.read(), "fly")
|
||||
f.close()
|
||||
|
||||
self.fs.remove({}, "pymongo_test")
|
||||
|
||||
@ -161,10 +178,31 @@ class TestGridfs(unittest.TestCase):
|
||||
self.assertEqual(self.db.pymongo_test.files.find().count(), 0)
|
||||
self.assertEqual(self.db.pymongo_test.chunks.find().count(), 0)
|
||||
|
||||
def test_threading(self):
|
||||
def test_threaded_reads(self):
|
||||
f = self.fs.open("test", "w")
|
||||
f.write("hello")
|
||||
f.close()
|
||||
|
||||
threads = []
|
||||
for i in range(10):
|
||||
t = JustWrite(self.fs)
|
||||
t.start()
|
||||
threads.append(JustRead(self.fs))
|
||||
threads[i].start()
|
||||
|
||||
for i in range(10):
|
||||
threads[i].join()
|
||||
|
||||
def test_threaded_writes(self):
|
||||
threads = []
|
||||
for i in range(10):
|
||||
threads.append(JustWrite(self.fs))
|
||||
threads[i].start()
|
||||
|
||||
for i in range(10):
|
||||
threads[i].join()
|
||||
|
||||
f = self.fs.open("test")
|
||||
self.assertEqual(f.read(), "hello")
|
||||
f.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user