PYTHON-1552 Prevent uploading partial or corrupt GridFS files after an error occurs

This commit is contained in:
Ben Warner 2022-06-16 15:26:27 -07:00 committed by GitHub
parent 922e63d6e0
commit 4ae93c4937
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 41 additions and 27 deletions

View File

@ -879,12 +879,11 @@ and store it with other file metadata. For example::
import hashlib
my_db = MongoClient().test
fs = GridFSBucket(my_db)
grid_in = fs.open_upload_stream("test_file")
file_data = b'...'
sha356 = hashlib.sha256(file_data).hexdigest()
grid_in.write(file_data)
grid_in.sha356 = sha356 # Set the custom 'sha356' field
grid_in.close()
with fs.open_upload_stream("test_file") as grid_in:
file_data = b'...'
sha356 = hashlib.sha256(file_data).hexdigest()
grid_in.write(file_data)
grid_in.sha356 = sha356 # Set the custom 'sha356' field
Note that for large files, the checksum may need to be computed in chunks
to avoid the excessive memory needed to load the entire file at once.

View File

@ -109,11 +109,8 @@ class GridFS(object):
Equivalent to doing::
try:
f = new_file(**kwargs)
with fs.new_file(**kwargs) as f:
f.write(data)
finally:
f.close()
`data` can be either an instance of :class:`bytes` or a file-like
object providing a :meth:`read` method. If an `encoding` keyword
@ -134,13 +131,10 @@ class GridFS(object):
.. versionchanged:: 3.0
w=0 writes to GridFS are now prohibited.
"""
grid_file = GridIn(self.__collection, **kwargs)
try:
grid_file.write(data)
finally:
grid_file.close()
return grid_file._id
with GridIn(self.__collection, **kwargs) as grid_file:
grid_file.write(data)
return grid_file._id
def get(self, file_id: Any, session: Optional[ClientSession] = None) -> GridOut:
"""Get a file from GridFS by ``"_id"``.
@ -528,11 +522,11 @@ class GridFSBucket(object):
my_db = MongoClient().test
fs = GridFSBucket(my_db)
grid_in = fs.open_upload_stream(
with fs.open_upload_stream(
"test_file", chunk_size_bytes=4,
metadata={"contentType": "text/plain"})
grid_in.write("data I want to store!")
grid_in.close() # uploaded on close
metadata={"contentType": "text/plain"}) as grid_in:
grid_in.write("data I want to store!")
# uploaded on close
Returns an instance of :class:`~gridfs.grid_file.GridIn`.
@ -584,13 +578,13 @@ class GridFSBucket(object):
my_db = MongoClient().test
fs = GridFSBucket(my_db)
grid_in = fs.open_upload_stream_with_id(
with fs.open_upload_stream_with_id(
ObjectId(),
"test_file",
chunk_size_bytes=4,
metadata={"contentType": "text/plain"})
grid_in.write("data I want to store!")
grid_in.close() # uploaded on close
metadata={"contentType": "text/plain"}) as grid_in:
grid_in.write("data I want to store!")
# uploaded on close
Returns an instance of :class:`~gridfs.grid_file.GridIn`.

View File

@ -396,9 +396,14 @@ class GridIn(object):
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Any:
"""Support for the context manager protocol.
Close the file and allow exceptions to propagate.
Close the file if no exceptions occur and allow exceptions to propagate.
"""
self.close()
if exc_type is None:
# No exceptions happened.
self.close()
else:
# Something happened, at minimum mark as closed.
object.__setattr__(self, "_closed", True)
# propagate exceptions
return False

View File

@ -675,6 +675,22 @@ Bye"""
with GridOut(self.db.fs, infile._id) as outfile:
self.assertEqual(contents, outfile.read())
def test_exception_file_non_existence(self):
contents = b"Imagine this is some important data..."
with self.assertRaises(ConnectionError):
with GridIn(self.db.fs, filename="important") as infile:
infile.write(contents)
raise ConnectionError("Test exception")
# Expectation: File chunks are written, entry in files doesn't appear.
self.assertEqual(
self.db.fs.chunks.count_documents({"files_id": infile._id}), infile._chunk_number
)
self.assertIsNone(self.db.fs.files.find_one({"_id": infile._id}))
self.assertTrue(infile.closed)
def test_prechunked_string(self):
def write_me(s, chunk_size):
buf = BytesIO(s)

View File

@ -540,7 +540,7 @@ class TestGridfsReplicaSet(IntegrationTest):
# Connects, doesn't create index.
self.assertRaises(NoFile, fs.get_last_version)
self.assertRaises(NotPrimaryError, fs.put, "data")
self.assertRaises(NotPrimaryError, fs.put, "data", encoding="utf-8")
if __name__ == "__main__":