PYTHON-5053 - AsyncMongoClient.close() should await all background tasks (#2127)

This commit is contained in:
Noah Stapp 2025-02-05 15:05:41 -05:00 committed by GitHub
parent f344eb7965
commit 1b818470fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 82 additions and 6 deletions

View File

@ -1565,6 +1565,12 @@ class AsyncMongoClient(common.BaseObject, Generic[_DocumentType]):
# TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened.
await self._encrypter.close()
self._closed = True
if not _IS_SYNC:
await asyncio.gather(
self._topology.cleanup_monitors(), # type: ignore[func-returns-value]
self._kill_cursors_executor.join(), # type: ignore[func-returns-value]
return_exceptions=True,
)
if not _IS_SYNC:
# Add support for contextlib.aclosing.

View File

@ -112,9 +112,9 @@ class MonitorBase:
"""
self.gc_safe_close()
async def join(self, timeout: Optional[int] = None) -> None:
async def join(self) -> None:
"""Wait for the monitor to stop."""
await self._executor.join(timeout)
await self._executor.join()
def request_check(self) -> None:
"""If the monitor is sleeping, wake it soon."""
@ -189,6 +189,11 @@ class Monitor(MonitorBase):
self._rtt_monitor.gc_safe_close()
self.cancel_check()
async def join(self) -> None:
await asyncio.gather(
self._executor.join(), self._rtt_monitor.join(), return_exceptions=True
) # type: ignore[func-returns-value]
async def close(self) -> None:
self.gc_safe_close()
await self._rtt_monitor.close()

View File

@ -16,6 +16,7 @@
from __future__ import annotations
import asyncio
import logging
import os
import queue
@ -29,7 +30,7 @@ from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast
from pymongo import _csot, common, helpers_shared, periodic_executor
from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool
from pymongo.asynchronous.monitor import SrvMonitor
from pymongo.asynchronous.monitor import MonitorBase, SrvMonitor
from pymongo.asynchronous.pool import Pool
from pymongo.asynchronous.server import Server
from pymongo.errors import (
@ -207,6 +208,9 @@ class Topology:
if self._settings.fqdn is not None and not self._settings.load_balanced:
self._srv_monitor = SrvMonitor(self, self._settings)
# Stores all monitor tasks that need to be joined on close or server selection
self._monitor_tasks: list[MonitorBase] = []
async def open(self) -> None:
"""Start monitoring, or restart after a fork.
@ -241,6 +245,8 @@ class Topology:
# Close servers and clear the pools.
for server in self._servers.values():
await server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)
# Reset the session pool to avoid duplicate sessions in
# the child process.
self._session_pool.reset()
@ -283,6 +289,10 @@ class Topology:
else:
server_timeout = server_selection_timeout
# Cleanup any completed monitor tasks safely
if not _IS_SYNC and self._monitor_tasks:
await self.cleanup_monitors()
async with self._lock:
server_descriptions = await self._select_servers_loop(
selector, server_timeout, operation, operation_id, address
@ -520,6 +530,8 @@ class Topology:
and self._description.topology_type not in SRV_POLLING_TOPOLOGIES
):
await self._srv_monitor.close()
if not _IS_SYNC:
self._monitor_tasks.append(self._srv_monitor)
# Clear the pool from a failed heartbeat.
if reset_pool:
@ -695,6 +707,8 @@ class Topology:
old_td = self._description
for server in self._servers.values():
await server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)
# Mark all servers Unknown.
self._description = self._description.reset()
@ -705,6 +719,8 @@ class Topology:
# Stop SRV polling thread.
if self._srv_monitor:
await self._srv_monitor.close()
if not _IS_SYNC:
self._monitor_tasks.append(self._srv_monitor)
self._opened = False
self._closed = True
@ -944,6 +960,8 @@ class Topology:
for address, server in list(self._servers.items()):
if not self._description.has_server(address):
await server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)
self._servers.pop(address)
def _create_pool_for_server(self, address: _Address) -> Pool:
@ -1031,6 +1049,15 @@ class Topology:
else:
return ",".join(str(server.error) for server in servers if server.error)
async def cleanup_monitors(self) -> None:
tasks = []
try:
while self._monitor_tasks:
tasks.append(self._monitor_tasks.pop())
except IndexError:
pass
await asyncio.gather(*[t.join() for t in tasks], return_exceptions=True) # type: ignore[func-returns-value]
def __repr__(self) -> str:
msg = ""
if not self._opened:

