diff --git a/gridfs/__init__.py b/gridfs/__init__.py index 015fd7c41..53c75bd83 100644 --- a/gridfs/__init__.py +++ b/gridfs/__init__.py @@ -17,3 +17,72 @@ The `gridfs` package is an implementation of GridFS on top of `pymongo`, exposing a file-like interface. """ + +import types + +from grid_file import GridFile +from pymongo.database import Database + +class GridFS(object): + """An instance of GridFS on top of a single `pymongo.database.Database`. + """ + def __init__(self, database): + """Create a new instance of GridFS. + + Raises TypeError if database is not an instance of + `pymongo.database.Database`. + + Arguments: + - `database`: database to use + """ + if not isinstance(database, Database): + raise TypeError("database must be an instance of Database") + + self.__database = database + + def open(self, filename, mode="r"): + """Open a GridFile for reading or writing. + + Shorthand method for creating / opening a GridFile from a filename. mode + must be a mode supported by `gridfs.grid_file.GridFile` + + Arguments: + - `filename`: the name of the GridFile to open + - `mode` (optional): the mode to open the file in + """ + return GridFile({"filename": filename}, self.__database, mode) + + def remove(self, filename_or_spec): + """Remove one or more GridFile(s). + + Can remove by filename, or by an entire file spec (see + `gridfs.grid_file.GridFile` for documentation on valid fields. Delete + all GridFiles that match filename_or_spec. Raises TypeError if + filename_or_spec is not an instance of (str, unicode, dict, SON). + + Arguments: + - `filename_or_spec`: identifier of file(s) to remove + """ + spec = filename_or_spec + if isinstance(filename_or_spec, types.StringTypes): + spec = {"filename": filename_or_spec} + + # convert to _id's so we can uniquely create GridFile instances + ids = [] + for file in self.__database._files.find(spec): + ids.append(file["_id"]) + + # open for writing to remove the chunks for these files + for id in ids: + f = GridFile({"_id": id}, self.__database, "w") + f.close() + + self.__database._files.remove(spec) + + def list(self): + """List the names of all GridFiles stored in this instance of GridFS. + """ + names = [] + for file in self.__database._files.find(): + names.append(file["filename"]) + return names diff --git a/test/test_gridfs.py b/test/test_gridfs.py new file mode 100644 index 000000000..3809bea14 --- /dev/null +++ b/test/test_gridfs.py @@ -0,0 +1,91 @@ +# Copyright 2009 10gen, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the gridfs package. +""" + +import unittest + +from test_connection import get_connection +import gridfs + +class TestGridfs(unittest.TestCase): + def setUp(self): + self.db = get_connection().test + self.db._files.remove({}) + self.db._chunks.remove({}) + self.fs = gridfs.GridFS(self.db) + + def test_open(self): + self.assertRaises(IOError, self.fs.open, "my file", "r") + f = self.fs.open("my file", "w") + f.write("hello gridfs world!") + f.close() + + g = self.fs.open("my file", "r") + self.assertEqual("hello gridfs world!", g.read()) + g.close() + + def test_list(self): + self.assertEqual(self.fs.list(), []) + + f = self.fs.open("mike", "w") + f.close() + + f = self.fs.open("test", "w") + f.close() + + f = self.fs.open("hello world", "w") + f.close() + + self.assertEqual(["mike", "test", "hello world"], self.fs.list()) + + def test_remove(self): + self.assertRaises(TypeError, self.fs.remove, 5) + self.assertRaises(TypeError, self.fs.remove, None) + self.assertRaises(TypeError, self.fs.remove, []) + + f = self.fs.open("mike", "w") + f.write("hi") + f.close() + f = self.fs.open("test", "w") + f.write("bye") + f.close() + f = self.fs.open("hello world", "w") + f.write("fly") + f.close() + self.assertEqual(["mike", "test", "hello world"], self.fs.list()) + self.assertEqual(self.db._files.find().count(), 3) + self.assertEqual(self.db._chunks.find().count(), 3) + + self.fs.remove("test") + + self.assertEqual(["mike", "hello world"], self.fs.list()) + self.assertEqual(self.db._files.find().count(), 2) + self.assertEqual(self.db._chunks.find().count(), 2) + self.assertEqual(self.fs.open("mike").read(), "hi") + self.assertEqual(self.fs.open("hello world").read(), "fly") + self.assertRaises(IOError, self.fs.open, "test") + + self.fs.remove({}) + + self.assertEqual([], self.fs.list()) + self.assertEqual(self.db._files.find().count(), 0) + self.assertEqual(self.db._chunks.find().count(), 0) + self.assertRaises(IOError, self.fs.open, "test") + self.assertRaises(IOError, self.fs.open, "mike") + self.assertRaises(IOError, self.fs.open, "hello world") + +if __name__ == "__main__": + unittest.main()