From 01f659cd8bd6ae970b044c3043c9ed2ca6d89bf4 Mon Sep 17 00:00:00 2001 From: Shane Harvey Date: Thu, 30 Jan 2025 12:34:59 -0800 Subject: [PATCH] PYTHON-5071 Use one event loop for all asyncio tests (#2086) --- test/__init__.py | 121 +++++++++++++++++++++++--------- test/asynchronous/__init__.py | 125 ++++++++++++++++++++++++---------- 2 files changed, 176 insertions(+), 70 deletions(-) diff --git a/test/__init__.py b/test/__init__.py index d3a63db2d..b49eee99a 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -17,6 +17,7 @@ from __future__ import annotations import asyncio import gc +import inspect import logging import multiprocessing import os @@ -30,6 +31,33 @@ import traceback import unittest import warnings from asyncio import iscoroutinefunction + +from pymongo.uri_parser import parse_uri + +try: + import ipaddress + + HAVE_IPADDRESS = True +except ImportError: + HAVE_IPADDRESS = False +from contextlib import contextmanager +from functools import partial, wraps +from typing import Any, Callable, Dict, Generator, overload +from unittest import SkipTest +from urllib.parse import quote_plus + +import pymongo +import pymongo.errors +from bson.son import SON +from pymongo.common import partition_node +from pymongo.hello import HelloCompat +from pymongo.server_api import ServerApi +from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] +from pymongo.synchronous.database import Database +from pymongo.synchronous.mongo_client import MongoClient + +sys.path[0:0] = [""] + from test.helpers import ( COMPRESSORS, IS_SRV, @@ -52,31 +80,7 @@ from test.helpers import ( sanitize_cmd, sanitize_reply, ) - -from pymongo.uri_parser import parse_uri - -try: - import ipaddress - - HAVE_IPADDRESS = True -except ImportError: - HAVE_IPADDRESS = False -from contextlib import contextmanager -from functools import partial, wraps from test.version import Version -from typing import Any, Callable, Dict, Generator, overload -from unittest import SkipTest -from urllib.parse import quote_plus - -import pymongo -import pymongo.errors -from bson.son import SON -from pymongo.common import partition_node -from pymongo.hello import HelloCompat -from pymongo.server_api import ServerApi -from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] -from pymongo.synchronous.database import Database -from pymongo.synchronous.mongo_client import MongoClient _IS_SYNC = True @@ -863,18 +867,66 @@ class ClientContext: # Reusable client context client_context = ClientContext() +# Global event loop for async tests. +LOOP = None -def reset_client_context(): - if _IS_SYNC: - # sync tests don't need to reset a client context - return - elif client_context.client is not None: - client_context.client.close() - client_context.client = None - client_context._init_client() + +def get_loop() -> asyncio.AbstractEventLoop: + """Get the test suite's global event loop.""" + global LOOP + if LOOP is None: + try: + LOOP = asyncio.get_running_loop() + except RuntimeError: + # no running event loop, fallback to get_event_loop. + try: + # Ignore DeprecationWarning: There is no current event loop + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + LOOP = asyncio.get_event_loop() + except RuntimeError: + LOOP = asyncio.new_event_loop() + asyncio.set_event_loop(LOOP) + return LOOP class PyMongoTestCase(unittest.TestCase): + if not _IS_SYNC: + # An async TestCase that uses a single event loop for all tests. + # Inspired by TestCase. + def setUp(self): + pass + + def tearDown(self): + pass + + def addCleanup(self, func, /, *args, **kwargs): + self.addCleanup(*(func, *args), **kwargs) + + def _callSetUp(self): + self.setUp() + self._callAsync(self.setUp) + + def _callTestMethod(self, method): + self._callMaybeAsync(method) + + def _callTearDown(self): + self._callAsync(self.tearDown) + self.tearDown() + + def _callCleanup(self, function, *args, **kwargs): + self._callMaybeAsync(function, *args, **kwargs) + + def _callAsync(self, func, /, *args, **kwargs): + assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function" + return get_loop().run_until_complete(func(*args, **kwargs)) + + def _callMaybeAsync(self, func, /, *args, **kwargs): + if inspect.iscoroutinefunction(func): + return get_loop().run_until_complete(func(*args, **kwargs)) + else: + return func(*args, **kwargs) + def assertEqualCommand(self, expected, actual, msg=None): self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) @@ -1136,8 +1188,6 @@ class IntegrationTest(PyMongoTestCase): @client_context.require_connection def setUp(self) -> None: - if not _IS_SYNC: - reset_client_context() if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False): raise SkipTest("this test does not support load balancers") if client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False): @@ -1186,6 +1236,9 @@ class MockClientTest(UnitTest): def setup(): + if not _IS_SYNC: + # Set up the event loop. + get_loop() client_context.init() warnings.resetwarnings() warnings.simplefilter("always") diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 73e282474..76fae407d 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -17,6 +17,7 @@ from __future__ import annotations import asyncio import gc +import inspect import logging import multiprocessing import os @@ -30,6 +31,33 @@ import traceback import unittest import warnings from asyncio import iscoroutinefunction + +from pymongo.uri_parser import parse_uri + +try: + import ipaddress + + HAVE_IPADDRESS = True +except ImportError: + HAVE_IPADDRESS = False +from contextlib import asynccontextmanager, contextmanager +from functools import partial, wraps +from typing import Any, Callable, Dict, Generator, overload +from unittest import SkipTest +from urllib.parse import quote_plus + +import pymongo +import pymongo.errors +from bson.son import SON +from pymongo.asynchronous.database import AsyncDatabase +from pymongo.asynchronous.mongo_client import AsyncMongoClient +from pymongo.common import partition_node +from pymongo.hello import HelloCompat +from pymongo.server_api import ServerApi +from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] + +sys.path[0:0] = [""] + from test.helpers import ( COMPRESSORS, IS_SRV, @@ -52,31 +80,7 @@ from test.helpers import ( sanitize_cmd, sanitize_reply, ) - -from pymongo.uri_parser import parse_uri - -try: - import ipaddress - - HAVE_IPADDRESS = True -except ImportError: - HAVE_IPADDRESS = False -from contextlib import asynccontextmanager, contextmanager -from functools import partial, wraps from test.version import Version -from typing import Any, Callable, Dict, Generator, overload -from unittest import SkipTest -from urllib.parse import quote_plus - -import pymongo -import pymongo.errors -from bson.son import SON -from pymongo.asynchronous.database import AsyncDatabase -from pymongo.asynchronous.mongo_client import AsyncMongoClient -from pymongo.common import partition_node -from pymongo.hello import HelloCompat -from pymongo.server_api import ServerApi -from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] _IS_SYNC = False @@ -865,18 +869,66 @@ class AsyncClientContext: # Reusable client context async_client_context = AsyncClientContext() - -async def reset_client_context(): - if _IS_SYNC: - # sync tests don't need to reset a client context - return - elif async_client_context.client is not None: - await async_client_context.client.close() - async_client_context.client = None - await async_client_context._init_client() +# Global event loop for async tests. +LOOP = None -class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase): +def get_loop() -> asyncio.AbstractEventLoop: + """Get the test suite's global event loop.""" + global LOOP + if LOOP is None: + try: + LOOP = asyncio.get_running_loop() + except RuntimeError: + # no running event loop, fallback to get_event_loop. + try: + # Ignore DeprecationWarning: There is no current event loop + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + LOOP = asyncio.get_event_loop() + except RuntimeError: + LOOP = asyncio.new_event_loop() + asyncio.set_event_loop(LOOP) + return LOOP + + +class AsyncPyMongoTestCase(unittest.TestCase): + if not _IS_SYNC: + # An async TestCase that uses a single event loop for all tests. + # Inspired by IsolatedAsyncioTestCase. + async def asyncSetUp(self): + pass + + async def asyncTearDown(self): + pass + + def addAsyncCleanup(self, func, /, *args, **kwargs): + self.addCleanup(*(func, *args), **kwargs) + + def _callSetUp(self): + self.setUp() + self._callAsync(self.asyncSetUp) + + def _callTestMethod(self, method): + self._callMaybeAsync(method) + + def _callTearDown(self): + self._callAsync(self.asyncTearDown) + self.tearDown() + + def _callCleanup(self, function, *args, **kwargs): + self._callMaybeAsync(function, *args, **kwargs) + + def _callAsync(self, func, /, *args, **kwargs): + assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function" + return get_loop().run_until_complete(func(*args, **kwargs)) + + def _callMaybeAsync(self, func, /, *args, **kwargs): + if inspect.iscoroutinefunction(func): + return get_loop().run_until_complete(func(*args, **kwargs)) + else: + return func(*args, **kwargs) + def assertEqualCommand(self, expected, actual, msg=None): self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg) @@ -1154,8 +1206,6 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase): @async_client_context.require_connection async def asyncSetUp(self) -> None: - if not _IS_SYNC: - await reset_client_context() if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False): raise SkipTest("this test does not support load balancers") if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False): @@ -1204,6 +1254,9 @@ class AsyncMockClientTest(AsyncUnitTest): async def async_setup(): + if not _IS_SYNC: + # Set up the event loop. + get_loop() await async_client_context.init() warnings.resetwarnings() warnings.simplefilter("always")