SERVER-113679 Take cost of SORT spilling into account (#43934)

GitOrigin-RevId: 4eeaf518cd4902fcaaf30a8af2def65260ea9df9
This commit is contained in:
Max Verbinnen 2025-11-13 20:47:23 +00:00 committed by MongoDB Bot
parent c18d3a8e39
commit 5b7db75440
4 changed files with 68 additions and 7 deletions

View File

@ -351,8 +351,14 @@ doc_scan_collection = create_coll_scan_collection_template(
)
sort_collections = create_coll_scan_collection_template(
"sort",
cardinalities=[5, 10, 50, 75, 100, 150, 300, 400, 500, 750, 1000],
payload_size=10,
# We add '2' here to calibrate the startup cost in qsn_calibrator
cardinalities=[2] + list(range(1000, 10_001, 1000)),
payload_size=1000,
)
large_sort_collections = create_coll_scan_collection_template(
"large_sort",
cardinalities=list(range(100_000, 150_001, 10_000)),
payload_size=1000,
)
merge_sort_collections = create_indexed_fields_template(
"merge_sort",
@ -401,6 +407,7 @@ data_generator = config.DataGeneratorConfig(
index_scan_collection,
doc_scan_collection,
sort_collections,
large_sort_collections,
merge_sort_collections,
or_collections,
intersection_sorted_collections,
@ -525,8 +532,19 @@ qsn_nodes = [
[
(df["n_processed"] * np.log2(df["n_processed"])).rename(
"n_processed * log2(n_processed)"
),
df["n_processed"],
)
],
axis=1,
),
),
config.QsNodeCalibrationConfig(
name="SORT_DEFAULT_SPILL",
type="SORT",
variables_override=lambda df: pd.concat(
[
(df["n_processed"] * np.log2(df["n_processed"])).rename(
"n_processed * log2(n_processed)"
)
],
axis=1,
),
@ -540,7 +558,18 @@ qsn_nodes = [
(df["n_processed"] * np.log2(df["n_processed"])).rename(
"n_processed * log2(n_processed)"
),
df["n_processed"],
],
axis=1,
),
),
config.QsNodeCalibrationConfig(
name="SORT_SIMPLE_SPILL",
type="SORT",
variables_override=lambda df: pd.concat(
[
(df["n_processed"] * np.log2(df["n_processed"])).rename(
"n_processed * log2(n_processed)"
)
],
axis=1,
),

View File

@ -81,6 +81,13 @@ def calibrate_node(
labels = X_vars.columns.tolist()
X = sm.add_constant(X_vars)
# TODO(SERVER-110398): Clean this up by not relying on fixed '2' value
if node_config.type == "SORT" and not X_vars.empty and "SPILL" not in node_name:
# This prints the average execution time for input cardinality '2'
mask = X_vars.iloc[:, 0] == 2.0
avg_exec_time = y[mask].mean()
print(f"{node_name}: Startup Cost -> {avg_exec_time} ns")
def fit(X, y):
nnls = LinearRegression(positive=True, fit_intercept=False)
model = nnls.fit(X, y)

View File

@ -308,7 +308,7 @@ async def execute_sorts(database: DatabaseInstance, collections: Sequence[Collec
# By combining a sort with a limit, we trigger the top-K sorting algorithm, which works
# for both the simple and default sort algorithms.
limits = [5, 10, 50, 75, 100, 150, 300, 500, 1000]
limits = [2, 5, 10, 50, 75, 100, 150, 300, 500, 1000]
for limit in limits:
requests.append(
Query(
@ -334,6 +334,30 @@ async def execute_sorts(database: DatabaseInstance, collections: Sequence[Collec
)
async def execute_sorts_spill(database: DatabaseInstance, collections: Sequence[CollectionInfo]):
collections = [c for c in collections if c.name.startswith("large_sort")]
assert len(collections) == 6
requests = [
# A standard sort applies the simple sort algorithm.
Query(
{"sort": {"payload": 1}},
note="SORT_SIMPLE_SPILL",
expected_stage={"SORT": {"usedDisk": True}},
),
# Including the recordId explicitly forces the use of the default sort algorithm.
Query(
{"projection": {"$recordId": {"$meta": "recordId"}}, "sort": {"payload": 1}},
note="SORT_DEFAULT_SPILL",
expected_stage={"SORT": {"usedDisk": True}},
),
]
await workload_execution.execute(
database, main_config.workload_execution, collections, requests
)
async def execute_merge_sorts(database: DatabaseInstance, collections: Sequence[CollectionInfo]):
collections = [c for c in collections if c.name.startswith("merge_sort")]
assert len(collections) == 11
@ -650,6 +674,7 @@ async def main():
execute_limits,
execute_skips,
execute_sorts,
execute_sorts_spill,
execute_merge_sorts,
execute_ors,
execute_sort_intersections,

View File

@ -146,7 +146,7 @@ class WorkloadExecution:
return False
def _check_explain(self, expected_stage: Any, explain: dict) -> bool:
root = explain["queryPlanner"]["winningPlan"]
root = explain["executionStats"]["executionStages"]
return self._check_explain_helper(expected_stage, root)
async def _run_query(self, coll_info: CollectionInfo, query: Query, result: Sequence):