# # Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Test the bson module.""" from __future__ import annotations import array import collections import datetime import importlib.util import mmap import os import pickle import re import struct import sys import tempfile import uuid from collections import OrderedDict, abc from io import BytesIO sys.path[0:0] = [""] from test import qcheck, unittest from test.helpers import ExceptionCatchingTask import bson from bson import ( BSON, EPOCH_AWARE, DatetimeMS, Regex, _array_of_documents_to_buffer, _datetime_to_millis, decode, decode_all, decode_file_iter, decode_iter, encode, is_valid, json_util, ) from bson.binary import ( USER_DEFINED_SUBTYPE, Binary, BinaryVector, BinaryVectorDtype, UuidRepresentation, ) from bson.code import Code from bson.codec_options import CodecOptions, DatetimeConversion from bson.datetime_ms import _DATETIME_ERROR_SUGGESTION from bson.dbref import DBRef from bson.errors import InvalidBSON, InvalidDocument from bson.int64 import Int64 from bson.max_key import MaxKey from bson.min_key import MinKey from bson.objectid import ObjectId from bson.son import SON from bson.timestamp import Timestamp from bson.tz_util import FixedOffset, utc _NUMPY_AVAILABLE = importlib.util.find_spec("numpy") is not None class NotADict(abc.MutableMapping): """Non-dict type that implements the mapping protocol.""" def __init__(self, initial=None): if not initial: self._dict = {} else: self._dict = initial def __iter__(self): return iter(self._dict) def __getitem__(self, item): return self._dict[item] def __delitem__(self, item): del self._dict[item] def __setitem__(self, item, value): self._dict[item] = value def __len__(self): return len(self._dict) def __eq__(self, other): if isinstance(other, abc.Mapping): return all(self.get(k) == other.get(k) for k in self) return NotImplemented def __repr__(self): return "NotADict(%s)" % repr(self._dict) class DSTAwareTimezone(datetime.tzinfo): def __init__(self, offset, name, dst_start_month, dst_end_month): self.__offset = offset self.__dst_start_month = dst_start_month self.__dst_end_month = dst_end_month self.__name = name def _is_dst(self, dt): return self.__dst_start_month <= dt.month <= self.__dst_end_month def utcoffset(self, dt): return datetime.timedelta(minutes=self.__offset) + self.dst(dt) def dst(self, dt): if self._is_dst(dt): return datetime.timedelta(hours=1) return datetime.timedelta(0) def tzname(self, dt): return self.__name class TestBSON(unittest.TestCase): def assertInvalid(self, data): self.assertRaises(InvalidBSON, decode, data) def check_encode_then_decode(self, doc_class=dict, decoder=decode, encoder=encode): def helper(doc): self.assertEqual(doc, (decoder(encoder(doc_class(doc))))) self.assertEqual(doc, decoder(encoder(doc))) helper({}) helper({"test": "hello"}) self.assertIsInstance(decoder(encoder({"hello": "world"}))["hello"], str) helper({"mike": -10120}) helper({"long": Int64(10)}) helper({"really big long": 2147483648}) helper({"hello": 0.0013109}) helper({"something": True}) helper({"false": False}) helper({"an array": [1, True, 3.8, "world"]}) helper({"an object": doc_class({"test": "something"})}) helper({"a binary": Binary(b"test", 100)}) helper({"a binary": Binary(b"test", 128)}) helper({"a binary": Binary(b"test", 254)}) helper({"another binary": Binary(b"test", 2)}) helper({"binary packed bit vector": Binary(b"\x10\x00\x7f\x07", 9)}) helper({"binary int8 vector": Binary(b"\x03\x00\x7f\x07", 9)}) helper({"binary float32 vector": Binary(b"'\x00\x00\x00\xfeB\x00\x00\xe0@", 9)}) helper(SON([("test dst", datetime.datetime(1993, 4, 4, 2))])) helper(SON([("test negative dst", datetime.datetime(1, 1, 1, 1, 1, 1))])) helper({"big float": float(10000000000)}) helper({"ref": DBRef("coll", 5)}) helper({"ref": DBRef("coll", 5, foo="bar", bar=4)}) helper({"ref": DBRef("coll", 5, "foo")}) helper({"ref": DBRef("coll", 5, "foo", foo="bar")}) helper({"ref": Timestamp(1, 2)}) helper({"foo": MinKey()}) helper({"foo": MaxKey()}) helper({"$field": Code("function(){ return true; }")}) helper({"$field": Code("return function(){ return x; }", scope={"x": False})}) def encode_then_decode(doc): return doc_class(doc) == decoder(encode(doc), CodecOptions(document_class=doc_class)) qcheck.check_unittest(self, encode_then_decode, qcheck.gen_mongo_dict(3)) def test_encode_then_decode(self): self.check_encode_then_decode() def test_encode_then_decode_any_mapping(self): self.check_encode_then_decode(doc_class=NotADict) def test_encode_then_decode_legacy(self): self.check_encode_then_decode( encoder=BSON.encode, decoder=lambda *args: BSON(args[0]).decode(*args[1:]) ) def test_encode_then_decode_any_mapping_legacy(self): self.check_encode_then_decode( doc_class=NotADict, encoder=BSON.encode, decoder=lambda *args: BSON(args[0]).decode(*args[1:]), ) def test_encoding_defaultdict(self): dct = collections.defaultdict(dict, [("foo", "bar")]) # type: ignore[arg-type] encode(dct) self.assertEqual(dct, collections.defaultdict(dict, [("foo", "bar")])) def test_basic_validation(self): self.assertRaises(TypeError, is_valid, 100) self.assertRaises(TypeError, is_valid, "test") self.assertRaises(TypeError, is_valid, 10.4) self.assertInvalid(b"test") # the simplest valid BSON document self.assertTrue(is_valid(b"\x05\x00\x00\x00\x00")) self.assertTrue(is_valid(BSON(b"\x05\x00\x00\x00\x00"))) # failure cases self.assertInvalid(b"\x04\x00\x00\x00\x00") self.assertInvalid(b"\x05\x00\x00\x00\x01") self.assertInvalid(b"\x05\x00\x00\x00") self.assertInvalid(b"\x05\x00\x00\x00\x00\x00") self.assertInvalid(b"\x07\x00\x00\x00\x02a\x00\x78\x56\x34\x12") self.assertInvalid(b"\x09\x00\x00\x00\x10a\x00\x05\x00") self.assertInvalid(b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00") self.assertInvalid(b"\x13\x00\x00\x00\x02foo\x00\x04\x00\x00\x00bar\x00\x00") self.assertInvalid( b"\x18\x00\x00\x00\x03foo\x00\x0f\x00\x00\x00\x10bar\x00\xff\xff\xff\x7f\x00\x00" ) self.assertInvalid(b"\x15\x00\x00\x00\x03foo\x00\x0c\x00\x00\x00\x08bar\x00\x01\x00\x00") self.assertInvalid( b"\x1c\x00\x00\x00\x03foo\x00" b"\x12\x00\x00\x00\x02bar\x00" b"\x05\x00\x00\x00baz\x00\x00\x00" ) self.assertInvalid(b"\x10\x00\x00\x00\x02a\x00\x04\x00\x00\x00abc\xff\x00") def test_bad_string_lengths(self): self.assertInvalid(b"\x0c\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00") self.assertInvalid(b"\x12\x00\x00\x00\x02\x00\xff\xff\xff\xfffoobar\x00\x00") self.assertInvalid(b"\x0c\x00\x00\x00\x0e\x00\x00\x00\x00\x00\x00\x00") self.assertInvalid(b"\x12\x00\x00\x00\x0e\x00\xff\xff\xff\xfffoobar\x00\x00") self.assertInvalid( b"\x18\x00\x00\x00\x0c\x00\x00\x00\x00\x00\x00RY\xb5j\xfa[\xd8A\xd6X]\x99\x00" ) self.assertInvalid( b"\x1e\x00\x00\x00\x0c\x00" b"\xff\xff\xff\xfffoobar\x00" b"RY\xb5j\xfa[\xd8A\xd6X]\x99\x00" ) self.assertInvalid(b"\x0c\x00\x00\x00\r\x00\x00\x00\x00\x00\x00\x00") self.assertInvalid(b"\x0c\x00\x00\x00\r\x00\xff\xff\xff\xff\x00\x00") self.assertInvalid( b"\x1c\x00\x00\x00\x0f\x00" b"\x15\x00\x00\x00\x00\x00" b"\x00\x00\x00\x0c\x00\x00" b"\x00\x02\x00\x01\x00\x00" b"\x00\x00\x00\x00" ) self.assertInvalid( b"\x1c\x00\x00\x00\x0f\x00" b"\x15\x00\x00\x00\xff\xff" b"\xff\xff\x00\x0c\x00\x00" b"\x00\x02\x00\x01\x00\x00" b"\x00\x00\x00\x00" ) self.assertInvalid( b"\x1c\x00\x00\x00\x0f\x00" b"\x15\x00\x00\x00\x01\x00" b"\x00\x00\x00\x0c\x00\x00" b"\x00\x02\x00\x00\x00\x00" b"\x00\x00\x00\x00" ) self.assertInvalid( b"\x1c\x00\x00\x00\x0f\x00" b"\x15\x00\x00\x00\x01\x00" b"\x00\x00\x00\x0c\x00\x00" b"\x00\x02\x00\xff\xff\xff" b"\xff\x00\x00\x00" ) def test_random_data_is_not_bson(self): qcheck.check_unittest( self, qcheck.isnt(is_valid), qcheck.gen_string(qcheck.gen_range(0, 40)) ) def test_basic_decode(self): self.assertEqual( {"test": "hello world"}, decode( b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74\x00\x0C" b"\x00\x00\x00\x68\x65\x6C\x6C\x6F\x20\x77\x6F" b"\x72\x6C\x64\x00\x00" ), ) self.assertEqual( [{"test": "hello world"}, {}], decode_all( b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74" b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" b"\x05\x00\x00\x00\x00" ), ) self.assertEqual( [{"test": "hello world"}, {}], list( decode_iter( b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74" b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" b"\x05\x00\x00\x00\x00" ) ), ) self.assertEqual( [{"test": "hello world"}, {}], list( decode_file_iter( BytesIO( b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74" b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" b"\x05\x00\x00\x00\x00" ) ) ), ) def test_decode_all_buffer_protocol(self): docs = [{"foo": "bar"}, {}] bs = b"".join(map(encode, docs)) # type: ignore[arg-type] self.assertEqual(docs, decode_all(bytearray(bs))) self.assertEqual(docs, decode_all(memoryview(bs))) self.assertEqual(docs, decode_all(memoryview(b"1" + bs + b"1")[1:-1])) self.assertEqual(docs, decode_all(array.array("B", bs))) with mmap.mmap(-1, len(bs)) as mm: mm.write(bs) mm.seek(0) self.assertEqual(docs, decode_all(mm)) def test_decode_buffer_protocol(self): doc = {"foo": "bar"} bs = encode(doc) self.assertEqual(doc, decode(bs)) self.assertEqual(doc, decode(bytearray(bs))) self.assertEqual(doc, decode(memoryview(bs))) self.assertEqual(doc, decode(memoryview(b"1" + bs + b"1")[1:-1])) self.assertEqual(doc, decode(array.array("B", bs))) with mmap.mmap(-1, len(bs)) as mm: mm.write(bs) mm.seek(0) self.assertEqual(doc, decode(mm)) def test_invalid_decodes(self): # Invalid object size (not enough bytes in document for even # an object size of first object. # NOTE: decode_all and decode_iter don't care, not sure if they should? self.assertRaises(InvalidBSON, list, decode_file_iter(BytesIO(b"\x1B"))) bad_bsons = [ # An object size that's too small to even include the object size, # but is correctly encoded, along with a correct EOO (and no data). b"\x01\x00\x00\x00\x00", # One object, but with object size listed smaller than it is in the # data. ( b"\x1A\x00\x00\x00\x0E\x74\x65\x73\x74" b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" b"\x05\x00\x00\x00\x00" ), # One object, missing the EOO at the end. ( b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74" b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" b"\x05\x00\x00\x00" ), # One object, sized correctly, with a spot for an EOO, but the EOO # isn't 0x00. ( b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74" b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C" b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00" b"\x05\x00\x00\x00\xFF" ), ] for i, data in enumerate(bad_bsons): msg = f"bad_bson[{i}]" with self.assertRaises(InvalidBSON, msg=msg): decode_all(data) with self.assertRaises(InvalidBSON, msg=msg): list(decode_iter(data)) with self.assertRaises(InvalidBSON, msg=msg): list(decode_file_iter(BytesIO(data))) with tempfile.TemporaryFile() as scratch: scratch.write(data) scratch.seek(0, os.SEEK_SET) with self.assertRaises(InvalidBSON, msg=msg): list(decode_file_iter(scratch)) def test_invalid_field_name(self): # Decode a truncated field with self.assertRaises(InvalidBSON) as ctx: decode(b"\x0b\x00\x00\x00\x02field\x00") # Assert that the InvalidBSON error message is not empty. self.assertTrue(str(ctx.exception)) def test_data_timestamp(self): self.assertEqual( {"test": Timestamp(4, 20)}, decode(b"\x13\x00\x00\x00\x11\x74\x65\x73\x74\x00\x14\x00\x00\x00\x04\x00\x00\x00\x00"), ) def test_basic_encode(self): self.assertRaises(TypeError, encode, 100) self.assertRaises(TypeError, encode, "hello") self.assertRaises(TypeError, encode, None) self.assertRaises(TypeError, encode, []) self.assertEqual(encode({}), BSON(b"\x05\x00\x00\x00\x00")) self.assertEqual(encode({}), b"\x05\x00\x00\x00\x00") self.assertEqual( encode({"test": "hello world"}), b"\x1B\x00\x00\x00\x02\x74\x65\x73\x74\x00\x0C\x00" b"\x00\x00\x68\x65\x6C\x6C\x6F\x20\x77\x6F\x72\x6C" b"\x64\x00\x00", ) self.assertEqual( encode({"mike": 100}), b"\x0F\x00\x00\x00\x10\x6D\x69\x6B\x65\x00\x64\x00\x00\x00\x00", ) self.assertEqual( encode({"hello": 1.5}), b"\x14\x00\x00\x00\x01\x68\x65\x6C\x6C\x6F\x00\x00\x00\x00\x00\x00\x00\xF8\x3F\x00", ) self.assertEqual( encode({"true": True}), b"\x0C\x00\x00\x00\x08\x74\x72\x75\x65\x00\x01\x00" ) self.assertEqual( encode({"false": False}), b"\x0D\x00\x00\x00\x08\x66\x61\x6C\x73\x65\x00\x00\x00" ) self.assertEqual( encode({"empty": []}), b"\x11\x00\x00\x00\x04\x65\x6D\x70\x74\x79\x00\x05\x00\x00\x00\x00\x00", ) self.assertEqual( encode({"none": {}}), b"\x10\x00\x00\x00\x03\x6E\x6F\x6E\x65\x00\x05\x00\x00\x00\x00\x00", ) self.assertEqual( encode({"test": Binary(b"test", 0)}), b"\x14\x00\x00\x00\x05\x74\x65\x73\x74\x00\x04\x00\x00\x00\x00\x74\x65\x73\x74\x00", ) self.assertEqual( encode({"test": Binary(b"test", 2)}), b"\x18\x00\x00\x00\x05\x74\x65\x73\x74\x00\x08\x00" b"\x00\x00\x02\x04\x00\x00\x00\x74\x65\x73\x74\x00", ) self.assertEqual( encode({"test": Binary(b"test", 128)}), b"\x14\x00\x00\x00\x05\x74\x65\x73\x74\x00\x04\x00\x00\x00\x80\x74\x65\x73\x74\x00", ) self.assertEqual( encode({"vector_int8": Binary.from_vector([-128, -1, 127], BinaryVectorDtype.INT8)}), b"\x1c\x00\x00\x00\x05vector_int8\x00\x05\x00\x00\x00\t\x03\x00\x80\xff\x7f\x00", ) self.assertEqual( encode({"vector_bool": Binary.from_vector([1, 127], BinaryVectorDtype.PACKED_BIT)}), b"\x1b\x00\x00\x00\x05vector_bool\x00\x04\x00\x00\x00\t\x10\x00\x01\x7f\x00", ) self.assertEqual( encode( {"vector_float32": Binary.from_vector([-1.1, 1.1e10], BinaryVectorDtype.FLOAT32)} ), b"$\x00\x00\x00\x05vector_float32\x00\n\x00\x00\x00\t'\x00\xcd\xcc\x8c\xbf\xac\xe9#P\x00", ) self.assertEqual(encode({"test": None}), b"\x0B\x00\x00\x00\x0A\x74\x65\x73\x74\x00\x00") self.assertEqual( encode({"date": datetime.datetime(2007, 1, 8, 0, 30, 11)}), b"\x13\x00\x00\x00\x09\x64\x61\x74\x65\x00\x38\xBE\x1C\xFF\x0F\x01\x00\x00\x00", ) self.assertEqual( encode({"regex": re.compile(b"a*b", re.IGNORECASE)}), b"\x12\x00\x00\x00\x0B\x72\x65\x67\x65\x78\x00\x61\x2A\x62\x00\x69\x00\x00", ) self.assertEqual( encode({"$where": Code("test")}), b"\x16\x00\x00\x00\r$where\x00\x05\x00\x00\x00test\x00\x00", ) self.assertEqual( encode({"$field": Code("function(){ return true;}", scope=None)}), b"+\x00\x00\x00\r$field\x00\x1a\x00\x00\x00function(){ return true;}\x00\x00", ) self.assertEqual( encode({"$field": Code("return function(){ return x; }", scope={"x": False})}), b"=\x00\x00\x00\x0f$field\x000\x00\x00\x00\x1f\x00" b"\x00\x00return function(){ return x; }\x00\t\x00" b"\x00\x00\x08x\x00\x00\x00\x00", ) unicode_empty_scope = Code("function(){ return 'héllo';}", {}) self.assertEqual( encode({"$field": unicode_empty_scope}), b"8\x00\x00\x00\x0f$field\x00+\x00\x00\x00\x1e\x00" b"\x00\x00function(){ return 'h\xc3\xa9llo';}\x00\x05" b"\x00\x00\x00\x00\x00", ) a = ObjectId(b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B") self.assertEqual( encode({"oid": a}), b"\x16\x00\x00\x00\x07\x6F\x69\x64\x00\x00\x01\x02" b"\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x00", ) self.assertEqual( encode({"ref": DBRef("coll", a)}), b"\x2F\x00\x00\x00\x03ref\x00\x25\x00\x00\x00\x02" b"$ref\x00\x05\x00\x00\x00coll\x00\x07$id\x00\x00" b"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x00" b"\x00", ) def test_bad_code(self): # Assert that decoding invalid Code with scope does not include a field name. def generate_payload(length: int) -> bytes: string_size = length - 0x1E return bytes.fromhex( struct.pack(", >=, !=, and ==. # These tests should be kept as assertTrue as opposed to using unittest's built-in comparison assertions because # MinKey and MaxKey define their own __ge__, __le__, and other comparison attributes, and we want to explicitly test that. self.assertTrue(MinKey() < None) self.assertTrue(MinKey() < 1) self.assertTrue(MinKey() <= 1) self.assertTrue(MinKey() <= MinKey()) self.assertFalse(MinKey() > None) self.assertFalse(MinKey() > 1) self.assertFalse(MinKey() >= 1) self.assertTrue(MinKey() >= MinKey()) self.assertTrue(MinKey() != 1) self.assertFalse(MinKey() == 1) self.assertTrue(MinKey() == MinKey()) # MinKey compared to MaxKey. self.assertTrue(MinKey() < MaxKey()) self.assertTrue(MinKey() <= MaxKey()) self.assertFalse(MinKey() > MaxKey()) self.assertFalse(MinKey() >= MaxKey()) self.assertTrue(MinKey() != MaxKey()) self.assertFalse(MinKey() == MaxKey()) # MaxKey's <, <=, >, >=, !=, and ==. self.assertFalse(MaxKey() < None) self.assertFalse(MaxKey() < 1) self.assertFalse(MaxKey() <= 1) self.assertTrue(MaxKey() <= MaxKey()) self.assertTrue(MaxKey() > None) self.assertTrue(MaxKey() > 1) self.assertTrue(MaxKey() >= 1) self.assertTrue(MaxKey() >= MaxKey()) self.assertTrue(MaxKey() != 1) self.assertFalse(MaxKey() == 1) self.assertTrue(MaxKey() == MaxKey()) # MaxKey compared to MinKey. self.assertFalse(MaxKey() < MinKey()) self.assertFalse(MaxKey() <= MinKey()) self.assertTrue(MaxKey() > MinKey()) self.assertTrue(MaxKey() >= MinKey()) self.assertTrue(MaxKey() != MinKey()) self.assertFalse(MaxKey() == MinKey()) def test_minkey_maxkey_hash(self): self.assertEqual(hash(MaxKey()), hash(MaxKey())) self.assertEqual(hash(MinKey()), hash(MinKey())) self.assertNotEqual(hash(MaxKey()), hash(MinKey())) def test_timestamp_comparison(self): # Timestamp is initialized with time, inc. Time is the more # significant comparand. self.assertTrue(Timestamp(1, 0) < Timestamp(2, 17)) self.assertTrue(Timestamp(2, 0) > Timestamp(1, 0)) self.assertTrue(Timestamp(1, 7) <= Timestamp(2, 0)) self.assertTrue(Timestamp(2, 0) >= Timestamp(1, 1)) self.assertTrue(Timestamp(2, 0) <= Timestamp(2, 0)) self.assertTrue(Timestamp(2, 0) >= Timestamp(2, 0)) self.assertFalse(Timestamp(1, 0) > Timestamp(2, 0)) # Comparison by inc. self.assertTrue(Timestamp(1, 0) < Timestamp(1, 1)) self.assertTrue(Timestamp(1, 1) > Timestamp(1, 0)) self.assertTrue(Timestamp(1, 0) <= Timestamp(1, 0)) self.assertTrue(Timestamp(1, 0) <= Timestamp(1, 1)) self.assertFalse(Timestamp(1, 0) >= Timestamp(1, 1)) self.assertTrue(Timestamp(1, 0) >= Timestamp(1, 0)) self.assertTrue(Timestamp(1, 1) >= Timestamp(1, 0)) self.assertFalse(Timestamp(1, 1) <= Timestamp(1, 0)) self.assertTrue(Timestamp(1, 0) <= Timestamp(1, 0)) self.assertFalse(Timestamp(1, 0) > Timestamp(1, 0)) def test_timestamp_highorder_bits(self): doc = {"a": Timestamp(0xFFFFFFFF, 0xFFFFFFFF)} doc_bson = b"\x10\x00\x00\x00\x11a\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00" self.assertEqual(doc_bson, encode(doc)) self.assertEqual(doc, decode(doc_bson)) def test_bad_id_keys(self): self.assertRaises(InvalidDocument, encode, {"_id": {"$bad": 123}}, True) self.assertRaises( InvalidDocument, encode, {"_id": {"$oid": "52d0b971b3ba219fdeb4170e"}}, True ) encode({"_id": {"$oid": "52d0b971b3ba219fdeb4170e"}}) def test_bson_encode_thread_safe(self): def target(i): for j in range(1000): my_int = type(f"MyInt_{i}_{j}", (int,), {}) bson.encode({"my_int": my_int()}) threads = [ExceptionCatchingTask(target=target, args=(i,)) for i in range(3)] for t in threads: t.start() for t in threads: t.join() for t in threads: self.assertIsNone(t.exc) def test_raise_invalid_document(self): class Wrapper: def __init__(self, val): self.val = val def __repr__(self): return repr(self.val) self.assertEqual("1", repr(Wrapper(1))) with self.assertRaisesRegex( InvalidDocument, "cannot encode object: 1, of type: " + repr(Wrapper) ): encode({"t": Wrapper(1)}) def test_doc_in_invalid_document_error_as_property(self): class Wrapper: def __init__(self, val): self.val = val def __repr__(self): return repr(self.val) self.assertEqual("1", repr(Wrapper(1))) doc = {"t": Wrapper(1)} with self.assertRaisesRegex(InvalidDocument, "Invalid document:") as cm: encode(doc) self.assertEqual(cm.exception.document, doc) def test_doc_in_invalid_document_error_as_property_mapping(self): class MyMapping(abc.Mapping): def keys(self): return ["t"] def __getitem__(self, name): if name == "_id": return None return Wrapper(name) def __len__(self): return 1 def __iter__(self): return iter(["t"]) def __eq__(self, other): if isinstance(other, MyMapping): return True return False class Wrapper: def __init__(self, val): self.val = val def __repr__(self): return repr(self.val) self.assertEqual("1", repr(Wrapper(1))) doc = MyMapping() with self.assertRaisesRegex(InvalidDocument, "Invalid document:") as cm: encode(doc) self.assertEqual(cm.exception.document, doc) def test_binary_length_accounts_for_header(self): size = 20 binary_length = 12 # 5 more than the actual 7 bytes payload = b"" payload += struct.pack(" None: doc = {"_id": ObjectId()} encoded = bson.encode(doc) decoded = bson.decode(encoded) encoded = bson.encode(decoded) decoded = bson.decode(encoded) # Documents returned from decode are mutable. decoded["new_field"] = 1 self.assertTrue(decoded["_id"].generation_time) class TestDatetimeConversion(unittest.TestCase): def test_comps(self): # Tests other timestamp formats. # Test each of the rich comparison methods. pairs = [ (DatetimeMS(-1), DatetimeMS(1)), (DatetimeMS(0), DatetimeMS(0)), (DatetimeMS(1), DatetimeMS(-1)), ] comp_ops = ["__lt__", "__le__", "__eq__", "__ne__", "__gt__", "__ge__"] for lh, rh in pairs: for op in comp_ops: self.assertEqual(getattr(lh, op)(rh), getattr(lh._value, op)(rh._value)) def test_class_conversions(self): # Test class conversions. dtr1 = DatetimeMS(1234) dt1 = dtr1.as_datetime() self.assertEqual(dtr1, DatetimeMS(dt1)) dt2 = datetime.datetime(1969, 1, 1) dtr2 = DatetimeMS(dt2) self.assertEqual(dtr2.as_datetime(), dt2) # Test encode and decode without codec options. Expect: DatetimeMS => datetime dtr1 = DatetimeMS(0) enc1 = encode({"x": dtr1}) dec1 = decode(enc1) self.assertEqual(dec1["x"], datetime.datetime(1970, 1, 1)) self.assertNotEqual(type(dtr1), type(dec1["x"])) # Test encode and decode with codec options. Expect: UTCDateimteRaw => DatetimeMS opts1 = CodecOptions(datetime_conversion=DatetimeConversion.DATETIME_MS) enc1 = encode({"x": dtr1}) dec1 = decode(enc1, opts1) self.assertEqual(type(dtr1), type(dec1["x"])) self.assertEqual(dtr1, dec1["x"]) # Expect: datetime => DatetimeMS opts1 = CodecOptions(datetime_conversion=DatetimeConversion.DATETIME_MS) dt1 = datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc) enc1 = encode({"x": dt1}) dec1 = decode(enc1, opts1) self.assertEqual(dec1["x"], DatetimeMS(0)) self.assertNotEqual(dt1, type(dec1["x"])) def test_clamping(self): # Test clamping from below and above. opts = CodecOptions( datetime_conversion=DatetimeConversion.DATETIME_CLAMP, tz_aware=True, tzinfo=datetime.timezone.utc, ) below = encode({"x": DatetimeMS(_datetime_to_millis(datetime.datetime.min) - 1)}) dec_below = decode(below, opts) self.assertEqual( dec_below["x"], datetime.datetime.min.replace(tzinfo=datetime.timezone.utc) ) above = encode({"x": DatetimeMS(_datetime_to_millis(datetime.datetime.max) + 1)}) dec_above = decode(above, opts) self.assertEqual( dec_above["x"], datetime.datetime.max.replace(tzinfo=datetime.timezone.utc, microsecond=999000), ) def test_tz_clamping_local(self): # Naive clamping to local tz. opts = CodecOptions(datetime_conversion=DatetimeConversion.DATETIME_CLAMP, tz_aware=False) below = encode({"x": DatetimeMS(_datetime_to_millis(datetime.datetime.min) - 24 * 60 * 60)}) dec_below = decode(below, opts) self.assertEqual(dec_below["x"], datetime.datetime.min) above = encode({"x": DatetimeMS(_datetime_to_millis(datetime.datetime.max) + 24 * 60 * 60)}) dec_above = decode(above, opts) self.assertEqual( dec_above["x"], datetime.datetime.max.replace(microsecond=999000), ) def test_tz_clamping_utc(self): # Aware clamping default utc. opts = CodecOptions(datetime_conversion=DatetimeConversion.DATETIME_CLAMP, tz_aware=True) below = encode({"x": DatetimeMS(_datetime_to_millis(datetime.datetime.min) - 24 * 60 * 60)}) dec_below = decode(below, opts) self.assertEqual( dec_below["x"], datetime.datetime.min.replace(tzinfo=datetime.timezone.utc) ) above = encode({"x": DatetimeMS(_datetime_to_millis(datetime.datetime.max) + 24 * 60 * 60)}) dec_above = decode(above, opts) self.assertEqual( dec_above["x"], datetime.datetime.max.replace(tzinfo=datetime.timezone.utc, microsecond=999000), ) def test_tz_clamping_non_utc(self): for tz in [FixedOffset(60, "+1H"), FixedOffset(-60, "-1H")]: opts = CodecOptions( datetime_conversion=DatetimeConversion.DATETIME_CLAMP, tz_aware=True, tzinfo=tz ) # Min/max values in this timezone which can be represented in both BSON and datetime UTC. try: min_tz = datetime.datetime.min.replace(tzinfo=utc).astimezone(tz) except OverflowError: min_tz = datetime.datetime.min.replace(tzinfo=tz) try: max_tz = datetime.datetime.max.replace(tzinfo=utc, microsecond=999000).astimezone( tz ) except OverflowError: max_tz = datetime.datetime.max.replace(tzinfo=tz, microsecond=999000) for in_range in [ min_tz, min_tz + datetime.timedelta(milliseconds=1), max_tz - datetime.timedelta(milliseconds=1), max_tz, ]: doc = decode(encode({"x": in_range}), opts) self.assertEqual(doc["x"], in_range) for too_low in [ DatetimeMS(_datetime_to_millis(min_tz) - 1), DatetimeMS(_datetime_to_millis(min_tz) - 60 * 60 * 1000), DatetimeMS(_datetime_to_millis(min_tz) - 1 - 60 * 60 * 1000), DatetimeMS(_datetime_to_millis(datetime.datetime.min) - 1), DatetimeMS(_datetime_to_millis(datetime.datetime.min) - 60 * 60 * 1000), DatetimeMS(_datetime_to_millis(datetime.datetime.min) - 1 - 60 * 60 * 1000), ]: doc = decode(encode({"x": too_low}), opts) self.assertEqual(doc["x"], min_tz) for too_high in [ DatetimeMS(_datetime_to_millis(max_tz) + 1), DatetimeMS(_datetime_to_millis(max_tz) + 60 * 60 * 1000), DatetimeMS(_datetime_to_millis(max_tz) + 1 + 60 * 60 * 1000), DatetimeMS(_datetime_to_millis(datetime.datetime.max) + 1), DatetimeMS(_datetime_to_millis(datetime.datetime.max) + 60 * 60 * 1000), DatetimeMS(_datetime_to_millis(datetime.datetime.max) + 1 + 60 * 60 * 1000), ]: doc = decode(encode({"x": too_high}), opts) self.assertEqual(doc["x"], max_tz) def test_tz_clamping_non_utc_simple(self): dtm = datetime.datetime(2024, 8, 23) encoded = encode({"d": dtm}) self.assertEqual(decode(encoded)["d"], dtm) for conversion in [ DatetimeConversion.DATETIME, DatetimeConversion.DATETIME_CLAMP, DatetimeConversion.DATETIME_AUTO, ]: for tz in [FixedOffset(60, "+1H"), FixedOffset(-60, "-1H")]: opts = CodecOptions(datetime_conversion=conversion, tz_aware=True, tzinfo=tz) self.assertEqual(decode(encoded, opts)["d"], dtm.replace(tzinfo=utc).astimezone(tz)) def test_tz_clamping_non_hashable(self): class NonHashableTZ(FixedOffset): __hash__ = None tz = NonHashableTZ(0, "UTC-non-hashable") self.assertRaises(TypeError, hash, tz) # Aware clamping. opts = CodecOptions( datetime_conversion=DatetimeConversion.DATETIME_CLAMP, tz_aware=True, tzinfo=tz ) below = encode({"x": DatetimeMS(_datetime_to_millis(datetime.datetime.min) - 24 * 60 * 60)}) dec_below = decode(below, opts) self.assertEqual(dec_below["x"], datetime.datetime.min.replace(tzinfo=tz)) within = encode({"x": EPOCH_AWARE.astimezone(tz)}) dec_within = decode(within, opts) self.assertEqual(dec_within["x"], EPOCH_AWARE.astimezone(tz)) above = encode({"x": DatetimeMS(_datetime_to_millis(datetime.datetime.max) + 24 * 60 * 60)}) dec_above = decode(above, opts) self.assertEqual( dec_above["x"], datetime.datetime.max.replace(tzinfo=tz, microsecond=999000), ) def test_datetime_auto(self): # Naive auto, in range. opts1 = CodecOptions(datetime_conversion=DatetimeConversion.DATETIME_AUTO) inr = encode({"x": datetime.datetime(1970, 1, 1)}, codec_options=opts1) dec_inr = decode(inr) self.assertEqual(dec_inr["x"], datetime.datetime(1970, 1, 1)) # Naive auto, below range. below = encode({"x": DatetimeMS(_datetime_to_millis(datetime.datetime.min) - 24 * 60 * 60)}) dec_below = decode(below, opts1) self.assertEqual( dec_below["x"], DatetimeMS(_datetime_to_millis(datetime.datetime.min) - 24 * 60 * 60) ) # Naive auto, above range. above = encode({"x": DatetimeMS(_datetime_to_millis(datetime.datetime.max) + 24 * 60 * 60)}) dec_above = decode(above, opts1) self.assertEqual( dec_above["x"], DatetimeMS(_datetime_to_millis(datetime.datetime.max) + 24 * 60 * 60), ) # Aware auto, in range. opts2 = CodecOptions( datetime_conversion=DatetimeConversion.DATETIME_AUTO, tz_aware=True, tzinfo=datetime.timezone.utc, ) inr = encode({"x": datetime.datetime(1970, 1, 1)}, codec_options=opts2) dec_inr = decode(inr) self.assertEqual(dec_inr["x"], datetime.datetime(1970, 1, 1)) # Aware auto, below range. below = encode({"x": DatetimeMS(_datetime_to_millis(datetime.datetime.min) - 24 * 60 * 60)}) dec_below = decode(below, opts2) self.assertEqual( dec_below["x"], DatetimeMS(_datetime_to_millis(datetime.datetime.min) - 24 * 60 * 60) ) # Aware auto, above range. above = encode({"x": DatetimeMS(_datetime_to_millis(datetime.datetime.max) + 24 * 60 * 60)}) dec_above = decode(above, opts2) self.assertEqual( dec_above["x"], DatetimeMS(_datetime_to_millis(datetime.datetime.max) + 24 * 60 * 60), ) def test_millis_from_datetime_ms(self): # Test 65+ bit integer conversion, expect OverflowError. big_ms = 2**65 with self.assertRaises(OverflowError): encode({"x": DatetimeMS(big_ms)}) # Subclass of DatetimeMS w/ __int__ override, expect an Error. class DatetimeMSOverride(DatetimeMS): def __int__(self): return float(self._value) float_ms = DatetimeMSOverride(2) with self.assertRaises(TypeError): encode({"x": float_ms}) # Test InvalidBSON errors on conversion include _DATETIME_ERROR_SUGGESTION small_ms = -2 << 51 with self.assertRaisesRegex(InvalidBSON, re.compile(re.escape(_DATETIME_ERROR_SUGGESTION))): decode(encode({"a": DatetimeMS(small_ms)})) def test_array_of_documents_to_buffer(self): doc = dict(a=1) buf = _array_of_documents_to_buffer(encode({"0": doc})) self.assertEqual(buf, encode(doc)) buf = _array_of_documents_to_buffer(encode({"0": doc, "1": doc})) self.assertEqual(buf, encode(doc) + encode(doc)) with self.assertRaises(InvalidBSON): _array_of_documents_to_buffer(encode({"0": doc, "1": doc}) + b"1") buf = encode({"0": doc, "1": doc}) buf = buf[:-1] + b"1" with self.assertRaises(InvalidBSON): _array_of_documents_to_buffer(buf) # We replace the size of the array with \xff\xff\xff\x00 which is -221 as an int32. buf = b"\x14\x00\x00\x00\x04a\x00\xff\xff\xff\x00\x100\x00\x01\x00\x00\x00\x00\x00" with self.assertRaises(InvalidBSON): _array_of_documents_to_buffer(buf) class TestLongLongToString(unittest.TestCase): def test_long_long_to_string(self): try: from bson import _cbson _cbson._test_long_long_to_str() except ImportError: print("_cbson was not imported. Check compilation logs.") if __name__ == "__main__": unittest.main()