From fbb56a231198c2c0e2299e466151a140c7004d4b Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Thu, 18 Apr 2019 12:01:43 -0700 Subject: [PATCH] PYTHON-1820 Validate bson size in RawBSONDocument init Also fixes a bug where an empty bson document could not be represented by RawBSONDocument. --- bson/__init__.py | 48 +++++++++++++++++++++---------------------- bson/raw_bson.py | 46 ++++++++++++++++++++++++++++------------- doc/changelog.rst | 3 +++ test/test_raw_bson.py | 16 +++++++++++++++ 4 files changed, 75 insertions(+), 38 deletions(-) diff --git a/bson/__init__.py b/bson/__init__.py index b56e8ed48..8b0b928de 100644 --- a/bson/__init__.py +++ b/bson/__init__.py @@ -183,14 +183,26 @@ def _get_string(data, position, obj_end, opts, dummy): opts.unicode_decode_error_handler, True)[0], end + 1 -def _get_object(data, position, obj_end, opts, dummy): - """Decode a BSON subdocument to opts.document_class or bson.dbref.DBRef.""" - obj_size = _UNPACK_INT(data[position:position + 4])[0] +def _get_object_size(data, position, obj_end): + """Validate and return a BSON document's size.""" + try: + obj_size = _UNPACK_INT(data[position:position + 4])[0] + except struct.error as exc: + raise InvalidBSON(str(exc)) end = position + obj_size - 1 - if data[end:position + obj_size] != b"\x00": + if data[end:end + 1] != b"\x00": raise InvalidBSON("bad eoo") if end >= obj_end: raise InvalidBSON("invalid object length") + # If this is the top-level document, validate the total size too. + if position == 0 and obj_size != obj_end: + raise InvalidBSON("invalid object length") + return obj_size, end + + +def _get_object(data, position, obj_end, opts, dummy): + """Decode a BSON subdocument to opts.document_class or bson.dbref.DBRef.""" + obj_size, end = _get_object_size(data, position, obj_end) if _raw_document_class(opts.document_class): return (opts.document_class(data[position:end + 1], opts), position + obj_size) @@ -406,38 +418,26 @@ if _USE_C: _element_to_dict = _cbson._element_to_dict -def _iterate_elements(data, position, obj_end, opts): +def _elements_to_dict(data, position, obj_end, opts, result=None): + """Decode a BSON document into result.""" + if result is None: + result = opts.document_class() end = obj_end - 1 while position < end: - (key, value, position) = _element_to_dict(data, position, obj_end, opts) - yield key, value, position - - -def _elements_to_dict(data, position, obj_end, opts): - """Decode a BSON document.""" - result = opts.document_class() - pos = position - for key, value, pos in _iterate_elements(data, position, obj_end, opts): + key, value, position = _element_to_dict(data, position, obj_end, opts) result[key] = value - if pos != obj_end: + if position != obj_end: raise InvalidBSON('bad object or element length') return result def _bson_to_dict(data, opts): """Decode a BSON string to document_class.""" - try: - obj_size = _UNPACK_INT(data[:4])[0] - except struct.error as exc: - raise InvalidBSON(str(exc)) - if obj_size != len(data): - raise InvalidBSON("invalid object size") - if data[obj_size - 1:obj_size] != b"\x00": - raise InvalidBSON("bad eoo") try: if _raw_document_class(opts.document_class): return opts.document_class(data, opts) - return _elements_to_dict(data, 4, obj_size - 1, opts) + _, end = _get_object_size(data, 0, len(data)) + return _elements_to_dict(data, 4, end, opts) except InvalidBSON: raise except Exception: diff --git a/bson/raw_bson.py b/bson/raw_bson.py index f99017c96..c13b1a9b3 100644 --- a/bson/raw_bson.py +++ b/bson/raw_bson.py @@ -15,11 +15,10 @@ """Tools for representing raw BSON documents. """ -from bson import _UNPACK_INT, _iterate_elements +from bson import _elements_to_dict, _get_object_size from bson.py3compat import abc, iteritems from bson.codec_options import ( DEFAULT_CODEC_OPTIONS as DEFAULT, _RAW_BSON_DOCUMENT_MARKER) -from bson.errors import InvalidBSON class RawBSONDocument(abc.Mapping): @@ -34,12 +33,33 @@ class RawBSONDocument(abc.Mapping): _type_marker = _RAW_BSON_DOCUMENT_MARKER def __init__(self, bson_bytes, codec_options=None): - """Create a new :class:`RawBSONDocument`. + """Create a new :class:`RawBSONDocument` + + :class:`RawBSONDocument` is a representation of a BSON document that + provides access to the underlying raw BSON bytes. Only when a field is + accessed or modified within the document does RawBSONDocument decode + its bytes. + + :class:`RawBSONDocument` implements the ``Mapping`` abstract base + class from the standard library so it can be used like a read-only + ``dict``:: + + >>> raw_doc = RawBSONDocument(BSON.encode({'_id': 'my_doc'})) + >>> raw_doc.raw + b'...' + >>> raw_doc['_id'] + 'my_doc' :Parameters: - `bson_bytes`: the BSON bytes that compose this document - `codec_options` (optional): An instance of - :class:`~bson.codec_options.CodecOptions`. + :class:`~bson.codec_options.CodecOptions` whose ``document_class`` + must be :class:`RawBSONDocument`. The default is + :attr:`DEFAULT_RAW_BSON_OPTIONS`. + + .. versionchanged:: 3.8 + :class:`RawBSONDocument` now validates that the ``bson_bytes`` + passed in represent a single bson document. .. versionchanged:: 3.5 If a :class:`~bson.codec_options.CodecOptions` is passed in, its @@ -56,6 +76,8 @@ class RawBSONDocument(abc.Mapping): "RawBSONDocument cannot use CodecOptions with document " "class %s" % (codec_options.document_class, )) self.__codec_options = codec_options + # Validate the bson object size. + _get_object_size(bson_bytes, 0, len(bson_bytes)) @property def raw(self): @@ -70,16 +92,9 @@ class RawBSONDocument(abc.Mapping): def __inflated(self): if self.__inflated_doc is None: # We already validated the object's size when this document was - # created, so no need to do that again. We still need to check the - # size of all the elements and compare to the document size. - object_size = _UNPACK_INT(self.__raw[:4])[0] - 1 - position = 0 - self.__inflated_doc = {} - for key, value, position in _iterate_elements( - self.__raw, 4, object_size, self.__codec_options): - self.__inflated_doc[key] = value - if position != object_size: - raise InvalidBSON('bad object or element length') + # created, so no need to do that again. + self.__inflated_doc = _elements_to_dict( + self.__raw, 4, len(self.__raw)-1, self.__codec_options, {}) return self.__inflated_doc def __getitem__(self, item): @@ -102,3 +117,6 @@ class RawBSONDocument(abc.Mapping): DEFAULT_RAW_BSON_OPTIONS = DEFAULT.with_options(document_class=RawBSONDocument) +"""The default :class:`~bson.codec_options.CodecOptions` for +:class:`RawBSONDocument`. +""" diff --git a/doc/changelog.rst b/doc/changelog.rst index d8707f75c..149f2b251 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -98,6 +98,9 @@ Changes in Version 3.8.0.dev0 supported since PyMongo 2.7. Valid values are `pythonLegacy` (the default), `javaLegacy`, `csharpLegacy` and `standard`. New applications should consider setting this to `standard` for cross language compatibility. +- :class:`~bson.raw_bson.RawBSONDocument` now validates that the ``bson_bytes`` + passed in represent a single bson document. Earlier versions would mistakenly + accept multiple bson documents. Issues Resolved ............... diff --git a/test/test_raw_bson.py b/test/test_raw_bson.py index 3931693f4..37a2ca08c 100644 --- a/test/test_raw_bson.py +++ b/test/test_raw_bson.py @@ -18,6 +18,7 @@ import uuid from bson import BSON from bson.binary import JAVA_LEGACY from bson.codec_options import CodecOptions +from bson.errors import InvalidBSON from bson.raw_bson import RawBSONDocument from test import client_context, unittest @@ -51,6 +52,21 @@ class TestRawBSONDocument(unittest.TestCase): def test_raw(self): self.assertEqual(self.bson_string, self.document.raw) + def test_empty_doc(self): + doc = RawBSONDocument(BSON.encode({})) + with self.assertRaises(KeyError): + doc['does-not-exist'] + + def test_invalid_bson_sequence(self): + bson_byte_sequence = BSON.encode({'a': 1})+BSON.encode({}) + with self.assertRaisesRegex(InvalidBSON, 'invalid object length'): + RawBSONDocument(bson_byte_sequence) + + def test_invalid_bson_eoo(self): + invalid_bson_eoo = BSON.encode({'a': 1})[:-1] + b'\x01' + with self.assertRaisesRegex(InvalidBSON, 'bad eoo'): + RawBSONDocument(invalid_bson_eoo) + @client_context.require_connection def test_round_trip(self): db = self.client.get_database(