diff --git a/test/asynchronous/test_auth_oidc.py b/test/asynchronous/test_auth_oidc.py index 639c155e7..ff604f55a 100644 --- a/test/asynchronous/test_auth_oidc.py +++ b/test/asynchronous/test_auth_oidc.py @@ -30,7 +30,7 @@ import pytest sys.path[0:0] = [""] -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path from test.utils_shared import EventListener, OvertCommandListener from bson import SON @@ -54,14 +54,13 @@ from pymongo.synchronous.uri_parser import parse_uri _IS_SYNC = False ROOT = Path(__file__).parent.parent.resolve() -TEST_PATH = ROOT / "auth" / "unified" ENVIRON = os.environ.get("OIDC_ENV", "test") DOMAIN = os.environ.get("OIDC_DOMAIN", "") TOKEN_DIR = os.environ.get("OIDC_TOKEN_DIR", "") TOKEN_FILE = os.environ.get("OIDC_TOKEN_FILE", "") # Generate unified tests. -globals().update(generate_test_classes(str(TEST_PATH), module=__name__)) +globals().update(generate_test_classes(get_test_path("auth", "unified"), module=__name__)) pytestmark = pytest.mark.auth_oidc diff --git a/test/asynchronous/test_auth_spec.py b/test/asynchronous/test_auth_spec.py index 7c659c6d9..a40687348 100644 --- a/test/asynchronous/test_auth_spec.py +++ b/test/asynchronous/test_auth_spec.py @@ -27,7 +27,7 @@ import pytest sys.path[0:0] = [""] from test import unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path from pymongo import AsyncMongoClient from pymongo.auth_oidc_shared import OIDCCallback @@ -35,8 +35,7 @@ from pymongo.auth_oidc_shared import OIDCCallback pytestmark = pytest.mark.auth _IS_SYNC = False - -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth") +_TEST_PATH = get_test_path("auth") class TestAuthSpec(AsyncPyMongoTestCase): diff --git a/test/asynchronous/test_change_stream.py b/test/asynchronous/test_change_stream.py index 89a215f14..3a34319ea 100644 --- a/test/asynchronous/test_change_stream.py +++ b/test/asynchronous/test_change_stream.py @@ -35,7 +35,7 @@ from test.asynchronous import ( async_client_context, unittest, ) -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path from test.utils_shared import ( AllowListEventListener, EventListener, @@ -1143,12 +1143,9 @@ class TestAllLegacyScenarios(AsyncIntegrationTest): self.listener.reset() -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "change_streams") - - globals().update( generate_test_classes( - os.path.join(_TEST_PATH, "unified"), + get_test_path("change_streams", "unified"), module=__name__, ) ) diff --git a/test/asynchronous/test_client_metadata.py b/test/asynchronous/test_client_metadata.py index 2f175ccee..45c1bd1b3 100644 --- a/test/asynchronous/test_client_metadata.py +++ b/test/asynchronous/test_client_metadata.py @@ -19,7 +19,7 @@ import pathlib import time import unittest from test.asynchronous import AsyncIntegrationTest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path from test.utils_shared import CMAPListener from typing import Any, Optional @@ -40,16 +40,8 @@ pytestmark = pytest.mark.mockupdb _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "handshake", "unified") -else: - _TEST_PATH = os.path.join( - pathlib.Path(__file__).resolve().parent.parent, "handshake", "unified" - ) - # Generate unified tests. -globals().update(generate_test_classes(_TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("handshake", "unified"), module=__name__)) def _get_handshake_driver_info(request): diff --git a/test/asynchronous/test_collection_management.py b/test/asynchronous/test_collection_management.py index c0edf9158..7a142dc65 100644 --- a/test/asynchronous/test_collection_management.py +++ b/test/asynchronous/test_collection_management.py @@ -22,20 +22,12 @@ import sys sys.path[0:0] = [""] from test import unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "collection_management") -else: - _TEST_PATH = os.path.join( - pathlib.Path(__file__).resolve().parent.parent, "collection_management" - ) - # Generate unified tests. -globals().update(generate_test_classes(_TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("collection_management"), module=__name__)) if __name__ == "__main__": unittest.main() diff --git a/test/asynchronous/test_command_logging.py b/test/asynchronous/test_command_logging.py index f9b459c15..831dd0c10 100644 --- a/test/asynchronous/test_command_logging.py +++ b/test/asynchronous/test_command_logging.py @@ -22,20 +22,13 @@ import sys sys.path[0:0] = [""] from test import unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_logging") -else: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_logging") - - globals().update( generate_test_classes( - _TEST_PATH, + get_test_path("command_logging"), module=__name__, ) ) diff --git a/test/asynchronous/test_command_monitoring.py b/test/asynchronous/test_command_monitoring.py index 311fd1fdc..a04ba449a 100644 --- a/test/asynchronous/test_command_monitoring.py +++ b/test/asynchronous/test_command_monitoring.py @@ -22,20 +22,13 @@ import sys sys.path[0:0] = [""] from test import unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_monitoring") -else: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_monitoring") - - globals().update( generate_test_classes( - _TEST_PATH, + get_test_path("command_monitoring"), module=__name__, ) ) diff --git a/test/asynchronous/test_connection_logging.py b/test/asynchronous/test_connection_logging.py index 945c6c59b..4d03391dd 100644 --- a/test/asynchronous/test_connection_logging.py +++ b/test/asynchronous/test_connection_logging.py @@ -22,20 +22,13 @@ import sys sys.path[0:0] = [""] from test import unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "connection_logging") -else: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "connection_logging") - - globals().update( generate_test_classes( - _TEST_PATH, + get_test_path("connection_logging"), module=__name__, ) ) diff --git a/test/asynchronous/test_crud_unified.py b/test/asynchronous/test_crud_unified.py index 8b1f9b8e3..94e47a26e 100644 --- a/test/asynchronous/test_crud_unified.py +++ b/test/asynchronous/test_crud_unified.py @@ -22,18 +22,12 @@ import sys sys.path[0:0] = [""] from test import unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "crud", "unified") -else: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "crud", "unified") - # Generate unified tests. -globals().update(generate_test_classes(_TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("crud", "unified"), module=__name__)) if __name__ == "__main__": unittest.main() diff --git a/test/asynchronous/test_csot.py b/test/asynchronous/test_csot.py index a978d1ccc..547ee20a5 100644 --- a/test/asynchronous/test_csot.py +++ b/test/asynchronous/test_csot.py @@ -22,7 +22,7 @@ from pathlib import Path sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path from test.asynchronous.utils import flaky import pymongo @@ -31,14 +31,8 @@ from pymongo.errors import PyMongoError _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "csot") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "csot") - # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("csot"), module=__name__)) class TestCSOT(AsyncIntegrationTest): diff --git a/test/asynchronous/test_discovery_and_monitoring.py b/test/asynchronous/test_discovery_and_monitoring.py index 5820d00c4..0bbf471d8 100644 --- a/test/asynchronous/test_discovery_and_monitoring.py +++ b/test/asynchronous/test_discovery_and_monitoring.py @@ -40,7 +40,7 @@ from test.asynchronous import ( unittest, ) from test.asynchronous.pymongo_mocks import DummyMonitor -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path from test.asynchronous.utils import ( async_get_pool, ) @@ -76,14 +76,7 @@ from pymongo.topology_description import TOPOLOGY_TYPE _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - SDAM_PATH = os.path.join(Path(__file__).resolve().parent, "discovery_and_monitoring") -else: - SDAM_PATH = os.path.join( - Path(__file__).resolve().parent.parent, - "discovery_and_monitoring", - ) +SDAM_PATH = get_test_path("discovery_and_monitoring") async def create_mock_topology(uri, monitor_class=DummyMonitor): diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index 9cd0944ce..3f71b35e3 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -53,7 +53,7 @@ from test import ( unittest, ) from test.asynchronous.test_bulk import AsyncBulkTestBase -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path from test.helpers_shared import ( ALL_KMS_PROVIDERS, AWS_CREDS, @@ -273,11 +273,7 @@ class AsyncEncryptionIntegrationTest(AsyncIntegrationTest): # Location of JSON test files. -if _IS_SYNC: - BASE = os.path.join(pathlib.Path(__file__).resolve().parent, "client-side-encryption") -else: - BASE = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "client-side-encryption") - +BASE = get_test_path("client-side-encryption") SPEC_PATH = os.path.join(BASE, "spec") OPTS = CodecOptions() diff --git a/test/asynchronous/test_gridfs_spec.py b/test/asynchronous/test_gridfs_spec.py index f3dc14fbd..ab1c8a0eb 100644 --- a/test/asynchronous/test_gridfs_spec.py +++ b/test/asynchronous/test_gridfs_spec.py @@ -22,18 +22,12 @@ from pathlib import Path sys.path[0:0] = [""] from test import unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "gridfs") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "gridfs") - # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("gridfs"), module=__name__)) if __name__ == "__main__": unittest.main() diff --git a/test/asynchronous/test_index_management.py b/test/asynchronous/test_index_management.py index 890788fc5..ac096ec09 100644 --- a/test/asynchronous/test_index_management.py +++ b/test/asynchronous/test_index_management.py @@ -28,7 +28,7 @@ import pytest sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path from test.utils_shared import AllowListEventListener, OvertCommandListener from pymongo.errors import OperationFailure @@ -40,12 +40,6 @@ _IS_SYNC = False pytestmark = pytest.mark.search_index -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "index_management") -else: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "index_management") - _NAME = "test-search-index" @@ -370,7 +364,7 @@ class TestSearchIndexProse(SearchIndexIntegrationBase): globals().update( generate_test_classes( - _TEST_PATH, + get_test_path("index_management"), module=__name__, ) ) diff --git a/test/asynchronous/test_load_balancer.py b/test/asynchronous/test_load_balancer.py index 17d85841f..8e1ee3e79 100644 --- a/test/asynchronous/test_load_balancer.py +++ b/test/asynchronous/test_load_balancer.py @@ -30,7 +30,7 @@ import pytest sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path from test.utils_shared import ( async_wait_until, create_async_event, @@ -40,14 +40,8 @@ _IS_SYNC = False pytestmark = pytest.mark.load_balancer -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "load_balancer") -else: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "load_balancer") - # Generate unified tests. -globals().update(generate_test_classes(_TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("load_balancer"), module=__name__)) class TestLB(AsyncIntegrationTest): diff --git a/test/asynchronous/test_read_write_concern_spec.py b/test/asynchronous/test_read_write_concern_spec.py index b5cb32932..2d08de780 100644 --- a/test/asynchronous/test_read_write_concern_spec.py +++ b/test/asynchronous/test_read_write_concern_spec.py @@ -24,7 +24,7 @@ from pathlib import Path sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path from test.utils_shared import OvertCommandListener from pymongo import DESCENDING @@ -42,11 +42,7 @@ from pymongo.write_concern import WriteConcern _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "read_write_concern") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "read_write_concern") +TEST_PATH = get_test_path("read_write_concern") class TestReadWriteConcernSpec(AsyncIntegrationTest): diff --git a/test/asynchronous/test_retryable_reads_unified.py b/test/asynchronous/test_retryable_reads_unified.py index e62d60681..3de8aa96a 100644 --- a/test/asynchronous/test_retryable_reads_unified.py +++ b/test/asynchronous/test_retryable_reads_unified.py @@ -22,21 +22,15 @@ from pathlib import Path sys.path[0:0] = [""] from test import unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "retryable_reads/unified") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "retryable_reads/unified") - # Generate unified tests. # PyMongo does not support MapReduce, ListDatabaseObjects or ListCollectionObjects. globals().update( generate_test_classes( - TEST_PATH, + get_test_path("retryable_reads", "unified"), module=__name__, expected_failures=["ListDatabaseObjects .*", "ListCollectionObjects .*", "MapReduce .*"], ) diff --git a/test/asynchronous/test_retryable_writes_unified.py b/test/asynchronous/test_retryable_writes_unified.py index bb493e601..7d33c5252 100644 --- a/test/asynchronous/test_retryable_writes_unified.py +++ b/test/asynchronous/test_retryable_writes_unified.py @@ -22,18 +22,14 @@ from pathlib import Path sys.path[0:0] = [""] from test import unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "retryable_writes/unified") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "retryable_writes/unified") - # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update( + generate_test_classes(get_test_path("retryable_writes", "unified"), module=__name__) +) if __name__ == "__main__": unittest.main() diff --git a/test/asynchronous/test_run_command.py b/test/asynchronous/test_run_command.py index 3ac8c3270..cfd1adfab 100644 --- a/test/asynchronous/test_run_command.py +++ b/test/asynchronous/test_run_command.py @@ -18,20 +18,13 @@ from __future__ import annotations import os import unittest from pathlib import Path -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "run_command") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "run_command") - - globals().update( generate_test_classes( - os.path.join(TEST_PATH, "unified"), + get_test_path("run_command", "unified"), module=__name__, ) ) diff --git a/test/asynchronous/test_server_selection_logging.py b/test/asynchronous/test_server_selection_logging.py index 6b0975318..6f3ea207f 100644 --- a/test/asynchronous/test_server_selection_logging.py +++ b/test/asynchronous/test_server_selection_logging.py @@ -22,20 +22,14 @@ from pathlib import Path sys.path[0:0] = [""] from test import unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection_logging") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "server_selection_logging") - globals().update( generate_test_classes( - TEST_PATH, + get_test_path("server_selection_logging"), module=__name__, ) ) diff --git a/test/asynchronous/test_sessions_unified.py b/test/asynchronous/test_sessions_unified.py index b4cbac570..ee2b4d418 100644 --- a/test/asynchronous/test_sessions_unified.py +++ b/test/asynchronous/test_sessions_unified.py @@ -22,19 +22,12 @@ from pathlib import Path sys.path[0:0] = [""] from test import unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "sessions") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "sessions") - - # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("sessions"), module=__name__)) if __name__ == "__main__": unittest.main() diff --git a/test/asynchronous/test_transactions_unified.py b/test/asynchronous/test_transactions_unified.py index 8e5b1ae18..5f9b5d022 100644 --- a/test/asynchronous/test_transactions_unified.py +++ b/test/asynchronous/test_transactions_unified.py @@ -22,7 +22,7 @@ from pathlib import Path sys.path[0:0] = [""] from test import client_context, unittest -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path _IS_SYNC = False @@ -31,25 +31,13 @@ def setUpModule(): pass -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "transactions/unified") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "transactions/unified") +# Generate unified tests. +globals().update(generate_test_classes(get_test_path("transactions/unified"), module=__name__)) # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) - -# Location of JSON test specifications for transactions-convenient-api. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "transactions-convenient-api/unified") -else: - TEST_PATH = os.path.join( - Path(__file__).resolve().parent.parent, "transactions-convenient-api/unified" - ) - -# Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update( + generate_test_classes(get_test_path("transactions-convenient-api/unified"), module=__name__) +) if __name__ == "__main__": unittest.main() diff --git a/test/asynchronous/test_unified_format.py b/test/asynchronous/test_unified_format.py index 58a1ea332..813664123 100644 --- a/test/asynchronous/test_unified_format.py +++ b/test/asynchronous/test_unified_format.py @@ -21,18 +21,18 @@ from typing import Any sys.path[0:0] = [""] from test import UnitTest, unittest -from test.asynchronous.unified_format import MatchEvaluatorUtil, generate_test_classes +from test.asynchronous.unified_format import ( + MatchEvaluatorUtil, + generate_test_classes, + get_test_path, +) from bson import ObjectId _IS_SYNC = False # Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "unified-test-format") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "unified-test-format") - +TEST_PATH = get_test_path("unified-test-format") globals().update( generate_test_classes( diff --git a/test/asynchronous/test_versioned_api_integration.py b/test/asynchronous/test_versioned_api_integration.py index 0f6b54446..7228b945a 100644 --- a/test/asynchronous/test_versioned_api_integration.py +++ b/test/asynchronous/test_versioned_api_integration.py @@ -16,7 +16,7 @@ from __future__ import annotations import os import sys from pathlib import Path -from test.asynchronous.unified_format import generate_test_classes +from test.asynchronous.unified_format import generate_test_classes, get_test_path sys.path[0:0] = [""] @@ -27,15 +27,8 @@ from pymongo.server_api import ServerApi _IS_SYNC = False -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "versioned-api") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "versioned-api") - - # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("versioned-api"), module=__name__)) class TestServerApiIntegration(AsyncIntegrationTest): diff --git a/test/asynchronous/unified_format.py b/test/asynchronous/unified_format.py index 0c9e8c10c..e30883336 100644 --- a/test/asynchronous/unified_format.py +++ b/test/asynchronous/unified_format.py @@ -29,6 +29,7 @@ import time import traceback from collections import defaultdict from inspect import iscoroutinefunction +from pathlib import Path from test.asynchronous import ( AsyncIntegrationTest, async_client_context, @@ -1564,6 +1565,14 @@ _SCHEMA_VERSION_MAJOR_TO_MIXIN_CLASS = { } +def get_test_path(*args): + if _IS_SYNC: + root_dir = Path(__file__).resolve().parent + else: + root_dir = Path(__file__).resolve().parent.parent + return os.path.join(root_dir, *args) + + def generate_test_classes( test_path, module=__name__, @@ -1596,10 +1605,12 @@ def generate_test_classes( return base + found_any = False for dirpath, _, filenames in os.walk(test_path): dirname = os.path.split(dirpath)[-1] for filename in filenames: + found_any = True fpath = os.path.join(dirpath, filename) with open(fpath) as scenario_stream: # Use tz_aware=False to match how CodecOptions decodes @@ -1637,4 +1648,7 @@ def generate_test_classes( continue raise + if not found_any: + raise ValueError(f"No test files found in {test_path}") + return test_klasses diff --git a/test/test_auth_oidc.py b/test/test_auth_oidc.py index 877a5ca98..1defe8200 100644 --- a/test/test_auth_oidc.py +++ b/test/test_auth_oidc.py @@ -30,7 +30,7 @@ import pytest sys.path[0:0] = [""] -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path from test.utils_shared import EventListener, OvertCommandListener from bson import SON @@ -54,14 +54,13 @@ from pymongo.synchronous.uri_parser import parse_uri _IS_SYNC = True ROOT = Path(__file__).parent.parent.resolve() -TEST_PATH = ROOT / "auth" / "unified" ENVIRON = os.environ.get("OIDC_ENV", "test") DOMAIN = os.environ.get("OIDC_DOMAIN", "") TOKEN_DIR = os.environ.get("OIDC_TOKEN_DIR", "") TOKEN_FILE = os.environ.get("OIDC_TOKEN_FILE", "") # Generate unified tests. -globals().update(generate_test_classes(str(TEST_PATH), module=__name__)) +globals().update(generate_test_classes(get_test_path("auth", "unified"), module=__name__)) pytestmark = pytest.mark.auth_oidc diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index ac6411cd8..93c5e7666 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -27,7 +27,7 @@ import pytest sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path from pymongo import MongoClient from pymongo.auth_oidc_shared import OIDCCallback @@ -35,8 +35,7 @@ from pymongo.auth_oidc_shared import OIDCCallback pytestmark = pytest.mark.auth _IS_SYNC = True - -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth") +_TEST_PATH = get_test_path("auth") class TestAuthSpec(PyMongoTestCase): diff --git a/test/test_change_stream.py b/test/test_change_stream.py index 0b2a9e76b..f1d01458d 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -35,7 +35,7 @@ from test import ( client_context, unittest, ) -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path from test.utils_shared import ( AllowListEventListener, EventListener, @@ -1123,12 +1123,9 @@ class TestAllLegacyScenarios(IntegrationTest): self.listener.reset() -_TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "change_streams") - - globals().update( generate_test_classes( - os.path.join(_TEST_PATH, "unified"), + get_test_path("change_streams", "unified"), module=__name__, ) ) diff --git a/test/test_client_metadata.py b/test/test_client_metadata.py index a94c5aa25..5f103f739 100644 --- a/test/test_client_metadata.py +++ b/test/test_client_metadata.py @@ -19,7 +19,7 @@ import pathlib import time import unittest from test import IntegrationTest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path from test.utils_shared import CMAPListener from typing import Any, Optional @@ -40,16 +40,8 @@ pytestmark = pytest.mark.mockupdb _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "handshake", "unified") -else: - _TEST_PATH = os.path.join( - pathlib.Path(__file__).resolve().parent.parent, "handshake", "unified" - ) - # Generate unified tests. -globals().update(generate_test_classes(_TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("handshake", "unified"), module=__name__)) def _get_handshake_driver_info(request): diff --git a/test/test_collection_management.py b/test/test_collection_management.py index 063c20df8..deb43677a 100644 --- a/test/test_collection_management.py +++ b/test/test_collection_management.py @@ -22,20 +22,12 @@ import sys sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "collection_management") -else: - _TEST_PATH = os.path.join( - pathlib.Path(__file__).resolve().parent.parent, "collection_management" - ) - # Generate unified tests. -globals().update(generate_test_classes(_TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("collection_management"), module=__name__)) if __name__ == "__main__": unittest.main() diff --git a/test/test_command_logging.py b/test/test_command_logging.py index cf865920c..17bc319d9 100644 --- a/test/test_command_logging.py +++ b/test/test_command_logging.py @@ -22,20 +22,13 @@ import sys sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_logging") -else: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_logging") - - globals().update( generate_test_classes( - _TEST_PATH, + get_test_path("command_logging"), module=__name__, ) ) diff --git a/test/test_command_monitoring.py b/test/test_command_monitoring.py index 4f5ef06f2..eaa2af5ee 100644 --- a/test/test_command_monitoring.py +++ b/test/test_command_monitoring.py @@ -22,20 +22,13 @@ import sys sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "command_monitoring") -else: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "command_monitoring") - - globals().update( generate_test_classes( - _TEST_PATH, + get_test_path("command_monitoring"), module=__name__, ) ) diff --git a/test/test_connection_logging.py b/test/test_connection_logging.py index 253193cc4..9f5da0c43 100644 --- a/test/test_connection_logging.py +++ b/test/test_connection_logging.py @@ -22,20 +22,13 @@ import sys sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "connection_logging") -else: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "connection_logging") - - globals().update( generate_test_classes( - _TEST_PATH, + get_test_path("connection_logging"), module=__name__, ) ) diff --git a/test/test_crud_unified.py b/test/test_crud_unified.py index 1b1abf360..45af155bd 100644 --- a/test/test_crud_unified.py +++ b/test/test_crud_unified.py @@ -22,18 +22,12 @@ import sys sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "crud", "unified") -else: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "crud", "unified") - # Generate unified tests. -globals().update(generate_test_classes(_TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("crud", "unified"), module=__name__)) if __name__ == "__main__": unittest.main() diff --git a/test/test_csot.py b/test/test_csot.py index 981af1ed0..d6dec51d3 100644 --- a/test/test_csot.py +++ b/test/test_csot.py @@ -22,7 +22,7 @@ from pathlib import Path sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path from test.utils import flaky import pymongo @@ -31,14 +31,8 @@ from pymongo.errors import PyMongoError _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "csot") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "csot") - # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("csot"), module=__name__)) class TestCSOT(IntegrationTest): diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index 67a82996b..8375d63e9 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -40,7 +40,7 @@ from test import ( unittest, ) from test.pymongo_mocks import DummyMonitor -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path from test.utils import ( get_pool, ) @@ -76,14 +76,7 @@ from pymongo.topology_description import TOPOLOGY_TYPE _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - SDAM_PATH = os.path.join(Path(__file__).resolve().parent, "discovery_and_monitoring") -else: - SDAM_PATH = os.path.join( - Path(__file__).resolve().parent.parent, - "discovery_and_monitoring", - ) +SDAM_PATH = get_test_path("discovery_and_monitoring") def create_mock_topology(uri, monitor_class=DummyMonitor): diff --git a/test/test_encryption.py b/test/test_encryption.py index b0f046d63..4da307689 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -65,7 +65,7 @@ from test.helpers_shared import ( LOCAL_MASTER_KEY, ) from test.test_bulk import BulkTestBase -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path from test.utils_shared import ( AllowListEventListener, OvertCommandListener, @@ -273,11 +273,7 @@ class EncryptionIntegrationTest(IntegrationTest): # Location of JSON test files. -if _IS_SYNC: - BASE = os.path.join(pathlib.Path(__file__).resolve().parent, "client-side-encryption") -else: - BASE = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "client-side-encryption") - +BASE = get_test_path("client-side-encryption") SPEC_PATH = os.path.join(BASE, "spec") OPTS = CodecOptions() diff --git a/test/test_gridfs_spec.py b/test/test_gridfs_spec.py index e84e19725..8e1a37364 100644 --- a/test/test_gridfs_spec.py +++ b/test/test_gridfs_spec.py @@ -22,18 +22,12 @@ from pathlib import Path sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "gridfs") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "gridfs") - # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("gridfs"), module=__name__)) if __name__ == "__main__": unittest.main() diff --git a/test/test_index_management.py b/test/test_index_management.py index dea8c0e2b..2d723bb4a 100644 --- a/test/test_index_management.py +++ b/test/test_index_management.py @@ -28,7 +28,7 @@ import pytest sys.path[0:0] = [""] from test import IntegrationTest, PyMongoTestCase, unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path from test.utils_shared import AllowListEventListener, OvertCommandListener from pymongo.errors import OperationFailure @@ -40,12 +40,6 @@ _IS_SYNC = True pytestmark = pytest.mark.search_index -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "index_management") -else: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "index_management") - _NAME = "test-search-index" @@ -370,7 +364,7 @@ class TestSearchIndexProse(SearchIndexIntegrationBase): globals().update( generate_test_classes( - _TEST_PATH, + get_test_path("index_management"), module=__name__, ) ) diff --git a/test/test_load_balancer.py b/test/test_load_balancer.py index 472ef51da..41663d988 100644 --- a/test/test_load_balancer.py +++ b/test/test_load_balancer.py @@ -30,7 +30,7 @@ import pytest sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path from test.utils_shared import ( create_event, wait_until, @@ -40,14 +40,8 @@ _IS_SYNC = True pytestmark = pytest.mark.load_balancer -# Location of JSON test specifications. -if _IS_SYNC: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent, "load_balancer") -else: - _TEST_PATH = os.path.join(pathlib.Path(__file__).resolve().parent.parent, "load_balancer") - # Generate unified tests. -globals().update(generate_test_classes(_TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("load_balancer"), module=__name__)) class TestLB(IntegrationTest): diff --git a/test/test_read_write_concern_spec.py b/test/test_read_write_concern_spec.py index 4b816b7af..54946c3ea 100644 --- a/test/test_read_write_concern_spec.py +++ b/test/test_read_write_concern_spec.py @@ -24,7 +24,7 @@ from pathlib import Path sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path from test.utils_shared import OvertCommandListener from pymongo import DESCENDING @@ -42,11 +42,7 @@ from pymongo.write_concern import WriteConcern _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "read_write_concern") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "read_write_concern") +TEST_PATH = get_test_path("read_write_concern") class TestReadWriteConcernSpec(IntegrationTest): diff --git a/test/test_retryable_reads_unified.py b/test/test_retryable_reads_unified.py index b1c6435c9..c47d89d04 100644 --- a/test/test_retryable_reads_unified.py +++ b/test/test_retryable_reads_unified.py @@ -22,21 +22,15 @@ from pathlib import Path sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "retryable_reads/unified") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "retryable_reads/unified") - # Generate unified tests. # PyMongo does not support MapReduce, ListDatabaseObjects or ListCollectionObjects. globals().update( generate_test_classes( - TEST_PATH, + get_test_path("retryable_reads", "unified"), module=__name__, expected_failures=["ListDatabaseObjects .*", "ListCollectionObjects .*", "MapReduce .*"], ) diff --git a/test/test_retryable_writes_unified.py b/test/test_retryable_writes_unified.py index 036c410e2..d06ee206f 100644 --- a/test/test_retryable_writes_unified.py +++ b/test/test_retryable_writes_unified.py @@ -22,18 +22,14 @@ from pathlib import Path sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "retryable_writes/unified") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "retryable_writes/unified") - # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update( + generate_test_classes(get_test_path("retryable_writes", "unified"), module=__name__) +) if __name__ == "__main__": unittest.main() diff --git a/test/test_run_command.py b/test/test_run_command.py index d2ef43b97..df835fb6d 100644 --- a/test/test_run_command.py +++ b/test/test_run_command.py @@ -18,20 +18,13 @@ from __future__ import annotations import os import unittest from pathlib import Path -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "run_command") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "run_command") - - globals().update( generate_test_classes( - os.path.join(TEST_PATH, "unified"), + get_test_path("run_command", "unified"), module=__name__, ) ) diff --git a/test/test_server_selection_logging.py b/test/test_server_selection_logging.py index d53d8dc84..c48e166a1 100644 --- a/test/test_server_selection_logging.py +++ b/test/test_server_selection_logging.py @@ -22,20 +22,14 @@ from pathlib import Path sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "server_selection_logging") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "server_selection_logging") - globals().update( generate_test_classes( - TEST_PATH, + get_test_path("server_selection_logging"), module=__name__, ) ) diff --git a/test/test_sessions_unified.py b/test/test_sessions_unified.py index 3c80c70d3..3d15fac85 100644 --- a/test/test_sessions_unified.py +++ b/test/test_sessions_unified.py @@ -22,19 +22,12 @@ from pathlib import Path sys.path[0:0] = [""] from test import unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "sessions") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "sessions") - - # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("sessions"), module=__name__)) if __name__ == "__main__": unittest.main() diff --git a/test/test_transactions_unified.py b/test/test_transactions_unified.py index 4ab4885e2..05e4a1e5c 100644 --- a/test/test_transactions_unified.py +++ b/test/test_transactions_unified.py @@ -22,7 +22,7 @@ from pathlib import Path sys.path[0:0] = [""] from test import client_context, unittest -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path _IS_SYNC = True @@ -31,25 +31,13 @@ def setUpModule(): pass -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "transactions/unified") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "transactions/unified") +# Generate unified tests. +globals().update(generate_test_classes(get_test_path("transactions/unified"), module=__name__)) # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) - -# Location of JSON test specifications for transactions-convenient-api. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "transactions-convenient-api/unified") -else: - TEST_PATH = os.path.join( - Path(__file__).resolve().parent.parent, "transactions-convenient-api/unified" - ) - -# Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update( + generate_test_classes(get_test_path("transactions-convenient-api/unified"), module=__name__) +) if __name__ == "__main__": unittest.main() diff --git a/test/test_unified_format.py b/test/test_unified_format.py index f1cfd0139..a55f81047 100644 --- a/test/test_unified_format.py +++ b/test/test_unified_format.py @@ -21,18 +21,18 @@ from typing import Any sys.path[0:0] = [""] from test import UnitTest, unittest -from test.unified_format import MatchEvaluatorUtil, generate_test_classes +from test.unified_format import ( + MatchEvaluatorUtil, + generate_test_classes, + get_test_path, +) from bson import ObjectId _IS_SYNC = True # Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "unified-test-format") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "unified-test-format") - +TEST_PATH = get_test_path("unified-test-format") globals().update( generate_test_classes( diff --git a/test/test_versioned_api_integration.py b/test/test_versioned_api_integration.py index 066a1935c..c4ee7856f 100644 --- a/test/test_versioned_api_integration.py +++ b/test/test_versioned_api_integration.py @@ -16,7 +16,7 @@ from __future__ import annotations import os import sys from pathlib import Path -from test.unified_format import generate_test_classes +from test.unified_format import generate_test_classes, get_test_path sys.path[0:0] = [""] @@ -27,15 +27,8 @@ from pymongo.server_api import ServerApi _IS_SYNC = True -# Location of JSON test specifications. -if _IS_SYNC: - TEST_PATH = os.path.join(Path(__file__).resolve().parent, "versioned-api") -else: - TEST_PATH = os.path.join(Path(__file__).resolve().parent.parent, "versioned-api") - - # Generate unified tests. -globals().update(generate_test_classes(TEST_PATH, module=__name__)) +globals().update(generate_test_classes(get_test_path("versioned-api"), module=__name__)) class TestServerApiIntegration(IntegrationTest): diff --git a/test/unified_format.py b/test/unified_format.py index 0c5f68edd..277783a9a 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -29,6 +29,7 @@ import time import traceback from collections import defaultdict from inspect import iscoroutinefunction +from pathlib import Path from test import ( IntegrationTest, client_context, @@ -1549,6 +1550,14 @@ _SCHEMA_VERSION_MAJOR_TO_MIXIN_CLASS = { } +def get_test_path(*args): + if _IS_SYNC: + root_dir = Path(__file__).resolve().parent + else: + root_dir = Path(__file__).resolve().parent.parent + return os.path.join(root_dir, *args) + + def generate_test_classes( test_path, module=__name__, @@ -1581,10 +1590,12 @@ def generate_test_classes( return base + found_any = False for dirpath, _, filenames in os.walk(test_path): dirname = os.path.split(dirpath)[-1] for filename in filenames: + found_any = True fpath = os.path.join(dirpath, filename) with open(fpath) as scenario_stream: # Use tz_aware=False to match how CodecOptions decodes @@ -1622,4 +1633,7 @@ def generate_test_classes( continue raise + if not found_any: + raise ValueError(f"No test files found in {test_path}") + return test_klasses