diff --git a/gridfs/__init__.py b/gridfs/__init__.py index 764604601..dfd8f02f4 100644 --- a/gridfs/__init__.py +++ b/gridfs/__init__.py @@ -40,49 +40,61 @@ class GridFS(object): self.__database = database - def open(self, filename, mode="r"): + def open(self, filename, mode="r", collection="gridfs"): """Open a GridFile for reading or writing. Shorthand method for creating / opening a GridFile from a filename. mode must be a mode supported by `gridfs.grid_file.GridFile` :Parameters: - - `filename`: the name of the GridFile to open - - `mode` (optional): the mode to open the file in + - `filename`: name of the GridFile to open + - `mode` (optional): mode to open the file in + - `collection` (optional): root collection to use for this file """ - return GridFile({"filename": filename}, self.__database, mode) + return GridFile({"filename": filename}, self.__database, mode, collection) - def remove(self, filename_or_spec): + def remove(self, filename_or_spec, collection="gridfs"): """Remove one or more GridFile(s). Can remove by filename, or by an entire file spec (see `gridfs.grid_file.GridFile` for documentation on valid fields. Delete all GridFiles that match filename_or_spec. Raises TypeError if - filename_or_spec is not an instance of (str, unicode, dict, SON). + filename_or_spec is not an instance of (str, unicode, dict, SON) or + collection is not an instance of (str, unicode). :Parameters: - `filename_or_spec`: identifier of file(s) to remove + - `collection` (optional): root collection where this file is located """ spec = filename_or_spec if isinstance(filename_or_spec, types.StringTypes): spec = {"filename": filename_or_spec} + if not isinstance(collection, types.StringTypes): + raise TypeError("collection must be an instance of (str, unicode)") # convert to _id's so we can uniquely create GridFile instances ids = [] - for file in self.__database._files.find(spec): + for file in self.__database[collection].files.find(spec): ids.append(file["_id"]) # open for writing to remove the chunks for these files for id in ids: - f = GridFile({"_id": id}, self.__database, "w") + f = GridFile({"_id": id}, self.__database, "w", collection) f.close() - self.__database._files.remove(spec) + self.__database[collection].files.remove(spec) - def list(self): + def list(self, collection="gridfs"): """List the names of all GridFiles stored in this instance of GridFS. + + Raises TypeError if collection is not an instance of (str, unicode). + + :Parameters: + - `collection` (optional): root collection to list files from """ + if not isinstance(collection, types.StringTypes): + raise TypeError("collection must be an instance of (str, unicode)") names = [] - for file in self.__database._files.find(): + for file in self.__database[collection].files.find(): names.append(file["filename"]) return names diff --git a/gridfs/grid_file.py b/gridfs/grid_file.py index a434010eb..5d1b6c9de 100644 --- a/gridfs/grid_file.py +++ b/gridfs/grid_file.py @@ -30,7 +30,7 @@ class GridFile(object): """ # TODO should be able to create a GridFile given a Collection object instead # of a database and collection name? - def __init__(self, file_spec, database, mode="r", collection="_files"): + def __init__(self, file_spec, database, mode="r", collection="gridfs"): """Open a "file" in GridFS. Application developers should generally not need to instantiate this @@ -57,6 +57,7 @@ class GridFile(object): - "uploadDate": date when the object was first stored * only used for querying, automatically set for inserts - "aliases": array of alias strings + - "metadata": a SON document containing arbitrary data :Parameters: - `file_spec`: query specifier as described above @@ -79,7 +80,7 @@ class GridFile(object): self.__collection = database[collection] - file = self.__collection.find_one(file_spec) + file = self.__collection.files.find_one(file_spec) if file: self.__id = file["_id"] else: @@ -88,7 +89,7 @@ class GridFile(object): file_spec["length"] = 0 file_spec["uploadDate"] = datetime.datetime.now() file_spec.setdefault("chunkSize", 256000) - self.__id = self.__collection.insert(file_spec)["_id"] + self.__id = self.__collection.files.insert(file_spec)["_id"] self.__mode = mode if mode == "w": @@ -103,13 +104,13 @@ class GridFile(object): def __erase(self): """Erase all of the data stored in this GridFile. """ - file = self.__collection.find_one(self.__id) + file = self.__collection.files.find_one(self.__id) next = file.get("next", None) chunk_number = 0 while next: - chunk = self.__collection.database().dereference(next) + chunk = self.__collection.files.database().dereference(next) if not chunk: raise CorruptGridFile("could not dereference: %r" % next) if chunk["cn"] != chunk_number: @@ -123,7 +124,7 @@ class GridFile(object): file["next"] = None file["length"] = 0 - self.__collection.save(file) + self.__collection.files.save(file) @property def closed(self): @@ -135,11 +136,11 @@ class GridFile(object): def __create_property(field_name, read_only=False): def get(self): - return self.__collection.find_one(self.__id).get(field_name, None) + return self.__collection.files.find_one(self.__id).get(field_name, None) def set(self, value): - file = self.__collection.find_one(self.__id) + file = self.__collection.files.find_one(self.__id) file[field_name] = value - self.__collection.save(file) + self.__collection.files.save(file) if not read_only: return property(get, set) return property(get) @@ -161,15 +162,13 @@ class GridFile(object): :Parameters: - `filename`: the new name for this GridFile """ - file = self.__collection.find_one(self.__id) + file = self.__collection.files.find_one(self.__id) file["filename"] = filename - self.__collection.save(file) + self.__collection.files.save(file) - __chunks_collection = "_chunks" def __write_current_chunk(self): - # TODO _chunks collection should be configurable? self.__current_chunk["data"] = Binary(self.__current_chunk["data"]) - self.__collection.database()[self.__chunks_collection].save(self.__current_chunk) + self.__collection.chunks.save(self.__current_chunk) def flush(self): """Flush the GridFile to the database. @@ -178,13 +177,13 @@ class GridFile(object): if self.mode != "w" or not self.__current_chunk: return - file = self.__collection.find_one(self.__id) + file = self.__collection.files.find_one(self.__id) length = file["chunkSize"] * self.__current_chunk["cn"] + len(self.__current_chunk["data"]) file["length"] = length self.__write_current_chunk() - self.__collection.save(file) + self.__collection.files.save(file) def close(self): """Close the GridFile. @@ -274,10 +273,10 @@ class GridFile(object): if not self.__current_chunk: self.__current_chunk = initialize_chunk(0) - ref = DBRef(self.__chunks_collection, self.__current_chunk["_id"]) - file = self.__collection.find_one(self.__id) + ref = DBRef(self.__collection.chunks.name(), self.__current_chunk["_id"]) + file = self.__collection.files.find_one(self.__id) file["next"] = ref - self.__collection.save(file) + self.__collection.files.save(file) data = self.__current_chunk["data"] data += str @@ -288,7 +287,7 @@ class GridFile(object): new_chunk = initialize_chunk(self.__current_chunk["cn"] + 1) - self.__current_chunk["next"] = DBRef(self.__chunks_collection, new_chunk["_id"]) + self.__current_chunk["next"] = DBRef(self.__collection.chunks.name(), new_chunk["_id"]) self.__write_current_chunk() self.__current_chunk = new_chunk diff --git a/test/test_grid_file.py b/test/test_grid_file.py index 1c8714e91..216180ad5 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -29,17 +29,17 @@ class TestGridFile(unittest.TestCase): self.db = get_connection().pymongo_test def test_basic(self): - self.db._files.remove({}) - self.db._chunks.remove({}) + self.db.gridfs.files.remove({}) + self.db.gridfs.chunks.remove({}) - self.assertEqual(self.db._files.find().count(), 0) - self.assertEqual(self.db._chunks.find().count(), 0) + self.assertEqual(self.db.gridfs.files.find().count(), 0) + self.assertEqual(self.db.gridfs.chunks.find().count(), 0) file = GridFile({"filename": "test"}, self.db, "w") file.write("hello world") file.close() - self.assertEqual(self.db._files.find().count(), 1) - self.assertEqual(self.db._chunks.find().count(), 1) + self.assertEqual(self.db.gridfs.files.find().count(), 1) + self.assertEqual(self.db.gridfs.chunks.find().count(), 1) file = GridFile({"filename": "test"}, self.db) self.assertEqual(file.read(), "hello world") @@ -53,17 +53,50 @@ class TestGridFile(unittest.TestCase): file = GridFile({"filename": "test"}, self.db, "w") file.close() - self.assertEqual(self.db._files.find().count(), 1) - self.assertEqual(self.db._chunks.find().count(), 0) + self.assertEqual(self.db.gridfs.files.find().count(), 1) + self.assertEqual(self.db.gridfs.chunks.find().count(), 0) file = GridFile({"filename": "test"}, self.db) self.assertEqual(file.next, None) self.assertEqual(file.read(), "") file.close() + def test_alternate_collection(self): + self.db.pymongo_test.files.remove({}) + self.db.pymongo_test.chunks.remove({}) + + self.assertEqual(self.db.pymongo_test.files.find().count(), 0) + self.assertEqual(self.db.pymongo_test.chunks.find().count(), 0) + file = GridFile({"filename": "test"}, self.db, "w", collection="pymongo_test") + file.write("hello world") + file.close() + + self.assertEqual(self.db.pymongo_test.files.find().count(), 1) + self.assertEqual(self.db.pymongo_test.chunks.find().count(), 1) + + file = GridFile({"filename": "test"}, self.db, collection="pymongo_test") + self.assertEqual(file.read(), "hello world") + file.close() + + # make sure it's still there... + file = GridFile({"filename": "test"}, self.db, collection="pymongo_test") + self.assertEqual(file.read(), "hello world") + file.close() + + file = GridFile({"filename": "test"}, self.db, "w", collection="pymongo_test") + file.close() + + self.assertEqual(self.db.pymongo_test.files.find().count(), 1) + self.assertEqual(self.db.pymongo_test.chunks.find().count(), 0) + + file = GridFile({"filename": "test"}, self.db, collection="pymongo_test") + self.assertEqual(file.next, None) + self.assertEqual(file.read(), "") + file.close() + def test_create_grid_file(self): - self.db._files.remove({}) - self.db._chunks.remove({}) + self.db.gridfs.files.remove({}) + self.db.gridfs.chunks.remove({}) # just write a blank file so that reads on {} don't fail file = GridFile({"filename": "test"}, self.db, "w") @@ -96,8 +129,8 @@ class TestGridFile(unittest.TestCase): self.assertTrue(GridFile({"filename": "test"}, self.db)) def test_properties(self): - self.db._files.remove({}) - self.db._chunks.remove({}) + self.db.gridfs.files.remove({}) + self.db.gridfs.chunks.remove({}) file = GridFile({"filename": "test"}, self.db, "w") self.assertEqual(file.mode, "w") @@ -146,8 +179,8 @@ class TestGridFile(unittest.TestCase): self.assertRaises(AttributeError, set_name) def test_rename(self): - self.db._files.remove({}) - self.db._chunks.remove({}) + self.db.gridfs.files.remove({}) + self.db.gridfs.chunks.remove({}) file = GridFile({"filename": "test"}, self.db, "w") file.close() @@ -162,8 +195,8 @@ class TestGridFile(unittest.TestCase): a = GridFile({"filename": "mike"}, self.db) def test_flush_close(self): - self.db._files.remove({}) - self.db._chunks.remove({}) + self.db.gridfs.files.remove({}) + self.db.gridfs.chunks.remove({}) file = GridFile({"filename": "test"}, self.db, "w") file.flush() @@ -186,8 +219,8 @@ class TestGridFile(unittest.TestCase): self.assertEqual(GridFile({}, self.db).read(), "miketesthuh") def test_overwrite(self): - self.db._files.remove({}) - self.db._chunks.remove({}) + self.db.gridfs.files.remove({}) + self.db.gridfs.chunks.remove({}) file = GridFile({"filename": "test"}, self.db, "w") file.write("test") @@ -200,8 +233,8 @@ class TestGridFile(unittest.TestCase): self.assertEqual(GridFile({}, self.db).read(), "mike") def test_multi_chunk_file(self): - self.db._files.remove({}) - self.db._chunks.remove({}) + self.db.gridfs.files.remove({}) + self.db.gridfs.chunks.remove({}) random_string = qcheck.gen_string(qcheck.lift(300000))() @@ -209,14 +242,14 @@ class TestGridFile(unittest.TestCase): file.write(random_string) file.close() - self.assertEqual(self.db._files.find().count(), 1) - self.assertEqual(self.db._chunks.find().count(), 2) + self.assertEqual(self.db.gridfs.files.find().count(), 1) + self.assertEqual(self.db.gridfs.chunks.find().count(), 2) self.assertEqual(GridFile({}, self.db).read(), random_string) def test_small_chunks(self): - self.db._files.remove({}) - self.db._chunks.remove({}) + self.db.gridfs.files.remove({}) + self.db.gridfs.chunks.remove({}) self.files = 0 self.chunks = 0 @@ -231,8 +264,8 @@ class TestGridFile(unittest.TestCase): self.files += 1 self.chunks += len(data) - self.assertEqual(self.db._files.find().count(), self.files) - self.assertEqual(self.db._chunks.find().count(), self.chunks) + self.assertEqual(self.db.gridfs.files.find().count(), self.files) + self.assertEqual(self.db.gridfs.chunks.find().count(), self.chunks) self.assertEqual(GridFile({"filename": filename}, self.db).read(), data) @@ -243,8 +276,8 @@ class TestGridFile(unittest.TestCase): qcheck.check_unittest(self, helper, qcheck.gen_string(qcheck.gen_range(0, 20))) def test_modes(self): - self.db._files.remove({}) - self.db._chunks.remove({}) + self.db.gridfs.files.remove({}) + self.db.gridfs.chunks.remove({}) file = GridFile({"filename": "test"}, self.db, "w") self.assertRaises(ValueError, file.read) @@ -261,8 +294,8 @@ class TestGridFile(unittest.TestCase): self.assertRaises(ValueError, file.write, "hello") def test_multiple_reads(self): - self.db._files.remove({}) - self.db._chunks.remove({}) + self.db.gridfs.files.remove({}) + self.db.gridfs.chunks.remove({}) file = GridFile({"filename": "test"}, self.db, "w") file.write("hello world") diff --git a/test/test_gridfs.py b/test/test_gridfs.py index 452512f9c..5468fdaab 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -25,8 +25,10 @@ from test_connection import get_connection class TestGridfs(unittest.TestCase): def setUp(self): self.db = get_connection().pymongo_test - self.db._files.remove({}) - self.db._chunks.remove({}) + self.db.gridfs.files.remove({}) + self.db.gridfs.chunks.remove({}) + self.db.pymongo_test.files.remove({}) + self.db.pymongo_test.chunks.remove({}) self.fs = gridfs.GridFS(self.db) def test_open(self): @@ -68,14 +70,14 @@ class TestGridfs(unittest.TestCase): f.write("fly") f.close() self.assertEqual(["mike", "test", "hello world"], self.fs.list()) - self.assertEqual(self.db._files.find().count(), 3) - self.assertEqual(self.db._chunks.find().count(), 3) + self.assertEqual(self.db.gridfs.files.find().count(), 3) + self.assertEqual(self.db.gridfs.chunks.find().count(), 3) self.fs.remove("test") self.assertEqual(["mike", "hello world"], self.fs.list()) - self.assertEqual(self.db._files.find().count(), 2) - self.assertEqual(self.db._chunks.find().count(), 2) + self.assertEqual(self.db.gridfs.files.find().count(), 2) + self.assertEqual(self.db.gridfs.chunks.find().count(), 2) self.assertEqual(self.fs.open("mike").read(), "hi") self.assertEqual(self.fs.open("hello world").read(), "fly") self.assertRaises(IOError, self.fs.open, "test") @@ -83,11 +85,59 @@ class TestGridfs(unittest.TestCase): self.fs.remove({}) self.assertEqual([], self.fs.list()) - self.assertEqual(self.db._files.find().count(), 0) - self.assertEqual(self.db._chunks.find().count(), 0) + self.assertEqual(self.db.gridfs.files.find().count(), 0) + self.assertEqual(self.db.gridfs.chunks.find().count(), 0) self.assertRaises(IOError, self.fs.open, "test") self.assertRaises(IOError, self.fs.open, "mike") self.assertRaises(IOError, self.fs.open, "hello world") + def test_open_alt_coll(self): + f = self.fs.open("my file", "w", "pymongo_test") + f.write("hello gridfs world!") + f.close() + + self.assertRaises(IOError, self.fs.open, "my file", "r") + g = self.fs.open("my file", "r", "pymongo_test") + self.assertEqual("hello gridfs world!", g.read()) + g.close() + + def test_list_alt_coll(self): + f = self.fs.open("mike", "w", "pymongo_test") + f.close() + + f = self.fs.open("test", "w", "pymongo_test") + f.close() + + f = self.fs.open("hello world", "w", "pymongo_test") + f.close() + + self.assertEqual([], self.fs.list()) + self.assertEqual(["mike", "test", "hello world"], self.fs.list("pymongo_test")) + + def test_remove_alt_coll(self): + f = self.fs.open("mike", "w", "pymongo_test") + f.write("hi") + f.close() + f = self.fs.open("test", "w", "pymongo_test") + f.write("bye") + f.close() + f = self.fs.open("hello world", "w", "pymongo_test") + f.write("fly") + f.close() + + self.fs.remove("test") + self.assertEqual(["mike", "test", "hello world"], self.fs.list("pymongo_test")) + self.fs.remove("test", "pymongo_test") + self.assertEqual(["mike", "hello world"], self.fs.list("pymongo_test")) + + self.assertEqual(self.fs.open("mike", collection="pymongo_test").read(), "hi") + self.assertEqual(self.fs.open("hello world", collection="pymongo_test").read(), "fly") + + self.fs.remove({}, "pymongo_test") + + self.assertEqual([], self.fs.list("pymongo_test")) + self.assertEqual(self.db.pymongo_test.files.find().count(), 0) + self.assertEqual(self.db.pymongo_test.chunks.find().count(), 0) + if __name__ == "__main__": unittest.main()