diff --git a/doc/changelog.rst b/doc/changelog.rst index b9ce9e6e9..101fda7a6 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -15,6 +15,13 @@ Changes in Version 3.8.0.dev0 - :class:`~bson.objectid.ObjectId` now implements the `ObjectID specification version 0.2 `_. +- For better performance and to better follow the GridFS spec, + :class:`~gridfs.grid_file.GridOut` now uses a single cursor to read all the + chunks in the file. Previously, each chunk in the file was queried + individually using :meth:`~pymongo.collection.Collection.find_one`. +- :meth:`gridfs.grid_file.GridOut.read` now only checks for extra chunks after + reading the entire file. Previously, this method would check for extra + chunks on every call. - :meth:`~pymongo.database.Database.current_op` now always uses the ``Database``'s :attr:`~pymongo.database.Database.codec_options` diff --git a/gridfs/__init__.py b/gridfs/__init__.py index 24cab5581..11084171d 100644 --- a/gridfs/__init__.py +++ b/gridfs/__init__.py @@ -715,9 +715,9 @@ class GridFSBucket(object): .. versionchanged:: 3.6 Added ``session`` parameter. """ - gout = self.open_download_stream(file_id, session=session) - for chunk in gout: - destination.write(chunk) + with self.open_download_stream(file_id, session=session) as gout: + for chunk in gout: + destination.write(chunk) def delete(self, file_id, session=None): """Given an file_id, delete this stored file's files collection document @@ -890,10 +890,10 @@ class GridFSBucket(object): .. versionchanged:: 3.6 Added ``session`` parameter. """ - gout = self.open_download_stream_by_name( - filename, revision, session=session) - for chunk in gout: - destination.write(chunk) + with self.open_download_stream_by_name( + filename, revision, session=session) as gout: + for chunk in gout: + destination.write(chunk) def rename(self, file_id, new_filename, session=None): """Renames the stored file with the specified file_id. diff --git a/gridfs/grid_file.py b/gridfs/grid_file.py index 8368c53e0..43a6f1b13 100644 --- a/gridfs/grid_file.py +++ b/gridfs/grid_file.py @@ -27,6 +27,7 @@ from pymongo import ASCENDING from pymongo.collection import Collection from pymongo.cursor import Cursor from pymongo.errors import (ConfigurationError, + CursorNotFound, DuplicateKeyError, OperationFailure) from pymongo.read_preferences import ReadPreference @@ -419,6 +420,11 @@ class GridOut(object): :class:`~pymongo.client_session.ClientSession` to use for all commands + .. versionchanged:: 3.8 + For better performance and to better follow the GridFS spec, + :class:`GridOut` now uses a single cursor to read all the chunks in + the file. + .. versionchanged:: 3.6 Added ``session`` parameter. @@ -434,6 +440,7 @@ class GridOut(object): self.__files = root_collection.files self.__file_id = file_id self.__buffer = EMPTY + self.__chunk_iter = None self.__position = 0 self._file = file_document self._session = session @@ -477,12 +484,11 @@ class GridOut(object): chunk_data = self.__buffer elif self.__position < int(self.length): chunk_number = int((received + self.__position) / chunk_size) - chunk = self.__chunks.find_one({"files_id": self._id, - "n": chunk_number}, - session=self._session) - if not chunk: - raise CorruptGridFile("no chunk #%d" % chunk_number) + if self.__chunk_iter is None: + self.__chunk_iter = _GridOutChunkIterator( + self, self.__chunks, self._session, chunk_number) + chunk = self.__chunk_iter.next() chunk_data = chunk["data"][self.__position % chunk_size:] if not chunk_data: @@ -501,16 +507,21 @@ class GridOut(object): :Parameters: - `size` (optional): the number of bytes to read + + .. versionchanged:: 3.8 + This method now only checks for extra chunks after reading the + entire file. Previously, this method would check for extra chunks + on every call. """ self._ensure_file() - if size == 0: - return EMPTY - remainder = int(self.length) - self.__position if size < 0 or size > remainder: size = remainder + if size == 0: + return EMPTY + received = 0 data = StringIO() while received < size: @@ -518,16 +529,12 @@ class GridOut(object): received += len(chunk_data) data.write(chunk_data) - # Detect extra chunks. - max_chunk_n = math.ceil(self.length / float(self.chunk_size)) - chunk = self.__chunks.find_one({"files_id": self._id, - "n": {"$gte": max_chunk_n}}, - session=self._session) - # According to spec, ignore extra chunks if they are empty. - if chunk is not None and len(chunk['data']): - raise CorruptGridFile( - "Extra chunk found: expected %i chunks but found " - "chunk with n=%i" % (max_chunk_n, chunk['n'])) + # Detect extra chunks after reading the entire file. + if size == remainder and self.__chunk_iter: + try: + self.__chunk_iter.next() + except StopIteration: + pass self.__position -= received - size @@ -543,13 +550,13 @@ class GridOut(object): :Parameters: - `size` (optional): the maximum number of bytes to read """ - if size == 0: - return b'' - remainder = int(self.length) - self.__position if size < 0 or size > remainder: size = remainder + if size == 0: + return EMPTY + received = 0 data = StringIO() while received < size: @@ -600,8 +607,15 @@ class GridOut(object): if new_pos < 0: raise IOError(22, "Invalid value for `pos` - must be positive") + # Optimization, continue using the same buffer and chunk iterator. + if new_pos == self.__position: + return + self.__position = new_pos self.__buffer = EMPTY + if self.__chunk_iter: + self.__chunk_iter.close() + self.__chunk_iter = None def __iter__(self): """Return an iterator over all of this file's data. @@ -610,12 +624,20 @@ class GridOut(object): :class:`str` (:class:`bytes` in python 3). This can be useful when serving files using a webserver that handles such an iterator efficiently. + + .. versionchanged:: 3.8 + The iterator now raises :class:`CorruptGridFile` when encountering + any truncated, missing, or extra chunk in a file. The previous + behavior was to only raise :class:`CorruptGridFile` on a missing + chunk. """ return GridOutIterator(self, self.__chunks, self._session) def close(self): """Make GridOut more generically file-like.""" - pass + if self.__chunk_iter: + self.__chunk_iter.close() + self.__chunk_iter = None def __enter__(self): """Makes it possible to use :class:`GridOut` files @@ -627,30 +649,108 @@ class GridOut(object): """Makes it possible to use :class:`GridOut` files with the context manager protocol. """ + self.close() return False +class _GridOutChunkIterator(object): + """Iterates over a file's chunks using a single cursor. + + Raises CorruptGridFile when encountering any truncated, missing, or extra + chunk in a file. + """ + def __init__(self, grid_out, chunks, session, next_chunk): + self._id = grid_out._id + self._chunk_size = int(grid_out.chunk_size) + self._length = int(grid_out.length) + self._chunks = chunks + self._session = session + self._next_chunk = next_chunk + self._num_chunks = math.ceil(float(self._length) / self._chunk_size) + self._cursor = None + + def expected_chunk_length(self, chunk_n): + if chunk_n < self._num_chunks - 1: + return self._chunk_size + return self._length - (self._chunk_size * (self._num_chunks - 1)) + + def __iter__(self): + return self + + def _create_cursor(self): + filter = {"files_id": self._id} + if self._next_chunk > 0: + filter["n"] = {"$gte": self._next_chunk} + self._cursor = self._chunks.find(filter, sort=[("n", 1)], + session=self._session) + + def _next_with_retry(self): + """Return the next chunk and retry once on CursorNotFound. + + We retry on CursorNotFound to maintain backwards compatibility in + cases where two calls to read occur more than 10 minutes apart (the + server's default cursor timeout). + """ + if self._cursor is None: + self._create_cursor() + + try: + return self._cursor.next() + except CursorNotFound: + self._cursor.close() + self._create_cursor() + return self._cursor.next() + + def next(self): + try: + chunk = self._next_with_retry() + except StopIteration: + if self._next_chunk >= self._num_chunks: + raise + raise CorruptGridFile("no chunk #%d" % self._next_chunk) + + if chunk["n"] != self._next_chunk: + self.close() + raise CorruptGridFile( + "Missing chunk: expected chunk #%d but found " + "chunk with n=%d" % (self._next_chunk, chunk["n"])) + + if chunk["n"] >= self._num_chunks: + # According to spec, ignore extra chunks if they are empty. + if len(chunk["data"]): + self.close() + raise CorruptGridFile( + "Extra chunk found: expected %d chunks but found " + "chunk with n=%d" % (self._num_chunks, chunk["n"])) + + expected_length = self.expected_chunk_length(chunk["n"]) + if len(chunk["data"]) != expected_length: + self.close() + raise CorruptGridFile( + "truncated chunk #%d: expected chunk length to be %d but " + "found chunk with length %d" % ( + chunk["n"], expected_length, len(chunk["data"]))) + + self._next_chunk += 1 + return chunk + + __next__ = next + + def close(self): + if self._cursor: + self._cursor.close() + self._cursor = None + + class GridOutIterator(object): def __init__(self, grid_out, chunks, session): - self.__id = grid_out._id - self.__chunks = chunks - self.__session = session - self.__current_chunk = 0 - self.__max_chunk = math.ceil(float(grid_out.length) / - grid_out.chunk_size) + self.__chunk_iter = _GridOutChunkIterator(grid_out, chunks, session, 0) def __iter__(self): return self def next(self): - if self.__current_chunk >= self.__max_chunk: - raise StopIteration - chunk = self.__chunks.find_one({"files_id": self.__id, - "n": self.__current_chunk}, - session=self.__session) - if not chunk: - raise CorruptGridFile("no chunk #%d" % self.__current_chunk) - self.__current_chunk += 1 + chunk = self.__chunk_iter.next() return bytes(chunk["data"]) __next__ = next diff --git a/test/test_grid_file.py b/test/test_grid_file.py index 8bd674e6a..1c8a629c7 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -33,10 +33,11 @@ from gridfs.grid_file import (DEFAULT_CHUNK_SIZE, from gridfs.errors import NoFile from pymongo import MongoClient from pymongo.errors import ConfigurationError, ServerSelectionTimeoutError +from pymongo.message import _CursorAddress from test import (IntegrationTest, unittest, qcheck) -from test.utils import rs_or_single_client +from test.utils import rs_or_single_client, EventListener class TestGridFileNoConnect(unittest.TestCase): @@ -616,6 +617,33 @@ Bye""")) with self.assertRaises(ConfigurationError): GridIn(rs_or_single_client(w=0).pymongo_test.fs) + def test_survive_cursor_not_found(self): + # By default the find command returns 101 documents in the first batch. + # Use 102 batches to cause a single getMore. + chunk_size = 1024 + data = b'd' * (102 * chunk_size) + listener = EventListener() + client = rs_or_single_client(event_listeners=[listener]) + db = client.pymongo_test + with GridIn(db.fs, chunk_size=chunk_size) as infile: + infile.write(data) + + with GridOut(db.fs, infile._id) as outfile: + self.assertEqual(len(outfile.readchunk()), chunk_size) + + # Kill the cursor to simulate the cursor timing out on the server + # when an application spends a long time between two calls to + # readchunk(). + client._close_cursor_now( + outfile._GridOut__chunk_iter._cursor.cursor_id, + _CursorAddress(client.address, db.fs.chunks.full_name)) + + # Read the rest of the file without error. + self.assertEqual(len(outfile.read()), len(data) - chunk_size) + + # Paranoid, ensure that a getMore was actually sent. + self.assertIn("getMore", listener.started_command_names()) + if __name__ == "__main__": unittest.main() diff --git a/test/test_gridfs_spec.py b/test/test_gridfs_spec.py index 736a6f9e3..c1394f255 100644 --- a/test/test_gridfs_spec.py +++ b/test/test_gridfs_spec.py @@ -163,8 +163,8 @@ def create_test(scenario_def): if test['assert'].get("error", False): self.assertIsNotNone(error) - self.assertTrue(isinstance(error, - errors[test['assert']['error']])) + self.assertIsInstance(error, errors[test['assert']['error']], + test['description']) else: self.assertIsNone(error) diff --git a/test/test_session.py b/test/test_session.py index 6ffaf3f09..292411780 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -571,8 +571,11 @@ class TestSession(IntegrationTest): for f in files: f.read() - with self.assertRaisesRegex(InvalidOperation, "ended session"): - files[0].read() + for f in files: + # Attempt to read the file again. + f.seek(0) + with self.assertRaisesRegex(InvalidOperation, "ended session"): + f.read() def test_aggregate(self): client = self.client