PYTHON-5212 - Do not hold Topology lock while resetting pool (#2301)

This commit is contained in:
Noah Stapp 2025-04-23 15:13:38 -04:00 committed by GitHub
parent e2e673edeb
commit 09897b698e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 230 additions and 25 deletions

View File

@ -14,6 +14,7 @@
from __future__ import annotations
import asyncio
import collections
import contextlib
import logging
@ -860,8 +861,14 @@ class Pool:
# PoolClosedEvent but that reset() SHOULD close sockets *after*
# publishing the PoolClearedEvent.
if close:
for conn in sockets:
await conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
if not _IS_SYNC:
await asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets],
return_exceptions=True,
)
else:
for conn in sockets:
await conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
if self.enabled_for_cmap:
assert listeners is not None
listeners.publish_pool_closed(self.address)
@ -891,8 +898,14 @@ class Pool:
serverPort=self.address[1],
serviceId=service_id,
)
for conn in sockets:
await conn.close_conn(ConnectionClosedReason.STALE)
if not _IS_SYNC:
await asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets],
return_exceptions=True,
)
else:
for conn in sockets:
await conn.close_conn(ConnectionClosedReason.STALE)
async def update_is_writable(self, is_writable: Optional[bool]) -> None:
"""Updates the is_writable attribute on all sockets currently in the
@ -938,8 +951,14 @@ class Pool:
and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds
):
close_conns.append(self.conns.pop())
for conn in close_conns:
await conn.close_conn(ConnectionClosedReason.IDLE)
if not _IS_SYNC:
await asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns],
return_exceptions=True,
)
else:
for conn in close_conns:
await conn.close_conn(ConnectionClosedReason.IDLE)
while True:
async with self.size_cond:

View File

@ -529,12 +529,6 @@ class Topology:
if not _IS_SYNC:
self._monitor_tasks.append(self._srv_monitor)
# Clear the pool from a failed heartbeat.
if reset_pool:
server = self._servers.get(server_description.address)
if server:
await server.pool.reset(interrupt_connections=interrupt_connections)
# Wake anything waiting in select_servers().
self._condition.notify_all()
@ -557,6 +551,11 @@ class Topology:
# that didn't include this server.
if self._opened and self._description.has_server(server_description.address):
await self._process_change(server_description, reset_pool, interrupt_connections)
# Clear the pool from a failed heartbeat, done outside the lock to avoid blocking on connection close.
if reset_pool:
server = self._servers.get(server_description.address)
if server:
await server.pool.reset(interrupt_connections=interrupt_connections)
async def _process_srv_update(self, seedlist: list[tuple[str, Any]]) -> None:
"""Process a new seedlist on an opened topology.

View File

@ -14,6 +14,7 @@
from __future__ import annotations
import asyncio
import collections
import contextlib
import logging
@ -858,8 +859,14 @@ class Pool:
# PoolClosedEvent but that reset() SHOULD close sockets *after*
# publishing the PoolClearedEvent.
if close:
for conn in sockets:
conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
if not _IS_SYNC:
asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets],
return_exceptions=True,
)
else:
for conn in sockets:
conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
if self.enabled_for_cmap:
assert listeners is not None
listeners.publish_pool_closed(self.address)
@ -889,8 +896,14 @@ class Pool:
serverPort=self.address[1],
serviceId=service_id,
)
for conn in sockets:
conn.close_conn(ConnectionClosedReason.STALE)
if not _IS_SYNC:
asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets],
return_exceptions=True,
)
else:
for conn in sockets:
conn.close_conn(ConnectionClosedReason.STALE)
def update_is_writable(self, is_writable: Optional[bool]) -> None:
"""Updates the is_writable attribute on all sockets currently in the
@ -934,8 +947,14 @@ class Pool:
and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds
):
close_conns.append(self.conns.pop())
for conn in close_conns:
conn.close_conn(ConnectionClosedReason.IDLE)
if not _IS_SYNC:
asyncio.gather(
*[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns],
return_exceptions=True,
)
else:
for conn in close_conns:
conn.close_conn(ConnectionClosedReason.IDLE)
while True:
with self.size_cond:

View File

@ -529,12 +529,6 @@ class Topology:
if not _IS_SYNC:
self._monitor_tasks.append(self._srv_monitor)
# Clear the pool from a failed heartbeat.
if reset_pool:
server = self._servers.get(server_description.address)
if server:
server.pool.reset(interrupt_connections=interrupt_connections)
# Wake anything waiting in select_servers().
self._condition.notify_all()
@ -557,6 +551,11 @@ class Topology:
# that didn't include this server.
if self._opened and self._description.has_server(server_description.address):
self._process_change(server_description, reset_pool, interrupt_connections)
# Clear the pool from a failed heartbeat, done outside the lock to avoid blocking on connection close.
if reset_pool:
server = self._servers.get(server_description.address)
if server:
server.pool.reset(interrupt_connections=interrupt_connections)
def _process_srv_update(self, seedlist: list[tuple[str, Any]]) -> None:
"""Process a new seedlist on an opened topology.

View File

@ -826,6 +826,14 @@ class ClientContext:
lambda: _IS_SYNC, "This test only works with the synchronous API", func=func
)
def require_async(self, func):
"""Run a test only if using the asynchronous API.""" # unasync: off
return self._require(
lambda: not _IS_SYNC,
"This test only works with the asynchronous API", # unasync: off
func=func,
)
def mongos_seeds(self):
return ",".join("{}:{}".format(*address) for address in self.mongoses)

View File

@ -828,6 +828,14 @@ class AsyncClientContext:
lambda: _IS_SYNC, "This test only works with the synchronous API", func=func
)
def require_async(self, func):
"""Run a test only if using the asynchronous API.""" # unasync: off
return self._require(
lambda: not _IS_SYNC,
"This test only works with the asynchronous API", # unasync: off
func=func,
)
def mongos_seeds(self):
return ",".join("{}:{}".format(*address) for address in self.mongoses)

View File

@ -20,10 +20,15 @@ import os
import socketserver
import sys
import threading
import time
from asyncio import StreamReader, StreamWriter
from pathlib import Path
from test.asynchronous.helpers import ConcurrentRunner
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.operations import _Op
from pymongo.server_selectors import writable_server_selector
sys.path[0:0] = [""]
from test.asynchronous import (
@ -370,6 +375,74 @@ class TestPoolManagement(AsyncIntegrationTest):
await listener.async_wait_for_event(monitoring.ServerHeartbeatSucceededEvent, 1)
await listener.async_wait_for_event(monitoring.PoolReadyEvent, 1)
@async_client_context.require_failCommand_appName
@async_client_context.require_test_commands
@async_client_context.require_async
async def test_connection_close_does_not_block_other_operations(self):
listener = CMAPHeartbeatListener()
client = await self.async_single_client(
appName="SDAMConnectionCloseTest",
event_listeners=[listener],
heartbeatFrequencyMS=500,
minPoolSize=10,
)
server = await (await client._get_topology()).select_server(
writable_server_selector, _Op.TEST
)
await async_wait_until(
lambda: len(server._pool.conns) == 10,
"pool initialized with 10 connections",
)
await client.db.test.insert_one({"x": 1})
close_delay = 0.1
latencies = []
should_exit = []
async def run_task():
while True:
start_time = time.monotonic()
await client.db.test.find_one({})
elapsed = time.monotonic() - start_time
latencies.append(elapsed)
if should_exit:
break
await asyncio.sleep(0.001)
task = ConcurrentRunner(target=run_task)
await task.start()
original_close = AsyncConnection.close_conn
try:
# Artificially delay the close operation to simulate a slow close
async def mock_close(self, reason):
await asyncio.sleep(close_delay)
await original_close(self, reason)
AsyncConnection.close_conn = mock_close
fail_hello = {
"mode": {"times": 4},
"data": {
"failCommands": [HelloCompat.LEGACY_CMD, "hello"],
"errorCode": 91,
"appName": "SDAMConnectionCloseTest",
},
}
async with self.fail_point(fail_hello):
# Wait for server heartbeat to fail
await listener.async_wait_for_event(monitoring.ServerHeartbeatFailedEvent, 1)
# Wait until all idle connections are closed to simulate real-world conditions
await listener.async_wait_for_event(monitoring.ConnectionClosedEvent, 10)
# Wait for one more find to complete after the pool has been reset, then shutdown the task
n = len(latencies)
await async_wait_until(lambda: len(latencies) >= n + 1, "run one more find")
should_exit.append(True)
await task.join()
# No operation latency should not significantly exceed close_delay
self.assertLessEqual(max(latencies), close_delay * 5.0)
finally:
AsyncConnection.close_conn = original_close
class TestServerMonitoringMode(AsyncIntegrationTest):
@async_client_context.require_no_serverless

View File

@ -20,10 +20,15 @@ import os
import socketserver
import sys
import threading
import time
from asyncio import StreamReader, StreamWriter
from pathlib import Path
from test.helpers import ConcurrentRunner
from pymongo.operations import _Op
from pymongo.server_selectors import writable_server_selector
from pymongo.synchronous.pool import Connection
sys.path[0:0] = [""]
from test import (
@ -370,6 +375,72 @@ class TestPoolManagement(IntegrationTest):
listener.wait_for_event(monitoring.ServerHeartbeatSucceededEvent, 1)
listener.wait_for_event(monitoring.PoolReadyEvent, 1)
@client_context.require_failCommand_appName
@client_context.require_test_commands
@client_context.require_async
def test_connection_close_does_not_block_other_operations(self):
listener = CMAPHeartbeatListener()
client = self.single_client(
appName="SDAMConnectionCloseTest",
event_listeners=[listener],
heartbeatFrequencyMS=500,
minPoolSize=10,
)
server = (client._get_topology()).select_server(writable_server_selector, _Op.TEST)
wait_until(
lambda: len(server._pool.conns) == 10,
"pool initialized with 10 connections",
)
client.db.test.insert_one({"x": 1})
close_delay = 0.1
latencies = []
should_exit = []
def run_task():
while True:
start_time = time.monotonic()
client.db.test.find_one({})
elapsed = time.monotonic() - start_time
latencies.append(elapsed)
if should_exit:
break
time.sleep(0.001)
task = ConcurrentRunner(target=run_task)
task.start()
original_close = Connection.close_conn
try:
# Artificially delay the close operation to simulate a slow close
def mock_close(self, reason):
time.sleep(close_delay)
original_close(self, reason)
Connection.close_conn = mock_close
fail_hello = {
"mode": {"times": 4},
"data": {
"failCommands": [HelloCompat.LEGACY_CMD, "hello"],
"errorCode": 91,
"appName": "SDAMConnectionCloseTest",
},
}
with self.fail_point(fail_hello):
# Wait for server heartbeat to fail
listener.wait_for_event(monitoring.ServerHeartbeatFailedEvent, 1)
# Wait until all idle connections are closed to simulate real-world conditions
listener.wait_for_event(monitoring.ConnectionClosedEvent, 10)
# Wait for one more find to complete after the pool has been reset, then shutdown the task
n = len(latencies)
wait_until(lambda: len(latencies) >= n + 1, "run one more find")
should_exit.append(True)
task.join()
# No operation latency should not significantly exceed close_delay
self.assertLessEqual(max(latencies), close_delay * 5.0)
finally:
Connection.close_conn = original_close
class TestServerMonitoringMode(IntegrationTest):
@client_context.require_no_serverless

View File

@ -288,7 +288,8 @@ def process_files(
if file in docstring_translate_files:
lines = translate_docstrings(lines)
if file in sync_test_files:
translate_imports(lines)
lines = translate_imports(lines)
lines = process_ignores(lines)
f.seek(0)
f.writelines(lines)
f.truncate()
@ -390,6 +391,14 @@ def translate_docstrings(lines: list[str]) -> list[str]:
return [line for line in lines if line != "DOCSTRING_REMOVED"]
def process_ignores(lines: list[str]) -> list[str]:
for i in range(len(lines)):
for k, v in replacements.items():
if "unasync: off" in lines[i] and v in lines[i]:
lines[i] = lines[i].replace(v, k)
return lines
def unasync_directory(files: list[str], src: str, dest: str, replacements: dict[str, str]) -> None:
unasync_files(
files,