From a3cf3cf568eb2624e5c537971155835919383441 Mon Sep 17 00:00:00 2001 From: aherlihy Date: Tue, 15 Dec 2015 15:28:20 -0500 Subject: [PATCH] PYTHON-1031 - GridFsBucket.download_to_stream now uses GridOutIterator --- gridfs/__init__.py | 11 +++++---- test/test_gridfs_bucket.py | 50 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 56 insertions(+), 5 deletions(-) diff --git a/gridfs/__init__.py b/gridfs/__init__.py index fc22921e7..42e94295c 100644 --- a/gridfs/__init__.py +++ b/gridfs/__init__.py @@ -525,7 +525,7 @@ class GridFSBucket(object): # Get _id of file to read file_id = fs.upload_from_stream("test_file", "data I want to store!") # Get file to write to - file = open('myfile','rw') + file = open('myfile','rwb') fs.download_to_stream(file_id, file) contents = file.read() @@ -536,7 +536,8 @@ class GridFSBucket(object): -`destination`: a file-like object implementing :meth:`write`. """ gout = self.open_download_stream(file_id) - destination.write(gout) + for chunk in gout: + destination.write(chunk) def delete(self, file_id): """Given an file_id, delete this stored file's files collection document @@ -663,7 +664,7 @@ class GridFSBucket(object): my_db = MongoClient().test fs = GridFSBucket(my_db) # Get file to write to - file = open('myfile','w') + file = open('myfile','wb') fs.download_to_stream_by_name("test_file", file) Raises :exc:`~gridfs.errors.NoFile` if no such version of @@ -687,8 +688,8 @@ class GridFSBucket(object): -1 = the most recent revision """ gout = self.open_download_stream_by_name(filename, revision) - - destination.write(gout) + for chunk in gout: + destination.write(chunk) def rename(self, file_id, new_filename): """Renames the stored file with the specified file_id. diff --git a/test/test_gridfs_bucket.py b/test/test_gridfs_bucket.py index a7995353f..43c78fcb3 100644 --- a/test/test_gridfs_bucket.py +++ b/test/test_gridfs_bucket.py @@ -369,6 +369,56 @@ class TestGridfs(IntegrationTest): self.assertEqual(0, self.db.fs.chunks.count( {"files_id": gin._id})) + def test_download_to_stream(self): + file1 = StringIO(b"hello world") + # Test with one chunk. + oid = self.fs.upload_from_stream("one_chunk", file1) + self.assertEqual(1, self.db.fs.chunks.count()) + file2 = StringIO() + self.fs.download_to_stream(oid, file2) + file1.seek(0) + file2.seek(0) + self.assertEqual(file1.read(), file2.read()) + + # Test with many chunks. + self.db.drop_collection("fs.files") + self.db.drop_collection("fs.chunks") + file1.seek(0) + oid = self.fs.upload_from_stream("many_chunks", + file1, + chunk_size_bytes=1) + self.assertEqual(11, self.db.fs.chunks.count()) + file2 = StringIO() + self.fs.download_to_stream(oid, file2) + file1.seek(0) + file2.seek(0) + self.assertEqual(file1.read(), file2.read()) + + def test_download_to_stream_by_name(self): + file1 = StringIO(b"hello world") + # Test with one chunk. + oid = self.fs.upload_from_stream("one_chunk", file1) + self.assertEqual(1, self.db.fs.chunks.count()) + file2 = StringIO() + self.fs.download_to_stream_by_name("one_chunk", file2) + file1.seek(0) + file2.seek(0) + self.assertEqual(file1.read(), file2.read()) + + # Test with many chunks. + self.db.drop_collection("fs.files") + self.db.drop_collection("fs.chunks") + file1.seek(0) + self.fs.upload_from_stream("many_chunks", file1, chunk_size_bytes=1) + self.assertEqual(11, self.db.fs.chunks.count()) + + file2 = StringIO() + self.fs.download_to_stream_by_name("many_chunks", file2) + file1.seek(0) + file2.seek(0) + self.assertEqual(file1.read(), file2.read()) + + class TestGridfsBucketReplicaSet(TestReplicaSetClientBase): def test_gridfs_replica_set(self):