diff --git a/bson/__init__.py b/bson/__init__.py index 97d38b37a..c8ac12e46 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -1166,7 +1166,7 @@ def decode_file_iter(file_obj, codec_options=DEFAULT_CODEC_OPTIONS): elif len(size_data) != 4: raise InvalidBSON("cut off in middle of objsize") obj_size = _UNPACK_INT_FROM(size_data, 0)[0] - 4 - elements = size_data + file_obj.read(obj_size) + elements = size_data + file_obj.read(max(0, obj_size)) yield _bson_to_dict(elements, codec_options) diff --git a/doc/contributors.rst b/doc/contributors.rst index 26717f76b..9773b3822 100644 --- a/doc/contributors.rst +++ b/doc/contributors.rst @@ -86,3 +86,4 @@ The following is a list of people who have contributed to - Shrey Batra(shreybatra) - Felipe Rodrigues(fbidu) - Terence Honles (terencehonles) +- Paul Fisher (thetorpedodog) diff --git a/test/test_bson.py b/test/test_bson.py index 0e1b8e1b7..dd604c738 100644 --- a/test/test_bson.py +++ b/test/test_bson.py @@ -18,8 +18,10 @@ import collections import datetime +import os import re import sys +import tempfile import uuid sys.path[0:0] = [""] @@ -335,41 +337,41 @@ class TestBSON(unittest.TestCase): self.assertRaises(InvalidBSON, list, decode_file_iter(StringIO(b"\x1B"))) - # An object size that's too small to even include the object size, - # but is correctly encoded, along with a correct EOO (and no data). - data = b"\x01\x00\x00\x00\x00" - self.assertRaises(InvalidBSON, decode_all, data) - self.assertRaises(InvalidBSON, list, decode_iter(data)) - self.assertRaises(InvalidBSON, list, decode_file_iter(StringIO(data))) - - # One object, but with object size listed smaller than it is in the - # data. - data = (b"\x1A\x00\x00\x00\x0E\x74\x65\x73\x74" - b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" - b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" - b"\x05\x00\x00\x00\x00") - self.assertRaises(InvalidBSON, decode_all, data) - self.assertRaises(InvalidBSON, list, decode_iter(data)) - self.assertRaises(InvalidBSON, list, decode_file_iter(StringIO(data))) - - # One object, missing the EOO at the end. - data = (b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74" - b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" - b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" - b"\x05\x00\x00\x00") - self.assertRaises(InvalidBSON, decode_all, data) - self.assertRaises(InvalidBSON, list, decode_iter(data)) - self.assertRaises(InvalidBSON, list, decode_file_iter(StringIO(data))) - - # One object, sized correctly, with a spot for an EOO, but the EOO - # isn't 0x00. - data = (b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74" - b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" - b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" - b"\x05\x00\x00\x00\xFF") - self.assertRaises(InvalidBSON, decode_all, data) - self.assertRaises(InvalidBSON, list, decode_iter(data)) - self.assertRaises(InvalidBSON, list, decode_file_iter(StringIO(data))) + bad_bsons = [ + # An object size that's too small to even include the object size, + # but is correctly encoded, along with a correct EOO (and no data). + b"\x01\x00\x00\x00\x00", + # One object, but with object size listed smaller than it is in the + # data. + (b"\x1A\x00\x00\x00\x0E\x74\x65\x73\x74" + b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" + b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" + b"\x05\x00\x00\x00\x00"), + # One object, missing the EOO at the end. + (b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74" + b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" + b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" + b"\x05\x00\x00\x00"), + # One object, sized correctly, with a spot for an EOO, but the EOO + # isn't 0x00. + (b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74" + b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" + b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" + b"\x05\x00\x00\x00\xFF"), + ] + for i, data in enumerate(bad_bsons): + msg = "bad_bson[{}]".format(i) + with self.assertRaises(InvalidBSON, msg=msg): + decode_all(data) + with self.assertRaises(InvalidBSON, msg=msg): + list(decode_iter(data)) + with self.assertRaises(InvalidBSON, msg=msg): + list(decode_file_iter(StringIO(data))) + with tempfile.TemporaryFile() as scratch: + scratch.write(data) + scratch.seek(0, os.SEEK_SET) + with self.assertRaises(InvalidBSON, msg=msg): + list(decode_file_iter(scratch)) def test_data_timestamp(self): self.assertEqual({"test": Timestamp(4, 20)},