SERVER-125838 Test extension match and project pushdown rules (#54043)

GitOrigin-RevId: a65d0df13e782d91121967e2df7c56282764896f
This commit is contained in:
Adithi Raghavan 2026-05-22 13:00:02 -04:00 committed by MongoDB Bot
parent e2917819ab
commit 989847bc89
3 changed files with 258 additions and 24 deletions

View File

@ -0,0 +1,126 @@
/**
* Tests $match and $project pushdown optimizations for the $readNDocuments test extension.
*
* The "applyMatchPushdown" rewrite rule on $produceIds folds a subsequent $match with an _id
* lower-bound filter into the stage's startId, eliminating the $match from the pipeline.
*
* The "applyProjectPushdown" in-place rule suppresses "value" and "label" fields not present in
* the downstream dep set. applyPipelineSuffixDependencies stores the getNeededFields() result and
* the rule consumes it.
*
* @tags: [
* featureFlagExtensionsAPI,
* featureFlagExtensionsOptimizations,
* ]
*/
import {checkPlatformCompatibleWithExtensions, withExtensions} from "jstests/noPassthrough/libs/extension_helpers.js";
checkPlatformCompatibleWithExtensions();
function runAllTests(coll) {
// --------------------------$match Pushdown Test Cases--------------------------
// Baseline: $readNDocuments produces _id, value, and label fields.
{
const result = coll.aggregate([{$readNDocuments: {numDocs: 5}}]).toArray();
assert.eq(result.length, 5, "baseline: expected 5 docs");
for (const doc of result) {
assert.eq(doc.value, doc._id * 2, "baseline: wrong value", {doc});
assert.eq(doc.label, `doc_${doc._id}`, "baseline: wrong label", {doc});
}
}
// $match pushdown with $gte: only docs with _id >= 5 are returned.
{
const result = coll.aggregate([{$readNDocuments: {numDocs: 10}}, {$match: {_id: {$gte: 5}}}]).toArray();
assert.eq(result.length, 5, "$match $gte pushdown: expected 5 docs");
for (const doc of result) {
assert.gte(doc._id, 5, "$match $gte pushdown: _id below lower bound", {doc});
}
}
// $match pushdown with $gt: only docs with _id > 5 are returned.
{
const result = coll.aggregate([{$readNDocuments: {numDocs: 10}}, {$match: {_id: {$gt: 5}}}]).toArray();
assert.eq(result.length, 4, "$match $gt pushdown: expected 4 docs");
for (const doc of result) {
assert.gt(doc._id, 5, "$match $gt pushdown: _id not above lower bound", {doc});
}
}
// $match lower-bound beyond num docs outputted by $readNDocuments: produces zero docs.
{
const result = coll.aggregate([{$readNDocuments: {numDocs: 5}}, {$match: {_id: {$gte: 10}}}]).toArray();
assert.eq(result.length, 0, "$match beyond numDocs: expected 0 docs");
}
// Non-absorbable $match: filter on "value" has no _id lower-bound so $match must not be erased.
{
const result = coll.aggregate([{$readNDocuments: {numDocs: 10}}, {$match: {value: {$gte: 10}}}]).toArray();
// docs 0..4 have value 0..8 (< 10); docs 5..9 have value 10..18 (>= 10)
assert.eq(result.length, 5, "non-absorbable $match: expected 5 docs");
for (const doc of result) {
assert.gte(doc.value, 10, "non-absorbable $match: value below bound", {doc});
}
}
// --------------------------$project Pushdown Test Cases--------------------------
// Neither "value" nor "label" should be produced when dep set = {_id}.
{
const result = coll.aggregate([{$readNDocuments: {numDocs: 5}}, {$project: {_id: 1}}]).toArray();
assert.eq(result.length, 5, "$project pushdown: expected 5 docs");
for (const doc of result) {
assert(!doc.hasOwnProperty("value"), "$project pushdown: unexpected value field", {doc});
assert(!doc.hasOwnProperty("label"), "$project pushdown: unexpected label field", {doc});
}
}
// Only "label" should be suppressed when dep set = {_id, value}.
{
const result = coll.aggregate([{$readNDocuments: {numDocs: 5}}, {$project: {_id: 1, value: 1}}]).toArray();
assert.eq(result.length, 5, "selective suppression (value): expected 5 docs");
for (const doc of result) {
assert(doc.hasOwnProperty("value"), "selective suppression (value): missing value", {doc});
assert.eq(doc.value, doc._id * 2, "selective suppression (value): wrong value", {doc});
assert(!doc.hasOwnProperty("label"), "selective suppression (value): unexpected label", {doc});
}
}
// Only "value" should be suppressed when dep set = {_id, label}.
{
const result = coll.aggregate([{$readNDocuments: {numDocs: 5}}, {$project: {_id: 1, label: 1}}]).toArray();
assert.eq(result.length, 5, "selective suppression (label): expected 5 docs");
for (const doc of result) {
assert(doc.hasOwnProperty("label"), "selective suppression (label): missing label", {doc});
assert.eq(doc.label, `doc_${doc._id}`, "selective suppression (label): wrong label", {doc});
assert(!doc.hasOwnProperty("value"), "selective suppression (label): unexpected value", {doc});
}
}
}
function runTests(conn, shardingTest) {
const db = conn.getDB("test");
const coll = db[jsTestName()];
assert.commandWorked(
coll.insertMany(Array.from({length: 10}, (_, i) => ({_id: i, value: i * 2, label: `doc_${i}`}))),
);
runAllTests(coll);
if (shardingTest) {
assert.commandWorked(db.adminCommand({shardCollection: coll.getFullName(), key: {_id: 1}}));
assert.commandWorked(db.adminCommand({split: coll.getFullName(), middle: {_id: 5}}));
assert.commandWorked(
db.adminCommand({
moveChunk: coll.getFullName(),
find: {_id: 5},
to: shardingTest.shard1.shardName,
}),
);
runAllTests(coll);
}
}
withExtensions({"libread_n_documents_mongo_extension.so": {}}, runTests, ["standalone", "sharded"], {shards: 2});

View File

@ -701,7 +701,7 @@ REGISTER_RULES_WITH_FEATURE_FLAG(
.name = "EXTENSION_APPLY_PIPELINE_SUFFIX_DEPENDENCIES",
.precondition = mongo::extension::host::extensionApplyDependenciesPrecondition,
.transform = mongo::extension::host::extensionApplyDependenciesTransform,
.priority = kDefaultOptimizeInPlacePriority,
.priority = kDefaultOptimizeInPlacePriority + 1,
.tags = PipelineRewriteContext::Tags::InPlace,
});
REGISTER_RULES(DocumentSourceExtensionOptimizable,

View File

@ -40,19 +40,30 @@ using namespace mongo;
/**
* Execution stage that produces sequential integer _ids. When _produceScore is true (determined
* by whether the downstream pipeline suffix needs "score" metadata), each document also gets
* $score metadata set to _id * 5.
* $score metadata set to _id * 5. When _valueNeeded is true, each document also includes a
* "value" field equal to _id * 2. When _labelNeeded is true, each document also includes a
* "label" field equal to "doc_<id>".
*/
class ProduceIdsExecStage : public sdk::ExecAggStageSource {
public:
ProduceIdsExecStage(std::string_view stageName, bool sortById, int numDocs, bool produceScore)
ProduceIdsExecStage(std::string_view stageName,
bool sortById,
int numDocs,
bool produceScore,
bool valueNeeded,
int startId,
bool labelNeeded)
: sdk::ExecAggStageSource(stageName),
_sortById(sortById),
_numDocs(numDocs),
_produceScore(produceScore) {}
_produceScore(produceScore),
_valueNeeded(valueNeeded),
_labelNeeded(labelNeeded),
_currentDoc(startId) {}
extension::ExtensionGetNextResult getNext(const sdk::QueryExecutionContextHandle& execCtx,
::MongoExtensionExecAggStage* execStage) override {
if (_currentDoc == _numDocs) {
if (_currentDoc >= _numDocs) {
// At EOF, all idLookup processing for previously produced IDs is complete. Capture the
// accumulated idLookup metrics into this stage's extension operation metrics.
auto metrics = execCtx->getMetrics(execStage);
@ -62,9 +73,15 @@ public:
return extension::ExtensionGetNextResult::eof();
}
// Generate zero-indexed, ascending ids.
// Generate ascending ids starting from _startId.
auto currentId = _currentDoc++;
auto document = extension::ExtensionBSONObj::makeAsByteBuf(BSON("_id" << currentId));
BSONObjBuilder docBuilder;
docBuilder.append("_id", currentId);
if (_valueNeeded)
docBuilder.append("value", currentId * 2);
if (_labelNeeded)
docBuilder.append("label", "doc_" + std::to_string(currentId));
auto document = extension::ExtensionBSONObj::makeAsByteBuf(docBuilder.obj());
if (!_sortById && !_produceScore) {
// We haven't been asked for sorted results or a $score, no need to generate metadata.
@ -99,32 +116,49 @@ private:
const bool _sortById;
const int _numDocs;
const bool _produceScore;
int _currentDoc = 0;
const bool _valueNeeded;
const bool _labelNeeded;
int _currentDoc;
};
/**
* Logical stage for $produceIds. Overrides applyPipelineSuffixDependencies() to conditionally
* enable $score metadata production: if the downstream pipeline suffix needs "score" metadata,
* _produceScore is set to true and the compiled exec stage will emit score = _id * 5.
* Logical stage for $produceIds.
*
* Match pushdown "applyMatchPushdown" (reordering rule): folds a subsequent $match with an _id
* lower-bound filter into _startId, eliminating the $match from the pipeline.
*
* Project pushdown "applyProjectPushdown" (in-place rule): reads the required-fields set
* computed by dependency analysis (stored by applyPipelineSuffixDependencies) and suppresses
* any output fields not needed by the downstream pipeline.
*/
class ProduceIdsLogicalStage : public sdk::LogicalAggStage {
public:
ProduceIdsLogicalStage(std::string_view stageName, const BSONObj& arguments)
: ProduceIdsLogicalStage(stageName,
arguments["sortById"] && arguments["sortById"].booleanSafe(),
arguments["numDocs"] && arguments["numDocs"].isNumber()
? arguments["numDocs"].safeNumberInt()
: 1,
arguments["produceScore"].booleanSafe()) {}
: ProduceIdsLogicalStage(
stageName,
arguments["sortById"] && arguments["sortById"].booleanSafe(),
arguments["numDocs"] && arguments["numDocs"].isNumber()
? arguments["numDocs"].safeNumberInt()
: 1,
arguments["produceScore"].booleanSafe(),
!arguments["skipValue"].booleanSafe(),
arguments["startId"].isNumber() ? arguments["startId"].safeNumberInt() : 0,
!arguments["skipLabel"].booleanSafe()) {}
ProduceIdsLogicalStage(std::string_view stageName,
bool sortById,
int numDocs,
bool produceScore)
bool produceScore,
bool valueNeeded = true,
int startId = 0,
bool labelNeeded = true)
: sdk::LogicalAggStage(stageName),
_sortById(sortById),
_numDocs(numDocs),
_produceScore(produceScore) {}
_produceScore(produceScore),
_valueNeeded(valueNeeded),
_labelNeeded(labelNeeded),
_startId(startId) {}
BSONObj serialize() const override {
BSONObjBuilder spec;
@ -133,6 +167,12 @@ public:
spec.appendBool("sortById", true);
if (_produceScore)
spec.appendBool("produceScore", true);
if (!_valueNeeded)
spec.appendBool("skipValue", true);
if (!_labelNeeded)
spec.appendBool("skipLabel", true);
if (_startId != 0)
spec.append("startId", _startId);
return BSON(_name << spec.obj());
}
@ -142,11 +182,13 @@ public:
}
std::unique_ptr<sdk::ExecAggStageBase> compile() const override {
return std::make_unique<ProduceIdsExecStage>(_name, _sortById, _numDocs, _produceScore);
return std::make_unique<ProduceIdsExecStage>(
_name, _sortById, _numDocs, _produceScore, _valueNeeded, _startId, _labelNeeded);
}
std::unique_ptr<sdk::LogicalAggStage> clone() const override {
return std::make_unique<ProduceIdsLogicalStage>(_name, _sortById, _numDocs, _produceScore);
return std::make_unique<ProduceIdsLogicalStage>(
_name, _sortById, _numDocs, _produceScore, _valueNeeded, _startId, _labelNeeded);
}
BSONObj getSortPattern() const override {
@ -177,21 +219,81 @@ public:
}
BSONObj getFilter() const override {
// We will generate _ids from [0, numDocs). We can turn this range into a filter that can be
// used for shard targeting.
auto rangeFilter = BSON("$gte" << 0 << "$lt" << _numDocs);
// We will generate _ids from [_startId, _numDocs). We can turn this range into a filter
// that can be used for shard targeting.
auto rangeFilter = BSON("$gte" << _startId << "$lt" << _numDocs);
return BSON("_id" << rangeFilter);
}
void applyPipelineSuffixDependencies(
const extension::PipelineDependenciesHandle& deps) override {
_produceScore = deps->needsMetadata("score");
// Store the required-fields set for the "applyProjectPushdown" in-place rule to consume.
// boost::none means the whole document is needed. After the rule fires it
// resets _neededFields back to boost::none to prevent re-firing on later passes.
_neededFields = deps->getNeededFields();
}
bool evaluateRulePrecondition(std::string_view ruleName,
extension::ConstPipelineRewriteContextHandle ctx) const override {
// $readNDocuments desugars to [$produceIds, $_internalSearchIdLookup, ...]. From
// $produceIds, index 1 is the id-lookup stage and index 2 is the first user-added stage.
if (ruleName == "applyMatchPushdown") {
if (!ctx->hasAtLeastNNextStages(2) || ctx->getNthNextStage(2)->getName() != "$match")
return false;
auto bound = extractIdLowerBound(ctx->getNthNextStage(2)->getFilter());
if (!bound)
return false;
_startId = *bound;
return true;
}
if (ruleName == "applyProjectPushdown") {
return _neededFields.has_value();
}
return false;
}
bool evaluateRuleTransform(std::string_view ruleName,
extension::PipelineRewriteContextHandle ctx) override {
if (ruleName == "applyMatchPushdown") {
ctx->eraseNthNext(2);
return true;
}
if (ruleName == "applyProjectPushdown") {
_valueNeeded = false;
_labelNeeded = false;
for (auto elem : *_neededFields) {
auto name = elem.valueStringDataSafe();
if (name == "value")
_valueNeeded = true;
else if (name == "label")
_labelNeeded = true;
}
_neededFields = boost::none;
return false;
}
return false;
}
static boost::optional<int> extractIdLowerBound(const BSONObj& filter) {
auto idElem = filter["_id"];
if (idElem.eoo())
return boost::none;
if (auto gt = idElem["$gt"]; !gt.eoo())
return gt.safeNumberInt() + 1;
if (auto gte = idElem["$gte"]; !gte.eoo())
return gte.safeNumberInt();
return boost::none;
}
private:
const bool _sortById;
const int _numDocs;
bool _produceScore;
bool _valueNeeded;
bool _labelNeeded;
mutable int _startId;
boost::optional<BSONObj> _neededFields;
};
class ProduceIdsAstNode : public sdk::TestAstNode<ProduceIdsLogicalStage> {
@ -269,6 +371,12 @@ public:
void initialize(const sdk::HostPortalHandle& portal) override {
_registerStage<ReadNDocumentsStageDescriptor>(portal);
_registerStage<ProduceIdsStageDescriptor>(portal);
_registerStageRules<ProduceIdsStageDescriptor>(
portal,
{
{"applyMatchPushdown", extension::kReordering},
{"applyProjectPushdown", extension::kInPlace},
});
}
};
REGISTER_EXTENSION(ReadNDocumentsExtension)