PYTHON-5071 Use one event loop for all asyncio tests (#2086)
This commit is contained in:
parent
34ae214e33
commit
01f659cd8b
121
test/__init__.py
121
test/__init__.py
@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import inspect
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
@ -30,6 +31,33 @@ import traceback
|
||||
import unittest
|
||||
import warnings
|
||||
from asyncio import iscoroutinefunction
|
||||
|
||||
from pymongo.uri_parser import parse_uri
|
||||
|
||||
try:
|
||||
import ipaddress
|
||||
|
||||
HAVE_IPADDRESS = True
|
||||
except ImportError:
|
||||
HAVE_IPADDRESS = False
|
||||
from contextlib import contextmanager
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Callable, Dict, Generator, overload
|
||||
from unittest import SkipTest
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
import pymongo
|
||||
import pymongo.errors
|
||||
from bson.son import SON
|
||||
from pymongo.common import partition_node
|
||||
from pymongo.hello import HelloCompat
|
||||
from pymongo.server_api import ServerApi
|
||||
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
|
||||
from pymongo.synchronous.database import Database
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.helpers import (
|
||||
COMPRESSORS,
|
||||
IS_SRV,
|
||||
@ -52,31 +80,7 @@ from test.helpers import (
|
||||
sanitize_cmd,
|
||||
sanitize_reply,
|
||||
)
|
||||
|
||||
from pymongo.uri_parser import parse_uri
|
||||
|
||||
try:
|
||||
import ipaddress
|
||||
|
||||
HAVE_IPADDRESS = True
|
||||
except ImportError:
|
||||
HAVE_IPADDRESS = False
|
||||
from contextlib import contextmanager
|
||||
from functools import partial, wraps
|
||||
from test.version import Version
|
||||
from typing import Any, Callable, Dict, Generator, overload
|
||||
from unittest import SkipTest
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
import pymongo
|
||||
import pymongo.errors
|
||||
from bson.son import SON
|
||||
from pymongo.common import partition_node
|
||||
from pymongo.hello import HelloCompat
|
||||
from pymongo.server_api import ServerApi
|
||||
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
|
||||
from pymongo.synchronous.database import Database
|
||||
from pymongo.synchronous.mongo_client import MongoClient
|
||||
|
||||
_IS_SYNC = True
|
||||
|
||||
@ -863,18 +867,66 @@ class ClientContext:
|
||||
# Reusable client context
|
||||
client_context = ClientContext()
|
||||
|
||||
# Global event loop for async tests.
|
||||
LOOP = None
|
||||
|
||||
def reset_client_context():
|
||||
if _IS_SYNC:
|
||||
# sync tests don't need to reset a client context
|
||||
return
|
||||
elif client_context.client is not None:
|
||||
client_context.client.close()
|
||||
client_context.client = None
|
||||
client_context._init_client()
|
||||
|
||||
def get_loop() -> asyncio.AbstractEventLoop:
|
||||
"""Get the test suite's global event loop."""
|
||||
global LOOP
|
||||
if LOOP is None:
|
||||
try:
|
||||
LOOP = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# no running event loop, fallback to get_event_loop.
|
||||
try:
|
||||
# Ignore DeprecationWarning: There is no current event loop
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
LOOP = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
LOOP = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(LOOP)
|
||||
return LOOP
|
||||
|
||||
|
||||
class PyMongoTestCase(unittest.TestCase):
|
||||
if not _IS_SYNC:
|
||||
# An async TestCase that uses a single event loop for all tests.
|
||||
# Inspired by TestCase.
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
def tearDown(self):
|
||||
pass
|
||||
|
||||
def addCleanup(self, func, /, *args, **kwargs):
|
||||
self.addCleanup(*(func, *args), **kwargs)
|
||||
|
||||
def _callSetUp(self):
|
||||
self.setUp()
|
||||
self._callAsync(self.setUp)
|
||||
|
||||
def _callTestMethod(self, method):
|
||||
self._callMaybeAsync(method)
|
||||
|
||||
def _callTearDown(self):
|
||||
self._callAsync(self.tearDown)
|
||||
self.tearDown()
|
||||
|
||||
def _callCleanup(self, function, *args, **kwargs):
|
||||
self._callMaybeAsync(function, *args, **kwargs)
|
||||
|
||||
def _callAsync(self, func, /, *args, **kwargs):
|
||||
assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function"
|
||||
return get_loop().run_until_complete(func(*args, **kwargs))
|
||||
|
||||
def _callMaybeAsync(self, func, /, *args, **kwargs):
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return get_loop().run_until_complete(func(*args, **kwargs))
|
||||
else:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def assertEqualCommand(self, expected, actual, msg=None):
|
||||
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
|
||||
|
||||
@ -1136,8 +1188,6 @@ class IntegrationTest(PyMongoTestCase):
|
||||
|
||||
@client_context.require_connection
|
||||
def setUp(self) -> None:
|
||||
if not _IS_SYNC:
|
||||
reset_client_context()
|
||||
if client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
|
||||
raise SkipTest("this test does not support load balancers")
|
||||
if client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
|
||||
@ -1186,6 +1236,9 @@ class MockClientTest(UnitTest):
|
||||
|
||||
|
||||
def setup():
|
||||
if not _IS_SYNC:
|
||||
# Set up the event loop.
|
||||
get_loop()
|
||||
client_context.init()
|
||||
warnings.resetwarnings()
|
||||
warnings.simplefilter("always")
|
||||
|
||||
@ -17,6 +17,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
import inspect
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
@ -30,6 +31,33 @@ import traceback
|
||||
import unittest
|
||||
import warnings
|
||||
from asyncio import iscoroutinefunction
|
||||
|
||||
from pymongo.uri_parser import parse_uri
|
||||
|
||||
try:
|
||||
import ipaddress
|
||||
|
||||
HAVE_IPADDRESS = True
|
||||
except ImportError:
|
||||
HAVE_IPADDRESS = False
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Callable, Dict, Generator, overload
|
||||
from unittest import SkipTest
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
import pymongo
|
||||
import pymongo.errors
|
||||
from bson.son import SON
|
||||
from pymongo.asynchronous.database import AsyncDatabase
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.common import partition_node
|
||||
from pymongo.hello import HelloCompat
|
||||
from pymongo.server_api import ServerApi
|
||||
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
|
||||
|
||||
sys.path[0:0] = [""]
|
||||
|
||||
from test.helpers import (
|
||||
COMPRESSORS,
|
||||
IS_SRV,
|
||||
@ -52,31 +80,7 @@ from test.helpers import (
|
||||
sanitize_cmd,
|
||||
sanitize_reply,
|
||||
)
|
||||
|
||||
from pymongo.uri_parser import parse_uri
|
||||
|
||||
try:
|
||||
import ipaddress
|
||||
|
||||
HAVE_IPADDRESS = True
|
||||
except ImportError:
|
||||
HAVE_IPADDRESS = False
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from functools import partial, wraps
|
||||
from test.version import Version
|
||||
from typing import Any, Callable, Dict, Generator, overload
|
||||
from unittest import SkipTest
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
import pymongo
|
||||
import pymongo.errors
|
||||
from bson.son import SON
|
||||
from pymongo.asynchronous.database import AsyncDatabase
|
||||
from pymongo.asynchronous.mongo_client import AsyncMongoClient
|
||||
from pymongo.common import partition_node
|
||||
from pymongo.hello import HelloCompat
|
||||
from pymongo.server_api import ServerApi
|
||||
from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined]
|
||||
|
||||
_IS_SYNC = False
|
||||
|
||||
@ -865,18 +869,66 @@ class AsyncClientContext:
|
||||
# Reusable client context
|
||||
async_client_context = AsyncClientContext()
|
||||
|
||||
|
||||
async def reset_client_context():
|
||||
if _IS_SYNC:
|
||||
# sync tests don't need to reset a client context
|
||||
return
|
||||
elif async_client_context.client is not None:
|
||||
await async_client_context.client.close()
|
||||
async_client_context.client = None
|
||||
await async_client_context._init_client()
|
||||
# Global event loop for async tests.
|
||||
LOOP = None
|
||||
|
||||
|
||||
class AsyncPyMongoTestCase(unittest.IsolatedAsyncioTestCase):
|
||||
def get_loop() -> asyncio.AbstractEventLoop:
|
||||
"""Get the test suite's global event loop."""
|
||||
global LOOP
|
||||
if LOOP is None:
|
||||
try:
|
||||
LOOP = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
# no running event loop, fallback to get_event_loop.
|
||||
try:
|
||||
# Ignore DeprecationWarning: There is no current event loop
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", DeprecationWarning)
|
||||
LOOP = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
LOOP = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(LOOP)
|
||||
return LOOP
|
||||
|
||||
|
||||
class AsyncPyMongoTestCase(unittest.TestCase):
|
||||
if not _IS_SYNC:
|
||||
# An async TestCase that uses a single event loop for all tests.
|
||||
# Inspired by IsolatedAsyncioTestCase.
|
||||
async def asyncSetUp(self):
|
||||
pass
|
||||
|
||||
async def asyncTearDown(self):
|
||||
pass
|
||||
|
||||
def addAsyncCleanup(self, func, /, *args, **kwargs):
|
||||
self.addCleanup(*(func, *args), **kwargs)
|
||||
|
||||
def _callSetUp(self):
|
||||
self.setUp()
|
||||
self._callAsync(self.asyncSetUp)
|
||||
|
||||
def _callTestMethod(self, method):
|
||||
self._callMaybeAsync(method)
|
||||
|
||||
def _callTearDown(self):
|
||||
self._callAsync(self.asyncTearDown)
|
||||
self.tearDown()
|
||||
|
||||
def _callCleanup(self, function, *args, **kwargs):
|
||||
self._callMaybeAsync(function, *args, **kwargs)
|
||||
|
||||
def _callAsync(self, func, /, *args, **kwargs):
|
||||
assert inspect.iscoroutinefunction(func), f"{func!r} is not an async function"
|
||||
return get_loop().run_until_complete(func(*args, **kwargs))
|
||||
|
||||
def _callMaybeAsync(self, func, /, *args, **kwargs):
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return get_loop().run_until_complete(func(*args, **kwargs))
|
||||
else:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def assertEqualCommand(self, expected, actual, msg=None):
|
||||
self.assertEqual(sanitize_cmd(expected), sanitize_cmd(actual), msg)
|
||||
|
||||
@ -1154,8 +1206,6 @@ class AsyncIntegrationTest(AsyncPyMongoTestCase):
|
||||
|
||||
@async_client_context.require_connection
|
||||
async def asyncSetUp(self) -> None:
|
||||
if not _IS_SYNC:
|
||||
await reset_client_context()
|
||||
if async_client_context.load_balancer and not getattr(self, "RUN_ON_LOAD_BALANCER", False):
|
||||
raise SkipTest("this test does not support load balancers")
|
||||
if async_client_context.serverless and not getattr(self, "RUN_ON_SERVERLESS", False):
|
||||
@ -1204,6 +1254,9 @@ class AsyncMockClientTest(AsyncUnitTest):
|
||||
|
||||
|
||||
async def async_setup():
|
||||
if not _IS_SYNC:
|
||||
# Set up the event loop.
|
||||
get_loop()
|
||||
await async_client_context.init()
|
||||
warnings.resetwarnings()
|
||||
warnings.simplefilter("always")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user