Merge branch 'master' of github.com:mongodb/mongo-python-driver

This commit is contained in:
Steven Silvester 2024-10-11 15:01:58 -05:00
commit bf9452b6f9
No known key found for this signature in database
GPG Key ID: B1BF5EC3A8B32F91
38 changed files with 3696 additions and 1099 deletions

View File

@ -257,9 +257,9 @@ if [ -z "$GREEN_FRAMEWORK" ]; then
# Use --capture=tee-sys so pytest prints test output inline:
# https://docs.pytest.org/en/stable/how-to/capture-stdout-stderr.html
if [ -z "$TEST_SUITES" ]; then
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 $TEST_ARGS
python -m pytest -v --capture=tee-sys --durations=5 $TEST_ARGS
else
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 -m $TEST_SUITES $TEST_ARGS
python -m pytest -v --capture=tee-sys --durations=5 -m $TEST_SUITES $TEST_ARGS
fi
else
python green_framework_test.py $GREEN_FRAMEWORK -v $TEST_ARGS

View File

@ -67,7 +67,7 @@ jobs:
# Note: the default manylinux is manylinux2014
run: |
python -m pip install -U pip
python -m pip install "cibuildwheel>=2.17,<3"
python -m pip install "cibuildwheel>=2.20,<3"
- name: Build wheels
env:
@ -89,6 +89,9 @@ jobs:
ls wheelhouse/*cp310*.whl
ls wheelhouse/*cp311*.whl
ls wheelhouse/*cp312*.whl
ls wheelhouse/*cp313*.whl
# Free-threading builds:
ls wheelhouse/*cp313t*.whl
- uses: actions/upload-artifact@v4
with:

View File

@ -51,11 +51,18 @@ jobs:
strategy:
matrix:
os: [ubuntu-20.04]
python-version: ["3.9", "pypy-3.9", "3.13"]
python-version: ["3.9", "pypy-3.9", "3.13", "3.13t"]
name: CPython ${{ matrix.python-version }}-${{ matrix.os }}
steps:
- uses: actions/checkout@v4
- name: Setup Python
- if: ${{ matrix.python-version == '3.13t' }}
name: Setup free-threaded Python
uses: deadsnakes/action@v3.2.0
with:
python-version: 3.13
nogil: true
- if: ${{ matrix.python-version != '3.13t' }}
name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
@ -65,9 +72,13 @@ jobs:
- name: Install dependencies
run: |
pip install -U pip
if [ "${{ matrix.python-version }}" == "3.13" ]; then
if [[ "${{ matrix.python-version }}" == "3.13" ]]; then
pip install --pre cffi setuptools
pip install --no-build-isolation hatch
elif [[ "${{ matrix.python-version }}" == "3.13t" ]]; then
# Hatch can't be installed on 3.13t, use pytest directly.
pip install .
pip install -r requirements/test.txt
else
pip install hatch
fi
@ -77,7 +88,11 @@ jobs:
mongodb-version: 6.0
- name: Run tests
run: |
hatch run test:test
if [[ "${{ matrix.python-version }}" == "3.13t" ]]; then
pytest -v --durations=5 --maxfail=10
else
hatch run test:test
fi
doctest:
runs-on: ubuntu-latest

View File

@ -3184,6 +3184,9 @@ static PyModuleDef_Slot _cbson_slots[] = {
{Py_mod_exec, _cbson_exec},
#if defined(Py_MOD_MULTIPLE_INTERPRETERS_SUPPORTED)
{Py_mod_multiple_interpreters, Py_MOD_MULTIPLE_INTERPRETERS_SUPPORTED},
#endif
#if PY_VERSION_HEX >= 0x030D0000
{Py_mod_gil, Py_MOD_GIL_NOT_USED},
#endif
{0, NULL},
};

View File

@ -12,6 +12,8 @@ PyMongo 4.11 brings a number of changes including:
- Dropped support for Python 3.8.
- Dropped support for MongoDB 3.6.
- Added support for free-threaded Python with the GIL disabled. For more information see:
`Free-threaded CPython <https://docs.python.org/3.13/whatsnew/3.13.html#whatsnew313-free-threaded-cpython>`_.
Issues Resolved
...............

View File

@ -1022,6 +1022,9 @@ static PyModuleDef_Slot _cmessage_slots[] = {
{Py_mod_exec, _cmessage_exec},
#ifdef Py_MOD_MULTIPLE_INTERPRETERS_SUPPORTED
{Py_mod_multiple_interpreters, Py_MOD_MULTIPLE_INTERPRETERS_SUPPORTED},
#endif
#if PY_VERSION_HEX >= 0x030D0000
{Py_mod_gil, Py_MOD_GIL_NOT_USED},
#endif
{0, NULL},
};

View File

@ -180,10 +180,20 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
while kms_context.bytes_needed > 0:
# CSOT: update timeout.
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
data = conn.recv(kms_context.bytes_needed)
if _IS_SYNC:
data = conn.recv(kms_context.bytes_needed)
else:
from pymongo.network_layer import ( # type: ignore[attr-defined]
async_receive_data_socket,
)
data = await async_receive_data_socket(conn, kms_context.bytes_needed)
if not data:
raise OSError("KMS connection closed")
kms_context.feed(data)
# Async raises an OSError instead of returning empty bytes
except OSError as err:
raise OSError("KMS connection closed") from err
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
finally:

View File

@ -130,7 +130,7 @@ if sys.platform != "win32":
loop.remove_writer(fd)
async def _async_receive_ssl(
conn: _sslConn, length: int, loop: AbstractEventLoop
conn: _sslConn, length: int, loop: AbstractEventLoop, once: Optional[bool] = False
) -> memoryview:
mv = memoryview(bytearray(length))
total_read = 0
@ -145,6 +145,9 @@ if sys.platform != "win32":
read = conn.recv_into(mv[total_read:])
if read == 0:
raise OSError("connection closed")
# KMS responses update their expected size after the first batch, stop reading after one loop
if once:
return mv[:read]
total_read += read
except BLOCKING_IO_ERRORS as exc:
fd = conn.fileno()
@ -275,6 +278,28 @@ async def async_receive_data(
sock.settimeout(sock_timeout)
async def async_receive_data_socket(
sock: Union[socket.socket, _sslConn], length: int
) -> memoryview:
sock_timeout = sock.gettimeout()
timeout = sock_timeout
sock.settimeout(0.0)
loop = asyncio.get_event_loop()
try:
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
return await asyncio.wait_for(
_async_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type]
timeout=timeout,
)
else:
return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type]
except asyncio.TimeoutError as err:
raise socket.timeout("timed out") from err
finally:
sock.settimeout(sock_timeout)
async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview:
mv = memoryview(bytearray(length))
bytes_read = 0

View File

@ -180,10 +180,20 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
while kms_context.bytes_needed > 0:
# CSOT: update timeout.
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
data = conn.recv(kms_context.bytes_needed)
if _IS_SYNC:
data = conn.recv(kms_context.bytes_needed)
else:
from pymongo.network_layer import ( # type: ignore[attr-defined]
receive_data_socket,
)
data = receive_data_socket(conn, kms_context.bytes_needed)
if not data:
raise OSError("KMS connection closed")
kms_context.feed(data)
# Async raises an OSError instead of returning empty bytes
except OSError as err:
raise OSError("KMS connection closed") from err
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
finally:

View File

@ -236,6 +236,8 @@ partial_branches = ["if (.*and +)*not _use_c( and.*)*:"]
directory = "htmlcov"
[tool.cibuildwheel]
# Enable free-threaded support
free-threaded-support = true
skip = "pp* *-musllinux*"
build-frontend = "build"
test-command = "python {project}/tools/fail_if_no_c.py"

View File

@ -464,11 +464,12 @@ class ClientContext:
if not self.connected:
pair = self.pair
raise SkipTest(f"Cannot connect to MongoDB on {pair}")
if iscoroutinefunction(condition) and condition():
if wraps_async:
return f(*args, **kwargs)
else:
return f(*args, **kwargs)
if iscoroutinefunction(condition):
if condition():
if wraps_async:
return f(*args, **kwargs)
else:
return f(*args, **kwargs)
elif condition():
if wraps_async:
return f(*args, **kwargs)

View File

@ -466,11 +466,12 @@ class AsyncClientContext:
if not self.connected:
pair = await self.pair
raise SkipTest(f"Cannot connect to MongoDB on {pair}")
if iscoroutinefunction(condition) and await condition():
if wraps_async:
return await f(*args, **kwargs)
else:
return f(*args, **kwargs)
if iscoroutinefunction(condition):
if await condition():
if wraps_async:
return await f(*args, **kwargs)
else:
return f(*args, **kwargs)
elif condition():
if wraps_async:
return await f(*args, **kwargs)

View File

@ -61,6 +61,13 @@ class TestAsyncClientContext(AsyncUnitTest):
self.assertEqual(sys.getdefaultencoding(), os.environ["SETDEFAULTENCODING"])
def test_free_threading_is_enabled(self):
if "free-threading build" not in sys.version:
raise SkipTest("this test requires the Python free-threading build")
# If the GIL is enabled then pymongo or one of our deps does not support free-threading.
self.assertFalse(sys._is_gil_enabled()) # type: ignore[attr-defined]
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,290 @@
# Copyright 2016-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the collation module."""
from __future__ import annotations
import functools
import warnings
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
from test.utils import EventListener
from typing import Any
from pymongo.asynchronous.helpers import anext
from pymongo.collation import (
Collation,
CollationAlternate,
CollationCaseFirst,
CollationMaxVariable,
CollationStrength,
)
from pymongo.errors import ConfigurationError
from pymongo.operations import (
DeleteMany,
DeleteOne,
IndexModel,
ReplaceOne,
UpdateMany,
UpdateOne,
)
from pymongo.write_concern import WriteConcern
_IS_SYNC = False
class TestCollationObject(unittest.TestCase):
def test_constructor(self):
self.assertRaises(TypeError, Collation, locale=42)
# Fill in a locale to test the other options.
_Collation = functools.partial(Collation, "en_US")
# No error.
_Collation(caseFirst=CollationCaseFirst.UPPER)
self.assertRaises(TypeError, _Collation, caseLevel="true")
self.assertRaises(ValueError, _Collation, strength="six")
self.assertRaises(TypeError, _Collation, numericOrdering="true")
self.assertRaises(TypeError, _Collation, alternate=5)
self.assertRaises(TypeError, _Collation, maxVariable=2)
self.assertRaises(TypeError, _Collation, normalization="false")
self.assertRaises(TypeError, _Collation, backwards="true")
# No errors.
Collation("en_US", future_option="bar", another_option=42)
collation = Collation(
"en_US",
caseLevel=True,
caseFirst=CollationCaseFirst.UPPER,
strength=CollationStrength.QUATERNARY,
numericOrdering=True,
alternate=CollationAlternate.SHIFTED,
maxVariable=CollationMaxVariable.SPACE,
normalization=True,
backwards=True,
)
self.assertEqual(
{
"locale": "en_US",
"caseLevel": True,
"caseFirst": "upper",
"strength": 4,
"numericOrdering": True,
"alternate": "shifted",
"maxVariable": "space",
"normalization": True,
"backwards": True,
},
collation.document,
)
self.assertEqual(
{"locale": "en_US", "backwards": True}, Collation("en_US", backwards=True).document
)
class TestCollation(AsyncIntegrationTest):
listener: EventListener
warn_context: Any
collation: Collation
@classmethod
@async_client_context.require_connection
async def _setup_class(cls):
await super()._setup_class()
cls.listener = EventListener()
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
cls.db = cls.client.pymongo_test
cls.collation = Collation("en_US")
cls.warn_context = warnings.catch_warnings()
cls.warn_context.__enter__()
warnings.simplefilter("ignore", DeprecationWarning)
@classmethod
async def _tearDown_class(cls):
cls.warn_context.__exit__()
cls.warn_context = None
await cls.client.close()
await super()._tearDown_class()
def tearDown(self):
self.listener.reset()
super().tearDown()
def last_command_started(self):
return self.listener.started_events[-1].command
def assertCollationInLastCommand(self):
self.assertEqual(self.collation.document, self.last_command_started()["collation"])
async def test_create_collection(self):
await self.db.test.drop()
await self.db.create_collection("test", collation=self.collation)
self.assertCollationInLastCommand()
# Test passing collation as a dict as well.
await self.db.test.drop()
self.listener.reset()
await self.db.create_collection("test", collation=self.collation.document)
self.assertCollationInLastCommand()
def test_index_model(self):
model = IndexModel([("a", 1), ("b", -1)], collation=self.collation)
self.assertEqual(self.collation.document, model.document["collation"])
async def test_create_index(self):
await self.db.test.create_index("foo", collation=self.collation)
ci_cmd = self.listener.started_events[0].command
self.assertEqual(self.collation.document, ci_cmd["indexes"][0]["collation"])
async def test_aggregate(self):
await self.db.test.aggregate([{"$group": {"_id": 42}}], collation=self.collation)
self.assertCollationInLastCommand()
async def test_count_documents(self):
await self.db.test.count_documents({}, collation=self.collation)
self.assertCollationInLastCommand()
async def test_distinct(self):
await self.db.test.distinct("foo", collation=self.collation)
self.assertCollationInLastCommand()
self.listener.reset()
await self.db.test.find(collation=self.collation).distinct("foo")
self.assertCollationInLastCommand()
async def test_find_command(self):
await self.db.test.insert_one({"is this thing on?": True})
self.listener.reset()
await anext(self.db.test.find(collation=self.collation))
self.assertCollationInLastCommand()
async def test_explain_command(self):
self.listener.reset()
await self.db.test.find(collation=self.collation).explain()
# The collation should be part of the explained command.
self.assertEqual(
self.collation.document, self.last_command_started()["explain"]["collation"]
)
async def test_delete(self):
await self.db.test.delete_one({"foo": 42}, collation=self.collation)
command = self.listener.started_events[0].command
self.assertEqual(self.collation.document, command["deletes"][0]["collation"])
self.listener.reset()
await self.db.test.delete_many({"foo": 42}, collation=self.collation)
command = self.listener.started_events[0].command
self.assertEqual(self.collation.document, command["deletes"][0]["collation"])
async def test_update(self):
await self.db.test.replace_one({"foo": 42}, {"foo": 43}, collation=self.collation)
command = self.listener.started_events[0].command
self.assertEqual(self.collation.document, command["updates"][0]["collation"])
self.listener.reset()
await self.db.test.update_one({"foo": 42}, {"$set": {"foo": 43}}, collation=self.collation)
command = self.listener.started_events[0].command
self.assertEqual(self.collation.document, command["updates"][0]["collation"])
self.listener.reset()
await self.db.test.update_many({"foo": 42}, {"$set": {"foo": 43}}, collation=self.collation)
command = self.listener.started_events[0].command
self.assertEqual(self.collation.document, command["updates"][0]["collation"])
async def test_find_and(self):
await self.db.test.find_one_and_delete({"foo": 42}, collation=self.collation)
self.assertCollationInLastCommand()
self.listener.reset()
await self.db.test.find_one_and_update(
{"foo": 42}, {"$set": {"foo": 43}}, collation=self.collation
)
self.assertCollationInLastCommand()
self.listener.reset()
await self.db.test.find_one_and_replace({"foo": 42}, {"foo": 43}, collation=self.collation)
self.assertCollationInLastCommand()
async def test_bulk_write(self):
await self.db.test.collection.bulk_write(
[
DeleteOne({"noCollation": 42}),
DeleteMany({"noCollation": 42}),
DeleteOne({"foo": 42}, collation=self.collation),
DeleteMany({"foo": 42}, collation=self.collation),
ReplaceOne({"noCollation": 24}, {"bar": 42}),
UpdateOne({"noCollation": 84}, {"$set": {"bar": 10}}, upsert=True),
UpdateMany({"noCollation": 45}, {"$set": {"bar": 42}}),
ReplaceOne({"foo": 24}, {"foo": 42}, collation=self.collation),
UpdateOne(
{"foo": 84}, {"$set": {"foo": 10}}, upsert=True, collation=self.collation
),
UpdateMany({"foo": 45}, {"$set": {"foo": 42}}, collation=self.collation),
]
)
delete_cmd = self.listener.started_events[0].command
update_cmd = self.listener.started_events[1].command
def check_ops(ops):
for op in ops:
if "noCollation" in op["q"]:
self.assertNotIn("collation", op)
else:
self.assertEqual(self.collation.document, op["collation"])
check_ops(delete_cmd["deletes"])
check_ops(update_cmd["updates"])
async def test_indexes_same_keys_different_collations(self):
await self.db.test.drop()
usa_collation = Collation("en_US")
ja_collation = Collation("ja")
await self.db.test.create_indexes(
[
IndexModel("fieldname", collation=usa_collation),
IndexModel("fieldname", name="japanese_version", collation=ja_collation),
IndexModel("fieldname", name="simple"),
]
)
indexes = await self.db.test.index_information()
self.assertEqual(
usa_collation.document["locale"], indexes["fieldname_1"]["collation"]["locale"]
)
self.assertEqual(
ja_collation.document["locale"], indexes["japanese_version"]["collation"]["locale"]
)
self.assertNotIn("collation", indexes["simple"])
await self.db.test.drop_index("fieldname_1")
indexes = await self.db.test.index_information()
self.assertIn("japanese_version", indexes)
self.assertIn("simple", indexes)
self.assertNotIn("fieldname", indexes)
async def test_unacknowledged_write(self):
unacknowledged = WriteConcern(w=0)
collection = self.db.get_collection("test", write_concern=unacknowledged)
with self.assertRaises(ConfigurationError):
await collection.update_one(
{"hello": "world"}, {"$set": {"hello": "moon"}}, collation=self.collation
)
update_one = UpdateOne(
{"hello": "world"}, {"$set": {"hello": "moon"}}, collation=self.collation
)
with self.assertRaises(ConfigurationError):
await collection.bulk_write([update_one])
async def test_cursor_collation(self):
await self.db.test.insert_one({"hello": "world"})
await anext(self.db.test.find().collation(self.collation))
self.assertCollationInLastCommand()

View File

@ -0,0 +1,44 @@
# Copyright 2023-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run the command monitoring unified format spec tests."""
from __future__ import annotations
import os
import pathlib
import sys
sys.path[0:0] = [""]
from test import unittest
from test.asynchronous.unified_format import generate_test_classes
_IS_SYNC = False
# Location of JSON test specifications.
if _IS_SYNC:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_logging")
else:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_logging")
globals().update(
generate_test_classes(
_TEST_PATH,
module=__name__,
)
)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,45 @@
# Copyright 2015-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run the command monitoring unified format spec tests."""
from __future__ import annotations
import os
import pathlib
import sys
sys.path[0:0] = [""]
from test import unittest
from test.asynchronous.unified_format import generate_test_classes
_IS_SYNC = False
# Location of JSON test specifications.
if _IS_SYNC:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_monitoring")
else:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_monitoring")
globals().update(
generate_test_classes(
_TEST_PATH,
module=__name__,
)
)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,185 @@
# Copyright 2011-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the pymongo common module."""
from __future__ import annotations
import sys
import uuid
sys.path[0:0] = [""]
from test.asynchronous import AsyncIntegrationTest, async_client_context, connected, unittest
from bson.binary import PYTHON_LEGACY, STANDARD, Binary, UuidRepresentation
from bson.codec_options import CodecOptions
from bson.objectid import ObjectId
from pymongo.errors import OperationFailure
from pymongo.write_concern import WriteConcern
_IS_SYNC = False
class TestCommon(AsyncIntegrationTest):
async def test_uuid_representation(self):
coll = self.db.uuid
await coll.drop()
# Test property
self.assertEqual(UuidRepresentation.UNSPECIFIED, coll.codec_options.uuid_representation)
# Test basic query
uu = uuid.uuid4()
# Insert as binary subtype 3
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
legacy_opts = coll.codec_options
await coll.insert_one({"uu": uu})
self.assertEqual(uu, (await coll.find_one({"uu": uu}))["uu"]) # type: ignore
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=STANDARD))
self.assertEqual(STANDARD, coll.codec_options.uuid_representation)
self.assertEqual(None, await coll.find_one({"uu": uu}))
uul = Binary.from_uuid(uu, PYTHON_LEGACY)
self.assertEqual(uul, (await coll.find_one({"uu": uul}))["uu"]) # type: ignore
# Test count_documents
self.assertEqual(0, await coll.count_documents({"uu": uu}))
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
self.assertEqual(1, await coll.count_documents({"uu": uu}))
# Test delete
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=STANDARD))
await coll.delete_one({"uu": uu})
self.assertEqual(1, await coll.count_documents({}))
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
await coll.delete_one({"uu": uu})
self.assertEqual(0, await coll.count_documents({}))
# Test update_one
await coll.insert_one({"_id": uu, "i": 1})
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=STANDARD))
await coll.update_one({"_id": uu}, {"$set": {"i": 2}})
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
self.assertEqual(1, (await coll.find_one({"_id": uu}))["i"]) # type: ignore
await coll.update_one({"_id": uu}, {"$set": {"i": 2}})
self.assertEqual(2, (await coll.find_one({"_id": uu}))["i"]) # type: ignore
# Test Cursor.distinct
self.assertEqual([2], await coll.find({"_id": uu}).distinct("i"))
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=STANDARD))
self.assertEqual([], await coll.find({"_id": uu}).distinct("i"))
# Test findAndModify
self.assertEqual(None, await coll.find_one_and_update({"_id": uu}, {"$set": {"i": 5}}))
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
self.assertEqual(2, (await coll.find_one_and_update({"_id": uu}, {"$set": {"i": 5}}))["i"])
self.assertEqual(5, (await coll.find_one({"_id": uu}))["i"]) # type: ignore
# Test command
self.assertEqual(
5,
(
await self.db.command(
"findAndModify",
"uuid",
update={"$set": {"i": 6}},
query={"_id": uu},
codec_options=legacy_opts,
)
)["value"]["i"],
)
self.assertEqual(
6,
(
await self.db.command(
"findAndModify",
"uuid",
update={"$set": {"i": 7}},
query={"_id": Binary.from_uuid(uu, PYTHON_LEGACY)},
)
)["value"]["i"],
)
async def test_write_concern(self):
c = await self.async_rs_or_single_client(connect=False)
self.assertEqual(WriteConcern(), c.write_concern)
c = await self.async_rs_or_single_client(connect=False, w=2, wTimeoutMS=1000)
wc = WriteConcern(w=2, wtimeout=1000)
self.assertEqual(wc, c.write_concern)
# Can we override back to the server default?
db = c.get_database("pymongo_test", write_concern=WriteConcern())
self.assertEqual(db.write_concern, WriteConcern())
db = c.pymongo_test
self.assertEqual(wc, db.write_concern)
coll = db.test
self.assertEqual(wc, coll.write_concern)
cwc = WriteConcern(j=True)
coll = db.get_collection("test", write_concern=cwc)
self.assertEqual(cwc, coll.write_concern)
self.assertEqual(wc, db.write_concern)
async def test_mongo_client(self):
pair = await async_client_context.pair
m = await self.async_rs_or_single_client(w=0)
coll = m.pymongo_test.write_concern_test
await coll.drop()
doc = {"_id": ObjectId()}
await coll.insert_one(doc)
self.assertTrue(await coll.insert_one(doc))
coll = coll.with_options(write_concern=WriteConcern(w=1))
with self.assertRaises(OperationFailure):
await coll.insert_one(doc)
m = await self.async_rs_or_single_client()
coll = m.pymongo_test.write_concern_test
new_coll = coll.with_options(write_concern=WriteConcern(w=0))
self.assertTrue(await new_coll.insert_one(doc))
with self.assertRaises(OperationFailure):
await coll.insert_one(doc)
m = await self.async_rs_or_single_client(
f"mongodb://{pair}/", replicaSet=async_client_context.replica_set_name
)
coll = m.pymongo_test.write_concern_test
with self.assertRaises(OperationFailure):
await coll.insert_one(doc)
m = await self.async_rs_or_single_client(
f"mongodb://{pair}/?w=0", replicaSet=async_client_context.replica_set_name
)
coll = m.pymongo_test.write_concern_test
await coll.insert_one(doc)
# Equality tests
direct = await connected(await self.async_single_client(w=0))
direct2 = await connected(
await self.async_single_client(f"mongodb://{pair}/?w=0", **self.credentials)
)
self.assertEqual(direct, direct2)
self.assertFalse(direct != direct2)
async def test_validate_boolean(self):
await self.db.test.update_one({}, {"$set": {"total": 1}}, upsert=True)
with self.assertRaisesRegex(
TypeError, "upsert must be True or False, was: upsert={'upsert': True}"
):
await self.db.test.update_one({}, {"$set": {"total": 1}}, {"upsert": True}) # type: ignore
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,45 @@
# Copyright 2023-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run the connection logging unified format spec tests."""
from __future__ import annotations
import os
import pathlib
import sys
sys.path[0:0] = [""]
from test import unittest
from test.unified_format import generate_test_classes
_IS_SYNC = False
# Location of JSON test specifications.
if _IS_SYNC:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "connection_logging")
else:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "connection_logging")
globals().update(
generate_test_classes(
_TEST_PATH,
module=__name__,
)
)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,39 @@
# Copyright 2021-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test the CRUD unified spec tests."""
from __future__ import annotations
import os
import pathlib
import sys
sys.path[0:0] = [""]
from test import unittest
from test.unified_format import generate_test_classes
_IS_SYNC = False
# Location of JSON test specifications.
if _IS_SYNC:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "crud", "unified")
else:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "crud", "unified")
# Generate unified tests.
globals().update(generate_test_classes(_TEST_PATH, module=__name__, RUN_ON_SERVERLESS=True))
if __name__ == "__main__":
unittest.main()

View File

@ -30,6 +30,7 @@ import uuid
import warnings
from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, async_client_context
from test.asynchronous.test_bulk import AsyncBulkTestBase
from test.asynchronous.utils_spec_runner import AsyncSpecRunner, AsyncSpecTestCreator
from threading import Thread
from typing import Any, Dict, Mapping, Optional
@ -59,7 +60,6 @@ from test.unified_format import generate_test_classes
from test.utils import (
AllowListEventListener,
OvertCommandListener,
SpecTestCreator,
TopologyEventListener,
async_wait_until,
camel_to_snake_args,
@ -626,132 +626,132 @@ AWS_TEMP_NO_SESSION_CREDS = {
KMS_TLS_OPTS = {"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}}
if _IS_SYNC:
# TODO: Add asynchronous SpecRunner (https://jira.mongodb.org/browse/PYTHON-4700)
class TestSpec(AsyncSpecRunner):
@classmethod
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
def setUpClass(cls):
super().setUpClass()
class AsyncTestSpec(AsyncSpecRunner):
@classmethod
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
async def _setup_class(cls):
await super()._setup_class()
def parse_auto_encrypt_opts(self, opts):
"""Parse clientOptions.autoEncryptOpts."""
opts = camel_to_snake_args(opts)
kms_providers = opts["kms_providers"]
if "aws" in kms_providers:
kms_providers["aws"] = AWS_CREDS
if not any(AWS_CREDS.values()):
self.skipTest("AWS environment credentials are not set")
if "awsTemporary" in kms_providers:
kms_providers["aws"] = AWS_TEMP_CREDS
del kms_providers["awsTemporary"]
if not any(AWS_TEMP_CREDS.values()):
self.skipTest("AWS Temp environment credentials are not set")
if "awsTemporaryNoSessionToken" in kms_providers:
kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS
del kms_providers["awsTemporaryNoSessionToken"]
if not any(AWS_TEMP_NO_SESSION_CREDS.values()):
self.skipTest("AWS Temp environment credentials are not set")
if "azure" in kms_providers:
kms_providers["azure"] = AZURE_CREDS
if not any(AZURE_CREDS.values()):
self.skipTest("Azure environment credentials are not set")
if "gcp" in kms_providers:
kms_providers["gcp"] = GCP_CREDS
if not any(AZURE_CREDS.values()):
self.skipTest("GCP environment credentials are not set")
if "kmip" in kms_providers:
kms_providers["kmip"] = KMIP_CREDS
opts["kms_tls_options"] = KMS_TLS_OPTS
if "key_vault_namespace" not in opts:
opts["key_vault_namespace"] = "keyvault.datakeys"
if "extra_options" in opts:
opts.update(camel_to_snake_args(opts.pop("extra_options")))
def parse_auto_encrypt_opts(self, opts):
"""Parse clientOptions.autoEncryptOpts."""
opts = camel_to_snake_args(opts)
kms_providers = opts["kms_providers"]
if "aws" in kms_providers:
kms_providers["aws"] = AWS_CREDS
if not any(AWS_CREDS.values()):
self.skipTest("AWS environment credentials are not set")
if "awsTemporary" in kms_providers:
kms_providers["aws"] = AWS_TEMP_CREDS
del kms_providers["awsTemporary"]
if not any(AWS_TEMP_CREDS.values()):
self.skipTest("AWS Temp environment credentials are not set")
if "awsTemporaryNoSessionToken" in kms_providers:
kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS
del kms_providers["awsTemporaryNoSessionToken"]
if not any(AWS_TEMP_NO_SESSION_CREDS.values()):
self.skipTest("AWS Temp environment credentials are not set")
if "azure" in kms_providers:
kms_providers["azure"] = AZURE_CREDS
if not any(AZURE_CREDS.values()):
self.skipTest("Azure environment credentials are not set")
if "gcp" in kms_providers:
kms_providers["gcp"] = GCP_CREDS
if not any(AZURE_CREDS.values()):
self.skipTest("GCP environment credentials are not set")
if "kmip" in kms_providers:
kms_providers["kmip"] = KMIP_CREDS
opts["kms_tls_options"] = KMS_TLS_OPTS
if "key_vault_namespace" not in opts:
opts["key_vault_namespace"] = "keyvault.datakeys"
if "extra_options" in opts:
opts.update(camel_to_snake_args(opts.pop("extra_options")))
opts = dict(opts)
return AutoEncryptionOpts(**opts)
opts = dict(opts)
return AutoEncryptionOpts(**opts)
def parse_client_options(self, opts):
"""Override clientOptions parsing to support autoEncryptOpts."""
encrypt_opts = opts.pop("autoEncryptOpts", None)
if encrypt_opts:
opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts)
def parse_client_options(self, opts):
"""Override clientOptions parsing to support autoEncryptOpts."""
encrypt_opts = opts.pop("autoEncryptOpts", None)
if encrypt_opts:
opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts)
return super().parse_client_options(opts)
return super().parse_client_options(opts)
def get_object_name(self, op):
"""Default object is collection."""
return op.get("object", "collection")
def get_object_name(self, op):
"""Default object is collection."""
return op.get("object", "collection")
def maybe_skip_scenario(self, test):
super().maybe_skip_scenario(test)
desc = test["description"].lower()
if (
"timeoutms applied to listcollections to get collection schema" in desc
and sys.platform in ("win32", "darwin")
):
self.skipTest("PYTHON-3706 flaky test on Windows/macOS")
if "type=symbol" in desc:
self.skipTest("PyMongo does not support the symbol type")
def maybe_skip_scenario(self, test):
super().maybe_skip_scenario(test)
desc = test["description"].lower()
if (
"timeoutms applied to listcollections to get collection schema" in desc
and sys.platform in ("win32", "darwin")
):
self.skipTest("PYTHON-3706 flaky test on Windows/macOS")
if "type=symbol" in desc:
self.skipTest("PyMongo does not support the symbol type")
if "timeoutms applied to listcollections to get collection schema" in desc and not _IS_SYNC:
self.skipTest("PYTHON-4844 flaky test on async")
def setup_scenario(self, scenario_def):
"""Override a test's setup."""
key_vault_data = scenario_def["key_vault_data"]
encrypted_fields = scenario_def["encrypted_fields"]
json_schema = scenario_def["json_schema"]
data = scenario_def["data"]
coll = async_client_context.client.get_database("keyvault", codec_options=OPTS)[
"datakeys"
]
coll.delete_many({})
if key_vault_data:
coll.insert_many(key_vault_data)
async def setup_scenario(self, scenario_def):
"""Override a test's setup."""
key_vault_data = scenario_def["key_vault_data"]
encrypted_fields = scenario_def["encrypted_fields"]
json_schema = scenario_def["json_schema"]
data = scenario_def["data"]
coll = async_client_context.client.get_database("keyvault", codec_options=OPTS)["datakeys"]
await coll.delete_many({})
if key_vault_data:
await coll.insert_many(key_vault_data)
db_name = self.get_scenario_db_name(scenario_def)
coll_name = self.get_scenario_coll_name(scenario_def)
db = async_client_context.client.get_database(db_name, codec_options=OPTS)
coll = db.drop_collection(coll_name, encrypted_fields=encrypted_fields)
wc = WriteConcern(w="majority")
kwargs: Dict[str, Any] = {}
if json_schema:
kwargs["validator"] = {"$jsonSchema": json_schema}
kwargs["codec_options"] = OPTS
if not data:
kwargs["write_concern"] = wc
if encrypted_fields:
kwargs["encryptedFields"] = encrypted_fields
db.create_collection(coll_name, **kwargs)
coll = db[coll_name]
if data:
# Load data.
coll.with_options(write_concern=wc).insert_many(scenario_def["data"])
db_name = self.get_scenario_db_name(scenario_def)
coll_name = self.get_scenario_coll_name(scenario_def)
db = async_client_context.client.get_database(db_name, codec_options=OPTS)
await db.drop_collection(coll_name, encrypted_fields=encrypted_fields)
wc = WriteConcern(w="majority")
kwargs: Dict[str, Any] = {}
if json_schema:
kwargs["validator"] = {"$jsonSchema": json_schema}
kwargs["codec_options"] = OPTS
if not data:
kwargs["write_concern"] = wc
if encrypted_fields:
kwargs["encryptedFields"] = encrypted_fields
await db.create_collection(coll_name, **kwargs)
coll = db[coll_name]
if data:
# Load data.
await coll.with_options(write_concern=wc).insert_many(scenario_def["data"])
def allowable_errors(self, op):
"""Override expected error classes."""
errors = super().allowable_errors(op)
# An updateOne test expects encryption to error when no $ operator
# appears but pymongo raises a client side ValueError in this case.
if op["name"] == "updateOne":
errors += (ValueError,)
return errors
def allowable_errors(self, op):
"""Override expected error classes."""
errors = super().allowable_errors(op)
# An updateOne test expects encryption to error when no $ operator
# appears but pymongo raises a client side ValueError in this case.
if op["name"] == "updateOne":
errors += (ValueError,)
return errors
def create_test(scenario_def, test, name):
@async_client_context.require_test_commands
def run_scenario(self):
self.run_scenario(scenario_def, test)
return run_scenario
async def create_test(scenario_def, test, name):
@async_client_context.require_test_commands
async def run_scenario(self):
await self.run_scenario(scenario_def, test)
test_creator = SpecTestCreator(create_test, TestSpec, os.path.join(SPEC_PATH, "legacy"))
test_creator.create_tests()
return run_scenario
if _HAVE_PYMONGOCRYPT:
globals().update(
generate_test_classes(
os.path.join(SPEC_PATH, "unified"),
module=__name__,
)
test_creator = AsyncSpecTestCreator(create_test, AsyncTestSpec, os.path.join(SPEC_PATH, "legacy"))
test_creator.create_tests()
if _HAVE_PYMONGOCRYPT:
globals().update(
generate_test_classes(
os.path.join(SPEC_PATH, "unified"),
module=__name__,
)
)
# Prose Tests
ALL_KMS_PROVIDERS = {

File diff suppressed because it is too large Load Diff

View File

@ -15,8 +15,12 @@
"""Utilities for testing driver specs."""
from __future__ import annotations
import asyncio
import functools
import os
import threading
import unittest
from asyncio import iscoroutinefunction
from collections import abc
from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs
from test.utils import (
@ -24,6 +28,7 @@ from test.utils import (
CompareType,
EventListener,
OvertCommandListener,
ScenarioDict,
ServerAndTopologyEventListener,
camel_to_snake,
camel_to_snake_args,
@ -32,11 +37,12 @@ from test.utils import (
)
from typing import List
from bson import ObjectId, decode, encode
from bson import ObjectId, decode, encode, json_util
from bson.binary import Binary
from bson.int64 import Int64
from bson.son import SON
from gridfs import GridFSBucket
from gridfs.asynchronous.grid_file import AsyncGridFSBucket
from pymongo.asynchronous import client_session
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.cursor import AsyncCursor
@ -83,6 +89,161 @@ class SpecRunnerThread(threading.Thread):
self.stop()
class AsyncSpecTestCreator:
"""Class to create test cases from specifications."""
def __init__(self, create_test, test_class, test_path):
"""Create a TestCreator object.
:Parameters:
- `create_test`: callback that returns a test case. The callback
must accept the following arguments - a dictionary containing the
entire test specification (the `scenario_def`), a dictionary
containing the specification for which the test case will be
generated (the `test_def`).
- `test_class`: the unittest.TestCase class in which to create the
test case.
- `test_path`: path to the directory containing the JSON files with
the test specifications.
"""
self._create_test = create_test
self._test_class = test_class
self.test_path = test_path
def _ensure_min_max_server_version(self, scenario_def, method):
"""Test modifier that enforces a version range for the server on a
test case.
"""
if "minServerVersion" in scenario_def:
min_ver = tuple(int(elt) for elt in scenario_def["minServerVersion"].split("."))
if min_ver is not None:
method = async_client_context.require_version_min(*min_ver)(method)
if "maxServerVersion" in scenario_def:
max_ver = tuple(int(elt) for elt in scenario_def["maxServerVersion"].split("."))
if max_ver is not None:
method = async_client_context.require_version_max(*max_ver)(method)
if "serverless" in scenario_def:
serverless = scenario_def["serverless"]
if serverless == "require":
serverless_satisfied = async_client_context.serverless
elif serverless == "forbid":
serverless_satisfied = not async_client_context.serverless
else: # unset or "allow"
serverless_satisfied = True
method = unittest.skipUnless(
serverless_satisfied, "Serverless requirement not satisfied"
)(method)
return method
@staticmethod
async def valid_topology(run_on_req):
return await async_client_context.is_topology_type(
run_on_req.get("topology", ["single", "replicaset", "sharded", "load-balanced"])
)
@staticmethod
def min_server_version(run_on_req):
version = run_on_req.get("minServerVersion")
if version:
min_ver = tuple(int(elt) for elt in version.split("."))
return async_client_context.version >= min_ver
return True
@staticmethod
def max_server_version(run_on_req):
version = run_on_req.get("maxServerVersion")
if version:
max_ver = tuple(int(elt) for elt in version.split("."))
return async_client_context.version <= max_ver
return True
@staticmethod
def valid_auth_enabled(run_on_req):
if "authEnabled" in run_on_req:
if run_on_req["authEnabled"]:
return async_client_context.auth_enabled
return not async_client_context.auth_enabled
return True
@staticmethod
def serverless_ok(run_on_req):
serverless = run_on_req["serverless"]
if serverless == "require":
return async_client_context.serverless
elif serverless == "forbid":
return not async_client_context.serverless
else: # unset or "allow"
return True
async def should_run_on(self, scenario_def):
run_on = scenario_def.get("runOn", [])
if not run_on:
# Always run these tests.
return True
for req in run_on:
if (
await self.valid_topology(req)
and self.min_server_version(req)
and self.max_server_version(req)
and self.valid_auth_enabled(req)
and self.serverless_ok(req)
):
return True
return False
def ensure_run_on(self, scenario_def, method):
"""Test modifier that enforces a 'runOn' on a test case."""
async def predicate():
return await self.should_run_on(scenario_def)
return async_client_context._require(predicate, "runOn not satisfied", method)
def tests(self, scenario_def):
"""Allow CMAP spec test to override the location of test."""
return scenario_def["tests"]
async def _create_tests(self):
for dirpath, _, filenames in os.walk(self.test_path):
dirname = os.path.split(dirpath)[-1]
for filename in filenames:
with open(os.path.join(dirpath, filename)) as scenario_stream: # noqa: ASYNC101, RUF100
# Use tz_aware=False to match how CodecOptions decodes
# dates.
opts = json_util.JSONOptions(tz_aware=False)
scenario_def = ScenarioDict(
json_util.loads(scenario_stream.read(), json_options=opts)
)
test_type = os.path.splitext(filename)[0]
# Construct test from scenario.
for test_def in self.tests(scenario_def):
test_name = "test_{}_{}_{}".format(
dirname,
test_type.replace("-", "_").replace(".", "_"),
str(test_def["description"].replace(" ", "_").replace(".", "_")),
)
new_test = await self._create_test(scenario_def, test_def, test_name)
new_test = self._ensure_min_max_server_version(scenario_def, new_test)
new_test = self.ensure_run_on(scenario_def, new_test)
new_test.__name__ = test_name
setattr(self._test_class, new_test.__name__, new_test)
def create_tests(self):
if _IS_SYNC:
self._create_tests()
else:
asyncio.run(self._create_tests())
class AsyncSpecRunner(AsyncIntegrationTest):
mongos_clients: List
knobs: client_knobs
@ -284,7 +445,7 @@ class AsyncSpecRunner(AsyncIntegrationTest):
if object_name == "gridfsbucket":
# Only create the GridFSBucket when we need it (for the gridfs
# retryable reads tests).
obj = GridFSBucket(database, bucket_name=collection.name)
obj = AsyncGridFSBucket(database, bucket_name=collection.name)
else:
objects = {
"client": database.client,
@ -312,7 +473,10 @@ class AsyncSpecRunner(AsyncIntegrationTest):
args.update(arguments)
arguments = args
result = cmd(**dict(arguments))
if not _IS_SYNC and iscoroutinefunction(cmd):
result = await cmd(**dict(arguments))
else:
result = cmd(**dict(arguments))
# Cleanup open change stream cursors.
if name == "watch":
self.addAsyncCleanup(result.close)
@ -588,7 +752,7 @@ class AsyncSpecRunner(AsyncIntegrationTest):
read_preference=ReadPreference.PRIMARY,
read_concern=ReadConcern("local"),
)
actual_data = await (await outcome_coll.find(sort=[("_id", 1)])).to_list()
actual_data = await outcome_coll.find(sort=[("_id", 1)]).to_list()
# The expected data needs to be the left hand side here otherwise
# CompareType(Binary) doesn't work.

View File

@ -110,7 +110,7 @@
"listCollections"
],
"blockConnection": true,
"blockTimeMS": 60
"blockTimeMS": 600
}
},
"clientOptions": {
@ -119,7 +119,7 @@
"aws": {}
}
},
"timeoutMS": 50
"timeoutMS": 500
},
"operations": [
{

View File

@ -61,6 +61,13 @@ class TestClientContext(UnitTest):
self.assertEqual(sys.getdefaultencoding(), os.environ["SETDEFAULTENCODING"])
def test_free_threading_is_enabled(self):
if "free-threading build" not in sys.version:
raise SkipTest("this test requires the Python free-threading build")
# If the GIL is enabled then pymongo or one of our deps does not support free-threading.
self.assertFalse(sys._is_gil_enabled()) # type: ignore[attr-defined]
if __name__ == "__main__":
unittest.main()

View File

@ -37,8 +37,11 @@ from pymongo.operations import (
UpdateMany,
UpdateOne,
)
from pymongo.synchronous.helpers import next
from pymongo.write_concern import WriteConcern
_IS_SYNC = True
class TestCollationObject(unittest.TestCase):
def test_constructor(self):
@ -96,8 +99,8 @@ class TestCollation(IntegrationTest):
@classmethod
@client_context.require_connection
def setUpClass(cls):
super().setUpClass()
def _setup_class(cls):
super()._setup_class()
cls.listener = EventListener()
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
cls.db = cls.client.pymongo_test
@ -107,11 +110,11 @@ class TestCollation(IntegrationTest):
warnings.simplefilter("ignore", DeprecationWarning)
@classmethod
def tearDownClass(cls):
def _tearDown_class(cls):
cls.warn_context.__exit__()
cls.warn_context = None
cls.client.close()
super().tearDownClass()
super()._tearDown_class()
def tearDown(self):
self.listener.reset()

View File

@ -16,6 +16,7 @@
from __future__ import annotations
import os
import pathlib
import sys
sys.path[0:0] = [""]
@ -23,8 +24,14 @@ sys.path[0:0] = [""]
from test import unittest
from test.unified_format import generate_test_classes
_IS_SYNC = True
# Location of JSON test specifications.
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "command_logging")
if _IS_SYNC:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_logging")
else:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_logging")
globals().update(
generate_test_classes(

View File

@ -16,6 +16,7 @@
from __future__ import annotations
import os
import pathlib
import sys
sys.path[0:0] = [""]
@ -23,8 +24,13 @@ sys.path[0:0] = [""]
from test import unittest
from test.unified_format import generate_test_classes
_IS_SYNC = True
# Location of JSON test specifications.
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "command_monitoring")
if _IS_SYNC:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_monitoring")
else:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_monitoring")
globals().update(

View File

@ -28,10 +28,7 @@ from bson.objectid import ObjectId
from pymongo.errors import OperationFailure
from pymongo.write_concern import WriteConcern
@client_context.require_connection
def setUpModule():
pass
_IS_SYNC = True
class TestCommon(IntegrationTest):
@ -48,12 +45,12 @@ class TestCommon(IntegrationTest):
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
legacy_opts = coll.codec_options
coll.insert_one({"uu": uu})
self.assertEqual(uu, coll.find_one({"uu": uu})["uu"]) # type: ignore
self.assertEqual(uu, (coll.find_one({"uu": uu}))["uu"]) # type: ignore
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=STANDARD))
self.assertEqual(STANDARD, coll.codec_options.uuid_representation)
self.assertEqual(None, coll.find_one({"uu": uu}))
uul = Binary.from_uuid(uu, PYTHON_LEGACY)
self.assertEqual(uul, coll.find_one({"uu": uul})["uu"]) # type: ignore
self.assertEqual(uul, (coll.find_one({"uu": uul}))["uu"]) # type: ignore
# Test count_documents
self.assertEqual(0, coll.count_documents({"uu": uu}))
@ -73,9 +70,9 @@ class TestCommon(IntegrationTest):
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=STANDARD))
coll.update_one({"_id": uu}, {"$set": {"i": 2}})
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
self.assertEqual(1, coll.find_one({"_id": uu})["i"]) # type: ignore
self.assertEqual(1, (coll.find_one({"_id": uu}))["i"]) # type: ignore
coll.update_one({"_id": uu}, {"$set": {"i": 2}})
self.assertEqual(2, coll.find_one({"_id": uu})["i"]) # type: ignore
self.assertEqual(2, (coll.find_one({"_id": uu}))["i"]) # type: ignore
# Test Cursor.distinct
self.assertEqual([2], coll.find({"_id": uu}).distinct("i"))
@ -85,27 +82,31 @@ class TestCommon(IntegrationTest):
# Test findAndModify
self.assertEqual(None, coll.find_one_and_update({"_id": uu}, {"$set": {"i": 5}}))
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
self.assertEqual(2, coll.find_one_and_update({"_id": uu}, {"$set": {"i": 5}})["i"])
self.assertEqual(5, coll.find_one({"_id": uu})["i"]) # type: ignore
self.assertEqual(2, (coll.find_one_and_update({"_id": uu}, {"$set": {"i": 5}}))["i"])
self.assertEqual(5, (coll.find_one({"_id": uu}))["i"]) # type: ignore
# Test command
self.assertEqual(
5,
self.db.command(
"findAndModify",
"uuid",
update={"$set": {"i": 6}},
query={"_id": uu},
codec_options=legacy_opts,
(
self.db.command(
"findAndModify",
"uuid",
update={"$set": {"i": 6}},
query={"_id": uu},
codec_options=legacy_opts,
)
)["value"]["i"],
)
self.assertEqual(
6,
self.db.command(
"findAndModify",
"uuid",
update={"$set": {"i": 7}},
query={"_id": Binary.from_uuid(uu, PYTHON_LEGACY)},
(
self.db.command(
"findAndModify",
"uuid",
update={"$set": {"i": 7}},
query={"_id": Binary.from_uuid(uu, PYTHON_LEGACY)},
)
)["value"]["i"],
)
@ -140,20 +141,23 @@ class TestCommon(IntegrationTest):
coll.insert_one(doc)
self.assertTrue(coll.insert_one(doc))
coll = coll.with_options(write_concern=WriteConcern(w=1))
self.assertRaises(OperationFailure, coll.insert_one, doc)
with self.assertRaises(OperationFailure):
coll.insert_one(doc)
m = self.rs_or_single_client()
coll = m.pymongo_test.write_concern_test
new_coll = coll.with_options(write_concern=WriteConcern(w=0))
self.assertTrue(new_coll.insert_one(doc))
self.assertRaises(OperationFailure, coll.insert_one, doc)
with self.assertRaises(OperationFailure):
coll.insert_one(doc)
m = self.rs_or_single_client(
f"mongodb://{pair}/", replicaSet=client_context.replica_set_name
)
coll = m.pymongo_test.write_concern_test
self.assertRaises(OperationFailure, coll.insert_one, doc)
with self.assertRaises(OperationFailure):
coll.insert_one(doc)
m = self.rs_or_single_client(
f"mongodb://{pair}/?w=0", replicaSet=client_context.replica_set_name
)

View File

@ -16,6 +16,7 @@
from __future__ import annotations
import os
import pathlib
import sys
sys.path[0:0] = [""]
@ -23,8 +24,13 @@ sys.path[0:0] = [""]
from test import unittest
from test.unified_format import generate_test_classes
_IS_SYNC = True
# Location of JSON test specifications.
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "connection_logging")
if _IS_SYNC:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "connection_logging")
else:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "connection_logging")
globals().update(

View File

@ -25,14 +25,13 @@ from test import IntegrationTest, client_knobs, unittest
from test.pymongo_mocks import DummyMonitor
from test.utils import (
CMAPListener,
SpecTestCreator,
camel_to_snake,
client_context,
get_pool,
get_pools,
wait_until,
)
from test.utils_spec_runner import SpecRunnerThread
from test.utils_spec_runner import SpecRunnerThread, SpecTestCreator
from bson.objectid import ObjectId
from bson.son import SON

View File

@ -16,6 +16,7 @@
from __future__ import annotations
import os
import pathlib
import sys
sys.path[0:0] = [""]
@ -23,11 +24,16 @@ sys.path[0:0] = [""]
from test import unittest
from test.unified_format import generate_test_classes
_IS_SYNC = True
# Location of JSON test specifications.
TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "crud", "unified")
if _IS_SYNC:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "crud", "unified")
else:
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "crud", "unified")
# Generate unified tests.
globals().update(generate_test_classes(TEST_PATH, module=__name__, RUN_ON_SERVERLESS=True))
globals().update(generate_test_classes(_TEST_PATH, module=__name__, RUN_ON_SERVERLESS=True))
if __name__ == "__main__":
unittest.main()

View File

@ -30,6 +30,7 @@ import uuid
import warnings
from test import IntegrationTest, PyMongoTestCase, client_context
from test.test_bulk import BulkTestBase
from test.utils_spec_runner import SpecRunner, SpecTestCreator
from threading import Thread
from typing import Any, Dict, Mapping, Optional
@ -58,7 +59,6 @@ from test.unified_format import generate_test_classes
from test.utils import (
AllowListEventListener,
OvertCommandListener,
SpecTestCreator,
TopologyEventListener,
camel_to_snake_args,
is_greenthread_patched,
@ -624,130 +624,132 @@ AWS_TEMP_NO_SESSION_CREDS = {
KMS_TLS_OPTS = {"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}}
if _IS_SYNC:
# TODO: Add synchronous SpecRunner (https://jira.mongodb.org/browse/PYTHON-4700)
class TestSpec(SpecRunner):
@classmethod
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
def setUpClass(cls):
super().setUpClass()
class TestSpec(SpecRunner):
@classmethod
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
def _setup_class(cls):
super()._setup_class()
def parse_auto_encrypt_opts(self, opts):
"""Parse clientOptions.autoEncryptOpts."""
opts = camel_to_snake_args(opts)
kms_providers = opts["kms_providers"]
if "aws" in kms_providers:
kms_providers["aws"] = AWS_CREDS
if not any(AWS_CREDS.values()):
self.skipTest("AWS environment credentials are not set")
if "awsTemporary" in kms_providers:
kms_providers["aws"] = AWS_TEMP_CREDS
del kms_providers["awsTemporary"]
if not any(AWS_TEMP_CREDS.values()):
self.skipTest("AWS Temp environment credentials are not set")
if "awsTemporaryNoSessionToken" in kms_providers:
kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS
del kms_providers["awsTemporaryNoSessionToken"]
if not any(AWS_TEMP_NO_SESSION_CREDS.values()):
self.skipTest("AWS Temp environment credentials are not set")
if "azure" in kms_providers:
kms_providers["azure"] = AZURE_CREDS
if not any(AZURE_CREDS.values()):
self.skipTest("Azure environment credentials are not set")
if "gcp" in kms_providers:
kms_providers["gcp"] = GCP_CREDS
if not any(AZURE_CREDS.values()):
self.skipTest("GCP environment credentials are not set")
if "kmip" in kms_providers:
kms_providers["kmip"] = KMIP_CREDS
opts["kms_tls_options"] = KMS_TLS_OPTS
if "key_vault_namespace" not in opts:
opts["key_vault_namespace"] = "keyvault.datakeys"
if "extra_options" in opts:
opts.update(camel_to_snake_args(opts.pop("extra_options")))
def parse_auto_encrypt_opts(self, opts):
"""Parse clientOptions.autoEncryptOpts."""
opts = camel_to_snake_args(opts)
kms_providers = opts["kms_providers"]
if "aws" in kms_providers:
kms_providers["aws"] = AWS_CREDS
if not any(AWS_CREDS.values()):
self.skipTest("AWS environment credentials are not set")
if "awsTemporary" in kms_providers:
kms_providers["aws"] = AWS_TEMP_CREDS
del kms_providers["awsTemporary"]
if not any(AWS_TEMP_CREDS.values()):
self.skipTest("AWS Temp environment credentials are not set")
if "awsTemporaryNoSessionToken" in kms_providers:
kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS
del kms_providers["awsTemporaryNoSessionToken"]
if not any(AWS_TEMP_NO_SESSION_CREDS.values()):
self.skipTest("AWS Temp environment credentials are not set")
if "azure" in kms_providers:
kms_providers["azure"] = AZURE_CREDS
if not any(AZURE_CREDS.values()):
self.skipTest("Azure environment credentials are not set")
if "gcp" in kms_providers:
kms_providers["gcp"] = GCP_CREDS
if not any(AZURE_CREDS.values()):
self.skipTest("GCP environment credentials are not set")
if "kmip" in kms_providers:
kms_providers["kmip"] = KMIP_CREDS
opts["kms_tls_options"] = KMS_TLS_OPTS
if "key_vault_namespace" not in opts:
opts["key_vault_namespace"] = "keyvault.datakeys"
if "extra_options" in opts:
opts.update(camel_to_snake_args(opts.pop("extra_options")))
opts = dict(opts)
return AutoEncryptionOpts(**opts)
opts = dict(opts)
return AutoEncryptionOpts(**opts)
def parse_client_options(self, opts):
"""Override clientOptions parsing to support autoEncryptOpts."""
encrypt_opts = opts.pop("autoEncryptOpts", None)
if encrypt_opts:
opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts)
def parse_client_options(self, opts):
"""Override clientOptions parsing to support autoEncryptOpts."""
encrypt_opts = opts.pop("autoEncryptOpts", None)
if encrypt_opts:
opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts)
return super().parse_client_options(opts)
return super().parse_client_options(opts)
def get_object_name(self, op):
"""Default object is collection."""
return op.get("object", "collection")
def get_object_name(self, op):
"""Default object is collection."""
return op.get("object", "collection")
def maybe_skip_scenario(self, test):
super().maybe_skip_scenario(test)
desc = test["description"].lower()
if (
"timeoutms applied to listcollections to get collection schema" in desc
and sys.platform in ("win32", "darwin")
):
self.skipTest("PYTHON-3706 flaky test on Windows/macOS")
if "type=symbol" in desc:
self.skipTest("PyMongo does not support the symbol type")
def maybe_skip_scenario(self, test):
super().maybe_skip_scenario(test)
desc = test["description"].lower()
if (
"timeoutms applied to listcollections to get collection schema" in desc
and sys.platform in ("win32", "darwin")
):
self.skipTest("PYTHON-3706 flaky test on Windows/macOS")
if "type=symbol" in desc:
self.skipTest("PyMongo does not support the symbol type")
if "timeoutms applied to listcollections to get collection schema" in desc and not _IS_SYNC:
self.skipTest("PYTHON-4844 flaky test on async")
def setup_scenario(self, scenario_def):
"""Override a test's setup."""
key_vault_data = scenario_def["key_vault_data"]
encrypted_fields = scenario_def["encrypted_fields"]
json_schema = scenario_def["json_schema"]
data = scenario_def["data"]
coll = client_context.client.get_database("keyvault", codec_options=OPTS)["datakeys"]
coll.delete_many({})
if key_vault_data:
coll.insert_many(key_vault_data)
def setup_scenario(self, scenario_def):
"""Override a test's setup."""
key_vault_data = scenario_def["key_vault_data"]
encrypted_fields = scenario_def["encrypted_fields"]
json_schema = scenario_def["json_schema"]
data = scenario_def["data"]
coll = client_context.client.get_database("keyvault", codec_options=OPTS)["datakeys"]
coll.delete_many({})
if key_vault_data:
coll.insert_many(key_vault_data)
db_name = self.get_scenario_db_name(scenario_def)
coll_name = self.get_scenario_coll_name(scenario_def)
db = client_context.client.get_database(db_name, codec_options=OPTS)
coll = db.drop_collection(coll_name, encrypted_fields=encrypted_fields)
wc = WriteConcern(w="majority")
kwargs: Dict[str, Any] = {}
if json_schema:
kwargs["validator"] = {"$jsonSchema": json_schema}
kwargs["codec_options"] = OPTS
if not data:
kwargs["write_concern"] = wc
if encrypted_fields:
kwargs["encryptedFields"] = encrypted_fields
db.create_collection(coll_name, **kwargs)
coll = db[coll_name]
if data:
# Load data.
coll.with_options(write_concern=wc).insert_many(scenario_def["data"])
db_name = self.get_scenario_db_name(scenario_def)
coll_name = self.get_scenario_coll_name(scenario_def)
db = client_context.client.get_database(db_name, codec_options=OPTS)
db.drop_collection(coll_name, encrypted_fields=encrypted_fields)
wc = WriteConcern(w="majority")
kwargs: Dict[str, Any] = {}
if json_schema:
kwargs["validator"] = {"$jsonSchema": json_schema}
kwargs["codec_options"] = OPTS
if not data:
kwargs["write_concern"] = wc
if encrypted_fields:
kwargs["encryptedFields"] = encrypted_fields
db.create_collection(coll_name, **kwargs)
coll = db[coll_name]
if data:
# Load data.
coll.with_options(write_concern=wc).insert_many(scenario_def["data"])
def allowable_errors(self, op):
"""Override expected error classes."""
errors = super().allowable_errors(op)
# An updateOne test expects encryption to error when no $ operator
# appears but pymongo raises a client side ValueError in this case.
if op["name"] == "updateOne":
errors += (ValueError,)
return errors
def allowable_errors(self, op):
"""Override expected error classes."""
errors = super().allowable_errors(op)
# An updateOne test expects encryption to error when no $ operator
# appears but pymongo raises a client side ValueError in this case.
if op["name"] == "updateOne":
errors += (ValueError,)
return errors
def create_test(scenario_def, test, name):
@client_context.require_test_commands
def run_scenario(self):
self.run_scenario(scenario_def, test)
return run_scenario
def create_test(scenario_def, test, name):
@client_context.require_test_commands
def run_scenario(self):
self.run_scenario(scenario_def, test)
test_creator = SpecTestCreator(create_test, TestSpec, os.path.join(SPEC_PATH, "legacy"))
test_creator.create_tests()
return run_scenario
if _HAVE_PYMONGOCRYPT:
globals().update(
generate_test_classes(
os.path.join(SPEC_PATH, "unified"),
module=__name__,
)
test_creator = SpecTestCreator(create_test, TestSpec, os.path.join(SPEC_PATH, "legacy"))
test_creator.create_tests()
if _HAVE_PYMONGOCRYPT:
globals().update(
generate_test_classes(
os.path.join(SPEC_PATH, "unified"),
module=__name__,
)
)
# Prose Tests
ALL_KMS_PROVIDERS = {

View File

@ -21,11 +21,11 @@ from test import IntegrationTest, client_context, unittest
from test.utils import (
CMAPListener,
OvertCommandListener,
SpecTestCreator,
get_pool,
wait_until,
)
from test.utils_selection_tests import create_topology
from test.utils_spec_runner import SpecTestCreator
from pymongo.common import clean_node
from pymongo.monitoring import ConnectionReadyEvent

View File

@ -18,41 +18,41 @@ https://github.com/mongodb/specifications/blob/master/source/unified-test-format
"""
from __future__ import annotations
import asyncio
import binascii
import collections
import copy
import datetime
import functools
import os
import re
import sys
import time
import traceback
import types
from collections import abc, defaultdict
from asyncio import iscoroutinefunction
from collections import defaultdict
from test import (
IntegrationTest,
client_context,
client_knobs,
unittest,
)
from test.helpers import (
AWS_CREDS,
AWS_CREDS_2,
AZURE_CREDS,
CA_PEM,
CLIENT_PEM,
GCP_CREDS,
KMIP_CREDS,
LOCAL_MASTER_KEY,
client_knobs,
from test.unified_format_shared import (
IS_INTERRUPTED,
KMS_TLS_OPTS,
PLACEHOLDER_MAP,
SKIP_CSOT_TESTS,
EventListenerUtil,
MatchEvaluatorUtil,
coerce_result,
parse_bulk_write_error_result,
parse_bulk_write_result,
parse_client_bulk_write_error_result,
parse_collection_or_database_options,
with_metaclass,
)
from test.utils import (
CMAPListener,
camel_to_snake,
camel_to_snake_args,
get_pool,
parse_collection_options,
parse_spec_options,
prepare_spec_arguments,
snake_to_camel,
@ -60,14 +60,12 @@ from test.utils import (
)
from test.utils_spec_runner import SpecRunnerThread
from test.version import Version
from typing import Any, Dict, List, Mapping, Optional, Union
from typing import Any, Dict, List, Mapping, Optional
import pymongo
from bson import SON, Code, DBRef, Decimal128, Int64, MaxKey, MinKey, json_util
from bson.binary import Binary
from bson import SON, json_util
from bson.codec_options import DEFAULT_CODEC_OPTIONS
from bson.objectid import ObjectId
from bson.regex import RE_TYPE, Regex
from gridfs import GridFSBucket, GridOut
from pymongo import ASCENDING, CursorType, MongoClient, _csot
from pymongo.encryption_options import _HAVE_PYMONGOCRYPT
@ -83,55 +81,14 @@ from pymongo.errors import (
PyMongoError,
)
from pymongo.monitoring import (
_SENSITIVE_COMMANDS,
CommandFailedEvent,
CommandListener,
CommandStartedEvent,
CommandSucceededEvent,
ConnectionCheckedInEvent,
ConnectionCheckedOutEvent,
ConnectionCheckOutFailedEvent,
ConnectionCheckOutStartedEvent,
ConnectionClosedEvent,
ConnectionCreatedEvent,
ConnectionReadyEvent,
PoolClearedEvent,
PoolClosedEvent,
PoolCreatedEvent,
PoolReadyEvent,
ServerClosedEvent,
ServerDescriptionChangedEvent,
ServerHeartbeatFailedEvent,
ServerHeartbeatListener,
ServerHeartbeatStartedEvent,
ServerHeartbeatSucceededEvent,
ServerListener,
ServerOpeningEvent,
TopologyClosedEvent,
TopologyDescriptionChangedEvent,
TopologyEvent,
TopologyListener,
TopologyOpenedEvent,
_CommandEvent,
_ConnectionEvent,
_PoolEvent,
_ServerEvent,
_ServerHeartbeatEvent,
)
from pymongo.operations import (
DeleteMany,
DeleteOne,
InsertOne,
ReplaceOne,
SearchIndexModel,
UpdateMany,
UpdateOne,
)
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
from pymongo.results import BulkWriteResult, ClientBulkWriteResult
from pymongo.server_api import ServerApi
from pymongo.server_description import ServerDescription
from pymongo.server_selectors import Selection, writable_server_selector
from pymongo.server_type import SERVER_TYPE
from pymongo.synchronous.change_stream import ChangeStream
@ -140,85 +97,12 @@ from pymongo.synchronous.collection import Collection
from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.database import Database
from pymongo.synchronous.encryption import ClientEncryption
from pymongo.synchronous.helpers import next
from pymongo.topology_description import TopologyDescription
from pymongo.typings import _Address
from pymongo.write_concern import WriteConcern
SKIP_CSOT_TESTS = os.getenv("SKIP_CSOT_TESTS")
JSON_OPTS = json_util.JSONOptions(tz_aware=False)
IS_INTERRUPTED = False
KMS_TLS_OPTS = {
"kmip": {
"tlsCAFile": CA_PEM,
"tlsCertificateKeyFile": CLIENT_PEM,
}
}
# Build up a placeholder maps.
PLACEHOLDER_MAP = {}
for provider_name, provider_data in [
("local", {"key": LOCAL_MASTER_KEY}),
("local:name1", {"key": LOCAL_MASTER_KEY}),
("aws", AWS_CREDS),
("aws:name1", AWS_CREDS),
("aws:name2", AWS_CREDS_2),
("azure", AZURE_CREDS),
("azure:name1", AZURE_CREDS),
("gcp", GCP_CREDS),
("gcp:name1", GCP_CREDS),
("kmip", KMIP_CREDS),
("kmip:name1", KMIP_CREDS),
]:
for key, value in provider_data.items():
placeholder = f"/clientEncryptionOpts/kmsProviders/{provider_name}/{key}"
PLACEHOLDER_MAP[placeholder] = value
OIDC_ENV = os.environ.get("OIDC_ENV", "test")
if OIDC_ENV == "test":
PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = {"ENVIRONMENT": "test"}
elif OIDC_ENV == "azure":
PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = {
"ENVIRONMENT": "azure",
"TOKEN_RESOURCE": os.environ["AZUREOIDC_RESOURCE"],
}
elif OIDC_ENV == "gcp":
PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = {
"ENVIRONMENT": "gcp",
"TOKEN_RESOURCE": os.environ["GCPOIDC_AUDIENCE"],
}
def interrupt_loop():
global IS_INTERRUPTED
IS_INTERRUPTED = True
def with_metaclass(meta, *bases):
"""Create a base class with a metaclass.
Vendored from six: https://github.com/benjaminp/six/blob/master/six.py
"""
# This requires a bit of explanation: the basic idea is to make a dummy
# metaclass for one level of class instantiation that replaces itself with
# the actual metaclass.
class metaclass(type):
def __new__(cls, name, this_bases, d):
# __orig_bases__ is required by PEP 560.
resolved_bases = types.resolve_bases(bases)
if resolved_bases is not bases:
d["__orig_bases__"] = bases
return meta(name, resolved_bases, d)
@classmethod
def __prepare__(cls, name, this_bases):
return meta.__prepare__(name, bases)
return type.__new__(metaclass, "temporary_class", (), {})
_IS_SYNC = True
def is_run_on_requirement_satisfied(requirement):
@ -283,77 +167,6 @@ def is_run_on_requirement_satisfied(requirement):
)
def parse_collection_or_database_options(options):
return parse_collection_options(options)
def parse_bulk_write_result(result):
upserted_ids = {str(int_idx): result.upserted_ids[int_idx] for int_idx in result.upserted_ids}
return {
"deletedCount": result.deleted_count,
"insertedCount": result.inserted_count,
"matchedCount": result.matched_count,
"modifiedCount": result.modified_count,
"upsertedCount": result.upserted_count,
"upsertedIds": upserted_ids,
}
def parse_client_bulk_write_individual(op_type, result):
if op_type == "insert":
return {"insertedId": result.inserted_id}
if op_type == "update":
if result.upserted_id:
return {
"matchedCount": result.matched_count,
"modifiedCount": result.modified_count,
"upsertedId": result.upserted_id,
}
else:
return {
"matchedCount": result.matched_count,
"modifiedCount": result.modified_count,
}
if op_type == "delete":
return {
"deletedCount": result.deleted_count,
}
def parse_client_bulk_write_result(result):
insert_results, update_results, delete_results = {}, {}, {}
if result.has_verbose_results:
for idx, res in result.insert_results.items():
insert_results[str(idx)] = parse_client_bulk_write_individual("insert", res)
for idx, res in result.update_results.items():
update_results[str(idx)] = parse_client_bulk_write_individual("update", res)
for idx, res in result.delete_results.items():
delete_results[str(idx)] = parse_client_bulk_write_individual("delete", res)
return {
"deletedCount": result.deleted_count,
"insertedCount": result.inserted_count,
"matchedCount": result.matched_count,
"modifiedCount": result.modified_count,
"upsertedCount": result.upserted_count,
"insertResults": insert_results,
"updateResults": update_results,
"deleteResults": delete_results,
}
def parse_bulk_write_error_result(error):
write_result = BulkWriteResult(error.details, True)
return parse_bulk_write_result(write_result)
def parse_client_bulk_write_error_result(error):
write_result = error.partial_result
if not write_result:
return None
return parse_client_bulk_write_result(write_result)
class NonLazyCursor:
"""A find cursor proxy that creates the remote cursor when initialized."""
@ -361,7 +174,16 @@ class NonLazyCursor:
self.client = client
self.find_cursor = find_cursor
# Create the server side cursor.
self.first_result = next(find_cursor, None)
self.first_result = None
@classmethod
def create(cls, find_cursor, client):
cursor = cls(find_cursor, client)
try:
cursor.first_result = next(cursor.find_cursor)
except StopIteration:
cursor.first_result = None
return cursor
@property
def alive(self):
@ -382,105 +204,6 @@ class NonLazyCursor:
self.client = None
class EventListenerUtil(
CMAPListener, CommandListener, ServerListener, ServerHeartbeatListener, TopologyListener
):
def __init__(
self, observe_events, ignore_commands, observe_sensitive_commands, store_events, entity_map
):
self._event_types = {name.lower() for name in observe_events}
if observe_sensitive_commands:
self._observe_sensitive_commands = True
self._ignore_commands = set(ignore_commands)
else:
self._observe_sensitive_commands = False
self._ignore_commands = _SENSITIVE_COMMANDS | set(ignore_commands)
self._ignore_commands.add("configurefailpoint")
self._event_mapping = collections.defaultdict(list)
self.entity_map = entity_map
if store_events:
for i in store_events:
id = i["id"]
events = (i.lower() for i in i["events"])
for i in events:
self._event_mapping[i].append(id)
self.entity_map[id] = []
super().__init__()
def get_events(self, event_type):
assert event_type in ("command", "cmap", "sdam", "all"), event_type
if event_type == "all":
return list(self.events)
if event_type == "command":
return [e for e in self.events if isinstance(e, _CommandEvent)]
if event_type == "cmap":
return [e for e in self.events if isinstance(e, (_ConnectionEvent, _PoolEvent))]
return [
e
for e in self.events
if isinstance(e, (_ServerEvent, TopologyEvent, _ServerHeartbeatEvent))
]
def add_event(self, event):
event_name = type(event).__name__.lower()
if event_name in self._event_types:
super().add_event(event)
for id in self._event_mapping[event_name]:
self.entity_map[id].append(
{
"name": type(event).__name__,
"observedAt": time.time(),
"description": repr(event),
}
)
def _command_event(self, event):
if event.command_name.lower() not in self._ignore_commands:
self.add_event(event)
def started(self, event):
if isinstance(event, CommandStartedEvent):
if event.command == {}:
# Command is redacted. Observe only if flag is set.
if self._observe_sensitive_commands:
self._command_event(event)
else:
self._command_event(event)
else:
self.add_event(event)
def succeeded(self, event):
if isinstance(event, CommandSucceededEvent):
if event.reply == {}:
# Command is redacted. Observe only if flag is set.
if self._observe_sensitive_commands:
self._command_event(event)
else:
self._command_event(event)
else:
self.add_event(event)
def failed(self, event):
if isinstance(event, CommandFailedEvent):
self._command_event(event)
else:
self.add_event(event)
def opened(self, event: Union[ServerOpeningEvent, TopologyOpenedEvent]) -> None:
self.add_event(event)
def description_changed(
self, event: Union[ServerDescriptionChangedEvent, TopologyDescriptionChangedEvent]
) -> None:
self.add_event(event)
def topology_changed(self, event: TopologyDescriptionChangedEvent) -> None:
self.add_event(event)
def closed(self, event: Union[ServerClosedEvent, TopologyClosedEvent]) -> None:
self.add_event(event)
class EntityMapUtil:
"""Utility class that implements an entity map as per the unified
test format specification.
@ -692,353 +415,12 @@ class EntityMapUtil:
def advance_cluster_times(self) -> None:
"""Manually synchronize entities when desired"""
if not self._cluster_time:
self._cluster_time = self.test.client.admin.command("ping").get("$clusterTime")
self._cluster_time = (self.test.client.admin.command("ping")).get("$clusterTime")
for entity in self._entities.values():
if isinstance(entity, ClientSession) and self._cluster_time:
entity.advance_cluster_time(self._cluster_time)
binary_types = (Binary, bytes)
long_types = (Int64,)
unicode_type = str
BSON_TYPE_ALIAS_MAP = {
# https://mongodb.com/docs/manual/reference/operator/query/type/
# https://pymongo.readthedocs.io/en/stable/api/bson/index.html
"double": (float,),
"string": (str,),
"object": (abc.Mapping,),
"array": (abc.MutableSequence,),
"binData": binary_types,
"undefined": (type(None),),
"objectId": (ObjectId,),
"bool": (bool,),
"date": (datetime.datetime,),
"null": (type(None),),
"regex": (Regex, RE_TYPE),
"dbPointer": (DBRef,),
"javascript": (unicode_type, Code),
"symbol": (unicode_type,),
"javascriptWithScope": (unicode_type, Code),
"int": (int,),
"long": (Int64,),
"decimal": (Decimal128,),
"maxKey": (MaxKey,),
"minKey": (MinKey,),
}
class MatchEvaluatorUtil:
"""Utility class that implements methods for evaluating matches as per
the unified test format specification.
"""
def __init__(self, test_class):
self.test = test_class
def _operation_exists(self, spec, actual, key_to_compare):
if spec is True:
if key_to_compare is None:
assert actual is not None
else:
self.test.assertIn(key_to_compare, actual)
elif spec is False:
if key_to_compare is None:
assert actual is None
else:
self.test.assertNotIn(key_to_compare, actual)
else:
self.test.fail(f"Expected boolean value for $$exists operator, got {spec}")
def __type_alias_to_type(self, alias):
if alias not in BSON_TYPE_ALIAS_MAP:
self.test.fail(f"Unrecognized BSON type alias {alias}")
return BSON_TYPE_ALIAS_MAP[alias]
def _operation_type(self, spec, actual, key_to_compare):
if isinstance(spec, abc.MutableSequence):
permissible_types = tuple(
[t for alias in spec for t in self.__type_alias_to_type(alias)]
)
else:
permissible_types = self.__type_alias_to_type(spec)
value = actual[key_to_compare] if key_to_compare else actual
self.test.assertIsInstance(value, permissible_types)
def _operation_matchesEntity(self, spec, actual, key_to_compare):
expected_entity = self.test.entity_map[spec]
self.test.assertEqual(expected_entity, actual[key_to_compare])
def _operation_matchesHexBytes(self, spec, actual, key_to_compare):
expected = binascii.unhexlify(spec)
value = actual[key_to_compare] if key_to_compare else actual
self.test.assertEqual(value, expected)
def _operation_unsetOrMatches(self, spec, actual, key_to_compare):
if key_to_compare is None and not actual:
# top-level document can be None when unset
return
if key_to_compare not in actual:
# we add a dummy value for the compared key to pass map size check
actual[key_to_compare] = "dummyValue"
return
self.match_result(spec, actual[key_to_compare], in_recursive_call=True)
def _operation_sessionLsid(self, spec, actual, key_to_compare):
expected_lsid = self.test.entity_map.get_lsid_for_session(spec)
self.test.assertEqual(expected_lsid, actual[key_to_compare])
def _operation_lte(self, spec, actual, key_to_compare):
if key_to_compare not in actual:
self.test.fail(f"Actual command is missing the {key_to_compare} field: {spec}")
self.test.assertLessEqual(actual[key_to_compare], spec)
def _operation_matchAsDocument(self, spec, actual, key_to_compare):
self._match_document(spec, json_util.loads(actual[key_to_compare]), False)
def _operation_matchAsRoot(self, spec, actual, key_to_compare):
self._match_document(spec, actual, True)
def _evaluate_special_operation(self, opname, spec, actual, key_to_compare):
method_name = "_operation_{}".format(opname.strip("$"))
try:
method = getattr(self, method_name)
except AttributeError:
self.test.fail(f"Unsupported special matching operator {opname}")
else:
method(spec, actual, key_to_compare)
def _evaluate_if_special_operation(self, expectation, actual, key_to_compare=None):
"""Returns True if a special operation is evaluated, False
otherwise. If the ``expectation`` map contains a single key,
value pair we check it for a special operation.
If given, ``key_to_compare`` is assumed to be the key in
``expectation`` whose corresponding value needs to be
evaluated for a possible special operation. ``key_to_compare``
is ignored when ``expectation`` has only one key.
"""
if not isinstance(expectation, abc.Mapping):
return False
is_special_op, opname, spec = False, False, False
if key_to_compare is not None:
if key_to_compare.startswith("$$"):
is_special_op = True
opname = key_to_compare
spec = expectation[key_to_compare]
key_to_compare = None
else:
nested = expectation[key_to_compare]
if isinstance(nested, abc.Mapping) and len(nested) == 1:
opname, spec = next(iter(nested.items()))
if opname.startswith("$$"):
is_special_op = True
elif len(expectation) == 1:
opname, spec = next(iter(expectation.items()))
if opname.startswith("$$"):
is_special_op = True
key_to_compare = None
if is_special_op:
self._evaluate_special_operation(
opname=opname, spec=spec, actual=actual, key_to_compare=key_to_compare
)
return True
return False
def _match_document(self, expectation, actual, is_root, test=False):
if self._evaluate_if_special_operation(expectation, actual):
return
self.test.assertIsInstance(actual, abc.Mapping)
for key, value in expectation.items():
if self._evaluate_if_special_operation(expectation, actual, key):
continue
self.test.assertIn(key, actual)
if not self.match_result(value, actual[key], in_recursive_call=True, test=test):
return False
if not is_root:
expected_keys = set(expectation.keys())
for key, value in expectation.items():
if value == {"$$exists": False}:
expected_keys.remove(key)
if test:
self.test.assertEqual(expected_keys, set(actual.keys()))
else:
return set(expected_keys).issubset(set(actual.keys()))
return True
def match_result(self, expectation, actual, in_recursive_call=False, test=True):
if isinstance(expectation, abc.Mapping):
return self._match_document(
expectation, actual, is_root=not in_recursive_call, test=test
)
if isinstance(expectation, abc.MutableSequence):
self.test.assertIsInstance(actual, abc.MutableSequence)
for e, a in zip(expectation, actual):
if isinstance(e, abc.Mapping):
self._match_document(e, a, is_root=not in_recursive_call, test=test)
else:
self.match_result(e, a, in_recursive_call=True, test=test)
return None
# account for flexible numerics in element-wise comparison
if isinstance(expectation, int) or isinstance(expectation, float):
if test:
self.test.assertEqual(expectation, actual)
else:
return expectation == actual
return None
else:
if test:
self.test.assertIsInstance(actual, type(expectation))
self.test.assertEqual(expectation, actual)
else:
return isinstance(actual, type(expectation)) and expectation == actual
return None
def match_server_description(self, actual: ServerDescription, spec: dict) -> None:
for field, expected in spec.items():
field = camel_to_snake(field)
if field == "type":
field = "server_type_name"
self.test.assertEqual(getattr(actual, field), expected)
def match_topology_description(self, actual: TopologyDescription, spec: dict) -> None:
for field, expected in spec.items():
field = camel_to_snake(field)
if field == "type":
field = "topology_type_name"
self.test.assertEqual(getattr(actual, field), expected)
def match_event_fields(self, actual: Any, spec: dict) -> None:
for field, expected in spec.items():
if field == "command" and isinstance(actual, CommandStartedEvent):
command = spec["command"]
if command:
self.match_result(command, actual.command)
continue
if field == "reply" and isinstance(actual, CommandSucceededEvent):
reply = spec["reply"]
if reply:
self.match_result(reply, actual.reply)
continue
if field == "hasServiceId":
if spec["hasServiceId"]:
self.test.assertIsNotNone(actual.service_id)
self.test.assertIsInstance(actual.service_id, ObjectId)
else:
self.test.assertIsNone(actual.service_id)
continue
if field == "hasServerConnectionId":
if spec["hasServerConnectionId"]:
self.test.assertIsNotNone(actual.server_connection_id)
self.test.assertIsInstance(actual.server_connection_id, int)
else:
self.test.assertIsNone(actual.server_connection_id)
continue
if field in ("previousDescription", "newDescription"):
if isinstance(actual, ServerDescriptionChangedEvent):
self.match_server_description(
getattr(actual, camel_to_snake(field)), spec[field]
)
continue
if isinstance(actual, TopologyDescriptionChangedEvent):
self.match_topology_description(
getattr(actual, camel_to_snake(field)), spec[field]
)
continue
if field == "interruptInUseConnections":
field = "interrupt_connections"
else:
field = camel_to_snake(field)
self.test.assertEqual(getattr(actual, field), expected)
def match_event(self, expectation, actual):
name, spec = next(iter(expectation.items()))
if name == "commandStartedEvent":
self.test.assertIsInstance(actual, CommandStartedEvent)
elif name == "commandSucceededEvent":
self.test.assertIsInstance(actual, CommandSucceededEvent)
elif name == "commandFailedEvent":
self.test.assertIsInstance(actual, CommandFailedEvent)
elif name == "poolCreatedEvent":
self.test.assertIsInstance(actual, PoolCreatedEvent)
elif name == "poolReadyEvent":
self.test.assertIsInstance(actual, PoolReadyEvent)
elif name == "poolClearedEvent":
self.test.assertIsInstance(actual, PoolClearedEvent)
self.test.assertIsInstance(actual.interrupt_connections, bool)
elif name == "poolClosedEvent":
self.test.assertIsInstance(actual, PoolClosedEvent)
elif name == "connectionCreatedEvent":
self.test.assertIsInstance(actual, ConnectionCreatedEvent)
elif name == "connectionReadyEvent":
self.test.assertIsInstance(actual, ConnectionReadyEvent)
elif name == "connectionClosedEvent":
self.test.assertIsInstance(actual, ConnectionClosedEvent)
elif name == "connectionCheckOutStartedEvent":
self.test.assertIsInstance(actual, ConnectionCheckOutStartedEvent)
elif name == "connectionCheckOutFailedEvent":
self.test.assertIsInstance(actual, ConnectionCheckOutFailedEvent)
elif name == "connectionCheckedOutEvent":
self.test.assertIsInstance(actual, ConnectionCheckedOutEvent)
elif name == "connectionCheckedInEvent":
self.test.assertIsInstance(actual, ConnectionCheckedInEvent)
elif name == "serverDescriptionChangedEvent":
self.test.assertIsInstance(actual, ServerDescriptionChangedEvent)
elif name == "serverHeartbeatStartedEvent":
self.test.assertIsInstance(actual, ServerHeartbeatStartedEvent)
elif name == "serverHeartbeatSucceededEvent":
self.test.assertIsInstance(actual, ServerHeartbeatSucceededEvent)
elif name == "serverHeartbeatFailedEvent":
self.test.assertIsInstance(actual, ServerHeartbeatFailedEvent)
elif name == "topologyDescriptionChangedEvent":
self.test.assertIsInstance(actual, TopologyDescriptionChangedEvent)
elif name == "topologyOpeningEvent":
self.test.assertIsInstance(actual, TopologyOpenedEvent)
elif name == "topologyClosedEvent":
self.test.assertIsInstance(actual, TopologyClosedEvent)
else:
raise Exception(f"Unsupported event type {name}")
self.match_event_fields(actual, spec)
def coerce_result(opname, result):
"""Convert a pymongo result into the spec's result format."""
if hasattr(result, "acknowledged") and not result.acknowledged:
return {"acknowledged": False}
if opname == "bulkWrite":
return parse_bulk_write_result(result)
if opname == "clientBulkWrite":
return parse_client_bulk_write_result(result)
if opname == "insertOne":
return {"insertedId": result.inserted_id}
if opname == "insertMany":
return dict(enumerate(result.inserted_ids))
if opname in ("deleteOne", "deleteMany"):
return {"deletedCount": result.deleted_count}
if opname in ("updateOne", "updateMany", "replaceOne"):
value = {
"matchedCount": result.matched_count,
"modifiedCount": result.modified_count,
"upsertedCount": 0 if result.upserted_id is None else 1,
}
if result.upserted_id is not None:
value["upsertedId"] = result.upserted_id
return value
return result
class UnifiedSpecTestMixinV1(IntegrationTest):
"""Mixin class to run test cases from test specification files.
@ -1090,9 +472,9 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
db.create_collection(coll_name, write_concern=wc, **opts)
@classmethod
def setUpClass(cls):
def _setup_class(cls):
# super call creates internal client cls.client
super().setUpClass()
super()._setup_class()
# process file-level runOnRequirements
run_on_spec = cls.TEST_SPEC.get("runOnRequirements", [])
if not cls.should_run_on(run_on_spec):
@ -1125,11 +507,11 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
cls.knobs.enable()
@classmethod
def tearDownClass(cls):
def _tearDown_class(cls):
cls.knobs.disable()
for client in cls.mongos_clients:
client.close()
super().tearDownClass()
super()._tearDown_class()
def setUp(self):
super().setUp()
@ -1391,7 +773,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
def __entityOperation_aggregate(self, target, *args, **kwargs):
self.__raise_if_unsupported("aggregate", target, Database, Collection)
return list(target.aggregate(*args, **kwargs))
return (target.aggregate(*args, **kwargs)).to_list()
def _databaseOperation_aggregate(self, target, *args, **kwargs):
return self.__entityOperation_aggregate(target, *args, **kwargs)
@ -1402,13 +784,13 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
def _collectionOperation_find(self, target, *args, **kwargs):
self.__raise_if_unsupported("find", target, Collection)
find_cursor = target.find(*args, **kwargs)
return list(find_cursor)
return find_cursor.to_list()
def _collectionOperation_createFindCursor(self, target, *args, **kwargs):
self.__raise_if_unsupported("find", target, Collection)
if "filter" not in kwargs:
self.fail('createFindCursor requires a "filter" argument')
cursor = NonLazyCursor(target.find(*args, **kwargs), target.database.client)
cursor = NonLazyCursor.create(target.find(*args, **kwargs), target.database.client)
self.addCleanup(cursor.close)
return cursor
@ -1418,7 +800,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
def _collectionOperation_listIndexes(self, target, *args, **kwargs):
if "batch_size" in kwargs:
self.skipTest("PyMongo does not support batch_size for list_indexes")
return list(target.list_indexes(*args, **kwargs))
return (target.list_indexes(*args, **kwargs)).to_list()
def _collectionOperation_listIndexNames(self, target, *args, **kwargs):
self.skipTest("PyMongo does not support list_index_names")
@ -1430,7 +812,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
def _collectionOperation_listSearchIndexes(self, target, *args, **kwargs):
name = kwargs.get("name")
agg_kwargs = kwargs.get("aggregation_options", dict())
return list(target.list_search_indexes(name, **agg_kwargs))
return (target.list_search_indexes(name, **agg_kwargs)).to_list()
def _sessionOperation_withTransaction(self, target, *args, **kwargs):
if client_context.storage_engine == "mmapv1":
@ -1470,7 +852,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
return target.create_data_key(*args, **kwargs)
def _clientEncryptionOperation_getKeys(self, target, *args, **kwargs):
return list(target.get_keys(*args, **kwargs))
return (target.get_keys(*args, **kwargs)).to_list()
def _clientEncryptionOperation_deleteKey(self, target, *args, **kwargs):
result = target.delete_key(*args, **kwargs)
@ -1516,7 +898,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
def _bucketOperation_find(
self, target: GridFSBucket, *args: Any, **kwargs: Any
) -> List[GridOut]:
return list(target.find(*args, **kwargs))
return target.find(*args, **kwargs).to_list()
def run_entity_operation(self, spec):
target = self.entity_map[spec["object"]]
@ -1849,7 +1231,10 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
except AttributeError:
self.fail(f"Unsupported special test operation {opname}")
else:
method(spec["arguments"])
if iscoroutinefunction(method):
method(spec["arguments"])
else:
method(spec["arguments"])
def run_operations(self, spec):
for op in spec:
@ -1985,7 +1370,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
if expected_documents:
sorted_expected_documents = sorted(expected_documents, key=lambda doc: doc["_id"])
actual_documents = list(coll.find({}, sort=[("_id", ASCENDING)]))
actual_documents = coll.find({}, sort=[("_id", ASCENDING)]).to_list()
self.assertListEqual(sorted_expected_documents, actual_documents)
def run_scenario(self, spec, uri=None):
@ -2040,7 +1425,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
# process initialData
if "initialData" in self.TEST_SPEC:
self.insert_initial_data(self.TEST_SPEC["initialData"])
self._cluster_time = self.client.admin.command("ping").get("$clusterTime")
self._cluster_time = (self.client.admin.command("ping")).get("$clusterTime")
self.entity_map.advance_cluster_times()
if "expectLogMessages" in spec:

View File

@ -0,0 +1,679 @@
# Copyright 2024-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utility functions and constants for the unified test format runner.
https://github.com/mongodb/specifications/blob/master/source/unified-test-format/unified-test-format.rst
"""
from __future__ import annotations
import binascii
import collections
import datetime
import os
import time
import types
from collections import abc
from test.helpers import (
AWS_CREDS,
AWS_CREDS_2,
AZURE_CREDS,
CA_PEM,
CLIENT_PEM,
GCP_CREDS,
KMIP_CREDS,
LOCAL_MASTER_KEY,
)
from test.utils import CMAPListener, camel_to_snake, parse_collection_options
from typing import Any, Union
from bson import (
RE_TYPE,
Binary,
Code,
DBRef,
Decimal128,
Int64,
MaxKey,
MinKey,
ObjectId,
Regex,
json_util,
)
from pymongo.monitoring import (
_SENSITIVE_COMMANDS,
CommandFailedEvent,
CommandListener,
CommandStartedEvent,
CommandSucceededEvent,
ConnectionCheckedInEvent,
ConnectionCheckedOutEvent,
ConnectionCheckOutFailedEvent,
ConnectionCheckOutStartedEvent,
ConnectionClosedEvent,
ConnectionCreatedEvent,
ConnectionReadyEvent,
PoolClearedEvent,
PoolClosedEvent,
PoolCreatedEvent,
PoolReadyEvent,
ServerClosedEvent,
ServerDescriptionChangedEvent,
ServerHeartbeatFailedEvent,
ServerHeartbeatListener,
ServerHeartbeatStartedEvent,
ServerHeartbeatSucceededEvent,
ServerListener,
ServerOpeningEvent,
TopologyClosedEvent,
TopologyDescriptionChangedEvent,
TopologyEvent,
TopologyListener,
TopologyOpenedEvent,
_CommandEvent,
_ConnectionEvent,
_PoolEvent,
_ServerEvent,
_ServerHeartbeatEvent,
)
from pymongo.results import BulkWriteResult
from pymongo.server_description import ServerDescription
from pymongo.topology_description import TopologyDescription
SKIP_CSOT_TESTS = os.getenv("SKIP_CSOT_TESTS")
JSON_OPTS = json_util.JSONOptions(tz_aware=False)
IS_INTERRUPTED = False
KMS_TLS_OPTS = {
"kmip": {
"tlsCAFile": CA_PEM,
"tlsCertificateKeyFile": CLIENT_PEM,
}
}
# Build up a placeholder maps.
PLACEHOLDER_MAP = {}
for provider_name, provider_data in [
("local", {"key": LOCAL_MASTER_KEY}),
("local:name1", {"key": LOCAL_MASTER_KEY}),
("aws", AWS_CREDS),
("aws:name1", AWS_CREDS),
("aws:name2", AWS_CREDS_2),
("azure", AZURE_CREDS),
("azure:name1", AZURE_CREDS),
("gcp", GCP_CREDS),
("gcp:name1", GCP_CREDS),
("kmip", KMIP_CREDS),
("kmip:name1", KMIP_CREDS),
]:
for key, value in provider_data.items():
placeholder = f"/clientEncryptionOpts/kmsProviders/{provider_name}/{key}"
PLACEHOLDER_MAP[placeholder] = value
OIDC_ENV = os.environ.get("OIDC_ENV", "test")
if OIDC_ENV == "test":
PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = {"ENVIRONMENT": "test"}
elif OIDC_ENV == "azure":
PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = {
"ENVIRONMENT": "azure",
"TOKEN_RESOURCE": os.environ["AZUREOIDC_RESOURCE"],
}
elif OIDC_ENV == "gcp":
PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = {
"ENVIRONMENT": "gcp",
"TOKEN_RESOURCE": os.environ["GCPOIDC_AUDIENCE"],
}
def interrupt_loop():
global IS_INTERRUPTED
IS_INTERRUPTED = True
def with_metaclass(meta, *bases):
"""Create a base class with a metaclass.
Vendored from six: https://github.com/benjaminp/six/blob/master/six.py
"""
# This requires a bit of explanation: the basic idea is to make a dummy
# metaclass for one level of class instantiation that replaces itself with
# the actual metaclass.
class metaclass(type):
def __new__(cls, name, this_bases, d):
# __orig_bases__ is required by PEP 560.
resolved_bases = types.resolve_bases(bases)
if resolved_bases is not bases:
d["__orig_bases__"] = bases
return meta(name, resolved_bases, d)
@classmethod
def __prepare__(cls, name, this_bases):
return meta.__prepare__(name, bases)
return type.__new__(metaclass, "temporary_class", (), {})
def parse_collection_or_database_options(options):
return parse_collection_options(options)
def parse_bulk_write_result(result):
upserted_ids = {str(int_idx): result.upserted_ids[int_idx] for int_idx in result.upserted_ids}
return {
"deletedCount": result.deleted_count,
"insertedCount": result.inserted_count,
"matchedCount": result.matched_count,
"modifiedCount": result.modified_count,
"upsertedCount": result.upserted_count,
"upsertedIds": upserted_ids,
}
def parse_client_bulk_write_individual(op_type, result):
if op_type == "insert":
return {"insertedId": result.inserted_id}
if op_type == "update":
if result.upserted_id:
return {
"matchedCount": result.matched_count,
"modifiedCount": result.modified_count,
"upsertedId": result.upserted_id,
}
else:
return {
"matchedCount": result.matched_count,
"modifiedCount": result.modified_count,
}
if op_type == "delete":
return {
"deletedCount": result.deleted_count,
}
def parse_client_bulk_write_result(result):
insert_results, update_results, delete_results = {}, {}, {}
if result.has_verbose_results:
for idx, res in result.insert_results.items():
insert_results[str(idx)] = parse_client_bulk_write_individual("insert", res)
for idx, res in result.update_results.items():
update_results[str(idx)] = parse_client_bulk_write_individual("update", res)
for idx, res in result.delete_results.items():
delete_results[str(idx)] = parse_client_bulk_write_individual("delete", res)
return {
"deletedCount": result.deleted_count,
"insertedCount": result.inserted_count,
"matchedCount": result.matched_count,
"modifiedCount": result.modified_count,
"upsertedCount": result.upserted_count,
"insertResults": insert_results,
"updateResults": update_results,
"deleteResults": delete_results,
}
def parse_bulk_write_error_result(error):
write_result = BulkWriteResult(error.details, True)
return parse_bulk_write_result(write_result)
def parse_client_bulk_write_error_result(error):
write_result = error.partial_result
if not write_result:
return None
return parse_client_bulk_write_result(write_result)
class EventListenerUtil(
CMAPListener, CommandListener, ServerListener, ServerHeartbeatListener, TopologyListener
):
def __init__(
self, observe_events, ignore_commands, observe_sensitive_commands, store_events, entity_map
):
self._event_types = {name.lower() for name in observe_events}
if observe_sensitive_commands:
self._observe_sensitive_commands = True
self._ignore_commands = set(ignore_commands)
else:
self._observe_sensitive_commands = False
self._ignore_commands = _SENSITIVE_COMMANDS | set(ignore_commands)
self._ignore_commands.add("configurefailpoint")
self._event_mapping = collections.defaultdict(list)
self.entity_map = entity_map
if store_events:
for i in store_events:
id = i["id"]
events = (i.lower() for i in i["events"])
for i in events:
self._event_mapping[i].append(id)
self.entity_map[id] = []
super().__init__()
def get_events(self, event_type):
assert event_type in ("command", "cmap", "sdam", "all"), event_type
if event_type == "all":
return list(self.events)
if event_type == "command":
return [e for e in self.events if isinstance(e, _CommandEvent)]
if event_type == "cmap":
return [e for e in self.events if isinstance(e, (_ConnectionEvent, _PoolEvent))]
return [
e
for e in self.events
if isinstance(e, (_ServerEvent, TopologyEvent, _ServerHeartbeatEvent))
]
def add_event(self, event):
event_name = type(event).__name__.lower()
if event_name in self._event_types:
super().add_event(event)
for id in self._event_mapping[event_name]:
self.entity_map[id].append(
{
"name": type(event).__name__,
"observedAt": time.time(),
"description": repr(event),
}
)
def _command_event(self, event):
if event.command_name.lower() not in self._ignore_commands:
self.add_event(event)
def started(self, event):
if isinstance(event, CommandStartedEvent):
if event.command == {}:
# Command is redacted. Observe only if flag is set.
if self._observe_sensitive_commands:
self._command_event(event)
else:
self._command_event(event)
else:
self.add_event(event)
def succeeded(self, event):
if isinstance(event, CommandSucceededEvent):
if event.reply == {}:
# Command is redacted. Observe only if flag is set.
if self._observe_sensitive_commands:
self._command_event(event)
else:
self._command_event(event)
else:
self.add_event(event)
def failed(self, event):
if isinstance(event, CommandFailedEvent):
self._command_event(event)
else:
self.add_event(event)
def opened(self, event: Union[ServerOpeningEvent, TopologyOpenedEvent]) -> None:
self.add_event(event)
def description_changed(
self, event: Union[ServerDescriptionChangedEvent, TopologyDescriptionChangedEvent]
) -> None:
self.add_event(event)
def topology_changed(self, event: TopologyDescriptionChangedEvent) -> None:
self.add_event(event)
def closed(self, event: Union[ServerClosedEvent, TopologyClosedEvent]) -> None:
self.add_event(event)
binary_types = (Binary, bytes)
long_types = (Int64,)
unicode_type = str
BSON_TYPE_ALIAS_MAP = {
# https://mongodb.com/docs/manual/reference/operator/query/type/
# https://pymongo.readthedocs.io/en/stable/api/bson/index.html
"double": (float,),
"string": (str,),
"object": (abc.Mapping,),
"array": (abc.MutableSequence,),
"binData": binary_types,
"undefined": (type(None),),
"objectId": (ObjectId,),
"bool": (bool,),
"date": (datetime.datetime,),
"null": (type(None),),
"regex": (Regex, RE_TYPE),
"dbPointer": (DBRef,),
"javascript": (unicode_type, Code),
"symbol": (unicode_type,),
"javascriptWithScope": (unicode_type, Code),
"int": (int,),
"long": (Int64,),
"decimal": (Decimal128,),
"maxKey": (MaxKey,),
"minKey": (MinKey,),
}
class MatchEvaluatorUtil:
"""Utility class that implements methods for evaluating matches as per
the unified test format specification.
"""
def __init__(self, test_class):
self.test = test_class
def _operation_exists(self, spec, actual, key_to_compare):
if spec is True:
if key_to_compare is None:
assert actual is not None
else:
self.test.assertIn(key_to_compare, actual)
elif spec is False:
if key_to_compare is None:
assert actual is None
else:
self.test.assertNotIn(key_to_compare, actual)
else:
self.test.fail(f"Expected boolean value for $$exists operator, got {spec}")
def __type_alias_to_type(self, alias):
if alias not in BSON_TYPE_ALIAS_MAP:
self.test.fail(f"Unrecognized BSON type alias {alias}")
return BSON_TYPE_ALIAS_MAP[alias]
def _operation_type(self, spec, actual, key_to_compare):
if isinstance(spec, abc.MutableSequence):
permissible_types = tuple(
[t for alias in spec for t in self.__type_alias_to_type(alias)]
)
else:
permissible_types = self.__type_alias_to_type(spec)
value = actual[key_to_compare] if key_to_compare else actual
self.test.assertIsInstance(value, permissible_types)
def _operation_matchesEntity(self, spec, actual, key_to_compare):
expected_entity = self.test.entity_map[spec]
self.test.assertEqual(expected_entity, actual[key_to_compare])
def _operation_matchesHexBytes(self, spec, actual, key_to_compare):
expected = binascii.unhexlify(spec)
value = actual[key_to_compare] if key_to_compare else actual
self.test.assertEqual(value, expected)
def _operation_unsetOrMatches(self, spec, actual, key_to_compare):
if key_to_compare is None and not actual:
# top-level document can be None when unset
return
if key_to_compare not in actual:
# we add a dummy value for the compared key to pass map size check
actual[key_to_compare] = "dummyValue"
return
self.match_result(spec, actual[key_to_compare], in_recursive_call=True)
def _operation_sessionLsid(self, spec, actual, key_to_compare):
expected_lsid = self.test.entity_map.get_lsid_for_session(spec)
self.test.assertEqual(expected_lsid, actual[key_to_compare])
def _operation_lte(self, spec, actual, key_to_compare):
if key_to_compare not in actual:
self.test.fail(f"Actual command is missing the {key_to_compare} field: {spec}")
self.test.assertLessEqual(actual[key_to_compare], spec)
def _operation_matchAsDocument(self, spec, actual, key_to_compare):
self._match_document(spec, json_util.loads(actual[key_to_compare]), False)
def _operation_matchAsRoot(self, spec, actual, key_to_compare):
self._match_document(spec, actual, True)
def _evaluate_special_operation(self, opname, spec, actual, key_to_compare):
method_name = "_operation_{}".format(opname.strip("$"))
try:
method = getattr(self, method_name)
except AttributeError:
self.test.fail(f"Unsupported special matching operator {opname}")
else:
method(spec, actual, key_to_compare)
def _evaluate_if_special_operation(self, expectation, actual, key_to_compare=None):
"""Returns True if a special operation is evaluated, False
otherwise. If the ``expectation`` map contains a single key,
value pair we check it for a special operation.
If given, ``key_to_compare`` is assumed to be the key in
``expectation`` whose corresponding value needs to be
evaluated for a possible special operation. ``key_to_compare``
is ignored when ``expectation`` has only one key.
"""
if not isinstance(expectation, abc.Mapping):
return False
is_special_op, opname, spec = False, False, False
if key_to_compare is not None:
if key_to_compare.startswith("$$"):
is_special_op = True
opname = key_to_compare
spec = expectation[key_to_compare]
key_to_compare = None
else:
nested = expectation[key_to_compare]
if isinstance(nested, abc.Mapping) and len(nested) == 1:
opname, spec = next(iter(nested.items()))
if opname.startswith("$$"):
is_special_op = True
elif len(expectation) == 1:
opname, spec = next(iter(expectation.items()))
if opname.startswith("$$"):
is_special_op = True
key_to_compare = None
if is_special_op:
self._evaluate_special_operation(
opname=opname, spec=spec, actual=actual, key_to_compare=key_to_compare
)
return True
return False
def _match_document(self, expectation, actual, is_root, test=False):
if self._evaluate_if_special_operation(expectation, actual):
return
self.test.assertIsInstance(actual, abc.Mapping)
for key, value in expectation.items():
if self._evaluate_if_special_operation(expectation, actual, key):
continue
self.test.assertIn(key, actual)
if not self.match_result(value, actual[key], in_recursive_call=True, test=test):
return False
if not is_root:
expected_keys = set(expectation.keys())
for key, value in expectation.items():
if value == {"$$exists": False}:
expected_keys.remove(key)
if test:
self.test.assertEqual(expected_keys, set(actual.keys()))
else:
return set(expected_keys).issubset(set(actual.keys()))
return True
def match_result(self, expectation, actual, in_recursive_call=False, test=True):
if isinstance(expectation, abc.Mapping):
return self._match_document(
expectation, actual, is_root=not in_recursive_call, test=test
)
if isinstance(expectation, abc.MutableSequence):
self.test.assertIsInstance(actual, abc.MutableSequence)
for e, a in zip(expectation, actual):
if isinstance(e, abc.Mapping):
self._match_document(e, a, is_root=not in_recursive_call, test=test)
else:
self.match_result(e, a, in_recursive_call=True, test=test)
return None
# account for flexible numerics in element-wise comparison
if isinstance(expectation, int) or isinstance(expectation, float):
if test:
self.test.assertEqual(expectation, actual)
else:
return expectation == actual
return None
else:
if test:
self.test.assertIsInstance(actual, type(expectation))
self.test.assertEqual(expectation, actual)
else:
return isinstance(actual, type(expectation)) and expectation == actual
return None
def match_server_description(self, actual: ServerDescription, spec: dict) -> None:
for field, expected in spec.items():
field = camel_to_snake(field)
if field == "type":
field = "server_type_name"
self.test.assertEqual(getattr(actual, field), expected)
def match_topology_description(self, actual: TopologyDescription, spec: dict) -> None:
for field, expected in spec.items():
field = camel_to_snake(field)
if field == "type":
field = "topology_type_name"
self.test.assertEqual(getattr(actual, field), expected)
def match_event_fields(self, actual: Any, spec: dict) -> None:
for field, expected in spec.items():
if field == "command" and isinstance(actual, CommandStartedEvent):
command = spec["command"]
if command:
self.match_result(command, actual.command)
continue
if field == "reply" and isinstance(actual, CommandSucceededEvent):
reply = spec["reply"]
if reply:
self.match_result(reply, actual.reply)
continue
if field == "hasServiceId":
if spec["hasServiceId"]:
self.test.assertIsNotNone(actual.service_id)
self.test.assertIsInstance(actual.service_id, ObjectId)
else:
self.test.assertIsNone(actual.service_id)
continue
if field == "hasServerConnectionId":
if spec["hasServerConnectionId"]:
self.test.assertIsNotNone(actual.server_connection_id)
self.test.assertIsInstance(actual.server_connection_id, int)
else:
self.test.assertIsNone(actual.server_connection_id)
continue
if field in ("previousDescription", "newDescription"):
if isinstance(actual, ServerDescriptionChangedEvent):
self.match_server_description(
getattr(actual, camel_to_snake(field)), spec[field]
)
continue
if isinstance(actual, TopologyDescriptionChangedEvent):
self.match_topology_description(
getattr(actual, camel_to_snake(field)), spec[field]
)
continue
if field == "interruptInUseConnections":
field = "interrupt_connections"
else:
field = camel_to_snake(field)
self.test.assertEqual(getattr(actual, field), expected)
def match_event(self, expectation, actual):
name, spec = next(iter(expectation.items()))
if name == "commandStartedEvent":
self.test.assertIsInstance(actual, CommandStartedEvent)
elif name == "commandSucceededEvent":
self.test.assertIsInstance(actual, CommandSucceededEvent)
elif name == "commandFailedEvent":
self.test.assertIsInstance(actual, CommandFailedEvent)
elif name == "poolCreatedEvent":
self.test.assertIsInstance(actual, PoolCreatedEvent)
elif name == "poolReadyEvent":
self.test.assertIsInstance(actual, PoolReadyEvent)
elif name == "poolClearedEvent":
self.test.assertIsInstance(actual, PoolClearedEvent)
self.test.assertIsInstance(actual.interrupt_connections, bool)
elif name == "poolClosedEvent":
self.test.assertIsInstance(actual, PoolClosedEvent)
elif name == "connectionCreatedEvent":
self.test.assertIsInstance(actual, ConnectionCreatedEvent)
elif name == "connectionReadyEvent":
self.test.assertIsInstance(actual, ConnectionReadyEvent)
elif name == "connectionClosedEvent":
self.test.assertIsInstance(actual, ConnectionClosedEvent)
elif name == "connectionCheckOutStartedEvent":
self.test.assertIsInstance(actual, ConnectionCheckOutStartedEvent)
elif name == "connectionCheckOutFailedEvent":
self.test.assertIsInstance(actual, ConnectionCheckOutFailedEvent)
elif name == "connectionCheckedOutEvent":
self.test.assertIsInstance(actual, ConnectionCheckedOutEvent)
elif name == "connectionCheckedInEvent":
self.test.assertIsInstance(actual, ConnectionCheckedInEvent)
elif name == "serverDescriptionChangedEvent":
self.test.assertIsInstance(actual, ServerDescriptionChangedEvent)
elif name == "serverHeartbeatStartedEvent":
self.test.assertIsInstance(actual, ServerHeartbeatStartedEvent)
elif name == "serverHeartbeatSucceededEvent":
self.test.assertIsInstance(actual, ServerHeartbeatSucceededEvent)
elif name == "serverHeartbeatFailedEvent":
self.test.assertIsInstance(actual, ServerHeartbeatFailedEvent)
elif name == "topologyDescriptionChangedEvent":
self.test.assertIsInstance(actual, TopologyDescriptionChangedEvent)
elif name == "topologyOpeningEvent":
self.test.assertIsInstance(actual, TopologyOpenedEvent)
elif name == "topologyClosedEvent":
self.test.assertIsInstance(actual, TopologyClosedEvent)
else:
raise Exception(f"Unsupported event type {name}")
self.match_event_fields(actual, spec)
def coerce_result(opname, result):
"""Convert a pymongo result into the spec's result format."""
if hasattr(result, "acknowledged") and not result.acknowledged:
return {"acknowledged": False}
if opname == "bulkWrite":
return parse_bulk_write_result(result)
if opname == "clientBulkWrite":
return parse_client_bulk_write_result(result)
if opname == "insertOne":
return {"insertedId": result.inserted_id}
if opname == "insertMany":
return dict(enumerate(result.inserted_ids))
if opname in ("deleteOne", "deleteMany"):
return {"deletedCount": result.deleted_count}
if opname in ("updateOne", "updateMany", "replaceOne"):
value = {
"matchedCount": result.matched_count,
"modifiedCount": result.modified_count,
"upsertedCount": 0 if result.upserted_id is None else 1,
}
if result.upserted_id is not None:
value["upsertedId"] = result.upserted_id
return value
return result

View File

@ -418,153 +418,6 @@ class FunctionCallRecorder:
return len(self._call_list)
class SpecTestCreator:
"""Class to create test cases from specifications."""
def __init__(self, create_test, test_class, test_path):
"""Create a TestCreator object.
:Parameters:
- `create_test`: callback that returns a test case. The callback
must accept the following arguments - a dictionary containing the
entire test specification (the `scenario_def`), a dictionary
containing the specification for which the test case will be
generated (the `test_def`).
- `test_class`: the unittest.TestCase class in which to create the
test case.
- `test_path`: path to the directory containing the JSON files with
the test specifications.
"""
self._create_test = create_test
self._test_class = test_class
self.test_path = test_path
def _ensure_min_max_server_version(self, scenario_def, method):
"""Test modifier that enforces a version range for the server on a
test case.
"""
if "minServerVersion" in scenario_def:
min_ver = tuple(int(elt) for elt in scenario_def["minServerVersion"].split("."))
if min_ver is not None:
method = client_context.require_version_min(*min_ver)(method)
if "maxServerVersion" in scenario_def:
max_ver = tuple(int(elt) for elt in scenario_def["maxServerVersion"].split("."))
if max_ver is not None:
method = client_context.require_version_max(*max_ver)(method)
if "serverless" in scenario_def:
serverless = scenario_def["serverless"]
if serverless == "require":
serverless_satisfied = client_context.serverless
elif serverless == "forbid":
serverless_satisfied = not client_context.serverless
else: # unset or "allow"
serverless_satisfied = True
method = unittest.skipUnless(
serverless_satisfied, "Serverless requirement not satisfied"
)(method)
return method
@staticmethod
def valid_topology(run_on_req):
return client_context.is_topology_type(
run_on_req.get("topology", ["single", "replicaset", "sharded", "load-balanced"])
)
@staticmethod
def min_server_version(run_on_req):
version = run_on_req.get("minServerVersion")
if version:
min_ver = tuple(int(elt) for elt in version.split("."))
return client_context.version >= min_ver
return True
@staticmethod
def max_server_version(run_on_req):
version = run_on_req.get("maxServerVersion")
if version:
max_ver = tuple(int(elt) for elt in version.split("."))
return client_context.version <= max_ver
return True
@staticmethod
def valid_auth_enabled(run_on_req):
if "authEnabled" in run_on_req:
if run_on_req["authEnabled"]:
return client_context.auth_enabled
return not client_context.auth_enabled
return True
@staticmethod
def serverless_ok(run_on_req):
serverless = run_on_req["serverless"]
if serverless == "require":
return client_context.serverless
elif serverless == "forbid":
return not client_context.serverless
else: # unset or "allow"
return True
def should_run_on(self, scenario_def):
run_on = scenario_def.get("runOn", [])
if not run_on:
# Always run these tests.
return True
for req in run_on:
if (
self.valid_topology(req)
and self.min_server_version(req)
and self.max_server_version(req)
and self.valid_auth_enabled(req)
and self.serverless_ok(req)
):
return True
return False
def ensure_run_on(self, scenario_def, method):
"""Test modifier that enforces a 'runOn' on a test case."""
return client_context._require(
lambda: self.should_run_on(scenario_def), "runOn not satisfied", method
)
def tests(self, scenario_def):
"""Allow CMAP spec test to override the location of test."""
return scenario_def["tests"]
def create_tests(self):
for dirpath, _, filenames in os.walk(self.test_path):
dirname = os.path.split(dirpath)[-1]
for filename in filenames:
with open(os.path.join(dirpath, filename)) as scenario_stream:
# Use tz_aware=False to match how CodecOptions decodes
# dates.
opts = json_util.JSONOptions(tz_aware=False)
scenario_def = ScenarioDict(
json_util.loads(scenario_stream.read(), json_options=opts)
)
test_type = os.path.splitext(filename)[0]
# Construct test from scenario.
for test_def in self.tests(scenario_def):
test_name = "test_{}_{}_{}".format(
dirname,
test_type.replace("-", "_").replace(".", "_"),
str(test_def["description"].replace(" ", "_").replace(".", "_")),
)
new_test = self._create_test(scenario_def, test_def, test_name)
new_test = self._ensure_min_max_server_version(scenario_def, new_test)
new_test = self.ensure_run_on(scenario_def, new_test)
new_test.__name__ = test_name
setattr(self._test_class, new_test.__name__, new_test)
def ensure_all_connected(client: MongoClient) -> None:
"""Ensure that the client's connection pool has socket connections to all
members of a replica set. Raises ConfigurationError when called with a

View File

@ -15,8 +15,12 @@
"""Utilities for testing driver specs."""
from __future__ import annotations
import asyncio
import functools
import os
import threading
import unittest
from asyncio import iscoroutinefunction
from collections import abc
from test import IntegrationTest, client_context, client_knobs
from test.utils import (
@ -24,6 +28,7 @@ from test.utils import (
CompareType,
EventListener,
OvertCommandListener,
ScenarioDict,
ServerAndTopologyEventListener,
camel_to_snake,
camel_to_snake_args,
@ -32,11 +37,12 @@ from test.utils import (
)
from typing import List
from bson import ObjectId, decode, encode
from bson import ObjectId, decode, encode, json_util
from bson.binary import Binary
from bson.int64 import Int64
from bson.son import SON
from gridfs import GridFSBucket
from gridfs.synchronous.grid_file import GridFSBucket
from pymongo.errors import BulkWriteError, OperationFailure, PyMongoError
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import ReadPreference
@ -83,6 +89,161 @@ class SpecRunnerThread(threading.Thread):
self.stop()
class SpecTestCreator:
"""Class to create test cases from specifications."""
def __init__(self, create_test, test_class, test_path):
"""Create a TestCreator object.
:Parameters:
- `create_test`: callback that returns a test case. The callback
must accept the following arguments - a dictionary containing the
entire test specification (the `scenario_def`), a dictionary
containing the specification for which the test case will be
generated (the `test_def`).
- `test_class`: the unittest.TestCase class in which to create the
test case.
- `test_path`: path to the directory containing the JSON files with
the test specifications.
"""
self._create_test = create_test
self._test_class = test_class
self.test_path = test_path
def _ensure_min_max_server_version(self, scenario_def, method):
"""Test modifier that enforces a version range for the server on a
test case.
"""
if "minServerVersion" in scenario_def:
min_ver = tuple(int(elt) for elt in scenario_def["minServerVersion"].split("."))
if min_ver is not None:
method = client_context.require_version_min(*min_ver)(method)
if "maxServerVersion" in scenario_def:
max_ver = tuple(int(elt) for elt in scenario_def["maxServerVersion"].split("."))
if max_ver is not None:
method = client_context.require_version_max(*max_ver)(method)
if "serverless" in scenario_def:
serverless = scenario_def["serverless"]
if serverless == "require":
serverless_satisfied = client_context.serverless
elif serverless == "forbid":
serverless_satisfied = not client_context.serverless
else: # unset or "allow"
serverless_satisfied = True
method = unittest.skipUnless(
serverless_satisfied, "Serverless requirement not satisfied"
)(method)
return method
@staticmethod
def valid_topology(run_on_req):
return client_context.is_topology_type(
run_on_req.get("topology", ["single", "replicaset", "sharded", "load-balanced"])
)
@staticmethod
def min_server_version(run_on_req):
version = run_on_req.get("minServerVersion")
if version:
min_ver = tuple(int(elt) for elt in version.split("."))
return client_context.version >= min_ver
return True
@staticmethod
def max_server_version(run_on_req):
version = run_on_req.get("maxServerVersion")
if version:
max_ver = tuple(int(elt) for elt in version.split("."))
return client_context.version <= max_ver
return True
@staticmethod
def valid_auth_enabled(run_on_req):
if "authEnabled" in run_on_req:
if run_on_req["authEnabled"]:
return client_context.auth_enabled
return not client_context.auth_enabled
return True
@staticmethod
def serverless_ok(run_on_req):
serverless = run_on_req["serverless"]
if serverless == "require":
return client_context.serverless
elif serverless == "forbid":
return not client_context.serverless
else: # unset or "allow"
return True
def should_run_on(self, scenario_def):
run_on = scenario_def.get("runOn", [])
if not run_on:
# Always run these tests.
return True
for req in run_on:
if (
self.valid_topology(req)
and self.min_server_version(req)
and self.max_server_version(req)
and self.valid_auth_enabled(req)
and self.serverless_ok(req)
):
return True
return False
def ensure_run_on(self, scenario_def, method):
"""Test modifier that enforces a 'runOn' on a test case."""
def predicate():
return self.should_run_on(scenario_def)
return client_context._require(predicate, "runOn not satisfied", method)
def tests(self, scenario_def):
"""Allow CMAP spec test to override the location of test."""
return scenario_def["tests"]
def _create_tests(self):
for dirpath, _, filenames in os.walk(self.test_path):
dirname = os.path.split(dirpath)[-1]
for filename in filenames:
with open(os.path.join(dirpath, filename)) as scenario_stream: # noqa: ASYNC101, RUF100
# Use tz_aware=False to match how CodecOptions decodes
# dates.
opts = json_util.JSONOptions(tz_aware=False)
scenario_def = ScenarioDict(
json_util.loads(scenario_stream.read(), json_options=opts)
)
test_type = os.path.splitext(filename)[0]
# Construct test from scenario.
for test_def in self.tests(scenario_def):
test_name = "test_{}_{}_{}".format(
dirname,
test_type.replace("-", "_").replace(".", "_"),
str(test_def["description"].replace(" ", "_").replace(".", "_")),
)
new_test = self._create_test(scenario_def, test_def, test_name)
new_test = self._ensure_min_max_server_version(scenario_def, new_test)
new_test = self.ensure_run_on(scenario_def, new_test)
new_test.__name__ = test_name
setattr(self._test_class, new_test.__name__, new_test)
def create_tests(self):
if _IS_SYNC:
self._create_tests()
else:
asyncio.run(self._create_tests())
class SpecRunner(IntegrationTest):
mongos_clients: List
knobs: client_knobs
@ -312,7 +473,10 @@ class SpecRunner(IntegrationTest):
args.update(arguments)
arguments = args
result = cmd(**dict(arguments))
if not _IS_SYNC and iscoroutinefunction(cmd):
result = cmd(**dict(arguments))
else:
result = cmd(**dict(arguments))
# Cleanup open change stream cursors.
if name == "watch":
self.addCleanup(result.close)
@ -583,7 +747,7 @@ class SpecRunner(IntegrationTest):
read_preference=ReadPreference.PRIMARY,
read_concern=ReadConcern("local"),
)
actual_data = (outcome_coll.find(sort=[("_id", 1)])).to_list()
actual_data = outcome_coll.find(sort=[("_id", 1)]).to_list()
# The expected data needs to be the left hand side here otherwise
# CompareType(Binary) doesn't work.

View File

@ -105,6 +105,8 @@ replacements = {
"PyMongo|c|async": "PyMongo|c",
"AsyncTestGridFile": "TestGridFile",
"AsyncTestGridFileNoConnect": "TestGridFileNoConnect",
"AsyncTestSpec": "TestSpec",
"AsyncSpecTestCreator": "SpecTestCreator",
"async_set_fail_point": "set_fail_point",
"async_ensure_all_connected": "ensure_all_connected",
"async_repl_set_step_down": "repl_set_step_down",
@ -188,8 +190,14 @@ converted_tests = [
"test_client.py",
"test_client_bulk_write.py",
"test_client_context.py",
"test_collation.py",
"test_collection.py",
"test_command_logging.py",
"test_command_monitoring.py",
"test_common.py",
"test_connection_logging.py",
"test_connections_survive_primary_stepdown_spec.py",
"test_crud_unified.py",
"test_cursor.py",
"test_database.py",
"test_encryption.py",
@ -201,6 +209,7 @@ converted_tests = [
"test_retryable_writes.py",
"test_session.py",
"test_transactions.py",
"unified_format.py",
]
sync_test_files = [