SERVER-124427 Add unit tests for ReshardingCoordinator to exercise different provenance scenarios (#52801)

GitOrigin-RevId: 7b7cbaa46512168a393e313e1d440f1ae4881d2f
This commit is contained in:
Kruti Shah 2026-05-04 17:31:10 -04:00 committed by MongoDB Bot
parent ff831c39b5
commit b180d2f19b
9 changed files with 1418 additions and 762 deletions

View File

@ -1129,10 +1129,12 @@ mongo_cc_unit_test(
mongo_cc_unit_test(
name = "db_s_config_server_part2_test",
srcs = [
"//src/mongo/db/s/resharding:reshard_collection_coordinator_test.cpp",
"//src/mongo/db/s/resharding:resharding_coordinator_commit_monitor_test.cpp",
"//src/mongo/db/s/resharding:resharding_coordinator_dao_test.cpp",
"//src/mongo/db/s/resharding:resharding_coordinator_observer_test.cpp",
"//src/mongo/db/s/resharding:resharding_coordinator_service_external_state_test.cpp",
"//src/mongo/db/s/resharding:resharding_coordinator_service_provenance_test.cpp",
"//src/mongo/db/s/resharding:resharding_coordinator_service_test.cpp",
"//src/mongo/db/s/resharding:resharding_coordinator_service_util_test.cpp",
"//src/mongo/db/s/resharding:resharding_coordinator_test.cpp",

View File

@ -114,18 +114,10 @@ void ReshardCollectionCoordinator::appendCommandInfo(BSONObjBuilder* cmdInfoBuil
}
BSONObj ReshardCollectionCoordinator::_computeFinalShardKey(const CurrentChunkManager& cmOld) {
auto provenance = _doc.getProvenance();
if (resharding::isRewriteCollection(provenance)) {
// rewriteCollection reshards the collection on its existing key.
return cmOld.getShardKeyPattern().getKeyPattern().toBSON();
}
if (cmOld.isTimeseriesCollection() && resharding::isOrdinaryReshardCollection(provenance)) {
const auto& tsOptions = cmOld.getTimeseriesFields().get().getTimeseriesOptions();
return shardkeyutil::validateAndTranslateTimeseriesShardKey(tsOptions, *_doc.getKey());
}
return *_doc.getKey();
return resharding::computeReshardingShardKey(_doc.getProvenance(),
cmOld.getShardKeyPattern(),
cmOld.getTimeseriesFields(),
_doc.getKey());
}
ExecutorFuture<void> ReshardCollectionCoordinator::_runImpl(
@ -201,35 +193,21 @@ ExecutorFuture<void> ReshardCollectionCoordinator::_runImpl(
configsvrReshardCollection.setPerformVerification(_doc.getPerformVerification());
auto provenance = _doc.getProvenance();
if (resharding::isMoveCollection(provenance)) {
uassert(ErrorCodes::NamespaceNotFound,
str::stream()
<< "MoveCollection can only be called on an unsharded collection.",
!cmOld.isSharded());
} else if (resharding::isUnshardCollection(provenance)) {
// Pick the "to" shard if the client did not specify one.
if (!_doc.getShardDistribution()) {
auto toShard = sharding_util::selectLeastLoadedNonDrainingShard(opCtx);
mongo::ShardKeyRange destinationRange(toShard);
destinationRange.setMin(cluster::unsplittable::kUnsplittableCollectionMinKey);
destinationRange.setMax(cluster::unsplittable::kUnsplittableCollectionMaxKey);
std::vector<mongo::ShardKeyRange> distribution = {destinationRange};
configsvrReshardCollection.setShardDistribution(distribution);
}
} else {
uassert(ErrorCodes::NamespaceNotSharded,
"Collection has to be a sharded collection.",
cmOld.isSharded());
resharding::validateReshardCollectionRequest(
provenance,
cmOld.isSharded(),
cmOld.getShardKeyPattern(),
finalShardKey,
_doc.getForceRedistribution().value_or(false));
if (_doc.getForceRedistribution() && *_doc.getForceRedistribution()) {
uassert(ErrorCodes::InvalidOptions,
str::stream()
<< "The new shard key must be the same as the original shard key "
"when using the forceRedistribution option. The "
"forceRedistribution option is meant for redistributing the "
"collection to a different set of shards.",
cmOld.getShardKeyPattern().isShardKey(finalShardKey));
}
// For unshardCollection, pick the destination shard if the client did not specify one.
if (resharding::isUnshardCollection(provenance) && !_doc.getShardDistribution()) {
auto toShard = sharding_util::selectLeastLoadedNonDrainingShard(opCtx);
mongo::ShardKeyRange destinationRange(toShard);
destinationRange.setMin(cluster::unsplittable::kUnsplittableCollectionMinKey);
destinationRange.setMax(cluster::unsplittable::kUnsplittableCollectionMaxKey);
std::vector<mongo::ShardKeyRange> distribution = {destinationRange};
configsvrReshardCollection.setShardDistribution(distribution);
}
configsvrReshardCollection.setProvenance(provenance);

View File

@ -0,0 +1,180 @@
/**
* Copyright (C) 2026-present MongoDB, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the Server Side Public License, version 1,
* as published by MongoDB, Inc.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Server Side Public License for more details.
*
* You should have received a copy of the Server Side Public License
* along with this program. If not, see
* <http://www.mongodb.com/licensing/server-side-public-license>.
*
* As a special exception, the copyright holders give permission to link the
* code of portions of this program with the OpenSSL library under certain
* conditions as described in each individual source file and distribute
* linked combinations including the program with the OpenSSL library. You
* must comply with the Server Side Public License in all respects for
* all of the code used other than as permitted herein. If you modify file(s)
* with this exception, you may extend this exception to your version of the
* file(s), but you are not obligated to do so. If you do not wish to do so,
* delete this exception statement from your version. If you delete this
* exception statement from all source files in the program, then also delete
* it in the license file.
*/
#include "mongo/base/error_codes.h"
#include "mongo/base/string_data.h"
#include "mongo/bson/bsonmisc.h"
#include "mongo/bson/bsonobj.h"
#include "mongo/db/global_catalog/shard_key_pattern.h"
#include "mongo/db/global_catalog/type_collection_common_types_gen.h"
#include "mongo/db/s/resharding/resharding_util.h"
#include "mongo/db/timeseries/timeseries_gen.h"
#include "mongo/idl/idl_parser.h"
#include "mongo/s/resharding/common_types_gen.h"
#include "mongo/unittest/assert.h"
#include "mongo/unittest/unittest.h"
#include "mongo/util/assert_util.h"
#include <string>
#include <boost/optional/optional.hpp>
#include <gtest/gtest.h>
namespace mongo {
namespace resharding {
namespace {
const ShardKeyPattern kSourceShardKey(BSON("a" << 1));
const BSONObj kUserShardKey = BSON("b" << 1);
TypeCollectionTimeseriesFields makeTimeseriesFields(StringData timeField,
boost::optional<StringData> metaField) {
TimeseriesOptions options{std::string{timeField}};
if (metaField) {
options.setMetaField(*metaField);
}
TypeCollectionTimeseriesFields fields;
fields.setTimeseriesOptions(std::move(options));
return fields;
}
class ReshardingProvenanceTest : public ::testing::TestWithParam<ReshardingProvenanceEnum> {
protected:
void validateCollection(bool isSharded,
const BSONObj& finalKey = kUserShardKey,
bool forceRedistribution = false) {
validateReshardCollectionRequest(
GetParam(), isSharded, kSourceShardKey, finalKey, forceRedistribution);
}
};
INSTANTIATE_TEST_SUITE_P(Provenance,
ReshardingProvenanceTest,
::testing::Values(ReshardingProvenanceEnum::kReshardCollection,
ReshardingProvenanceEnum::kMoveCollection,
ReshardingProvenanceEnum::kUnshardCollection,
ReshardingProvenanceEnum::kRewriteCollection),
[](const ::testing::TestParamInfo<ReshardingProvenanceEnum>& info) {
return std::string(idl::serialize(info.param));
});
// =====================================================================================
// computeReshardingShardKey
// =====================================================================================
TEST_P(ReshardingProvenanceTest, ComputeShardKey) {
auto result =
computeReshardingShardKey(GetParam(), kSourceShardKey, boost::none, kUserShardKey);
if (isRewriteCollection(GetParam()))
ASSERT_BSONOBJ_EQ(result, kSourceShardKey.getKeyPattern().toBSON());
else
ASSERT_BSONOBJ_EQ(result, kUserShardKey);
}
TEST_P(ReshardingProvenanceTest, ComputeShardKeyMissingUserKey) {
auto computeWithNoUserKey = [&] {
return computeReshardingShardKey(GetParam(), kSourceShardKey, boost::none, boost::none);
};
if (isRewriteCollection(GetParam()))
ASSERT_BSONOBJ_EQ(computeWithNoUserKey(), kSourceShardKey.getKeyPattern().toBSON());
else
ASSERT_THROWS_CODE(computeWithNoUserKey(), DBException, ErrorCodes::InvalidOptions);
}
TEST_P(ReshardingProvenanceTest, ComputeShardKeyTimeseries) {
// reshardCollection must translate userKey from user-facing field to bucket-level, otherwise it
// may incorrectly throws InvalidOptions.
const ShardKeyPattern tsSourceKey(BSON("meta.x" << 1));
const BSONObj tsUserKey = BSON("metadata.x" << 1);
auto tsFields = makeTimeseriesFields("time", StringData("metadata"));
auto result = computeReshardingShardKey(GetParam(), tsSourceKey, tsFields, tsUserKey);
if (isRewriteCollection(GetParam()) || isOrdinaryReshardCollection(GetParam()))
ASSERT_BSONOBJ_EQ(result, tsSourceKey.getKeyPattern().toBSON());
else
ASSERT_BSONOBJ_EQ(result, tsUserKey);
}
// =====================================================================================
// validateReshardCollectionRequest
// =====================================================================================
TEST_P(ReshardingProvenanceTest, ValidateAcceptsValidShardedness) {
if (isMoveCollection(GetParam())) {
validateCollection(/*isSharded=*/false);
} else {
validateCollection(/*isSharded=*/true);
}
}
TEST_P(ReshardingProvenanceTest, ValidateRejectsInvalidShardedness) {
if (isUnshardCollection(GetParam())) {
// unshardCollection does not validate sharded at the reshardCollectionCoordinator level.
return;
}
if (isMoveCollection(GetParam()))
ASSERT_THROWS_CODE(
validateCollection(/*isSharded=*/true), DBException, ErrorCodes::NamespaceNotFound);
else
ASSERT_THROWS_CODE(
validateCollection(/*isSharded=*/false), DBException, ErrorCodes::NamespaceNotSharded);
}
TEST_P(ReshardingProvenanceTest, ForceRedistributionIdentityKeyAccepted) {
bool isSharded = !isMoveCollection(GetParam());
validateCollection(
isSharded, kSourceShardKey.getKeyPattern().toBSON(), /*forceRedistribution=*/true);
}
TEST_P(ReshardingProvenanceTest, ForceRedistributionWithDifferentKey) {
if (isMoveCollection(GetParam()) || isUnshardCollection(GetParam())) {
// forceRedistribution is not applicable to moveCollection and unshardCollection.
return;
}
ASSERT_THROWS_CODE(
validateCollection(/*isSharded=*/true, kUserShardKey, /*forceRedistribution=*/true),
DBException,
ErrorCodes::InvalidOptions);
}
TEST(ValidateReshardCollectionRequestTest, NoProvenanceRequiresShardedSource) {
// When no provenance is provided, it defaults to reshardCollection.
ASSERT_THROWS_CODE(
validateReshardCollectionRequest(boost::none, false, kSourceShardKey, kUserShardKey, false),
DBException,
ErrorCodes::NamespaceNotSharded);
}
} // namespace
} // namespace resharding
} // namespace mongo

View File

@ -0,0 +1,280 @@
/**
* Copyright (C) 2026-present MongoDB, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the Server Side Public License, version 1,
* as published by MongoDB, Inc.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Server Side Public License for more details.
*
* You should have received a copy of the Server Side Public License
* along with this program. If not, see
* <http://www.mongodb.com/licensing/server-side-public-license>.
*
* As a special exception, the copyright holders give permission to link the
* code of portions of this program with the OpenSSL library under certain
* conditions as described in each individual source file and distribute
* linked combinations including the program with the OpenSSL library. You
* must comply with the Server Side Public License in all respects for
* all of the code used other than as permitted herein. If you modify file(s)
* with this exception, you may extend this exception to your version of the
* file(s), but you are not obligated to do so. If you do not wish to do so,
* delete this exception statement from your version. If you delete this
* exception statement from all source files in the program, then also delete
* it in the license file.
*/
#include "mongo/bson/bsonmisc.h"
#include "mongo/bson/bsonobjbuilder.h"
#include "mongo/db/dbdirectclient.h"
#include "mongo/db/global_catalog/type_chunk.h"
#include "mongo/db/global_catalog/type_collection.h"
#include "mongo/db/global_catalog/type_database_gen.h"
#include "mongo/db/s/resharding/resharding_coordinator.h"
#include "mongo/db/s/resharding/resharding_coordinator_service_test_fixture.h"
#include "mongo/db/s/resharding/resharding_cumulative_metrics.h"
#include "mongo/db/s/resharding/resharding_util.h"
#include "mongo/db/versioning_protocol/database_version.h"
#include "mongo/idl/idl_parser.h"
#include "mongo/unittest/unittest.h"
#include <string>
#include <vector>
#include <gtest/gtest.h>
namespace mongo {
namespace {
using resharding_coordinator_test::ExternalStateForTest;
using resharding_coordinator_test::ReshardingCoordinatorServiceTestCommon;
/**
* Test fixture parameterized by ReshardingProvenanceEnum.
*
* Three shards are configured: shard0, shard1, shard2. Per-provenance shard layout:
* - kReshardCollection / kRewriteCollection: donors {shard0,shard1}, recipients {shard0,shard1}.
* - kMoveCollection: donor {shard0}, recipient {shard2}
* - kUnshardCollection: donors {shard0,shard1}, recipient {shard2}.
*/
class ReshardingCoordinatorServiceProvenanceTest
: public ReshardingCoordinatorServiceTestCommon,
public ::testing::WithParamInterface<ReshardingProvenanceEnum> {
public:
std::vector<ShardId> getShardIds() const override {
return {_shard0, _shard1, _shard2};
}
ExternalStateForTest::Options getExternalStateOptions() const override {
ExternalStateForTest::Options opts;
for (const auto& id : donorShardIds()) {
opts.documentsToCopy.emplace(id, 50);
opts.documentsDelta.emplace(id, 5);
}
return opts;
}
std::vector<ShardId> donorShardIds() const {
return resharding::isMoveCollection(GetParam()) ? std::vector<ShardId>{_shard0}
: std::vector<ShardId>{_shard0, _shard1};
}
std::vector<ShardId> recipientShardIds() const {
if (resharding::isMoveCollection(GetParam()) ||
resharding::isUnshardCollection(GetParam())) {
return {_shard2};
}
return {_shard0, _shard1};
}
// True if the coordinator calls _stopMigrations on the source for this provenance.
bool sourceMigrationsAreBlocked() const {
return !resharding::isMoveCollection(GetParam());
}
ReshardingCumulativeMetrics* metricsBucket() const {
if (resharding::isOrdinaryReshardCollection(GetParam()))
return ReshardingCumulativeMetrics::getForResharding(getServiceContext());
if (resharding::isMoveCollection(GetParam()))
return ReshardingCumulativeMetrics::getForMoveCollection(getServiceContext());
if (resharding::isUnshardCollection(GetParam()))
return ReshardingCumulativeMetrics::getForUnshardCollection(getServiceContext());
if (resharding::isRewriteCollection(GetParam()))
return ReshardingCumulativeMetrics::getForRewriteCollection(getServiceContext());
MONGO_UNREACHABLE;
}
ReshardingCoordinatorDocument makeCoordinatorDoc() {
std::vector<DonorShardEntry> donors;
for (auto&& id : donorShardIds()) {
donors.push_back(DonorShardEntry{id, {}});
}
std::vector<RecipientShardEntry> recipients;
for (auto&& id : recipientShardIds()) {
recipients.push_back(RecipientShardEntry{id, {}});
}
CommonReshardingMetadata meta(
_reshardingUUID, _originalNss, _originalUUID, _tempNss, _newShardKey.toBSON());
meta.setStartTime(getServiceContext()->getFastClockSource()->now());
meta.setProvenance(GetParam());
ReshardingCoordinatorDocument doc(CoordinatorStateEnum::kUnused, donors, recipients);
doc.setCommonReshardingMetadata(meta);
resharding::emplaceCloneTimestampIfExists(doc, _cloneTimestamp);
doc.setDemoMode(true);
// moveCollection / unshardCollection consult shardDistribution to determine the
// destination shard during _isReshardingOpRedundant.
if (resharding::isMoveCollection(GetParam()) ||
resharding::isUnshardCollection(GetParam())) {
ShardKeyRange dest{recipientShardIds().front()};
dest.setMin(_newShardKey.getKeyPattern().globalMin());
dest.setMax(_newShardKey.getKeyPattern().globalMax());
doc.setShardDistribution(std::vector<ShardKeyRange>{dest});
}
return doc;
}
void insertCatalogEntries(const ReshardingCoordinatorDocument& doc) {
auto opCtx = operationContext();
DBDirectClient client(opCtx);
CollectionType origColl(_originalNss,
_originalEpoch,
_originalTimestamp,
opCtx->getServiceContext()->getPreciseClockSource()->now(),
_originalUUID,
_oldShardKey.getKeyPattern());
client.insert(NamespaceString::kConfigsvrCollectionsNamespace, origColl.toBSON());
DatabaseType dbDoc(doc.getSourceNss().dbName(),
doc.getDonorShards().front().getId(),
DatabaseVersion{UUID::gen(), Timestamp(1, 1)});
client.insert(NamespaceString::kConfigDatabasesNamespace, dbDoc.toBSON());
}
// Seed config.chunks for the source collection and return preset reshardedChunks for the
// recipient layout.
std::vector<ReshardedChunk> seedSourceChunksAndComputeReshardedChunks() {
auto opCtx = operationContext();
DBDirectClient client(opCtx);
const auto donors = donorShardIds();
ChunkVersion sourceVersion({_originalEpoch, _originalTimestamp}, {1, 0});
if (donors.size() == 1) {
ChunkType chunk(_originalUUID,
ChunkRange{_oldShardKey.getKeyPattern().globalMin(),
_oldShardKey.getKeyPattern().globalMax()},
sourceVersion,
donors.front());
chunk.setName(OID::gen());
client.insert(NamespaceString::kConfigsvrChunksNamespace, chunk.toConfigBSON());
} else {
ChunkType chunk1(
_originalUUID,
ChunkRange{_oldShardKey.getKeyPattern().globalMin(), BSON("oldShardKey" << 0)},
sourceVersion,
donors[0]);
chunk1.setName(OID::gen());
sourceVersion.incMinor();
ChunkType chunk2(
_originalUUID,
ChunkRange{BSON("oldShardKey" << 0), _oldShardKey.getKeyPattern().globalMax()},
sourceVersion,
donors[1]);
chunk2.setName(OID::gen());
client.insert(NamespaceString::kConfigsvrChunksNamespace, chunk1.toConfigBSON());
client.insert(NamespaceString::kConfigsvrChunksNamespace, chunk2.toConfigBSON());
}
// Build the preset reshardedChunks across recipientShardIds().
std::vector<ReshardedChunk> presetChunks;
const auto recipients = recipientShardIds();
if (recipients.size() == 1) {
presetChunks.emplace_back(recipients.front(),
_newShardKey.getKeyPattern().globalMin(),
_newShardKey.getKeyPattern().globalMax());
} else {
presetChunks.emplace_back(
recipients[0], _newShardKey.getKeyPattern().globalMin(), BSON("newShardKey" << 0));
presetChunks.emplace_back(
recipients[1], BSON("newShardKey" << 0), _newShardKey.getKeyPattern().globalMax());
}
return presetChunks;
}
protected:
const ShardId _shard0{"shard0000"};
const ShardId _shard1{"shard0001"};
const ShardId _shard2{"shard0002"};
};
INSTANTIATE_TEST_SUITE_P(Provenance,
ReshardingCoordinatorServiceProvenanceTest,
::testing::Values(ReshardingProvenanceEnum::kReshardCollection,
ReshardingProvenanceEnum::kRewriteCollection,
ReshardingProvenanceEnum::kMoveCollection,
ReshardingProvenanceEnum::kUnshardCollection),
[](const ::testing::TestParamInfo<ReshardingProvenanceEnum>& info) {
return std::string(idl::serialize(info.param));
});
// Drives the full coordinator state machine to kDone for each provenance and asserts the
// per-provenance side effects.
TEST_P(ReshardingCoordinatorServiceProvenanceTest, FullLifecycleSucceeds) {
auto opCtx = operationContext();
auto doc = makeCoordinatorDoc();
insertCatalogEntries(doc);
auto presetChunks = seedSourceChunksAndComputeReshardedChunks();
doc.setPresetReshardedChunks(presetChunks);
auto coordinator = ReshardingCoordinator::getOrCreate(opCtx, _service, doc.toBSON());
waitUntilCommittedCoordinatorDocReach(opCtx, CoordinatorStateEnum::kPreparingToDonate);
if (sourceMigrationsAreBlocked()) {
ASSERT_FALSE(getCollectionCatalogEntry(opCtx).getAllowMigrations());
} else {
ASSERT_TRUE(getCollectionCatalogEntry(opCtx).getAllowMigrations());
}
makeDonorsReadyToDonateWithAssert(opCtx);
waitUntilCommittedCoordinatorDocReach(opCtx, CoordinatorStateEnum::kCloning);
makeRecipientsFinishedCloningWithAssert(opCtx);
waitUntilCommittedCoordinatorDocReach(opCtx, CoordinatorStateEnum::kApplying);
coordinator->onOkayToEnterCritical();
waitUntilCommittedCoordinatorDocReach(opCtx, CoordinatorStateEnum::kBlockingWrites);
makeRecipientsBeInStrictConsistencyWithAssert(opCtx);
waitUntilCommittedCoordinatorDocReach(opCtx, CoordinatorStateEnum::kCommitting);
makeDonorsProceedToDoneWithAssert(opCtx);
makeRecipientsProceedToDoneWithAssert(opCtx);
coordinator->getCompletionFuture().get(opCtx);
// After commit the source collection's reshardingFields are removed and migrations are
// re-allowed (default true).
ASSERT_TRUE(getCollectionCatalogEntry(opCtx).getAllowMigrations());
if (resharding::isUnshardCollection(GetParam())) {
ASSERT_TRUE(getCollectionCatalogEntry(opCtx).getUnsplittable());
}
BSONObjBuilder bob;
metricsBucket()->reportForServerStatus(&bob);
auto metricsReport = bob.obj();
auto bucketMetrics = metricsReport.firstElement().Obj();
ASSERT_EQ(bucketMetrics["countStarted"].numberInt(), 1);
ASSERT_EQ(bucketMetrics["countSucceeded"].numberInt(), 1);
}
} // namespace
} // namespace mongo

View File

@ -30,63 +30,34 @@
#include "mongo/db/s/resharding/resharding_coordinator_service.h"
#include "mongo/base/error_codes.h"
#include "mongo/base/status_with.h"
#include "mongo/bson/bson_field.h"
#include "mongo/bson/bsonmisc.h"
#include "mongo/bson/bsonobjbuilder.h"
#include "mongo/client/dbclient_cursor.h"
#include "mongo/db/client.h"
#include "mongo/db/dbdirectclient.h"
#include "mongo/db/global_catalog/shard_key_pattern.h"
#include "mongo/db/global_catalog/type_collection.h"
#include "mongo/db/global_catalog/type_database_gen.h"
#include "mongo/db/global_catalog/type_shard.h"
#include "mongo/db/keypattern.h"
#include "mongo/db/op_observer/op_observer_registry.h"
#include "mongo/db/query/find_command.h"
#include "mongo/db/repl/member_state.h"
#include "mongo/db/repl/oplog.h"
#include "mongo/db/repl/optime.h"
#include "mongo/db/repl/replication_coordinator.h"
#include "mongo/db/repl/replication_coordinator_mock.h"
#include "mongo/db/repl/wait_for_majority_service.h"
#include "mongo/db/s/resharding/resharding_coordinator.h"
#include "mongo/db/s/resharding/resharding_coordinator_observer.h"
#include "mongo/db/s/resharding/resharding_coordinator_service_external_state.h"
#include "mongo/db/s/resharding/resharding_coordinator_service_test_fixture.h"
#include "mongo/db/s/resharding/resharding_cumulative_metrics.h"
#include "mongo/db/s/resharding/resharding_op_observer.h"
#include "mongo/db/s/resharding/resharding_service_test_helpers.h"
#include "mongo/db/s/resharding/resharding_util.h"
#include "mongo/db/s/transaction_coordinator_service.h"
#include "mongo/db/service_context.h"
#include "mongo/db/session/logical_session_cache.h"
#include "mongo/db/session/logical_session_cache_noop.h"
#include "mongo/db/session/session_catalog_mongod.h"
#include "mongo/db/shard_role/transaction_resources.h"
#include "mongo/db/sharding_environment/config_server_test_fixture.h"
#include "mongo/db/versioning_protocol/database_version.h"
#include "mongo/executor/mock_async_rpc.h"
#include "mongo/idl/idl_parser.h"
#include "mongo/idl/server_parameter_test_controller.h"
#include "mongo/logv2/log.h"
#include "mongo/s/resharding/resharding_coordinator_service_conflicting_op_in_progress_info.h"
#include "mongo/s/resharding/resharding_feature_flag_gen.h"
#include "mongo/s/resharding/type_collection_fields_gen.h"
#include "mongo/stdx/unordered_map.h"
#include "mongo/unittest/unittest.h"
#include "mongo/util/assert_util.h"
#include "mongo/util/clock_source.h"
#include "mongo/util/duration.h"
#include "mongo/util/fail_point.h"
#include "mongo/util/string_map.h"
#include "mongo/util/time_support.h"
#include <functional>
#include <string>
#include <absl/container/node_hash_map.h>
#include <boost/move/utility_core.hpp>
#include <boost/none.hpp>
#include <boost/optional.hpp>
#include <boost/optional/optional.hpp>
@ -104,351 +75,11 @@ using OpObserverForTest = resharding_service_test_helpers::
using PauseDuringStateTransitions =
resharding_service_test_helpers::PauseDuringStateTransitions<CoordinatorStateEnum>;
class ExternalStateForTest : public ReshardingCoordinatorExternalState {
public:
struct Options {
std::map<ShardId, int64_t> documentsToCopy;
std::map<ShardId, int64_t> documentsDelta;
boost::optional<ErrorCodes::Error> getDocumentsToCopyErrorCode;
boost::optional<ErrorCodes::Error> getDocumentsDeltaErrorCode;
boost::optional<ErrorCodes::Error> verifyClonedErrorCode;
boost::optional<ErrorCodes::Error> verifyFinalErrorCode;
bool blockInGetDocumentsDelta = false;
};
using resharding_coordinator_test::ExternalStateForTest;
using resharding_coordinator_test::ReshardingCoordinatorServiceForTest;
using resharding_coordinator_test::ReshardingCoordinatorServiceTestCommon;
enum class ExternalFunction {
kTellAllDonorsToRefresh,
kEstablishAllDonorsAsParticipants,
kEstablishAllRecipientsAsParticipants,
kGetDocumentsToCopyFromDonors,
kGetDocumentsDeltaFromDonors,
};
ExternalStateForTest(Options options)
: ReshardingCoordinatorExternalState(), _options(options) {}
ParticipantShardsAndChunks calculateParticipantShardsAndChunks(
OperationContext* opCtx,
const ReshardingCoordinatorDocument& coordinatorDoc,
std::vector<ReshardingZoneType> zones) override {
std::vector<ChunkType> initialChunks;
auto version = calculateChunkVersionForInitialChunks(opCtx);
// Use the provided shardIds from presetReshardedChunks to construct the
// recipient list.
if (const auto& chunks = coordinatorDoc.getPresetReshardedChunks()) {
for (const auto& reshardedChunk : *chunks) {
initialChunks.emplace_back(
coordinatorDoc.getReshardingUUID(),
ChunkRange{reshardedChunk.getMin(), reshardedChunk.getMax()},
version,
reshardedChunk.getRecipientShardId());
version.incMinor();
}
}
return ParticipantShardsAndChunks(
{coordinatorDoc.getDonorShards(), coordinatorDoc.getRecipientShards(), initialChunks});
}
bool searchIndexExistsForCollection(OperationContext* opCtx,
const NamespaceString& nss) override {
if (_searchIndexResults.empty()) {
return _searchIndexDefaultResult;
}
auto result = std::move(_searchIndexResults.front());
_searchIndexResults.erase(_searchIndexResults.begin());
return uassertStatusOK(result);
}
void pushSearchIndexResult(bool result) {
_searchIndexResults.push_back(result);
}
void pushSearchIndexError(ErrorCodes::Error errorCode) {
_searchIndexResults.push_back(
Status{errorCode, "Failing call to searchIndexExistsForCollection"});
}
void tellAllDonorsToRefresh(OperationContext* opCtx,
const NamespaceString& sourceNss,
const UUID& reshardingUUID,
const std::vector<mongo::DonorShardEntry>& donorShards,
const std::shared_ptr<executor::TaskExecutor>& executor,
CancellationToken token) override {
_maybeThrowErrorForFunction(opCtx, ExternalFunction::kTellAllDonorsToRefresh);
auto donorShardIds = resharding::extractShardIdsFromParticipantEntries(donorShards);
resharding::sendFlushReshardingStateChangeToShards(
opCtx, sourceNss, reshardingUUID, donorShardIds, executor, token);
}
void tellAllRecipientsToRefresh(OperationContext* opCtx,
const NamespaceString& nssToRefresh,
const UUID& reshardingUUID,
const std::vector<mongo::RecipientShardEntry>& recipientShards,
const std::shared_ptr<executor::TaskExecutor>& executor,
CancellationToken token) override {
auto recipientShardIds = resharding::extractShardIdsFromParticipantEntries(recipientShards);
resharding::sendFlushReshardingStateChangeToShards(
opCtx, nssToRefresh, reshardingUUID, recipientShardIds, executor, token);
}
void establishAllDonorsAsParticipants(OperationContext* opCtx,
const NamespaceString& sourceNss,
const std::vector<mongo::DonorShardEntry>& donorShards,
const std::shared_ptr<executor::TaskExecutor>& executor,
CancellationToken token) override {
_maybeThrowErrorForFunction(opCtx, ExternalFunction::kEstablishAllDonorsAsParticipants);
auto donorShardIds = resharding::extractShardIdsFromParticipantEntries(donorShards);
resharding::sendFlushRoutingTableCacheUpdatesToShards(
opCtx, sourceNss, donorShardIds, executor, token);
}
void establishAllRecipientsAsParticipants(
OperationContext* opCtx,
const NamespaceString& tempNss,
const std::vector<mongo::RecipientShardEntry>& recipientShards,
const std::shared_ptr<executor::TaskExecutor>& executor,
CancellationToken token) override {
_maybeThrowErrorForFunction(opCtx, ExternalFunction::kEstablishAllRecipientsAsParticipants);
auto recipientShardIds = resharding::extractShardIdsFromParticipantEntries(recipientShards);
resharding::sendFlushRoutingTableCacheUpdatesToShards(
opCtx, tempNss, recipientShardIds, executor, token);
}
std::map<ShardId, int64_t> getDocumentsToCopyFromDonors(
OperationContext* opCtx,
const std::shared_ptr<executor::TaskExecutor>& executor,
CancellationToken token,
const UUID& reshardingUUID,
const NamespaceString& nss,
const Timestamp& cloneTimestamp,
const std::map<ShardId, ShardVersion>& shardVersions) override {
_maybeThrowErrorForFunction(opCtx, ExternalFunction::kGetDocumentsToCopyFromDonors);
if (_options.getDocumentsToCopyErrorCode) {
uasserted(*_options.getDocumentsToCopyErrorCode, "Failing call to getDocumentsToCopy.");
}
std::map<ShardId, int64_t> docsToCopy;
for (const auto& [shardId, _] : shardVersions) {
auto it = _options.documentsToCopy.find(shardId);
ASSERT(it != _options.documentsToCopy.end());
docsToCopy.emplace(shardId, it->second);
}
return docsToCopy;
}
std::map<ShardId, int64_t> getDocumentsDeltaFromDonors(
OperationContext* opCtx,
const std::shared_ptr<executor::TaskExecutor>& executor,
CancellationToken token,
const UUID& reshardingUUID,
const NamespaceString& nss,
const std::vector<ShardId>& shardIds) override {
_maybeThrowErrorForFunction(opCtx, ExternalFunction::kGetDocumentsDeltaFromDonors);
if (_options.getDocumentsDeltaErrorCode) {
uasserted(*_options.getDocumentsDeltaErrorCode, "Failing call to getDocumentsDelta");
}
if (_options.blockInGetDocumentsDelta) {
std::unique_lock lk(_mutex);
opCtx->waitForConditionOrInterrupt(_blockInGetDocumentsDeltaCV, lk, [this] {
return !_doKeepBlockingInGetDocumentsDelta;
});
}
std::map<ShardId, int64_t> docsDelta;
for (const auto& shardId : shardIds) {
auto it = _options.documentsDelta.find(shardId);
ASSERT(it != _options.documentsDelta.end());
docsDelta.emplace(shardId, it->second);
}
return docsDelta;
}
void verifyClonedCollection(OperationContext* opCtx,
const std::shared_ptr<executor::TaskExecutor>& executor,
CancellationToken token,
const ReshardingCoordinatorDocument& coordinatorDoc) override {
if (_options.verifyClonedErrorCode) {
uasserted(*_options.verifyClonedErrorCode, "Failing cloned collection verification");
}
}
void verifyFinalCollection(OperationContext* opCtx,
const ReshardingCoordinatorDocument& coordinatorDoc) override {
if (_options.verifyFinalErrorCode) {
uasserted(*_options.verifyFinalErrorCode, "Failing final collection verification");
}
}
void stopMigrations(OperationContext* opCtx,
const NamespaceString& nss,
const UUID& expectedCollectionUUID,
const OperationSessionInfo& osi) override {
DBDirectClient client(opCtx);
client.update(NamespaceString::kConfigsvrCollectionsNamespace,
BSON(CollectionType::kNssFieldName << NamespaceStringUtil::serialize(
nss, SerializationContext::stateDefault())),
BSON("$set" << BSON(CollectionType::kAllowMigrationsFieldName << false)));
}
void resumeMigrations(OperationContext* opCtx,
const NamespaceString& nss,
const UUID& expectedCollectionUUID,
const OperationSessionInfo& osi) override {
DBDirectClient client(opCtx);
client.update(NamespaceString::kConfigsvrCollectionsNamespace,
BSON(CollectionType::kNssFieldName << NamespaceStringUtil::serialize(
nss, SerializationContext::stateDefault())),
BSON("$unset" << BSON(CollectionType::kAllowMigrationsFieldName << "")));
}
std::unique_ptr<CausalityBarrier> buildCausalityBarrier(std::vector<ShardId>,
std::shared_ptr<executor::TaskExecutor>,
CancellationToken) override {
_causalityBarrierInvokeCount.fetch_add(1);
class NoOpBarrier : public CausalityBarrier {
public:
// Unit tests have no real shard servers. Skip the no-op retryable write to avoid
// network errors.
void perform(OperationContext*, const OperationSessionInfo&) override {}
};
return std::make_unique<NoOpBarrier>();
}
int getCausalityBarrierInvokeCount() const {
return _causalityBarrierInvokeCount.load();
}
void throwUnrecoverableErrorIn(CoordinatorStateEnum phase, ExternalFunction func) {
_errorFunction = std::make_tuple(phase, func);
}
void unblockGetDocumentsDelta() {
std::lock_guard lk(_mutex);
_doKeepBlockingInGetDocumentsDelta = false;
_blockInGetDocumentsDeltaCV.notify_all();
}
private:
const Options _options;
std::atomic<int> _causalityBarrierInvokeCount{0};
boost::optional<std::tuple<CoordinatorStateEnum, ExternalFunction>> _errorFunction =
boost::none;
std::mutex _mutex;
stdx::condition_variable _blockInGetDocumentsDeltaCV;
bool _doKeepBlockingInGetDocumentsDelta = true;
std::vector<StatusWith<bool>> _searchIndexResults;
bool _searchIndexDefaultResult{false};
CoordinatorStateEnum _getCurrentPhaseOnDisk(OperationContext* opCtx) {
DBDirectClient client(opCtx);
auto doc = client.findOne(NamespaceString::kConfigReshardingOperationsNamespace, BSONObj{});
IDLParserContext errCtx("reshardingCoordFromTest");
auto parseDoc = ReshardingCoordinatorDocument::parse(doc, errCtx);
return parseDoc.getState();
}
void _maybeThrowErrorForFunction(OperationContext* opCtx, ExternalFunction func) {
if (_errorFunction) {
auto [expectedPhase, expectedFunction] = *_errorFunction;
auto currentPhase = _getCurrentPhaseOnDisk(opCtx);
if (currentPhase == expectedPhase && func == expectedFunction) {
uasserted(ErrorCodes::InternalError, "Simulating unrecoverable error for testing");
}
}
}
};
class ReshardingCoordinatorServiceForTest : public ReshardingCoordinatorService {
public:
explicit ReshardingCoordinatorServiceForTest(
ServiceContext* serviceContext, std::shared_ptr<ExternalStateForTest> externalState)
: ReshardingCoordinatorService(serviceContext),
_serviceContext(serviceContext),
_externalState(externalState) {}
std::shared_ptr<PrimaryOnlyService::Instance> constructInstance(BSONObj initialState) override {
return std::make_shared<ReshardingCoordinator>(
this,
ReshardingCoordinatorDocument::parse(initialState,
IDLParserContext("ReshardingCoordinatorStateDoc")),
_externalState,
_serviceContext);
}
private:
ServiceContext* _serviceContext;
std::shared_ptr<ExternalStateForTest> _externalState;
};
/**
* Wraps a delegate AsyncRPCRunner and asserts that every command sent by the
* ReshardingCoordinator carries both `lsid` and `txnNumber` (i.e. OperationSessionInfo),
* which are required for replay protection.
*/
class OsiCheckingAsyncRPCRunner : public async_rpc::detail::AsyncRPCRunner {
public:
explicit OsiCheckingAsyncRPCRunner(std::unique_ptr<async_rpc::detail::AsyncRPCRunner> inner)
: _inner(std::move(inner)) {}
ExecutorFuture<async_rpc::detail::AsyncRPCInternalResponse> _sendCommand(
std::shared_ptr<executor::TaskExecutor> exec,
CancellationToken token,
OperationContext* opCtx,
async_rpc::Targeter* targeter,
const TargetingMetadata& targetingMetadata,
const DatabaseName& dbName,
BSONObj cmdBSON,
BatonHandle baton,
boost::optional<UUID> clientOperationKey) final {
auto cmdName = cmdBSON.firstElementFieldNameStringData();
if (!kOsiExemptCommands.count(cmdName) &&
resharding::gFeatureFlagReshardingInitNoRefresh.isEnabled(
VersionContext::getDecoration(opCtx),
serverGlobalParams.featureCompatibility.acquireFCVSnapshot())) {
ASSERT(cmdBSON.hasField("lsid"))
<< "ReshardingCoordinator RPC '" << cmdName << "' missing lsid (OSI)";
ASSERT(cmdBSON.hasField("txnNumber"))
<< "ReshardingCoordinator RPC '" << cmdName << "' missing txnNumber (OSI)";
}
return _inner->_sendCommand(std::move(exec),
std::move(token),
opCtx,
targeter,
targetingMetadata,
dbName,
std::move(cmdBSON),
std::move(baton),
clientOperationKey);
}
private:
// Commands sent by the coordinator that are exempt from carrying OSI.
// _flushReshardingStateChange is idempotent, so OSI-based deduplication is unnecessary.
// One instance is also sent post-commit on a best-effort basis, after the coordinator
// document and its associated session have already been removed, making it impossible
// to include OSI. This command is expected to be removed once reshardingFields are no
// longer written to config.collections, when shards authoritatively manage their own
// filtering metadata.
static inline const StringSet kOsiExemptCommands{
"_flushReshardingStateChange",
};
std::unique_ptr<async_rpc::detail::AsyncRPCRunner> _inner;
};
class ReshardingCoordinatorServiceTestBase : service_context_test::WithSetupTransportLayer,
public ConfigServerTestFixture {
class ReshardingCoordinatorServiceTestBase : public ReshardingCoordinatorServiceTestCommon {
public:
struct ReshardingOptions {
const std::vector<ShardId> donorShardIds;
@ -478,84 +109,7 @@ public:
return {donorShardIds, recipientShardIds};
}
virtual ExternalStateForTest::Options getExternalStateOptions() const = 0;
std::unique_ptr<repl::PrimaryOnlyService> makeService(
ServiceContext* serviceContext, std::shared_ptr<ExternalStateForTest> externalState) {
return std::make_unique<ReshardingCoordinatorServiceForTest>(serviceContext, externalState);
}
void setUp() override {
ConfigServerTestFixture::setUp();
ShardType shard0;
shard0.setName(shardId0.toString());
shard0.setHost(shardId0.toString() + ":1234");
ShardType shard1;
shard1.setName(shardId1.toString());
shard1.setHost(shardId1.toString() + ":1234");
setupShards({shard0, shard1});
auto opCtx = operationContext();
DBDirectClient client(opCtx);
client.createCollection(NamespaceString::kSessionTransactionsTableNamespace);
client.createIndexes(NamespaceString::kSessionTransactionsTableNamespace,
{MongoDSessionCatalog::getConfigTxnPartialIndexSpec()});
client.createCollection(NamespaceString::kConfigReshardingOperationsNamespace);
client.createCollection(NamespaceString::kConfigsvrCollectionsNamespace);
LogicalSessionCache::set(getServiceContext(), std::make_unique<LogicalSessionCacheNoop>());
TransactionCoordinatorService::get(operationContext())
->initializeIfNeeded(operationContext(), /* term */ 1);
_controller = std::make_shared<CoordinatorStateTransitionController>();
WaitForMajorityService::get(getServiceContext()).startup(getServiceContext());
repl::createOplog(opCtx);
auto asyncRPCMock = std::make_unique<OsiCheckingAsyncRPCRunner>(
std::make_unique<async_rpc::NoopMockAsyncRPCRunner>());
async_rpc::detail::AsyncRPCRunner::set(getServiceContext(), std::move(asyncRPCMock));
_opObserverRegistry =
dynamic_cast<OpObserverRegistry*>(getServiceContext()->getOpObserver());
invariant(_opObserverRegistry);
_opObserverRegistry->addObserver(std::make_unique<ReshardingOpObserver>());
_opObserverRegistry->addObserver(std::make_unique<OpObserverForTest>(
_controller,
NamespaceString::kConfigReshardingOperationsNamespace,
[](const ReshardingCoordinatorDocument& stateDoc) { return stateDoc.getState(); }));
_registry = repl::PrimaryOnlyServiceRegistry::get(getServiceContext());
_externalState = std::make_shared<ExternalStateForTest>(getExternalStateOptions());
auto service = makeService(getServiceContext(), _externalState);
auto serviceName = service->getServiceName();
_registry->registerService(std::move(service));
_service = _registry->lookupServiceByName(serviceName);
_registry->onStartup(opCtx);
auto term = replicationCoordinator()->getMyLastAppliedOpTime().getTerm();
_registry->onStepUpComplete(opCtx, term);
}
void tearDown() override {
globalFailPointRegistry().disableAllFailpoints();
externalState()->unblockGetDocumentsDelta();
TransactionCoordinatorService::get(operationContext())->interruptForStepDown();
WaitForMajorityService::get(getServiceContext()).shutDown();
ConfigServerTestFixture::tearDown();
_registry->onShutdown();
}
CoordinatorStateTransitionController* controller() {
return _controller.get();
}
ExternalStateForTest* externalState() {
return _externalState.get();
}
ExternalStateForTest::Options getExternalStateOptions() const override = 0;
ReshardingCoordinatorDocument makeCoordinatorDoc(CoordinatorStateEnum state,
UUID reshardingUUID,
@ -592,198 +146,6 @@ public:
return doc;
}
std::shared_ptr<ReshardingCoordinator> getCoordinator(
OperationContext* opCtx, repl::PrimaryOnlyService::InstanceID instanceId) {
auto coordinator = getCoordinatorIfExists(opCtx, instanceId);
ASSERT_TRUE(bool(coordinator));
return coordinator;
}
std::shared_ptr<ReshardingCoordinator> getCoordinatorIfExists(
OperationContext* opCtx, repl::PrimaryOnlyService::InstanceID instanceId) {
auto [coordinatorOpt, _] = ReshardingCoordinator::lookup(opCtx, _service, instanceId);
return coordinatorOpt ? *coordinatorOpt : nullptr;
}
BSONObj getCoordinatorDocBSON(OperationContext* opCtx) {
DBDirectClient client(opCtx);
return client.findOne(NamespaceString::kConfigReshardingOperationsNamespace, BSONObj{});
}
ReshardingCoordinatorDocument getCoordinatorDoc(OperationContext* opCtx) {
DBDirectClient client(opCtx);
auto doc = client.findOne(NamespaceString::kConfigReshardingOperationsNamespace, BSONObj{});
IDLParserContext errCtx("reshardingCoordFromTest");
return ReshardingCoordinatorDocument::parse(doc, errCtx);
}
void checkCoordinatorDocumentRemoved(OperationContext* opCtx) {
const auto coordinatorColl = acquireCollection(
opCtx,
CollectionAcquisitionRequest{NamespaceString::kConfigReshardingOperationsNamespace,
PlacementConcern{boost::none, ShardVersion::UNTRACKED()},
repl::ReadConcernArgs::get(opCtx),
AcquisitionPrerequisites::kRead},
MODE_IS);
ASSERT_TRUE(coordinatorColl.exists());
ASSERT_TRUE(bool(coordinatorColl.getCollectionPtr()->isEmpty(opCtx)));
}
CollectionType getCollectionCatalogEntry(OperationContext* opCtx) {
DBDirectClient client(opCtx);
auto doc = client.findOne(NamespaceString::kConfigsvrCollectionsNamespace,
BSON(CollectionType::kNssFieldName << _originalNss.ns_forTest()));
return CollectionType{std::move(doc)};
}
CollectionType getTemporaryCollectionCatalogEntry(
OperationContext* opCtx, const ReshardingCoordinatorDocument& coordinatorDoc) {
DBDirectClient client(opCtx);
auto doc = client.findOne(NamespaceString::kConfigsvrCollectionsNamespace,
BSON(CollectionType::kNssFieldName
<< coordinatorDoc.getTempReshardingNss().ns_forTest()));
return CollectionType{std::move(doc)};
}
void updateCoordinatorDoc(OperationContext* opCtx,
const UUID& reshardingUUID,
const BSONObj& updates) {
DBDirectClient client(opCtx);
const BSONObj query(
BSON(ReshardingCoordinatorDocument::kReshardingUUIDFieldName << reshardingUUID));
client.update(NamespaceString::kConfigReshardingOperationsNamespace, query, updates);
}
void waitUntilCommittedCoordinatorDocReach(OperationContext* opCtx,
CoordinatorStateEnum state) {
DBDirectClient client(opCtx);
while (true) {
auto coordinatorDoc = getCoordinatorDocBSON(opCtx);
auto currentState = coordinatorDoc.getStringField("state");
if (currentState == idl::serialize(state)) {
break;
}
sleepmillis(50);
}
}
void makeDonorsReadyToDonateWithAssert(OperationContext* opCtx) {
auto coordDoc = getCoordinatorDoc(opCtx);
ASSERT_NE(coordDoc.getStartTime(), Date_t::min());
auto donorShards = coordDoc.getDonorShards();
auto keyPre = [](StringData suffix) {
return fmt::format("{}.$[].mutableState.{}",
ReshardingCoordinatorDocument::kDonorShardsFieldName,
suffix);
};
BSONObjBuilder updates;
{
BSONObjBuilder{updates.subobjStart("$set")}
.append(keyPre("state"), idl::serialize(DonorStateEnum::kDonatingInitialData))
.append(keyPre("minFetchTimestamp"), _cloneTimestamp)
.append(keyPre("bytesToClone"),
static_cast<long long>(totalApproxBytesToClone / donorShards.size()))
.append(keyPre("documentsToClone"),
static_cast<long long>(totalApproxDocumentsToClone / donorShards.size()));
}
updateCoordinatorDoc(opCtx, coordDoc.getReshardingUUID(), updates.obj());
}
void makeRecipientsFinishedCloningWithAssert(OperationContext* opCtx) {
auto coordDoc = getCoordinatorDoc(opCtx);
ASSERT_NE(coordDoc.getMetrics()->getDocumentCopy()->getStart(), Date_t::min());
BSONObj updates = BSON(
"$set" << BSON(std::string(ReshardingCoordinatorDocument::kRecipientShardsFieldName) +
".$[].mutableState.state"
<< idl::serialize(RecipientStateEnum::kApplying)));
updateCoordinatorDoc(opCtx, coordDoc.getReshardingUUID(), updates);
}
void makeRecipientsBeInStrictConsistencyWithAssert(OperationContext* opCtx) {
auto coordDoc = getCoordinatorDoc(opCtx);
ASSERT_LTE(coordDoc.getMetrics()->getOplogApplication()->getStart(),
coordDoc.getMetrics()->getOplogApplication()->getStop());
BSONObj updates = BSON(
"$set" << BSON(std::string(ReshardingCoordinatorDocument::kRecipientShardsFieldName) +
".$[].mutableState.state"
<< idl::serialize(RecipientStateEnum::kStrictConsistency)));
updateCoordinatorDoc(opCtx, coordDoc.getReshardingUUID(), updates);
}
void makeDonorsProceedToDone(OperationContext* opCtx, UUID reshardingUUID) {
BSONObj updates =
BSON("$set" << BSON(std::string(ReshardingCoordinatorDocument::kDonorShardsFieldName) +
".$[].mutableState.state"
<< idl::serialize(DonorStateEnum::kDone)));
updateCoordinatorDoc(opCtx, reshardingUUID, updates);
}
void makeDonorsProceedToDone(OperationContext* opCtx) {
auto coordDoc = getCoordinatorDoc(opCtx);
makeDonorsProceedToDone(opCtx, coordDoc.getReshardingUUID());
}
void makeDonorsProceedToDoneWithAssert(OperationContext* opCtx) {
auto coordDoc = getCoordinatorDoc(opCtx);
ASSERT_LTE(coordDoc.getMetrics()->getDocumentCopy()->getStart(),
coordDoc.getMetrics()->getDocumentCopy()->getStop());
makeDonorsProceedToDone(opCtx, coordDoc.getReshardingUUID());
}
void makeRecipientsProceedToDone(OperationContext* opCtx, UUID reshardingUUID) {
BSONObj updates = BSON(
"$set" << BSON(std::string(ReshardingCoordinatorDocument::kRecipientShardsFieldName) +
".$[].mutableState.state"
<< idl::serialize(RecipientStateEnum::kDone)));
updateCoordinatorDoc(opCtx, reshardingUUID, updates);
}
void makeRecipientsProceedToDone(OperationContext* opCtx) {
auto coordDoc = getCoordinatorDoc(opCtx);
makeRecipientsProceedToDone(opCtx, coordDoc.getReshardingUUID());
}
void makeRecipientsProceedToDoneWithAssert(OperationContext* opCtx) {
auto coordDoc = getCoordinatorDoc(opCtx);
ASSERT_LTE(coordDoc.getMetrics()->getDocumentCopy()->getStart(),
coordDoc.getMetrics()->getDocumentCopy()->getStop());
makeRecipientsProceedToDone(opCtx, coordDoc.getReshardingUUID());
}
void makeRecipientsReturnErrorWithAssert(OperationContext* opCtx) {
auto coordDoc = getCoordinatorDoc(opCtx);
ASSERT_NE(coordDoc.getMetrics()->getDocumentCopy()->getStart(), Date_t::min());
Status abortReasonStatus{ErrorCodes::SnapshotUnavailable, "test simulated error"};
BSONObjBuilder tmpBuilder;
abortReasonStatus.serialize(&tmpBuilder);
BSONObj updates =
BSON("$set" << BSON(
std::string(ReshardingCoordinatorDocument::kRecipientShardsFieldName) +
".$[].mutableState.state"
<< idl::serialize(RecipientStateEnum::kError)
<< std::string(ReshardingCoordinatorDocument::kRecipientShardsFieldName) +
".$[].mutableState.abortReason"
<< tmpBuilder.obj()));
updateCoordinatorDoc(opCtx, coordDoc.getReshardingUUID(), updates);
}
CollectionType makeOriginalCollectionCatalogEntry(
ReshardingCoordinatorDocument coordinatorDoc,
@ -909,51 +271,6 @@ public:
insertChunkAndZoneEntries(std::move(chunks), {});
}
void stepUp(OperationContext* opCtx) {
auto replCoord = repl::ReplicationCoordinator::get(getServiceContext());
auto currOpTime = replCoord->getMyLastAppliedOpTime();
// Advance the term and last applied opTime. We retain the timestamp component of the
// current last applied opTime to avoid log messages from
// ReplClientInfo::setLastOpToSystemLastOpTime() about the opTime having moved backwards.
++_term;
auto newOpTime = repl::OpTime{currOpTime.getTimestamp(), _term};
ASSERT_OK(replCoord->setFollowerMode(repl::MemberState::RS_PRIMARY));
ASSERT_OK(replCoord->updateTerm(opCtx, _term));
replCoord->setMyLastAppliedOpTimeAndWallTimeForward({newOpTime, {}});
_registry->onStepUpComplete(opCtx, _term);
}
void stepDown(OperationContext* opCtx) {
ASSERT_OK(repl::ReplicationCoordinator::get(getServiceContext())
->setFollowerMode(repl::MemberState::RS_SECONDARY));
_registry->onStepDown();
// Some opCtx can be created via AlternativeClientRegion and not tied to any resharding
// cancellation token, so we also need to simulate the repl step down killOp thread.
auto serviceCtx = opCtx->getServiceContext();
for (ServiceContext::LockedClientsCursor cursor(serviceCtx);
Client* client = cursor.next();) {
if (!client->canKillOperationInStepdown()) {
continue;
}
ClientLock lk(client);
OperationContext* toKill = client->getOperationContext();
if (toKill && !toKill->isKillPending() && toKill->getOpID() != opCtx->getOpID()) {
auto locker = shard_role_details::getLocker(toKill);
if (toKill->shouldAlwaysInterruptAtStepDownOrUp() ||
locker->wasGlobalLockTakenInModeConflictingWithWrites()) {
serviceCtx->killOperation(lk, toKill);
}
}
}
}
void killAllReshardingCoordinatorOps() {
for (ServiceContext::LockedClientsCursor cursor(getServiceContext());
Client* client = cursor.next();) {
@ -1424,28 +741,6 @@ public:
}
}
repl::PrimaryOnlyService* _service = nullptr;
std::shared_ptr<CoordinatorStateTransitionController> _controller;
std::shared_ptr<ExternalStateForTest> _externalState;
OpObserverRegistry* _opObserverRegistry = nullptr;
repl::PrimaryOnlyServiceRegistry* _registry = nullptr;
NamespaceString _originalNss = NamespaceString::createNamespaceString_forTest("db.foo");
UUID _originalUUID = UUID::gen();
OID _originalEpoch = OID::gen();
Timestamp _originalTimestamp = Timestamp(1);
NamespaceString _tempNss = NamespaceString::createNamespaceString_forTest(
"db.system.resharding." + _originalUUID.toString());
UUID _reshardingUUID = UUID::gen();
OID _tempEpoch = OID::gen();
Timestamp _tempTimestamp = Timestamp(2);
ShardKeyPattern _oldShardKey = ShardKeyPattern(BSON("oldShardKey" << 1));
ShardKeyPattern _newShardKey = ShardKeyPattern(BSON("newShardKey" << 1));
const std::vector<ChunkRange> _oldChunkRanges = {
ChunkRange(_oldShardKey.getKeyPattern().globalMin(), BSON("oldShardKey" << 12345)),
ChunkRange(BSON("oldShardKey" << 12345), _oldShardKey.getKeyPattern().globalMax()),
@ -1455,24 +750,14 @@ public:
ChunkRange(BSON("newShardKey" << 0), _newShardKey.getKeyPattern().globalMax()),
};
Timestamp _cloneTimestamp = Timestamp(Date_t::now());
RAIIServerParameterControllerForTest _serverParamController{
"reshardingMinimumOperationDurationMillis", 0};
long long _term = 0;
protected:
std::vector<ShardId> getShardIds() const {
std::vector<ShardId> getShardIds() const override {
return {shardId0, shardId1};
}
const ShardId shardId0{"shard0000"};
const ShardId shardId1{"shard0001"};
const long totalApproxBytesToClone = 10000;
const long totalApproxDocumentsToClone = 100;
const std::map<ShardId, int64_t> documentsToCopy{
{shardId0, 65},
{shardId1, 55},

View File

@ -0,0 +1,761 @@
/**
* Copyright (C) 2026-present MongoDB, Inc.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the Server Side Public License, version 1,
* as published by MongoDB, Inc.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* Server Side Public License for more details.
*
* You should have received a copy of the Server Side Public License
* along with this program. If not, see
* <http://www.mongodb.com/licensing/server-side-public-license>.
*
* As a special exception, the copyright holders give permission to link the
* code of portions of this program with the OpenSSL library under certain
* conditions as described in each individual source file and distribute
* linked combinations including the program with the OpenSSL library. You
* must comply with the Server Side Public License in all respects for
* all of the code used other than as permitted herein. If you modify file(s)
* with this exception, you may extend this exception to your version of the
* file(s), but you are not obligated to do so. If you do not wish to do so,
* delete this exception statement from your version. If you delete this
* exception statement from all source files in the program, then also delete
* it in the license file.
*/
#pragma once
#include "mongo/base/error_codes.h"
#include "mongo/bson/bsonmisc.h"
#include "mongo/bson/bsonobj.h"
#include "mongo/db/dbdirectclient.h"
#include "mongo/db/global_catalog/type_collection.h"
#include "mongo/db/op_observer/op_observer_registry.h"
#include "mongo/db/operation_context.h"
#include "mongo/db/repl/member_state.h"
#include "mongo/db/repl/oplog.h"
#include "mongo/db/repl/replication_coordinator.h"
#include "mongo/db/repl/wait_for_majority_service.h"
#include "mongo/db/s/resharding/resharding_coordinator.h"
#include "mongo/db/s/resharding/resharding_coordinator_service.h"
#include "mongo/db/s/resharding/resharding_coordinator_service_external_state.h"
#include "mongo/db/s/resharding/resharding_op_observer.h"
#include "mongo/db/s/resharding/resharding_service_test_helpers.h"
#include "mongo/db/s/resharding/resharding_util.h"
#include "mongo/db/s/transaction_coordinator_service.h"
#include "mongo/db/server_options.h"
#include "mongo/db/session/logical_session_cache.h"
#include "mongo/db/session/logical_session_cache_noop.h"
#include "mongo/db/session/session_catalog_mongod.h"
#include "mongo/db/shard_role/transaction_resources.h"
#include "mongo/db/sharding_environment/config_server_test_fixture.h"
#include "mongo/executor/mock_async_rpc.h"
#include "mongo/idl/idl_parser.h"
#include "mongo/idl/server_parameter_test_controller.h"
#include "mongo/s/resharding/resharding_feature_flag_gen.h"
#include "mongo/unittest/assert.h"
#include "mongo/util/assert_util.h"
#include "mongo/util/fail_point.h"
#include <atomic>
#include <condition_variable>
#include <map>
#include <memory>
#include <mutex>
#include <tuple>
#include <utility>
#include <vector>
#include <boost/optional.hpp>
namespace mongo {
namespace resharding_coordinator_test {
class ExternalStateForTest : public ReshardingCoordinatorExternalState {
public:
struct Options {
std::map<ShardId, int64_t> documentsToCopy;
std::map<ShardId, int64_t> documentsDelta;
boost::optional<ErrorCodes::Error> getDocumentsToCopyErrorCode;
boost::optional<ErrorCodes::Error> getDocumentsDeltaErrorCode;
boost::optional<ErrorCodes::Error> verifyClonedErrorCode;
boost::optional<ErrorCodes::Error> verifyFinalErrorCode;
bool blockInGetDocumentsDelta = false;
};
enum class ExternalFunction {
kTellAllDonorsToRefresh,
kEstablishAllDonorsAsParticipants,
kEstablishAllRecipientsAsParticipants,
kGetDocumentsToCopyFromDonors,
kGetDocumentsDeltaFromDonors,
};
explicit ExternalStateForTest(Options options)
: ReshardingCoordinatorExternalState(), _options(std::move(options)) {}
ParticipantShardsAndChunks calculateParticipantShardsAndChunks(
OperationContext* opCtx,
const ReshardingCoordinatorDocument& coordinatorDoc,
std::vector<ReshardingZoneType> zones) override {
std::vector<ChunkType> initialChunks;
auto version = calculateChunkVersionForInitialChunks(opCtx);
if (const auto& chunks = coordinatorDoc.getPresetReshardedChunks()) {
for (const auto& reshardedChunk : *chunks) {
initialChunks.emplace_back(
coordinatorDoc.getReshardingUUID(),
ChunkRange{reshardedChunk.getMin(), reshardedChunk.getMax()},
version,
reshardedChunk.getRecipientShardId());
version.incMinor();
}
}
return ParticipantShardsAndChunks{
coordinatorDoc.getDonorShards(), coordinatorDoc.getRecipientShards(), initialChunks};
}
bool searchIndexExistsForCollection(OperationContext* opCtx, const NamespaceString& nss) {
if (_searchIndexResults.empty()) {
return _searchIndexDefaultResult;
}
auto result = std::move(_searchIndexResults.front());
_searchIndexResults.erase(_searchIndexResults.begin());
return uassertStatusOK(result);
}
void pushSearchIndexResult(bool result) {
_searchIndexResults.push_back(result);
}
void pushSearchIndexError(ErrorCodes::Error errorCode) {
_searchIndexResults.push_back(
Status{errorCode, "Failing call to searchIndexExistsForCollection"});
}
void tellAllDonorsToRefresh(OperationContext* opCtx,
const NamespaceString& sourceNss,
const UUID& reshardingUUID,
const std::vector<DonorShardEntry>& donorShards,
const std::shared_ptr<executor::TaskExecutor>& executor,
CancellationToken token) override {
_maybeThrowErrorForFunction(opCtx, ExternalFunction::kTellAllDonorsToRefresh);
resharding::sendFlushReshardingStateChangeToShards(
opCtx,
sourceNss,
reshardingUUID,
resharding::extractShardIdsFromParticipantEntries(donorShards),
executor,
token);
}
void tellAllRecipientsToRefresh(OperationContext* opCtx,
const NamespaceString& nssToRefresh,
const UUID& reshardingUUID,
const std::vector<RecipientShardEntry>& recipientShards,
const std::shared_ptr<executor::TaskExecutor>& executor,
CancellationToken token) override {
resharding::sendFlushReshardingStateChangeToShards(
opCtx,
nssToRefresh,
reshardingUUID,
resharding::extractShardIdsFromParticipantEntries(recipientShards),
executor,
token);
}
void establishAllDonorsAsParticipants(OperationContext* opCtx,
const NamespaceString& sourceNss,
const std::vector<DonorShardEntry>& donorShards,
const std::shared_ptr<executor::TaskExecutor>& executor,
CancellationToken token) override {
_maybeThrowErrorForFunction(opCtx, ExternalFunction::kEstablishAllDonorsAsParticipants);
resharding::sendFlushRoutingTableCacheUpdatesToShards(
opCtx,
sourceNss,
resharding::extractShardIdsFromParticipantEntries(donorShards),
executor,
token);
}
void establishAllRecipientsAsParticipants(
OperationContext* opCtx,
const NamespaceString& tempNss,
const std::vector<RecipientShardEntry>& recipientShards,
const std::shared_ptr<executor::TaskExecutor>& executor,
CancellationToken token) override {
_maybeThrowErrorForFunction(opCtx, ExternalFunction::kEstablishAllRecipientsAsParticipants);
resharding::sendFlushRoutingTableCacheUpdatesToShards(
opCtx,
tempNss,
resharding::extractShardIdsFromParticipantEntries(recipientShards),
executor,
token);
}
std::map<ShardId, int64_t> getDocumentsToCopyFromDonors(
OperationContext* opCtx,
const std::shared_ptr<executor::TaskExecutor>&,
CancellationToken,
const UUID&,
const NamespaceString&,
const Timestamp&,
const std::map<ShardId, ShardVersion>& shardVersions) override {
_maybeThrowErrorForFunction(opCtx, ExternalFunction::kGetDocumentsToCopyFromDonors);
if (_options.getDocumentsToCopyErrorCode) {
uasserted(*_options.getDocumentsToCopyErrorCode, "Failing call to getDocumentsToCopy.");
}
std::map<ShardId, int64_t> docsToCopy;
for (const auto& [shardId, _] : shardVersions) {
auto it = _options.documentsToCopy.find(shardId);
ASSERT(it != _options.documentsToCopy.end());
docsToCopy.emplace(shardId, it->second);
}
return docsToCopy;
}
std::map<ShardId, int64_t> getDocumentsDeltaFromDonors(
OperationContext* opCtx,
const std::shared_ptr<executor::TaskExecutor>&,
CancellationToken,
const UUID&,
const NamespaceString&,
const std::vector<ShardId>& shardIds) override {
_maybeThrowErrorForFunction(opCtx, ExternalFunction::kGetDocumentsDeltaFromDonors);
if (_options.getDocumentsDeltaErrorCode) {
uasserted(*_options.getDocumentsDeltaErrorCode, "Failing call to getDocumentsDelta");
}
if (_options.blockInGetDocumentsDelta) {
std::unique_lock lk(_mutex);
opCtx->waitForConditionOrInterrupt(_blockInGetDocumentsDeltaCV, lk, [this] {
return !_doKeepBlockingInGetDocumentsDelta;
});
}
std::map<ShardId, int64_t> docsDelta;
for (const auto& shardId : shardIds) {
auto it = _options.documentsDelta.find(shardId);
ASSERT(it != _options.documentsDelta.end());
docsDelta.emplace(shardId, it->second);
}
return docsDelta;
}
void verifyClonedCollection(OperationContext*,
const std::shared_ptr<executor::TaskExecutor>&,
CancellationToken,
const ReshardingCoordinatorDocument&) override {
if (_options.verifyClonedErrorCode) {
uasserted(*_options.verifyClonedErrorCode, "Failing cloned collection verification");
}
}
void verifyFinalCollection(OperationContext*, const ReshardingCoordinatorDocument&) override {
if (_options.verifyFinalErrorCode) {
uasserted(*_options.verifyFinalErrorCode, "Failing final collection verification");
}
}
void stopMigrations(OperationContext* opCtx,
const NamespaceString& nss,
const UUID&,
const OperationSessionInfo&) override {
DBDirectClient client(opCtx);
client.update(NamespaceString::kConfigsvrCollectionsNamespace,
BSON(CollectionType::kNssFieldName << NamespaceStringUtil::serialize(
nss, SerializationContext::stateDefault())),
BSON("$set" << BSON(CollectionType::kAllowMigrationsFieldName << false)));
}
void resumeMigrations(OperationContext* opCtx,
const NamespaceString& nss,
const UUID&,
const OperationSessionInfo&) override {
DBDirectClient client(opCtx);
client.update(NamespaceString::kConfigsvrCollectionsNamespace,
BSON(CollectionType::kNssFieldName << NamespaceStringUtil::serialize(
nss, SerializationContext::stateDefault())),
BSON("$unset" << BSON(CollectionType::kAllowMigrationsFieldName << "")));
}
std::unique_ptr<CausalityBarrier> buildCausalityBarrier(std::vector<ShardId>,
std::shared_ptr<executor::TaskExecutor>,
CancellationToken) override {
_causalityBarrierInvokeCount.fetch_add(1);
class NoOpBarrier : public CausalityBarrier {
public:
// Unit tests have no real shard servers. Skip the no-op retryable write to avoid
// network errors.
void perform(OperationContext*, const OperationSessionInfo&) override {}
};
return std::make_unique<NoOpBarrier>();
}
int getCausalityBarrierInvokeCount() const {
return _causalityBarrierInvokeCount.load();
}
void throwUnrecoverableErrorIn(CoordinatorStateEnum phase, ExternalFunction func) {
_errorFunction = std::make_tuple(phase, func);
}
void unblockGetDocumentsDelta() {
std::lock_guard lk(_mutex);
_doKeepBlockingInGetDocumentsDelta = false;
_blockInGetDocumentsDeltaCV.notify_all();
}
private:
const Options _options;
std::atomic<int> _causalityBarrierInvokeCount{0};
boost::optional<std::tuple<CoordinatorStateEnum, ExternalFunction>> _errorFunction =
boost::none;
std::mutex _mutex;
stdx::condition_variable _blockInGetDocumentsDeltaCV;
bool _doKeepBlockingInGetDocumentsDelta = true;
std::vector<StatusWith<bool>> _searchIndexResults;
bool _searchIndexDefaultResult{false};
CoordinatorStateEnum _getCurrentPhaseOnDisk(OperationContext* opCtx) {
DBDirectClient client(opCtx);
auto doc = client.findOne(NamespaceString::kConfigReshardingOperationsNamespace, BSONObj{});
IDLParserContext errCtx("reshardingCoordFromTest");
return ReshardingCoordinatorDocument::parse(doc, errCtx).getState();
}
void _maybeThrowErrorForFunction(OperationContext* opCtx, ExternalFunction func) {
if (_errorFunction) {
auto [expectedPhase, expectedFunction] = *_errorFunction;
if (_getCurrentPhaseOnDisk(opCtx) == expectedPhase && func == expectedFunction) {
uasserted(ErrorCodes::InternalError, "Simulating unrecoverable error for testing");
}
}
}
};
class ReshardingCoordinatorServiceForTest : public ReshardingCoordinatorService {
public:
ReshardingCoordinatorServiceForTest(ServiceContext* serviceContext,
std::shared_ptr<ExternalStateForTest> externalState)
: ReshardingCoordinatorService(serviceContext),
_serviceContext(serviceContext),
_externalState(std::move(externalState)) {}
std::shared_ptr<PrimaryOnlyService::Instance> constructInstance(BSONObj initialState) override {
return std::make_shared<ReshardingCoordinator>(
this,
ReshardingCoordinatorDocument::parse(initialState,
IDLParserContext("ReshardingCoordinatorStateDoc")),
_externalState,
_serviceContext);
}
private:
ServiceContext* _serviceContext;
std::shared_ptr<ExternalStateForTest> _externalState;
};
/**
* Wraps a delegate AsyncRPCRunner and asserts that every command sent by the
* ReshardingCoordinator carries both `lsid` and `txnNumber` (i.e. OperationSessionInfo),
* which are required for replay protection.
*/
class OsiCheckingAsyncRPCRunner : public async_rpc::detail::AsyncRPCRunner {
public:
explicit OsiCheckingAsyncRPCRunner(std::unique_ptr<async_rpc::detail::AsyncRPCRunner> inner)
: _inner(std::move(inner)) {}
ExecutorFuture<async_rpc::detail::AsyncRPCInternalResponse> _sendCommand(
std::shared_ptr<executor::TaskExecutor> exec,
CancellationToken token,
OperationContext* opCtx,
async_rpc::Targeter* targeter,
const TargetingMetadata& targetingMetadata,
const DatabaseName& dbName,
BSONObj cmdBSON,
BatonHandle baton,
boost::optional<UUID> clientOperationKey) final {
auto cmdName = cmdBSON.firstElementFieldNameStringData();
if (!kOsiExemptCommands.count(cmdName) &&
resharding::gFeatureFlagReshardingInitNoRefresh.isEnabled(
VersionContext::getDecoration(opCtx),
serverGlobalParams.featureCompatibility.acquireFCVSnapshot())) {
ASSERT(cmdBSON.hasField("lsid"))
<< "ReshardingCoordinator RPC '" << cmdName << "' missing lsid (OSI)";
ASSERT(cmdBSON.hasField("txnNumber"))
<< "ReshardingCoordinator RPC '" << cmdName << "' missing txnNumber (OSI)";
}
return _inner->_sendCommand(std::move(exec),
std::move(token),
opCtx,
targeter,
targetingMetadata,
dbName,
std::move(cmdBSON),
std::move(baton),
clientOperationKey);
}
private:
// Commands sent by the coordinator that are exempt from carrying OSI.
// _flushReshardingStateChange is idempotent, so OSI-based deduplication is unnecessary.
// One instance is also sent post-commit on a best-effort basis, after the coordinator
// document and its associated session have already been removed, making it impossible
// to include OSI. This command is expected to be removed once reshardingFields are no
// longer written to config.collections, when shards authoritatively manage their own
// filtering metadata.
static inline const StringSet kOsiExemptCommands{
"_flushReshardingStateChange",
};
std::unique_ptr<async_rpc::detail::AsyncRPCRunner> _inner;
};
/**
* Provenance-agnostic base fixture for reshardingCoordinator tests.
*/
class ReshardingCoordinatorServiceTestCommon : public service_context_test::WithSetupTransportLayer,
public ConfigServerTestFixture {
public:
using CoordinatorStateTransitionController =
resharding_service_test_helpers::StateTransitionController<CoordinatorStateEnum>;
using OpObserverForTest = resharding_service_test_helpers::
StateTransitionControllerOpObserver<CoordinatorStateEnum, ReshardingCoordinatorDocument>;
virtual std::vector<ShardId> getShardIds() const = 0;
virtual ExternalStateForTest::Options getExternalStateOptions() const = 0;
std::unique_ptr<repl::PrimaryOnlyService> makeService(
ServiceContext* serviceContext, std::shared_ptr<ExternalStateForTest> externalState) {
async_rpc::detail::AsyncRPCRunner::set(
serviceContext,
std::make_unique<OsiCheckingAsyncRPCRunner>(
std::make_unique<async_rpc::NoopMockAsyncRPCRunner>()));
return std::make_unique<ReshardingCoordinatorServiceForTest>(serviceContext, externalState);
}
void setUp() override {
ConfigServerTestFixture::setUp();
std::vector<ShardType> shards;
for (const auto& id : getShardIds()) {
ShardType s;
s.setName(id.toString());
s.setHost(id.toString() + ":1234");
shards.push_back(std::move(s));
}
setupShards(shards);
auto opCtx = operationContext();
DBDirectClient client(opCtx);
client.createCollection(NamespaceString::kSessionTransactionsTableNamespace);
client.createIndexes(NamespaceString::kSessionTransactionsTableNamespace,
{MongoDSessionCatalog::getConfigTxnPartialIndexSpec()});
client.createCollection(NamespaceString::kConfigReshardingOperationsNamespace);
client.createCollection(NamespaceString::kConfigsvrCollectionsNamespace);
LogicalSessionCache::set(getServiceContext(), std::make_unique<LogicalSessionCacheNoop>());
TransactionCoordinatorService::get(opCtx)->initializeIfNeeded(opCtx, /*term*/ 1);
_controller = std::make_shared<CoordinatorStateTransitionController>();
WaitForMajorityService::get(getServiceContext()).startup(getServiceContext());
repl::createOplog(opCtx);
_opObserverRegistry =
dynamic_cast<OpObserverRegistry*>(getServiceContext()->getOpObserver());
invariant(_opObserverRegistry);
_opObserverRegistry->addObserver(std::make_unique<ReshardingOpObserver>());
_opObserverRegistry->addObserver(std::make_unique<OpObserverForTest>(
_controller,
NamespaceString::kConfigReshardingOperationsNamespace,
[](const ReshardingCoordinatorDocument& doc) { return doc.getState(); }));
_registry = repl::PrimaryOnlyServiceRegistry::get(getServiceContext());
_externalState = std::make_shared<ExternalStateForTest>(getExternalStateOptions());
auto service = makeService(getServiceContext(), _externalState);
auto serviceName = service->getServiceName();
_registry->registerService(std::move(service));
_service = _registry->lookupServiceByName(serviceName);
_registry->onStartup(opCtx);
auto term = replicationCoordinator()->getMyLastAppliedOpTime().getTerm();
_registry->onStepUpComplete(opCtx, term);
}
void tearDown() override {
globalFailPointRegistry().disableAllFailpoints();
externalState()->unblockGetDocumentsDelta();
TransactionCoordinatorService::get(operationContext())->interruptForStepDown();
WaitForMajorityService::get(getServiceContext()).shutDown();
ConfigServerTestFixture::tearDown();
_registry->onShutdown();
}
CoordinatorStateTransitionController* controller() {
return _controller.get();
}
ExternalStateForTest* externalState() {
return _externalState.get();
}
std::shared_ptr<ReshardingCoordinator> getCoordinator(
OperationContext* opCtx, repl::PrimaryOnlyService::InstanceID instanceId) {
auto coordinator = getCoordinatorIfExists(opCtx, instanceId);
ASSERT_TRUE(bool(coordinator));
return coordinator;
}
std::shared_ptr<ReshardingCoordinator> getCoordinatorIfExists(
OperationContext* opCtx, repl::PrimaryOnlyService::InstanceID instanceId) {
auto [coordinatorOpt, _] = ReshardingCoordinator::lookup(opCtx, _service, instanceId);
return coordinatorOpt ? *coordinatorOpt : nullptr;
}
BSONObj getCoordinatorDocBSON(OperationContext* opCtx) {
DBDirectClient client(opCtx);
return client.findOne(NamespaceString::kConfigReshardingOperationsNamespace, BSONObj{});
}
ReshardingCoordinatorDocument getCoordinatorDoc(OperationContext* opCtx) {
DBDirectClient client(opCtx);
auto doc = client.findOne(NamespaceString::kConfigReshardingOperationsNamespace, BSONObj{});
IDLParserContext errCtx("reshardingCoordFromTest");
return ReshardingCoordinatorDocument::parse(doc, errCtx);
}
void checkCoordinatorDocumentRemoved(OperationContext* opCtx) {
const auto coordinatorColl = acquireCollection(
opCtx,
CollectionAcquisitionRequest{NamespaceString::kConfigReshardingOperationsNamespace,
PlacementConcern{boost::none, ShardVersion::UNTRACKED()},
repl::ReadConcernArgs::get(opCtx),
AcquisitionPrerequisites::kRead},
MODE_IS);
ASSERT_TRUE(coordinatorColl.exists());
ASSERT_TRUE(bool(coordinatorColl.getCollectionPtr()->isEmpty(opCtx)));
}
CollectionType getCollectionCatalogEntry(OperationContext* opCtx) {
DBDirectClient client(opCtx);
auto doc = client.findOne(NamespaceString::kConfigsvrCollectionsNamespace,
BSON(CollectionType::kNssFieldName << _originalNss.ns_forTest()));
return CollectionType{std::move(doc)};
}
CollectionType getTemporaryCollectionCatalogEntry(
OperationContext* opCtx, const ReshardingCoordinatorDocument& coordinatorDoc) {
DBDirectClient client(opCtx);
auto doc = client.findOne(NamespaceString::kConfigsvrCollectionsNamespace,
BSON(CollectionType::kNssFieldName
<< coordinatorDoc.getTempReshardingNss().ns_forTest()));
return CollectionType{std::move(doc)};
}
void updateCoordinatorDoc(OperationContext* opCtx,
const UUID& reshardingUUID,
const BSONObj& updates) {
DBDirectClient client(opCtx);
const BSONObj query(
BSON(ReshardingCoordinatorDocument::kReshardingUUIDFieldName << reshardingUUID));
client.update(NamespaceString::kConfigReshardingOperationsNamespace, query, updates);
}
void waitUntilCommittedCoordinatorDocReach(OperationContext* opCtx,
CoordinatorStateEnum state) {
while (true) {
auto coordinatorDoc = getCoordinatorDocBSON(opCtx);
auto currentState = coordinatorDoc.getStringField("state");
if (currentState == idl::serialize(state)) {
break;
}
sleepmillis(50);
}
}
void makeDonorsReadyToDonateWithAssert(OperationContext* opCtx) {
auto coordDoc = getCoordinatorDoc(opCtx);
ASSERT_NE(coordDoc.getStartTime(), Date_t::min());
auto donorShards = coordDoc.getDonorShards();
auto keyPre = [](StringData suffix) {
return fmt::format("{}.$[].mutableState.{}",
ReshardingCoordinatorDocument::kDonorShardsFieldName,
suffix);
};
BSONObjBuilder updates;
{
BSONObjBuilder{updates.subobjStart("$set")}
.append(keyPre("state"), idl::serialize(DonorStateEnum::kDonatingInitialData))
.append(keyPre("minFetchTimestamp"), _cloneTimestamp)
.append(keyPre("bytesToClone"),
static_cast<long long>(totalApproxBytesToClone / donorShards.size()))
.append(keyPre("documentsToClone"),
static_cast<long long>(totalApproxDocumentsToClone / donorShards.size()));
}
updateCoordinatorDoc(opCtx, coordDoc.getReshardingUUID(), updates.obj());
}
void makeRecipientsFinishedCloningWithAssert(OperationContext* opCtx) {
auto coordDoc = getCoordinatorDoc(opCtx);
ASSERT_NE(coordDoc.getMetrics()->getDocumentCopy()->getStart(), Date_t::min());
BSONObj updates = BSON(
"$set" << BSON(std::string(ReshardingCoordinatorDocument::kRecipientShardsFieldName) +
".$[].mutableState.state"
<< idl::serialize(RecipientStateEnum::kApplying)));
updateCoordinatorDoc(opCtx, coordDoc.getReshardingUUID(), updates);
}
void makeRecipientsBeInStrictConsistencyWithAssert(OperationContext* opCtx) {
auto coordDoc = getCoordinatorDoc(opCtx);
ASSERT_LTE(coordDoc.getMetrics()->getOplogApplication()->getStart(),
coordDoc.getMetrics()->getOplogApplication()->getStop());
BSONObj updates = BSON(
"$set" << BSON(std::string(ReshardingCoordinatorDocument::kRecipientShardsFieldName) +
".$[].mutableState.state"
<< idl::serialize(RecipientStateEnum::kStrictConsistency)));
updateCoordinatorDoc(opCtx, coordDoc.getReshardingUUID(), updates);
}
void makeDonorsProceedToDone(OperationContext* opCtx, UUID reshardingUUID) {
BSONObj updates =
BSON("$set" << BSON(std::string(ReshardingCoordinatorDocument::kDonorShardsFieldName) +
".$[].mutableState.state"
<< idl::serialize(DonorStateEnum::kDone)));
updateCoordinatorDoc(opCtx, reshardingUUID, updates);
}
void makeDonorsProceedToDone(OperationContext* opCtx) {
makeDonorsProceedToDone(opCtx, getCoordinatorDoc(opCtx).getReshardingUUID());
}
void makeDonorsProceedToDoneWithAssert(OperationContext* opCtx) {
auto coordDoc = getCoordinatorDoc(opCtx);
ASSERT_LTE(coordDoc.getMetrics()->getDocumentCopy()->getStart(),
coordDoc.getMetrics()->getDocumentCopy()->getStop());
makeDonorsProceedToDone(opCtx, coordDoc.getReshardingUUID());
}
void makeRecipientsProceedToDone(OperationContext* opCtx, UUID reshardingUUID) {
BSONObj updates = BSON(
"$set" << BSON(std::string(ReshardingCoordinatorDocument::kRecipientShardsFieldName) +
".$[].mutableState.state"
<< idl::serialize(RecipientStateEnum::kDone)));
updateCoordinatorDoc(opCtx, reshardingUUID, updates);
}
void makeRecipientsProceedToDone(OperationContext* opCtx) {
makeRecipientsProceedToDone(opCtx, getCoordinatorDoc(opCtx).getReshardingUUID());
}
void makeRecipientsProceedToDoneWithAssert(OperationContext* opCtx) {
auto coordDoc = getCoordinatorDoc(opCtx);
ASSERT_LTE(coordDoc.getMetrics()->getDocumentCopy()->getStart(),
coordDoc.getMetrics()->getDocumentCopy()->getStop());
makeRecipientsProceedToDone(opCtx, coordDoc.getReshardingUUID());
}
void makeRecipientsReturnErrorWithAssert(OperationContext* opCtx) {
auto coordDoc = getCoordinatorDoc(opCtx);
ASSERT_NE(coordDoc.getMetrics()->getDocumentCopy()->getStart(), Date_t::min());
Status abortReasonStatus{ErrorCodes::SnapshotUnavailable, "test simulated error"};
BSONObjBuilder tmpBuilder;
abortReasonStatus.serialize(&tmpBuilder);
BSONObj updates =
BSON("$set" << BSON(
std::string(ReshardingCoordinatorDocument::kRecipientShardsFieldName) +
".$[].mutableState.state"
<< idl::serialize(RecipientStateEnum::kError)
<< std::string(ReshardingCoordinatorDocument::kRecipientShardsFieldName) +
".$[].mutableState.abortReason"
<< tmpBuilder.obj()));
updateCoordinatorDoc(opCtx, coordDoc.getReshardingUUID(), updates);
}
void stepUp(OperationContext* opCtx) {
auto replCoord = repl::ReplicationCoordinator::get(getServiceContext());
auto currOpTime = replCoord->getMyLastAppliedOpTime();
++_term;
auto newOpTime = repl::OpTime{currOpTime.getTimestamp(), _term};
ASSERT_OK(replCoord->setFollowerMode(repl::MemberState::RS_PRIMARY));
ASSERT_OK(replCoord->updateTerm(opCtx, _term));
replCoord->setMyLastAppliedOpTimeAndWallTimeForward({newOpTime, {}});
_registry->onStepUpComplete(opCtx, _term);
}
void stepDown(OperationContext* opCtx) {
ASSERT_OK(repl::ReplicationCoordinator::get(getServiceContext())
->setFollowerMode(repl::MemberState::RS_SECONDARY));
_registry->onStepDown();
auto serviceCtx = opCtx->getServiceContext();
for (ServiceContext::LockedClientsCursor cursor(serviceCtx);
Client* client = cursor.next();) {
if (!client->canKillOperationInStepdown()) {
continue;
}
ClientLock lk(client);
OperationContext* toKill = client->getOperationContext();
if (toKill && !toKill->isKillPending() && toKill->getOpID() != opCtx->getOpID()) {
auto locker = shard_role_details::getLocker(toKill);
if (toKill->shouldAlwaysInterruptAtStepDownOrUp() ||
locker->wasGlobalLockTakenInModeConflictingWithWrites()) {
serviceCtx->killOperation(lk, toKill);
}
}
}
}
protected:
repl::PrimaryOnlyService* _service = nullptr;
std::shared_ptr<CoordinatorStateTransitionController> _controller;
std::shared_ptr<ExternalStateForTest> _externalState;
OpObserverRegistry* _opObserverRegistry = nullptr;
repl::PrimaryOnlyServiceRegistry* _registry = nullptr;
NamespaceString _originalNss = NamespaceString::createNamespaceString_forTest("db.foo");
UUID _originalUUID = UUID::gen();
OID _originalEpoch = OID::gen();
Timestamp _originalTimestamp = Timestamp(1);
NamespaceString _tempNss = NamespaceString::createNamespaceString_forTest(
"db.system.resharding." + _originalUUID.toString());
UUID _reshardingUUID = UUID::gen();
OID _tempEpoch = OID::gen();
Timestamp _tempTimestamp = Timestamp(2);
ShardKeyPattern _oldShardKey = ShardKeyPattern(BSON("oldShardKey" << 1));
ShardKeyPattern _newShardKey = ShardKeyPattern(BSON("newShardKey" << 1));
Timestamp _cloneTimestamp = Timestamp(Date_t::now());
RAIIServerParameterControllerForTest _serverParamController{
"reshardingMinimumOperationDurationMillis", 0};
long long _term = 0;
static constexpr long totalApproxBytesToClone = 10000;
static constexpr long totalApproxDocumentsToClone = 100;
};
} // namespace resharding_coordinator_test
} // namespace mongo

View File

@ -39,6 +39,9 @@
#include "mongo/otel/traces/span/span.h"
#include "mongo/otel/traces/telemetry_context_serialization.h"
#include "mongo/unittest/unittest.h"
#include <gtest/gtest.h>
namespace mongo {
namespace resharding {
@ -324,5 +327,100 @@ TEST_F(ReshardingCoordinatorServiceUtilTest, RegistryPathReturnsReshardingUUID)
ASSERT_EQ(retrieveReshardingUUID(opCtx.get(), kSourceNss), reshardingUUID);
}
/**
* Parameterized fixture exercising the per-provenance behavior of
* createReshardingFieldsUpdateForOriginalNss and createTempReshardingCollectionType.
*/
class ReshardingCoordinatorServiceUtilProvenanceTest
: public ReshardingCoordinatorServiceUtilTest,
public ::testing::WithParamInterface<ReshardingProvenanceEnum> {
protected:
ReshardingCoordinatorDocument makeCoordinatorDocWithProvenance(CoordinatorStateEnum state) {
ReshardingCoordinatorDocument doc;
auto metadata = makeMetadata();
metadata.setStartTime(Date_t::now());
metadata.setProvenance(GetParam());
doc.setCommonReshardingMetadata(std::move(metadata));
doc.setState(state);
DonorShardContext donorCtx;
donorCtx.setState(DonorStateEnum::kPreparingToDonate);
doc.setDonorShards({DonorShardEntry(ShardId("donor0"), donorCtx)});
RecipientShardContext recipientCtx;
recipientCtx.setState(RecipientStateEnum::kUnused);
doc.setRecipientShards({RecipientShardEntry(ShardId("recipient0"), recipientCtx)});
return doc;
}
};
INSTANTIATE_TEST_SUITE_P(Provenance,
ReshardingCoordinatorServiceUtilProvenanceTest,
::testing::Values(ReshardingProvenanceEnum::kReshardCollection,
ReshardingProvenanceEnum::kMoveCollection,
ReshardingProvenanceEnum::kUnshardCollection,
ReshardingProvenanceEnum::kRewriteCollection),
[](const ::testing::TestParamInfo<ReshardingProvenanceEnum>& info) {
return std::string(idl::serialize(info.param));
});
TEST_P(ReshardingCoordinatorServiceUtilProvenanceTest,
CollectionUpdateAtCommitSetsUnsplittableForUnshardOnly) {
auto opCtx = makeOperationContext();
auto doc = makeCoordinatorDocWithProvenance(CoordinatorStateEnum::kCommitting);
auto update =
createReshardingFieldsUpdateForOriginalNss(opCtx.get(), doc, OID::gen(), Timestamp(1, 2));
auto setFields = update.getObjectField("$set");
if (isUnshardCollection(GetParam())) {
ASSERT_TRUE(setFields.hasField("unsplittable"));
ASSERT_TRUE(setFields["unsplittable"].Bool());
} else {
ASSERT_FALSE(setFields.hasField("unsplittable"));
}
}
TEST_P(ReshardingCoordinatorServiceUtilProvenanceTest,
TempCollectionBlocksMigrationsForReshardAndRewriteOnly) {
auto opCtx = makeOperationContext();
auto doc = makeCoordinatorDocWithProvenance(CoordinatorStateEnum::kPreparingToDonate);
ChunkVersion chunkVersion(CollectionGeneration{OID::gen(), Timestamp(5, 0)},
CollectionPlacement(10, 1));
auto collType = createTempReshardingCollectionType(
opCtx.get(), doc, chunkVersion, BSONObj() /* collation */, boost::none);
const bool expectMigrationsBlocked =
isOrdinaryReshardCollection(GetParam()) || isRewriteCollection(GetParam());
ASSERT_EQ(collType.getAllowMigrations(), !expectMigrationsBlocked);
}
TEST_P(ReshardingCoordinatorServiceUtilProvenanceTest,
CollectionUpdateAtInitializingCopiesProvenance) {
auto opCtx = makeOperationContext();
auto doc = makeCoordinatorDocWithProvenance(CoordinatorStateEnum::kInitializing);
auto update =
createReshardingFieldsUpdateForOriginalNss(opCtx.get(), doc, boost::none, boost::none);
auto reshardingFields = update.getObjectField("$set").getObjectField("reshardingFields");
ASSERT_EQ(reshardingFields.getStringField("provenance"), idl::serialize(GetParam()));
}
TEST_P(ReshardingCoordinatorServiceUtilProvenanceTest, TempCollectionTypeCopiesProvenance) {
auto opCtx = makeOperationContext();
auto doc = makeCoordinatorDocWithProvenance(CoordinatorStateEnum::kPreparingToDonate);
ChunkVersion chunkVersion(CollectionGeneration{OID::gen(), Timestamp(5, 0)},
CollectionPlacement(10, 1));
auto collType =
createTempReshardingCollectionType(opCtx.get(), doc, chunkVersion, BSONObj(), boost::none);
ASSERT_TRUE(collType.getReshardingFields().has_value());
ASSERT_EQ(collType.getReshardingFields()->getProvenance(), GetParam());
}
} // namespace resharding
} // namespace mongo

View File

@ -528,6 +528,57 @@ bool isOrdinaryReshardCollection(const boost::optional<ReshardingProvenanceEnum>
return provenance && provenance.get() == ReshardingProvenanceEnum::kReshardCollection;
}
BSONObj computeReshardingShardKey(
const boost::optional<ReshardingProvenanceEnum>& provenance,
const ShardKeyPattern& sourceShardKey,
const boost::optional<TypeCollectionTimeseriesFields>& timeseriesFields,
const boost::optional<BSONObj>& userKey) {
if (isRewriteCollection(provenance)) {
// rewriteCollection reshards the collection on its existing key.
return sourceShardKey.getKeyPattern().toBSON();
}
uassert(
ErrorCodes::InvalidOptions, "Resharding operation requires shard key", userKey.has_value());
if (timeseriesFields && isOrdinaryReshardCollection(provenance)) {
return shardkeyutil::validateAndTranslateTimeseriesShardKey(
timeseriesFields->getTimeseriesOptions(), *userKey);
}
return *userKey;
}
void validateReshardCollectionRequest(const boost::optional<ReshardingProvenanceEnum>& provenance,
bool sourceIsSharded,
const ShardKeyPattern& sourceShardKey,
const BSONObj& finalShardKey,
bool forceRedistribution) {
if (isMoveCollection(provenance)) {
uassert(ErrorCodes::NamespaceNotFound,
"MoveCollection can only be called on an unsharded collection.",
!sourceIsSharded);
return;
}
if (isUnshardCollection(provenance)) {
// Skip: unshardCollection is validated later in the resharding flow.
return;
}
uassert(ErrorCodes::NamespaceNotSharded,
"Collection has to be a sharded collection.",
sourceIsSharded);
if (forceRedistribution) {
uassert(ErrorCodes::InvalidOptions,
"The new shard key must be the same as the original shard key when using the "
"forceRedistribution option. The forceRedistribution option is meant for "
"redistributing the collection to a different set of shards.",
sourceShardKey.isShardKey(finalShardKey));
}
}
std::shared_ptr<ThreadPool> makeThreadPoolForMarkKilledExecutor(const std::string& poolName) {
return std::make_shared<ThreadPool>([&] {
ThreadPool::Options options;

View File

@ -491,6 +491,27 @@ bool isUnshardCollection(const boost::optional<ReshardingProvenanceEnum>& proven
*/
bool isRewriteCollection(const boost::optional<ReshardingProvenanceEnum>& provenance);
/**
* Returns the final shard key for the operation:
* - kRewriteCollection: always returns the existing source key.
* - kReshardCollection on timeseries: translates userKey from user-facing field to bucket-level.
* - Everything else: returns userKey unchanged.
*/
BSONObj computeReshardingShardKey(
const boost::optional<ReshardingProvenanceEnum>& provenance,
const ShardKeyPattern& sourceShardKey,
const boost::optional<TypeCollectionTimeseriesFields>& timeseriesFields,
const boost::optional<BSONObj>& userKey);
/**
* Validates source collection sharding state for the given provenance.
*/
void validateReshardCollectionRequest(const boost::optional<ReshardingProvenanceEnum>& provenance,
bool sourceIsSharded,
const ShardKeyPattern& sourceShardKey,
const BSONObj& finalShardKey,
bool forceRedistribution);
/**
* Helper function to create a thread pool for _markKilledExecutor member of resharding POS.
*/