PYTHON-5053 - AsyncMongoClient.close() should await all background tasks (#2127)
This commit is contained in:
parent
f344eb7965
commit
1b818470fc
@ -1565,6 +1565,12 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
# TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened.
|
||||
await self._encrypter.close()
|
||||
self._closed = True
|
||||
if not _IS_SYNC:
|
||||
await asyncio.gather(
|
||||
self._topology.cleanup_monitors(), # type: ignore[func-returns-value]
|
||||
self._kill_cursors_executor.join(), # type: ignore[func-returns-value]
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
if not _IS_SYNC:
|
||||
# Add support for contextlib.aclosing.
|
||||
|
||||
@ -112,9 +112,9 @@ class MonitorBase:
|
||||
"""
|
||||
self.gc_safe_close()
|
||||
|
||||
async def join(self, timeout: Optional[int] = None) -> None:
|
||||
async def join(self) -> None:
|
||||
"""Wait for the monitor to stop."""
|
||||
await self._executor.join(timeout)
|
||||
await self._executor.join()
|
||||
|
||||
def request_check(self) -> None:
|
||||
"""If the monitor is sleeping, wake it soon."""
|
||||
@ -189,6 +189,11 @@ class Monitor(MonitorBase):
|
||||
self._rtt_monitor.gc_safe_close()
|
||||
self.cancel_check()
|
||||
|
||||
async def join(self) -> None:
|
||||
await asyncio.gather(
|
||||
self._executor.join(), self._rtt_monitor.join(), return_exceptions=True
|
||||
) # type: ignore[func-returns-value]
|
||||
|
||||
async def close(self) -> None:
|
||||
self.gc_safe_close()
|
||||
await self._rtt_monitor.close()
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
@ -29,7 +30,7 @@ from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast
|
||||
|
||||
from pymongo import _csot, common, helpers_shared, periodic_executor
|
||||
from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool
|
||||
from pymongo.asynchronous.monitor import SrvMonitor
|
||||
from pymongo.asynchronous.monitor import MonitorBase, SrvMonitor
|
||||
from pymongo.asynchronous.pool import Pool
|
||||
from pymongo.asynchronous.server import Server
|
||||
from pymongo.errors import (
|
||||
@ -207,6 +208,9 @@ class Topology:
|
||||
if self._settings.fqdn is not None and not self._settings.load_balanced:
|
||||
self._srv_monitor = SrvMonitor(self, self._settings)
|
||||
|
||||
# Stores all monitor tasks that need to be joined on close or server selection
|
||||
self._monitor_tasks: list[MonitorBase] = []
|
||||
|
||||
async def open(self) -> None:
|
||||
"""Start monitoring, or restart after a fork.
|
||||
|
||||
@ -241,6 +245,8 @@ class Topology:
|
||||
# Close servers and clear the pools.
|
||||
for server in self._servers.values():
|
||||
await server.close()
|
||||
if not _IS_SYNC:
|
||||
self._monitor_tasks.append(server._monitor)
|
||||
# Reset the session pool to avoid duplicate sessions in
|
||||
# the child process.
|
||||
self._session_pool.reset()
|
||||
@ -283,6 +289,10 @@ class Topology:
|
||||
else:
|
||||
server_timeout = server_selection_timeout
|
||||
|
||||
# Cleanup any completed monitor tasks safely
|
||||
if not _IS_SYNC and self._monitor_tasks:
|
||||
await self.cleanup_monitors()
|
||||
|
||||
async with self._lock:
|
||||
server_descriptions = await self._select_servers_loop(
|
||||
selector, server_timeout, operation, operation_id, address
|
||||
@ -520,6 +530,8 @@ class Topology:
|
||||
and self._description.topology_type not in SRV_POLLING_TOPOLOGIES
|
||||
):
|
||||
await self._srv_monitor.close()
|
||||
if not _IS_SYNC:
|
||||
self._monitor_tasks.append(self._srv_monitor)
|
||||
|
||||
# Clear the pool from a failed heartbeat.
|
||||
if reset_pool:
|
||||
@ -695,6 +707,8 @@ class Topology:
|
||||
old_td = self._description
|
||||
for server in self._servers.values():
|
||||
await server.close()
|
||||
if not _IS_SYNC:
|
||||
self._monitor_tasks.append(server._monitor)
|
||||
|
||||
# Mark all servers Unknown.
|
||||
self._description = self._description.reset()
|
||||
@ -705,6 +719,8 @@ class Topology:
|
||||
# Stop SRV polling thread.
|
||||
if self._srv_monitor:
|
||||
await self._srv_monitor.close()
|
||||
if not _IS_SYNC:
|
||||
self._monitor_tasks.append(self._srv_monitor)
|
||||
|
||||
self._opened = False
|
||||
self._closed = True
|
||||
@ -944,6 +960,8 @@ class Topology:
|
||||
for address, server in list(self._servers.items()):
|
||||
if not self._description.has_server(address):
|
||||
await server.close()
|
||||
if not _IS_SYNC:
|
||||
self._monitor_tasks.append(server._monitor)
|
||||
self._servers.pop(address)
|
||||
|
||||
def _create_pool_for_server(self, address: _Address) -> Pool:
|
||||
@ -1031,6 +1049,15 @@ class Topology:
|
||||
else:
|
||||
return ",".join(str(server.error) for server in servers if server.error)
|
||||
|
||||
async def cleanup_monitors(self) -> None:
|
||||
tasks = []
|
||||
try:
|
||||
while self._monitor_tasks:
|
||||
tasks.append(self._monitor_tasks.pop())
|
||||
except IndexError:
|
||||
pass
|
||||
await asyncio.gather(*[t.join() for t in tasks], return_exceptions=True) # type: ignore[func-returns-value]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
msg = ""
|
||||
if not self._opened:
|
||||
|
||||
@ -75,6 +75,8 @@ class AsyncPeriodicExecutor:
|
||||
callback; see monitor.py.
|
||||
"""
|
||||
self._stopped = True
|
||||
if self._task is not None:
|
||||
self._task.cancel()
|
||||
|
||||
async def join(self, timeout: Optional[int] = None) -> None:
|
||||
if self._task is not None:
|
||||
|
||||
@ -1559,6 +1559,12 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
|
||||
# TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened.
|
||||
self._encrypter.close()
|
||||
self._closed = True
|
||||
if not _IS_SYNC:
|
||||
asyncio.gather(
|
||||
self._topology.cleanup_monitors(), # type: ignore[func-returns-value]
|
||||
self._kill_cursors_executor.join(), # type: ignore[func-returns-value]
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
if not _IS_SYNC:
|
||||
# Add support for contextlib.closing.
|
||||
|
||||
@ -112,9 +112,9 @@ class MonitorBase:
|
||||
"""
|
||||
self.gc_safe_close()
|
||||
|
||||
def join(self, timeout: Optional[int] = None) -> None:
|
||||
def join(self) -> None:
|
||||
"""Wait for the monitor to stop."""
|
||||
self._executor.join(timeout)
|
||||
self._executor.join()
|
||||
|
||||
def request_check(self) -> None:
|
||||
"""If the monitor is sleeping, wake it soon."""
|
||||
@ -189,6 +189,9 @@ class Monitor(MonitorBase):
|
||||
self._rtt_monitor.gc_safe_close()
|
||||
self.cancel_check()
|
||||
|
||||
def join(self) -> None:
|
||||
asyncio.gather(self._executor.join(), self._rtt_monitor.join(), return_exceptions=True) # type: ignore[func-returns-value]
|
||||
|
||||
def close(self) -> None:
|
||||
self.gc_safe_close()
|
||||
self._rtt_monitor.close()
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
@ -61,7 +62,7 @@ from pymongo.server_selectors import (
|
||||
writable_server_selector,
|
||||
)
|
||||
from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool
|
||||
from pymongo.synchronous.monitor import SrvMonitor
|
||||
from pymongo.synchronous.monitor import MonitorBase, SrvMonitor
|
||||
from pymongo.synchronous.pool import Pool
|
||||
from pymongo.synchronous.server import Server
|
||||
from pymongo.topology_description import (
|
||||
@ -207,6 +208,9 @@ class Topology:
|
||||
if self._settings.fqdn is not None and not self._settings.load_balanced:
|
||||
self._srv_monitor = SrvMonitor(self, self._settings)
|
||||
|
||||
# Stores all monitor tasks that need to be joined on close or server selection
|
||||
self._monitor_tasks: list[MonitorBase] = []
|
||||
|
||||
def open(self) -> None:
|
||||
"""Start monitoring, or restart after a fork.
|
||||
|
||||
@ -241,6 +245,8 @@ class Topology:
|
||||
# Close servers and clear the pools.
|
||||
for server in self._servers.values():
|
||||
server.close()
|
||||
if not _IS_SYNC:
|
||||
self._monitor_tasks.append(server._monitor)
|
||||
# Reset the session pool to avoid duplicate sessions in
|
||||
# the child process.
|
||||
self._session_pool.reset()
|
||||
@ -283,6 +289,10 @@ class Topology:
|
||||
else:
|
||||
server_timeout = server_selection_timeout
|
||||
|
||||
# Cleanup any completed monitor tasks safely
|
||||
if not _IS_SYNC and self._monitor_tasks:
|
||||
self.cleanup_monitors()
|
||||
|
||||
with self._lock:
|
||||
server_descriptions = self._select_servers_loop(
|
||||
selector, server_timeout, operation, operation_id, address
|
||||
@ -520,6 +530,8 @@ class Topology:
|
||||
and self._description.topology_type not in SRV_POLLING_TOPOLOGIES
|
||||
):
|
||||
self._srv_monitor.close()
|
||||
if not _IS_SYNC:
|
||||
self._monitor_tasks.append(self._srv_monitor)
|
||||
|
||||
# Clear the pool from a failed heartbeat.
|
||||
if reset_pool:
|
||||
@ -693,6 +705,8 @@ class Topology:
|
||||
old_td = self._description
|
||||
for server in self._servers.values():
|
||||
server.close()
|
||||
if not _IS_SYNC:
|
||||
self._monitor_tasks.append(server._monitor)
|
||||
|
||||
# Mark all servers Unknown.
|
||||
self._description = self._description.reset()
|
||||
@ -703,6 +717,8 @@ class Topology:
|
||||
# Stop SRV polling thread.
|
||||
if self._srv_monitor:
|
||||
self._srv_monitor.close()
|
||||
if not _IS_SYNC:
|
||||
self._monitor_tasks.append(self._srv_monitor)
|
||||
|
||||
self._opened = False
|
||||
self._closed = True
|
||||
@ -942,6 +958,8 @@ class Topology:
|
||||
for address, server in list(self._servers.items()):
|
||||
if not self._description.has_server(address):
|
||||
server.close()
|
||||
if not _IS_SYNC:
|
||||
self._monitor_tasks.append(server._monitor)
|
||||
self._servers.pop(address)
|
||||
|
||||
def _create_pool_for_server(self, address: _Address) -> Pool:
|
||||
@ -1029,6 +1047,15 @@ class Topology:
|
||||
else:
|
||||
return ",".join(str(server.error) for server in servers if server.error)
|
||||
|
||||
def cleanup_monitors(self) -> None:
|
||||
tasks = []
|
||||
try:
|
||||
while self._monitor_tasks:
|
||||
tasks.append(self._monitor_tasks.pop())
|
||||
except IndexError:
|
||||
pass
|
||||
asyncio.gather(*[t.join() for t in tasks], return_exceptions=True) # type: ignore[func-returns-value]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
msg = ""
|
||||
if not self._opened:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user