PYTHON-4700 - Convert CSFLE tests to async (#1907)
This commit is contained in:
parent
8118aea985
commit
3a662291e0
@ -257,9 +257,9 @@ if [ -z "$GREEN_FRAMEWORK" ]; then
|
||||
# Use --capture=tee-sys so pytest prints test output inline:
|
||||
# https://docs.pytest.org/en/stable/how-to/capture-stdout-stderr.html
|
||||
if [ -z "$TEST_SUITES" ]; then
|
||||
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 $TEST_ARGS
|
||||
python -m pytest -v --capture=tee-sys --durations=5 $TEST_ARGS
|
||||
else
|
||||
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 -m $TEST_SUITES $TEST_ARGS
|
||||
python -m pytest -v --capture=tee-sys --durations=5 -m $TEST_SUITES $TEST_ARGS
|
||||
fi
|
||||
else
|
||||
python green_framework_test.py $GREEN_FRAMEWORK -v $TEST_ARGS
|
||||
|
||||
@ -180,10 +180,20 @@ class _EncryptionIO(AsyncMongoCryptCallback): # type: ignore[misc]
|
||||
while kms_context.bytes_needed > 0:
|
||||
# CSOT: update timeout.
|
||||
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
|
||||
data = conn.recv(kms_context.bytes_needed)
|
||||
if _IS_SYNC:
|
||||
data = conn.recv(kms_context.bytes_needed)
|
||||
else:
|
||||
from pymongo.network_layer import ( # type: ignore[attr-defined]
|
||||
async_receive_data_socket,
|
||||
)
|
||||
|
||||
data = await async_receive_data_socket(conn, kms_context.bytes_needed)
|
||||
if not data:
|
||||
raise OSError("KMS connection closed")
|
||||
kms_context.feed(data)
|
||||
# Async raises an OSError instead of returning empty bytes
|
||||
except OSError as err:
|
||||
raise OSError("KMS connection closed") from err
|
||||
except BLOCKING_IO_ERRORS:
|
||||
raise socket.timeout("timed out") from None
|
||||
finally:
|
||||
|
||||
@ -130,7 +130,7 @@ if sys.platform != "win32":
|
||||
loop.remove_writer(fd)
|
||||
|
||||
async def _async_receive_ssl(
|
||||
conn: _sslConn, length: int, loop: AbstractEventLoop
|
||||
conn: _sslConn, length: int, loop: AbstractEventLoop, once: Optional[bool] = False
|
||||
) -> memoryview:
|
||||
mv = memoryview(bytearray(length))
|
||||
total_read = 0
|
||||
@ -145,6 +145,9 @@ if sys.platform != "win32":
|
||||
read = conn.recv_into(mv[total_read:])
|
||||
if read == 0:
|
||||
raise OSError("connection closed")
|
||||
# KMS responses update their expected size after the first batch, stop reading after one loop
|
||||
if once:
|
||||
return mv[:read]
|
||||
total_read += read
|
||||
except BLOCKING_IO_ERRORS as exc:
|
||||
fd = conn.fileno()
|
||||
@ -275,6 +278,28 @@ async def async_receive_data(
|
||||
sock.settimeout(sock_timeout)
|
||||
|
||||
|
||||
async def async_receive_data_socket(
|
||||
sock: Union[socket.socket, _sslConn], length: int
|
||||
) -> memoryview:
|
||||
sock_timeout = sock.gettimeout()
|
||||
timeout = sock_timeout
|
||||
|
||||
sock.settimeout(0.0)
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
|
||||
return await asyncio.wait_for(
|
||||
_async_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type]
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type]
|
||||
except asyncio.TimeoutError as err:
|
||||
raise socket.timeout("timed out") from err
|
||||
finally:
|
||||
sock.settimeout(sock_timeout)
|
||||
|
||||
|
||||
async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview:
|
||||
mv = memoryview(bytearray(length))
|
||||
bytes_read = 0
|
||||
|
||||
@ -180,10 +180,20 @@ class _EncryptionIO(MongoCryptCallback): # type: ignore[misc]
|
||||
while kms_context.bytes_needed > 0:
|
||||
# CSOT: update timeout.
|
||||
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
|
||||
data = conn.recv(kms_context.bytes_needed)
|
||||
if _IS_SYNC:
|
||||
data = conn.recv(kms_context.bytes_needed)
|
||||
else:
|
||||
from pymongo.network_layer import ( # type: ignore[attr-defined]
|
||||
receive_data_socket,
|
||||
)
|
||||
|
||||
data = receive_data_socket(conn, kms_context.bytes_needed)
|
||||
if not data:
|
||||
raise OSError("KMS connection closed")
|
||||
kms_context.feed(data)
|
||||
# Async raises an OSError instead of returning empty bytes
|
||||
except OSError as err:
|
||||
raise OSError("KMS connection closed") from err
|
||||
except BLOCKING_IO_ERRORS:
|
||||
raise socket.timeout("timed out") from None
|
||||
finally:
|
||||
|
||||
@ -464,11 +464,12 @@ class ClientContext:
|
||||
if not self.connected:
|
||||
pair = self.pair
|
||||
raise SkipTest(f"Cannot connect to MongoDB on {pair}")
|
||||
if iscoroutinefunction(condition) and condition():
|
||||
if wraps_async:
|
||||
return f(*args, **kwargs)
|
||||
else:
|
||||
return f(*args, **kwargs)
|
||||
if iscoroutinefunction(condition):
|
||||
if condition():
|
||||
if wraps_async:
|
||||
return f(*args, **kwargs)
|
||||
else:
|
||||
return f(*args, **kwargs)
|
||||
elif condition():
|
||||
if wraps_async:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
@ -466,11 +466,12 @@ class AsyncClientContext:
|
||||
if not self.connected:
|
||||
pair = await self.pair
|
||||
raise SkipTest(f"Cannot connect to MongoDB on {pair}")
|
||||
if iscoroutinefunction(condition) and await condition():
|
||||
if wraps_async:
|
||||
return await f(*args, **kwargs)
|
||||
else:
|
||||
return f(*args, **kwargs)
|
||||
if iscoroutinefunction(condition):
|
||||
if await condition():
|
||||
if wraps_async:
|
||||
return await f(*args, **kwargs)
|
||||
else:
|
||||
return f(*args, **kwargs)
|
||||
elif condition():
|
||||
if wraps_async:
|
||||
return await f(*args, **kwargs)
|
||||
|
||||
@ -30,6 +30,7 @@ import uuid
|
||||
import warnings
|
||||
from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, async_client_context
|
||||
from test.asynchronous.test_bulk import AsyncBulkTestBase
|
||||
from test.asynchronous.utils_spec_runner import AsyncSpecRunner, AsyncSpecTestCreator
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, Mapping, Optional
|
||||
|
||||
@ -59,7 +60,6 @@ from test.unified_format import generate_test_classes
|
||||
from test.utils import (
|
||||
AllowListEventListener,
|
||||
OvertCommandListener,
|
||||
SpecTestCreator,
|
||||
TopologyEventListener,
|
||||
async_wait_until,
|
||||
camel_to_snake_args,
|
||||
@ -626,137 +626,132 @@ AWS_TEMP_NO_SESSION_CREDS = {
|
||||
KMS_TLS_OPTS = {"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}}
|
||||
|
||||
|
||||
if _IS_SYNC:
|
||||
# TODO: Add asynchronous SpecRunner (https://jira.mongodb.org/browse/PYTHON-4700)
|
||||
class TestSpec(AsyncSpecRunner):
|
||||
@classmethod
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
class AsyncTestSpec(AsyncSpecRunner):
|
||||
@classmethod
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
async def _setup_class(cls):
|
||||
await super()._setup_class()
|
||||
|
||||
def parse_auto_encrypt_opts(self, opts):
|
||||
"""Parse clientOptions.autoEncryptOpts."""
|
||||
opts = camel_to_snake_args(opts)
|
||||
kms_providers = opts["kms_providers"]
|
||||
if "aws" in kms_providers:
|
||||
kms_providers["aws"] = AWS_CREDS
|
||||
if not any(AWS_CREDS.values()):
|
||||
self.skipTest("AWS environment credentials are not set")
|
||||
if "awsTemporary" in kms_providers:
|
||||
kms_providers["aws"] = AWS_TEMP_CREDS
|
||||
del kms_providers["awsTemporary"]
|
||||
if not any(AWS_TEMP_CREDS.values()):
|
||||
self.skipTest("AWS Temp environment credentials are not set")
|
||||
if "awsTemporaryNoSessionToken" in kms_providers:
|
||||
kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS
|
||||
del kms_providers["awsTemporaryNoSessionToken"]
|
||||
if not any(AWS_TEMP_NO_SESSION_CREDS.values()):
|
||||
self.skipTest("AWS Temp environment credentials are not set")
|
||||
if "azure" in kms_providers:
|
||||
kms_providers["azure"] = AZURE_CREDS
|
||||
if not any(AZURE_CREDS.values()):
|
||||
self.skipTest("Azure environment credentials are not set")
|
||||
if "gcp" in kms_providers:
|
||||
kms_providers["gcp"] = GCP_CREDS
|
||||
if not any(AZURE_CREDS.values()):
|
||||
self.skipTest("GCP environment credentials are not set")
|
||||
if "kmip" in kms_providers:
|
||||
kms_providers["kmip"] = KMIP_CREDS
|
||||
opts["kms_tls_options"] = KMS_TLS_OPTS
|
||||
if "key_vault_namespace" not in opts:
|
||||
opts["key_vault_namespace"] = "keyvault.datakeys"
|
||||
if "extra_options" in opts:
|
||||
opts.update(camel_to_snake_args(opts.pop("extra_options")))
|
||||
def parse_auto_encrypt_opts(self, opts):
|
||||
"""Parse clientOptions.autoEncryptOpts."""
|
||||
opts = camel_to_snake_args(opts)
|
||||
kms_providers = opts["kms_providers"]
|
||||
if "aws" in kms_providers:
|
||||
kms_providers["aws"] = AWS_CREDS
|
||||
if not any(AWS_CREDS.values()):
|
||||
self.skipTest("AWS environment credentials are not set")
|
||||
if "awsTemporary" in kms_providers:
|
||||
kms_providers["aws"] = AWS_TEMP_CREDS
|
||||
del kms_providers["awsTemporary"]
|
||||
if not any(AWS_TEMP_CREDS.values()):
|
||||
self.skipTest("AWS Temp environment credentials are not set")
|
||||
if "awsTemporaryNoSessionToken" in kms_providers:
|
||||
kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS
|
||||
del kms_providers["awsTemporaryNoSessionToken"]
|
||||
if not any(AWS_TEMP_NO_SESSION_CREDS.values()):
|
||||
self.skipTest("AWS Temp environment credentials are not set")
|
||||
if "azure" in kms_providers:
|
||||
kms_providers["azure"] = AZURE_CREDS
|
||||
if not any(AZURE_CREDS.values()):
|
||||
self.skipTest("Azure environment credentials are not set")
|
||||
if "gcp" in kms_providers:
|
||||
kms_providers["gcp"] = GCP_CREDS
|
||||
if not any(AZURE_CREDS.values()):
|
||||
self.skipTest("GCP environment credentials are not set")
|
||||
if "kmip" in kms_providers:
|
||||
kms_providers["kmip"] = KMIP_CREDS
|
||||
opts["kms_tls_options"] = KMS_TLS_OPTS
|
||||
if "key_vault_namespace" not in opts:
|
||||
opts["key_vault_namespace"] = "keyvault.datakeys"
|
||||
if "extra_options" in opts:
|
||||
opts.update(camel_to_snake_args(opts.pop("extra_options")))
|
||||
|
||||
opts = dict(opts)
|
||||
return AutoEncryptionOpts(**opts)
|
||||
opts = dict(opts)
|
||||
return AutoEncryptionOpts(**opts)
|
||||
|
||||
def parse_client_options(self, opts):
|
||||
"""Override clientOptions parsing to support autoEncryptOpts."""
|
||||
encrypt_opts = opts.pop("autoEncryptOpts", None)
|
||||
if encrypt_opts:
|
||||
opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts)
|
||||
def parse_client_options(self, opts):
|
||||
"""Override clientOptions parsing to support autoEncryptOpts."""
|
||||
encrypt_opts = opts.pop("autoEncryptOpts", None)
|
||||
if encrypt_opts:
|
||||
opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts)
|
||||
|
||||
return super().parse_client_options(opts)
|
||||
return super().parse_client_options(opts)
|
||||
|
||||
def get_object_name(self, op):
|
||||
"""Default object is collection."""
|
||||
return op.get("object", "collection")
|
||||
def get_object_name(self, op):
|
||||
"""Default object is collection."""
|
||||
return op.get("object", "collection")
|
||||
|
||||
def maybe_skip_scenario(self, test):
|
||||
super().maybe_skip_scenario(test)
|
||||
desc = test["description"].lower()
|
||||
if (
|
||||
"timeoutms applied to listcollections to get collection schema" in desc
|
||||
and sys.platform in ("win32", "darwin")
|
||||
):
|
||||
self.skipTest("PYTHON-3706 flaky test on Windows/macOS")
|
||||
if "type=symbol" in desc:
|
||||
self.skipTest("PyMongo does not support the symbol type")
|
||||
if (
|
||||
"timeoutms applied to listcollections to get collection schema" in desc
|
||||
and not _IS_SYNC
|
||||
):
|
||||
self.skipTest("PYTHON-4844 flaky test on async")
|
||||
def maybe_skip_scenario(self, test):
|
||||
super().maybe_skip_scenario(test)
|
||||
desc = test["description"].lower()
|
||||
if (
|
||||
"timeoutms applied to listcollections to get collection schema" in desc
|
||||
and sys.platform in ("win32", "darwin")
|
||||
):
|
||||
self.skipTest("PYTHON-3706 flaky test on Windows/macOS")
|
||||
if "type=symbol" in desc:
|
||||
self.skipTest("PyMongo does not support the symbol type")
|
||||
if "timeoutms applied to listcollections to get collection schema" in desc and not _IS_SYNC:
|
||||
self.skipTest("PYTHON-4844 flaky test on async")
|
||||
|
||||
def setup_scenario(self, scenario_def):
|
||||
"""Override a test's setup."""
|
||||
key_vault_data = scenario_def["key_vault_data"]
|
||||
encrypted_fields = scenario_def["encrypted_fields"]
|
||||
json_schema = scenario_def["json_schema"]
|
||||
data = scenario_def["data"]
|
||||
coll = async_client_context.client.get_database("keyvault", codec_options=OPTS)[
|
||||
"datakeys"
|
||||
]
|
||||
coll.delete_many({})
|
||||
if key_vault_data:
|
||||
coll.insert_many(key_vault_data)
|
||||
async def setup_scenario(self, scenario_def):
|
||||
"""Override a test's setup."""
|
||||
key_vault_data = scenario_def["key_vault_data"]
|
||||
encrypted_fields = scenario_def["encrypted_fields"]
|
||||
json_schema = scenario_def["json_schema"]
|
||||
data = scenario_def["data"]
|
||||
coll = async_client_context.client.get_database("keyvault", codec_options=OPTS)["datakeys"]
|
||||
await coll.delete_many({})
|
||||
if key_vault_data:
|
||||
await coll.insert_many(key_vault_data)
|
||||
|
||||
db_name = self.get_scenario_db_name(scenario_def)
|
||||
coll_name = self.get_scenario_coll_name(scenario_def)
|
||||
db = async_client_context.client.get_database(db_name, codec_options=OPTS)
|
||||
coll = db.drop_collection(coll_name, encrypted_fields=encrypted_fields)
|
||||
wc = WriteConcern(w="majority")
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if json_schema:
|
||||
kwargs["validator"] = {"$jsonSchema": json_schema}
|
||||
kwargs["codec_options"] = OPTS
|
||||
if not data:
|
||||
kwargs["write_concern"] = wc
|
||||
if encrypted_fields:
|
||||
kwargs["encryptedFields"] = encrypted_fields
|
||||
db.create_collection(coll_name, **kwargs)
|
||||
coll = db[coll_name]
|
||||
if data:
|
||||
# Load data.
|
||||
coll.with_options(write_concern=wc).insert_many(scenario_def["data"])
|
||||
db_name = self.get_scenario_db_name(scenario_def)
|
||||
coll_name = self.get_scenario_coll_name(scenario_def)
|
||||
db = async_client_context.client.get_database(db_name, codec_options=OPTS)
|
||||
await db.drop_collection(coll_name, encrypted_fields=encrypted_fields)
|
||||
wc = WriteConcern(w="majority")
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if json_schema:
|
||||
kwargs["validator"] = {"$jsonSchema": json_schema}
|
||||
kwargs["codec_options"] = OPTS
|
||||
if not data:
|
||||
kwargs["write_concern"] = wc
|
||||
if encrypted_fields:
|
||||
kwargs["encryptedFields"] = encrypted_fields
|
||||
await db.create_collection(coll_name, **kwargs)
|
||||
coll = db[coll_name]
|
||||
if data:
|
||||
# Load data.
|
||||
await coll.with_options(write_concern=wc).insert_many(scenario_def["data"])
|
||||
|
||||
def allowable_errors(self, op):
|
||||
"""Override expected error classes."""
|
||||
errors = super().allowable_errors(op)
|
||||
# An updateOne test expects encryption to error when no $ operator
|
||||
# appears but pymongo raises a client side ValueError in this case.
|
||||
if op["name"] == "updateOne":
|
||||
errors += (ValueError,)
|
||||
return errors
|
||||
def allowable_errors(self, op):
|
||||
"""Override expected error classes."""
|
||||
errors = super().allowable_errors(op)
|
||||
# An updateOne test expects encryption to error when no $ operator
|
||||
# appears but pymongo raises a client side ValueError in this case.
|
||||
if op["name"] == "updateOne":
|
||||
errors += (ValueError,)
|
||||
return errors
|
||||
|
||||
def create_test(scenario_def, test, name):
|
||||
@async_client_context.require_test_commands
|
||||
def run_scenario(self):
|
||||
self.run_scenario(scenario_def, test)
|
||||
|
||||
return run_scenario
|
||||
async def create_test(scenario_def, test, name):
|
||||
@async_client_context.require_test_commands
|
||||
async def run_scenario(self):
|
||||
await self.run_scenario(scenario_def, test)
|
||||
|
||||
test_creator = SpecTestCreator(create_test, TestSpec, os.path.join(SPEC_PATH, "legacy"))
|
||||
test_creator.create_tests()
|
||||
return run_scenario
|
||||
|
||||
if _HAVE_PYMONGOCRYPT:
|
||||
globals().update(
|
||||
generate_test_classes(
|
||||
os.path.join(SPEC_PATH, "unified"),
|
||||
module=__name__,
|
||||
)
|
||||
|
||||
test_creator = AsyncSpecTestCreator(create_test, AsyncTestSpec, os.path.join(SPEC_PATH, "legacy"))
|
||||
test_creator.create_tests()
|
||||
|
||||
if _HAVE_PYMONGOCRYPT:
|
||||
globals().update(
|
||||
generate_test_classes(
|
||||
os.path.join(SPEC_PATH, "unified"),
|
||||
module=__name__,
|
||||
)
|
||||
)
|
||||
|
||||
# Prose Tests
|
||||
ALL_KMS_PROVIDERS = {
|
||||
|
||||
@ -15,8 +15,12 @@
|
||||
"""Utilities for testing driver specs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import os
|
||||
import threading
|
||||
import unittest
|
||||
from asyncio import iscoroutinefunction
|
||||
from collections import abc
|
||||
from test.asynchronous import AsyncIntegrationTest, async_client_context, client_knobs
|
||||
from test.utils import (
|
||||
@ -24,6 +28,7 @@ from test.utils import (
|
||||
CompareType,
|
||||
EventListener,
|
||||
OvertCommandListener,
|
||||
ScenarioDict,
|
||||
ServerAndTopologyEventListener,
|
||||
camel_to_snake,
|
||||
camel_to_snake_args,
|
||||
@ -32,11 +37,12 @@ from test.utils import (
|
||||
)
|
||||
from typing import List
|
||||
|
||||
from bson import ObjectId, decode, encode
|
||||
from bson import ObjectId, decode, encode, json_util
|
||||
from bson.binary import Binary
|
||||
from bson.int64 import Int64
|
||||
from bson.son import SON
|
||||
from gridfs import GridFSBucket
|
||||
from gridfs.asynchronous.grid_file import AsyncGridFSBucket
|
||||
from pymongo.asynchronous import client_session
|
||||
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
|
||||
from pymongo.asynchronous.cursor import AsyncCursor
|
||||
@ -83,6 +89,161 @@ class SpecRunnerThread(threading.Thread):
|
||||
self.stop()
|
||||
|
||||
|
||||
class AsyncSpecTestCreator:
|
||||
"""Class to create test cases from specifications."""
|
||||
|
||||
def __init__(self, create_test, test_class, test_path):
|
||||
"""Create a TestCreator object.
|
||||
|
||||
:Parameters:
|
||||
- `create_test`: callback that returns a test case. The callback
|
||||
must accept the following arguments - a dictionary containing the
|
||||
entire test specification (the `scenario_def`), a dictionary
|
||||
containing the specification for which the test case will be
|
||||
generated (the `test_def`).
|
||||
- `test_class`: the unittest.TestCase class in which to create the
|
||||
test case.
|
||||
- `test_path`: path to the directory containing the JSON files with
|
||||
the test specifications.
|
||||
"""
|
||||
self._create_test = create_test
|
||||
self._test_class = test_class
|
||||
self.test_path = test_path
|
||||
|
||||
def _ensure_min_max_server_version(self, scenario_def, method):
|
||||
"""Test modifier that enforces a version range for the server on a
|
||||
test case.
|
||||
"""
|
||||
if "minServerVersion" in scenario_def:
|
||||
min_ver = tuple(int(elt) for elt in scenario_def["minServerVersion"].split("."))
|
||||
if min_ver is not None:
|
||||
method = async_client_context.require_version_min(*min_ver)(method)
|
||||
|
||||
if "maxServerVersion" in scenario_def:
|
||||
max_ver = tuple(int(elt) for elt in scenario_def["maxServerVersion"].split("."))
|
||||
if max_ver is not None:
|
||||
method = async_client_context.require_version_max(*max_ver)(method)
|
||||
|
||||
if "serverless" in scenario_def:
|
||||
serverless = scenario_def["serverless"]
|
||||
if serverless == "require":
|
||||
serverless_satisfied = async_client_context.serverless
|
||||
elif serverless == "forbid":
|
||||
serverless_satisfied = not async_client_context.serverless
|
||||
else: # unset or "allow"
|
||||
serverless_satisfied = True
|
||||
method = unittest.skipUnless(
|
||||
serverless_satisfied, "Serverless requirement not satisfied"
|
||||
)(method)
|
||||
|
||||
return method
|
||||
|
||||
@staticmethod
|
||||
async def valid_topology(run_on_req):
|
||||
return await async_client_context.is_topology_type(
|
||||
run_on_req.get("topology", ["single", "replicaset", "sharded", "load-balanced"])
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def min_server_version(run_on_req):
|
||||
version = run_on_req.get("minServerVersion")
|
||||
if version:
|
||||
min_ver = tuple(int(elt) for elt in version.split("."))
|
||||
return async_client_context.version >= min_ver
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def max_server_version(run_on_req):
|
||||
version = run_on_req.get("maxServerVersion")
|
||||
if version:
|
||||
max_ver = tuple(int(elt) for elt in version.split("."))
|
||||
return async_client_context.version <= max_ver
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def valid_auth_enabled(run_on_req):
|
||||
if "authEnabled" in run_on_req:
|
||||
if run_on_req["authEnabled"]:
|
||||
return async_client_context.auth_enabled
|
||||
return not async_client_context.auth_enabled
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def serverless_ok(run_on_req):
|
||||
serverless = run_on_req["serverless"]
|
||||
if serverless == "require":
|
||||
return async_client_context.serverless
|
||||
elif serverless == "forbid":
|
||||
return not async_client_context.serverless
|
||||
else: # unset or "allow"
|
||||
return True
|
||||
|
||||
async def should_run_on(self, scenario_def):
|
||||
run_on = scenario_def.get("runOn", [])
|
||||
if not run_on:
|
||||
# Always run these tests.
|
||||
return True
|
||||
|
||||
for req in run_on:
|
||||
if (
|
||||
await self.valid_topology(req)
|
||||
and self.min_server_version(req)
|
||||
and self.max_server_version(req)
|
||||
and self.valid_auth_enabled(req)
|
||||
and self.serverless_ok(req)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def ensure_run_on(self, scenario_def, method):
|
||||
"""Test modifier that enforces a 'runOn' on a test case."""
|
||||
|
||||
async def predicate():
|
||||
return await self.should_run_on(scenario_def)
|
||||
|
||||
return async_client_context._require(predicate, "runOn not satisfied", method)
|
||||
|
||||
def tests(self, scenario_def):
|
||||
"""Allow CMAP spec test to override the location of test."""
|
||||
return scenario_def["tests"]
|
||||
|
||||
async def _create_tests(self):
|
||||
for dirpath, _, filenames in os.walk(self.test_path):
|
||||
dirname = os.path.split(dirpath)[-1]
|
||||
|
||||
for filename in filenames:
|
||||
with open(os.path.join(dirpath, filename)) as scenario_stream: # noqa: ASYNC101, RUF100
|
||||
# Use tz_aware=False to match how CodecOptions decodes
|
||||
# dates.
|
||||
opts = json_util.JSONOptions(tz_aware=False)
|
||||
scenario_def = ScenarioDict(
|
||||
json_util.loads(scenario_stream.read(), json_options=opts)
|
||||
)
|
||||
|
||||
test_type = os.path.splitext(filename)[0]
|
||||
|
||||
# Construct test from scenario.
|
||||
for test_def in self.tests(scenario_def):
|
||||
test_name = "test_{}_{}_{}".format(
|
||||
dirname,
|
||||
test_type.replace("-", "_").replace(".", "_"),
|
||||
str(test_def["description"].replace(" ", "_").replace(".", "_")),
|
||||
)
|
||||
|
||||
new_test = await self._create_test(scenario_def, test_def, test_name)
|
||||
new_test = self._ensure_min_max_server_version(scenario_def, new_test)
|
||||
new_test = self.ensure_run_on(scenario_def, new_test)
|
||||
|
||||
new_test.__name__ = test_name
|
||||
setattr(self._test_class, new_test.__name__, new_test)
|
||||
|
||||
def create_tests(self):
|
||||
if _IS_SYNC:
|
||||
self._create_tests()
|
||||
else:
|
||||
asyncio.run(self._create_tests())
|
||||
|
||||
|
||||
class AsyncSpecRunner(AsyncIntegrationTest):
|
||||
mongos_clients: List
|
||||
knobs: client_knobs
|
||||
@ -284,7 +445,7 @@ class AsyncSpecRunner(AsyncIntegrationTest):
|
||||
if object_name == "gridfsbucket":
|
||||
# Only create the GridFSBucket when we need it (for the gridfs
|
||||
# retryable reads tests).
|
||||
obj = GridFSBucket(database, bucket_name=collection.name)
|
||||
obj = AsyncGridFSBucket(database, bucket_name=collection.name)
|
||||
else:
|
||||
objects = {
|
||||
"client": database.client,
|
||||
@ -312,7 +473,10 @@ class AsyncSpecRunner(AsyncIntegrationTest):
|
||||
args.update(arguments)
|
||||
arguments = args
|
||||
|
||||
result = cmd(**dict(arguments))
|
||||
if not _IS_SYNC and iscoroutinefunction(cmd):
|
||||
result = await cmd(**dict(arguments))
|
||||
else:
|
||||
result = cmd(**dict(arguments))
|
||||
# Cleanup open change stream cursors.
|
||||
if name == "watch":
|
||||
self.addAsyncCleanup(result.close)
|
||||
@ -588,7 +752,7 @@ class AsyncSpecRunner(AsyncIntegrationTest):
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
read_concern=ReadConcern("local"),
|
||||
)
|
||||
actual_data = await (await outcome_coll.find(sort=[("_id", 1)])).to_list()
|
||||
actual_data = await outcome_coll.find(sort=[("_id", 1)]).to_list()
|
||||
|
||||
# The expected data needs to be the left hand side here otherwise
|
||||
# CompareType(Binary) doesn't work.
|
||||
|
||||
@ -110,7 +110,7 @@
|
||||
"listCollections"
|
||||
],
|
||||
"blockConnection": true,
|
||||
"blockTimeMS": 60
|
||||
"blockTimeMS": 600
|
||||
}
|
||||
},
|
||||
"clientOptions": {
|
||||
@ -119,7 +119,7 @@
|
||||
"aws": {}
|
||||
}
|
||||
},
|
||||
"timeoutMS": 50
|
||||
"timeoutMS": 500
|
||||
},
|
||||
"operations": [
|
||||
{
|
||||
|
||||
@ -25,14 +25,13 @@ from test import IntegrationTest, client_knobs, unittest
|
||||
from test.pymongo_mocks import DummyMonitor
|
||||
from test.utils import (
|
||||
CMAPListener,
|
||||
SpecTestCreator,
|
||||
camel_to_snake,
|
||||
client_context,
|
||||
get_pool,
|
||||
get_pools,
|
||||
wait_until,
|
||||
)
|
||||
from test.utils_spec_runner import SpecRunnerThread
|
||||
from test.utils_spec_runner import SpecRunnerThread, SpecTestCreator
|
||||
|
||||
from bson.objectid import ObjectId
|
||||
from bson.son import SON
|
||||
|
||||
@ -30,6 +30,7 @@ import uuid
|
||||
import warnings
|
||||
from test import IntegrationTest, PyMongoTestCase, client_context
|
||||
from test.test_bulk import BulkTestBase
|
||||
from test.utils_spec_runner import SpecRunner, SpecTestCreator
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, Mapping, Optional
|
||||
|
||||
@ -58,7 +59,6 @@ from test.unified_format import generate_test_classes
|
||||
from test.utils import (
|
||||
AllowListEventListener,
|
||||
OvertCommandListener,
|
||||
SpecTestCreator,
|
||||
TopologyEventListener,
|
||||
camel_to_snake_args,
|
||||
is_greenthread_patched,
|
||||
@ -624,135 +624,132 @@ AWS_TEMP_NO_SESSION_CREDS = {
|
||||
KMS_TLS_OPTS = {"kmip": {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM}}
|
||||
|
||||
|
||||
if _IS_SYNC:
|
||||
# TODO: Add synchronous SpecRunner (https://jira.mongodb.org/browse/PYTHON-4700)
|
||||
class TestSpec(SpecRunner):
|
||||
@classmethod
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
class TestSpec(SpecRunner):
|
||||
@classmethod
|
||||
@unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed")
|
||||
def _setup_class(cls):
|
||||
super()._setup_class()
|
||||
|
||||
def parse_auto_encrypt_opts(self, opts):
|
||||
"""Parse clientOptions.autoEncryptOpts."""
|
||||
opts = camel_to_snake_args(opts)
|
||||
kms_providers = opts["kms_providers"]
|
||||
if "aws" in kms_providers:
|
||||
kms_providers["aws"] = AWS_CREDS
|
||||
if not any(AWS_CREDS.values()):
|
||||
self.skipTest("AWS environment credentials are not set")
|
||||
if "awsTemporary" in kms_providers:
|
||||
kms_providers["aws"] = AWS_TEMP_CREDS
|
||||
del kms_providers["awsTemporary"]
|
||||
if not any(AWS_TEMP_CREDS.values()):
|
||||
self.skipTest("AWS Temp environment credentials are not set")
|
||||
if "awsTemporaryNoSessionToken" in kms_providers:
|
||||
kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS
|
||||
del kms_providers["awsTemporaryNoSessionToken"]
|
||||
if not any(AWS_TEMP_NO_SESSION_CREDS.values()):
|
||||
self.skipTest("AWS Temp environment credentials are not set")
|
||||
if "azure" in kms_providers:
|
||||
kms_providers["azure"] = AZURE_CREDS
|
||||
if not any(AZURE_CREDS.values()):
|
||||
self.skipTest("Azure environment credentials are not set")
|
||||
if "gcp" in kms_providers:
|
||||
kms_providers["gcp"] = GCP_CREDS
|
||||
if not any(AZURE_CREDS.values()):
|
||||
self.skipTest("GCP environment credentials are not set")
|
||||
if "kmip" in kms_providers:
|
||||
kms_providers["kmip"] = KMIP_CREDS
|
||||
opts["kms_tls_options"] = KMS_TLS_OPTS
|
||||
if "key_vault_namespace" not in opts:
|
||||
opts["key_vault_namespace"] = "keyvault.datakeys"
|
||||
if "extra_options" in opts:
|
||||
opts.update(camel_to_snake_args(opts.pop("extra_options")))
|
||||
def parse_auto_encrypt_opts(self, opts):
|
||||
"""Parse clientOptions.autoEncryptOpts."""
|
||||
opts = camel_to_snake_args(opts)
|
||||
kms_providers = opts["kms_providers"]
|
||||
if "aws" in kms_providers:
|
||||
kms_providers["aws"] = AWS_CREDS
|
||||
if not any(AWS_CREDS.values()):
|
||||
self.skipTest("AWS environment credentials are not set")
|
||||
if "awsTemporary" in kms_providers:
|
||||
kms_providers["aws"] = AWS_TEMP_CREDS
|
||||
del kms_providers["awsTemporary"]
|
||||
if not any(AWS_TEMP_CREDS.values()):
|
||||
self.skipTest("AWS Temp environment credentials are not set")
|
||||
if "awsTemporaryNoSessionToken" in kms_providers:
|
||||
kms_providers["aws"] = AWS_TEMP_NO_SESSION_CREDS
|
||||
del kms_providers["awsTemporaryNoSessionToken"]
|
||||
if not any(AWS_TEMP_NO_SESSION_CREDS.values()):
|
||||
self.skipTest("AWS Temp environment credentials are not set")
|
||||
if "azure" in kms_providers:
|
||||
kms_providers["azure"] = AZURE_CREDS
|
||||
if not any(AZURE_CREDS.values()):
|
||||
self.skipTest("Azure environment credentials are not set")
|
||||
if "gcp" in kms_providers:
|
||||
kms_providers["gcp"] = GCP_CREDS
|
||||
if not any(AZURE_CREDS.values()):
|
||||
self.skipTest("GCP environment credentials are not set")
|
||||
if "kmip" in kms_providers:
|
||||
kms_providers["kmip"] = KMIP_CREDS
|
||||
opts["kms_tls_options"] = KMS_TLS_OPTS
|
||||
if "key_vault_namespace" not in opts:
|
||||
opts["key_vault_namespace"] = "keyvault.datakeys"
|
||||
if "extra_options" in opts:
|
||||
opts.update(camel_to_snake_args(opts.pop("extra_options")))
|
||||
|
||||
opts = dict(opts)
|
||||
return AutoEncryptionOpts(**opts)
|
||||
opts = dict(opts)
|
||||
return AutoEncryptionOpts(**opts)
|
||||
|
||||
def parse_client_options(self, opts):
|
||||
"""Override clientOptions parsing to support autoEncryptOpts."""
|
||||
encrypt_opts = opts.pop("autoEncryptOpts", None)
|
||||
if encrypt_opts:
|
||||
opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts)
|
||||
def parse_client_options(self, opts):
|
||||
"""Override clientOptions parsing to support autoEncryptOpts."""
|
||||
encrypt_opts = opts.pop("autoEncryptOpts", None)
|
||||
if encrypt_opts:
|
||||
opts["auto_encryption_opts"] = self.parse_auto_encrypt_opts(encrypt_opts)
|
||||
|
||||
return super().parse_client_options(opts)
|
||||
return super().parse_client_options(opts)
|
||||
|
||||
def get_object_name(self, op):
|
||||
"""Default object is collection."""
|
||||
return op.get("object", "collection")
|
||||
def get_object_name(self, op):
|
||||
"""Default object is collection."""
|
||||
return op.get("object", "collection")
|
||||
|
||||
def maybe_skip_scenario(self, test):
|
||||
super().maybe_skip_scenario(test)
|
||||
desc = test["description"].lower()
|
||||
if (
|
||||
"timeoutms applied to listcollections to get collection schema" in desc
|
||||
and sys.platform in ("win32", "darwin")
|
||||
):
|
||||
self.skipTest("PYTHON-3706 flaky test on Windows/macOS")
|
||||
if "type=symbol" in desc:
|
||||
self.skipTest("PyMongo does not support the symbol type")
|
||||
if (
|
||||
"timeoutms applied to listcollections to get collection schema" in desc
|
||||
and not _IS_SYNC
|
||||
):
|
||||
self.skipTest("PYTHON-4844 flaky test on async")
|
||||
def maybe_skip_scenario(self, test):
|
||||
super().maybe_skip_scenario(test)
|
||||
desc = test["description"].lower()
|
||||
if (
|
||||
"timeoutms applied to listcollections to get collection schema" in desc
|
||||
and sys.platform in ("win32", "darwin")
|
||||
):
|
||||
self.skipTest("PYTHON-3706 flaky test on Windows/macOS")
|
||||
if "type=symbol" in desc:
|
||||
self.skipTest("PyMongo does not support the symbol type")
|
||||
if "timeoutms applied to listcollections to get collection schema" in desc and not _IS_SYNC:
|
||||
self.skipTest("PYTHON-4844 flaky test on async")
|
||||
|
||||
def setup_scenario(self, scenario_def):
|
||||
"""Override a test's setup."""
|
||||
key_vault_data = scenario_def["key_vault_data"]
|
||||
encrypted_fields = scenario_def["encrypted_fields"]
|
||||
json_schema = scenario_def["json_schema"]
|
||||
data = scenario_def["data"]
|
||||
coll = client_context.client.get_database("keyvault", codec_options=OPTS)["datakeys"]
|
||||
coll.delete_many({})
|
||||
if key_vault_data:
|
||||
coll.insert_many(key_vault_data)
|
||||
def setup_scenario(self, scenario_def):
|
||||
"""Override a test's setup."""
|
||||
key_vault_data = scenario_def["key_vault_data"]
|
||||
encrypted_fields = scenario_def["encrypted_fields"]
|
||||
json_schema = scenario_def["json_schema"]
|
||||
data = scenario_def["data"]
|
||||
coll = client_context.client.get_database("keyvault", codec_options=OPTS)["datakeys"]
|
||||
coll.delete_many({})
|
||||
if key_vault_data:
|
||||
coll.insert_many(key_vault_data)
|
||||
|
||||
db_name = self.get_scenario_db_name(scenario_def)
|
||||
coll_name = self.get_scenario_coll_name(scenario_def)
|
||||
db = client_context.client.get_database(db_name, codec_options=OPTS)
|
||||
coll = db.drop_collection(coll_name, encrypted_fields=encrypted_fields)
|
||||
wc = WriteConcern(w="majority")
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if json_schema:
|
||||
kwargs["validator"] = {"$jsonSchema": json_schema}
|
||||
kwargs["codec_options"] = OPTS
|
||||
if not data:
|
||||
kwargs["write_concern"] = wc
|
||||
if encrypted_fields:
|
||||
kwargs["encryptedFields"] = encrypted_fields
|
||||
db.create_collection(coll_name, **kwargs)
|
||||
coll = db[coll_name]
|
||||
if data:
|
||||
# Load data.
|
||||
coll.with_options(write_concern=wc).insert_many(scenario_def["data"])
|
||||
db_name = self.get_scenario_db_name(scenario_def)
|
||||
coll_name = self.get_scenario_coll_name(scenario_def)
|
||||
db = client_context.client.get_database(db_name, codec_options=OPTS)
|
||||
db.drop_collection(coll_name, encrypted_fields=encrypted_fields)
|
||||
wc = WriteConcern(w="majority")
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if json_schema:
|
||||
kwargs["validator"] = {"$jsonSchema": json_schema}
|
||||
kwargs["codec_options"] = OPTS
|
||||
if not data:
|
||||
kwargs["write_concern"] = wc
|
||||
if encrypted_fields:
|
||||
kwargs["encryptedFields"] = encrypted_fields
|
||||
db.create_collection(coll_name, **kwargs)
|
||||
coll = db[coll_name]
|
||||
if data:
|
||||
# Load data.
|
||||
coll.with_options(write_concern=wc).insert_many(scenario_def["data"])
|
||||
|
||||
def allowable_errors(self, op):
|
||||
"""Override expected error classes."""
|
||||
errors = super().allowable_errors(op)
|
||||
# An updateOne test expects encryption to error when no $ operator
|
||||
# appears but pymongo raises a client side ValueError in this case.
|
||||
if op["name"] == "updateOne":
|
||||
errors += (ValueError,)
|
||||
return errors
|
||||
def allowable_errors(self, op):
|
||||
"""Override expected error classes."""
|
||||
errors = super().allowable_errors(op)
|
||||
# An updateOne test expects encryption to error when no $ operator
|
||||
# appears but pymongo raises a client side ValueError in this case.
|
||||
if op["name"] == "updateOne":
|
||||
errors += (ValueError,)
|
||||
return errors
|
||||
|
||||
def create_test(scenario_def, test, name):
|
||||
@client_context.require_test_commands
|
||||
def run_scenario(self):
|
||||
self.run_scenario(scenario_def, test)
|
||||
|
||||
return run_scenario
|
||||
def create_test(scenario_def, test, name):
|
||||
@client_context.require_test_commands
|
||||
def run_scenario(self):
|
||||
self.run_scenario(scenario_def, test)
|
||||
|
||||
test_creator = SpecTestCreator(create_test, TestSpec, os.path.join(SPEC_PATH, "legacy"))
|
||||
test_creator.create_tests()
|
||||
return run_scenario
|
||||
|
||||
if _HAVE_PYMONGOCRYPT:
|
||||
globals().update(
|
||||
generate_test_classes(
|
||||
os.path.join(SPEC_PATH, "unified"),
|
||||
module=__name__,
|
||||
)
|
||||
|
||||
test_creator = SpecTestCreator(create_test, TestSpec, os.path.join(SPEC_PATH, "legacy"))
|
||||
test_creator.create_tests()
|
||||
|
||||
if _HAVE_PYMONGOCRYPT:
|
||||
globals().update(
|
||||
generate_test_classes(
|
||||
os.path.join(SPEC_PATH, "unified"),
|
||||
module=__name__,
|
||||
)
|
||||
)
|
||||
|
||||
# Prose Tests
|
||||
ALL_KMS_PROVIDERS = {
|
||||
|
||||
@ -21,11 +21,11 @@ from test import IntegrationTest, client_context, unittest
|
||||
from test.utils import (
|
||||
CMAPListener,
|
||||
OvertCommandListener,
|
||||
SpecTestCreator,
|
||||
get_pool,
|
||||
wait_until,
|
||||
)
|
||||
from test.utils_selection_tests import create_topology
|
||||
from test.utils_spec_runner import SpecTestCreator
|
||||
|
||||
from pymongo.common import clean_node
|
||||
from pymongo.monitoring import ConnectionReadyEvent
|
||||
|
||||
147
test/utils.py
147
test/utils.py
@ -418,153 +418,6 @@ class FunctionCallRecorder:
|
||||
return len(self._call_list)
|
||||
|
||||
|
||||
class SpecTestCreator:
|
||||
"""Class to create test cases from specifications."""
|
||||
|
||||
def __init__(self, create_test, test_class, test_path):
|
||||
"""Create a TestCreator object.
|
||||
|
||||
:Parameters:
|
||||
- `create_test`: callback that returns a test case. The callback
|
||||
must accept the following arguments - a dictionary containing the
|
||||
entire test specification (the `scenario_def`), a dictionary
|
||||
containing the specification for which the test case will be
|
||||
generated (the `test_def`).
|
||||
- `test_class`: the unittest.TestCase class in which to create the
|
||||
test case.
|
||||
- `test_path`: path to the directory containing the JSON files with
|
||||
the test specifications.
|
||||
"""
|
||||
self._create_test = create_test
|
||||
self._test_class = test_class
|
||||
self.test_path = test_path
|
||||
|
||||
def _ensure_min_max_server_version(self, scenario_def, method):
|
||||
"""Test modifier that enforces a version range for the server on a
|
||||
test case.
|
||||
"""
|
||||
if "minServerVersion" in scenario_def:
|
||||
min_ver = tuple(int(elt) for elt in scenario_def["minServerVersion"].split("."))
|
||||
if min_ver is not None:
|
||||
method = client_context.require_version_min(*min_ver)(method)
|
||||
|
||||
if "maxServerVersion" in scenario_def:
|
||||
max_ver = tuple(int(elt) for elt in scenario_def["maxServerVersion"].split("."))
|
||||
if max_ver is not None:
|
||||
method = client_context.require_version_max(*max_ver)(method)
|
||||
|
||||
if "serverless" in scenario_def:
|
||||
serverless = scenario_def["serverless"]
|
||||
if serverless == "require":
|
||||
serverless_satisfied = client_context.serverless
|
||||
elif serverless == "forbid":
|
||||
serverless_satisfied = not client_context.serverless
|
||||
else: # unset or "allow"
|
||||
serverless_satisfied = True
|
||||
method = unittest.skipUnless(
|
||||
serverless_satisfied, "Serverless requirement not satisfied"
|
||||
)(method)
|
||||
|
||||
return method
|
||||
|
||||
@staticmethod
|
||||
def valid_topology(run_on_req):
|
||||
return client_context.is_topology_type(
|
||||
run_on_req.get("topology", ["single", "replicaset", "sharded", "load-balanced"])
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def min_server_version(run_on_req):
|
||||
version = run_on_req.get("minServerVersion")
|
||||
if version:
|
||||
min_ver = tuple(int(elt) for elt in version.split("."))
|
||||
return client_context.version >= min_ver
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def max_server_version(run_on_req):
|
||||
version = run_on_req.get("maxServerVersion")
|
||||
if version:
|
||||
max_ver = tuple(int(elt) for elt in version.split("."))
|
||||
return client_context.version <= max_ver
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def valid_auth_enabled(run_on_req):
|
||||
if "authEnabled" in run_on_req:
|
||||
if run_on_req["authEnabled"]:
|
||||
return client_context.auth_enabled
|
||||
return not client_context.auth_enabled
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def serverless_ok(run_on_req):
|
||||
serverless = run_on_req["serverless"]
|
||||
if serverless == "require":
|
||||
return client_context.serverless
|
||||
elif serverless == "forbid":
|
||||
return not client_context.serverless
|
||||
else: # unset or "allow"
|
||||
return True
|
||||
|
||||
def should_run_on(self, scenario_def):
|
||||
run_on = scenario_def.get("runOn", [])
|
||||
if not run_on:
|
||||
# Always run these tests.
|
||||
return True
|
||||
|
||||
for req in run_on:
|
||||
if (
|
||||
self.valid_topology(req)
|
||||
and self.min_server_version(req)
|
||||
and self.max_server_version(req)
|
||||
and self.valid_auth_enabled(req)
|
||||
and self.serverless_ok(req)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def ensure_run_on(self, scenario_def, method):
|
||||
"""Test modifier that enforces a 'runOn' on a test case."""
|
||||
return client_context._require(
|
||||
lambda: self.should_run_on(scenario_def), "runOn not satisfied", method
|
||||
)
|
||||
|
||||
def tests(self, scenario_def):
|
||||
"""Allow CMAP spec test to override the location of test."""
|
||||
return scenario_def["tests"]
|
||||
|
||||
def create_tests(self):
|
||||
for dirpath, _, filenames in os.walk(self.test_path):
|
||||
dirname = os.path.split(dirpath)[-1]
|
||||
|
||||
for filename in filenames:
|
||||
with open(os.path.join(dirpath, filename)) as scenario_stream:
|
||||
# Use tz_aware=False to match how CodecOptions decodes
|
||||
# dates.
|
||||
opts = json_util.JSONOptions(tz_aware=False)
|
||||
scenario_def = ScenarioDict(
|
||||
json_util.loads(scenario_stream.read(), json_options=opts)
|
||||
)
|
||||
|
||||
test_type = os.path.splitext(filename)[0]
|
||||
|
||||
# Construct test from scenario.
|
||||
for test_def in self.tests(scenario_def):
|
||||
test_name = "test_{}_{}_{}".format(
|
||||
dirname,
|
||||
test_type.replace("-", "_").replace(".", "_"),
|
||||
str(test_def["description"].replace(" ", "_").replace(".", "_")),
|
||||
)
|
||||
|
||||
new_test = self._create_test(scenario_def, test_def, test_name)
|
||||
new_test = self._ensure_min_max_server_version(scenario_def, new_test)
|
||||
new_test = self.ensure_run_on(scenario_def, new_test)
|
||||
|
||||
new_test.__name__ = test_name
|
||||
setattr(self._test_class, new_test.__name__, new_test)
|
||||
|
||||
|
||||
def ensure_all_connected(client: MongoClient) -> None:
|
||||
"""Ensure that the client's connection pool has socket connections to all
|
||||
members of a replica set. Raises ConfigurationError when called with a
|
||||
|
||||
@ -15,8 +15,12 @@
|
||||
"""Utilities for testing driver specs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import os
|
||||
import threading
|
||||
import unittest
|
||||
from asyncio import iscoroutinefunction
|
||||
from collections import abc
|
||||
from test import IntegrationTest, client_context, client_knobs
|
||||
from test.utils import (
|
||||
@ -24,6 +28,7 @@ from test.utils import (
|
||||
CompareType,
|
||||
EventListener,
|
||||
OvertCommandListener,
|
||||
ScenarioDict,
|
||||
ServerAndTopologyEventListener,
|
||||
camel_to_snake,
|
||||
camel_to_snake_args,
|
||||
@ -32,11 +37,12 @@ from test.utils import (
|
||||
)
|
||||
from typing import List
|
||||
|
||||
from bson import ObjectId, decode, encode
|
||||
from bson import ObjectId, decode, encode, json_util
|
||||
from bson.binary import Binary
|
||||
from bson.int64 import Int64
|
||||
from bson.son import SON
|
||||
from gridfs import GridFSBucket
|
||||
from gridfs.synchronous.grid_file import GridFSBucket
|
||||
from pymongo.errors import BulkWriteError, OperationFailure, PyMongoError
|
||||
from pymongo.read_concern import ReadConcern
|
||||
from pymongo.read_preferences import ReadPreference
|
||||
@ -83,6 +89,161 @@ class SpecRunnerThread(threading.Thread):
|
||||
self.stop()
|
||||
|
||||
|
||||
class SpecTestCreator:
|
||||
"""Class to create test cases from specifications."""
|
||||
|
||||
def __init__(self, create_test, test_class, test_path):
|
||||
"""Create a TestCreator object.
|
||||
|
||||
:Parameters:
|
||||
- `create_test`: callback that returns a test case. The callback
|
||||
must accept the following arguments - a dictionary containing the
|
||||
entire test specification (the `scenario_def`), a dictionary
|
||||
containing the specification for which the test case will be
|
||||
generated (the `test_def`).
|
||||
- `test_class`: the unittest.TestCase class in which to create the
|
||||
test case.
|
||||
- `test_path`: path to the directory containing the JSON files with
|
||||
the test specifications.
|
||||
"""
|
||||
self._create_test = create_test
|
||||
self._test_class = test_class
|
||||
self.test_path = test_path
|
||||
|
||||
def _ensure_min_max_server_version(self, scenario_def, method):
|
||||
"""Test modifier that enforces a version range for the server on a
|
||||
test case.
|
||||
"""
|
||||
if "minServerVersion" in scenario_def:
|
||||
min_ver = tuple(int(elt) for elt in scenario_def["minServerVersion"].split("."))
|
||||
if min_ver is not None:
|
||||
method = client_context.require_version_min(*min_ver)(method)
|
||||
|
||||
if "maxServerVersion" in scenario_def:
|
||||
max_ver = tuple(int(elt) for elt in scenario_def["maxServerVersion"].split("."))
|
||||
if max_ver is not None:
|
||||
method = client_context.require_version_max(*max_ver)(method)
|
||||
|
||||
if "serverless" in scenario_def:
|
||||
serverless = scenario_def["serverless"]
|
||||
if serverless == "require":
|
||||
serverless_satisfied = client_context.serverless
|
||||
elif serverless == "forbid":
|
||||
serverless_satisfied = not client_context.serverless
|
||||
else: # unset or "allow"
|
||||
serverless_satisfied = True
|
||||
method = unittest.skipUnless(
|
||||
serverless_satisfied, "Serverless requirement not satisfied"
|
||||
)(method)
|
||||
|
||||
return method
|
||||
|
||||
@staticmethod
|
||||
def valid_topology(run_on_req):
|
||||
return client_context.is_topology_type(
|
||||
run_on_req.get("topology", ["single", "replicaset", "sharded", "load-balanced"])
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def min_server_version(run_on_req):
|
||||
version = run_on_req.get("minServerVersion")
|
||||
if version:
|
||||
min_ver = tuple(int(elt) for elt in version.split("."))
|
||||
return client_context.version >= min_ver
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def max_server_version(run_on_req):
|
||||
version = run_on_req.get("maxServerVersion")
|
||||
if version:
|
||||
max_ver = tuple(int(elt) for elt in version.split("."))
|
||||
return client_context.version <= max_ver
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def valid_auth_enabled(run_on_req):
|
||||
if "authEnabled" in run_on_req:
|
||||
if run_on_req["authEnabled"]:
|
||||
return client_context.auth_enabled
|
||||
return not client_context.auth_enabled
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def serverless_ok(run_on_req):
|
||||
serverless = run_on_req["serverless"]
|
||||
if serverless == "require":
|
||||
return client_context.serverless
|
||||
elif serverless == "forbid":
|
||||
return not client_context.serverless
|
||||
else: # unset or "allow"
|
||||
return True
|
||||
|
||||
def should_run_on(self, scenario_def):
|
||||
run_on = scenario_def.get("runOn", [])
|
||||
if not run_on:
|
||||
# Always run these tests.
|
||||
return True
|
||||
|
||||
for req in run_on:
|
||||
if (
|
||||
self.valid_topology(req)
|
||||
and self.min_server_version(req)
|
||||
and self.max_server_version(req)
|
||||
and self.valid_auth_enabled(req)
|
||||
and self.serverless_ok(req)
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def ensure_run_on(self, scenario_def, method):
|
||||
"""Test modifier that enforces a 'runOn' on a test case."""
|
||||
|
||||
def predicate():
|
||||
return self.should_run_on(scenario_def)
|
||||
|
||||
return client_context._require(predicate, "runOn not satisfied", method)
|
||||
|
||||
def tests(self, scenario_def):
|
||||
"""Allow CMAP spec test to override the location of test."""
|
||||
return scenario_def["tests"]
|
||||
|
||||
def _create_tests(self):
|
||||
for dirpath, _, filenames in os.walk(self.test_path):
|
||||
dirname = os.path.split(dirpath)[-1]
|
||||
|
||||
for filename in filenames:
|
||||
with open(os.path.join(dirpath, filename)) as scenario_stream: # noqa: ASYNC101, RUF100
|
||||
# Use tz_aware=False to match how CodecOptions decodes
|
||||
# dates.
|
||||
opts = json_util.JSONOptions(tz_aware=False)
|
||||
scenario_def = ScenarioDict(
|
||||
json_util.loads(scenario_stream.read(), json_options=opts)
|
||||
)
|
||||
|
||||
test_type = os.path.splitext(filename)[0]
|
||||
|
||||
# Construct test from scenario.
|
||||
for test_def in self.tests(scenario_def):
|
||||
test_name = "test_{}_{}_{}".format(
|
||||
dirname,
|
||||
test_type.replace("-", "_").replace(".", "_"),
|
||||
str(test_def["description"].replace(" ", "_").replace(".", "_")),
|
||||
)
|
||||
|
||||
new_test = self._create_test(scenario_def, test_def, test_name)
|
||||
new_test = self._ensure_min_max_server_version(scenario_def, new_test)
|
||||
new_test = self.ensure_run_on(scenario_def, new_test)
|
||||
|
||||
new_test.__name__ = test_name
|
||||
setattr(self._test_class, new_test.__name__, new_test)
|
||||
|
||||
def create_tests(self):
|
||||
if _IS_SYNC:
|
||||
self._create_tests()
|
||||
else:
|
||||
asyncio.run(self._create_tests())
|
||||
|
||||
|
||||
class SpecRunner(IntegrationTest):
|
||||
mongos_clients: List
|
||||
knobs: client_knobs
|
||||
@ -312,7 +473,10 @@ class SpecRunner(IntegrationTest):
|
||||
args.update(arguments)
|
||||
arguments = args
|
||||
|
||||
result = cmd(**dict(arguments))
|
||||
if not _IS_SYNC and iscoroutinefunction(cmd):
|
||||
result = cmd(**dict(arguments))
|
||||
else:
|
||||
result = cmd(**dict(arguments))
|
||||
# Cleanup open change stream cursors.
|
||||
if name == "watch":
|
||||
self.addCleanup(result.close)
|
||||
@ -583,7 +747,7 @@ class SpecRunner(IntegrationTest):
|
||||
read_preference=ReadPreference.PRIMARY,
|
||||
read_concern=ReadConcern("local"),
|
||||
)
|
||||
actual_data = (outcome_coll.find(sort=[("_id", 1)])).to_list()
|
||||
actual_data = outcome_coll.find(sort=[("_id", 1)]).to_list()
|
||||
|
||||
# The expected data needs to be the left hand side here otherwise
|
||||
# CompareType(Binary) doesn't work.
|
||||
|
||||
@ -105,6 +105,8 @@ replacements = {
|
||||
"PyMongo|c|async": "PyMongo|c",
|
||||
"AsyncTestGridFile": "TestGridFile",
|
||||
"AsyncTestGridFileNoConnect": "TestGridFileNoConnect",
|
||||
"AsyncTestSpec": "TestSpec",
|
||||
"AsyncSpecTestCreator": "SpecTestCreator",
|
||||
"async_set_fail_point": "set_fail_point",
|
||||
"async_ensure_all_connected": "ensure_all_connected",
|
||||
"async_repl_set_step_down": "repl_set_step_down",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user