Merge branch 'master' of github.com:mongodb/mongo-python-driver
This commit is contained in:
commit
bf9452b6f9
@ -257,9 +257,9 @@ if [ -z "$GREEN_FRAMEWORK" ]; then
|
||||
# Use --capture=tee-sys so pytest prints test output inline:
|
||||
# https://docs.pytest.org/en/stable/how-to/capture-stdout-stderr.html
|
||||
if [ -z "$TEST_SUITES" ]; then
|
||||
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 $TEST_ARGS
|
||||
python -m pytest -v --capture=tee-sys --durations=5 $TEST_ARGS
|
||||
else
|
||||
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 -m $TEST_SUITES $TEST_ARGS
|
||||
python -m pytest -v --capture=tee-sys --durations=5 -m $TEST_SUITES $TEST_ARGS
|
||||
fi
|
||||
else
|
||||
python green_framework_test.py $GREEN_FRAMEWORK -v $TEST_ARGS
|
||||
|
||||
5
.github/workflows/dist.yml
vendored
5
.github/workflows/dist.yml
vendored
@ -67,7 +67,7 @@ jobs:
|
||||
# Note: the default manylinux is manylinux2014
|
||||
run: |
|
||||
python -m pip install -U pip
|
||||
python -m pip install "cibuildwheel>=2.17,<3"
|
||||
python -m pip install "cibuildwheel>=2.20,<3"
|
||||
|
||||
- name: Build wheels
|
||||
env:
|
||||
@ -89,6 +89,9 @@ jobs:
|
||||
ls wheelhouse/*cp310*.whl
|
||||
ls wheelhouse/*cp311*.whl
|
||||
ls wheelhouse/*cp312*.whl
|
||||
ls wheelhouse/*cp313*.whl
|
||||
# Free-threading builds:
|
||||
ls wheelhouse/*cp313t*.whl
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
|
||||
23
.github/workflows/test-python.yml
vendored
23
.github/workflows/test-python.yml
vendored
@ -51,11 +51,18 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-20.04]
|
||||
python-version: ["3.9", "pypy-3.9", "3.13"]
|
||||
python-version: ["3.9", "pypy-3.9", "3.13", "3.13t"]
|
||||
name: CPython ${{ matrix.python-version }}-${{ matrix.os }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Setup Python
|
||||
- if: ${{ matrix.python-version == '3.13t' }}
|
||||
name: Setup free-threaded Python
|
||||
uses: deadsnakes/action@v3.2.0
|
||||
with:
|
||||
python-version: 3.13
|
||||
nogil: true
|
||||
- if: ${{ matrix.python-version != '3.13t' }}
|
||||
name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python-version }}
|
||||
@ -65,9 +72,13 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install -U pip
|
||||
if [ "${{ matrix.python-version }}" == "3.13" ]; then
|
||||
if [[ "${{ matrix.python-version }}" == "3.13" ]]; then
|
||||
pip install --pre cffi setuptools
|
||||
pip install --no-build-isolation hatch
|
||||
elif [[ "${{ matrix.python-version }}" == "3.13t" ]]; then
|
||||
# Hatch can't be installed on 3.13t, use pytest directly.
|
||||
pip install .
|
||||
pip install -r requirements/test.txt
|
||||
else
|
||||
pip install hatch
|
||||
fi
|
||||
@ -77,7 +88,11 @@ jobs:
|
||||
mongodb-version: 6.0
|
||||
- name: Run tests
|
||||
run: |
|
||||
hatch run test:test
|
||||
if [[ "${{ matrix.python-version }}" == "3.13t" ]]; then
|
||||
pytest -v --durations=5 --maxfail=10
|
||||
else
|
||||
hatch run test:test
|
||||
fi
|
||||
|
||||
doctest:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
@ -3184,6 +3184,9 @@ static PyModuleDef_Slot _cbson_slots[] = {
|
||||
{Py_mod_exec, _cbson_exec},
|
||||
#if defined(Py_MOD_MULTIPLE_INTERPRETERS_SUPPORTED)
|
||||
{Py_mod_multiple_interpreters, Py_MOD_MULTIPLE_INTERPRETERS_SUPPORTED},
|
||||
#endif
|
||||
#if PY_VERSION_HEX >= 0x030D0000
|
||||
{Py_mod_gil, Py_MOD_GIL_NOT_USED},
|
||||
#endif
|
||||
{0, NULL},
|
||||
};
|
||||
|
||||
@ -12,6 +12,8 @@ PyMongo 4.11 brings a number of changes including:
|
||||
|
||||
- Dropped support for Python 3.8.
|
||||
- Dropped support for MongoDB 3.6.
|
||||
- Added support for free-threaded Python with the GIL disabled. For more information see:
|
||||
`Free-threaded CPython <https://docs.python.org/3.13/whatsnew/3.13.html#whatsnew313-free-threaded-cpython>`_.
|
||||
|
||||
Issues Resolved
|
||||
...............
|
||||
|
||||
@ -1022,6 +1022,9 @@ static PyModuleDef_Slot _cmessage_slots[] = {
|
||||
{Py_mod_exec, _cmessage_exec},
|
||||
#ifdef Py_MOD_MULTIPLE_INTERPRETERS_SUPPORTED
|
||||
{Py_mod_multiple_interpreters, Py_MOD_MULTIPLE_INTERPRETERS_SUPPORTED},
|
||||
#endif
|
||||
#if PY_VERSION_HEX >= 0x030D0000
|
||||
{Py_mod_gil, Py_MOD_GIL_NOT_USED},
|
||||
#endif
|
||||
{0, NULL},
|
||||
};
|
||||
|
||||
@ -180,10 +180,20 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
|
||||
while kms_context.bytes_needed > 0:
|
||||
# CSOT: update timeout.
|
||||
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
|
||||
data = conn.recv(kms_context.bytes_needed)
|
||||
if _IS_SYNC:
|
||||
data = conn.recv(kms_context.bytes_needed)
|
||||
else:
|
||||
from pymongo.network_layer import ( # type: ignore[attr-defined]
|
||||
async_receive_data_socket,
|
||||
)
|
||||
|
||||
data = await async_receive_data_socket(conn, kms_context.bytes_needed)
|
||||
if not data:
|
||||
raise OSError("KMS connection closed")
|
||||
kms_context.feed(data)
|
||||
# Async raises an OSError instead of returning empty bytes
|
||||
except OSError as err:
|
||||
raise OSError("KMS connection closed") from err
|
||||
except BLOCKING_IO_ERRORS:
|
||||
raise socket.timeout("timed out") from None
|
||||
finally:
|
||||
|
||||
@ -130,7 +130,7 @@ if sys.platform != "win32":
|
||||
loop.remove_writer(fd)
|
||||
|
||||
async def _async_receive_ssl(
|
||||
conn: _sslConn, length: int, loop: AbstractEventLoop
|
||||
conn: _sslConn, length: int, loop: AbstractEventLoop, once: Optional[bool] = False
|
||||
) -> memoryview:
|
||||
mv = memoryview(bytearray(length))
|
||||
total_read = 0
|
||||
@ -145,6 +145,9 @@ if sys.platform != "win32":
|
||||
read = conn.recv_into(mv[total_read:])
|
||||
if read == 0:
|
||||
raise OSError("connection closed")
|
||||
# KMS responses update their expected size after the first batch, stop reading after one loop
|
||||
if once:
|
||||
return mv[:read]
|
||||
total_read += read
|
||||
except BLOCKING_IO_ERRORS as exc:
|
||||
fd = conn.fileno()
|
||||
@ -275,6 +278,28 @@ async def async_receive_data(
|
||||
sock.settimeout(sock_timeout)
|
||||
|
||||
|
||||
async def async_receive_data_socket(
|
||||
sock: Union[socket.socket, _sslConn], length: int
|
||||
) -> memoryview:
|
||||
sock_timeout = sock.gettimeout()
|
||||
timeout = sock_timeout
|
||||
|
||||
sock.settimeout(0.0)
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
|
||||
return await asyncio.wait_for(
|
||||
_async_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type]
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type]
|
||||
except asyncio.TimeoutError as err:
|
||||
raise socket.timeout("timed out") from err
|
||||
finally:
|
||||
sock.settimeout(sock_timeout)
|
||||
|
||||
|
||||
async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview:
|
||||
mv = memoryview(bytearray(length))
|
||||
bytes_read = 0
|
||||
|
||||
@ -180,10 +180,20 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
|
||||
while kms_context.bytes_needed > 0:
|
||||
# CSOT: update timeout.
|
||||
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
|
||||
data = conn.recv(kms_context.bytes_needed)
|
||||
if _IS_SYNC:
|
||||
data = conn.recv(kms_context.bytes_needed)
|
||||
else:
|
||||
from pymongo.network_layer import ( # type: ignore[attr-defined]
|
||||
receive_data_socket,
|
||||
)
|
||||
|
||||
data = receive_data_socket(conn, kms_context.bytes_needed)
|
||||
if not data:
|
||||
raise OSError("KMS connection closed")
|
||||
kms_context.feed(data)
|
||||
# Async raises an OSError instead of returning empty bytes
|
||||
except OSError as err:
|
||||
raise OSError("KMS connection closed") from err
|
||||
except BLOCKING_IO_ERRORS:
|
||||
raise socket.timeout("timed out") from None
|
||||
finally:
|
||||
|
||||
@ -236,6 +236,8 @@ partial_branches = ["if (.*and +)*not _use_c( and.*)*:"]
|
||||
directory = "htmlcov"
|
||||
|
||||
[tool.cibuildwheel]
|
||||
# Enable free-threaded support
|
||||
free-threaded-support = true
|
||||
skip = "pp* *-musllinux*"
|
||||
build-frontend = "build"
|
||||
test-command = "python {project}/tools/fail_if_no_c.py"
|
||||
|
||||
@ -464,11 +464,12 @@ class ClientContext:
|
||||
if not self.connected:
|
||||
pair = self.pair
|
||||
raise SkipTest(f"Cannot connect to MongoDB on {pair}")
|
||||
if iscoroutinefunction(condition) and condition():
|
||||
if wraps_async:
|
||||
return f(*args, **kwargs)
|
||||
else:
|
||||
return f(*args, **kwargs)
|
||||
if iscoroutinefunction(condition):
|
||||
if condition():
|
||||
if wraps_async:
|
||||
return f(*args, **kwargs)
|
||||
else:
|
||||
return f(*args, **kwargs)
|
||||
elif condition():
|
||||
if wraps_async:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
@ -466,11 +466,12 @@ class AsyncClientContext:
|
||||
if not self.connected:
|
||||
pair = await self.pair
|
||||
raise SkipTest(f"Cannot connect to MongoDB on {pair}")
|
||||
if iscoroutinefunction(condition) and await condition():
|
||||
if wraps_async:
|
||||
return await f(*args, **kwargs)
|
||||
else:
|
||||
return f(*args, **kwargs)
|
||||
if iscoroutinefunction(condition):
|
||||
if await condition():
|
||||
if wraps_async:
|
||||
return await f(*args, **kwargs)
|
||||
else:
|
||||
return f(*args, **kwargs)
|
||||
elif condition():
|
||||
if wraps_async:
|
||||
return await f(*args, **kwargs)
|
||||
|
||||
@ -61,6 +61,13 @@ class TestAsyncClientContext(AsyncUnitTest):
|
||||
|
||||
self.assertEqual(sys.getdefaultencoding(), os.environ["SETDEFAULTENCODING"])
|
||||
|
||||
def test_free_threading_is_enabled(self):
|
||||
if "free-threading build" not in sys.version:
|
||||
raise SkipTest("this test requires the Python free-threading build")
|
||||
|
||||
# If the GIL is enabled then pymongo or one of our deps does not support free-threading.
|
||||
self.assertFalse(sys._is_gil_enabled()) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
290
test/asynchronous/test_collation.py
Normal file
290
test/asynchronous/test_collation.py
Normal file
@ -0,0 +1,290 @@
|
||||
# Copyright 2016-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test the collation module."""
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import warnings
|
||||
from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest
|
||||
from test.utils import EventListener
|
||||
from typing import Any
|
||||
|
||||
from pymongo.asynchronous.helpers import anext
|
||||
from pymongo.collation import (
|
||||
Collation,
|
||||
CollationAlternate,
|
||||
CollationCaseFirst,
|
||||
CollationMaxVariable,
|
||||
CollationStrength,
|
||||
)
|
||||
from pymongo.errors import ConfigurationError
|
||||
from pymongo.operations import (
|
||||
DeleteMany,
|
||||
DeleteOne,
|
||||
IndexModel,
|
||||
ReplaceOne,
|
||||
UpdateMany,
|
||||
UpdateOne,
|
||||
)
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class TestCollationObject(unittest.TestCase):
|
||||
def test_constructor(self):
|
||||
self.assertRaises(TypeError, Collation, locale=42)
|
||||
# Fill in a locale to test the other options.
|
||||
_Collation = functools.partial(Collation, "en_US")
|
||||
# No error.
|
||||
_Collation(caseFirst=CollationCaseFirst.UPPER)
|
||||
self.assertRaises(TypeError, _Collation, caseLevel="true")
|
||||
self.assertRaises(ValueError, _Collation, strength="six")
|
||||
self.assertRaises(TypeError, _Collation, numericOrdering="true")
|
||||
self.assertRaises(TypeError, _Collation, alternate=5)
|
||||
self.assertRaises(TypeError, _Collation, maxVariable=2)
|
||||
self.assertRaises(TypeError, _Collation, normalization="false")
|
||||
self.assertRaises(TypeError, _Collation, backwards="true")
|
||||
|
||||
# No errors.
|
||||
Collation("en_US", future_option="bar", another_option=42)
|
||||
collation = Collation(
|
||||
"en_US",
|
||||
caseLevel=True,
|
||||
caseFirst=CollationCaseFirst.UPPER,
|
||||
strength=CollationStrength.QUATERNARY,
|
||||
numericOrdering=True,
|
||||
alternate=CollationAlternate.SHIFTED,
|
||||
maxVariable=CollationMaxVariable.SPACE,
|
||||
normalization=True,
|
||||
backwards=True,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
{
|
||||
"locale": "en_US",
|
||||
"caseLevel": True,
|
||||
"caseFirst": "upper",
|
||||
"strength": 4,
|
||||
"numericOrdering": True,
|
||||
"alternate": "shifted",
|
||||
"maxVariable": "space",
|
||||
"normalization": True,
|
||||
"backwards": True,
|
||||
},
|
||||
collation.document,
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
{"locale": "en_US", "backwards": True}, Collation("en_US", backwards=True).document
|
||||
)
|
||||
|
||||
|
||||
class TestCollation(AsyncIntegrationTest):
|
||||
listener: EventListener
|
||||
warn_context: Any
|
||||
collation: Collation
|
||||
|
||||
@classmethod
|
||||
@async_client_context.require_connection
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
cls.listener = EventListener()
|
||||
cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener])
|
||||
cls.db = cls.client.pymongo_test
|
||||
cls.collation = Collation("en_US")
|
||||
cls.warn_context = warnings.catch_warnings()
|
||||
cls.warn_context.__enter__()
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
|
||||
@classmethod
|
||||
async def _tearDown_class(cls):
|
||||
cls.warn_context.__exit__()
|
||||
cls.warn_context = None
|
||||
await cls.client.close()
|
||||
await super()._tearDown_class()
|
||||
|
||||
def tearDown(self):
|
||||
self.listener.reset()
|
||||
super().tearDown()
|
||||
|
||||
def last_command_started(self):
|
||||
return self.listener.started_events[-1].command
|
||||
|
||||
def assertCollationInLastCommand(self):
|
||||
self.assertEqual(self.collation.document, self.last_command_started()["collation"])
|
||||
|
||||
async def test_create_collection(self):
|
||||
await self.db.test.drop()
|
||||
await self.db.create_collection("test", collation=self.collation)
|
||||
self.assertCollationInLastCommand()
|
||||
|
||||
# Test passing collation as a dict as well.
|
||||
await self.db.test.drop()
|
||||
self.listener.reset()
|
||||
await self.db.create_collection("test", collation=self.collation.document)
|
||||
self.assertCollationInLastCommand()
|
||||
|
||||
def test_index_model(self):
|
||||
model = IndexModel([("a", 1), ("b", -1)], collation=self.collation)
|
||||
self.assertEqual(self.collation.document, model.document["collation"])
|
||||
|
||||
async def test_create_index(self):
|
||||
await self.db.test.create_index("foo", collation=self.collation)
|
||||
ci_cmd = self.listener.started_events[0].command
|
||||
self.assertEqual(self.collation.document, ci_cmd["indexes"][0]["collation"])
|
||||
|
||||
async def test_aggregate(self):
|
||||
await self.db.test.aggregate([{"$group": {"_id": 42}}], collation=self.collation)
|
||||
self.assertCollationInLastCommand()
|
||||
|
||||
async def test_count_documents(self):
|
||||
await self.db.test.count_documents({}, collation=self.collation)
|
||||
self.assertCollationInLastCommand()
|
||||
|
||||
async def test_distinct(self):
|
||||
await self.db.test.distinct("foo", collation=self.collation)
|
||||
self.assertCollationInLastCommand()
|
||||
|
||||
self.listener.reset()
|
||||
await self.db.test.find(collation=self.collation).distinct("foo")
|
||||
self.assertCollationInLastCommand()
|
||||
|
||||
async def test_find_command(self):
|
||||
await self.db.test.insert_one({"is this thing on?": True})
|
||||
self.listener.reset()
|
||||
await anext(self.db.test.find(collation=self.collation))
|
||||
self.assertCollationInLastCommand()
|
||||
|
||||
async def test_explain_command(self):
|
||||
self.listener.reset()
|
||||
await self.db.test.find(collation=self.collation).explain()
|
||||
# The collation should be part of the explained command.
|
||||
self.assertEqual(
|
||||
self.collation.document, self.last_command_started()["explain"]["collation"]
|
||||
)
|
||||
|
||||
async def test_delete(self):
|
||||
await self.db.test.delete_one({"foo": 42}, collation=self.collation)
|
||||
command = self.listener.started_events[0].command
|
||||
self.assertEqual(self.collation.document, command["deletes"][0]["collation"])
|
||||
|
||||
self.listener.reset()
|
||||
await self.db.test.delete_many({"foo": 42}, collation=self.collation)
|
||||
command = self.listener.started_events[0].command
|
||||
self.assertEqual(self.collation.document, command["deletes"][0]["collation"])
|
||||
|
||||
async def test_update(self):
|
||||
await self.db.test.replace_one({"foo": 42}, {"foo": 43}, collation=self.collation)
|
||||
command = self.listener.started_events[0].command
|
||||
self.assertEqual(self.collation.document, command["updates"][0]["collation"])
|
||||
|
||||
self.listener.reset()
|
||||
await self.db.test.update_one({"foo": 42}, {"$set": {"foo": 43}}, collation=self.collation)
|
||||
command = self.listener.started_events[0].command
|
||||
self.assertEqual(self.collation.document, command["updates"][0]["collation"])
|
||||
|
||||
self.listener.reset()
|
||||
await self.db.test.update_many({"foo": 42}, {"$set": {"foo": 43}}, collation=self.collation)
|
||||
command = self.listener.started_events[0].command
|
||||
self.assertEqual(self.collation.document, command["updates"][0]["collation"])
|
||||
|
||||
async def test_find_and(self):
|
||||
await self.db.test.find_one_and_delete({"foo": 42}, collation=self.collation)
|
||||
self.assertCollationInLastCommand()
|
||||
|
||||
self.listener.reset()
|
||||
await self.db.test.find_one_and_update(
|
||||
{"foo": 42}, {"$set": {"foo": 43}}, collation=self.collation
|
||||
)
|
||||
self.assertCollationInLastCommand()
|
||||
|
||||
self.listener.reset()
|
||||
await self.db.test.find_one_and_replace({"foo": 42}, {"foo": 43}, collation=self.collation)
|
||||
self.assertCollationInLastCommand()
|
||||
|
||||
async def test_bulk_write(self):
|
||||
await self.db.test.collection.bulk_write(
|
||||
[
|
||||
DeleteOne({"noCollation": 42}),
|
||||
DeleteMany({"noCollation": 42}),
|
||||
DeleteOne({"foo": 42}, collation=self.collation),
|
||||
DeleteMany({"foo": 42}, collation=self.collation),
|
||||
ReplaceOne({"noCollation": 24}, {"bar": 42}),
|
||||
UpdateOne({"noCollation": 84}, {"$set": {"bar": 10}}, upsert=True),
|
||||
UpdateMany({"noCollation": 45}, {"$set": {"bar": 42}}),
|
||||
ReplaceOne({"foo": 24}, {"foo": 42}, collation=self.collation),
|
||||
UpdateOne(
|
||||
{"foo": 84}, {"$set": {"foo": 10}}, upsert=True, collation=self.collation
|
||||
),
|
||||
UpdateMany({"foo": 45}, {"$set": {"foo": 42}}, collation=self.collation),
|
||||
]
|
||||
)
|
||||
|
||||
delete_cmd = self.listener.started_events[0].command
|
||||
update_cmd = self.listener.started_events[1].command
|
||||
|
||||
def check_ops(ops):
|
||||
for op in ops:
|
||||
if "noCollation" in op["q"]:
|
||||
self.assertNotIn("collation", op)
|
||||
else:
|
||||
self.assertEqual(self.collation.document, op["collation"])
|
||||
|
||||
check_ops(delete_cmd["deletes"])
|
||||
check_ops(update_cmd["updates"])
|
||||
|
||||
async def test_indexes_same_keys_different_collations(self):
|
||||
await self.db.test.drop()
|
||||
usa_collation = Collation("en_US")
|
||||
ja_collation = Collation("ja")
|
||||
await self.db.test.create_indexes(
|
||||
[
|
||||
IndexModel("fieldname", collation=usa_collation),
|
||||
IndexModel("fieldname", name="japanese_version", collation=ja_collation),
|
||||
IndexModel("fieldname", name="simple"),
|
||||
]
|
||||
)
|
||||
indexes = await self.db.test.index_information()
|
||||
self.assertEqual(
|
||||
usa_collation.document["locale"], indexes["fieldname_1"]["collation"]["locale"]
|
||||
)
|
||||
self.assertEqual(
|
||||
ja_collation.document["locale"], indexes["japanese_version"]["collation"]["locale"]
|
||||
)
|
||||
self.assertNotIn("collation", indexes["simple"])
|
||||
await self.db.test.drop_index("fieldname_1")
|
||||
indexes = await self.db.test.index_information()
|
||||
self.assertIn("japanese_version", indexes)
|
||||
self.assertIn("simple", indexes)
|
||||
self.assertNotIn("fieldname", indexes)
|
||||
|
||||
async def test_unacknowledged_write(self):
|
||||
unacknowledged = WriteConcern(w=0)
|
||||
collection = self.db.get_collection("test", write_concern=unacknowledged)
|
||||
with self.assertRaises(ConfigurationError):
|
||||
await collection.update_one(
|
||||
{"hello": "world"}, {"$set": {"hello": "moon"}}, collation=self.collation
|
||||
)
|
||||
update_one = UpdateOne(
|
||||
{"hello": "world"}, {"$set": {"hello": "moon"}}, collation=self.collation
|
||||
)
|
||||
with self.assertRaises(ConfigurationError):
|
||||
await collection.bulk_write([update_one])
|
||||
|
||||
async def test_cursor_collation(self):
|
||||
await self.db.test.insert_one({"hello": "world"})
|
||||
await anext(self.db.test.find().collation(self.collation))
|
||||
self.assertCollationInLastCommand()
|
||||
44
test/asynchronous/test_command_logging.py
Normal file
44
test/asynchronous/test_command_logging.py
Normal file
@ -0,0 +1,44 @@
|
||||
# Copyright 2023-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Run the command monitoring unified format spec tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test.asynchronous.unified_format import generate_test_classes
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# Location of JSON test specifications.
|
||||
if _IS_SYNC:
|
||||
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_logging")
|
||||
else:
|
||||
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_logging")
|
||||
|
||||
|
||||
globals().update(
|
||||
generate_test_classes(
|
||||
_TEST_PATH,
|
||||
module=__name__,
|
||||
)
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
45
test/asynchronous/test_command_monitoring.py
Normal file
45
test/asynchronous/test_command_monitoring.py
Normal file
@ -0,0 +1,45 @@
|
||||
# Copyright 2015-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Run the command monitoring unified format spec tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test.asynchronous.unified_format import generate_test_classes
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# Location of JSON test specifications.
|
||||
if _IS_SYNC:
|
||||
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_monitoring")
|
||||
else:
|
||||
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_monitoring")
|
||||
|
||||
|
||||
globals().update(
|
||||
generate_test_classes(
|
||||
_TEST_PATH,
|
||||
module=__name__,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
185
test/asynchronous/test_common.py
Normal file
185
test/asynchronous/test_common.py
Normal file
@ -0,0 +1,185 @@
|
||||
# Copyright 2011-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test the pymongo common module."""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.asynchronous import AsyncIntegrationTest, async_client_context, connected, unittest
|
||||
|
||||
from bson.binary import PYTHON_LEGACY, STANDARD, Binary, UuidRepresentation
|
||||
from bson.codec_options import CodecOptions
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.errors import OperationFailure
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
|
||||
class TestCommon(AsyncIntegrationTest):
|
||||
async def test_uuid_representation(self):
|
||||
coll = self.db.uuid
|
||||
await coll.drop()
|
||||
|
||||
# Test property
|
||||
self.assertEqual(UuidRepresentation.UNSPECIFIED, coll.codec_options.uuid_representation)
|
||||
|
||||
# Test basic query
|
||||
uu = uuid.uuid4()
|
||||
# Insert as binary subtype 3
|
||||
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
|
||||
legacy_opts = coll.codec_options
|
||||
await coll.insert_one({"uu": uu})
|
||||
self.assertEqual(uu, (await coll.find_one({"uu": uu}))["uu"]) # type: ignore
|
||||
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=STANDARD))
|
||||
self.assertEqual(STANDARD, coll.codec_options.uuid_representation)
|
||||
self.assertEqual(None, await coll.find_one({"uu": uu}))
|
||||
uul = Binary.from_uuid(uu, PYTHON_LEGACY)
|
||||
self.assertEqual(uul, (await coll.find_one({"uu": uul}))["uu"]) # type: ignore
|
||||
|
||||
# Test count_documents
|
||||
self.assertEqual(0, await coll.count_documents({"uu": uu}))
|
||||
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
|
||||
self.assertEqual(1, await coll.count_documents({"uu": uu}))
|
||||
|
||||
# Test delete
|
||||
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=STANDARD))
|
||||
await coll.delete_one({"uu": uu})
|
||||
self.assertEqual(1, await coll.count_documents({}))
|
||||
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
|
||||
await coll.delete_one({"uu": uu})
|
||||
self.assertEqual(0, await coll.count_documents({}))
|
||||
|
||||
# Test update_one
|
||||
await coll.insert_one({"_id": uu, "i": 1})
|
||||
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=STANDARD))
|
||||
await coll.update_one({"_id": uu}, {"$set": {"i": 2}})
|
||||
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
|
||||
self.assertEqual(1, (await coll.find_one({"_id": uu}))["i"]) # type: ignore
|
||||
await coll.update_one({"_id": uu}, {"$set": {"i": 2}})
|
||||
self.assertEqual(2, (await coll.find_one({"_id": uu}))["i"]) # type: ignore
|
||||
|
||||
# Test Cursor.distinct
|
||||
self.assertEqual([2], await coll.find({"_id": uu}).distinct("i"))
|
||||
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=STANDARD))
|
||||
self.assertEqual([], await coll.find({"_id": uu}).distinct("i"))
|
||||
|
||||
# Test findAndModify
|
||||
self.assertEqual(None, await coll.find_one_and_update({"_id": uu}, {"$set": {"i": 5}}))
|
||||
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
|
||||
self.assertEqual(2, (await coll.find_one_and_update({"_id": uu}, {"$set": {"i": 5}}))["i"])
|
||||
self.assertEqual(5, (await coll.find_one({"_id": uu}))["i"]) # type: ignore
|
||||
|
||||
# Test command
|
||||
self.assertEqual(
|
||||
5,
|
||||
(
|
||||
await self.db.command(
|
||||
"findAndModify",
|
||||
"uuid",
|
||||
update={"$set": {"i": 6}},
|
||||
query={"_id": uu},
|
||||
codec_options=legacy_opts,
|
||||
)
|
||||
)["value"]["i"],
|
||||
)
|
||||
self.assertEqual(
|
||||
6,
|
||||
(
|
||||
await self.db.command(
|
||||
"findAndModify",
|
||||
"uuid",
|
||||
update={"$set": {"i": 7}},
|
||||
query={"_id": Binary.from_uuid(uu, PYTHON_LEGACY)},
|
||||
)
|
||||
)["value"]["i"],
|
||||
)
|
||||
|
||||
async def test_write_concern(self):
|
||||
c = await self.async_rs_or_single_client(connect=False)
|
||||
self.assertEqual(WriteConcern(), c.write_concern)
|
||||
|
||||
c = await self.async_rs_or_single_client(connect=False, w=2, wTimeoutMS=1000)
|
||||
wc = WriteConcern(w=2, wtimeout=1000)
|
||||
self.assertEqual(wc, c.write_concern)
|
||||
|
||||
# Can we override back to the server default?
|
||||
db = c.get_database("pymongo_test", write_concern=WriteConcern())
|
||||
self.assertEqual(db.write_concern, WriteConcern())
|
||||
|
||||
db = c.pymongo_test
|
||||
self.assertEqual(wc, db.write_concern)
|
||||
coll = db.test
|
||||
self.assertEqual(wc, coll.write_concern)
|
||||
|
||||
cwc = WriteConcern(j=True)
|
||||
coll = db.get_collection("test", write_concern=cwc)
|
||||
self.assertEqual(cwc, coll.write_concern)
|
||||
self.assertEqual(wc, db.write_concern)
|
||||
|
||||
async def test_mongo_client(self):
|
||||
pair = await async_client_context.pair
|
||||
m = await self.async_rs_or_single_client(w=0)
|
||||
coll = m.pymongo_test.write_concern_test
|
||||
await coll.drop()
|
||||
doc = {"_id": ObjectId()}
|
||||
await coll.insert_one(doc)
|
||||
self.assertTrue(await coll.insert_one(doc))
|
||||
coll = coll.with_options(write_concern=WriteConcern(w=1))
|
||||
with self.assertRaises(OperationFailure):
|
||||
await coll.insert_one(doc)
|
||||
|
||||
m = await self.async_rs_or_single_client()
|
||||
coll = m.pymongo_test.write_concern_test
|
||||
new_coll = coll.with_options(write_concern=WriteConcern(w=0))
|
||||
self.assertTrue(await new_coll.insert_one(doc))
|
||||
with self.assertRaises(OperationFailure):
|
||||
await coll.insert_one(doc)
|
||||
|
||||
m = await self.async_rs_or_single_client(
|
||||
f"mongodb://{pair}/", replicaSet=async_client_context.replica_set_name
|
||||
)
|
||||
|
||||
coll = m.pymongo_test.write_concern_test
|
||||
with self.assertRaises(OperationFailure):
|
||||
await coll.insert_one(doc)
|
||||
m = await self.async_rs_or_single_client(
|
||||
f"mongodb://{pair}/?w=0", replicaSet=async_client_context.replica_set_name
|
||||
)
|
||||
|
||||
coll = m.pymongo_test.write_concern_test
|
||||
await coll.insert_one(doc)
|
||||
|
||||
# Equality tests
|
||||
direct = await connected(await self.async_single_client(w=0))
|
||||
direct2 = await connected(
|
||||
await self.async_single_client(f"mongodb://{pair}/?w=0", **self.credentials)
|
||||
)
|
||||
self.assertEqual(direct, direct2)
|
||||
self.assertFalse(direct != direct2)
|
||||
|
||||
async def test_validate_boolean(self):
|
||||
await self.db.test.update_one({}, {"$set": {"total": 1}}, upsert=True)
|
||||
with self.assertRaisesRegex(
|
||||
TypeError, "upsert must be True or False, was: upsert={'upsert': True}"
|
||||
):
|
||||
await self.db.test.update_one({}, {"$set": {"total": 1}}, {"upsert": True}) # type: ignore
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
45
test/asynchronous/test_connection_logging.py
Normal file
45
test/asynchronous/test_connection_logging.py
Normal file
@ -0,0 +1,45 @@
|
||||
# Copyright 2023-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Run the connection logging unified format spec tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test.unified_format import generate_test_classes
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# Location of JSON test specifications.
|
||||
if _IS_SYNC:
|
||||
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "connection_logging")
|
||||
else:
|
||||
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "connection_logging")
|
||||
|
||||
|
||||
globals().update(
|
||||
generate_test_classes(
|
||||
_TEST_PATH,
|
||||
module=__name__,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
39
test/asynchronous/test_crud_unified.py
Normal file
39
test/asynchronous/test_crud_unified.py
Normal file
@ -0,0 +1,39 @@
|
||||
# Copyright 2021-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test the CRUD unified spec tests."""
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test import unittest
|
||||
from test.unified_format import generate_test_classes
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
# Location of JSON test specifications.
|
||||
if _IS_SYNC:
|
||||
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "crud", "unified")
|
||||
else:
|
||||
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "crud", "unified")
|
||||
|
||||
# Generate unified tests.
|
||||
globals().update(generate_test_classes(_TEST_PATH, module=__name__, RUN_ON_SERVERLESS=True))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -30,6 +30,7 @@ import uuid
|
||||
import warnings
|
||||
from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, async_client_context
|
||||
from test.asynchronous.test_bulk import AsyncBulkTestBase
|
||||
from test.asynchronous.utils_spec_runner import AsyncSpecRunner, AsyncSpecTestCreator
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, Mapping, Optional
|
||||
|
||||
@ -59,7 +60,6 @@ from test.unified_format import generate_test_classes
|
||||
from test.utils import (
|
||||
AllowListEventListener,
|
||||
OvertCommandListener,
|
||||
SpecTestCreator,
|
||||
TopologyEventListener,
|
||||
async_wait_until,
|
||||
camel_to_snake_args,
|
||||
@ -626,132 +626,132 @@ AWS_TEMP_NO_SESSION_CREDS = {
|
||||
KMS_TLS_OPTS = {"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}}
|
||||
|
||||
|
||||
if _IS_SYNC:
|
||||
# TODO: Add asynchronous SpecRunner (https://jira.mongodb.org/browse/PYTHON-4700)
|
||||
class TestSpec(AsyncSpecRunner):
|
||||
@classmethod
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
class AsyncTestSpec(AsyncSpecRunner):
|
||||
@classmethod
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
|
||||
def parse_auto_encrypt_opts(self, opts):
|
||||
"""Parse clientOptions.autoEncryptOpts."""
|
||||
opts = camel_to_snake_args(opts)
|
||||
kms_providers = opts["kms_providers"]
|
||||
if "aws" in kms_providers:
|
||||
kms_providers["aws"] = AWS_CREDS
|
||||
if not any(AWS_CREDS.values()):
|
||||
self.skipTest("AWS environment credentials are not set")
|
||||
if "awsTemporary" in kms_providers:
|
||||
kms_providers["aws"] = AWS_TEMP_CREDS
|
||||
del kms_providers["awsTemporary"]
|
||||
if not any(AWS_TEMP_CREDS.values()):
|
||||
self.skipTest("AWS Temp environment credentials are not set")
|
||||
if "awsTemporaryNoSessionToken" in kms_providers:
|
||||
kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS
|
||||
del kms_providers["awsTemporaryNoSessionToken"]
|
||||
if not any(AWS_TEMP_NO_SESSION_CREDS.values()):
|
||||
self.skipTest("AWS Temp environment credentials are not set")
|
||||
if "azure" in kms_providers:
|
||||
kms_providers["azure"] = AZURE_CREDS
|
||||
if not any(AZURE_CREDS.values()):
|
||||
self.skipTest("Azure environment credentials are not set")
|
||||
if "gcp" in kms_providers:
|
||||
kms_providers["gcp"] = GCP_CREDS
|
||||
if not any(AZURE_CREDS.values()):
|
||||
self.skipTest("GCP environment credentials are not set")
|
||||
if "kmip" in kms_providers:
|
||||
kms_providers["kmip"] = KMIP_CREDS
|
||||
opts["kms_tls_options"] = KMS_TLS_OPTS
|
||||
if "key_vault_namespace" not in opts:
|
||||
opts["key_vault_namespace"] = "keyvault.datakeys"
|
||||
if "extra_options" in opts:
|
||||
opts.update(camel_to_snake_args(opts.pop("extra_options")))
|
||||
def parse_auto_encrypt_opts(self, opts):
|
||||
"""Parse clientOptions.autoEncryptOpts."""
|
||||
opts = camel_to_snake_args(opts)
|
||||
kms_providers = opts["kms_providers"]
|
||||
if "aws" in kms_providers:
|
||||
kms_providers["aws"] = AWS_CREDS
|
||||
if not any(AWS_CREDS.values()):
|
||||
self.skipTest("AWS environment credentials are not set")
|
||||
if "awsTemporary" in kms_providers:
|
||||
kms_providers["aws"] = AWS_TEMP_CREDS
|
||||
del kms_providers["awsTemporary"]
|
||||
if not any(AWS_TEMP_CREDS.values()):
|
||||
self.skipTest("AWS Temp environment credentials are not set")
|
||||
if "awsTemporaryNoSessionToken" in kms_providers:
|
||||
kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS
|
||||
del kms_providers["awsTemporaryNoSessionToken"]
|
||||
if not any(AWS_TEMP_NO_SESSION_CREDS.values()):
|
||||
self.skipTest("AWS Temp environment credentials are not set")
|
||||
if "azure" in kms_providers:
|
||||
kms_providers["azure"] = AZURE_CREDS
|
||||
if not any(AZURE_CREDS.values()):
|
||||
self.skipTest("Azure environment credentials are not set")
|
||||
if "gcp" in kms_providers:
|
||||
kms_providers["gcp"] = GCP_CREDS
|
||||
if not any(AZURE_CREDS.values()):
|
||||
self.skipTest("GCP environment credentials are not set")
|
||||
if "kmip" in kms_providers:
|
||||
kms_providers["kmip"] = KMIP_CREDS
|
||||
opts["kms_tls_options"] = KMS_TLS_OPTS
|
||||
if "key_vault_namespace" not in opts:
|
||||
opts["key_vault_namespace"] = "keyvault.datakeys"
|
||||
if "extra_options" in opts:
|
||||
opts.update(camel_to_snake_args(opts.pop("extra_options")))
|
||||
|
||||
opts = dict(opts)
|
||||
return AutoEncryptionOpts(**opts)
|
||||
opts = dict(opts)
|
||||
return AutoEncryptionOpts(**opts)
|
||||
|
||||
def parse_client_options(self, opts):
|
||||
"""Override clientOptions parsing to support autoEncryptOpts."""
|
||||
encrypt_opts = opts.pop("autoEncryptOpts", None)
|
||||
if encrypt_opts:
|
||||
opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts)
|
||||
def parse_client_options(self, opts):
|
||||
"""Override clientOptions parsing to support autoEncryptOpts."""
|
||||
encrypt_opts = opts.pop("autoEncryptOpts", None)
|
||||
if encrypt_opts:
|
||||
opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts)
|
||||
|
||||
return super().parse_client_options(opts)
|
||||
return super().parse_client_options(opts)
|
||||
|
||||
def get_object_name(self, op):
|
||||
"""Default object is collection."""
|
||||
return op.get("object", "collection")
|
||||
def get_object_name(self, op):
|
||||
"""Default object is collection."""
|
||||
return op.get("object", "collection")
|
||||
|
||||
def maybe_skip_scenario(self, test):
|
||||
super().maybe_skip_scenario(test)
|
||||
desc = test["description"].lower()
|
||||
if (
|
||||
"timeoutms applied to listcollections to get collection schema" in desc
|
||||
and sys.platform in ("win32", "darwin")
|
||||
):
|
||||
self.skipTest("PYTHON-3706 flaky test on Windows/macOS")
|
||||
if "type=symbol" in desc:
|
||||
self.skipTest("PyMongo does not support the symbol type")
|
||||
def maybe_skip_scenario(self, test):
|
||||
super().maybe_skip_scenario(test)
|
||||
desc = test["description"].lower()
|
||||
if (
|
||||
"timeoutms applied to listcollections to get collection schema" in desc
|
||||
and sys.platform in ("win32", "darwin")
|
||||
):
|
||||
self.skipTest("PYTHON-3706 flaky test on Windows/macOS")
|
||||
if "type=symbol" in desc:
|
||||
self.skipTest("PyMongo does not support the symbol type")
|
||||
if "timeoutms applied to listcollections to get collection schema" in desc and not _IS_SYNC:
|
||||
self.skipTest("PYTHON-4844 flaky test on async")
|
||||
|
||||
def setup_scenario(self, scenario_def):
|
||||
"""Override a test's setup."""
|
||||
key_vault_data = scenario_def["key_vault_data"]
|
||||
encrypted_fields = scenario_def["encrypted_fields"]
|
||||
json_schema = scenario_def["json_schema"]
|
||||
data = scenario_def["data"]
|
||||
coll = async_client_context.client.get_database("keyvault", codec_options=OPTS)[
|
||||
"datakeys"
|
||||
]
|
||||
coll.delete_many({})
|
||||
if key_vault_data:
|
||||
coll.insert_many(key_vault_data)
|
||||
async def setup_scenario(self, scenario_def):
|
||||
"""Override a test's setup."""
|
||||
key_vault_data = scenario_def["key_vault_data"]
|
||||
encrypted_fields = scenario_def["encrypted_fields"]
|
||||
json_schema = scenario_def["json_schema"]
|
||||
data = scenario_def["data"]
|
||||
coll = async_client_context.client.get_database("keyvault", codec_options=OPTS)["datakeys"]
|
||||
await coll.delete_many({})
|
||||
if key_vault_data:
|
||||
await coll.insert_many(key_vault_data)
|
||||
|
||||
db_name = self.get_scenario_db_name(scenario_def)
|
||||
coll_name = self.get_scenario_coll_name(scenario_def)
|
||||
db = async_client_context.client.get_database(db_name, codec_options=OPTS)
|
||||
coll = db.drop_collection(coll_name, encrypted_fields=encrypted_fields)
|
||||
wc = WriteConcern(w="majority")
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if json_schema:
|
||||
kwargs["validator"] = {"$jsonSchema": json_schema}
|
||||
kwargs["codec_options"] = OPTS
|
||||
if not data:
|
||||
kwargs["write_concern"] = wc
|
||||
if encrypted_fields:
|
||||
kwargs["encryptedFields"] = encrypted_fields
|
||||
db.create_collection(coll_name, **kwargs)
|
||||
coll = db[coll_name]
|
||||
if data:
|
||||
# Load data.
|
||||
coll.with_options(write_concern=wc).insert_many(scenario_def["data"])
|
||||
db_name = self.get_scenario_db_name(scenario_def)
|
||||
coll_name = self.get_scenario_coll_name(scenario_def)
|
||||
db = async_client_context.client.get_database(db_name, codec_options=OPTS)
|
||||
await db.drop_collection(coll_name, encrypted_fields=encrypted_fields)
|
||||
wc = WriteConcern(w="majority")
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if json_schema:
|
||||
kwargs["validator"] = {"$jsonSchema": json_schema}
|
||||
kwargs["codec_options"] = OPTS
|
||||
if not data:
|
||||
kwargs["write_concern"] = wc
|
||||
if encrypted_fields:
|
||||
kwargs["encryptedFields"] = encrypted_fields
|
||||
await db.create_collection(coll_name, **kwargs)
|
||||
coll = db[coll_name]
|
||||
if data:
|
||||
# Load data.
|
||||
await coll.with_options(write_concern=wc).insert_many(scenario_def["data"])
|
||||
|
||||
def allowable_errors(self, op):
|
||||
"""Override expected error classes."""
|
||||
errors = super().allowable_errors(op)
|
||||
# An updateOne test expects encryption to error when no $ operator
|
||||
# appears but pymongo raises a client side ValueError in this case.
|
||||
if op["name"] == "updateOne":
|
||||
errors += (ValueError,)
|
||||
return errors
|
||||
def allowable_errors(self, op):
|
||||
"""Override expected error classes."""
|
||||
errors = super().allowable_errors(op)
|
||||
# An updateOne test expects encryption to error when no $ operator
|
||||
# appears but pymongo raises a client side ValueError in this case.
|
||||
if op["name"] == "updateOne":
|
||||
errors += (ValueError,)
|
||||
return errors
|
||||
|
||||
def create_test(scenario_def, test, name):
|
||||
@async_client_context.require_test_commands
|
||||
def run_scenario(self):
|
||||
self.run_scenario(scenario_def, test)
|
||||
|
||||
return run_scenario
|
||||
async def create_test(scenario_def, test, name):
|
||||
@async_client_context.require_test_commands
|
||||
async def run_scenario(self):
|
||||
await self.run_scenario(scenario_def, test)
|
||||
|
||||
test_creator = SpecTestCreator(create_test, TestSpec, os.path.join(SPEC_PATH, "legacy"))
|
||||
test_creator.create_tests()
|
||||
return run_scenario
|
||||
|
||||
if _HAVE_PYMONGOCRYPT:
|
||||
globals().update(
|
||||
generate_test_classes(
|
||||
os.path.join(SPEC_PATH, "unified"),
|
||||
module=__name__,
|
||||
)
|
||||
|
||||
test_creator = AsyncSpecTestCreator(create_test, AsyncTestSpec, os.path.join(SPEC_PATH, "legacy"))
|
||||
test_creator.create_tests()
|
||||
|
||||
if _HAVE_PYMONGOCRYPT:
|
||||
globals().update(
|
||||
generate_test_classes(
|
||||
os.path.join(SPEC_PATH, "unified"),
|
||||
module=__name__,
|
||||
)
|
||||
)
|
||||
|
||||
# Prose Tests
|
||||
ALL_KMS_PROVIDERS = {
|
||||
|
||||
1573
test/asynchronous/unified_format.py
Normal file
1573
test/asynchronous/unified_format.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -15,8 +15,12 @@
|
||||
"""Utilities for testing driver specs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import os
|
||||
import threading
|
||||
import unittest
|
||||
from asyncio import iscoroutinefunction
|
||||
from collections import abc
|
||||
from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs
|
||||
from test.utils import (
|
||||
@ -24,6 +28,7 @@ from test.utils import (
|
||||
CompareType,
|
||||
EventListener,
|
||||
OvertCommandListener,
|
||||
ScenarioDict,
|
||||
ServerAndTopologyEventListener,
|
||||
camel_to_snake,
|
||||
camel_to_snake_args,
|
||||
@ -32,11 +37,12 @@ from test.utils import (
|
||||
)
|
||||
from typing import List
|
||||
|
||||
from bson import ObjectId, decode, encode
|
||||
from bson import ObjectId, decode, encode, json_util
|
||||
from bson.binary import Binary
|
||||
from bson.int64 import Int64
|
||||
from bson.son import SON
|
||||
from gridfs import GridFSBucket
|
||||
from gridfs.asynchronous.grid_file import AsyncGridFSBucket
|
||||
from pymongo.asynchronous import client_session
|
||||
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
|
||||
from pymongo.asynchronous.cursor import AsyncCursor
|
||||
@ -83,6 +89,161 @@ class SpecRunnerThread(threading.Thread):
|
||||
self.stop()
|
||||
|
||||
|
||||
class AsyncSpecTestCreator:
|
||||
"""Class to create test cases from specifications."""
|
||||
|
||||
def __init__(self, create_test, test_class, test_path):
|
||||
"""Create a TestCreator object.
|
||||
|
||||
:Parameters:
|
||||
- `create_test`: callback that returns a test case. The callback
|
||||
must accept the following arguments - a dictionary containing the
|
||||
entire test specification (the `scenario_def`), a dictionary
|
||||
containing the specification for which the test case will be
|
||||
generated (the `test_def`).
|
||||
- `test_class`: the unittest.TestCase class in which to create the
|
||||
test case.
|
||||
- `test_path`: path to the directory containing the JSON files with
|
||||
the test specifications.
|
||||
"""
|
||||
self._create_test = create_test
|
||||
self._test_class = test_class
|
||||
self.test_path = test_path
|
||||
|
||||
def _ensure_min_max_server_version(self, scenario_def, method):
|
||||
"""Test modifier that enforces a version range for the server on a
|
||||
test case.
|
||||
"""
|
||||
if "minServerVersion" in scenario_def:
|
||||
min_ver = tuple(int(elt) for elt in scenario_def["minServerVersion"].split("."))
|
||||
if min_ver is not None:
|
||||
method = async_client_context.require_version_min(*min_ver)(method)
|
||||
|
||||
if "maxServerVersion" in scenario_def:
|
||||
max_ver = tuple(int(elt) for elt in scenario_def["maxServerVersion"].split("."))
|
||||
if max_ver is not None:
|
||||
method = async_client_context.require_version_max(*max_ver)(method)
|
||||
|
||||
if "serverless" in scenario_def:
|
||||
serverless = scenario_def["serverless"]
|
||||
if serverless == "require":
|
||||
serverless_satisfied = async_client_context.serverless
|
||||
elif serverless == "forbid":
|
||||
serverless_satisfied = not async_client_context.serverless
|
||||
else: # unset or "allow"
|
||||
serverless_satisfied = True
|
||||
method = unittest.skipUnless(
|
||||
serverless_satisfied, "Serverless requirement not satisfied"
|
||||
)(method)
|
||||
|
||||
return method
|
||||
|
||||
@staticmethod
|
||||
async def valid_topology(run_on_req):
|
||||
return await async_client_context.is_topology_type(
|
||||
run_on_req.get("topology", ["single", "replicaset", "sharded", "load-balanced"])
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def min_server_version(run_on_req):
|
||||
version = run_on_req.get("minServerVersion")
|
||||
if version:
|
||||
min_ver = tuple(int(elt) for elt in version.split("."))
|
||||
return async_client_context.version >= min_ver
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def max_server_version(run_on_req):
|
||||
version = run_on_req.get("maxServerVersion")
|
||||
if version:
|
||||
max_ver = tuple(int(elt) for elt in version.split("."))
|
||||
return async_client_context.version <= max_ver
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def valid_auth_enabled(run_on_req):
|
||||
if "authEnabled" in run_on_req:
|
||||
if run_on_req["authEnabled"]:
|
||||
return async_client_context.auth_enabled
|
||||
return not async_client_context.auth_enabled
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def serverless_ok(run_on_req):
|
||||
serverless = run_on_req["serverless"]
|
||||
if serverless == "require":
|
||||
return async_client_context.serverless
|
||||
elif serverless == "forbid":
|
||||
return not async_client_context.serverless
|
||||
else: # unset or "allow"
|
||||
return True
|
||||
|
||||
async def should_run_on(self, scenario_def):
|
||||
run_on = scenario_def.get("runOn", [])
|
||||
if not run_on:
|
||||
# Always run these tests.
|
||||
return True
|
||||
|
||||
for req in run_on:
|
||||
if (
|
||||
await self.valid_topology(req)
|
||||
and self.min_server_version(req)
|
||||
and self.max_server_version(req)
|
||||
and self.valid_auth_enabled(req)
|
||||
and self.serverless_ok(req)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def ensure_run_on(self, scenario_def, method):
|
||||
"""Test modifier that enforces a 'runOn' on a test case."""
|
||||
|
||||
async def predicate():
|
||||
return await self.should_run_on(scenario_def)
|
||||
|
||||
return async_client_context._require(predicate, "runOn not satisfied", method)
|
||||
|
||||
def tests(self, scenario_def):
|
||||
"""Allow CMAP spec test to override the location of test."""
|
||||
return scenario_def["tests"]
|
||||
|
||||
async def _create_tests(self):
|
||||
for dirpath, _, filenames in os.walk(self.test_path):
|
||||
dirname = os.path.split(dirpath)[-1]
|
||||
|
||||
for filename in filenames:
|
||||
with open(os.path.join(dirpath, filename)) as scenario_stream: # noqa: ASYNC101, RUF100
|
||||
# Use tz_aware=False to match how CodecOptions decodes
|
||||
# dates.
|
||||
opts = json_util.JSONOptions(tz_aware=False)
|
||||
scenario_def = ScenarioDict(
|
||||
json_util.loads(scenario_stream.read(), json_options=opts)
|
||||
)
|
||||
|
||||
test_type = os.path.splitext(filename)[0]
|
||||
|
||||
# Construct test from scenario.
|
||||
for test_def in self.tests(scenario_def):
|
||||
test_name = "test_{}_{}_{}".format(
|
||||
dirname,
|
||||
test_type.replace("-", "_").replace(".", "_"),
|
||||
str(test_def["description"].replace(" ", "_").replace(".", "_")),
|
||||
)
|
||||
|
||||
new_test = await self._create_test(scenario_def, test_def, test_name)
|
||||
new_test = self._ensure_min_max_server_version(scenario_def, new_test)
|
||||
new_test = self.ensure_run_on(scenario_def, new_test)
|
||||
|
||||
new_test.__name__ = test_name
|
||||
setattr(self._test_class, new_test.__name__, new_test)
|
||||
|
||||
def create_tests(self):
|
||||
if _IS_SYNC:
|
||||
self._create_tests()
|
||||
else:
|
||||
asyncio.run(self._create_tests())
|
||||
|
||||
|
||||
class AsyncSpecRunner(AsyncIntegrationTest):
|
||||
mongos_clients: List
|
||||
knobs: client_knobs
|
||||
@ -284,7 +445,7 @@ class AsyncSpecRunner(AsyncIntegrationTest):
|
||||
if object_name == "gridfsbucket":
|
||||
# Only create the GridFSBucket when we need it (for the gridfs
|
||||
# retryable reads tests).
|
||||
obj = GridFSBucket(database, bucket_name=collection.name)
|
||||
obj = AsyncGridFSBucket(database, bucket_name=collection.name)
|
||||
else:
|
||||
objects = {
|
||||
"client": database.client,
|
||||
@ -312,7 +473,10 @@ class AsyncSpecRunner(AsyncIntegrationTest):
|
||||
args.update(arguments)
|
||||
arguments = args
|
||||
|
||||
result = cmd(**dict(arguments))
|
||||
if not _IS_SYNC and iscoroutinefunction(cmd):
|
||||
result = await cmd(**dict(arguments))
|
||||
else:
|
||||
result = cmd(**dict(arguments))
|
||||
# Cleanup open change stream cursors.
|
||||
if name == "watch":
|
||||
self.addAsyncCleanup(result.close)
|
||||
@ -588,7 +752,7 @@ class AsyncSpecRunner(AsyncIntegrationTest):
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
read_concern=ReadConcern("local"),
|
||||
)
|
||||
actual_data = await (await outcome_coll.find(sort=[("_id", 1)])).to_list()
|
||||
actual_data = await outcome_coll.find(sort=[("_id", 1)]).to_list()
|
||||
|
||||
# The expected data needs to be the left hand side here otherwise
|
||||
# CompareType(Binary) doesn't work.
|
||||
|
||||
@ -110,7 +110,7 @@
|
||||
"listCollections"
|
||||
],
|
||||
"blockConnection": true,
|
||||
"blockTimeMS": 60
|
||||
"blockTimeMS": 600
|
||||
}
|
||||
},
|
||||
"clientOptions": {
|
||||
@ -119,7 +119,7 @@
|
||||
"aws": {}
|
||||
}
|
||||
},
|
||||
"timeoutMS": 50
|
||||
"timeoutMS": 500
|
||||
},
|
||||
"operations": [
|
||||
{
|
||||
|
||||
@ -61,6 +61,13 @@ class TestClientContext(UnitTest):
|
||||
|
||||
self.assertEqual(sys.getdefaultencoding(), os.environ["SETDEFAULTENCODING"])
|
||||
|
||||
def test_free_threading_is_enabled(self):
|
||||
if "free-threading build" not in sys.version:
|
||||
raise SkipTest("this test requires the Python free-threading build")
|
||||
|
||||
# If the GIL is enabled then pymongo or one of our deps does not support free-threading.
|
||||
self.assertFalse(sys._is_gil_enabled()) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -37,8 +37,11 @@ from pymongo.operations import (
|
||||
UpdateMany,
|
||||
UpdateOne,
|
||||
)
|
||||
from pymongo.synchronous.helpers import next
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class TestCollationObject(unittest.TestCase):
|
||||
def test_constructor(self):
|
||||
@ -96,8 +99,8 @@ class TestCollation(IntegrationTest):
|
||||
|
||||
@classmethod
|
||||
@client_context.require_connection
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
cls.listener = EventListener()
|
||||
cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener])
|
||||
cls.db = cls.client.pymongo_test
|
||||
@ -107,11 +110,11 @@ class TestCollation(IntegrationTest):
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
def _tearDown_class(cls):
|
||||
cls.warn_context.__exit__()
|
||||
cls.warn_context = None
|
||||
cls.client.close()
|
||||
super().tearDownClass()
|
||||
super()._tearDown_class()
|
||||
|
||||
def tearDown(self):
|
||||
self.listener.reset()
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
@ -23,8 +24,14 @@ sys.path[0:0] = [""]
|
||||
from test import unittest
|
||||
from test.unified_format import generate_test_classes
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
# Location of JSON test specifications.
|
||||
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "command_logging")
|
||||
if _IS_SYNC:
|
||||
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_logging")
|
||||
else:
|
||||
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_logging")
|
||||
|
||||
|
||||
globals().update(
|
||||
generate_test_classes(
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
@ -23,8 +24,13 @@ sys.path[0:0] = [""]
|
||||
from test import unittest
|
||||
from test.unified_format import generate_test_classes
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
# Location of JSON test specifications.
|
||||
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "command_monitoring")
|
||||
if _IS_SYNC:
|
||||
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_monitoring")
|
||||
else:
|
||||
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_monitoring")
|
||||
|
||||
|
||||
globals().update(
|
||||
|
||||
@ -28,10 +28,7 @@ from bson.objectid import ObjectId
|
||||
from pymongo.errors import OperationFailure
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
|
||||
@client_context.require_connection
|
||||
def setUpModule():
|
||||
pass
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
class TestCommon(IntegrationTest):
|
||||
@ -48,12 +45,12 @@ class TestCommon(IntegrationTest):
|
||||
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
|
||||
legacy_opts = coll.codec_options
|
||||
coll.insert_one({"uu": uu})
|
||||
self.assertEqual(uu, coll.find_one({"uu": uu})["uu"]) # type: ignore
|
||||
self.assertEqual(uu, (coll.find_one({"uu": uu}))["uu"]) # type: ignore
|
||||
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=STANDARD))
|
||||
self.assertEqual(STANDARD, coll.codec_options.uuid_representation)
|
||||
self.assertEqual(None, coll.find_one({"uu": uu}))
|
||||
uul = Binary.from_uuid(uu, PYTHON_LEGACY)
|
||||
self.assertEqual(uul, coll.find_one({"uu": uul})["uu"]) # type: ignore
|
||||
self.assertEqual(uul, (coll.find_one({"uu": uul}))["uu"]) # type: ignore
|
||||
|
||||
# Test count_documents
|
||||
self.assertEqual(0, coll.count_documents({"uu": uu}))
|
||||
@ -73,9 +70,9 @@ class TestCommon(IntegrationTest):
|
||||
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=STANDARD))
|
||||
coll.update_one({"_id": uu}, {"$set": {"i": 2}})
|
||||
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
|
||||
self.assertEqual(1, coll.find_one({"_id": uu})["i"]) # type: ignore
|
||||
self.assertEqual(1, (coll.find_one({"_id": uu}))["i"]) # type: ignore
|
||||
coll.update_one({"_id": uu}, {"$set": {"i": 2}})
|
||||
self.assertEqual(2, coll.find_one({"_id": uu})["i"]) # type: ignore
|
||||
self.assertEqual(2, (coll.find_one({"_id": uu}))["i"]) # type: ignore
|
||||
|
||||
# Test Cursor.distinct
|
||||
self.assertEqual([2], coll.find({"_id": uu}).distinct("i"))
|
||||
@ -85,27 +82,31 @@ class TestCommon(IntegrationTest):
|
||||
# Test findAndModify
|
||||
self.assertEqual(None, coll.find_one_and_update({"_id": uu}, {"$set": {"i": 5}}))
|
||||
coll = self.db.get_collection("uuid", CodecOptions(uuid_representation=PYTHON_LEGACY))
|
||||
self.assertEqual(2, coll.find_one_and_update({"_id": uu}, {"$set": {"i": 5}})["i"])
|
||||
self.assertEqual(5, coll.find_one({"_id": uu})["i"]) # type: ignore
|
||||
self.assertEqual(2, (coll.find_one_and_update({"_id": uu}, {"$set": {"i": 5}}))["i"])
|
||||
self.assertEqual(5, (coll.find_one({"_id": uu}))["i"]) # type: ignore
|
||||
|
||||
# Test command
|
||||
self.assertEqual(
|
||||
5,
|
||||
self.db.command(
|
||||
"findAndModify",
|
||||
"uuid",
|
||||
update={"$set": {"i": 6}},
|
||||
query={"_id": uu},
|
||||
codec_options=legacy_opts,
|
||||
(
|
||||
self.db.command(
|
||||
"findAndModify",
|
||||
"uuid",
|
||||
update={"$set": {"i": 6}},
|
||||
query={"_id": uu},
|
||||
codec_options=legacy_opts,
|
||||
)
|
||||
)["value"]["i"],
|
||||
)
|
||||
self.assertEqual(
|
||||
6,
|
||||
self.db.command(
|
||||
"findAndModify",
|
||||
"uuid",
|
||||
update={"$set": {"i": 7}},
|
||||
query={"_id": Binary.from_uuid(uu, PYTHON_LEGACY)},
|
||||
(
|
||||
self.db.command(
|
||||
"findAndModify",
|
||||
"uuid",
|
||||
update={"$set": {"i": 7}},
|
||||
query={"_id": Binary.from_uuid(uu, PYTHON_LEGACY)},
|
||||
)
|
||||
)["value"]["i"],
|
||||
)
|
||||
|
||||
@ -140,20 +141,23 @@ class TestCommon(IntegrationTest):
|
||||
coll.insert_one(doc)
|
||||
self.assertTrue(coll.insert_one(doc))
|
||||
coll = coll.with_options(write_concern=WriteConcern(w=1))
|
||||
self.assertRaises(OperationFailure, coll.insert_one, doc)
|
||||
with self.assertRaises(OperationFailure):
|
||||
coll.insert_one(doc)
|
||||
|
||||
m = self.rs_or_single_client()
|
||||
coll = m.pymongo_test.write_concern_test
|
||||
new_coll = coll.with_options(write_concern=WriteConcern(w=0))
|
||||
self.assertTrue(new_coll.insert_one(doc))
|
||||
self.assertRaises(OperationFailure, coll.insert_one, doc)
|
||||
with self.assertRaises(OperationFailure):
|
||||
coll.insert_one(doc)
|
||||
|
||||
m = self.rs_or_single_client(
|
||||
f"mongodb://{pair}/", replicaSet=client_context.replica_set_name
|
||||
)
|
||||
|
||||
coll = m.pymongo_test.write_concern_test
|
||||
self.assertRaises(OperationFailure, coll.insert_one, doc)
|
||||
with self.assertRaises(OperationFailure):
|
||||
coll.insert_one(doc)
|
||||
m = self.rs_or_single_client(
|
||||
f"mongodb://{pair}/?w=0", replicaSet=client_context.replica_set_name
|
||||
)
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
@ -23,8 +24,13 @@ sys.path[0:0] = [""]
|
||||
from test import unittest
|
||||
from test.unified_format import generate_test_classes
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
# Location of JSON test specifications.
|
||||
_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "connection_logging")
|
||||
if _IS_SYNC:
|
||||
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "connection_logging")
|
||||
else:
|
||||
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "connection_logging")
|
||||
|
||||
|
||||
globals().update(
|
||||
|
||||
@ -25,14 +25,13 @@ from test import IntegrationTest, client_knobs, unittest
|
||||
from test.pymongo_mocks import DummyMonitor
|
||||
from test.utils import (
|
||||
CMAPListener,
|
||||
SpecTestCreator,
|
||||
camel_to_snake,
|
||||
client_context,
|
||||
get_pool,
|
||||
get_pools,
|
||||
wait_until,
|
||||
)
|
||||
from test.utils_spec_runner import SpecRunnerThread
|
||||
from test.utils_spec_runner import SpecRunnerThread, SpecTestCreator
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from bson.son import SON
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
@ -23,11 +24,16 @@ sys.path[0:0] = [""]
|
||||
from test import unittest
|
||||
from test.unified_format import generate_test_classes
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
# Location of JSON test specifications.
|
||||
TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "crud", "unified")
|
||||
if _IS_SYNC:
|
||||
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "crud", "unified")
|
||||
else:
|
||||
_TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "crud", "unified")
|
||||
|
||||
# Generate unified tests.
|
||||
globals().update(generate_test_classes(TEST_PATH, module=__name__, RUN_ON_SERVERLESS=True))
|
||||
globals().update(generate_test_classes(_TEST_PATH, module=__name__, RUN_ON_SERVERLESS=True))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -30,6 +30,7 @@ import uuid
|
||||
import warnings
|
||||
from test import IntegrationTest, PyMongoTestCase, client_context
|
||||
from test.test_bulk import BulkTestBase
|
||||
from test.utils_spec_runner import SpecRunner, SpecTestCreator
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, Mapping, Optional
|
||||
|
||||
@ -58,7 +59,6 @@ from test.unified_format import generate_test_classes
|
||||
from test.utils import (
|
||||
AllowListEventListener,
|
||||
OvertCommandListener,
|
||||
SpecTestCreator,
|
||||
TopologyEventListener,
|
||||
camel_to_snake_args,
|
||||
is_greenthread_patched,
|
||||
@ -624,130 +624,132 @@ AWS_TEMP_NO_SESSION_CREDS = {
|
||||
KMS_TLS_OPTS = {"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}}
|
||||
|
||||
|
||||
if _IS_SYNC:
|
||||
# TODO: Add synchronous SpecRunner (https://jira.mongodb.org/browse/PYTHON-4700)
|
||||
class TestSpec(SpecRunner):
|
||||
@classmethod
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
class TestSpec(SpecRunner):
|
||||
@classmethod
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
|
||||
def parse_auto_encrypt_opts(self, opts):
|
||||
"""Parse clientOptions.autoEncryptOpts."""
|
||||
opts = camel_to_snake_args(opts)
|
||||
kms_providers = opts["kms_providers"]
|
||||
if "aws" in kms_providers:
|
||||
kms_providers["aws"] = AWS_CREDS
|
||||
if not any(AWS_CREDS.values()):
|
||||
self.skipTest("AWS environment credentials are not set")
|
||||
if "awsTemporary" in kms_providers:
|
||||
kms_providers["aws"] = AWS_TEMP_CREDS
|
||||
del kms_providers["awsTemporary"]
|
||||
if not any(AWS_TEMP_CREDS.values()):
|
||||
self.skipTest("AWS Temp environment credentials are not set")
|
||||
if "awsTemporaryNoSessionToken" in kms_providers:
|
||||
kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS
|
||||
del kms_providers["awsTemporaryNoSessionToken"]
|
||||
if not any(AWS_TEMP_NO_SESSION_CREDS.values()):
|
||||
self.skipTest("AWS Temp environment credentials are not set")
|
||||
if "azure" in kms_providers:
|
||||
kms_providers["azure"] = AZURE_CREDS
|
||||
if not any(AZURE_CREDS.values()):
|
||||
self.skipTest("Azure environment credentials are not set")
|
||||
if "gcp" in kms_providers:
|
||||
kms_providers["gcp"] = GCP_CREDS
|
||||
if not any(AZURE_CREDS.values()):
|
||||
self.skipTest("GCP environment credentials are not set")
|
||||
if "kmip" in kms_providers:
|
||||
kms_providers["kmip"] = KMIP_CREDS
|
||||
opts["kms_tls_options"] = KMS_TLS_OPTS
|
||||
if "key_vault_namespace" not in opts:
|
||||
opts["key_vault_namespace"] = "keyvault.datakeys"
|
||||
if "extra_options" in opts:
|
||||
opts.update(camel_to_snake_args(opts.pop("extra_options")))
|
||||
def parse_auto_encrypt_opts(self, opts):
|
||||
"""Parse clientOptions.autoEncryptOpts."""
|
||||
opts = camel_to_snake_args(opts)
|
||||
kms_providers = opts["kms_providers"]
|
||||
if "aws" in kms_providers:
|
||||
kms_providers["aws"] = AWS_CREDS
|
||||
if not any(AWS_CREDS.values()):
|
||||
self.skipTest("AWS environment credentials are not set")
|
||||
if "awsTemporary" in kms_providers:
|
||||
kms_providers["aws"] = AWS_TEMP_CREDS
|
||||
del kms_providers["awsTemporary"]
|
||||
if not any(AWS_TEMP_CREDS.values()):
|
||||
self.skipTest("AWS Temp environment credentials are not set")
|
||||
if "awsTemporaryNoSessionToken" in kms_providers:
|
||||
kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS
|
||||
del kms_providers["awsTemporaryNoSessionToken"]
|
||||
if not any(AWS_TEMP_NO_SESSION_CREDS.values()):
|
||||
self.skipTest("AWS Temp environment credentials are not set")
|
||||
if "azure" in kms_providers:
|
||||
kms_providers["azure"] = AZURE_CREDS
|
||||
if not any(AZURE_CREDS.values()):
|
||||
self.skipTest("Azure environment credentials are not set")
|
||||
if "gcp" in kms_providers:
|
||||
kms_providers["gcp"] = GCP_CREDS
|
||||
if not any(AZURE_CREDS.values()):
|
||||
self.skipTest("GCP environment credentials are not set")
|
||||
if "kmip" in kms_providers:
|
||||
kms_providers["kmip"] = KMIP_CREDS
|
||||
opts["kms_tls_options"] = KMS_TLS_OPTS
|
||||
if "key_vault_namespace" not in opts:
|
||||
opts["key_vault_namespace"] = "keyvault.datakeys"
|
||||
if "extra_options" in opts:
|
||||
opts.update(camel_to_snake_args(opts.pop("extra_options")))
|
||||
|
||||
opts = dict(opts)
|
||||
return AutoEncryptionOpts(**opts)
|
||||
opts = dict(opts)
|
||||
return AutoEncryptionOpts(**opts)
|
||||
|
||||
def parse_client_options(self, opts):
|
||||
"""Override clientOptions parsing to support autoEncryptOpts."""
|
||||
encrypt_opts = opts.pop("autoEncryptOpts", None)
|
||||
if encrypt_opts:
|
||||
opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts)
|
||||
def parse_client_options(self, opts):
|
||||
"""Override clientOptions parsing to support autoEncryptOpts."""
|
||||
encrypt_opts = opts.pop("autoEncryptOpts", None)
|
||||
if encrypt_opts:
|
||||
opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts)
|
||||
|
||||
return super().parse_client_options(opts)
|
||||
return super().parse_client_options(opts)
|
||||
|
||||
def get_object_name(self, op):
|
||||
"""Default object is collection."""
|
||||
return op.get("object", "collection")
|
||||
def get_object_name(self, op):
|
||||
"""Default object is collection."""
|
||||
return op.get("object", "collection")
|
||||
|
||||
def maybe_skip_scenario(self, test):
|
||||
super().maybe_skip_scenario(test)
|
||||
desc = test["description"].lower()
|
||||
if (
|
||||
"timeoutms applied to listcollections to get collection schema" in desc
|
||||
and sys.platform in ("win32", "darwin")
|
||||
):
|
||||
self.skipTest("PYTHON-3706 flaky test on Windows/macOS")
|
||||
if "type=symbol" in desc:
|
||||
self.skipTest("PyMongo does not support the symbol type")
|
||||
def maybe_skip_scenario(self, test):
|
||||
super().maybe_skip_scenario(test)
|
||||
desc = test["description"].lower()
|
||||
if (
|
||||
"timeoutms applied to listcollections to get collection schema" in desc
|
||||
and sys.platform in ("win32", "darwin")
|
||||
):
|
||||
self.skipTest("PYTHON-3706 flaky test on Windows/macOS")
|
||||
if "type=symbol" in desc:
|
||||
self.skipTest("PyMongo does not support the symbol type")
|
||||
if "timeoutms applied to listcollections to get collection schema" in desc and not _IS_SYNC:
|
||||
self.skipTest("PYTHON-4844 flaky test on async")
|
||||
|
||||
def setup_scenario(self, scenario_def):
|
||||
"""Override a test's setup."""
|
||||
key_vault_data = scenario_def["key_vault_data"]
|
||||
encrypted_fields = scenario_def["encrypted_fields"]
|
||||
json_schema = scenario_def["json_schema"]
|
||||
data = scenario_def["data"]
|
||||
coll = client_context.client.get_database("keyvault", codec_options=OPTS)["datakeys"]
|
||||
coll.delete_many({})
|
||||
if key_vault_data:
|
||||
coll.insert_many(key_vault_data)
|
||||
def setup_scenario(self, scenario_def):
|
||||
"""Override a test's setup."""
|
||||
key_vault_data = scenario_def["key_vault_data"]
|
||||
encrypted_fields = scenario_def["encrypted_fields"]
|
||||
json_schema = scenario_def["json_schema"]
|
||||
data = scenario_def["data"]
|
||||
coll = client_context.client.get_database("keyvault", codec_options=OPTS)["datakeys"]
|
||||
coll.delete_many({})
|
||||
if key_vault_data:
|
||||
coll.insert_many(key_vault_data)
|
||||
|
||||
db_name = self.get_scenario_db_name(scenario_def)
|
||||
coll_name = self.get_scenario_coll_name(scenario_def)
|
||||
db = client_context.client.get_database(db_name, codec_options=OPTS)
|
||||
coll = db.drop_collection(coll_name, encrypted_fields=encrypted_fields)
|
||||
wc = WriteConcern(w="majority")
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if json_schema:
|
||||
kwargs["validator"] = {"$jsonSchema": json_schema}
|
||||
kwargs["codec_options"] = OPTS
|
||||
if not data:
|
||||
kwargs["write_concern"] = wc
|
||||
if encrypted_fields:
|
||||
kwargs["encryptedFields"] = encrypted_fields
|
||||
db.create_collection(coll_name, **kwargs)
|
||||
coll = db[coll_name]
|
||||
if data:
|
||||
# Load data.
|
||||
coll.with_options(write_concern=wc).insert_many(scenario_def["data"])
|
||||
db_name = self.get_scenario_db_name(scenario_def)
|
||||
coll_name = self.get_scenario_coll_name(scenario_def)
|
||||
db = client_context.client.get_database(db_name, codec_options=OPTS)
|
||||
db.drop_collection(coll_name, encrypted_fields=encrypted_fields)
|
||||
wc = WriteConcern(w="majority")
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if json_schema:
|
||||
kwargs["validator"] = {"$jsonSchema": json_schema}
|
||||
kwargs["codec_options"] = OPTS
|
||||
if not data:
|
||||
kwargs["write_concern"] = wc
|
||||
if encrypted_fields:
|
||||
kwargs["encryptedFields"] = encrypted_fields
|
||||
db.create_collection(coll_name, **kwargs)
|
||||
coll = db[coll_name]
|
||||
if data:
|
||||
# Load data.
|
||||
coll.with_options(write_concern=wc).insert_many(scenario_def["data"])
|
||||
|
||||
def allowable_errors(self, op):
|
||||
"""Override expected error classes."""
|
||||
errors = super().allowable_errors(op)
|
||||
# An updateOne test expects encryption to error when no $ operator
|
||||
# appears but pymongo raises a client side ValueError in this case.
|
||||
if op["name"] == "updateOne":
|
||||
errors += (ValueError,)
|
||||
return errors
|
||||
def allowable_errors(self, op):
|
||||
"""Override expected error classes."""
|
||||
errors = super().allowable_errors(op)
|
||||
# An updateOne test expects encryption to error when no $ operator
|
||||
# appears but pymongo raises a client side ValueError in this case.
|
||||
if op["name"] == "updateOne":
|
||||
errors += (ValueError,)
|
||||
return errors
|
||||
|
||||
def create_test(scenario_def, test, name):
|
||||
@client_context.require_test_commands
|
||||
def run_scenario(self):
|
||||
self.run_scenario(scenario_def, test)
|
||||
|
||||
return run_scenario
|
||||
def create_test(scenario_def, test, name):
|
||||
@client_context.require_test_commands
|
||||
def run_scenario(self):
|
||||
self.run_scenario(scenario_def, test)
|
||||
|
||||
test_creator = SpecTestCreator(create_test, TestSpec, os.path.join(SPEC_PATH, "legacy"))
|
||||
test_creator.create_tests()
|
||||
return run_scenario
|
||||
|
||||
if _HAVE_PYMONGOCRYPT:
|
||||
globals().update(
|
||||
generate_test_classes(
|
||||
os.path.join(SPEC_PATH, "unified"),
|
||||
module=__name__,
|
||||
)
|
||||
|
||||
test_creator = SpecTestCreator(create_test, TestSpec, os.path.join(SPEC_PATH, "legacy"))
|
||||
test_creator.create_tests()
|
||||
|
||||
if _HAVE_PYMONGOCRYPT:
|
||||
globals().update(
|
||||
generate_test_classes(
|
||||
os.path.join(SPEC_PATH, "unified"),
|
||||
module=__name__,
|
||||
)
|
||||
)
|
||||
|
||||
# Prose Tests
|
||||
ALL_KMS_PROVIDERS = {
|
||||
|
||||
@ -21,11 +21,11 @@ from test import IntegrationTest, client_context, unittest
|
||||
from test.utils import (
|
||||
CMAPListener,
|
||||
OvertCommandListener,
|
||||
SpecTestCreator,
|
||||
get_pool,
|
||||
wait_until,
|
||||
)
|
||||
from test.utils_selection_tests import create_topology
|
||||
from test.utils_spec_runner import SpecTestCreator
|
||||
|
||||
from pymongo.common import clean_node
|
||||
from pymongo.monitoring import ConnectionReadyEvent
|
||||
|
||||
@ -18,41 +18,41 @@ https://github.com/mongodb/specifications/blob/master/source/unified-test-format
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import binascii
|
||||
import collections
|
||||
import copy
|
||||
import datetime
|
||||
import functools
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import types
|
||||
from collections import abc, defaultdict
|
||||
from asyncio import iscoroutinefunction
|
||||
from collections import defaultdict
|
||||
from test import (
|
||||
IntegrationTest,
|
||||
client_context,
|
||||
client_knobs,
|
||||
unittest,
|
||||
)
|
||||
from test.helpers import (
|
||||
AWS_CREDS,
|
||||
AWS_CREDS_2,
|
||||
AZURE_CREDS,
|
||||
CA_PEM,
|
||||
CLIENT_PEM,
|
||||
GCP_CREDS,
|
||||
KMIP_CREDS,
|
||||
LOCAL_MASTER_KEY,
|
||||
client_knobs,
|
||||
from test.unified_format_shared import (
|
||||
IS_INTERRUPTED,
|
||||
KMS_TLS_OPTS,
|
||||
PLACEHOLDER_MAP,
|
||||
SKIP_CSOT_TESTS,
|
||||
EventListenerUtil,
|
||||
MatchEvaluatorUtil,
|
||||
coerce_result,
|
||||
parse_bulk_write_error_result,
|
||||
parse_bulk_write_result,
|
||||
parse_client_bulk_write_error_result,
|
||||
parse_collection_or_database_options,
|
||||
with_metaclass,
|
||||
)
|
||||
from test.utils import (
|
||||
CMAPListener,
|
||||
camel_to_snake,
|
||||
camel_to_snake_args,
|
||||
get_pool,
|
||||
parse_collection_options,
|
||||
parse_spec_options,
|
||||
prepare_spec_arguments,
|
||||
snake_to_camel,
|
||||
@ -60,14 +60,12 @@ from test.utils import (
|
||||
)
|
||||
from test.utils_spec_runner import SpecRunnerThread
|
||||
from test.version import Version
|
||||
from typing import Any, Dict, List, Mapping, Optional, Union
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
import pymongo
|
||||
from bson import SON, Code, DBRef, Decimal128, Int64, MaxKey, MinKey, json_util
|
||||
from bson.binary import Binary
|
||||
from bson import SON, json_util
|
||||
from bson.codec_options import DEFAULT_CODEC_OPTIONS
|
||||
from bson.objectid import ObjectId
|
||||
from bson.regex import RE_TYPE, Regex
|
||||
from gridfs import GridFSBucket, GridOut
|
||||
from pymongo import ASCENDING, CursorType, MongoClient, _csot
|
||||
from pymongo.encryption_options import _HAVE_PYMONGOCRYPT
|
||||
@ -83,55 +81,14 @@ from pymongo.errors import (
|
||||
PyMongoError,
|
||||
)
|
||||
from pymongo.monitoring import (
|
||||
_SENSITIVE_COMMANDS,
|
||||
CommandFailedEvent,
|
||||
CommandListener,
|
||||
CommandStartedEvent,
|
||||
CommandSucceededEvent,
|
||||
ConnectionCheckedInEvent,
|
||||
ConnectionCheckedOutEvent,
|
||||
ConnectionCheckOutFailedEvent,
|
||||
ConnectionCheckOutStartedEvent,
|
||||
ConnectionClosedEvent,
|
||||
ConnectionCreatedEvent,
|
||||
ConnectionReadyEvent,
|
||||
PoolClearedEvent,
|
||||
PoolClosedEvent,
|
||||
PoolCreatedEvent,
|
||||
PoolReadyEvent,
|
||||
ServerClosedEvent,
|
||||
ServerDescriptionChangedEvent,
|
||||
ServerHeartbeatFailedEvent,
|
||||
ServerHeartbeatListener,
|
||||
ServerHeartbeatStartedEvent,
|
||||
ServerHeartbeatSucceededEvent,
|
||||
ServerListener,
|
||||
ServerOpeningEvent,
|
||||
TopologyClosedEvent,
|
||||
TopologyDescriptionChangedEvent,
|
||||
TopologyEvent,
|
||||
TopologyListener,
|
||||
TopologyOpenedEvent,
|
||||
_CommandEvent,
|
||||
_ConnectionEvent,
|
||||
_PoolEvent,
|
||||
_ServerEvent,
|
||||
_ServerHeartbeatEvent,
|
||||
)
|
||||
from pymongo.operations import (
|
||||
DeleteMany,
|
||||
DeleteOne,
|
||||
InsertOne,
|
||||
ReplaceOne,
|
||||
SearchIndexModel,
|
||||
UpdateMany,
|
||||
UpdateOne,
|
||||
)
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
from pymongo.results import BulkWriteResult, ClientBulkWriteResult
|
||||
from pymongo.server_api import ServerApi
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.server_selectors import Selection, writable_server_selector
|
||||
from pymongo.server_type import SERVER_TYPE
|
||||
from pymongo.synchronous.change_stream import ChangeStream
|
||||
@ -140,85 +97,12 @@ from pymongo.synchronous.collection import Collection
|
||||
from pymongo.synchronous.command_cursor import CommandCursor
|
||||
from pymongo.synchronous.database import Database
|
||||
from pymongo.synchronous.encryption import ClientEncryption
|
||||
from pymongo.synchronous.helpers import next
|
||||
from pymongo.topology_description import TopologyDescription
|
||||
from pymongo.typings import _Address
|
||||
from pymongo.write_concern import WriteConcern
|
||||
|
||||
SKIP_CSOT_TESTS = os.getenv("SKIP_CSOT_TESTS")
|
||||
|
||||
JSON_OPTS = json_util.JSONOptions(tz_aware=False)
|
||||
|
||||
IS_INTERRUPTED = False
|
||||
|
||||
KMS_TLS_OPTS = {
|
||||
"kmip": {
|
||||
"tlsCAFile": CA_PEM,
|
||||
"tlsCertificateKeyFile": CLIENT_PEM,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Build up a placeholder maps.
|
||||
PLACEHOLDER_MAP = {}
|
||||
for provider_name, provider_data in [
|
||||
("local", {"key": LOCAL_MASTER_KEY}),
|
||||
("local:name1", {"key": LOCAL_MASTER_KEY}),
|
||||
("aws", AWS_CREDS),
|
||||
("aws:name1", AWS_CREDS),
|
||||
("aws:name2", AWS_CREDS_2),
|
||||
("azure", AZURE_CREDS),
|
||||
("azure:name1", AZURE_CREDS),
|
||||
("gcp", GCP_CREDS),
|
||||
("gcp:name1", GCP_CREDS),
|
||||
("kmip", KMIP_CREDS),
|
||||
("kmip:name1", KMIP_CREDS),
|
||||
]:
|
||||
for key, value in provider_data.items():
|
||||
placeholder = f"/clientEncryptionOpts/kmsProviders/{provider_name}/{key}"
|
||||
PLACEHOLDER_MAP[placeholder] = value
|
||||
|
||||
OIDC_ENV = os.environ.get("OIDC_ENV", "test")
|
||||
if OIDC_ENV == "test":
|
||||
PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = {"ENVIRONMENT": "test"}
|
||||
elif OIDC_ENV == "azure":
|
||||
PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = {
|
||||
"ENVIRONMENT": "azure",
|
||||
"TOKEN_RESOURCE": os.environ["AZUREOIDC_RESOURCE"],
|
||||
}
|
||||
elif OIDC_ENV == "gcp":
|
||||
PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = {
|
||||
"ENVIRONMENT": "gcp",
|
||||
"TOKEN_RESOURCE": os.environ["GCPOIDC_AUDIENCE"],
|
||||
}
|
||||
|
||||
|
||||
def interrupt_loop():
|
||||
global IS_INTERRUPTED
|
||||
IS_INTERRUPTED = True
|
||||
|
||||
|
||||
def with_metaclass(meta, *bases):
|
||||
"""Create a base class with a metaclass.
|
||||
|
||||
Vendored from six: https://github.com/benjaminp/six/blob/master/six.py
|
||||
"""
|
||||
|
||||
# This requires a bit of explanation: the basic idea is to make a dummy
|
||||
# metaclass for one level of class instantiation that replaces itself with
|
||||
# the actual metaclass.
|
||||
class metaclass(type):
|
||||
def __new__(cls, name, this_bases, d):
|
||||
# __orig_bases__ is required by PEP 560.
|
||||
resolved_bases = types.resolve_bases(bases)
|
||||
if resolved_bases is not bases:
|
||||
d["__orig_bases__"] = bases
|
||||
return meta(name, resolved_bases, d)
|
||||
|
||||
@classmethod
|
||||
def __prepare__(cls, name, this_bases):
|
||||
return meta.__prepare__(name, bases)
|
||||
|
||||
return type.__new__(metaclass, "temporary_class", (), {})
|
||||
_IS_SYNC = True
|
||||
|
||||
|
||||
def is_run_on_requirement_satisfied(requirement):
|
||||
@ -283,77 +167,6 @@ def is_run_on_requirement_satisfied(requirement):
|
||||
)
|
||||
|
||||
|
||||
def parse_collection_or_database_options(options):
|
||||
return parse_collection_options(options)
|
||||
|
||||
|
||||
def parse_bulk_write_result(result):
|
||||
upserted_ids = {str(int_idx): result.upserted_ids[int_idx] for int_idx in result.upserted_ids}
|
||||
return {
|
||||
"deletedCount": result.deleted_count,
|
||||
"insertedCount": result.inserted_count,
|
||||
"matchedCount": result.matched_count,
|
||||
"modifiedCount": result.modified_count,
|
||||
"upsertedCount": result.upserted_count,
|
||||
"upsertedIds": upserted_ids,
|
||||
}
|
||||
|
||||
|
||||
def parse_client_bulk_write_individual(op_type, result):
|
||||
if op_type == "insert":
|
||||
return {"insertedId": result.inserted_id}
|
||||
if op_type == "update":
|
||||
if result.upserted_id:
|
||||
return {
|
||||
"matchedCount": result.matched_count,
|
||||
"modifiedCount": result.modified_count,
|
||||
"upsertedId": result.upserted_id,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"matchedCount": result.matched_count,
|
||||
"modifiedCount": result.modified_count,
|
||||
}
|
||||
if op_type == "delete":
|
||||
return {
|
||||
"deletedCount": result.deleted_count,
|
||||
}
|
||||
|
||||
|
||||
def parse_client_bulk_write_result(result):
|
||||
insert_results, update_results, delete_results = {}, {}, {}
|
||||
if result.has_verbose_results:
|
||||
for idx, res in result.insert_results.items():
|
||||
insert_results[str(idx)] = parse_client_bulk_write_individual("insert", res)
|
||||
for idx, res in result.update_results.items():
|
||||
update_results[str(idx)] = parse_client_bulk_write_individual("update", res)
|
||||
for idx, res in result.delete_results.items():
|
||||
delete_results[str(idx)] = parse_client_bulk_write_individual("delete", res)
|
||||
|
||||
return {
|
||||
"deletedCount": result.deleted_count,
|
||||
"insertedCount": result.inserted_count,
|
||||
"matchedCount": result.matched_count,
|
||||
"modifiedCount": result.modified_count,
|
||||
"upsertedCount": result.upserted_count,
|
||||
"insertResults": insert_results,
|
||||
"updateResults": update_results,
|
||||
"deleteResults": delete_results,
|
||||
}
|
||||
|
||||
|
||||
def parse_bulk_write_error_result(error):
|
||||
write_result = BulkWriteResult(error.details, True)
|
||||
return parse_bulk_write_result(write_result)
|
||||
|
||||
|
||||
def parse_client_bulk_write_error_result(error):
|
||||
write_result = error.partial_result
|
||||
if not write_result:
|
||||
return None
|
||||
return parse_client_bulk_write_result(write_result)
|
||||
|
||||
|
||||
class NonLazyCursor:
|
||||
"""A find cursor proxy that creates the remote cursor when initialized."""
|
||||
|
||||
@ -361,7 +174,16 @@ class NonLazyCursor:
|
||||
self.client = client
|
||||
self.find_cursor = find_cursor
|
||||
# Create the server side cursor.
|
||||
self.first_result = next(find_cursor, None)
|
||||
self.first_result = None
|
||||
|
||||
@classmethod
|
||||
def create(cls, find_cursor, client):
|
||||
cursor = cls(find_cursor, client)
|
||||
try:
|
||||
cursor.first_result = next(cursor.find_cursor)
|
||||
except StopIteration:
|
||||
cursor.first_result = None
|
||||
return cursor
|
||||
|
||||
@property
|
||||
def alive(self):
|
||||
@ -382,105 +204,6 @@ class NonLazyCursor:
|
||||
self.client = None
|
||||
|
||||
|
||||
class EventListenerUtil(
|
||||
CMAPListener, CommandListener, ServerListener, ServerHeartbeatListener, TopologyListener
|
||||
):
|
||||
def __init__(
|
||||
self, observe_events, ignore_commands, observe_sensitive_commands, store_events, entity_map
|
||||
):
|
||||
self._event_types = {name.lower() for name in observe_events}
|
||||
if observe_sensitive_commands:
|
||||
self._observe_sensitive_commands = True
|
||||
self._ignore_commands = set(ignore_commands)
|
||||
else:
|
||||
self._observe_sensitive_commands = False
|
||||
self._ignore_commands = _SENSITIVE_COMMANDS | set(ignore_commands)
|
||||
self._ignore_commands.add("configurefailpoint")
|
||||
self._event_mapping = collections.defaultdict(list)
|
||||
self.entity_map = entity_map
|
||||
if store_events:
|
||||
for i in store_events:
|
||||
id = i["id"]
|
||||
events = (i.lower() for i in i["events"])
|
||||
for i in events:
|
||||
self._event_mapping[i].append(id)
|
||||
self.entity_map[id] = []
|
||||
super().__init__()
|
||||
|
||||
def get_events(self, event_type):
|
||||
assert event_type in ("command", "cmap", "sdam", "all"), event_type
|
||||
if event_type == "all":
|
||||
return list(self.events)
|
||||
if event_type == "command":
|
||||
return [e for e in self.events if isinstance(e, _CommandEvent)]
|
||||
if event_type == "cmap":
|
||||
return [e for e in self.events if isinstance(e, (_ConnectionEvent, _PoolEvent))]
|
||||
return [
|
||||
e
|
||||
for e in self.events
|
||||
if isinstance(e, (_ServerEvent, TopologyEvent, _ServerHeartbeatEvent))
|
||||
]
|
||||
|
||||
def add_event(self, event):
|
||||
event_name = type(event).__name__.lower()
|
||||
if event_name in self._event_types:
|
||||
super().add_event(event)
|
||||
for id in self._event_mapping[event_name]:
|
||||
self.entity_map[id].append(
|
||||
{
|
||||
"name": type(event).__name__,
|
||||
"observedAt": time.time(),
|
||||
"description": repr(event),
|
||||
}
|
||||
)
|
||||
|
||||
def _command_event(self, event):
|
||||
if event.command_name.lower() not in self._ignore_commands:
|
||||
self.add_event(event)
|
||||
|
||||
def started(self, event):
|
||||
if isinstance(event, CommandStartedEvent):
|
||||
if event.command == {}:
|
||||
# Command is redacted. Observe only if flag is set.
|
||||
if self._observe_sensitive_commands:
|
||||
self._command_event(event)
|
||||
else:
|
||||
self._command_event(event)
|
||||
else:
|
||||
self.add_event(event)
|
||||
|
||||
def succeeded(self, event):
|
||||
if isinstance(event, CommandSucceededEvent):
|
||||
if event.reply == {}:
|
||||
# Command is redacted. Observe only if flag is set.
|
||||
if self._observe_sensitive_commands:
|
||||
self._command_event(event)
|
||||
else:
|
||||
self._command_event(event)
|
||||
else:
|
||||
self.add_event(event)
|
||||
|
||||
def failed(self, event):
|
||||
if isinstance(event, CommandFailedEvent):
|
||||
self._command_event(event)
|
||||
else:
|
||||
self.add_event(event)
|
||||
|
||||
def opened(self, event: Union[ServerOpeningEvent, TopologyOpenedEvent]) -> None:
|
||||
self.add_event(event)
|
||||
|
||||
def description_changed(
|
||||
self, event: Union[ServerDescriptionChangedEvent, TopologyDescriptionChangedEvent]
|
||||
) -> None:
|
||||
self.add_event(event)
|
||||
|
||||
def topology_changed(self, event: TopologyDescriptionChangedEvent) -> None:
|
||||
self.add_event(event)
|
||||
|
||||
def closed(self, event: Union[ServerClosedEvent, TopologyClosedEvent]) -> None:
|
||||
self.add_event(event)
|
||||
|
||||
|
||||
class EntityMapUtil:
|
||||
"""Utility class that implements an entity map as per the unified
|
||||
test format specification.
|
||||
@ -692,353 +415,12 @@ class EntityMapUtil:
|
||||
def advance_cluster_times(self) -> None:
|
||||
"""Manually synchronize entities when desired"""
|
||||
if not self._cluster_time:
|
||||
self._cluster_time = self.test.client.admin.command("ping").get("$clusterTime")
|
||||
self._cluster_time = (self.test.client.admin.command("ping")).get("$clusterTime")
|
||||
for entity in self._entities.values():
|
||||
if isinstance(entity, ClientSession) and self._cluster_time:
|
||||
entity.advance_cluster_time(self._cluster_time)
|
||||
|
||||
|
||||
binary_types = (Binary, bytes)
|
||||
long_types = (Int64,)
|
||||
unicode_type = str
|
||||
|
||||
|
||||
BSON_TYPE_ALIAS_MAP = {
|
||||
# https://mongodb.com/docs/manual/reference/operator/query/type/
|
||||
# https://pymongo.readthedocs.io/en/stable/api/bson/index.html
|
||||
"double": (float,),
|
||||
"string": (str,),
|
||||
"object": (abc.Mapping,),
|
||||
"array": (abc.MutableSequence,),
|
||||
"binData": binary_types,
|
||||
"undefined": (type(None),),
|
||||
"objectId": (ObjectId,),
|
||||
"bool": (bool,),
|
||||
"date": (datetime.datetime,),
|
||||
"null": (type(None),),
|
||||
"regex": (Regex, RE_TYPE),
|
||||
"dbPointer": (DBRef,),
|
||||
"javascript": (unicode_type, Code),
|
||||
"symbol": (unicode_type,),
|
||||
"javascriptWithScope": (unicode_type, Code),
|
||||
"int": (int,),
|
||||
"long": (Int64,),
|
||||
"decimal": (Decimal128,),
|
||||
"maxKey": (MaxKey,),
|
||||
"minKey": (MinKey,),
|
||||
}
|
||||
|
||||
|
||||
class MatchEvaluatorUtil:
|
||||
"""Utility class that implements methods for evaluating matches as per
|
||||
the unified test format specification.
|
||||
"""
|
||||
|
||||
def __init__(self, test_class):
|
||||
self.test = test_class
|
||||
|
||||
def _operation_exists(self, spec, actual, key_to_compare):
|
||||
if spec is True:
|
||||
if key_to_compare is None:
|
||||
assert actual is not None
|
||||
else:
|
||||
self.test.assertIn(key_to_compare, actual)
|
||||
elif spec is False:
|
||||
if key_to_compare is None:
|
||||
assert actual is None
|
||||
else:
|
||||
self.test.assertNotIn(key_to_compare, actual)
|
||||
else:
|
||||
self.test.fail(f"Expected boolean value for $$exists operator, got {spec}")
|
||||
|
||||
def __type_alias_to_type(self, alias):
|
||||
if alias not in BSON_TYPE_ALIAS_MAP:
|
||||
self.test.fail(f"Unrecognized BSON type alias {alias}")
|
||||
return BSON_TYPE_ALIAS_MAP[alias]
|
||||
|
||||
def _operation_type(self, spec, actual, key_to_compare):
|
||||
if isinstance(spec, abc.MutableSequence):
|
||||
permissible_types = tuple(
|
||||
[t for alias in spec for t in self.__type_alias_to_type(alias)]
|
||||
)
|
||||
else:
|
||||
permissible_types = self.__type_alias_to_type(spec)
|
||||
value = actual[key_to_compare] if key_to_compare else actual
|
||||
self.test.assertIsInstance(value, permissible_types)
|
||||
|
||||
def _operation_matchesEntity(self, spec, actual, key_to_compare):
|
||||
expected_entity = self.test.entity_map[spec]
|
||||
self.test.assertEqual(expected_entity, actual[key_to_compare])
|
||||
|
||||
def _operation_matchesHexBytes(self, spec, actual, key_to_compare):
|
||||
expected = binascii.unhexlify(spec)
|
||||
value = actual[key_to_compare] if key_to_compare else actual
|
||||
self.test.assertEqual(value, expected)
|
||||
|
||||
def _operation_unsetOrMatches(self, spec, actual, key_to_compare):
|
||||
if key_to_compare is None and not actual:
|
||||
# top-level document can be None when unset
|
||||
return
|
||||
|
||||
if key_to_compare not in actual:
|
||||
# we add a dummy value for the compared key to pass map size check
|
||||
actual[key_to_compare] = "dummyValue"
|
||||
return
|
||||
self.match_result(spec, actual[key_to_compare], in_recursive_call=True)
|
||||
|
||||
def _operation_sessionLsid(self, spec, actual, key_to_compare):
|
||||
expected_lsid = self.test.entity_map.get_lsid_for_session(spec)
|
||||
self.test.assertEqual(expected_lsid, actual[key_to_compare])
|
||||
|
||||
def _operation_lte(self, spec, actual, key_to_compare):
|
||||
if key_to_compare not in actual:
|
||||
self.test.fail(f"Actual command is missing the {key_to_compare} field: {spec}")
|
||||
self.test.assertLessEqual(actual[key_to_compare], spec)
|
||||
|
||||
def _operation_matchAsDocument(self, spec, actual, key_to_compare):
|
||||
self._match_document(spec, json_util.loads(actual[key_to_compare]), False)
|
||||
|
||||
def _operation_matchAsRoot(self, spec, actual, key_to_compare):
|
||||
self._match_document(spec, actual, True)
|
||||
|
||||
def _evaluate_special_operation(self, opname, spec, actual, key_to_compare):
|
||||
method_name = "_operation_{}".format(opname.strip("$"))
|
||||
try:
|
||||
method = getattr(self, method_name)
|
||||
except AttributeError:
|
||||
self.test.fail(f"Unsupported special matching operator {opname}")
|
||||
else:
|
||||
method(spec, actual, key_to_compare)
|
||||
|
||||
def _evaluate_if_special_operation(self, expectation, actual, key_to_compare=None):
|
||||
"""Returns True if a special operation is evaluated, False
|
||||
otherwise. If the ``expectation`` map contains a single key,
|
||||
value pair we check it for a special operation.
|
||||
If given, ``key_to_compare`` is assumed to be the key in
|
||||
``expectation`` whose corresponding value needs to be
|
||||
evaluated for a possible special operation. ``key_to_compare``
|
||||
is ignored when ``expectation`` has only one key.
|
||||
"""
|
||||
if not isinstance(expectation, abc.Mapping):
|
||||
return False
|
||||
|
||||
is_special_op, opname, spec = False, False, False
|
||||
|
||||
if key_to_compare is not None:
|
||||
if key_to_compare.startswith("$$"):
|
||||
is_special_op = True
|
||||
opname = key_to_compare
|
||||
spec = expectation[key_to_compare]
|
||||
key_to_compare = None
|
||||
else:
|
||||
nested = expectation[key_to_compare]
|
||||
if isinstance(nested, abc.Mapping) and len(nested) == 1:
|
||||
opname, spec = next(iter(nested.items()))
|
||||
if opname.startswith("$$"):
|
||||
is_special_op = True
|
||||
elif len(expectation) == 1:
|
||||
opname, spec = next(iter(expectation.items()))
|
||||
if opname.startswith("$$"):
|
||||
is_special_op = True
|
||||
key_to_compare = None
|
||||
|
||||
if is_special_op:
|
||||
self._evaluate_special_operation(
|
||||
opname=opname, spec=spec, actual=actual, key_to_compare=key_to_compare
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _match_document(self, expectation, actual, is_root, test=False):
|
||||
if self._evaluate_if_special_operation(expectation, actual):
|
||||
return
|
||||
|
||||
self.test.assertIsInstance(actual, abc.Mapping)
|
||||
for key, value in expectation.items():
|
||||
if self._evaluate_if_special_operation(expectation, actual, key):
|
||||
continue
|
||||
|
||||
self.test.assertIn(key, actual)
|
||||
if not self.match_result(value, actual[key], in_recursive_call=True, test=test):
|
||||
return False
|
||||
|
||||
if not is_root:
|
||||
expected_keys = set(expectation.keys())
|
||||
for key, value in expectation.items():
|
||||
if value == {"$$exists": False}:
|
||||
expected_keys.remove(key)
|
||||
if test:
|
||||
self.test.assertEqual(expected_keys, set(actual.keys()))
|
||||
else:
|
||||
return set(expected_keys).issubset(set(actual.keys()))
|
||||
return True
|
||||
|
||||
def match_result(self, expectation, actual, in_recursive_call=False, test=True):
|
||||
if isinstance(expectation, abc.Mapping):
|
||||
return self._match_document(
|
||||
expectation, actual, is_root=not in_recursive_call, test=test
|
||||
)
|
||||
|
||||
if isinstance(expectation, abc.MutableSequence):
|
||||
self.test.assertIsInstance(actual, abc.MutableSequence)
|
||||
for e, a in zip(expectation, actual):
|
||||
if isinstance(e, abc.Mapping):
|
||||
self._match_document(e, a, is_root=not in_recursive_call, test=test)
|
||||
else:
|
||||
self.match_result(e, a, in_recursive_call=True, test=test)
|
||||
return None
|
||||
|
||||
# account for flexible numerics in element-wise comparison
|
||||
if isinstance(expectation, int) or isinstance(expectation, float):
|
||||
if test:
|
||||
self.test.assertEqual(expectation, actual)
|
||||
else:
|
||||
return expectation == actual
|
||||
return None
|
||||
else:
|
||||
if test:
|
||||
self.test.assertIsInstance(actual, type(expectation))
|
||||
self.test.assertEqual(expectation, actual)
|
||||
else:
|
||||
return isinstance(actual, type(expectation)) and expectation == actual
|
||||
return None
|
||||
|
||||
def match_server_description(self, actual: ServerDescription, spec: dict) -> None:
|
||||
for field, expected in spec.items():
|
||||
field = camel_to_snake(field)
|
||||
if field == "type":
|
||||
field = "server_type_name"
|
||||
self.test.assertEqual(getattr(actual, field), expected)
|
||||
|
||||
def match_topology_description(self, actual: TopologyDescription, spec: dict) -> None:
|
||||
for field, expected in spec.items():
|
||||
field = camel_to_snake(field)
|
||||
if field == "type":
|
||||
field = "topology_type_name"
|
||||
self.test.assertEqual(getattr(actual, field), expected)
|
||||
|
||||
def match_event_fields(self, actual: Any, spec: dict) -> None:
|
||||
for field, expected in spec.items():
|
||||
if field == "command" and isinstance(actual, CommandStartedEvent):
|
||||
command = spec["command"]
|
||||
if command:
|
||||
self.match_result(command, actual.command)
|
||||
continue
|
||||
if field == "reply" and isinstance(actual, CommandSucceededEvent):
|
||||
reply = spec["reply"]
|
||||
if reply:
|
||||
self.match_result(reply, actual.reply)
|
||||
continue
|
||||
if field == "hasServiceId":
|
||||
if spec["hasServiceId"]:
|
||||
self.test.assertIsNotNone(actual.service_id)
|
||||
self.test.assertIsInstance(actual.service_id, ObjectId)
|
||||
else:
|
||||
self.test.assertIsNone(actual.service_id)
|
||||
continue
|
||||
if field == "hasServerConnectionId":
|
||||
if spec["hasServerConnectionId"]:
|
||||
self.test.assertIsNotNone(actual.server_connection_id)
|
||||
self.test.assertIsInstance(actual.server_connection_id, int)
|
||||
else:
|
||||
self.test.assertIsNone(actual.server_connection_id)
|
||||
continue
|
||||
if field in ("previousDescription", "newDescription"):
|
||||
if isinstance(actual, ServerDescriptionChangedEvent):
|
||||
self.match_server_description(
|
||||
getattr(actual, camel_to_snake(field)), spec[field]
|
||||
)
|
||||
continue
|
||||
if isinstance(actual, TopologyDescriptionChangedEvent):
|
||||
self.match_topology_description(
|
||||
getattr(actual, camel_to_snake(field)), spec[field]
|
||||
)
|
||||
continue
|
||||
|
||||
if field == "interruptInUseConnections":
|
||||
field = "interrupt_connections"
|
||||
else:
|
||||
field = camel_to_snake(field)
|
||||
self.test.assertEqual(getattr(actual, field), expected)
|
||||
|
||||
def match_event(self, expectation, actual):
|
||||
name, spec = next(iter(expectation.items()))
|
||||
if name == "commandStartedEvent":
|
||||
self.test.assertIsInstance(actual, CommandStartedEvent)
|
||||
elif name == "commandSucceededEvent":
|
||||
self.test.assertIsInstance(actual, CommandSucceededEvent)
|
||||
elif name == "commandFailedEvent":
|
||||
self.test.assertIsInstance(actual, CommandFailedEvent)
|
||||
elif name == "poolCreatedEvent":
|
||||
self.test.assertIsInstance(actual, PoolCreatedEvent)
|
||||
elif name == "poolReadyEvent":
|
||||
self.test.assertIsInstance(actual, PoolReadyEvent)
|
||||
elif name == "poolClearedEvent":
|
||||
self.test.assertIsInstance(actual, PoolClearedEvent)
|
||||
self.test.assertIsInstance(actual.interrupt_connections, bool)
|
||||
elif name == "poolClosedEvent":
|
||||
self.test.assertIsInstance(actual, PoolClosedEvent)
|
||||
elif name == "connectionCreatedEvent":
|
||||
self.test.assertIsInstance(actual, ConnectionCreatedEvent)
|
||||
elif name == "connectionReadyEvent":
|
||||
self.test.assertIsInstance(actual, ConnectionReadyEvent)
|
||||
elif name == "connectionClosedEvent":
|
||||
self.test.assertIsInstance(actual, ConnectionClosedEvent)
|
||||
elif name == "connectionCheckOutStartedEvent":
|
||||
self.test.assertIsInstance(actual, ConnectionCheckOutStartedEvent)
|
||||
elif name == "connectionCheckOutFailedEvent":
|
||||
self.test.assertIsInstance(actual, ConnectionCheckOutFailedEvent)
|
||||
elif name == "connectionCheckedOutEvent":
|
||||
self.test.assertIsInstance(actual, ConnectionCheckedOutEvent)
|
||||
elif name == "connectionCheckedInEvent":
|
||||
self.test.assertIsInstance(actual, ConnectionCheckedInEvent)
|
||||
elif name == "serverDescriptionChangedEvent":
|
||||
self.test.assertIsInstance(actual, ServerDescriptionChangedEvent)
|
||||
elif name == "serverHeartbeatStartedEvent":
|
||||
self.test.assertIsInstance(actual, ServerHeartbeatStartedEvent)
|
||||
elif name == "serverHeartbeatSucceededEvent":
|
||||
self.test.assertIsInstance(actual, ServerHeartbeatSucceededEvent)
|
||||
elif name == "serverHeartbeatFailedEvent":
|
||||
self.test.assertIsInstance(actual, ServerHeartbeatFailedEvent)
|
||||
elif name == "topologyDescriptionChangedEvent":
|
||||
self.test.assertIsInstance(actual, TopologyDescriptionChangedEvent)
|
||||
elif name == "topologyOpeningEvent":
|
||||
self.test.assertIsInstance(actual, TopologyOpenedEvent)
|
||||
elif name == "topologyClosedEvent":
|
||||
self.test.assertIsInstance(actual, TopologyClosedEvent)
|
||||
else:
|
||||
raise Exception(f"Unsupported event type {name}")
|
||||
|
||||
self.match_event_fields(actual, spec)
|
||||
|
||||
|
||||
def coerce_result(opname, result):
|
||||
"""Convert a pymongo result into the spec's result format."""
|
||||
if hasattr(result, "acknowledged") and not result.acknowledged:
|
||||
return {"acknowledged": False}
|
||||
if opname == "bulkWrite":
|
||||
return parse_bulk_write_result(result)
|
||||
if opname == "clientBulkWrite":
|
||||
return parse_client_bulk_write_result(result)
|
||||
if opname == "insertOne":
|
||||
return {"insertedId": result.inserted_id}
|
||||
if opname == "insertMany":
|
||||
return dict(enumerate(result.inserted_ids))
|
||||
if opname in ("deleteOne", "deleteMany"):
|
||||
return {"deletedCount": result.deleted_count}
|
||||
if opname in ("updateOne", "updateMany", "replaceOne"):
|
||||
value = {
|
||||
"matchedCount": result.matched_count,
|
||||
"modifiedCount": result.modified_count,
|
||||
"upsertedCount": 0 if result.upserted_id is None else 1,
|
||||
}
|
||||
if result.upserted_id is not None:
|
||||
value["upsertedId"] = result.upserted_id
|
||||
return value
|
||||
return result
|
||||
|
||||
|
||||
class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
"""Mixin class to run test cases from test specification files.
|
||||
|
||||
@ -1090,9 +472,9 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
db.create_collection(coll_name, write_concern=wc, **opts)
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
def _setup_class(cls):
|
||||
# super call creates internal client cls.client
|
||||
super().setUpClass()
|
||||
super()._setup_class()
|
||||
# process file-level runOnRequirements
|
||||
run_on_spec = cls.TEST_SPEC.get("runOnRequirements", [])
|
||||
if not cls.should_run_on(run_on_spec):
|
||||
@ -1125,11 +507,11 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
cls.knobs.enable()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
def _tearDown_class(cls):
|
||||
cls.knobs.disable()
|
||||
for client in cls.mongos_clients:
|
||||
client.close()
|
||||
super().tearDownClass()
|
||||
super()._tearDown_class()
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -1391,7 +773,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
|
||||
def __entityOperation_aggregate(self, target, *args, **kwargs):
|
||||
self.__raise_if_unsupported("aggregate", target, Database, Collection)
|
||||
return list(target.aggregate(*args, **kwargs))
|
||||
return (target.aggregate(*args, **kwargs)).to_list()
|
||||
|
||||
def _databaseOperation_aggregate(self, target, *args, **kwargs):
|
||||
return self.__entityOperation_aggregate(target, *args, **kwargs)
|
||||
@ -1402,13 +784,13 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
def _collectionOperation_find(self, target, *args, **kwargs):
|
||||
self.__raise_if_unsupported("find", target, Collection)
|
||||
find_cursor = target.find(*args, **kwargs)
|
||||
return list(find_cursor)
|
||||
return find_cursor.to_list()
|
||||
|
||||
def _collectionOperation_createFindCursor(self, target, *args, **kwargs):
|
||||
self.__raise_if_unsupported("find", target, Collection)
|
||||
if "filter" not in kwargs:
|
||||
self.fail('createFindCursor requires a "filter" argument')
|
||||
cursor = NonLazyCursor(target.find(*args, **kwargs), target.database.client)
|
||||
cursor = NonLazyCursor.create(target.find(*args, **kwargs), target.database.client)
|
||||
self.addCleanup(cursor.close)
|
||||
return cursor
|
||||
|
||||
@ -1418,7 +800,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
def _collectionOperation_listIndexes(self, target, *args, **kwargs):
|
||||
if "batch_size" in kwargs:
|
||||
self.skipTest("PyMongo does not support batch_size for list_indexes")
|
||||
return list(target.list_indexes(*args, **kwargs))
|
||||
return (target.list_indexes(*args, **kwargs)).to_list()
|
||||
|
||||
def _collectionOperation_listIndexNames(self, target, *args, **kwargs):
|
||||
self.skipTest("PyMongo does not support list_index_names")
|
||||
@ -1430,7 +812,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
def _collectionOperation_listSearchIndexes(self, target, *args, **kwargs):
|
||||
name = kwargs.get("name")
|
||||
agg_kwargs = kwargs.get("aggregation_options", dict())
|
||||
return list(target.list_search_indexes(name, **agg_kwargs))
|
||||
return (target.list_search_indexes(name, **agg_kwargs)).to_list()
|
||||
|
||||
def _sessionOperation_withTransaction(self, target, *args, **kwargs):
|
||||
if client_context.storage_engine == "mmapv1":
|
||||
@ -1470,7 +852,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
return target.create_data_key(*args, **kwargs)
|
||||
|
||||
def _clientEncryptionOperation_getKeys(self, target, *args, **kwargs):
|
||||
return list(target.get_keys(*args, **kwargs))
|
||||
return (target.get_keys(*args, **kwargs)).to_list()
|
||||
|
||||
def _clientEncryptionOperation_deleteKey(self, target, *args, **kwargs):
|
||||
result = target.delete_key(*args, **kwargs)
|
||||
@ -1516,7 +898,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
def _bucketOperation_find(
|
||||
self, target: GridFSBucket, *args: Any, **kwargs: Any
|
||||
) -> List[GridOut]:
|
||||
return list(target.find(*args, **kwargs))
|
||||
return target.find(*args, **kwargs).to_list()
|
||||
|
||||
def run_entity_operation(self, spec):
|
||||
target = self.entity_map[spec["object"]]
|
||||
@ -1849,7 +1231,10 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
except AttributeError:
|
||||
self.fail(f"Unsupported special test operation {opname}")
|
||||
else:
|
||||
method(spec["arguments"])
|
||||
if iscoroutinefunction(method):
|
||||
method(spec["arguments"])
|
||||
else:
|
||||
method(spec["arguments"])
|
||||
|
||||
def run_operations(self, spec):
|
||||
for op in spec:
|
||||
@ -1985,7 +1370,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
|
||||
if expected_documents:
|
||||
sorted_expected_documents = sorted(expected_documents, key=lambda doc: doc["_id"])
|
||||
actual_documents = list(coll.find({}, sort=[("_id", ASCENDING)]))
|
||||
actual_documents = coll.find({}, sort=[("_id", ASCENDING)]).to_list()
|
||||
self.assertListEqual(sorted_expected_documents, actual_documents)
|
||||
|
||||
def run_scenario(self, spec, uri=None):
|
||||
@ -2040,7 +1425,7 @@ class UnifiedSpecTestMixinV1(IntegrationTest):
|
||||
# process initialData
|
||||
if "initialData" in self.TEST_SPEC:
|
||||
self.insert_initial_data(self.TEST_SPEC["initialData"])
|
||||
self._cluster_time = self.client.admin.command("ping").get("$clusterTime")
|
||||
self._cluster_time = (self.client.admin.command("ping")).get("$clusterTime")
|
||||
self.entity_map.advance_cluster_times()
|
||||
|
||||
if "expectLogMessages" in spec:
|
||||
|
||||
679
test/unified_format_shared.py
Normal file
679
test/unified_format_shared.py
Normal file
@ -0,0 +1,679 @@
|
||||
# Copyright 2024-present MongoDB, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Shared utility functions and constants for the unified test format runner.
|
||||
|
||||
https://github.com/mongodb/specifications/blob/master/source/unified-test-format/unified-test-format.rst
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import binascii
|
||||
import collections
|
||||
import datetime
|
||||
import os
|
||||
import time
|
||||
import types
|
||||
from collections import abc
|
||||
from test.helpers import (
|
||||
AWS_CREDS,
|
||||
AWS_CREDS_2,
|
||||
AZURE_CREDS,
|
||||
CA_PEM,
|
||||
CLIENT_PEM,
|
||||
GCP_CREDS,
|
||||
KMIP_CREDS,
|
||||
LOCAL_MASTER_KEY,
|
||||
)
|
||||
from test.utils import CMAPListener, camel_to_snake, parse_collection_options
|
||||
from typing import Any, Union
|
||||
|
||||
from bson import (
|
||||
RE_TYPE,
|
||||
Binary,
|
||||
Code,
|
||||
DBRef,
|
||||
Decimal128,
|
||||
Int64,
|
||||
MaxKey,
|
||||
MinKey,
|
||||
ObjectId,
|
||||
Regex,
|
||||
json_util,
|
||||
)
|
||||
from pymongo.monitoring import (
|
||||
_SENSITIVE_COMMANDS,
|
||||
CommandFailedEvent,
|
||||
CommandListener,
|
||||
CommandStartedEvent,
|
||||
CommandSucceededEvent,
|
||||
ConnectionCheckedInEvent,
|
||||
ConnectionCheckedOutEvent,
|
||||
ConnectionCheckOutFailedEvent,
|
||||
ConnectionCheckOutStartedEvent,
|
||||
ConnectionClosedEvent,
|
||||
ConnectionCreatedEvent,
|
||||
ConnectionReadyEvent,
|
||||
PoolClearedEvent,
|
||||
PoolClosedEvent,
|
||||
PoolCreatedEvent,
|
||||
PoolReadyEvent,
|
||||
ServerClosedEvent,
|
||||
ServerDescriptionChangedEvent,
|
||||
ServerHeartbeatFailedEvent,
|
||||
ServerHeartbeatListener,
|
||||
ServerHeartbeatStartedEvent,
|
||||
ServerHeartbeatSucceededEvent,
|
||||
ServerListener,
|
||||
ServerOpeningEvent,
|
||||
TopologyClosedEvent,
|
||||
TopologyDescriptionChangedEvent,
|
||||
TopologyEvent,
|
||||
TopologyListener,
|
||||
TopologyOpenedEvent,
|
||||
_CommandEvent,
|
||||
_ConnectionEvent,
|
||||
_PoolEvent,
|
||||
_ServerEvent,
|
||||
_ServerHeartbeatEvent,
|
||||
)
|
||||
from pymongo.results import BulkWriteResult
|
||||
from pymongo.server_description import ServerDescription
|
||||
from pymongo.topology_description import TopologyDescription
|
||||
|
||||
SKIP_CSOT_TESTS = os.getenv("SKIP_CSOT_TESTS")
|
||||
|
||||
JSON_OPTS = json_util.JSONOptions(tz_aware=False)
|
||||
|
||||
IS_INTERRUPTED = False
|
||||
|
||||
KMS_TLS_OPTS = {
|
||||
"kmip": {
|
||||
"tlsCAFile": CA_PEM,
|
||||
"tlsCertificateKeyFile": CLIENT_PEM,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Build up a placeholder maps.
|
||||
PLACEHOLDER_MAP = {}
|
||||
for provider_name, provider_data in [
|
||||
("local", {"key": LOCAL_MASTER_KEY}),
|
||||
("local:name1", {"key": LOCAL_MASTER_KEY}),
|
||||
("aws", AWS_CREDS),
|
||||
("aws:name1", AWS_CREDS),
|
||||
("aws:name2", AWS_CREDS_2),
|
||||
("azure", AZURE_CREDS),
|
||||
("azure:name1", AZURE_CREDS),
|
||||
("gcp", GCP_CREDS),
|
||||
("gcp:name1", GCP_CREDS),
|
||||
("kmip", KMIP_CREDS),
|
||||
("kmip:name1", KMIP_CREDS),
|
||||
]:
|
||||
for key, value in provider_data.items():
|
||||
placeholder = f"/clientEncryptionOpts/kmsProviders/{provider_name}/{key}"
|
||||
PLACEHOLDER_MAP[placeholder] = value
|
||||
|
||||
OIDC_ENV = os.environ.get("OIDC_ENV", "test")
|
||||
if OIDC_ENV == "test":
|
||||
PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = {"ENVIRONMENT": "test"}
|
||||
elif OIDC_ENV == "azure":
|
||||
PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = {
|
||||
"ENVIRONMENT": "azure",
|
||||
"TOKEN_RESOURCE": os.environ["AZUREOIDC_RESOURCE"],
|
||||
}
|
||||
elif OIDC_ENV == "gcp":
|
||||
PLACEHOLDER_MAP["/uriOptions/authMechanismProperties"] = {
|
||||
"ENVIRONMENT": "gcp",
|
||||
"TOKEN_RESOURCE": os.environ["GCPOIDC_AUDIENCE"],
|
||||
}
|
||||
|
||||
|
||||
def interrupt_loop():
|
||||
global IS_INTERRUPTED
|
||||
IS_INTERRUPTED = True
|
||||
|
||||
|
||||
def with_metaclass(meta, *bases):
|
||||
"""Create a base class with a metaclass.
|
||||
|
||||
Vendored from six: https://github.com/benjaminp/six/blob/master/six.py
|
||||
"""
|
||||
|
||||
# This requires a bit of explanation: the basic idea is to make a dummy
|
||||
# metaclass for one level of class instantiation that replaces itself with
|
||||
# the actual metaclass.
|
||||
class metaclass(type):
|
||||
def __new__(cls, name, this_bases, d):
|
||||
# __orig_bases__ is required by PEP 560.
|
||||
resolved_bases = types.resolve_bases(bases)
|
||||
if resolved_bases is not bases:
|
||||
d["__orig_bases__"] = bases
|
||||
return meta(name, resolved_bases, d)
|
||||
|
||||
@classmethod
|
||||
def __prepare__(cls, name, this_bases):
|
||||
return meta.__prepare__(name, bases)
|
||||
|
||||
return type.__new__(metaclass, "temporary_class", (), {})
|
||||
|
||||
|
||||
def parse_collection_or_database_options(options):
|
||||
return parse_collection_options(options)
|
||||
|
||||
|
||||
def parse_bulk_write_result(result):
|
||||
upserted_ids = {str(int_idx): result.upserted_ids[int_idx] for int_idx in result.upserted_ids}
|
||||
return {
|
||||
"deletedCount": result.deleted_count,
|
||||
"insertedCount": result.inserted_count,
|
||||
"matchedCount": result.matched_count,
|
||||
"modifiedCount": result.modified_count,
|
||||
"upsertedCount": result.upserted_count,
|
||||
"upsertedIds": upserted_ids,
|
||||
}
|
||||
|
||||
|
||||
def parse_client_bulk_write_individual(op_type, result):
|
||||
if op_type == "insert":
|
||||
return {"insertedId": result.inserted_id}
|
||||
if op_type == "update":
|
||||
if result.upserted_id:
|
||||
return {
|
||||
"matchedCount": result.matched_count,
|
||||
"modifiedCount": result.modified_count,
|
||||
"upsertedId": result.upserted_id,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"matchedCount": result.matched_count,
|
||||
"modifiedCount": result.modified_count,
|
||||
}
|
||||
if op_type == "delete":
|
||||
return {
|
||||
"deletedCount": result.deleted_count,
|
||||
}
|
||||
|
||||
|
||||
def parse_client_bulk_write_result(result):
|
||||
insert_results, update_results, delete_results = {}, {}, {}
|
||||
if result.has_verbose_results:
|
||||
for idx, res in result.insert_results.items():
|
||||
insert_results[str(idx)] = parse_client_bulk_write_individual("insert", res)
|
||||
for idx, res in result.update_results.items():
|
||||
update_results[str(idx)] = parse_client_bulk_write_individual("update", res)
|
||||
for idx, res in result.delete_results.items():
|
||||
delete_results[str(idx)] = parse_client_bulk_write_individual("delete", res)
|
||||
|
||||
return {
|
||||
"deletedCount": result.deleted_count,
|
||||
"insertedCount": result.inserted_count,
|
||||
"matchedCount": result.matched_count,
|
||||
"modifiedCount": result.modified_count,
|
||||
"upsertedCount": result.upserted_count,
|
||||
"insertResults": insert_results,
|
||||
"updateResults": update_results,
|
||||
"deleteResults": delete_results,
|
||||
}
|
||||
|
||||
|
||||
def parse_bulk_write_error_result(error):
|
||||
write_result = BulkWriteResult(error.details, True)
|
||||
return parse_bulk_write_result(write_result)
|
||||
|
||||
|
||||
def parse_client_bulk_write_error_result(error):
|
||||
write_result = error.partial_result
|
||||
if not write_result:
|
||||
return None
|
||||
return parse_client_bulk_write_result(write_result)
|
||||
|
||||
|
||||
class EventListenerUtil(
|
||||
CMAPListener, CommandListener, ServerListener, ServerHeartbeatListener, TopologyListener
|
||||
):
|
||||
def __init__(
|
||||
self, observe_events, ignore_commands, observe_sensitive_commands, store_events, entity_map
|
||||
):
|
||||
self._event_types = {name.lower() for name in observe_events}
|
||||
if observe_sensitive_commands:
|
||||
self._observe_sensitive_commands = True
|
||||
self._ignore_commands = set(ignore_commands)
|
||||
else:
|
||||
self._observe_sensitive_commands = False
|
||||
self._ignore_commands = _SENSITIVE_COMMANDS | set(ignore_commands)
|
||||
self._ignore_commands.add("configurefailpoint")
|
||||
self._event_mapping = collections.defaultdict(list)
|
||||
self.entity_map = entity_map
|
||||
if store_events:
|
||||
for i in store_events:
|
||||
id = i["id"]
|
||||
events = (i.lower() for i in i["events"])
|
||||
for i in events:
|
||||
self._event_mapping[i].append(id)
|
||||
self.entity_map[id] = []
|
||||
super().__init__()
|
||||
|
||||
def get_events(self, event_type):
|
||||
assert event_type in ("command", "cmap", "sdam", "all"), event_type
|
||||
if event_type == "all":
|
||||
return list(self.events)
|
||||
if event_type == "command":
|
||||
return [e for e in self.events if isinstance(e, _CommandEvent)]
|
||||
if event_type == "cmap":
|
||||
return [e for e in self.events if isinstance(e, (_ConnectionEvent, _PoolEvent))]
|
||||
return [
|
||||
e
|
||||
for e in self.events
|
||||
if isinstance(e, (_ServerEvent, TopologyEvent, _ServerHeartbeatEvent))
|
||||
]
|
||||
|
||||
def add_event(self, event):
|
||||
event_name = type(event).__name__.lower()
|
||||
if event_name in self._event_types:
|
||||
super().add_event(event)
|
||||
for id in self._event_mapping[event_name]:
|
||||
self.entity_map[id].append(
|
||||
{
|
||||
"name": type(event).__name__,
|
||||
"observedAt": time.time(),
|
||||
"description": repr(event),
|
||||
}
|
||||
)
|
||||
|
||||
def _command_event(self, event):
|
||||
if event.command_name.lower() not in self._ignore_commands:
|
||||
self.add_event(event)
|
||||
|
||||
def started(self, event):
|
||||
if isinstance(event, CommandStartedEvent):
|
||||
if event.command == {}:
|
||||
# Command is redacted. Observe only if flag is set.
|
||||
if self._observe_sensitive_commands:
|
||||
self._command_event(event)
|
||||
else:
|
||||
self._command_event(event)
|
||||
else:
|
||||
self.add_event(event)
|
||||
|
||||
def succeeded(self, event):
|
||||
if isinstance(event, CommandSucceededEvent):
|
||||
if event.reply == {}:
|
||||
# Command is redacted. Observe only if flag is set.
|
||||
if self._observe_sensitive_commands:
|
||||
self._command_event(event)
|
||||
else:
|
||||
self._command_event(event)
|
||||
else:
|
||||
self.add_event(event)
|
||||
|
||||
def failed(self, event):
|
||||
if isinstance(event, CommandFailedEvent):
|
||||
self._command_event(event)
|
||||
else:
|
||||
self.add_event(event)
|
||||
|
||||
def opened(self, event: Union[ServerOpeningEvent, TopologyOpenedEvent]) -> None:
|
||||
self.add_event(event)
|
||||
|
||||
def description_changed(
|
||||
self, event: Union[ServerDescriptionChangedEvent, TopologyDescriptionChangedEvent]
|
||||
) -> None:
|
||||
self.add_event(event)
|
||||
|
||||
def topology_changed(self, event: TopologyDescriptionChangedEvent) -> None:
|
||||
self.add_event(event)
|
||||
|
||||
def closed(self, event: Union[ServerClosedEvent, TopologyClosedEvent]) -> None:
|
||||
self.add_event(event)
|
||||
|
||||
|
||||
binary_types = (Binary, bytes)
|
||||
long_types = (Int64,)
|
||||
unicode_type = str
|
||||
|
||||
|
||||
BSON_TYPE_ALIAS_MAP = {
|
||||
# https://mongodb.com/docs/manual/reference/operator/query/type/
|
||||
# https://pymongo.readthedocs.io/en/stable/api/bson/index.html
|
||||
"double": (float,),
|
||||
"string": (str,),
|
||||
"object": (abc.Mapping,),
|
||||
"array": (abc.MutableSequence,),
|
||||
"binData": binary_types,
|
||||
"undefined": (type(None),),
|
||||
"objectId": (ObjectId,),
|
||||
"bool": (bool,),
|
||||
"date": (datetime.datetime,),
|
||||
"null": (type(None),),
|
||||
"regex": (Regex, RE_TYPE),
|
||||
"dbPointer": (DBRef,),
|
||||
"javascript": (unicode_type, Code),
|
||||
"symbol": (unicode_type,),
|
||||
"javascriptWithScope": (unicode_type, Code),
|
||||
"int": (int,),
|
||||
"long": (Int64,),
|
||||
"decimal": (Decimal128,),
|
||||
"maxKey": (MaxKey,),
|
||||
"minKey": (MinKey,),
|
||||
}
|
||||
|
||||
|
||||
class MatchEvaluatorUtil:
|
||||
"""Utility class that implements methods for evaluating matches as per
|
||||
the unified test format specification.
|
||||
"""
|
||||
|
||||
def __init__(self, test_class):
|
||||
self.test = test_class
|
||||
|
||||
def _operation_exists(self, spec, actual, key_to_compare):
|
||||
if spec is True:
|
||||
if key_to_compare is None:
|
||||
assert actual is not None
|
||||
else:
|
||||
self.test.assertIn(key_to_compare, actual)
|
||||
elif spec is False:
|
||||
if key_to_compare is None:
|
||||
assert actual is None
|
||||
else:
|
||||
self.test.assertNotIn(key_to_compare, actual)
|
||||
else:
|
||||
self.test.fail(f"Expected boolean value for $$exists operator, got {spec}")
|
||||
|
||||
def __type_alias_to_type(self, alias):
|
||||
if alias not in BSON_TYPE_ALIAS_MAP:
|
||||
self.test.fail(f"Unrecognized BSON type alias {alias}")
|
||||
return BSON_TYPE_ALIAS_MAP[alias]
|
||||
|
||||
def _operation_type(self, spec, actual, key_to_compare):
|
||||
if isinstance(spec, abc.MutableSequence):
|
||||
permissible_types = tuple(
|
||||
[t for alias in spec for t in self.__type_alias_to_type(alias)]
|
||||
)
|
||||
else:
|
||||
permissible_types = self.__type_alias_to_type(spec)
|
||||
value = actual[key_to_compare] if key_to_compare else actual
|
||||
self.test.assertIsInstance(value, permissible_types)
|
||||
|
||||
def _operation_matchesEntity(self, spec, actual, key_to_compare):
|
||||
expected_entity = self.test.entity_map[spec]
|
||||
self.test.assertEqual(expected_entity, actual[key_to_compare])
|
||||
|
||||
def _operation_matchesHexBytes(self, spec, actual, key_to_compare):
|
||||
expected = binascii.unhexlify(spec)
|
||||
value = actual[key_to_compare] if key_to_compare else actual
|
||||
self.test.assertEqual(value, expected)
|
||||
|
||||
def _operation_unsetOrMatches(self, spec, actual, key_to_compare):
|
||||
if key_to_compare is None and not actual:
|
||||
# top-level document can be None when unset
|
||||
return
|
||||
|
||||
if key_to_compare not in actual:
|
||||
# we add a dummy value for the compared key to pass map size check
|
||||
actual[key_to_compare] = "dummyValue"
|
||||
return
|
||||
self.match_result(spec, actual[key_to_compare], in_recursive_call=True)
|
||||
|
||||
def _operation_sessionLsid(self, spec, actual, key_to_compare):
|
||||
expected_lsid = self.test.entity_map.get_lsid_for_session(spec)
|
||||
self.test.assertEqual(expected_lsid, actual[key_to_compare])
|
||||
|
||||
def _operation_lte(self, spec, actual, key_to_compare):
|
||||
if key_to_compare not in actual:
|
||||
self.test.fail(f"Actual command is missing the {key_to_compare} field: {spec}")
|
||||
self.test.assertLessEqual(actual[key_to_compare], spec)
|
||||
|
||||
def _operation_matchAsDocument(self, spec, actual, key_to_compare):
|
||||
self._match_document(spec, json_util.loads(actual[key_to_compare]), False)
|
||||
|
||||
def _operation_matchAsRoot(self, spec, actual, key_to_compare):
|
||||
self._match_document(spec, actual, True)
|
||||
|
||||
def _evaluate_special_operation(self, opname, spec, actual, key_to_compare):
|
||||
method_name = "_operation_{}".format(opname.strip("$"))
|
||||
try:
|
||||
method = getattr(self, method_name)
|
||||
except AttributeError:
|
||||
self.test.fail(f"Unsupported special matching operator {opname}")
|
||||
else:
|
||||
method(spec, actual, key_to_compare)
|
||||
|
||||
def _evaluate_if_special_operation(self, expectation, actual, key_to_compare=None):
|
||||
"""Returns True if a special operation is evaluated, False
|
||||
otherwise. If the ``expectation`` map contains a single key,
|
||||
value pair we check it for a special operation.
|
||||
If given, ``key_to_compare`` is assumed to be the key in
|
||||
``expectation`` whose corresponding value needs to be
|
||||
evaluated for a possible special operation. ``key_to_compare``
|
||||
is ignored when ``expectation`` has only one key.
|
||||
"""
|
||||
if not isinstance(expectation, abc.Mapping):
|
||||
return False
|
||||
|
||||
is_special_op, opname, spec = False, False, False
|
||||
|
||||
if key_to_compare is not None:
|
||||
if key_to_compare.startswith("$$"):
|
||||
is_special_op = True
|
||||
opname = key_to_compare
|
||||
spec = expectation[key_to_compare]
|
||||
key_to_compare = None
|
||||
else:
|
||||
nested = expectation[key_to_compare]
|
||||
if isinstance(nested, abc.Mapping) and len(nested) == 1:
|
||||
opname, spec = next(iter(nested.items()))
|
||||
if opname.startswith("$$"):
|
||||
is_special_op = True
|
||||
elif len(expectation) == 1:
|
||||
opname, spec = next(iter(expectation.items()))
|
||||
if opname.startswith("$$"):
|
||||
is_special_op = True
|
||||
key_to_compare = None
|
||||
|
||||
if is_special_op:
|
||||
self._evaluate_special_operation(
|
||||
opname=opname, spec=spec, actual=actual, key_to_compare=key_to_compare
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _match_document(self, expectation, actual, is_root, test=False):
|
||||
if self._evaluate_if_special_operation(expectation, actual):
|
||||
return
|
||||
|
||||
self.test.assertIsInstance(actual, abc.Mapping)
|
||||
for key, value in expectation.items():
|
||||
if self._evaluate_if_special_operation(expectation, actual, key):
|
||||
continue
|
||||
|
||||
self.test.assertIn(key, actual)
|
||||
if not self.match_result(value, actual[key], in_recursive_call=True, test=test):
|
||||
return False
|
||||
|
||||
if not is_root:
|
||||
expected_keys = set(expectation.keys())
|
||||
for key, value in expectation.items():
|
||||
if value == {"$$exists": False}:
|
||||
expected_keys.remove(key)
|
||||
if test:
|
||||
self.test.assertEqual(expected_keys, set(actual.keys()))
|
||||
else:
|
||||
return set(expected_keys).issubset(set(actual.keys()))
|
||||
return True
|
||||
|
||||
def match_result(self, expectation, actual, in_recursive_call=False, test=True):
|
||||
if isinstance(expectation, abc.Mapping):
|
||||
return self._match_document(
|
||||
expectation, actual, is_root=not in_recursive_call, test=test
|
||||
)
|
||||
|
||||
if isinstance(expectation, abc.MutableSequence):
|
||||
self.test.assertIsInstance(actual, abc.MutableSequence)
|
||||
for e, a in zip(expectation, actual):
|
||||
if isinstance(e, abc.Mapping):
|
||||
self._match_document(e, a, is_root=not in_recursive_call, test=test)
|
||||
else:
|
||||
self.match_result(e, a, in_recursive_call=True, test=test)
|
||||
return None
|
||||
|
||||
# account for flexible numerics in element-wise comparison
|
||||
if isinstance(expectation, int) or isinstance(expectation, float):
|
||||
if test:
|
||||
self.test.assertEqual(expectation, actual)
|
||||
else:
|
||||
return expectation == actual
|
||||
return None
|
||||
else:
|
||||
if test:
|
||||
self.test.assertIsInstance(actual, type(expectation))
|
||||
self.test.assertEqual(expectation, actual)
|
||||
else:
|
||||
return isinstance(actual, type(expectation)) and expectation == actual
|
||||
return None
|
||||
|
||||
def match_server_description(self, actual: ServerDescription, spec: dict) -> None:
|
||||
for field, expected in spec.items():
|
||||
field = camel_to_snake(field)
|
||||
if field == "type":
|
||||
field = "server_type_name"
|
||||
self.test.assertEqual(getattr(actual, field), expected)
|
||||
|
||||
def match_topology_description(self, actual: TopologyDescription, spec: dict) -> None:
|
||||
for field, expected in spec.items():
|
||||
field = camel_to_snake(field)
|
||||
if field == "type":
|
||||
field = "topology_type_name"
|
||||
self.test.assertEqual(getattr(actual, field), expected)
|
||||
|
||||
def match_event_fields(self, actual: Any, spec: dict) -> None:
|
||||
for field, expected in spec.items():
|
||||
if field == "command" and isinstance(actual, CommandStartedEvent):
|
||||
command = spec["command"]
|
||||
if command:
|
||||
self.match_result(command, actual.command)
|
||||
continue
|
||||
if field == "reply" and isinstance(actual, CommandSucceededEvent):
|
||||
reply = spec["reply"]
|
||||
if reply:
|
||||
self.match_result(reply, actual.reply)
|
||||
continue
|
||||
if field == "hasServiceId":
|
||||
if spec["hasServiceId"]:
|
||||
self.test.assertIsNotNone(actual.service_id)
|
||||
self.test.assertIsInstance(actual.service_id, ObjectId)
|
||||
else:
|
||||
self.test.assertIsNone(actual.service_id)
|
||||
continue
|
||||
if field == "hasServerConnectionId":
|
||||
if spec["hasServerConnectionId"]:
|
||||
self.test.assertIsNotNone(actual.server_connection_id)
|
||||
self.test.assertIsInstance(actual.server_connection_id, int)
|
||||
else:
|
||||
self.test.assertIsNone(actual.server_connection_id)
|
||||
continue
|
||||
if field in ("previousDescription", "newDescription"):
|
||||
if isinstance(actual, ServerDescriptionChangedEvent):
|
||||
self.match_server_description(
|
||||
getattr(actual, camel_to_snake(field)), spec[field]
|
||||
)
|
||||
continue
|
||||
if isinstance(actual, TopologyDescriptionChangedEvent):
|
||||
self.match_topology_description(
|
||||
getattr(actual, camel_to_snake(field)), spec[field]
|
||||
)
|
||||
continue
|
||||
|
||||
if field == "interruptInUseConnections":
|
||||
field = "interrupt_connections"
|
||||
else:
|
||||
field = camel_to_snake(field)
|
||||
self.test.assertEqual(getattr(actual, field), expected)
|
||||
|
||||
def match_event(self, expectation, actual):
|
||||
name, spec = next(iter(expectation.items()))
|
||||
if name == "commandStartedEvent":
|
||||
self.test.assertIsInstance(actual, CommandStartedEvent)
|
||||
elif name == "commandSucceededEvent":
|
||||
self.test.assertIsInstance(actual, CommandSucceededEvent)
|
||||
elif name == "commandFailedEvent":
|
||||
self.test.assertIsInstance(actual, CommandFailedEvent)
|
||||
elif name == "poolCreatedEvent":
|
||||
self.test.assertIsInstance(actual, PoolCreatedEvent)
|
||||
elif name == "poolReadyEvent":
|
||||
self.test.assertIsInstance(actual, PoolReadyEvent)
|
||||
elif name == "poolClearedEvent":
|
||||
self.test.assertIsInstance(actual, PoolClearedEvent)
|
||||
self.test.assertIsInstance(actual.interrupt_connections, bool)
|
||||
elif name == "poolClosedEvent":
|
||||
self.test.assertIsInstance(actual, PoolClosedEvent)
|
||||
elif name == "connectionCreatedEvent":
|
||||
self.test.assertIsInstance(actual, ConnectionCreatedEvent)
|
||||
elif name == "connectionReadyEvent":
|
||||
self.test.assertIsInstance(actual, ConnectionReadyEvent)
|
||||
elif name == "connectionClosedEvent":
|
||||
self.test.assertIsInstance(actual, ConnectionClosedEvent)
|
||||
elif name == "connectionCheckOutStartedEvent":
|
||||
self.test.assertIsInstance(actual, ConnectionCheckOutStartedEvent)
|
||||
elif name == "connectionCheckOutFailedEvent":
|
||||
self.test.assertIsInstance(actual, ConnectionCheckOutFailedEvent)
|
||||
elif name == "connectionCheckedOutEvent":
|
||||
self.test.assertIsInstance(actual, ConnectionCheckedOutEvent)
|
||||
elif name == "connectionCheckedInEvent":
|
||||
self.test.assertIsInstance(actual, ConnectionCheckedInEvent)
|
||||
elif name == "serverDescriptionChangedEvent":
|
||||
self.test.assertIsInstance(actual, ServerDescriptionChangedEvent)
|
||||
elif name == "serverHeartbeatStartedEvent":
|
||||
self.test.assertIsInstance(actual, ServerHeartbeatStartedEvent)
|
||||
elif name == "serverHeartbeatSucceededEvent":
|
||||
self.test.assertIsInstance(actual, ServerHeartbeatSucceededEvent)
|
||||
elif name == "serverHeartbeatFailedEvent":
|
||||
self.test.assertIsInstance(actual, ServerHeartbeatFailedEvent)
|
||||
elif name == "topologyDescriptionChangedEvent":
|
||||
self.test.assertIsInstance(actual, TopologyDescriptionChangedEvent)
|
||||
elif name == "topologyOpeningEvent":
|
||||
self.test.assertIsInstance(actual, TopologyOpenedEvent)
|
||||
elif name == "topologyClosedEvent":
|
||||
self.test.assertIsInstance(actual, TopologyClosedEvent)
|
||||
else:
|
||||
raise Exception(f"Unsupported event type {name}")
|
||||
|
||||
self.match_event_fields(actual, spec)
|
||||
|
||||
|
||||
def coerce_result(opname, result):
|
||||
"""Convert a pymongo result into the spec's result format."""
|
||||
if hasattr(result, "acknowledged") and not result.acknowledged:
|
||||
return {"acknowledged": False}
|
||||
if opname == "bulkWrite":
|
||||
return parse_bulk_write_result(result)
|
||||
if opname == "clientBulkWrite":
|
||||
return parse_client_bulk_write_result(result)
|
||||
if opname == "insertOne":
|
||||
return {"insertedId": result.inserted_id}
|
||||
if opname == "insertMany":
|
||||
return dict(enumerate(result.inserted_ids))
|
||||
if opname in ("deleteOne", "deleteMany"):
|
||||
return {"deletedCount": result.deleted_count}
|
||||
if opname in ("updateOne", "updateMany", "replaceOne"):
|
||||
value = {
|
||||
"matchedCount": result.matched_count,
|
||||
"modifiedCount": result.modified_count,
|
||||
"upsertedCount": 0 if result.upserted_id is None else 1,
|
||||
}
|
||||
if result.upserted_id is not None:
|
||||
value["upsertedId"] = result.upserted_id
|
||||
return value
|
||||
return result
|
||||
147
test/utils.py
147
test/utils.py
@ -418,153 +418,6 @@ class FunctionCallRecorder:
|
||||
return len(self._call_list)
|
||||
|
||||
|
||||
class SpecTestCreator:
|
||||
"""Class to create test cases from specifications."""
|
||||
|
||||
def __init__(self, create_test, test_class, test_path):
|
||||
"""Create a TestCreator object.
|
||||
|
||||
:Parameters:
|
||||
- `create_test`: callback that returns a test case. The callback
|
||||
must accept the following arguments - a dictionary containing the
|
||||
entire test specification (the `scenario_def`), a dictionary
|
||||
containing the specification for which the test case will be
|
||||
generated (the `test_def`).
|
||||
- `test_class`: the unittest.TestCase class in which to create the
|
||||
test case.
|
||||
- `test_path`: path to the directory containing the JSON files with
|
||||
the test specifications.
|
||||
"""
|
||||
self._create_test = create_test
|
||||
self._test_class = test_class
|
||||
self.test_path = test_path
|
||||
|
||||
def _ensure_min_max_server_version(self, scenario_def, method):
|
||||
"""Test modifier that enforces a version range for the server on a
|
||||
test case.
|
||||
"""
|
||||
if "minServerVersion" in scenario_def:
|
||||
min_ver = tuple(int(elt) for elt in scenario_def["minServerVersion"].split("."))
|
||||
if min_ver is not None:
|
||||
method = client_context.require_version_min(*min_ver)(method)
|
||||
|
||||
if "maxServerVersion" in scenario_def:
|
||||
max_ver = tuple(int(elt) for elt in scenario_def["maxServerVersion"].split("."))
|
||||
if max_ver is not None:
|
||||
method = client_context.require_version_max(*max_ver)(method)
|
||||
|
||||
if "serverless" in scenario_def:
|
||||
serverless = scenario_def["serverless"]
|
||||
if serverless == "require":
|
||||
serverless_satisfied = client_context.serverless
|
||||
elif serverless == "forbid":
|
||||
serverless_satisfied = not client_context.serverless
|
||||
else: # unset or "allow"
|
||||
serverless_satisfied = True
|
||||
method = unittest.skipUnless(
|
||||
serverless_satisfied, "Serverless requirement not satisfied"
|
||||
)(method)
|
||||
|
||||
return method
|
||||
|
||||
@staticmethod
|
||||
def valid_topology(run_on_req):
|
||||
return client_context.is_topology_type(
|
||||
run_on_req.get("topology", ["single", "replicaset", "sharded", "load-balanced"])
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def min_server_version(run_on_req):
|
||||
version = run_on_req.get("minServerVersion")
|
||||
if version:
|
||||
min_ver = tuple(int(elt) for elt in version.split("."))
|
||||
return client_context.version >= min_ver
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def max_server_version(run_on_req):
|
||||
version = run_on_req.get("maxServerVersion")
|
||||
if version:
|
||||
max_ver = tuple(int(elt) for elt in version.split("."))
|
||||
return client_context.version <= max_ver
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def valid_auth_enabled(run_on_req):
|
||||
if "authEnabled" in run_on_req:
|
||||
if run_on_req["authEnabled"]:
|
||||
return client_context.auth_enabled
|
||||
return not client_context.auth_enabled
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def serverless_ok(run_on_req):
|
||||
serverless = run_on_req["serverless"]
|
||||
if serverless == "require":
|
||||
return client_context.serverless
|
||||
elif serverless == "forbid":
|
||||
return not client_context.serverless
|
||||
else: # unset or "allow"
|
||||
return True
|
||||
|
||||
def should_run_on(self, scenario_def):
|
||||
run_on = scenario_def.get("runOn", [])
|
||||
if not run_on:
|
||||
# Always run these tests.
|
||||
return True
|
||||
|
||||
for req in run_on:
|
||||
if (
|
||||
self.valid_topology(req)
|
||||
and self.min_server_version(req)
|
||||
and self.max_server_version(req)
|
||||
and self.valid_auth_enabled(req)
|
||||
and self.serverless_ok(req)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def ensure_run_on(self, scenario_def, method):
|
||||
"""Test modifier that enforces a 'runOn' on a test case."""
|
||||
return client_context._require(
|
||||
lambda: self.should_run_on(scenario_def), "runOn not satisfied", method
|
||||
)
|
||||
|
||||
def tests(self, scenario_def):
|
||||
"""Allow CMAP spec test to override the location of test."""
|
||||
return scenario_def["tests"]
|
||||
|
||||
def create_tests(self):
|
||||
for dirpath, _, filenames in os.walk(self.test_path):
|
||||
dirname = os.path.split(dirpath)[-1]
|
||||
|
||||
for filename in filenames:
|
||||
with open(os.path.join(dirpath, filename)) as scenario_stream:
|
||||
# Use tz_aware=False to match how CodecOptions decodes
|
||||
# dates.
|
||||
opts = json_util.JSONOptions(tz_aware=False)
|
||||
scenario_def = ScenarioDict(
|
||||
json_util.loads(scenario_stream.read(), json_options=opts)
|
||||
)
|
||||
|
||||
test_type = os.path.splitext(filename)[0]
|
||||
|
||||
# Construct test from scenario.
|
||||
for test_def in self.tests(scenario_def):
|
||||
test_name = "test_{}_{}_{}".format(
|
||||
dirname,
|
||||
test_type.replace("-", "_").replace(".", "_"),
|
||||
str(test_def["description"].replace(" ", "_").replace(".", "_")),
|
||||
)
|
||||
|
||||
new_test = self._create_test(scenario_def, test_def, test_name)
|
||||
new_test = self._ensure_min_max_server_version(scenario_def, new_test)
|
||||
new_test = self.ensure_run_on(scenario_def, new_test)
|
||||
|
||||
new_test.__name__ = test_name
|
||||
setattr(self._test_class, new_test.__name__, new_test)
|
||||
|
||||
|
||||
def ensure_all_connected(client: MongoClient) -> None:
|
||||
"""Ensure that the client's connection pool has socket connections to all
|
||||
members of a replica set. Raises ConfigurationError when called with a
|
||||
|
||||
@ -15,8 +15,12 @@
|
||||
"""Utilities for testing driver specs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import os
|
||||
import threading
|
||||
import unittest
|
||||
from asyncio import iscoroutinefunction
|
||||
from collections import abc
|
||||
from test import IntegrationTest, client_context, client_knobs
|
||||
from test.utils import (
|
||||
@ -24,6 +28,7 @@ from test.utils import (
|
||||
CompareType,
|
||||
EventListener,
|
||||
OvertCommandListener,
|
||||
ScenarioDict,
|
||||
ServerAndTopologyEventListener,
|
||||
camel_to_snake,
|
||||
camel_to_snake_args,
|
||||
@ -32,11 +37,12 @@ from test.utils import (
|
||||
)
|
||||
from typing import List
|
||||
|
||||
from bson import ObjectId, decode, encode
|
||||
from bson import ObjectId, decode, encode, json_util
|
||||
from bson.binary import Binary
|
||||
from bson.int64 import Int64
|
||||
from bson.son import SON
|
||||
from gridfs import GridFSBucket
|
||||
from gridfs.synchronous.grid_file import GridFSBucket
|
||||
from pymongo.errors import BulkWriteError, OperationFailure, PyMongoError
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
@ -83,6 +89,161 @@ class SpecRunnerThread(threading.Thread):
|
||||
self.stop()
|
||||
|
||||
|
||||
class SpecTestCreator:
|
||||
"""Class to create test cases from specifications."""
|
||||
|
||||
def __init__(self, create_test, test_class, test_path):
|
||||
"""Create a TestCreator object.
|
||||
|
||||
:Parameters:
|
||||
- `create_test`: callback that returns a test case. The callback
|
||||
must accept the following arguments - a dictionary containing the
|
||||
entire test specification (the `scenario_def`), a dictionary
|
||||
containing the specification for which the test case will be
|
||||
generated (the `test_def`).
|
||||
- `test_class`: the unittest.TestCase class in which to create the
|
||||
test case.
|
||||
- `test_path`: path to the directory containing the JSON files with
|
||||
the test specifications.
|
||||
"""
|
||||
self._create_test = create_test
|
||||
self._test_class = test_class
|
||||
self.test_path = test_path
|
||||
|
||||
def _ensure_min_max_server_version(self, scenario_def, method):
|
||||
"""Test modifier that enforces a version range for the server on a
|
||||
test case.
|
||||
"""
|
||||
if "minServerVersion" in scenario_def:
|
||||
min_ver = tuple(int(elt) for elt in scenario_def["minServerVersion"].split("."))
|
||||
if min_ver is not None:
|
||||
method = client_context.require_version_min(*min_ver)(method)
|
||||
|
||||
if "maxServerVersion" in scenario_def:
|
||||
max_ver = tuple(int(elt) for elt in scenario_def["maxServerVersion"].split("."))
|
||||
if max_ver is not None:
|
||||
method = client_context.require_version_max(*max_ver)(method)
|
||||
|
||||
if "serverless" in scenario_def:
|
||||
serverless = scenario_def["serverless"]
|
||||
if serverless == "require":
|
||||
serverless_satisfied = client_context.serverless
|
||||
elif serverless == "forbid":
|
||||
serverless_satisfied = not client_context.serverless
|
||||
else: # unset or "allow"
|
||||
serverless_satisfied = True
|
||||
method = unittest.skipUnless(
|
||||
serverless_satisfied, "Serverless requirement not satisfied"
|
||||
)(method)
|
||||
|
||||
return method
|
||||
|
||||
@staticmethod
|
||||
def valid_topology(run_on_req):
|
||||
return client_context.is_topology_type(
|
||||
run_on_req.get("topology", ["single", "replicaset", "sharded", "load-balanced"])
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def min_server_version(run_on_req):
|
||||
version = run_on_req.get("minServerVersion")
|
||||
if version:
|
||||
min_ver = tuple(int(elt) for elt in version.split("."))
|
||||
return client_context.version >= min_ver
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def max_server_version(run_on_req):
|
||||
version = run_on_req.get("maxServerVersion")
|
||||
if version:
|
||||
max_ver = tuple(int(elt) for elt in version.split("."))
|
||||
return client_context.version <= max_ver
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def valid_auth_enabled(run_on_req):
|
||||
if "authEnabled" in run_on_req:
|
||||
if run_on_req["authEnabled"]:
|
||||
return client_context.auth_enabled
|
||||
return not client_context.auth_enabled
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def serverless_ok(run_on_req):
|
||||
serverless = run_on_req["serverless"]
|
||||
if serverless == "require":
|
||||
return client_context.serverless
|
||||
elif serverless == "forbid":
|
||||
return not client_context.serverless
|
||||
else: # unset or "allow"
|
||||
return True
|
||||
|
||||
def should_run_on(self, scenario_def):
|
||||
run_on = scenario_def.get("runOn", [])
|
||||
if not run_on:
|
||||
# Always run these tests.
|
||||
return True
|
||||
|
||||
for req in run_on:
|
||||
if (
|
||||
self.valid_topology(req)
|
||||
and self.min_server_version(req)
|
||||
and self.max_server_version(req)
|
||||
and self.valid_auth_enabled(req)
|
||||
and self.serverless_ok(req)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def ensure_run_on(self, scenario_def, method):
|
||||
"""Test modifier that enforces a 'runOn' on a test case."""
|
||||
|
||||
def predicate():
|
||||
return self.should_run_on(scenario_def)
|
||||
|
||||
return client_context._require(predicate, "runOn not satisfied", method)
|
||||
|
||||
def tests(self, scenario_def):
|
||||
"""Allow CMAP spec test to override the location of test."""
|
||||
return scenario_def["tests"]
|
||||
|
||||
def _create_tests(self):
|
||||
for dirpath, _, filenames in os.walk(self.test_path):
|
||||
dirname = os.path.split(dirpath)[-1]
|
||||
|
||||
for filename in filenames:
|
||||
with open(os.path.join(dirpath, filename)) as scenario_stream: # noqa: ASYNC101, RUF100
|
||||
# Use tz_aware=False to match how CodecOptions decodes
|
||||
# dates.
|
||||
opts = json_util.JSONOptions(tz_aware=False)
|
||||
scenario_def = ScenarioDict(
|
||||
json_util.loads(scenario_stream.read(), json_options=opts)
|
||||
)
|
||||
|
||||
test_type = os.path.splitext(filename)[0]
|
||||
|
||||
# Construct test from scenario.
|
||||
for test_def in self.tests(scenario_def):
|
||||
test_name = "test_{}_{}_{}".format(
|
||||
dirname,
|
||||
test_type.replace("-", "_").replace(".", "_"),
|
||||
str(test_def["description"].replace(" ", "_").replace(".", "_")),
|
||||
)
|
||||
|
||||
new_test = self._create_test(scenario_def, test_def, test_name)
|
||||
new_test = self._ensure_min_max_server_version(scenario_def, new_test)
|
||||
new_test = self.ensure_run_on(scenario_def, new_test)
|
||||
|
||||
new_test.__name__ = test_name
|
||||
setattr(self._test_class, new_test.__name__, new_test)
|
||||
|
||||
def create_tests(self):
|
||||
if _IS_SYNC:
|
||||
self._create_tests()
|
||||
else:
|
||||
asyncio.run(self._create_tests())
|
||||
|
||||
|
||||
class SpecRunner(IntegrationTest):
|
||||
mongos_clients: List
|
||||
knobs: client_knobs
|
||||
@ -312,7 +473,10 @@ class SpecRunner(IntegrationTest):
|
||||
args.update(arguments)
|
||||
arguments = args
|
||||
|
||||
result = cmd(**dict(arguments))
|
||||
if not _IS_SYNC and iscoroutinefunction(cmd):
|
||||
result = cmd(**dict(arguments))
|
||||
else:
|
||||
result = cmd(**dict(arguments))
|
||||
# Cleanup open change stream cursors.
|
||||
if name == "watch":
|
||||
self.addCleanup(result.close)
|
||||
@ -583,7 +747,7 @@ class SpecRunner(IntegrationTest):
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
read_concern=ReadConcern("local"),
|
||||
)
|
||||
actual_data = (outcome_coll.find(sort=[("_id", 1)])).to_list()
|
||||
actual_data = outcome_coll.find(sort=[("_id", 1)]).to_list()
|
||||
|
||||
# The expected data needs to be the left hand side here otherwise
|
||||
# CompareType(Binary) doesn't work.
|
||||
|
||||
@ -105,6 +105,8 @@ replacements = {
|
||||
"PyMongo|c|async": "PyMongo|c",
|
||||
"AsyncTestGridFile": "TestGridFile",
|
||||
"AsyncTestGridFileNoConnect": "TestGridFileNoConnect",
|
||||
"AsyncTestSpec": "TestSpec",
|
||||
"AsyncSpecTestCreator": "SpecTestCreator",
|
||||
"async_set_fail_point": "set_fail_point",
|
||||
"async_ensure_all_connected": "ensure_all_connected",
|
||||
"async_repl_set_step_down": "repl_set_step_down",
|
||||
@ -188,8 +190,14 @@ converted_tests = [
|
||||
"test_client.py",
|
||||
"test_client_bulk_write.py",
|
||||
"test_client_context.py",
|
||||
"test_collation.py",
|
||||
"test_collection.py",
|
||||
"test_command_logging.py",
|
||||
"test_command_monitoring.py",
|
||||
"test_common.py",
|
||||
"test_connection_logging.py",
|
||||
"test_connections_survive_primary_stepdown_spec.py",
|
||||
"test_crud_unified.py",
|
||||
"test_cursor.py",
|
||||
"test_database.py",
|
||||
"test_encryption.py",
|
||||
@ -201,6 +209,7 @@ converted_tests = [
|
||||
"test_retryable_writes.py",
|
||||
"test_session.py",
|
||||
"test_transactions.py",
|
||||
"unified_format.py",
|
||||
]
|
||||
|
||||
sync_test_files = [
|
||||
|
||||
Loading…
Reference in New Issue
Block a user