SERVER-105713 Allow callers to provide TickSource to RateLimiter for unit testing (#42173)

GitOrigin-RevId: 6930a411d24e6230b21edfec6f897bea396f7f5d
This commit is contained in:
Erin McNulty 2025-10-07 10:16:43 -04:00 committed by MongoDB Bot
parent 4211210a74
commit 6d56055c32
3 changed files with 91 additions and 82 deletions

View File

@ -54,7 +54,7 @@ double calculateBurstSize(double refreshRate, double burstCapacitySecs) {
} // namespace
struct RateLimiter::RateLimiterPrivate {
RateLimiterPrivate(double r, double b, int64_t m, std::string n)
RateLimiterPrivate(double r, double b, int64_t m, TickSource* clock, std::string n)
// Initialize the token bucket with one "burst" of tokens. The third parameter to
// tokenBucket's constructor ("zeroTime") is interpreted as a number of seconds from the
// epoch of the clock used by the token bucket. The clock is
@ -62,13 +62,11 @@ struct RateLimiter::RateLimiterPrivate {
// of the machine. Rather than have an initial accumulation of tokens based on some
// unknown point in the past, set the zero time to a known time in the past: enough time
// for burst size (b) tokens to have accumulated.
: tokenBucket{r, b, folly::TokenBucket::defaultClockNow() - b / r},
maxQueueDepth(m),
: maxQueueDepth(m),
queued(0),
name(std::move(n)) {}
WriteRarelyRWMutex rwMutex;
folly::TokenBucket tokenBucket;
name(std::move(n)),
tickSource(clock),
tokenBucket{r, b, nowInSeconds() - b / r} {}
Stats stats;
@ -77,6 +75,10 @@ struct RateLimiter::RateLimiterPrivate {
std::string name;
TickSource* tickSource;
WriteRarelyRWMutex rwMutex;
folly::TokenBucket tokenBucket;
Status enqueue() {
const auto maxDepth = maxQueueDepth.loadRelaxed();
int64_t expected = queued.load();
@ -91,12 +93,19 @@ struct RateLimiter::RateLimiterPrivate {
return Status::OK();
}
double nowInSeconds() {
return std::chrono::duration<double>(
std::chrono::nanoseconds(tickSource->getTicks())) // NOLINT
.count();
}
};
RateLimiter::RateLimiter(double refreshRatePerSec,
double burstCapacitySecs,
int64_t maxQueueDepth,
std::string name) {
std::string name,
TickSource* tickSource) {
uassert(ErrorCodes::InvalidOptions,
fmt::format("burstCapacitySecs cannot be less than or equal to 0.0. "
"burstCapacitySecs={}; rateLimiterName={}",
@ -105,7 +114,7 @@ RateLimiter::RateLimiter(double refreshRatePerSec,
burstCapacitySecs > 0.0);
auto burstSize = calculateBurstSize(refreshRatePerSec, burstCapacitySecs);
_impl = std::make_unique<RateLimiterPrivate>(
refreshRatePerSec, burstSize, maxQueueDepth, std::move(name));
refreshRatePerSec, burstSize, maxQueueDepth, tickSource, std::move(name));
}
RateLimiter::~RateLimiter() = default;
@ -121,7 +130,8 @@ Status RateLimiter::acquireToken(OperationContext* opCtx) {
waitForTokenSecs = 60 * 60; // 1 hour
} else {
auto lk = _impl->rwMutex.readLock();
waitForTokenSecs = _impl->tokenBucket.consumeWithBorrowNonBlocking(1.0).value_or(0);
waitForTokenSecs =
_impl->tokenBucket.consumeWithBorrowNonBlocking(1.0, _impl->nowInSeconds()).value_or(0);
}
if (auto napTime = doubleToMillis(waitForTokenSecs); napTime > Milliseconds{0}) {
@ -174,7 +184,7 @@ Status RateLimiter::acquireToken(OperationContext* opCtx) {
Status RateLimiter::tryAcquireToken() {
_impl->stats.attemptedAdmissions.incrementRelaxed();
if (!_impl->tokenBucket.consume(1.0)) {
if (!_impl->tokenBucket.consume(1.0, _impl->nowInSeconds())) {
_impl->stats.rejectedAdmissions.incrementRelaxed();
return Status{kRejectedErrorCode,
fmt::format("Rate limiter '{}' rate exceeded", _impl->name)};
@ -229,12 +239,12 @@ void RateLimiter::appendStats(BSONObjBuilder* bob) const {
double RateLimiter::tokensAvailable() const {
auto lk = _impl->rwMutex.readLock();
return _impl->tokenBucket.available();
return _impl->tokenBucket.available(_impl->nowInSeconds());
}
double RateLimiter::tokenBalance() const {
auto lk = _impl->rwMutex.readLock();
return _impl->tokenBucket.balance();
return _impl->tokenBucket.balance(_impl->nowInSeconds());
}
int64_t RateLimiter::queued() const {

View File

@ -33,6 +33,7 @@
#include "mongo/bson/bsonobjbuilder.h"
#include "mongo/db/operation_context.h"
#include "mongo/util/moving_average.h"
#include "mongo/util/system_tick_source.h"
namespace mongo::admission {
@ -95,7 +96,8 @@ public:
RateLimiter(double refreshRatePerSec,
double burstCapacitySecs,
int64_t maxQueueDepth,
std::string name);
std::string name,
TickSource* tickSource = globalSystemTickSource());
~RateLimiter();

View File

@ -36,6 +36,7 @@
#include "mongo/unittest/unittest.h"
#include "mongo/util/clock_source_mock.h"
#include "mongo/util/concurrency/notification.h"
#include "mongo/util/tick_source_mock.h"
#define MONGO_LOGV2_DEFAULT_COMPONENT ::mongo::logv2::LogComponent::kTest
@ -59,20 +60,38 @@ auto constexpr convertBurstSizeToBurstCapacitySecs(double refreshRate, double bu
return burstSize / refreshRate;
}
auto assertApproxEqualUnevenBounds(double value, double lowerBound, double upperBound) {
ASSERT_GTE(value, lowerBound);
ASSERT_LT(value, upperBound);
}
class RateLimiterWithMockClockTest : public ClockSourceMockServiceContextTest {
public:
void setUp() override {
static_cast<ClockSourceMock*>(getServiceContext()->getFastClockSource())->reset();
static_cast<TickSourceMock<Milliseconds>*>(getServiceContext()->getTickSource())->reset(0);
}
RateLimiter makeRateLimiter(std::string name,
double refreshRate = 1.0,
double burstCapacitySecs = 1.0,
int maxQueueDepth = INT_MAX) {
return RateLimiter(refreshRate,
burstCapacitySecs,
maxQueueDepth,
name,
getServiceContext()->getTickSource());
}
void advanceTime(Milliseconds d) {
static_cast<ClockSourceMock*>(getServiceContext()->getFastClockSource())->advance(d);
static_cast<TickSourceMock<Milliseconds>*>(getServiceContext()->getTickSource())
->advance(d);
}
class RateLimiterTest : public ServiceContextTest {
private:
unittest::MinimumLoggedSeverityGuard logSeverityGuard{logv2::LogComponent::kDefault,
logv2::LogSeverity::Debug(4)};
};
// Verify that a RateLimiter with sufficient capacity will dispense a token.
TEST_F(RateLimiterTest, BasicTokenAcquisition) {
RateLimiter rateLimiter(DBL_MAX, DBL_MAX, INT_MAX, "BasicTokenAcquisition");
TEST_F(RateLimiterWithMockClockTest, BasicTokenAcquisition) {
RateLimiter rateLimiter = makeRateLimiter("BasicTokenAcquisition");
auto opCtx = makeOperationContext();
ASSERT_OK(rateLimiter.acquireToken(opCtx.get()));
@ -89,11 +108,11 @@ TEST_F(RateLimiterTest, BasicTokenAcquisition) {
}
// Verify that RateLimiter::setBurstSize range checks its input.
TEST_F(RateLimiterTest, InvalidBurstSize) {
TEST_F(RateLimiterWithMockClockTest, InvalidBurstSize) {
ASSERT_THROWS_CODE(
RateLimiter(1.0, 0, 0, "InvalidBurstSize"), DBException, ErrorCodes::InvalidOptions);
RateLimiter rateLimiter(DBL_MAX, DBL_MAX, INT_MAX, "InvalidBurstSize");
RateLimiter rateLimiter = makeRateLimiter("InvalidBurstSize");
ASSERT_THROWS_CODE(
rateLimiter.updateRateParameters(1.0, 0), DBException, ErrorCodes::InvalidOptions);
}
@ -101,14 +120,15 @@ TEST_F(RateLimiterTest, InvalidBurstSize) {
// Verify that RateLimiter will reject a request for a token if:
// - the request would otherwise be enqueued (there are insufficent tokens),
// - but there are already the maximum number of threads enqueued.
TEST_F(RateLimiterTest, RejectOverMaxWaiters) {
TEST_F(RateLimiterWithMockClockTest, RejectOverMaxWaiters) {
unittest::threadAssertionMonitoredTest([&](auto& monitor) {
constexpr double burstSize = 1.0;
constexpr double refreshRate = 0.1;
constexpr double refreshRate = 1.0;
constexpr double burstCapacitySecs =
convertBurstSizeToBurstCapacitySecs(refreshRate, burstSize);
RateLimiter rateLimiter(refreshRate, burstCapacitySecs, 1, "RejectOverMaxWaiters");
RateLimiter rateLimiter =
makeRateLimiter("RejectOverMaxWaiters", refreshRate, burstCapacitySecs, 1);
auto clientsWithOps = makeClientsWithOpCtxs(getServiceContext(), 3);
Notification<void> firstTokenAcquired;
Notification<void> hasFailed;
@ -122,8 +142,8 @@ TEST_F(RateLimiterTest, RejectOverMaxWaiters) {
firstTokenAcquired.set();
}));
// Enqueue two requests, both of which will queue due to the low refreshRate. Whichever
// request comes in second will be rejected, because maxQueueDepth is 1.
// Enqueue two requests, both of which will queue because the mock clock hasn't advanced.
// Whichever request comes in second will be rejected, because maxQueueDepth is 1.
threads.emplace_back(monitor.spawn([&]() {
firstTokenAcquired.get();
status1 = rateLimiter.acquireToken(clientsWithOps[1].second.get());
@ -152,11 +172,8 @@ TEST_F(RateLimiterTest, RejectOverMaxWaiters) {
(status2.code() == RateLimiter::kRejectedErrorCode));
ASSERT_EQ(rateLimiter.stats().rejectedAdmissions.get(), 1);
// Assert that the token balance is between ~0 and -1, as only one token should have been
// borrowed from the bucket. The bucket may have refilled slightly past 0 on slow machines,
// and so we account for that in the assertion.
ASSERT_LT(rateLimiter.tokenBalance(), 0.2);
ASSERT_GT(rateLimiter.tokenBalance(), -1);
// One token is "borrowed" from the bucket due to the queued request.
ASSERT_EQ(rateLimiter.tokenBalance(), -1);
// Interrupt the other token acquisition.
getServiceContext()->setKillAllOperations();
@ -165,14 +182,15 @@ TEST_F(RateLimiterTest, RejectOverMaxWaiters) {
// Verify that if the maximum queue depth is configured to be zero, then any requests that would
// otherwise queue are instead rejected.
TEST_F(RateLimiterTest, QueueingDisabled) {
TEST_F(RateLimiterWithMockClockTest, QueueingDisabled) {
unittest::threadAssertionMonitoredTest([&](auto& monitor) {
constexpr double burstSize = 1.0;
constexpr double refreshRate = .01;
constexpr double refreshRate = 1.0;
constexpr double burstCapacitySecs =
convertBurstSizeToBurstCapacitySecs(refreshRate, burstSize);
RateLimiter rateLimiter(refreshRate, burstCapacitySecs, 0, "QueueingDisabled");
RateLimiter rateLimiter =
makeRateLimiter("QueueingDisabled", refreshRate, burstCapacitySecs, 0);
auto clientsWithOps = makeClientsWithOpCtxs(getServiceContext(), 2);
Notification<void> firstTokenAcquired;
@ -184,9 +202,7 @@ TEST_F(RateLimiterTest, QueueingDisabled) {
firstTokenAcquired.set();
}));
// The next token acquisition attempt fails because it would need to queue. We'd need to
// wait another 100 seconds for the next token, and the max queue depth is zero, so the
// request is rejected.
// The next token acquisition attempt fails because it would need to queue.
threads.emplace_back(monitor.spawn([&]() {
firstTokenAcquired.get();
Status token = rateLimiter.acquireToken(clientsWithOps[1].second.get());
@ -197,25 +213,24 @@ TEST_F(RateLimiterTest, QueueingDisabled) {
ASSERT_EQ(rateLimiter.stats().rejectedAdmissions.get(), 1);
ASSERT_EQ(rateLimiter.stats().attemptedAdmissions.get(), 2);
// Assert that the token balance is ~0, as only one token should have been
// Assert that the token balance is 0, as only one token should have been
// consumed from the bucket.
ASSERT_APPROX_EQUAL(rateLimiter.tokenBalance(), 0, .1);
ASSERT_EQ(rateLimiter.tokenBalance(), 0);
}));
});
}
// Verify that if a client disconnects while their session thread is asleep in the rate limiter,
// the rate limiter wakes up the thread and returns the appropriate error status.
TEST_F(RateLimiterTest, InterruptedDueToOperationKilled) {
TEST_F(RateLimiterWithMockClockTest, InterruptedDueToOperationKilled) {
unittest::threadAssertionMonitoredTest([&](auto& monitor) {
constexpr double burstSize = 1.0;
constexpr double refreshRate = .01;
constexpr double refreshRate = 1.0;
constexpr double burstCapacitySecs =
convertBurstSizeToBurstCapacitySecs(refreshRate, burstSize);
RateLimiter rateLimiter(
refreshRate, burstCapacitySecs, INT_MAX, "InterruptedDueToOperationKilled");
RateLimiter rateLimiter =
makeRateLimiter("InterruptedDueToOperationKilled", refreshRate, burstCapacitySecs);
auto clientsWithOps = makeClientsWithOpCtxs(getServiceContext(), 2);
Notification<void> firstTokenAcquired;
std::vector<unittest::JoinThread> threads;
@ -243,9 +258,9 @@ TEST_F(RateLimiterTest, InterruptedDueToOperationKilled) {
ASSERT_EQ(rateLimiter.stats().removedFromQueue.get(), 1);
ASSERT_NE(rateLimiter.stats().averageTimeQueuedMicros.get(), boost::none);
ASSERT_GTE(*rateLimiter.stats().averageTimeQueuedMicros.get(), 0.0);
// Assert that the token balance is ~0, as only one token should have been
// Assert that the token balance is 0, as only one token should have been
// consumed from the bucket.
ASSERT_APPROX_EQUAL(rateLimiter.tokenBalance(), 0, .1);
ASSERT_EQ(rateLimiter.tokenBalance(), 0);
}));
firstTokenAcquired.get();
@ -253,18 +268,17 @@ TEST_F(RateLimiterTest, InterruptedDueToOperationKilled) {
});
}
// This is like the previous two tests, but instead of a client disconnect or a timeout, it's a
// service shutdown.
TEST_F(RateLimiterTest, InterruptedDueToKillAllOperations) {
TEST_F(RateLimiterWithMockClockTest, InterruptedDueToKillAllOperations) {
unittest::threadAssertionMonitoredTest([&](auto& monitor) {
constexpr double burstSize = 1.0;
constexpr double refreshRate = .01;
constexpr double refreshRate = 1.0;
constexpr double burstCapacitySecs =
convertBurstSizeToBurstCapacitySecs(refreshRate, burstSize);
RateLimiter rateLimiter(
refreshRate, burstCapacitySecs, INT_MAX, "InterruptedDueToKillAllOperations");
RateLimiter rateLimiter = makeRateLimiter(
"InterruptedDueToKillAllOperations", refreshRate, burstCapacitySecs, INT_MAX);
auto clientsWithOps = makeClientsWithOpCtxs(getServiceContext(), 2);
Notification<void> firstTokenAcquired;
Notification<void> secondTokenInterrupted;
@ -312,8 +326,8 @@ TEST_F(RateLimiterTest, InterruptedDueToKillAllOperations) {
}
// Verify that `RateLimiter::recordExemption()` increments the exemption metric but no others.
TEST_F(RateLimiterTest, RecordExemption) {
RateLimiter rateLimiter(INT_MAX, INT_MAX, INT_MAX, "RecordExemption");
TEST_F(RateLimiterWithMockClockTest, RecordExemption) {
RateLimiter rateLimiter = makeRateLimiter("RecordExemption");
rateLimiter.recordExemption();
ASSERT_EQ(rateLimiter.stats().exemptedAdmissions.get(), 1);
@ -326,22 +340,6 @@ TEST_F(RateLimiterTest, RecordExemption) {
ASSERT_EQ(rateLimiter.stats().averageTimeQueuedMicros.get(), boost::none);
}
class RateLimiterWithMockClockTest : public ClockSourceMockServiceContextTest {
public:
/**
* Note that the mocked clock in this test applies to sleeping on the opCtx rather than to the
* internal folly::TokenBucket's calculations.
*/
void advanceTime(Milliseconds d) {
static_cast<ClockSourceMock*>(getServiceContext()->getFastClockSource())->advance(d);
}
private:
unittest::MinimumLoggedSeverityGuard logSeverityGuard{logv2::LogComponent::kDefault,
logv2::LogSeverity::Debug(4)};
};
// Verify that a newly initialized RateLimiter can immediately dispense up to its burst rate of
// tokens, and thereafter releases tokens at its configured rate.
// If multiple threads concurrently request tokens, then some of the threads will be admitted
@ -355,8 +353,8 @@ TEST_F(RateLimiterWithMockClockTest, ConcurrentTokenAcquisitionWithQueueing) {
const int64_t numThreads = 10;
const Milliseconds tokenInterval = Milliseconds(1000) / refreshRate;
RateLimiter rateLimiter(
refreshRate, burstCapacitySecs, INT_MAX, "ConcurrentTokenAcquisitionWithQueueing");
RateLimiter rateLimiter = makeRateLimiter(
"ConcurrentTokenAcquisitionWithQueueing", refreshRate, burstCapacitySecs, INT_MAX);
auto clientsWithOps = makeClientsWithOpCtxs(getServiceContext(), numThreads);
@ -476,13 +474,13 @@ TEST_F(RateLimiterWithMockClockTest, ConcurrentTokenAcquisitionWithQueueing) {
TEST_F(RateLimiterWithMockClockTest, TokenBalanceIsValidAfterRejectedRequests) {
unittest::threadAssertionMonitoredTest([&](auto& monitor) {
constexpr double burstSize = 1.0;
constexpr double refreshRate = .01;
constexpr double refreshRate = 1.0;
constexpr int maxQueueDepth = 0;
constexpr double burstCapacitySecs =
convertBurstSizeToBurstCapacitySecs(refreshRate, burstSize);
RateLimiter rateLimiter(
refreshRate, burstCapacitySecs, maxQueueDepth, "RateLimitIsValidAfterRejectedRequests");
RateLimiter rateLimiter = makeRateLimiter(
"RateLimitIsValidAfterRejectedRequests", refreshRate, burstCapacitySecs, maxQueueDepth);
auto clientsWithOps = makeClientsWithOpCtxs(getServiceContext(), 2);
Notification<void> firstTokenAcquired;
std::vector<unittest::JoinThread> threads;
@ -497,28 +495,27 @@ TEST_F(RateLimiterWithMockClockTest, TokenBalanceIsValidAfterRejectedRequests) {
// balance.
threads.emplace_back(monitor.spawn([&]() {
firstTokenAcquired.get();
ASSERT_APPROX_EQUAL(rateLimiter.tokenBalance(), 0, .5);
ASSERT_EQ(rateLimiter.tokenBalance(), 0);
Status token = rateLimiter.acquireToken(clientsWithOps[1].second.get());
ASSERT_EQ(token, Status(RateLimiter::kRejectedErrorCode, ""));
ASSERT_APPROX_EQUAL(rateLimiter.tokenBalance(), 0, .5);
ASSERT_EQ(rateLimiter.tokenBalance(), 0);
}));
firstTokenAcquired.get();
});
}
// Verify that if the sleep within an enqueued thread is interrupted, then the rate limiter
// returns the error status corresponding to the reason for the interruption.
TEST_F(RateLimiterWithMockClockTest, InterruptedDueToOperationDeadline) {
unittest::threadAssertionMonitoredTest([&](auto& monitor) {
constexpr double burstSize = 1.0;
constexpr double refreshRate = .01;
constexpr double refreshRate = 1.0;
constexpr double burstCapacitySecs =
convertBurstSizeToBurstCapacitySecs(refreshRate, burstSize);
RateLimiter rateLimiter(
refreshRate, burstCapacitySecs, INT_MAX, "InterruptedDueToOperationDeadline");
RateLimiter rateLimiter = makeRateLimiter(
"InterruptedDueToOperationDeadline", refreshRate, burstCapacitySecs, INT_MAX);
auto clientsWithOps = makeClientsWithOpCtxs(getServiceContext(), 2);
Notification<void> firstTokenAcquired;
std::vector<unittest::JoinThread> threads;