View File

@ -75,6 +75,8 @@ class AsyncPeriodicExecutor:
callback; see monitor.py.
"""
self._stopped = True
if self._task is not None:
self._task.cancel()
async def join(self, timeout: Optional[int] = None) -> None:
if self._task is not None:

View File

@ -1559,6 +1559,12 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
# TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened.
self._encrypter.close()
self._closed = True
if not _IS_SYNC:
asyncio.gather(
self._topology.cleanup_monitors(), # type: ignore[func-returns-value]
self._kill_cursors_executor.join(), # type: ignore[func-returns-value]
return_exceptions=True,
)
if not _IS_SYNC:
# Add support for contextlib.closing.

View File

@ -112,9 +112,9 @@ class MonitorBase:
"""
self.gc_safe_close()
def join(self, timeout: Optional[int] = None) -> None:
def join(self) -> None:
"""Wait for the monitor to stop."""
self._executor.join(timeout)
self._executor.join()
def request_check(self) -> None:
"""If the monitor is sleeping, wake it soon."""
@ -189,6 +189,9 @@ class Monitor(MonitorBase):
self._rtt_monitor.gc_safe_close()
self.cancel_check()
def join(self) -> None:
asyncio.gather(self._executor.join(), self._rtt_monitor.join(), return_exceptions=True) # type: ignore[func-returns-value]
def close(self) -> None:
self.gc_safe_close()
self._rtt_monitor.close()

View File

@ -16,6 +16,7 @@
from __future__ import annotations
import asyncio
import logging
import os
import queue
@ -61,7 +62,7 @@ from pymongo.server_selectors import (
writable_server_selector,
)
from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool
from pymongo.synchronous.monitor import SrvMonitor
from pymongo.synchronous.monitor import MonitorBase, SrvMonitor
from pymongo.synchronous.pool import Pool
from pymongo.synchronous.server import Server
from pymongo.topology_description import (
@ -207,6 +208,9 @@ class Topology:
if self._settings.fqdn is not None and not self._settings.load_balanced:
self._srv_monitor = SrvMonitor(self, self._settings)
# Stores all monitor tasks that need to be joined on close or server selection
self._monitor_tasks: list[MonitorBase] = []
def open(self) -> None:
"""Start monitoring, or restart after a fork.
@ -241,6 +245,8 @@ class Topology:
# Close servers and clear the pools.
for server in self._servers.values():
server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)
# Reset the session pool to avoid duplicate sessions in
# the child process.
self._session_pool.reset()
@ -283,6 +289,10 @@ class Topology:
else:
server_timeout = server_selection_timeout
# Cleanup any completed monitor tasks safely
if not _IS_SYNC and self._monitor_tasks:
self.cleanup_monitors()
with self._lock:
server_descriptions = self._select_servers_loop(
selector, server_timeout, operation, operation_id, address
@ -520,6 +530,8 @@ class Topology:
and self._description.topology_type not in SRV_POLLING_TOPOLOGIES
):
self._srv_monitor.close()
if not _IS_SYNC:
self._monitor_tasks.append(self._srv_monitor)
# Clear the pool from a failed heartbeat.
if reset_pool:
@ -693,6 +705,8 @@ class Topology:
old_td = self._description
for server in self._servers.values():
server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)
# Mark all servers Unknown.
self._description = self._description.reset()
@ -703,6 +717,8 @@ class Topology:
# Stop SRV polling thread.
if self._srv_monitor:
self._srv_monitor.close()
if not _IS_SYNC:
self._monitor_tasks.append(self._srv_monitor)
self._opened = False
self._closed = True
@ -942,6 +958,8 @@ class Topology:
for address, server in list(self._servers.items()):
if not self._description.has_server(address):
server.close()
if not _IS_SYNC:
self._monitor_tasks.append(server._monitor)
self._servers.pop(address)
def _create_pool_for_server(self, address: _Address) -> Pool:
@ -1029,6 +1047,15 @@ class Topology:
else:
return ",".join(str(server.error) for server in servers if server.error)
def cleanup_monitors(self) -> None:
tasks = []
try:
while self._monitor_tasks:
tasks.append(self._monitor_tasks.pop())
except IndexError:
pass
asyncio.gather(*[t.join() for t in tasks], return_exceptions=True) # type: ignore[func-returns-value]
def __repr__(self) -> str:
msg = ""
if not self._opened: