SERVER-124955 Use tenacity for retry/timeout in ContinuousAddRemoveShard hook (#53074)

Co-authored-by: Tommaso Tocci <58224719+toto-dev@users.noreply.github.com>
GitOrigin-RevId: a982e37a3f587ab91ddbbda06ce075f5cd846879
This commit is contained in:
Daniele Alessandrelli 2026-05-08 11:36:18 +01:00 committed by MongoDB Bot
parent bd351fac69
commit 21739dc963
2 changed files with 198 additions and 185 deletions

View File

@ -2,6 +2,7 @@
which case it is transitioned in/out of config shard mode.
"""
import logging
import os.path
import random
import re
@ -10,6 +11,15 @@ import time
import bson
import pymongo.errors
from tenacity import (
RetryError,
Retrying,
before_sleep_log,
retry_if_exception,
retry_if_result,
stop_after_delay,
wait_exponential,
)
from buildscripts.resmokelib import errors
from buildscripts.resmokelib.hang_analyzer.hang_analyzer import HangAnalyzer
@ -17,6 +27,7 @@ from buildscripts.resmokelib.testing.fixtures import interface as fixture_interf
from buildscripts.resmokelib.testing.fixtures import shardedcluster
from buildscripts.resmokelib.testing.hooks import interface
from buildscripts.resmokelib.testing.hooks import lifecycle as lifecycle_interface
from buildscripts.resmokelib.testing.retry import is_retryable_error
from buildscripts.resmokelib.testing.retry import (
retryable_code_names as retryable_network_err_names,
)
@ -488,18 +499,50 @@ class _AddRemoveShardThread(threading.Thread):
return False
def _decomission_removed_shard(self, shard_obj):
start_time = time.monotonic()
def _retry_with_timeout(self, func, *, retry_predicate, step_name, shard_id=None):
"""Run `func` with `retry_predicate` until it succeeds or `TRANSITION_TIMEOUT_SECS` elapse.
On timeout, dumps process stacks via the hang analyzer and raises ServerFailure.
On success, logs a completion message with elapsed time.
"""
start_time = time.monotonic()
step_label = step_name + (
f" for shard '{shard_id}'" if shard_id and shard_id != "config" else ""
)
try:
result = Retrying(
retry=retry_predicate,
wait=wait_exponential(multiplier=0.1, min=0.1, max=1),
stop=stop_after_delay(self.TRANSITION_TIMEOUT_SECS),
before_sleep=before_sleep_log(self.logger, logging.INFO),
)(func)
except RetryError as err:
last = err.last_attempt
reason = (
f"last exception: {last.exception()}"
if last.failed
else f"last result: {last.result()}"
)
elapsed = time.monotonic() - start_time
msg = (
f"Timed out during {step_label} after {last.attempt_number} attempts "
f"in {elapsed:.1f}s "
f"(configured timeout {self.TRANSITION_TIMEOUT_SECS:.0f}s, {reason})"
)
self.logger.error(msg)
self._dump_stacks_on_timeout(msg)
raise errors.ServerFailure(msg) from err
self.logger.info(
"Completed %s in %0d ms", step_label, (time.monotonic() - start_time) * 1000
)
return result
def _decomission_removed_shard(self, shard_obj):
direct_shard_conn = pymongo.MongoClient(shard_obj.get_driver_connection_url())
with shard_obj.removeshard_teardown_mutex:
while True:
if time.monotonic() - start_time > self.TRANSITION_TIMEOUT_SECS:
msg = "Timed out waiting for removed shard to finish data clean up"
self.logger.error(msg)
raise errors.ServerFailure(msg)
def attempt():
# Wait until any DDL, resharding, transactions, and migration ops are cleaned up.
# TODO SERVER-90782 Change these to be assertions, rather than waiting for the collections
# to be empty
@ -507,8 +550,7 @@ class _AddRemoveShardThread(threading.Thread):
self.logger.info(
"Waiting for config.system.sharding_ddl_coordinators to be empty before decomissioning."
)
time.sleep(1)
continue
return False
if (
len(list(direct_shard_conn.config.localReshardingOperations.recipient.find()))
@ -517,23 +559,20 @@ class _AddRemoveShardThread(threading.Thread):
self.logger.info(
"Waiting for config.localReshardingOperations.recipient to be empty before decomissioning."
)
time.sleep(1)
continue
return False
if len(list(direct_shard_conn.config.transaction_coordinators.find())) != 0:
self.logger.info(
"Waiting for config.transaction_coordinators to be empty before decomissioning."
)
time.sleep(1)
continue
return False
# TODO SERVER-91474 Wait for ongoing transactions to finish on participants
if self._get_number_of_ongoing_transactions(direct_shard_conn) != 0:
self.logger.info(
"Waiting for ongoing transactions to commit or abort before decomissioning."
)
time.sleep(1)
continue
return False
# TODO SERVER-50144 Wait for config.rangeDeletions to be empty before decomissioning
@ -547,7 +586,13 @@ class _AddRemoveShardThread(threading.Thread):
self.logger.error(msg)
raise errors.ServerFailure(msg)
break
return True
self._retry_with_timeout(
attempt,
retry_predicate=retry_if_result(lambda done: not done),
step_name="decommissioning of removed shard",
)
for db_name in direct_shard_conn.list_database_names():
if db_name in ["admin", "config", "local"]:
@ -1117,76 +1162,60 @@ class _AddRemoveShardThread(threading.Thread):
)
time.sleep(1)
def _is_retryable_config_shard_error(self, err):
"""Retry predicate for _transition_to_config_shard_or_add_shard."""
# Some suites run with forced failovers, if transitioning fails with a retryable
# network error, we should retry.
retryable_error_codes = set(retryable_network_errs)
# Some suites kill the primary causing the request to fail with
# FailedToSatisfyReadPreference.
retryable_error_codes.add(self._FAILED_TO_SATISFY_READ_PREFERENCE)
if is_retryable_error(err, retryable_error_codes):
return True
if isinstance(err, pymongo.errors.OperationFailure):
# Some workloads kill sessions which may interrupt the transition.
if self._is_expected_transition_error_code(err.code):
return True
# If one of the nodes in the shard is killed just before the attempt to
# transition/addShard, addShard will fail because it will not be able to connect.
# The error code returned is not retryable (it is OperationFailed), so we check
# the specific error message as well.
if err.code == self._OPERATION_FAILED:
s = str(err)
if (
"Connection refused" in s
or "does not belong to replica set" in s
or any(n in s for n in retryable_network_err_names)
):
return True
return False
def _transition_to_config_shard_or_add_shard(self, shard_id, shard_host):
if shard_id == "config":
self.logger.info("Starting transition from " + self._current_config_mode)
else:
self.logger.info("Starting to add shard " + shard_id)
start_time = time.monotonic()
msg = "transitioning from dedicated" if shard_id == "config" else "adding shard"
while True:
# Ensure the hook doesn't get stuck in a silent retry loop.
self._check_transition_timeout(start_time, msg, shard_id)
try:
if shard_id == "config":
self._client.admin.command({"transitionFromDedicatedConfigServer": 1})
else:
original_shard_id = (
shard_id if self._shard_name_suffix == 0 else shard_id.split("_")[0]
)
shard_name = original_shard_id + "_" + str(self._shard_name_suffix)
self.logger.info("Adding shard with new shardId: " + shard_name)
self._client.admin.command({"addShard": shard_host, "name": shard_name})
self._shard_name_suffix = self._shard_name_suffix + 1
self.logger.info(
"Completed " + msg + " in %0d ms", (time.monotonic() - start_time) * 1000
def attempt():
if shard_id == "config":
self._client.admin.command({"transitionFromDedicatedConfigServer": 1})
else:
original_shard_id = (
shard_id if self._shard_name_suffix == 0 else shard_id.split("_")[0]
)
return
except pymongo.errors.AutoReconnect:
self.logger.info("AutoReconnect exception thrown, retrying...")
time.sleep(0.1)
except pymongo.errors.OperationFailure as err:
# Some suites run with forced failovers, if transitioning fails with a retryable
# network error, we should retry.
if err.code in set(retryable_network_errs):
self.logger.info(
"Network error when " + msg + " server, will retry. err: " + str(err)
)
time.sleep(1)
continue
shard_name = original_shard_id + "_" + str(self._shard_name_suffix)
self.logger.info("Adding shard with new shardId: " + shard_name)
self._client.admin.command({"addShard": shard_host, "name": shard_name})
self._shard_name_suffix = self._shard_name_suffix + 1
# If one of the nodes in the shard is killed just before the attempt to
# transition/addShard, addShard will fail because it will not be able to connect. The
# error code returned is not retryable (it is OperationFailed), so we check the specific
# error message as well.
if err.code in [self._OPERATION_FAILED] and (
"Connection refused" in str(err)
or "does not belong to replica set" in str(err)
or any(err_name in str(err) for err_name in retryable_network_err_names)
):
self.logger.info(
"Transient error when " + msg + ", will retry. err: " + str(err)
)
time.sleep(1)
continue
# Some suites kill the primary causing the request to fail with
# FailedToSatisfyReadPreference
if err.code in [self._FAILED_TO_SATISFY_READ_PREFERENCE]:
self.logger.info(
"Primary not found when " + msg + ", will retry. err: " + str(err)
)
time.sleep(1)
continue
# Some workloads kill sessions which may interrupt the transition.
if not self._is_expected_transition_error_code(err.code):
raise err
self.logger.info("Ignoring error " + msg + " : " + str(err))
self._retry_with_timeout(
attempt,
retry_predicate=retry_if_exception(self._is_retryable_config_shard_error),
step_name=msg,
shard_id=shard_id,
)
def _get_other_shard_info(self, shard_id):
res = self._client.admin.command({"listShards": 1})
@ -1231,117 +1260,96 @@ class _AddRemoveShardThread(threading.Thread):
)
return res[0]["num_ongoing_txns"] if res else 0
def _is_retryable_post_checks_error(self, err):
"""Retry predicate for _run_post_remove_shard_checks."""
# The checks run directly on shard nodes, so they may fail getting a connection if the
# shard node is killed (AutoReconnect, NotPrimaryError — both ConnectionFailure
# subclasses handled by is_retryable_error).
# Retry on retryable errors that might be thrown in suites with forced failovers.
retryable_error_codes = set(retryable_network_errs)
# Some workloads kill sessions which may interrupt the transition.
retryable_error_codes.add(self._INTERRUPTED)
return is_retryable_error(err, retryable_error_codes)
def _run_post_remove_shard_checks(self, removed_shard_fixture, removed_shard_name):
start_time = time.monotonic()
while True:
# Ensure the hook doesn't get stuck in a silent retry loop.
self._check_transition_timeout(
start_time, "post-remove-shard checks", removed_shard_name
def attempt():
# Configsvr metadata checks:
## Check that the removed shard no longer exists on config.shards.
assert (
self._client["config"]["shards"].count_documents({"_id": removed_shard_name}) == 0
), f"Removed shard still exists on config.shards: {removed_shard_name}"
## Check that no database has the removed shard as its primary shard.
databasesPointingToRemovedShard = [
doc
for doc in self._client["config"]["databases"].find({"primary": removed_shard_name})
]
assert not databasesPointingToRemovedShard, f"Found databases whose primary shard is a removed shard: {databasesPointingToRemovedShard}"
## Check that no chunk has the removed shard as its owner.
chunksPointingToRemovedShard = [
doc for doc in self._client["config"]["chunks"].find({"shard": removed_shard_name})
]
assert (
not chunksPointingToRemovedShard
), f"Found chunks whose owner is a removed shard: {chunksPointingToRemovedShard}"
## Check that all tag in config.tags refer to at least one existing shard.
tagsWithoutShardPipeline = [
{
"$lookup": {
"from": "shards",
"localField": "tag",
"foreignField": "tags",
"as": "shards",
}
},
{"$match": {"shards": []}},
]
tagsWithoutShardPipelineResultCursor = self._client["config"]["tags"].aggregate(
tagsWithoutShardPipeline
)
try:
# Configsvr metadata checks:
## Check that the removed shard no longer exists on config.shards.
assert (
self._client["config"]["shards"].count_documents({"_id": removed_shard_name})
== 0
), f"Removed shard still exists on config.shards: {removed_shard_name}"
## Check that no database has the removed shard as its primary shard.
databasesPointingToRemovedShard = [
doc
for doc in self._client["config"]["databases"].find(
{"primary": removed_shard_name}
)
]
assert not databasesPointingToRemovedShard, f"Found databases whose primary shard is a removed shard: {databasesPointingToRemovedShard}"
## Check that no chunk has the removed shard as its owner.
chunksPointingToRemovedShard = [
doc
for doc in self._client["config"]["chunks"].find({"shard": removed_shard_name})
]
assert (
not chunksPointingToRemovedShard
), f"Found chunks whose owner is a removed shard: {chunksPointingToRemovedShard}"
## Check that all tag in config.tags refer to at least one existing shard.
tagsWithoutShardPipeline = [
{
"$lookup": {
"from": "shards",
"localField": "tag",
"foreignField": "tags",
"as": "shards",
}
},
{"$match": {"shards": []}},
]
tagsWithoutShardPipelineResultCursor = self._client["config"]["tags"].aggregate(
tagsWithoutShardPipeline
)
tagsWithoutShardPipelineResult = [
doc for doc in tagsWithoutShardPipelineResultCursor
]
assert not tagsWithoutShardPipelineResult, f"Found tags in config.tags that are not owned by any shard: {tagsWithoutShardPipelineResult}"
if removed_shard_name != "config":
return
# Check that there is no user data left on the removed shard. (Note: This can only be
# checked on transitionToDedicatedConfigServer)
removed_shard_primary_client = removed_shard_fixture.get_primary().mongo_client()
dbs = removed_shard_primary_client.list_database_names()
assert all(
databaseName in {"local", "admin", "config"} for databaseName in dbs
), f"Expected to not have any user database on removed shard: {dbs}"
# Check the filtering metadata on removed shard. Expect that the shard knows that it does
# not own any chunk anymore. Check on all replica set nodes.
# First, await secondaries to replicate the last optime
removed_shard_fixture.await_last_op_committed(
removed_shard_fixture.AWAIT_REPL_TIMEOUT_FOREVER_MINS * 60
)
for removed_shard_node in [
removed_shard_fixture.get_primary()
] + removed_shard_fixture.get_secondaries():
sharding_state_response = removed_shard_node.mongo_client().admin.command(
{"shardingState": 1}
)
for nss, metadata in sharding_state_response["versions"].items():
# placementVersion == Timestamp(0, 0) means that this shard owns no chunk for the
# collection.
# TODO (SERVER-90810): Re-enable this check for resharding temporary collections.
if "system.resharding" in nss or "system.buckets.resharding" in nss:
continue
assert (
metadata["placementVersion"] == bson.Timestamp(0, 0)
), f"Expected remove shard's filtering information to reflect that the shard does not own any chunk for collection {nss}, but found {metadata} on node {removed_shard_node.get_driver_connection_url()}"
tagsWithoutShardPipelineResult = [doc for doc in tagsWithoutShardPipelineResultCursor]
assert not tagsWithoutShardPipelineResult, f"Found tags in config.tags that are not owned by any shard: {tagsWithoutShardPipelineResult}"
if removed_shard_name != "config":
return
except (pymongo.errors.AutoReconnect, pymongo.errors.NotPrimaryError) as err:
# The above operations run directly on a shard, so they may fail getting a
# connection if the shard node is killed.
self.logger.info(
"Connection error when running post removal checks, will retry. err: "
+ str(err)
)
continue
except pymongo.errors.OperationFailure as err:
# Retry on retryable errors that might be thrown in suites with forced failovers.
if err.code in set(retryable_network_errs):
self.logger.info(
"Retryable error when running post removal checks, will retry. err: "
+ str(err)
)
continue
if err.code in set([self._INTERRUPTED]):
# Some workloads kill sessions which may interrupt the transition.
self.logger.info(
"Received 'Interrupted' error when running post removal checks, will retry. err: "
+ str(err)
)
continue
raise err
# Check that there is no user data left on the removed shard. (Note: This can only be
# checked on transitionToDedicatedConfigServer)
removed_shard_primary_client = removed_shard_fixture.get_primary().mongo_client()
dbs = removed_shard_primary_client.list_database_names()
assert all(
databaseName in {"local", "admin", "config"} for databaseName in dbs
), f"Expected to not have any user database on removed shard: {dbs}"
# Check the filtering metadata on removed shard. Expect that the shard knows that it does
# not own any chunk anymore. Check on all replica set nodes.
# First, await secondaries to replicate the last optime
removed_shard_fixture.await_last_op_committed(
removed_shard_fixture.AWAIT_REPL_TIMEOUT_FOREVER_MINS * 60
)
for removed_shard_node in [
removed_shard_fixture.get_primary()
] + removed_shard_fixture.get_secondaries():
sharding_state_response = removed_shard_node.mongo_client().admin.command(
{"shardingState": 1}
)
for nss, metadata in sharding_state_response["versions"].items():
# placementVersion == Timestamp(0, 0) means that this shard owns no chunk for the
# collection.
# TODO (SERVER-90810): Re-enable this check for resharding temporary collections.
if "system.resharding" in nss or "system.buckets.resharding" in nss:
continue
assert (
metadata["placementVersion"] == bson.Timestamp(0, 0)
), f"Expected remove shard's filtering information to reflect that the shard does not own any chunk for collection {nss}, but found {metadata} on node {removed_shard_node.get_driver_connection_url()}"
self._retry_with_timeout(
attempt,
retry_predicate=retry_if_exception(self._is_retryable_post_checks_error),
step_name="post-remove-shard checks",
shard_id=removed_shard_name,
)

View File

@ -43,6 +43,11 @@ retryable_code_names = [
def is_retryable_error(exc, retryable_error_codes):
# Guard against non-PyMongoError exceptions: has_error_label() only exists on
# PyMongoError, so calling it on e.g. AssertionError or ServerFailure would raise
# AttributeError. Return False immediately for anything that isn't a pymongo error.
if not isinstance(exc, PyMongoError):
return False
if isinstance(exc, ConnectionFailure):
return True
if exc.has_error_label("RetryableWriteError"):