diff --git a/doc/migrate-to-pymongo4.rst b/doc/migrate-to-pymongo4.rst index eca479c7c..5843a2261 100644 --- a/doc/migrate-to-pymongo4.rst +++ b/doc/migrate-to-pymongo4.rst @@ -879,12 +879,11 @@ and store it with other file metadata. For example:: import hashlib my_db = MongoClient().test fs = GridFSBucket(my_db) - grid_in = fs.open_upload_stream("test_file") - file_data = b'...' - sha356 = hashlib.sha256(file_data).hexdigest() - grid_in.write(file_data) - grid_in.sha356 = sha356 # Set the custom 'sha356' field - grid_in.close() + with fs.open_upload_stream("test_file") as grid_in: + file_data = b'...' + sha356 = hashlib.sha256(file_data).hexdigest() + grid_in.write(file_data) + grid_in.sha356 = sha356 # Set the custom 'sha356' field Note that for large files, the checksum may need to be computed in chunks to avoid the excessive memory needed to load the entire file at once. diff --git a/gridfs/__init__.py b/gridfs/__init__.py index 5675e8f93..29d582cd2 100644 --- a/gridfs/__init__.py +++ b/gridfs/__init__.py @@ -109,11 +109,8 @@ class GridFS(object): Equivalent to doing:: - try: - f = new_file(**kwargs) + with fs.new_file(**kwargs) as f: f.write(data) - finally: - f.close() `data` can be either an instance of :class:`bytes` or a file-like object providing a :meth:`read` method. If an `encoding` keyword @@ -134,13 +131,10 @@ class GridFS(object): .. versionchanged:: 3.0 w=0 writes to GridFS are now prohibited. """ - grid_file = GridIn(self.__collection, **kwargs) - try: - grid_file.write(data) - finally: - grid_file.close() - return grid_file._id + with GridIn(self.__collection, **kwargs) as grid_file: + grid_file.write(data) + return grid_file._id def get(self, file_id: Any, session: Optional[ClientSession] = None) -> GridOut: """Get a file from GridFS by ``"_id"``. @@ -528,11 +522,11 @@ class GridFSBucket(object): my_db = MongoClient().test fs = GridFSBucket(my_db) - grid_in = fs.open_upload_stream( + with fs.open_upload_stream( "test_file", chunk_size_bytes=4, - metadata={"contentType": "text/plain"}) - grid_in.write("data I want to store!") - grid_in.close() # uploaded on close + metadata={"contentType": "text/plain"}) as grid_in: + grid_in.write("data I want to store!") + # uploaded on close Returns an instance of :class:`~gridfs.grid_file.GridIn`. @@ -584,13 +578,13 @@ class GridFSBucket(object): my_db = MongoClient().test fs = GridFSBucket(my_db) - grid_in = fs.open_upload_stream_with_id( + with fs.open_upload_stream_with_id( ObjectId(), "test_file", chunk_size_bytes=4, - metadata={"contentType": "text/plain"}) - grid_in.write("data I want to store!") - grid_in.close() # uploaded on close + metadata={"contentType": "text/plain"}) as grid_in: + grid_in.write("data I want to store!") + # uploaded on close Returns an instance of :class:`~gridfs.grid_file.GridIn`. diff --git a/gridfs/grid_file.py b/gridfs/grid_file.py index 5d63d5c65..cec7d57a2 100644 --- a/gridfs/grid_file.py +++ b/gridfs/grid_file.py @@ -396,9 +396,14 @@ class GridIn(object): def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Any: """Support for the context manager protocol. - Close the file and allow exceptions to propagate. + Close the file if no exceptions occur and allow exceptions to propagate. """ - self.close() + if exc_type is None: + # No exceptions happened. + self.close() + else: + # Something happened, at minimum mark as closed. + object.__setattr__(self, "_closed", True) # propagate exceptions return False diff --git a/test/test_grid_file.py b/test/test_grid_file.py index b9fdeacef..8b46133a6 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -675,6 +675,22 @@ Bye""" with GridOut(self.db.fs, infile._id) as outfile: self.assertEqual(contents, outfile.read()) + def test_exception_file_non_existence(self): + contents = b"Imagine this is some important data..." + + with self.assertRaises(ConnectionError): + with GridIn(self.db.fs, filename="important") as infile: + infile.write(contents) + raise ConnectionError("Test exception") + + # Expectation: File chunks are written, entry in files doesn't appear. + self.assertEqual( + self.db.fs.chunks.count_documents({"files_id": infile._id}), infile._chunk_number + ) + + self.assertIsNone(self.db.fs.files.find_one({"_id": infile._id})) + self.assertTrue(infile.closed) + def test_prechunked_string(self): def write_me(s, chunk_size): buf = BytesIO(s) diff --git a/test/test_gridfs.py b/test/test_gridfs.py index ec88dcd48..35a574a1d 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -540,7 +540,7 @@ class TestGridfsReplicaSet(IntegrationTest): # Connects, doesn't create index. self.assertRaises(NoFile, fs.get_last_version) - self.assertRaises(NotPrimaryError, fs.put, "data") + self.assertRaises(NotPrimaryError, fs.put, "data", encoding="utf-8") if __name__ == "__main__":