diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 7c8f7180b..7744a75d9 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -878,6 +878,7 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): self._opened = False self._closed = False + self._loop: Optional[asyncio.AbstractEventLoop] = None if not is_srv: self._init_background() @@ -1709,6 +1710,13 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]): If this client was created with "connect=False", calling _get_topology launches the connection process in the background. """ + if not _IS_SYNC: + if self._loop is None: + self._loop = asyncio.get_running_loop() + elif self._loop != asyncio.get_running_loop(): + raise RuntimeError( + "Cannot use AsyncMongoClient in different event loop. AsyncMongoClient uses low-level asyncio APIs that bind it to the event loop it was created on." + ) if not self._opened: if self._resolve_srv_info["is_srv"]: await self._resolve_srv() diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 14fdefcb6..1c0adb5d6 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -876,6 +876,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): self._opened = False self._closed = False + self._loop: Optional[asyncio.AbstractEventLoop] = None if not is_srv: self._init_background() @@ -1703,6 +1704,13 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]): If this client was created with "connect=False", calling _get_topology launches the connection process in the background. """ + if not _IS_SYNC: + if self._loop is None: + self._loop = asyncio.get_running_loop() + elif self._loop != asyncio.get_running_loop(): + raise RuntimeError( + "Cannot use MongoClient in different event loop. MongoClient uses low-level asyncio APIs that bind it to the event loop it was created on." + ) if not self._opened: if self._resolve_srv_info["is_srv"]: self._resolve_srv() diff --git a/pyproject.toml b/pyproject.toml index 611cac13a..4da75b4c1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -117,6 +117,8 @@ filterwarnings = [ "module:unclosed bool: """Return True for async tests that should not be converted to sync.""" - return f in ["test_locks.py", "test_concurrency.py", "test_async_cancellation.py"] + return f in [ + "test_locks.py", + "test_concurrency.py", + "test_async_cancellation.py", + "test_async_loop_safety.py", + ] test_files = [