PYTHON-5071 Use one event loop for all asyncio tests (#2086)

This commit is contained in:
Shane Harvey 2025-01-30 12:34:59 -08:00 committed by GitHub
parent 34ae214e33
commit 01f659cd8b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 176 additions and 70 deletions

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import asyncio
import gc
import inspect
import logging
import multiprocessing
import os
@ -30,6 +31,33 @@ import traceback
import unittest
import warnings
from asyncio import iscoroutinefunction
from pymongo.uri_parser import parse_uri
try:
import ipaddress
HAVE_IPADDRESS = True
except ImportError:
HAVE_IPADDRESS = False
from contextlib import contextmanager
from functools import partial, wraps
from typing import Any, Callable, Dict, Generator, overload
from unittest import SkipTest
from urllib.parse import quote_plus
import pymongo
import pymongo.errors
from bson.son import SON
from pymongo.common import partition_node
from pymongo.hello import HelloCompat
from pymongo.server_api import ServerApi
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
from pymongo.synchronous.database import Database
from pymongo.synchronous.mongo_client import MongoClient
sys.path[0:0] = [""]
from test.helpers import (
COMPRESSORS,
IS_SRV,
@ -52,31 +80,7 @@ from test.helpers import (
sanitize_cmd,
sanitize_reply,
)
from pymongo.uri_parser import parse_uri
try:
import ipaddress
HAVE_IPADDRESS = True
except ImportError:
HAVE_IPADDRESS = False
from contextlib import contextmanager
from functools import partial, wraps
from test.version import Version
from typing import Any, Callable, Dict, Generator, overload
from unittest import SkipTest
from urllib.parse import quote_plus
import pymongo
import pymongo.errors
from bson.son import SON
from pymongo.common import partition_node
from pymongo.hello import HelloCompat
from pymongo.server_api import ServerApi
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
from pymongo.synchronous.database import Database
from pymongo.synchronous.mongo_client import MongoClient
_IS_SYNC = True
@ -863,18 +867,66 @@ class ClientContext:
# Reusable client context
client_context = ClientContext()
# Global event loop for async tests.
LOOP = None
def reset_client_context():
if _IS_SYNC:
# sync tests don't need to reset a client context
return
elif client_context.client is not None:
client_context.client.close()
client_context.client = None
client_context._init_client()
def get_loop() -> asyncio.AbstractEventLoop:
"""Get the test suite's global event loop."""
global LOOP
if LOOP is None:
try:
LOOP = asyncio.get_running_loop()
except RuntimeError:
# no running event loop, fallback to get_event_loop.
try:
# Ignore DeprecationWarning: There is no current event loop
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
LOOP = asyncio.get_event_loop()
except RuntimeError:
LOOP = asyncio.new_event_loop()
asyncio.set_event_loop(LOOP)
return LOOP
class PyMongoTestCase(unittest.TestCase):
if not _IS_SYNC:
# An async TestCase that uses a single event loop for all tests.
# Inspired by TestCase.
def setUp(self):
pass
def tearDown(self):
pass
def addCleanup(self, func, /, *args, **kwargs):
self.addCleanup(*(func, *args), **kwargs)
def _callSetUp(self):
self.setUp()
self._callAsync(self.setUp)
def _callTestMethod(self, method):
self._callMaybeAsync(method)
def _callTearDown(self):
self._callAsync(self.tearDown)
self.tearDown()
def _callCleanup(self, function, *args, **kwargs):
self._callMaybeAsync(function, *args, **kwargs)
def _callAsync(self, func, /, *args, **kwargs):
assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function"
return get_loop().run_until_complete(func(*args, **kwargs))
def _callMaybeAsync(self, func, /, *args, **kwargs):
if inspect.iscoroutinefunction(func):
return get_loop().run_until_complete(func(*args, **kwargs))
else:
return func(*args, **kwargs)
def assertEqualCommand(self, expected, actual, msg=None):
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
@ -1136,8 +1188,6 @@ class IntegrationTest(PyMongoTestCase):
@client_context.require_connection
def setUp(self) -> None:
if not _IS_SYNC:
reset_client_context()
if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
raise SkipTest("this test does not support load balancers")
if client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
@ -1186,6 +1236,9 @@ class MockClientTest(UnitTest):
def setup():
if not _IS_SYNC:
# Set up the event loop.
get_loop()
client_context.init()
warnings.resetwarnings()
warnings.simplefilter("always")

View File

@ -17,6 +17,7 @@ from __future__ import annotations
import asyncio
import gc
import inspect
import logging
import multiprocessing
import os
@ -30,6 +31,33 @@ import traceback
import unittest
import warnings
from asyncio import iscoroutinefunction
from pymongo.uri_parser import parse_uri
try:
import ipaddress
HAVE_IPADDRESS = True
except ImportError:
HAVE_IPADDRESS = False
from contextlib import asynccontextmanager, contextmanager
from functools import partial, wraps
from typing import Any, Callable, Dict, Generator, overload
from unittest import SkipTest
from urllib.parse import quote_plus
import pymongo
import pymongo.errors
from bson.son import SON
from pymongo.asynchronous.database import AsyncDatabase
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.common import partition_node
from pymongo.hello import HelloCompat
from pymongo.server_api import ServerApi
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
sys.path[0:0] = [""]
from test.helpers import (
COMPRESSORS,
IS_SRV,
@ -52,31 +80,7 @@ from test.helpers import (
sanitize_cmd,
sanitize_reply,
)
from pymongo.uri_parser import parse_uri
try:
import ipaddress
HAVE_IPADDRESS = True
except ImportError:
HAVE_IPADDRESS = False
from contextlib import asynccontextmanager, contextmanager
from functools import partial, wraps
from test.version import Version
from typing import Any, Callable, Dict, Generator, overload
from unittest import SkipTest
from urllib.parse import quote_plus
import pymongo
import pymongo.errors
from bson.son import SON
from pymongo.asynchronous.database import AsyncDatabase
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.common import partition_node
from pymongo.hello import HelloCompat
from pymongo.server_api import ServerApi
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
_IS_SYNC = False
@ -865,18 +869,66 @@ class AsyncClientContext:
# Reusable client context
async_client_context = AsyncClientContext()
async def reset_client_context():
if _IS_SYNC:
# sync tests don't need to reset a client context
return
elif async_client_context.client is not None:
await async_client_context.client.close()
async_client_context.client = None
await async_client_context._init_client()
# Global event loop for async tests.
LOOP = None
class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
def get_loop() -> asyncio.AbstractEventLoop:
"""Get the test suite's global event loop."""
global LOOP
if LOOP is None:
try:
LOOP = asyncio.get_running_loop()
except RuntimeError:
# no running event loop, fallback to get_event_loop.
try:
# Ignore DeprecationWarning: There is no current event loop
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
LOOP = asyncio.get_event_loop()
except RuntimeError:
LOOP = asyncio.new_event_loop()
asyncio.set_event_loop(LOOP)
return LOOP
class AsyncPyMongoTestCase(unittest.TestCase):
if not _IS_SYNC:
# An async TestCase that uses a single event loop for all tests.
# Inspired by IsolatedAsyncioTestCase.
async def asyncSetUp(self):
pass
async def asyncTearDown(self):
pass
def addAsyncCleanup(self, func, /, *args, **kwargs):
self.addCleanup(*(func, *args), **kwargs)
def _callSetUp(self):
self.setUp()
self._callAsync(self.asyncSetUp)
def _callTestMethod(self, method):
self._callMaybeAsync(method)
def _callTearDown(self):
self._callAsync(self.asyncTearDown)
self.tearDown()
def _callCleanup(self, function, *args, **kwargs):
self._callMaybeAsync(function, *args, **kwargs)
def _callAsync(self, func, /, *args, **kwargs):
assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function"
return get_loop().run_until_complete(func(*args, **kwargs))
def _callMaybeAsync(self, func, /, *args, **kwargs):
if inspect.iscoroutinefunction(func):
return get_loop().run_until_complete(func(*args, **kwargs))
else:
return func(*args, **kwargs)
def assertEqualCommand(self, expected, actual, msg=None):
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
@ -1154,8 +1206,6 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
@async_client_context.require_connection
async def asyncSetUp(self) -> None:
if not _IS_SYNC:
await reset_client_context()
if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
raise SkipTest("this test does not support load balancers")
if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
@ -1204,6 +1254,9 @@ class AsyncMockClientTest(AsyncUnitTest):
async def async_setup():
if not _IS_SYNC:
# Set up the event loop.
get_loop()
await async_client_context.init()
warnings.resetwarnings()
warnings.simplefilter("always")