SERVER-120017 Initial calibration of HJ vs INLJ (#48386)

GitOrigin-RevId: 36c1de462f4d4ace7d674fba91209d4d1fb0dba2
This commit is contained in:
Max Verbinnen 2026-03-20 09:01:36 +00:00 committed by MongoDB Bot
parent 054f892670
commit eefbef199f
5 changed files with 219 additions and 9 deletions

View File

@ -99,6 +99,16 @@ class DatabaseInstance:
verbosity="executionStats",
)
async def explain_aggregate(
self, collection_name: str, pipeline: list, verbosity: str = "executionStats"
) -> dict[str, any]:
"""Return explain for the given aggregation pipeline."""
return await self.database.command(
"explain",
{"aggregate": collection_name, "pipeline": pipeline, "cursor": {}},
verbosity=verbosity,
)
async def hide_index(self, collection_name: str, index_name: str) -> None:
"""Hide the given index from the query optimizer."""
await self.database.command(

View File

@ -48,11 +48,13 @@ import config
from random_generator import DataType, RandomDistribution, RangeGenerator, StringRandomDistribution
COLLECTION_CARDINALITY = 100_000
STRING_FILLER_LENGTH = 900
DEFAULT_STRING_FILLER_LENGTH = 900
LARGE_STRING_FILLER_LENGTH = 3_000
def create_join_collection_template(
name: str,
string_filler_length: int = DEFAULT_STRING_FILLER_LENGTH,
create_indexes: bool = False,
) -> config.CollectionTemplate:
"""Create a collection template for join calibration."""
@ -102,7 +104,7 @@ def create_join_collection_template(
config.FieldTemplate(
name="string_filler",
data_type=DataType.STRING,
distribution=StringRandomDistribution(STRING_FILLER_LENGTH, pool_size=1000),
distribution=StringRandomDistribution(string_filler_length, pool_size=1000),
indexed=False,
),
],
@ -115,11 +117,20 @@ join_coll_1 = create_join_collection_template(
name="join_coll_1",
create_indexes=False,
)
join_coll_1_large = create_join_collection_template(
name="join_coll_1_large",
string_filler_length=LARGE_STRING_FILLER_LENGTH,
create_indexes=False,
)
join_coll_2 = create_join_collection_template(
name="join_coll_2",
create_indexes=True,
)
join_coll_2_large = create_join_collection_template(
name="join_coll_2_large",
string_filler_length=LARGE_STRING_FILLER_LENGTH,
create_indexes=True,
)
join_database = config.DatabaseConfig(
connection_string=os.getenv("MONGODB_URI", "mongodb://localhost"),
@ -133,7 +144,7 @@ join_data_generator = config.DataGeneratorConfig(
enabled=True,
create_indexes=True,
batch_size=COLLECTION_CARDINALITY,
collection_templates=[join_coll_1, join_coll_2],
collection_templates=[join_coll_1, join_coll_1_large, join_coll_2, join_coll_2_large],
write_mode=config.WriteMode.REPLACE,
collection_name_with_card=False,
)

View File

@ -29,8 +29,8 @@
Join Cost Model Calibration entry point.
Assumptions:
- The WiredTiger cache is large enough ( 100 MB) to contain the join
calibration collection. If the WT cache were too small, the CPU
- The WiredTiger cache is large enough ( 100 MB) by default to contain the
join calibration collection. If the WT cache were too small, the CPU
measurements would be invalid.
"""
@ -44,6 +44,7 @@ from join_calibration_settings import (
join_data_generator,
join_database,
)
from join_workload_execution import run_join_explain
from mongod_manager import MongodManager
from scipy.stats import trim_mean
@ -182,6 +183,124 @@ async def calibrate_random_io(manager: MongodManager, num_lookups: int = 100):
return fetch_mean_ms
async def calibrate_join_algorithms(manager: MongodManager, num_runs: int = 10):
"""
Compares INLJ vs HJ by running $lookup queries across join fields and selectivities.
The WT cache is set to 256 MB so neither ~300 MB "large" collection fits entirely
in memory, forcing relatively realistic disk I/O during joins.
For each (join_field, predicate) combination we run three rounds, each after a
cold restart so cache conditions are comparable:
1. Any (algorithm chosen by optimizer)
2. Force INLJ
3. Force HJ
High-cardinality join fields (uniform_256, uniform_16) produce very large result
sets, so we don't actually run those join queries because they would take too long.
We just verify that the optimizer picks HJ, which is the best algorithm for these.
Prints a per-combination summary showing which algorithm is genuinely faster
and whether the optimizer's majority pick is correct.
"""
print(f"\n=== Join Algorithm Calibration ({num_runs} runs per config) ===")
join_cache_args = ["--wiredTigerCacheSizeGB", "0.25"]
join_fields = [
# (name, low or high cardinality)
("unique", "low"),
("uniform_64k", "low"),
("uniform_4k", "low"),
("uniform_256", "high"),
("uniform_16", "high"),
]
predicate_constants = [256, 1024, 4096, 16384, 65536]
header = (
f"{'join_field':<16} {'pred':<8} {'optimizer_picks':<24} "
f"{'INLJ_ms':>10} {'HJ_ms':>10} {'faster':>8} {'correct':>10}"
)
separator = "-" * len(header)
print(header)
print(separator)
for join_field in join_fields:
for pred_const in predicate_constants:
pipeline = [
{"$match": {"random": {"$lte": pred_const}}},
{
"$lookup": {
"from": "join_coll_2_large",
"localField": join_field[0],
"foreignField": join_field[0],
"as": "joined",
}
},
{"$unwind": "$joined"},
{"$count": "total"},
]
high_join_cardinality = join_field[1] == "high"
verbosity = "queryPlanner" if high_join_cardinality else "executionStats"
# Using the algorithm which the optimizer picks
manager.restart_cold(extra_start_args=join_cache_args)
await manager.database.set_parameter("internalJoinMethod", "any")
algo_freqs: dict[str, int] = {}
for _ in range(num_runs):
_, algo = await run_join_explain(
manager.database, "join_coll_1_large", pipeline, verbosity
)
algo_freqs[algo] = algo_freqs.get(algo, 0) + 1
optimizer_picks_str = " ".join(
f"{algo} {freq}/{num_runs}"
for algo, freq in sorted(algo_freqs.items(), key=lambda x: -x[1])
)
majority_algo = max(algo_freqs, key=algo_freqs.get)
inlj_mean, hj_mean = None, None
if not high_join_cardinality:
# Forcing INLJ
manager.restart_cold(extra_start_args=join_cache_args)
await manager.database.set_parameter("internalJoinMethod", "INLJ")
inlj_times = []
for _ in range(num_runs):
t, algo = await run_join_explain(
manager.database, "join_coll_1_large", pipeline
)
assert algo == "INLJ", f"Expected INLJ but got {algo}"
inlj_times.append(t)
# Forcing HJ
manager.restart_cold(extra_start_args=join_cache_args)
await manager.database.set_parameter("internalJoinMethod", "HJ")
hj_times = []
for _ in range(num_runs):
t, algo = await run_join_explain(
manager.database, "join_coll_1_large", pipeline
)
assert algo == "HJ", f"Expected HJ but got {algo}"
hj_times.append(t)
inlj_mean = trim_mean(inlj_times, proportiontocut=TRIMMED_MEAN_PROPORTION)
hj_mean = trim_mean(hj_times, proportiontocut=TRIMMED_MEAN_PROPORTION)
# Print whether the optimizer is making the correct decision
faster = "HJ" if high_join_cardinality or hj_mean < inlj_mean else "INLJ"
correct = "" if majority_algo == faster else ""
def fmt_mean(v):
return f"{v:>10.1f}" if v is not None else f"{'-':>10}"
print(
f"{join_field[0]:<16} {pred_const:<8} {optimizer_picks_str:<24} "
f"{fmt_mean(inlj_mean)} {fmt_mean(hj_mean)} {faster:>8} {correct:>10}"
)
await manager.database.set_parameter("internalJoinMethod", "any")
print(separator)
async def main():
"""Entry point function."""
script_directory = os.path.abspath(os.path.dirname(__file__))
@ -200,6 +319,8 @@ async def main():
"all",
"--setParameter",
"internalMeasureQueryExecutionTimeInNanoseconds=true",
"--setParameter",
"internalEnableJoinOptimization=true",
],
) as manager:
generator = DataGenerator(manager.database, join_data_generator)
@ -222,6 +343,8 @@ async def main():
f" ({time_rand_page_ms:.4f}ms / {time_tuple_ms:.6f}ms)"
)
await calibrate_join_algorithms(manager)
print("DONE!")

View File

@ -0,0 +1,65 @@
# Copyright (C) 2026-present MongoDB, Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the Server Side Public License, version 1,
# as published by MongoDB, Inc.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# Server Side Public License for more details.
#
# You should have received a copy of the Server Side Public License
# along with this program. If not, see
# <http://www.mongodb.com/licensing/server-side-public-license>.
#
# As a special exception, the copyright holders give permission to link the
# code of portions of this program with the OpenSSL library under certain
# conditions as described in each individual source file and distribute
# linked combinations including the program with the OpenSSL library. You
# must comply with the Server Side Public License in all respects for
# all of the code used other than as permitted herein. If you modify file(s)
# with this exception, you may extend this exception to your version of the
# file(s), but you are not obligated to do so. If you do not wish to do so,
# delete this exception statement from your version. If you delete this
# exception statement from all source files in the program, then also delete
# it in the license file.
#
"""Join workload execution helpers for calibration."""
from __future__ import annotations
from database_instance import DatabaseInstance
def abbreviate_stage(stage_name: str) -> str:
"""Abbreviate join algorithm stage names."""
abbreviations = {
"HASH_JOIN_EMBEDDING": "HJ",
"NESTED_LOOP_JOIN_EMBEDDING": "NLJ",
"INDEXED_NESTED_LOOP_JOIN_EMBEDDING": "INLJ",
}
return abbreviations.get(stage_name, stage_name)
async def run_join_explain(
database: DatabaseInstance,
collection_name: str,
pipeline: list,
verbosity: str = "executionStats",
) -> tuple[float | None, str]:
"""
Run an aggregation explain and return (exec_time_ms, algorithm).
If verbosity is 'queryPlanner', exec_time_ms will be None.
"""
explain = await database.explain_aggregate(collection_name, pipeline, verbosity)
cursor = explain["stages"][0]["$cursor"]
exec_time_ms = (
cursor["executionStats"]["executionTimeMillis"] if verbosity == "executionStats" else None
)
query_plan = cursor["queryPlanner"]["winningPlan"]["queryPlan"]
algorithm = abbreviate_stage(query_plan["stage"])
assert query_plan["leftEmbeddingField"] == "none", f"Expected {collection_name} as outer table"
return exec_time_ms, algorithm

View File

@ -68,7 +68,7 @@ class MongodManager:
def is_running(self):
return self._proc is not None and self._proc.poll() is None
def start(self):
def start(self, extra_start_args=None):
"""Start the mongod process."""
if self.is_running:
raise RuntimeError("mongod is already running")
@ -81,6 +81,7 @@ class MongodManager:
"--dbpath",
self.dbpath,
*self.extra_args,
*(extra_start_args or []),
],
stdout=log_fh,
stderr=subprocess.STDOUT,
@ -103,11 +104,11 @@ class MongodManager:
self._proc = None
self._database = None
def restart_cold(self):
def restart_cold(self, extra_start_args=None):
"""Stop mongod, drop OS page cache, start fresh."""
self.stop()
self.flush_os_cache()
self.start()
self.start(extra_start_args=extra_start_args)
def flush_os_cache(self):
"""Flush the OS page cache."""