1691 lines
67 KiB
Python
1691 lines
67 KiB
Python
#
|
|
# Copyright 2009-present MongoDB, Inc.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""Test the bson module."""
|
|
from __future__ import annotations
|
|
|
|
import array
|
|
import collections
|
|
import datetime
|
|
import mmap
|
|
import os
|
|
import pickle
|
|
import re
|
|
import struct
|
|
import sys
|
|
import tempfile
|
|
import uuid
|
|
from collections import OrderedDict, abc
|
|
from io import BytesIO
|
|
|
|
sys.path[0:0] = [""]
|
|
|
|
from test import qcheck, unittest
|
|
from test.helpers import ExceptionCatchingTask
|
|
|
|
import bson
|
|
from bson import (
|
|
BSON,
|
|
EPOCH_AWARE,
|
|
DatetimeMS,
|
|
Regex,
|
|
_array_of_documents_to_buffer,
|
|
_datetime_to_millis,
|
|
decode,
|
|
decode_all,
|
|
decode_file_iter,
|
|
decode_iter,
|
|
encode,
|
|
is_valid,
|
|
json_util,
|
|
)
|
|
from bson.binary import (
|
|
USER_DEFINED_SUBTYPE,
|
|
Binary,
|
|
BinaryVector,
|
|
BinaryVectorDtype,
|
|
UuidRepresentation,
|
|
)
|
|
from bson.code import Code
|
|
from bson.codec_options import CodecOptions, DatetimeConversion
|
|
from bson.datetime_ms import _DATETIME_ERROR_SUGGESTION
|
|
from bson.dbref import DBRef
|
|
from bson.errors import InvalidBSON, InvalidDocument
|
|
from bson.int64 import Int64
|
|
from bson.max_key import MaxKey
|
|
from bson.min_key import MinKey
|
|
from bson.objectid import ObjectId
|
|
from bson.son import SON
|
|
from bson.timestamp import Timestamp
|
|
from bson.tz_util import FixedOffset, utc
|
|
|
|
|
|
class NotADict(abc.MutableMapping):
|
|
"""Non-dict type that implements the mapping protocol."""
|
|
|
|
def __init__(self, initial=None):
|
|
if not initial:
|
|
self._dict = {}
|
|
else:
|
|
self._dict = initial
|
|
|
|
def __iter__(self):
|
|
return iter(self._dict)
|
|
|
|
def __getitem__(self, item):
|
|
return self._dict[item]
|
|
|
|
def __delitem__(self, item):
|
|
del self._dict[item]
|
|
|
|
def __setitem__(self, item, value):
|
|
self._dict[item] = value
|
|
|
|
def __len__(self):
|
|
return len(self._dict)
|
|
|
|
def __eq__(self, other):
|
|
if isinstance(other, abc.Mapping):
|
|
return all(self.get(k) == other.get(k) for k in self)
|
|
return NotImplemented
|
|
|
|
def __repr__(self):
|
|
return "NotADict(%s)" % repr(self._dict)
|
|
|
|
|
|
class DSTAwareTimezone(datetime.tzinfo):
|
|
def __init__(self, offset, name, dst_start_month, dst_end_month):
|
|
self.__offset = offset
|
|
self.__dst_start_month = dst_start_month
|
|
self.__dst_end_month = dst_end_month
|
|
self.__name = name
|
|
|
|
def _is_dst(self, dt):
|
|
return self.__dst_start_month <= dt.month <= self.__dst_end_month
|
|
|
|
def utcoffset(self, dt):
|
|
return datetime.timedelta(minutes=self.__offset) + self.dst(dt)
|
|
|
|
def dst(self, dt):
|
|
if self._is_dst(dt):
|
|
return datetime.timedelta(hours=1)
|
|
return datetime.timedelta(0)
|
|
|
|
def tzname(self, dt):
|
|
return self.__name
|
|
|
|
|
|
class TestBSON(unittest.TestCase):
|
|
def assertInvalid(self, data):
|
|
self.assertRaises(InvalidBSON, decode, data)
|
|
|
|
def check_encode_then_decode(self, doc_class=dict, decoder=decode, encoder=encode):
|
|
# Work around http://bugs.jython.org/issue1728
|
|
if sys.platform.startswith("java"):
|
|
doc_class = SON
|
|
|
|
def helper(doc):
|
|
self.assertEqual(doc, (decoder(encoder(doc_class(doc)))))
|
|
self.assertEqual(doc, decoder(encoder(doc)))
|
|
|
|
helper({})
|
|
helper({"test": "hello"})
|
|
self.assertIsInstance(decoder(encoder({"hello": "world"}))["hello"], str)
|
|
helper({"mike": -10120})
|
|
helper({"long": Int64(10)})
|
|
helper({"really big long": 2147483648})
|
|
helper({"hello": 0.0013109})
|
|
helper({"something": True})
|
|
helper({"false": False})
|
|
helper({"an array": [1, True, 3.8, "world"]})
|
|
helper({"an object": doc_class({"test": "something"})})
|
|
helper({"a binary": Binary(b"test", 100)})
|
|
helper({"a binary": Binary(b"test", 128)})
|
|
helper({"a binary": Binary(b"test", 254)})
|
|
helper({"another binary": Binary(b"test", 2)})
|
|
helper({"binary packed bit vector": Binary(b"\x10\x00\x7f\x07", 9)})
|
|
helper({"binary int8 vector": Binary(b"\x03\x00\x7f\x07", 9)})
|
|
helper({"binary float32 vector": Binary(b"'\x00\x00\x00\xfeB\x00\x00\xe0@", 9)})
|
|
helper(SON([("test dst", datetime.datetime(1993, 4, 4, 2))]))
|
|
helper(SON([("test negative dst", datetime.datetime(1, 1, 1, 1, 1, 1))]))
|
|
helper({"big float": float(10000000000)})
|
|
helper({"ref": DBRef("coll", 5)})
|
|
helper({"ref": DBRef("coll", 5, foo="bar", bar=4)})
|
|
helper({"ref": DBRef("coll", 5, "foo")})
|
|
helper({"ref": DBRef("coll", 5, "foo", foo="bar")})
|
|
helper({"ref": Timestamp(1, 2)})
|
|
helper({"foo": MinKey()})
|
|
helper({"foo": MaxKey()})
|
|
helper({"$field": Code("function(){ return true; }")})
|
|
helper({"$field": Code("return function(){ return x; }", scope={"x": False})})
|
|
|
|
def encode_then_decode(doc):
|
|
return doc_class(doc) == decoder(encode(doc), CodecOptions(document_class=doc_class))
|
|
|
|
qcheck.check_unittest(self, encode_then_decode, qcheck.gen_mongo_dict(3))
|
|
|
|
def test_encode_then_decode(self):
|
|
self.check_encode_then_decode()
|
|
|
|
def test_encode_then_decode_any_mapping(self):
|
|
self.check_encode_then_decode(doc_class=NotADict)
|
|
|
|
def test_encode_then_decode_legacy(self):
|
|
self.check_encode_then_decode(
|
|
encoder=BSON.encode, decoder=lambda *args: BSON(args[0]).decode(*args[1:])
|
|
)
|
|
|
|
def test_encode_then_decode_any_mapping_legacy(self):
|
|
self.check_encode_then_decode(
|
|
doc_class=NotADict,
|
|
encoder=BSON.encode,
|
|
decoder=lambda *args: BSON(args[0]).decode(*args[1:]),
|
|
)
|
|
|
|
def test_encoding_defaultdict(self):
|
|
dct = collections.defaultdict(dict, [("foo", "bar")]) # type: ignore[arg-type]
|
|
encode(dct)
|
|
self.assertEqual(dct, collections.defaultdict(dict, [("foo", "bar")]))
|
|
|
|
def test_basic_validation(self):
|
|
self.assertRaises(TypeError, is_valid, 100)
|
|
self.assertRaises(TypeError, is_valid, "test")
|
|
self.assertRaises(TypeError, is_valid, 10.4)
|
|
|
|
self.assertInvalid(b"test")
|
|
|
|
# the simplest valid BSON document
|
|
self.assertTrue(is_valid(b"\x05\x00\x00\x00\x00"))
|
|
self.assertTrue(is_valid(BSON(b"\x05\x00\x00\x00\x00")))
|
|
|
|
# failure cases
|
|
self.assertInvalid(b"\x04\x00\x00\x00\x00")
|
|
self.assertInvalid(b"\x05\x00\x00\x00\x01")
|
|
self.assertInvalid(b"\x05\x00\x00\x00")
|
|
self.assertInvalid(b"\x05\x00\x00\x00\x00\x00")
|
|
self.assertInvalid(b"\x07\x00\x00\x00\x02a\x00\x78\x56\x34\x12")
|
|
self.assertInvalid(b"\x09\x00\x00\x00\x10a\x00\x05\x00")
|
|
self.assertInvalid(b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00")
|
|
self.assertInvalid(b"\x13\x00\x00\x00\x02foo\x00\x04\x00\x00\x00bar\x00\x00")
|
|
self.assertInvalid(
|
|
b"\x18\x00\x00\x00\x03foo\x00\x0f\x00\x00\x00\x10bar\x00\xff\xff\xff\x7f\x00\x00"
|
|
)
|
|
self.assertInvalid(b"\x15\x00\x00\x00\x03foo\x00\x0c\x00\x00\x00\x08bar\x00\x01\x00\x00")
|
|
self.assertInvalid(
|
|
b"\x1c\x00\x00\x00\x03foo\x00"
|
|
b"\x12\x00\x00\x00\x02bar\x00"
|
|
b"\x05\x00\x00\x00baz\x00\x00\x00"
|
|
)
|
|
self.assertInvalid(b"\x10\x00\x00\x00\x02a\x00\x04\x00\x00\x00abc\xff\x00")
|
|
|
|
def test_bad_string_lengths(self):
|
|
self.assertInvalid(b"\x0c\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00")
|
|
self.assertInvalid(b"\x12\x00\x00\x00\x02\x00\xff\xff\xff\xfffoobar\x00\x00")
|
|
self.assertInvalid(b"\x0c\x00\x00\x00\x0e\x00\x00\x00\x00\x00\x00\x00")
|
|
self.assertInvalid(b"\x12\x00\x00\x00\x0e\x00\xff\xff\xff\xfffoobar\x00\x00")
|
|
self.assertInvalid(
|
|
b"\x18\x00\x00\x00\x0c\x00\x00\x00\x00\x00\x00RY\xb5j\xfa[\xd8A\xd6X]\x99\x00"
|
|
)
|
|
self.assertInvalid(
|
|
b"\x1e\x00\x00\x00\x0c\x00"
|
|
b"\xff\xff\xff\xfffoobar\x00"
|
|
b"RY\xb5j\xfa[\xd8A\xd6X]\x99\x00"
|
|
)
|
|
self.assertInvalid(b"\x0c\x00\x00\x00\r\x00\x00\x00\x00\x00\x00\x00")
|
|
self.assertInvalid(b"\x0c\x00\x00\x00\r\x00\xff\xff\xff\xff\x00\x00")
|
|
self.assertInvalid(
|
|
b"\x1c\x00\x00\x00\x0f\x00"
|
|
b"\x15\x00\x00\x00\x00\x00"
|
|
b"\x00\x00\x00\x0c\x00\x00"
|
|
b"\x00\x02\x00\x01\x00\x00"
|
|
b"\x00\x00\x00\x00"
|
|
)
|
|
self.assertInvalid(
|
|
b"\x1c\x00\x00\x00\x0f\x00"
|
|
b"\x15\x00\x00\x00\xff\xff"
|
|
b"\xff\xff\x00\x0c\x00\x00"
|
|
b"\x00\x02\x00\x01\x00\x00"
|
|
b"\x00\x00\x00\x00"
|
|
)
|
|
self.assertInvalid(
|
|
b"\x1c\x00\x00\x00\x0f\x00"
|
|
b"\x15\x00\x00\x00\x01\x00"
|
|
b"\x00\x00\x00\x0c\x00\x00"
|
|
b"\x00\x02\x00\x00\x00\x00"
|
|
b"\x00\x00\x00\x00"
|
|
)
|
|
self.assertInvalid(
|
|
b"\x1c\x00\x00\x00\x0f\x00"
|
|
b"\x15\x00\x00\x00\x01\x00"
|
|
b"\x00\x00\x00\x0c\x00\x00"
|
|
b"\x00\x02\x00\xff\xff\xff"
|
|
b"\xff\x00\x00\x00"
|
|
)
|
|
|
|
def test_random_data_is_not_bson(self):
|
|
qcheck.check_unittest(
|
|
self, qcheck.isnt(is_valid), qcheck.gen_string(qcheck.gen_range(0, 40))
|
|
)
|
|
|
|
def test_basic_decode(self):
|
|
self.assertEqual(
|
|
{"test": "hello world"},
|
|
decode(
|
|
b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74\x00\x0C"
|
|
b"\x00\x00\x00\x68\x65\x6C\x6C\x6F\x20\x77\x6F"
|
|
b"\x72\x6C\x64\x00\x00"
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
[{"test": "hello world"}, {}],
|
|
decode_all(
|
|
b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74"
|
|
b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C"
|
|
b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00"
|
|
b"\x05\x00\x00\x00\x00"
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
[{"test": "hello world"}, {}],
|
|
list(
|
|
decode_iter(
|
|
b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74"
|
|
b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C"
|
|
b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00"
|
|
b"\x05\x00\x00\x00\x00"
|
|
)
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
[{"test": "hello world"}, {}],
|
|
list(
|
|
decode_file_iter(
|
|
BytesIO(
|
|
b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74"
|
|
b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C"
|
|
b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00"
|
|
b"\x05\x00\x00\x00\x00"
|
|
)
|
|
)
|
|
),
|
|
)
|
|
|
|
def test_decode_all_buffer_protocol(self):
|
|
docs = [{"foo": "bar"}, {}]
|
|
bs = b"".join(map(encode, docs)) # type: ignore[arg-type]
|
|
self.assertEqual(docs, decode_all(bytearray(bs)))
|
|
self.assertEqual(docs, decode_all(memoryview(bs)))
|
|
self.assertEqual(docs, decode_all(memoryview(b"1" + bs + b"1")[1:-1]))
|
|
self.assertEqual(docs, decode_all(array.array("B", bs)))
|
|
with mmap.mmap(-1, len(bs)) as mm:
|
|
mm.write(bs)
|
|
mm.seek(0)
|
|
self.assertEqual(docs, decode_all(mm))
|
|
|
|
def test_decode_buffer_protocol(self):
|
|
doc = {"foo": "bar"}
|
|
bs = encode(doc)
|
|
self.assertEqual(doc, decode(bs))
|
|
self.assertEqual(doc, decode(bytearray(bs)))
|
|
self.assertEqual(doc, decode(memoryview(bs)))
|
|
self.assertEqual(doc, decode(memoryview(b"1" + bs + b"1")[1:-1]))
|
|
self.assertEqual(doc, decode(array.array("B", bs)))
|
|
with mmap.mmap(-1, len(bs)) as mm:
|
|
mm.write(bs)
|
|
mm.seek(0)
|
|
self.assertEqual(doc, decode(mm))
|
|
|
|
def test_invalid_decodes(self):
|
|
# Invalid object size (not enough bytes in document for even
|
|
# an object size of first object.
|
|
# NOTE: decode_all and decode_iter don't care, not sure if they should?
|
|
self.assertRaises(InvalidBSON, list, decode_file_iter(BytesIO(b"\x1B")))
|
|
|
|
bad_bsons = [
|
|
# An object size that's too small to even include the object size,
|
|
# but is correctly encoded, along with a correct EOO (and no data).
|
|
b"\x01\x00\x00\x00\x00",
|
|
# One object, but with object size listed smaller than it is in the
|
|
# data.
|
|
(
|
|
b"\x1A\x00\x00\x00\x0E\x74\x65\x73\x74"
|
|
b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C"
|
|
b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00"
|
|
b"\x05\x00\x00\x00\x00"
|
|
),
|
|
# One object, missing the EOO at the end.
|
|
(
|
|
b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74"
|
|
b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C"
|
|
b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00"
|
|
b"\x05\x00\x00\x00"
|
|
),
|
|
# One object, sized correctly, with a spot for an EOO, but the EOO
|
|
# isn't 0x00.
|
|
(
|
|
b"\x1B\x00\x00\x00\x0E\x74\x65\x73\x74"
|
|
b"\x00\x0C\x00\x00\x00\x68\x65\x6C\x6C"
|
|
b"\x6f\x20\x77\x6F\x72\x6C\x64\x00\x00"
|
|
b"\x05\x00\x00\x00\xFF"
|
|
),
|
|
]
|
|
for i, data in enumerate(bad_bsons):
|
|
msg = f"bad_bson[{i}]"
|
|
with self.assertRaises(InvalidBSON, msg=msg):
|
|
decode_all(data)
|
|
with self.assertRaises(InvalidBSON, msg=msg):
|
|
list(decode_iter(data))
|
|
with self.assertRaises(InvalidBSON, msg=msg):
|
|
list(decode_file_iter(BytesIO(data)))
|
|
with tempfile.TemporaryFile() as scratch:
|
|
scratch.write(data)
|
|
scratch.seek(0, os.SEEK_SET)
|
|
with self.assertRaises(InvalidBSON, msg=msg):
|
|
list(decode_file_iter(scratch))
|
|
|
|
def test_invalid_field_name(self):
|
|
# Decode a truncated field
|
|
with self.assertRaises(InvalidBSON) as ctx:
|
|
decode(b"\x0b\x00\x00\x00\x02field\x00")
|
|
# Assert that the InvalidBSON error message is not empty.
|
|
self.assertTrue(str(ctx.exception))
|
|
|
|
def test_data_timestamp(self):
|
|
self.assertEqual(
|
|
{"test": Timestamp(4, 20)},
|
|
decode(b"\x13\x00\x00\x00\x11\x74\x65\x73\x74\x00\x14\x00\x00\x00\x04\x00\x00\x00\x00"),
|
|
)
|
|
|
|
def test_basic_encode(self):
|
|
self.assertRaises(TypeError, encode, 100)
|
|
self.assertRaises(TypeError, encode, "hello")
|
|
self.assertRaises(TypeError, encode, None)
|
|
self.assertRaises(TypeError, encode, [])
|
|
|
|
self.assertEqual(encode({}), BSON(b"\x05\x00\x00\x00\x00"))
|
|
self.assertEqual(encode({}), b"\x05\x00\x00\x00\x00")
|
|
self.assertEqual(
|
|
encode({"test": "hello world"}),
|
|
b"\x1B\x00\x00\x00\x02\x74\x65\x73\x74\x00\x0C\x00"
|
|
b"\x00\x00\x68\x65\x6C\x6C\x6F\x20\x77\x6F\x72\x6C"
|
|
b"\x64\x00\x00",
|
|
)
|
|
self.assertEqual(
|
|
encode({"mike": 100}),
|
|
b"\x0F\x00\x00\x00\x10\x6D\x69\x6B\x65\x00\x64\x00\x00\x00\x00",
|
|
)
|
|
self.assertEqual(
|
|
encode({"hello": 1.5}),
|
|
b"\x14\x00\x00\x00\x01\x68\x65\x6C\x6C\x6F\x00\x00\x00\x00\x00\x00\x00\xF8\x3F\x00",
|
|
)
|
|
self.assertEqual(
|
|
encode({"true": True}), b"\x0C\x00\x00\x00\x08\x74\x72\x75\x65\x00\x01\x00"
|
|
)
|
|
self.assertEqual(
|
|
encode({"false": False}), b"\x0D\x00\x00\x00\x08\x66\x61\x6C\x73\x65\x00\x00\x00"
|
|
)
|
|
self.assertEqual(
|
|
encode({"empty": []}),
|
|
b"\x11\x00\x00\x00\x04\x65\x6D\x70\x74\x79\x00\x05\x00\x00\x00\x00\x00",
|
|
)
|
|
self.assertEqual(
|
|
encode({"none": {}}),
|
|
b"\x10\x00\x00\x00\x03\x6E\x6F\x6E\x65\x00\x05\x00\x00\x00\x00\x00",
|
|
)
|
|
self.assertEqual(
|
|
encode({"test": Binary(b"test", 0)}),
|
|
b"\x14\x00\x00\x00\x05\x74\x65\x73\x74\x00\x04\x00\x00\x00\x00\x74\x65\x73\x74\x00",
|
|
)
|
|
self.assertEqual(
|
|
encode({"test": Binary(b"test", 2)}),
|
|
b"\x18\x00\x00\x00\x05\x74\x65\x73\x74\x00\x08\x00"
|
|
b"\x00\x00\x02\x04\x00\x00\x00\x74\x65\x73\x74\x00",
|
|
)
|
|
self.assertEqual(
|
|
encode({"test": Binary(b"test", 128)}),
|
|
b"\x14\x00\x00\x00\x05\x74\x65\x73\x74\x00\x04\x00\x00\x00\x80\x74\x65\x73\x74\x00",
|
|
)
|
|
self.assertEqual(
|
|
encode({"vector_int8": Binary.from_vector([-128, -1, 127], BinaryVectorDtype.INT8)}),
|
|
b"\x1c\x00\x00\x00\x05vector_int8\x00\x05\x00\x00\x00\t\x03\x00\x80\xff\x7f\x00",
|
|
)
|
|
self.assertEqual(
|
|
encode({"vector_bool": Binary.from_vector([1, 127], BinaryVectorDtype.PACKED_BIT)}),
|
|
b"\x1b\x00\x00\x00\x05vector_bool\x00\x04\x00\x00\x00\t\x10\x00\x01\x7f\x00",
|
|
)
|
|
self.assertEqual(
|
|
encode(
|
|
{"vector_float32": Binary.from_vector([-1.1, 1.1e10], BinaryVectorDtype.FLOAT32)}
|
|
),
|
|
b"$\x00\x00\x00\x05vector_float32\x00\n\x00\x00\x00\t'\x00\xcd\xcc\x8c\xbf\xac\xe9#P\x00",
|
|
)
|
|
self.assertEqual(encode({"test": None}), b"\x0B\x00\x00\x00\x0A\x74\x65\x73\x74\x00\x00")
|
|
self.assertEqual(
|
|
encode({"date": datetime.datetime(2007, 1, 8, 0, 30, 11)}),
|
|
b"\x13\x00\x00\x00\x09\x64\x61\x74\x65\x00\x38\xBE\x1C\xFF\x0F\x01\x00\x00\x00",
|
|
)
|
|
self.assertEqual(
|
|
encode({"regex": re.compile(b"a*b", re.IGNORECASE)}),
|
|
b"\x12\x00\x00\x00\x0B\x72\x65\x67\x65\x78\x00\x61\x2A\x62\x00\x69\x00\x00",
|
|
)
|
|
self.assertEqual(
|
|
encode({"$where": Code("test")}),
|
|
b"\x16\x00\x00\x00\r$where\x00\x05\x00\x00\x00test\x00\x00",
|
|
)
|
|
self.assertEqual(
|
|
encode({"$field": Code("function(){ return true;}", scope=None)}),
|
|
b"+\x00\x00\x00\r$field\x00\x1a\x00\x00\x00function(){ return true;}\x00\x00",
|
|
)
|
|
self.assertEqual(
|
|
encode({"$field": Code("return function(){ return x; }", scope={"x": False})}),
|
|
b"=\x00\x00\x00\x0f$field\x000\x00\x00\x00\x1f\x00"
|
|
b"\x00\x00return function(){ return x; }\x00\t\x00"
|
|
b"\x00\x00\x08x\x00\x00\x00\x00",
|
|
)
|
|
unicode_empty_scope = Code("function(){ return 'héllo';}", {})
|
|
self.assertEqual(
|
|
encode({"$field": unicode_empty_scope}),
|
|
b"8\x00\x00\x00\x0f$field\x00+\x00\x00\x00\x1e\x00"
|
|
b"\x00\x00function(){ return 'h\xc3\xa9llo';}\x00\x05"
|
|
b"\x00\x00\x00\x00\x00",
|
|
)
|
|
a = ObjectId(b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B")
|
|
self.assertEqual(
|
|
encode({"oid": a}),
|
|
b"\x16\x00\x00\x00\x07\x6F\x69\x64\x00\x00\x01\x02"
|
|
b"\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x00",
|
|
)
|
|
self.assertEqual(
|
|
encode({"ref": DBRef("coll", a)}),
|
|
b"\x2F\x00\x00\x00\x03ref\x00\x25\x00\x00\x00\x02"
|
|
b"$ref\x00\x05\x00\x00\x00coll\x00\x07$id\x00\x00"
|
|
b"\x01\x02\x03\x04\x05\x06\x07\x08\x09\x0A\x0B\x00"
|
|
b"\x00",
|
|
)
|
|
|
|
def test_bad_code(self):
|
|
# Assert that decoding invalid Code with scope does not include a field name.
|
|
def generate_payload(length: int) -> bytes:
|
|
string_size = length - 0x1E
|
|
|
|
return bytes.fromhex(
|
|
struct.pack("<I", length).hex() # payload size
|
|
+ "0f" # type "code with scope"
|
|
+ "3100" # key (cstring)
|
|
+ "0a000000" # c_w_s_size
|
|
+ "04000000" # code_size
|
|
+ "41004200" # code (cstring)
|
|
+ "feffffff" # scope_size
|
|
+ "02" # type "string"
|
|
+ "3200" # key (cstring)
|
|
+ struct.pack("<I", string_size).hex() # string size
|
|
+ "00" * string_size # value (cstring)
|
|
# next bytes is a field name for type \x00
|
|
# type \x00 is invalid so bson throws an exception
|
|
)
|
|
|
|
for i in range(100):
|
|
payload = generate_payload(0x54F + i)
|
|
with self.assertRaisesRegex(InvalidBSON, "invalid") as ctx:
|
|
bson.decode(payload)
|
|
self.assertNotIn("fieldname", str(ctx.exception))
|
|
|
|
def test_unknown_type(self):
|
|
# Repr value differs with major python version
|
|
part = "type {!r} for fieldname 'foo'".format(b"\x14")
|
|
docs = [
|
|
b"\x0e\x00\x00\x00\x14foo\x00\x01\x00\x00\x00\x00",
|
|
(b"\x16\x00\x00\x00\x04foo\x00\x0c\x00\x00\x00\x140\x00\x01\x00\x00\x00\x00\x00"),
|
|
(
|
|
b" \x00\x00\x00\x04bar\x00\x16\x00\x00\x00\x030\x00\x0e\x00\x00"
|
|
b"\x00\x14foo\x00\x01\x00\x00\x00\x00\x00\x00"
|
|
),
|
|
]
|
|
for bs in docs:
|
|
try:
|
|
decode(bs)
|
|
except Exception as exc:
|
|
self.assertIsInstance(exc, InvalidBSON)
|
|
self.assertIn(part, str(exc))
|
|
else:
|
|
self.fail("Failed to raise an exception.")
|
|
|
|
def test_dbpointer(self):
|
|
# *Note* - DBPointer and DBRef are *not* the same thing. DBPointer
|
|
# is a deprecated BSON type. DBRef is a convention that does not
|
|
# exist in the BSON spec, meant to replace DBPointer. PyMongo does
|
|
# not support creation of the DBPointer type, but will decode
|
|
# DBPointer to DBRef.
|
|
|
|
bs = b"\x18\x00\x00\x00\x0c\x00\x01\x00\x00\x00\x00RY\xb5j\xfa[\xd8A\xd6X]\x99\x00"
|
|
|
|
self.assertEqual({"": DBRef("", ObjectId("5259b56afa5bd841d6585d99"))}, decode(bs))
|
|
|
|
def test_bad_dbref(self):
|
|
ref_only = {"ref": {"$ref": "collection"}}
|
|
id_only = {"ref": {"$id": ObjectId()}}
|
|
|
|
self.assertEqual(ref_only, decode(encode(ref_only)))
|
|
self.assertEqual(id_only, decode(encode(id_only)))
|
|
|
|
def test_bytes_as_keys(self):
|
|
doc = {b"foo": "bar"}
|
|
# Since `bytes` are stored as Binary you can't use them
|
|
# as keys. Using binary data as a key makes no sense in BSON
|
|
# anyway and little sense in python.
|
|
self.assertRaises(InvalidDocument, encode, doc)
|
|
|
|
def test_datetime_encode_decode(self):
|
|
# Negative timestamps
|
|
dt1 = datetime.datetime(1, 1, 1, 1, 1, 1, 111000)
|
|
dt2 = decode(encode({"date": dt1}))["date"]
|
|
self.assertEqual(dt1, dt2)
|
|
|
|
dt1 = datetime.datetime(1959, 6, 25, 12, 16, 59, 999000)
|
|
dt2 = decode(encode({"date": dt1}))["date"]
|
|
self.assertEqual(dt1, dt2)
|
|
|
|
# Positive timestamps
|
|
dt1 = datetime.datetime(9999, 12, 31, 23, 59, 59, 999000)
|
|
dt2 = decode(encode({"date": dt1}))["date"]
|
|
self.assertEqual(dt1, dt2)
|
|
|
|
dt1 = datetime.datetime(2011, 6, 14, 10, 47, 53, 444000)
|
|
dt2 = decode(encode({"date": dt1}))["date"]
|
|
self.assertEqual(dt1, dt2)
|
|
|
|
def test_large_datetime_truncation(self):
|
|
# Ensure that a large datetime is truncated correctly.
|
|
dt1 = datetime.datetime(9999, 1, 1, 1, 1, 1, 999999)
|
|
dt2 = decode(encode({"date": dt1}))["date"]
|
|
self.assertEqual(dt2.microsecond, 999000)
|
|
self.assertEqual(dt2.second, dt1.second)
|
|
|
|
def test_aware_datetime(self):
|
|
aware = datetime.datetime(1993, 4, 4, 2, tzinfo=FixedOffset(555, "SomeZone"))
|
|
offset = aware.utcoffset()
|
|
assert offset is not None
|
|
as_utc = (aware - offset).replace(tzinfo=utc)
|
|
self.assertEqual(datetime.datetime(1993, 4, 3, 16, 45, tzinfo=utc), as_utc)
|
|
after = decode(encode({"date": aware}), CodecOptions(tz_aware=True))["date"]
|
|
self.assertEqual(utc, after.tzinfo)
|
|
self.assertEqual(as_utc, after)
|
|
|
|
def test_local_datetime(self):
|
|
# Timezone -60 minutes of UTC, with DST between April and July.
|
|
tz = DSTAwareTimezone(60, "sixty-minutes", 4, 7)
|
|
|
|
# It's not DST.
|
|
local = datetime.datetime(year=2025, month=12, hour=2, day=1, tzinfo=tz)
|
|
options = CodecOptions(tz_aware=True, tzinfo=tz)
|
|
# Encode with this timezone, then decode to UTC.
|
|
encoded = encode({"date": local}, codec_options=options)
|
|
self.assertEqual(local.replace(hour=1, tzinfo=None), decode(encoded)["date"])
|
|
|
|
# It's DST.
|
|
local = datetime.datetime(year=2025, month=4, hour=1, day=1, tzinfo=tz)
|
|
encoded = encode({"date": local}, codec_options=options)
|
|
self.assertEqual(
|
|
local.replace(month=3, day=31, hour=23, tzinfo=None), decode(encoded)["date"]
|
|
)
|
|
|
|
# Encode UTC, then decode in a different timezone.
|
|
encoded = encode({"date": local.replace(tzinfo=utc)})
|
|
decoded = decode(encoded, options)["date"]
|
|
self.assertEqual(local.replace(hour=3), decoded)
|
|
self.assertEqual(tz, decoded.tzinfo)
|
|
|
|
# Test round-tripping.
|
|
self.assertEqual(
|
|
local, decode(encode({"date": local}, codec_options=options), options)["date"]
|
|
)
|
|
|
|
# Test around the Unix Epoch.
|
|
epochs = (
|
|
EPOCH_AWARE,
|
|
EPOCH_AWARE.astimezone(FixedOffset(120, "one twenty")),
|
|
EPOCH_AWARE.astimezone(FixedOffset(-120, "minus one twenty")),
|
|
)
|
|
utc_co = CodecOptions(tz_aware=True)
|
|
for epoch in epochs:
|
|
doc = {"epoch": epoch}
|
|
# We always retrieve datetimes in UTC unless told to do otherwise.
|
|
self.assertEqual(EPOCH_AWARE, decode(encode(doc), codec_options=utc_co)["epoch"])
|
|
# Round-trip the epoch.
|
|
local_co = CodecOptions(tz_aware=True, tzinfo=epoch.tzinfo)
|
|
self.assertEqual(epoch, decode(encode(doc), codec_options=local_co)["epoch"])
|
|
|
|
def test_naive_decode(self):
|
|
aware = datetime.datetime(1993, 4, 4, 2, tzinfo=FixedOffset(555, "SomeZone"))
|
|
offset = aware.utcoffset()
|
|
assert offset is not None
|
|
naive_utc = (aware - offset).replace(tzinfo=None)
|
|
self.assertEqual(datetime.datetime(1993, 4, 3, 16, 45), naive_utc)
|
|
after = decode(encode({"date": aware}))["date"]
|
|
self.assertEqual(None, after.tzinfo)
|
|
self.assertEqual(naive_utc, after)
|
|
|
|
def test_dst(self):
|
|
d = {"x": datetime.datetime(1993, 4, 4, 2)}
|
|
self.assertEqual(d, decode(encode(d)))
|
|
|
|
@unittest.skip("Disabled due to http://bugs.python.org/issue25222")
|
|
def test_bad_encode(self):
|
|
evil_list: dict = {"a": []}
|
|
evil_list["a"].append(evil_list)
|
|
evil_dict: dict = {}
|
|
evil_dict["a"] = evil_dict
|
|
for evil_data in [evil_dict, evil_list]:
|
|
self.assertRaises(Exception, encode, evil_data)
|
|
|
|
def test_overflow(self):
|
|
self.assertTrue(encode({"x": 9223372036854775807}))
|
|
self.assertRaises(OverflowError, encode, {"x": 9223372036854775808})
|
|
|
|
self.assertTrue(encode({"x": -9223372036854775808}))
|
|
self.assertRaises(OverflowError, encode, {"x": -9223372036854775809})
|
|
|
|
def test_small_long_encode_decode(self):
|
|
encoded1 = encode({"x": 256})
|
|
decoded1 = decode(encoded1)["x"]
|
|
self.assertEqual(256, decoded1)
|
|
self.assertEqual(int, type(decoded1))
|
|
|
|
encoded2 = encode({"x": Int64(256)})
|
|
decoded2 = decode(encoded2)["x"]
|
|
expected = Int64(256)
|
|
self.assertEqual(expected, decoded2)
|
|
self.assertEqual(type(expected), type(decoded2))
|
|
|
|
self.assertNotEqual(type(decoded1), type(decoded2))
|
|
|
|
def test_tuple(self):
|
|
self.assertEqual({"tuple": [1, 2]}, decode(encode({"tuple": (1, 2)})))
|
|
|
|
def test_uuid(self):
|
|
id = uuid.uuid4()
|
|
# The default uuid_representation is UNSPECIFIED
|
|
with self.assertRaisesRegex(ValueError, "cannot encode native uuid"):
|
|
bson.decode_all(encode({"uuid": id}))
|
|
|
|
opts = CodecOptions(uuid_representation=UuidRepresentation.STANDARD)
|
|
transformed_id = decode(encode({"id": id}, codec_options=opts), codec_options=opts)["id"]
|
|
self.assertIsInstance(transformed_id, uuid.UUID)
|
|
self.assertEqual(id, transformed_id)
|
|
self.assertNotEqual(uuid.uuid4(), transformed_id)
|
|
|
|
def test_uuid_legacy(self):
|
|
id = uuid.uuid4()
|
|
legacy = Binary.from_uuid(id, UuidRepresentation.PYTHON_LEGACY)
|
|
self.assertEqual(3, legacy.subtype)
|
|
bin = decode(encode({"uuid": legacy}))["uuid"]
|
|
self.assertIsInstance(bin, Binary)
|
|
transformed = bin.as_uuid(UuidRepresentation.PYTHON_LEGACY)
|
|
self.assertEqual(id, transformed)
|
|
|
|
def test_vector(self):
|
|
"""Tests of subtype 9"""
|
|
# We start with valid cases, across the 3 dtypes implemented.
|
|
# Work with a simple vector that can be interpreted as int8, float32, or ubyte
|
|
list_vector = [127, 8]
|
|
# As INT8, vector has length 2
|
|
binary_vector = Binary.from_vector(list_vector, BinaryVectorDtype.INT8)
|
|
vector = binary_vector.as_vector()
|
|
assert vector.data == list_vector
|
|
# test encoding roundtrip
|
|
assert {"vector": binary_vector} == decode(encode({"vector": binary_vector}))
|
|
# test json roundtrip
|
|
assert binary_vector == json_util.loads(json_util.dumps(binary_vector))
|
|
|
|
# For vectors of bits, aka PACKED_BIT type, vector has length 8 * 2
|
|
packed_bit_binary = Binary.from_vector(list_vector, BinaryVectorDtype.PACKED_BIT)
|
|
packed_bit_vec = packed_bit_binary.as_vector()
|
|
assert packed_bit_vec.data == list_vector
|
|
|
|
# A padding parameter permits vectors of length that aren't divisible by 8
|
|
# The following ignores the last 3 bits in list_vector,
|
|
# hence it's length is 8 * len(list_vector) - padding
|
|
padding = 3
|
|
padded_vec = Binary.from_vector(list_vector, BinaryVectorDtype.PACKED_BIT, padding=padding)
|
|
assert padded_vec.as_vector().data == list_vector
|
|
# To visualize how this looks as a binary vector..
|
|
uncompressed = ""
|
|
for val in list_vector:
|
|
uncompressed += format(val, "08b")
|
|
assert uncompressed[:-padding] == "0111111100001"
|
|
|
|
# It is worthwhile explicitly showing the values encoded to BSON
|
|
padded_doc = {"padded_vec": padded_vec}
|
|
assert (
|
|
encode(padded_doc)
|
|
== b"\x1a\x00\x00\x00\x05padded_vec\x00\x04\x00\x00\x00\t\x10\x03\x7f\x08\x00"
|
|
)
|
|
# and dumped to json
|
|
assert (
|
|
json_util.dumps(padded_doc)
|
|
== '{"padded_vec": {"$binary": {"base64": "EAN/CA==", "subType": "09"}}}'
|
|
)
|
|
|
|
# FLOAT32 is also implemented
|
|
float_binary = Binary.from_vector(list_vector, BinaryVectorDtype.FLOAT32)
|
|
assert all(isinstance(d, float) for d in float_binary.as_vector().data)
|
|
|
|
# Now some invalid cases
|
|
for x in [-1, 257]:
|
|
with self.assertRaises(struct.error):
|
|
Binary.from_vector([x], BinaryVectorDtype.PACKED_BIT)
|
|
|
|
# Test one must pass zeros for all ignored bits
|
|
with self.assertRaises(ValueError):
|
|
Binary.from_vector([255], BinaryVectorDtype.PACKED_BIT, padding=7)
|
|
|
|
with self.assertWarns(DeprecationWarning):
|
|
meta = struct.pack("<sB", BinaryVectorDtype.PACKED_BIT.value, 7)
|
|
data = struct.pack("1B", 255)
|
|
Binary(meta + data, subtype=9).as_vector()
|
|
|
|
# Test form of Binary.from_vector(BinaryVector)
|
|
assert padded_vec == Binary.from_vector(
|
|
BinaryVector(list_vector, BinaryVectorDtype.PACKED_BIT, padding)
|
|
)
|
|
assert binary_vector == Binary.from_vector(
|
|
BinaryVector(list_vector, BinaryVectorDtype.INT8)
|
|
)
|
|
assert float_binary == Binary.from_vector(
|
|
BinaryVector(list_vector, BinaryVectorDtype.FLOAT32)
|
|
)
|
|
# Confirm kwargs cannot be passed when BinaryVector is provided
|
|
with self.assertRaises(ValueError):
|
|
Binary.from_vector(
|
|
BinaryVector(list_vector, BinaryVectorDtype.PACKED_BIT, padding),
|
|
dtype=BinaryVectorDtype.PACKED_BIT,
|
|
) # type: ignore[call-overload]
|
|
|
|
def assertRepr(self, obj):
|
|
new_obj = eval(repr(obj))
|
|
self.assertEqual(type(new_obj), type(obj))
|
|
self.assertEqual(repr(new_obj), repr(obj))
|
|
|
|
def test_binaryvector_repr(self):
|
|
"""Tests of repr(BinaryVector)"""
|
|
|
|
data = [1 / 127, -7 / 6]
|
|
one = BinaryVector(data, BinaryVectorDtype.FLOAT32)
|
|
self.assertEqual(
|
|
repr(one), f"BinaryVector(dtype=BinaryVectorDtype.FLOAT32, padding=0, data={data})"
|
|
)
|
|
self.assertRepr(one)
|
|
|
|
data = [127, 7]
|
|
two = BinaryVector(data, BinaryVectorDtype.INT8)
|
|
self.assertEqual(
|
|
repr(two), f"BinaryVector(dtype=BinaryVectorDtype.INT8, padding=0, data={data})"
|
|
)
|
|
self.assertRepr(two)
|
|
|
|
three = BinaryVector(data, BinaryVectorDtype.INT8, padding=0)
|
|
self.assertEqual(
|
|
repr(three), f"BinaryVector(dtype=BinaryVectorDtype.INT8, padding=0, data={data})"
|
|
)
|
|
self.assertRepr(three)
|
|
|
|
four = BinaryVector(data, BinaryVectorDtype.PACKED_BIT, padding=3)
|
|
self.assertEqual(
|
|
repr(four), f"BinaryVector(dtype=BinaryVectorDtype.PACKED_BIT, padding=3, data={data})"
|
|
)
|
|
self.assertRepr(four)
|
|
|
|
zero = BinaryVector([], BinaryVectorDtype.INT8)
|
|
self.assertEqual(
|
|
repr(zero), "BinaryVector(dtype=BinaryVectorDtype.INT8, padding=0, data=[])"
|
|
)
|
|
self.assertRepr(zero)
|
|
|
|
def test_binaryvector_equality(self):
|
|
"""Tests of == __eq__"""
|
|
self.assertEqual(
|
|
BinaryVector([1.2, 1 - 1 / 3], BinaryVectorDtype.FLOAT32, 0),
|
|
BinaryVector([1.2, 1 - 1.0 / 3.0], BinaryVectorDtype.FLOAT32, 0),
|
|
)
|
|
self.assertNotEqual(
|
|
BinaryVector([1.2, 1 - 1 / 3], BinaryVectorDtype.FLOAT32, 0),
|
|
BinaryVector([1.2, 6.0 / 9.0], BinaryVectorDtype.FLOAT32, 0),
|
|
)
|
|
self.assertEqual(
|
|
BinaryVector([], BinaryVectorDtype.FLOAT32, 0),
|
|
BinaryVector([], BinaryVectorDtype.FLOAT32, 0),
|
|
)
|
|
self.assertNotEqual(
|
|
BinaryVector([1], BinaryVectorDtype.INT8), BinaryVector([2], BinaryVectorDtype.INT8)
|
|
)
|
|
|
|
def test_unicode_regex(self):
|
|
"""Tests we do not get a segfault for C extension on unicode RegExs.
|
|
This had been happening.
|
|
"""
|
|
regex = re.compile("revisi\xf3n")
|
|
decode(encode({"regex": regex}))
|
|
|
|
def test_non_string_keys(self):
|
|
self.assertRaises(InvalidDocument, encode, {8.9: "test"})
|
|
|
|
def test_utf8(self):
|
|
w = {"aéあ": "aéあ"}
|
|
self.assertEqual(w, decode(encode(w)))
|
|
|
|
# b'a\xe9' == "aé".encode("iso-8859-1")
|
|
iso8859_bytes = b"a\xe9"
|
|
y = {"hello": iso8859_bytes}
|
|
# Stored as BSON binary subtype 0.
|
|
out = decode(encode(y))
|
|
self.assertIsInstance(out["hello"], bytes)
|
|
self.assertEqual(out["hello"], iso8859_bytes)
|
|
|
|
def test_null_character(self):
|
|
doc = {"a": "\x00"}
|
|
self.assertEqual(doc, decode(encode(doc)))
|
|
|
|
doc = {"a": "\x00"}
|
|
self.assertEqual(doc, decode(encode(doc)))
|
|
|
|
self.assertRaises(InvalidDocument, encode, {b"\x00": "a"})
|
|
self.assertRaises(InvalidDocument, encode, {"\x00": "a"})
|
|
|
|
self.assertRaises(InvalidDocument, encode, {"a": re.compile(b"ab\x00c")})
|
|
self.assertRaises(InvalidDocument, encode, {"a": re.compile("ab\x00c")})
|
|
|
|
def test_move_id(self):
|
|
self.assertEqual(
|
|
b"\x19\x00\x00\x00\x02_id\x00\x02\x00\x00\x00a\x00"
|
|
b"\x02a\x00\x02\x00\x00\x00a\x00\x00",
|
|
encode(SON([("a", "a"), ("_id", "a")])),
|
|
)
|
|
|
|
self.assertEqual(
|
|
b"\x2c\x00\x00\x00"
|
|
b"\x02_id\x00\x02\x00\x00\x00b\x00"
|
|
b"\x03b\x00"
|
|
b"\x19\x00\x00\x00\x02a\x00\x02\x00\x00\x00a\x00"
|
|
b"\x02_id\x00\x02\x00\x00\x00a\x00\x00\x00",
|
|
encode(SON([("b", SON([("a", "a"), ("_id", "a")])), ("_id", "b")])),
|
|
)
|
|
|
|
def test_dates(self):
|
|
doc = {"early": datetime.datetime(1686, 5, 5), "late": datetime.datetime(2086, 5, 5)}
|
|
try:
|
|
self.assertEqual(doc, decode(encode(doc)))
|
|
except ValueError:
|
|
# Ignore ValueError when no C ext, since it's probably
|
|
# a problem w/ 32-bit Python - we work around this in the
|
|
# C ext, though.
|
|
if bson.has_c():
|
|
raise
|
|
|
|
def test_custom_class(self):
|
|
self.assertIsInstance(decode(encode({})), dict)
|
|
self.assertNotIsInstance(decode(encode({})), SON)
|
|
self.assertIsInstance(decode(encode({}), CodecOptions(document_class=SON)), SON) # type: ignore[type-var]
|
|
|
|
self.assertEqual(1, decode(encode({"x": 1}), CodecOptions(document_class=SON))["x"]) # type: ignore[type-var]
|
|
|
|
x = encode({"x": [{"y": 1}]})
|
|
self.assertIsInstance(decode(x, CodecOptions(document_class=SON))["x"][0], SON) # type: ignore[type-var]
|
|
|
|
def test_subclasses(self):
|
|
# make sure we can serialize subclasses of native Python types.
|
|
class _myint(int):
|
|
pass
|
|
|
|
class _myfloat(float):
|
|
pass
|
|
|
|
class _myunicode(str):
|
|
pass
|
|
|
|
d = {"a": _myint(42), "b": _myfloat(63.9), "c": _myunicode("hello world")}
|
|
d2 = decode(encode(d))
|
|
for key, value in d2.items():
|
|
orig_value = d[key]
|
|
orig_type = orig_value.__class__.__bases__[0]
|
|
self.assertEqual(type(value), orig_type)
|
|
self.assertEqual(value, orig_type(value))
|
|
|
|
def test_encode_type_marker(self):
|
|
# Assert that a custom subclass can be BSON encoded based on the _type_marker attribute.
|
|
class MyMaxKey:
|
|
_type_marker = 127
|
|
|
|
expected_bson = encode({"a": MaxKey()})
|
|
self.assertEqual(encode({"a": MyMaxKey()}), expected_bson)
|
|
|
|
# Test a class that inherits from two built in types
|
|
class MyBinary(Binary):
|
|
pass
|
|
|
|
expected_bson = encode({"a": Binary(b"bin", USER_DEFINED_SUBTYPE)})
|
|
self.assertEqual(encode({"a": MyBinary(b"bin", USER_DEFINED_SUBTYPE)}), expected_bson)
|
|
|
|
def test_ordered_dict(self):
|
|
d = OrderedDict([("one", 1), ("two", 2), ("three", 3), ("four", 4)])
|
|
self.assertEqual(d, decode(encode(d), CodecOptions(document_class=OrderedDict))) # type: ignore[type-var]
|
|
|
|
def test_bson_regex(self):
|
|
# Invalid Python regex, though valid PCRE.
|
|
bson_re1 = Regex(r"[\w-\.]")
|
|
self.assertEqual(r"[\w-\.]", bson_re1.pattern)
|
|
self.assertEqual(0, bson_re1.flags)
|
|
|
|
doc1 = {"r": bson_re1}
|
|
doc1_bson = (
|
|
b"\x11\x00\x00\x00\x0br\x00[\\w-\\.]\x00\x00\x00"
|
|
) # document length # r: regex # document terminator
|
|
|
|
self.assertEqual(doc1_bson, encode(doc1))
|
|
self.assertEqual(doc1, decode(doc1_bson))
|
|
|
|
# Valid Python regex, with flags.
|
|
re2 = re.compile(".*", re.I | re.M | re.S | re.U | re.X)
|
|
bson_re2 = Regex(".*", re.I | re.M | re.S | re.U | re.X)
|
|
|
|
doc2_with_re = {"r": re2}
|
|
doc2_with_bson_re = {"r": bson_re2}
|
|
doc2_bson = (
|
|
b"\x11\x00\x00\x00\x0br\x00.*\x00imsux\x00\x00"
|
|
) # document length # r: regex # document terminator
|
|
|
|
self.assertEqual(doc2_bson, encode(doc2_with_re))
|
|
self.assertEqual(doc2_bson, encode(doc2_with_bson_re))
|
|
|
|
self.assertEqual(re2.pattern, decode(doc2_bson)["r"].pattern)
|
|
self.assertEqual(re2.flags, decode(doc2_bson)["r"].flags)
|
|
|
|
def test_regex_from_native(self):
|
|
self.assertEqual(".*", Regex.from_native(re.compile(".*")).pattern)
|
|
self.assertEqual(0, Regex.from_native(re.compile(b"")).flags)
|
|
|
|
regex = re.compile(b"", re.I | re.L | re.M | re.S | re.X)
|
|
self.assertEqual(re.I | re.L | re.M | re.S | re.X, Regex.from_native(regex).flags)
|
|
|
|
unicode_regex = re.compile("", re.U)
|
|
self.assertEqual(re.U, Regex.from_native(unicode_regex).flags)
|
|
|
|
def test_regex_hash(self):
|
|
self.assertRaises(TypeError, hash, Regex("hello"))
|
|
|
|
def test_regex_comparison(self):
|
|
re1 = Regex("a")
|
|
re2 = Regex("b")
|
|
self.assertNotEqual(re1, re2)
|
|
re1 = Regex("a", re.I)
|
|
re2 = Regex("a", re.M)
|
|
self.assertNotEqual(re1, re2)
|
|
re1 = Regex("a", re.I)
|
|
re2 = Regex("a", re.I)
|
|
self.assertEqual(re1, re2)
|
|
|
|
def test_exception_wrapping(self):
|
|
# No matter what exception is raised while trying to decode BSON,
|
|
# the final exception always matches InvalidBSON.
|
|
|
|
# {'s': '\xff'}, will throw attempting to decode utf-8.
|
|
bad_doc = b"\x0f\x00\x00\x00\x02s\x00\x03\x00\x00\x00\xff\x00\x00\x00"
|
|
|
|
with self.assertRaises(InvalidBSON) as context:
|
|
decode_all(bad_doc)
|
|
|
|
self.assertIn("codec can't decode byte 0xff", str(context.exception))
|
|
|
|
def test_minkey_maxkey_comparison(self):
|
|
# MinKey's <, <=, >, >=, !=, and ==.
|
|
# These tests should be kept as assertTrue as opposed to using unittest's built-in comparison assertions because
|
|
# MinKey and MaxKey define their own __ge__, __le__, and other comparison attributes, and we want to explicitly test that.
|
|
self.assertTrue(MinKey() < None)
|
|
self.assertTrue(MinKey() < 1)
|
|
self.assertTrue(MinKey() <= 1)
|
|
self.assertTrue(MinKey() <= MinKey())
|
|
self.assertFalse(MinKey() > None)
|
|
self.assertFalse(MinKey() > 1)
|
|
self.assertFalse(MinKey() >= 1)
|
|
self.assertTrue(MinKey() >= MinKey())
|
|
self.assertTrue(MinKey() != 1)
|
|
self.assertFalse(MinKey() == 1)
|
|
self.assertTrue(MinKey() == MinKey())
|
|
|
|
# MinKey compared to MaxKey.
|
|
self.assertTrue(MinKey() < MaxKey())
|
|
self.assertTrue(MinKey() <= MaxKey())
|
|
self.assertFalse(MinKey() > MaxKey())
|
|
self.assertFalse(MinKey() >= MaxKey())
|
|
self.assertTrue(MinKey() != MaxKey())
|
|
self.assertFalse(MinKey() == MaxKey())
|
|
|
|
# MaxKey's <, <=, >, >=, !=, and ==.
|
|
self.assertFalse(MaxKey() < None)
|
|
self.assertFalse(MaxKey() < 1)
|
|
self.assertFalse(MaxKey() <= 1)
|
|
self.assertTrue(MaxKey() <= MaxKey())
|
|
self.assertTrue(MaxKey() > None)
|
|
self.assertTrue(MaxKey() > 1)
|
|
self.assertTrue(MaxKey() >= 1)
|
|
self.assertTrue(MaxKey() >= MaxKey())
|
|
self.assertTrue(MaxKey() != 1)
|
|
self.assertFalse(MaxKey() == 1)
|
|
self.assertTrue(MaxKey() == MaxKey())
|
|
|
|
# MaxKey compared to MinKey.
|
|
self.assertFalse(MaxKey() < MinKey())
|
|
self.assertFalse(MaxKey() <= MinKey())
|
|
self.assertTrue(MaxKey() > MinKey())
|
|
self.assertTrue(MaxKey() >= MinKey())
|
|
self.assertTrue(MaxKey() != MinKey())
|
|
self.assertFalse(MaxKey() == MinKey())
|
|
|
|
def test_minkey_maxkey_hash(self):
|
|
self.assertEqual(hash(MaxKey()), hash(MaxKey()))
|
|
self.assertEqual(hash(MinKey()), hash(MinKey()))
|
|
self.assertNotEqual(hash(MaxKey()), hash(MinKey()))
|
|
|
|
def test_timestamp_comparison(self):
|
|
# Timestamp is initialized with time, inc. Time is the more
|
|
# significant comparand.
|
|
self.assertTrue(Timestamp(1, 0) < Timestamp(2, 17))
|
|
self.assertTrue(Timestamp(2, 0) > Timestamp(1, 0))
|
|
self.assertTrue(Timestamp(1, 7) <= Timestamp(2, 0))
|
|
self.assertTrue(Timestamp(2, 0) >= Timestamp(1, 1))
|
|
self.assertTrue(Timestamp(2, 0) <= Timestamp(2, 0))
|
|
self.assertTrue(Timestamp(2, 0) >= Timestamp(2, 0))
|
|
self.assertFalse(Timestamp(1, 0) > Timestamp(2, 0))
|
|
|
|
# Comparison by inc.
|
|
self.assertTrue(Timestamp(1, 0) < Timestamp(1, 1))
|
|
self.assertTrue(Timestamp(1, 1) > Timestamp(1, 0))
|
|
self.assertTrue(Timestamp(1, 0) <= Timestamp(1, 0))
|
|
self.assertTrue(Timestamp(1, 0) <= Timestamp(1, 1))
|
|
self.assertFalse(Timestamp(1, 0) >= Timestamp(1, 1))
|
|
self.assertTrue(Timestamp(1, 0) >= Timestamp(1, 0))
|
|
self.assertTrue(Timestamp(1, 1) >= Timestamp(1, 0))
|
|
self.assertFalse(Timestamp(1, 1) <= Timestamp(1, 0))
|
|
self.assertTrue(Timestamp(1, 0) <= Timestamp(1, 0))
|
|
self.assertFalse(Timestamp(1, 0) > Timestamp(1, 0))
|
|
|
|
def test_timestamp_highorder_bits(self):
|
|
doc = {"a": Timestamp(0xFFFFFFFF, 0xFFFFFFFF)}
|
|
doc_bson = b"\x10\x00\x00\x00\x11a\x00\xff\xff\xff\xff\xff\xff\xff\xff\x00"
|
|
self.assertEqual(doc_bson, encode(doc))
|
|
self.assertEqual(doc, decode(doc_bson))
|
|
|
|
def test_bad_id_keys(self):
|
|
self.assertRaises(InvalidDocument, encode, {"_id": {"$bad": 123}}, True)
|
|
self.assertRaises(
|
|
InvalidDocument, encode, {"_id": {"$oid": "52d0b971b3ba219fdeb4170e"}}, True
|
|
)
|
|
encode({"_id": {"$oid": "52d0b971b3ba219fdeb4170e"}})
|
|
|
|
def test_bson_encode_thread_safe(self):
|
|
def target(i):
|
|
for j in range(1000):
|
|
my_int = type(f"MyInt_{i}_{j}", (int,), {})
|
|
bson.encode({"my_int": my_int()})
|
|
|
|
threads = [ExceptionCatchingTask(target=target, args=(i,)) for i in range(3)]
|
|
for t in threads:
|
|
t.start()
|
|
|
|
for t in threads:
|
|
t.join()
|
|
|
|
for t in threads:
|
|
self.assertIsNone(t.exc)
|
|
|
|
def test_raise_invalid_document(self):
|
|
class Wrapper:
|
|
def __init__(self, val):
|
|
self.val = val
|
|
|
|
def __repr__(self):
|
|
return repr(self.val)
|
|
|
|
self.assertEqual("1", repr(Wrapper(1)))
|
|
with self.assertRaisesRegex(
|
|
InvalidDocument, "cannot encode object: 1, of type: " + repr(Wrapper)
|
|
):
|
|
encode({"t": Wrapper(1)})
|
|
|
|
def test_doc_in_invalid_document_error_message(self):
|
|
class Wrapper:
|
|
def __init__(self, val):
|
|
self.val = val
|
|
|
|
def __repr__(self):
|
|
return repr(self.val)
|
|
|
|
self.assertEqual("1", repr(Wrapper(1)))
|
|
doc = {"t": Wrapper(1)}
|
|
with self.assertRaisesRegex(InvalidDocument, f"Invalid document {doc}"):
|
|
encode(doc)
|
|
|
|
def test_doc_in_invalid_document_error_message_mapping(self):
|
|
class MyMapping(abc.Mapping):
|
|
def keys(self):
|
|
return ["t"]
|
|
|
|
def __getitem__(self, name):
|
|
if name == "_id":
|
|
return None
|
|
return Wrapper(name)
|
|
|
|
def __len__(self):
|
|
return 1
|
|
|
|
def __iter__(self):
|
|
return iter(["t"])
|
|
|
|
class Wrapper:
|
|
def __init__(self, val):
|
|
self.val = val
|
|
|
|
def __repr__(self):
|
|
return repr(self.val)
|
|
|
|
self.assertEqual("1", repr(Wrapper(1)))
|
|
doc = MyMapping()
|
|
with self.assertRaisesRegex(InvalidDocument, f"Invalid document {doc}"):
|
|
encode(doc)
|
|
|
|
|
|
class TestCodecOptions(unittest.TestCase):
|
|
def test_document_class(self):
|
|
self.assertRaises(TypeError, CodecOptions, document_class=object)
|
|
self.assertIs(SON, CodecOptions(document_class=SON).document_class) # type: ignore[type-var]
|
|
|
|
def test_tz_aware(self):
|
|
self.assertRaises(TypeError, CodecOptions, tz_aware=1)
|
|
self.assertFalse(CodecOptions().tz_aware)
|
|
self.assertTrue(CodecOptions(tz_aware=True).tz_aware)
|
|
|
|
def test_uuid_representation(self):
|
|
self.assertRaises(ValueError, CodecOptions, uuid_representation=7)
|
|
self.assertRaises(ValueError, CodecOptions, uuid_representation=2)
|
|
|
|
def test_tzinfo(self):
|
|
self.assertRaises(TypeError, CodecOptions, tzinfo="pacific")
|
|
tz = FixedOffset(42, "forty-two")
|
|
self.assertRaises(ValueError, CodecOptions, tzinfo=tz)
|
|
self.assertEqual(tz, CodecOptions(tz_aware=True, tzinfo=tz).tzinfo)
|
|
self.assertEqual(repr(tz), "FixedOffset(datetime.timedelta(seconds=2520), 'forty-two')")
|
|
self.assertEqual(
|
|
repr(eval(repr(tz))), "FixedOffset(datetime.timedelta(seconds=2520), 'forty-two')"
|
|
)
|
|
|
|
def test_codec_options_repr(self):
|
|
r = (
|
|
"CodecOptions(document_class=dict, tz_aware=False, "
|
|
"uuid_representation=UuidRepresentation.UNSPECIFIED, "
|
|
"unicode_decode_error_handler='strict', "
|
|
"tzinfo=None, type_registry=TypeRegistry(type_codecs=[], "
|
|
"fallback_encoder=None), "
|
|
"datetime_conversion=DatetimeConversion.DATETIME)"
|
|
)
|
|
self.assertEqual(r, repr(CodecOptions()))
|
|
|
|
def test_decode_all_defaults(self):
|
|
# Test decode_all()'s default document_class is dict and tz_aware is
|
|
# False.
|
|
doc = {"sub_document": {}, "dt": datetime.datetime.now(tz=datetime.timezone.utc)}
|
|
|
|
decoded = bson.decode_all(bson.encode(doc))[0]
|
|
self.assertIsInstance(decoded["sub_document"], dict)
|
|
self.assertIsNone(decoded["dt"].tzinfo)
|
|
# The default uuid_representation is UNSPECIFIED
|
|
with self.assertRaisesRegex(ValueError, "cannot encode native uuid"):
|
|
bson.decode_all(bson.encode({"uuid": uuid.uuid4()}))
|
|
|
|
def test_decode_all_no_options(self):
|
|
# Test decode_all()'s default document_class is dict and tz_aware is
|
|
# False.
|
|
doc = {"sub_document": {}, "dt": datetime.datetime.now(tz=datetime.timezone.utc)}
|
|
|
|
decoded = bson.decode_all(bson.encode(doc), None)[0]
|
|
self.assertIsInstance(decoded["sub_document"], dict)
|
|
self.assertIsNone(decoded["dt"].tzinfo)
|
|
|
|
doc2 = {"id": Binary.from_uuid(uuid.uuid4())}
|
|
decoded = bson.decode_all(bson.encode(doc2), None)[0]
|
|
self.assertIsInstance(decoded["id"], Binary)
|
|
|
|
def test_decode_all_kwarg(self):
|
|
doc = {"a": uuid.uuid4()}
|
|
opts = CodecOptions(uuid_representation=UuidRepresentation.STANDARD)
|
|
encoded = encode(doc, codec_options=opts)
|
|
# Positional codec_options
|
|
self.assertEqual([doc], decode_all(encoded, opts))
|
|
# Keyword codec_options
|
|
self.assertEqual([doc], decode_all(encoded, codec_options=opts))
|
|
|
|
def test_unicode_decode_error_handler(self):
|
|
enc = encode({"keystr": "foobar"})
|
|
|
|
# Test handling of bad key value, bad string value, and both.
|
|
invalid_key = enc[:7] + b"\xe9" + enc[8:]
|
|
invalid_val = enc[:18] + b"\xe9" + enc[19:]
|
|
invalid_both = enc[:7] + b"\xe9" + enc[8:18] + b"\xe9" + enc[19:]
|
|
|
|
# Ensure that strict mode raises an error.
|
|
for invalid in [invalid_key, invalid_val, invalid_both]:
|
|
self.assertRaises(
|
|
InvalidBSON,
|
|
decode,
|
|
invalid,
|
|
CodecOptions(unicode_decode_error_handler="strict"),
|
|
)
|
|
self.assertRaises(InvalidBSON, decode, invalid, CodecOptions())
|
|
self.assertRaises(InvalidBSON, decode, invalid)
|
|
|
|
# Test all other error handlers.
|
|
for handler in ["replace", "backslashreplace", "surrogateescape", "ignore"]:
|
|
expected_key = b"ke\xe9str".decode("utf-8", handler)
|
|
expected_val = b"fo\xe9bar".decode("utf-8", handler)
|
|
doc = decode(invalid_key, CodecOptions(unicode_decode_error_handler=handler))
|
|
self.assertEqual(doc, {expected_key: "foobar"})
|
|
doc = decode(invalid_val, CodecOptions(unicode_decode_error_handler=handler))
|
|
self.assertEqual(doc, {"keystr": expected_val})
|
|
doc = decode(invalid_both, CodecOptions(unicode_decode_error_handler=handler))
|
|
self.assertEqual(doc, {expected_key: expected_val})
|
|
|
|
# Test handling bad error mode.
|
|
dec = decode(enc, CodecOptions(unicode_decode_error_handler="junk"))
|
|
self.assertEqual(dec, {"keystr": "foobar"})
|
|
|
|
self.assertRaises(
|
|
InvalidBSON,
|
|
decode,
|
|
invalid_both,
|
|
CodecOptions(unicode_decode_error_handler="junk"),
|
|
)
|
|
|
|
def round_trip_pickle(self, obj, pickled_with_older):
|
|
pickled_with_older_obj = pickle.loads(pickled_with_older)
|
|
for protocol in range(pickle.HIGHEST_PROTOCOL + 1):
|
|
pkl = pickle.dumps(obj, protocol=protocol)
|
|
obj2 = pickle.loads(pkl)
|
|
self.assertEqual(obj, obj2)
|
|
self.assertEqual(pickled_with_older_obj, obj2)
|
|
|
|
def test_regex_pickling(self):
|
|
reg = Regex(".?")
|
|
pickled_with_3 = (
|
|
b"\x80\x04\x959\x00\x00\x00\x00\x00\x00\x00\x8c\n"
|
|
b"bson.regex\x94\x8c\x05Regex\x94\x93\x94)\x81\x94}"
|
|
b"\x94(\x8c\x07pattern\x94\x8c\x02.?\x94\x8c\x05flag"
|
|
b"s\x94K\x00ub."
|
|
)
|
|
self.round_trip_pickle(reg, pickled_with_3)
|
|
|
|
def test_timestamp_pickling(self):
|
|
ts = Timestamp(0, 1)
|
|
pickled_with_3 = (
|
|
b"\x80\x04\x95Q\x00\x00\x00\x00\x00\x00\x00\x8c"
|
|
b"\x0ebson.timestamp\x94\x8c\tTimestamp\x94\x93\x94)"
|
|
b"\x81\x94}\x94("
|
|
b"\x8c\x10_Timestamp__time\x94K\x00\x8c"
|
|
b"\x0f_Timestamp__inc\x94K\x01ub."
|
|
)
|
|
self.round_trip_pickle(ts, pickled_with_3)
|
|
|
|
def test_dbref_pickling(self):
|
|
dbr = DBRef("foo", 5)
|
|
pickled_with_3 = (
|
|
b"\x80\x04\x95q\x00\x00\x00\x00\x00\x00\x00\x8c\n"
|
|
b"bson.dbref\x94\x8c\x05DBRef\x94\x93\x94)\x81\x94}"
|
|
b"\x94(\x8c\x12_DBRef__collection\x94\x8c\x03foo\x94"
|
|
b"\x8c\n_DBRef__id\x94K\x05\x8c\x10_DBRef__database"
|
|
b"\x94N\x8c\x0e_DBRef__kwargs\x94}\x94ub."
|
|
)
|
|
self.round_trip_pickle(dbr, pickled_with_3)
|
|
|
|
dbr = DBRef("foo", 5, database="db", kwargs1=None)
|
|
pickled_with_3 = (
|
|
b"\x80\x04\x95\x81\x00\x00\x00\x00\x00\x00\x00\x8c"
|
|
b"\nbson.dbref\x94\x8c\x05DBRef\x94\x93\x94)\x81\x94}"
|
|
b"\x94(\x8c\x12_DBRef__collection\x94\x8c\x03foo\x94"
|
|
b"\x8c\n_DBRef__id\x94K\x05\x8c\x10_DBRef__database"
|
|
b"\x94\x8c\x02db\x94\x8c\x0e_DBRef__kwargs\x94}\x94"
|
|
b"\x8c\x07kwargs1\x94Nsub."
|
|
)
|
|
|
|
self.round_trip_pickle(dbr, pickled_with_3)
|
|
|
|
def test_minkey_pickling(self):
|
|
mink = MinKey()
|
|
pickled_with_3 = (
|
|
b"\x80\x04\x95\x1e\x00\x00\x00\x00\x00\x00\x00\x8c"
|
|
b"\x0cbson.min_key\x94\x8c\x06MinKey\x94\x93\x94)"
|
|
b"\x81\x94."
|
|
)
|
|
|
|
self.round_trip_pickle(mink, pickled_with_3)
|
|
|
|
def test_maxkey_pickling(self):
|
|
maxk = MaxKey()
|
|
pickled_with_3 = (
|
|
b"\x80\x04\x95\x1e\x00\x00\x00\x00\x00\x00\x00\x8c"
|
|
b"\x0cbson.max_key\x94\x8c\x06MaxKey\x94\x93\x94)"
|
|
b"\x81\x94."
|
|
)
|
|
|
|
self.round_trip_pickle(maxk, pickled_with_3)
|
|
|
|
def test_int64_pickling(self):
|
|
i64 = Int64(9)
|
|
pickled_with_3 = (
|
|
b"\x80\x04\x95\x1e\x00\x00\x00\x00\x00\x00\x00\x8c\n"
|
|
b"bson.int64\x94\x8c\x05Int64\x94\x93\x94K\t\x85\x94"
|
|
b"\x81\x94."
|
|
)
|
|
self.round_trip_pickle(i64, pickled_with_3)
|
|
|
|
def test_bson_encode_decode(self) -> None:
|
|
doc = {"_id": ObjectId()}
|
|
encoded = bson.encode(doc)
|
|
decoded = bson.decode(encoded)
|
|
encoded = bson.encode(decoded)
|
|
decoded = bson.decode(encoded)
|
|
# Documents returned from decode are mutable.
|
|
decoded["new_field"] = 1
|
|
self.assertTrue(decoded["_id"].generation_time)
|
|
|
|
|
|
class TestDatetimeConversion(unittest.TestCase):
|
|
def test_comps(self):
|
|
# Tests other timestamp formats.
|
|
# Test each of the rich comparison methods.
|
|
pairs = [
|
|
(DatetimeMS(-1), DatetimeMS(1)),
|
|
(DatetimeMS(0), DatetimeMS(0)),
|
|
(DatetimeMS(1), DatetimeMS(-1)),
|
|
]
|
|
|
|
comp_ops = ["__lt__", "__le__", "__eq__", "__ne__", "__gt__", "__ge__"]
|
|
for lh, rh in pairs:
|
|
for op in comp_ops:
|
|
self.assertEqual(getattr(lh, op)(rh), getattr(lh._value, op)(rh._value))
|
|
|
|
def test_class_conversions(self):
|
|
# Test class conversions.
|
|
dtr1 = DatetimeMS(1234)
|
|
dt1 = dtr1.as_datetime()
|
|
self.assertEqual(dtr1, DatetimeMS(dt1))
|
|
|
|
dt2 = datetime.datetime(1969, 1, 1)
|
|
dtr2 = DatetimeMS(dt2)
|
|
self.assertEqual(dtr2.as_datetime(), dt2)
|
|
|
|
# Test encode and decode without codec options. Expect: DatetimeMS => datetime
|
|
dtr1 = DatetimeMS(0)
|
|
enc1 = encode({"x": dtr1})
|
|
dec1 = decode(enc1)
|
|
self.assertEqual(dec1["x"], datetime.datetime(1970, 1, 1))
|
|
self.assertNotEqual(type(dtr1), type(dec1["x"]))
|
|
|
|
# Test encode and decode with codec options. Expect: UTCDateimteRaw => DatetimeMS
|
|
opts1 = CodecOptions(datetime_conversion=DatetimeConversion.DATETIME_MS)
|
|
enc1 = encode({"x": dtr1})
|
|
dec1 = decode(enc1, opts1)
|
|
self.assertEqual(type(dtr1), type(dec1["x"]))
|
|
self.assertEqual(dtr1, dec1["x"])
|
|
|
|
# Expect: datetime => DatetimeMS
|
|
opts1 = CodecOptions(datetime_conversion=DatetimeConversion.DATETIME_MS)
|
|
dt1 = datetime.datetime(1970, 1, 1, tzinfo=datetime.timezone.utc)
|
|
enc1 = encode({"x": dt1})
|
|
dec1 = decode(enc1, opts1)
|
|
self.assertEqual(dec1["x"], DatetimeMS(0))
|
|
self.assertNotEqual(dt1, type(dec1["x"]))
|
|
|
|
def test_clamping(self):
|
|
# Test clamping from below and above.
|
|
opts = CodecOptions(
|
|
datetime_conversion=DatetimeConversion.DATETIME_CLAMP,
|
|
tz_aware=True,
|
|
tzinfo=datetime.timezone.utc,
|
|
)
|
|
below = encode({"x": DatetimeMS(_datetime_to_millis(datetime.datetime.min) - 1)})
|
|
dec_below = decode(below, opts)
|
|
self.assertEqual(
|
|
dec_below["x"], datetime.datetime.min.replace(tzinfo=datetime.timezone.utc)
|
|
)
|
|
|
|
above = encode({"x": DatetimeMS(_datetime_to_millis(datetime.datetime.max) + 1)})
|
|
dec_above = decode(above, opts)
|
|
self.assertEqual(
|
|
dec_above["x"],
|
|
datetime.datetime.max.replace(tzinfo=datetime.timezone.utc, microsecond=999000),
|
|
)
|
|
|
|
def test_tz_clamping_local(self):
|
|
# Naive clamping to local tz.
|
|
opts = CodecOptions(datetime_conversion=DatetimeConversion.DATETIME_CLAMP, tz_aware=False)
|
|
below = encode({"x": DatetimeMS(_datetime_to_millis(datetime.datetime.min) - 24 * 60 * 60)})
|
|
|
|
dec_below = decode(below, opts)
|
|
self.assertEqual(dec_below["x"], datetime.datetime.min)
|
|
|
|
above = encode({"x": DatetimeMS(_datetime_to_millis(datetime.datetime.max) + 24 * 60 * 60)})
|
|
dec_above = decode(above, opts)
|
|
self.assertEqual(
|
|
dec_above["x"],
|
|
datetime.datetime.max.replace(microsecond=999000),
|
|
)
|
|
|
|
def test_tz_clamping_utc(self):
|
|
# Aware clamping default utc.
|
|
opts = CodecOptions(datetime_conversion=DatetimeConversion.DATETIME_CLAMP, tz_aware=True)
|
|
below = encode({"x": DatetimeMS(_datetime_to_millis(datetime.datetime.min) - 24 * 60 * 60)})
|
|
dec_below = decode(below, opts)
|
|
self.assertEqual(
|
|
dec_below["x"], datetime.datetime.min.replace(tzinfo=datetime.timezone.utc)
|
|
)
|
|
|
|
above = encode({"x": DatetimeMS(_datetime_to_millis(datetime.datetime.max) + 24 * 60 * 60)})
|
|
dec_above = decode(above, opts)
|
|
self.assertEqual(
|
|
dec_above["x"],
|
|
datetime.datetime.max.replace(tzinfo=datetime.timezone.utc, microsecond=999000),
|
|
)
|
|
|
|
def test_tz_clamping_non_utc(self):
|
|
for tz in [FixedOffset(60, "+1H"), FixedOffset(-60, "-1H")]:
|
|
opts = CodecOptions(
|
|
datetime_conversion=DatetimeConversion.DATETIME_CLAMP, tz_aware=True, tzinfo=tz
|
|
)
|
|
# Min/max values in this timezone which can be represented in both BSON and datetime UTC.
|
|
try:
|
|
min_tz = datetime.datetime.min.replace(tzinfo=utc).astimezone(tz)
|
|
except OverflowError:
|
|
min_tz = datetime.datetime.min.replace(tzinfo=tz)
|
|
try:
|
|
max_tz = datetime.datetime.max.replace(tzinfo=utc, microsecond=999000).astimezone(
|
|
tz
|
|
)
|
|
except OverflowError:
|
|
max_tz = datetime.datetime.max.replace(tzinfo=tz, microsecond=999000)
|
|
|
|
for in_range in [
|
|
min_tz,
|
|
min_tz + datetime.timedelta(milliseconds=1),
|
|
max_tz - datetime.timedelta(milliseconds=1),
|
|
max_tz,
|
|
]:
|
|
doc = decode(encode({"x": in_range}), opts)
|
|
self.assertEqual(doc["x"], in_range)
|
|
|
|
for too_low in [
|
|
DatetimeMS(_datetime_to_millis(min_tz) - 1),
|
|
DatetimeMS(_datetime_to_millis(min_tz) - 60 * 60 * 1000),
|
|
DatetimeMS(_datetime_to_millis(min_tz) - 1 - 60 * 60 * 1000),
|
|
DatetimeMS(_datetime_to_millis(datetime.datetime.min) - 1),
|
|
DatetimeMS(_datetime_to_millis(datetime.datetime.min) - 60 * 60 * 1000),
|
|
DatetimeMS(_datetime_to_millis(datetime.datetime.min) - 1 - 60 * 60 * 1000),
|
|
]:
|
|
doc = decode(encode({"x": too_low}), opts)
|
|
self.assertEqual(doc["x"], min_tz)
|
|
|
|
for too_high in [
|
|
DatetimeMS(_datetime_to_millis(max_tz) + 1),
|
|
DatetimeMS(_datetime_to_millis(max_tz) + 60 * 60 * 1000),
|
|
DatetimeMS(_datetime_to_millis(max_tz) + 1 + 60 * 60 * 1000),
|
|
DatetimeMS(_datetime_to_millis(datetime.datetime.max) + 1),
|
|
DatetimeMS(_datetime_to_millis(datetime.datetime.max) + 60 * 60 * 1000),
|
|
DatetimeMS(_datetime_to_millis(datetime.datetime.max) + 1 + 60 * 60 * 1000),
|
|
]:
|
|
doc = decode(encode({"x": too_high}), opts)
|
|
self.assertEqual(doc["x"], max_tz)
|
|
|
|
def test_tz_clamping_non_utc_simple(self):
|
|
dtm = datetime.datetime(2024, 8, 23)
|
|
encoded = encode({"d": dtm})
|
|
self.assertEqual(decode(encoded)["d"], dtm)
|
|
for conversion in [
|
|
DatetimeConversion.DATETIME,
|
|
DatetimeConversion.DATETIME_CLAMP,
|
|
DatetimeConversion.DATETIME_AUTO,
|
|
]:
|
|
for tz in [FixedOffset(60, "+1H"), FixedOffset(-60, "-1H")]:
|
|
opts = CodecOptions(datetime_conversion=conversion, tz_aware=True, tzinfo=tz)
|
|
self.assertEqual(decode(encoded, opts)["d"], dtm.replace(tzinfo=utc).astimezone(tz))
|
|
|
|
def test_tz_clamping_non_hashable(self):
|
|
class NonHashableTZ(FixedOffset):
|
|
__hash__ = None
|
|
|
|
tz = NonHashableTZ(0, "UTC-non-hashable")
|
|
self.assertRaises(TypeError, hash, tz)
|
|
# Aware clamping.
|
|
opts = CodecOptions(
|
|
datetime_conversion=DatetimeConversion.DATETIME_CLAMP, tz_aware=True, tzinfo=tz
|
|
)
|
|
below = encode({"x": DatetimeMS(_datetime_to_millis(datetime.datetime.min) - 24 * 60 * 60)})
|
|
dec_below = decode(below, opts)
|
|
self.assertEqual(dec_below["x"], datetime.datetime.min.replace(tzinfo=tz))
|
|
|
|
within = encode({"x": EPOCH_AWARE.astimezone(tz)})
|
|
dec_within = decode(within, opts)
|
|
self.assertEqual(dec_within["x"], EPOCH_AWARE.astimezone(tz))
|
|
|
|
above = encode({"x": DatetimeMS(_datetime_to_millis(datetime.datetime.max) + 24 * 60 * 60)})
|
|
dec_above = decode(above, opts)
|
|
self.assertEqual(
|
|
dec_above["x"],
|
|
datetime.datetime.max.replace(tzinfo=tz, microsecond=999000),
|
|
)
|
|
|
|
def test_datetime_auto(self):
|
|
# Naive auto, in range.
|
|
opts1 = CodecOptions(datetime_conversion=DatetimeConversion.DATETIME_AUTO)
|
|
inr = encode({"x": datetime.datetime(1970, 1, 1)}, codec_options=opts1)
|
|
dec_inr = decode(inr)
|
|
self.assertEqual(dec_inr["x"], datetime.datetime(1970, 1, 1))
|
|
|
|
# Naive auto, below range.
|
|
below = encode({"x": DatetimeMS(_datetime_to_millis(datetime.datetime.min) - 24 * 60 * 60)})
|
|
dec_below = decode(below, opts1)
|
|
self.assertEqual(
|
|
dec_below["x"], DatetimeMS(_datetime_to_millis(datetime.datetime.min) - 24 * 60 * 60)
|
|
)
|
|
|
|
# Naive auto, above range.
|
|
above = encode({"x": DatetimeMS(_datetime_to_millis(datetime.datetime.max) + 24 * 60 * 60)})
|
|
dec_above = decode(above, opts1)
|
|
self.assertEqual(
|
|
dec_above["x"],
|
|
DatetimeMS(_datetime_to_millis(datetime.datetime.max) + 24 * 60 * 60),
|
|
)
|
|
|
|
# Aware auto, in range.
|
|
opts2 = CodecOptions(
|
|
datetime_conversion=DatetimeConversion.DATETIME_AUTO,
|
|
tz_aware=True,
|
|
tzinfo=datetime.timezone.utc,
|
|
)
|
|
inr = encode({"x": datetime.datetime(1970, 1, 1)}, codec_options=opts2)
|
|
dec_inr = decode(inr)
|
|
self.assertEqual(dec_inr["x"], datetime.datetime(1970, 1, 1))
|
|
|
|
# Aware auto, below range.
|
|
below = encode({"x": DatetimeMS(_datetime_to_millis(datetime.datetime.min) - 24 * 60 * 60)})
|
|
dec_below = decode(below, opts2)
|
|
self.assertEqual(
|
|
dec_below["x"], DatetimeMS(_datetime_to_millis(datetime.datetime.min) - 24 * 60 * 60)
|
|
)
|
|
|
|
# Aware auto, above range.
|
|
above = encode({"x": DatetimeMS(_datetime_to_millis(datetime.datetime.max) + 24 * 60 * 60)})
|
|
dec_above = decode(above, opts2)
|
|
self.assertEqual(
|
|
dec_above["x"],
|
|
DatetimeMS(_datetime_to_millis(datetime.datetime.max) + 24 * 60 * 60),
|
|
)
|
|
|
|
def test_millis_from_datetime_ms(self):
|
|
# Test 65+ bit integer conversion, expect OverflowError.
|
|
big_ms = 2**65
|
|
with self.assertRaises(OverflowError):
|
|
encode({"x": DatetimeMS(big_ms)})
|
|
|
|
# Subclass of DatetimeMS w/ __int__ override, expect an Error.
|
|
class DatetimeMSOverride(DatetimeMS):
|
|
def __int__(self):
|
|
return float(self._value)
|
|
|
|
float_ms = DatetimeMSOverride(2)
|
|
with self.assertRaises(TypeError):
|
|
encode({"x": float_ms})
|
|
|
|
# Test InvalidBSON errors on conversion include _DATETIME_ERROR_SUGGESTION
|
|
small_ms = -2 << 51
|
|
with self.assertRaisesRegex(InvalidBSON, re.compile(re.escape(_DATETIME_ERROR_SUGGESTION))):
|
|
decode(encode({"a": DatetimeMS(small_ms)}))
|
|
|
|
def test_array_of_documents_to_buffer(self):
|
|
doc = dict(a=1)
|
|
buf = _array_of_documents_to_buffer(encode({"0": doc}))
|
|
self.assertEqual(buf, encode(doc))
|
|
buf = _array_of_documents_to_buffer(encode({"0": doc, "1": doc}))
|
|
self.assertEqual(buf, encode(doc) + encode(doc))
|
|
with self.assertRaises(InvalidBSON):
|
|
_array_of_documents_to_buffer(encode({"0": doc, "1": doc}) + b"1")
|
|
buf = encode({"0": doc, "1": doc})
|
|
buf = buf[:-1] + b"1"
|
|
with self.assertRaises(InvalidBSON):
|
|
_array_of_documents_to_buffer(buf)
|
|
# We replace the size of the array with \xff\xff\xff\x00 which is -221 as an int32.
|
|
buf = b"\x14\x00\x00\x00\x04a\x00\xff\xff\xff\x00\x100\x00\x01\x00\x00\x00\x00\x00"
|
|
with self.assertRaises(InvalidBSON):
|
|
_array_of_documents_to_buffer(buf)
|
|
|
|
|
|
class TestLongLongToString(unittest.TestCase):
|
|
def test_long_long_to_string(self):
|
|
try:
|
|
from bson import _cbson
|
|
|
|
_cbson._test_long_long_to_str()
|
|
except ImportError:
|
|
print("_cbson was not imported. Check compilation logs.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|