SERVER-120017 Initial calibration of HJ vs INLJ (#48386)
GitOrigin-RevId: 36c1de462f4d4ace7d674fba91209d4d1fb0dba2
This commit is contained in:
parent
054f892670
commit
eefbef199f
@ -99,6 +99,16 @@ class DatabaseInstance:
|
||||
verbosity="executionStats",
|
||||
)
|
||||
|
||||
async def explain_aggregate(
|
||||
self, collection_name: str, pipeline: list, verbosity: str = "executionStats"
|
||||
) -> dict[str, any]:
|
||||
"""Return explain for the given aggregation pipeline."""
|
||||
return await self.database.command(
|
||||
"explain",
|
||||
{"aggregate": collection_name, "pipeline": pipeline, "cursor": {}},
|
||||
verbosity=verbosity,
|
||||
)
|
||||
|
||||
async def hide_index(self, collection_name: str, index_name: str) -> None:
|
||||
"""Hide the given index from the query optimizer."""
|
||||
await self.database.command(
|
||||
|
||||
@ -48,11 +48,13 @@ import config
|
||||
from random_generator import DataType, RandomDistribution, RangeGenerator, StringRandomDistribution
|
||||
|
||||
COLLECTION_CARDINALITY = 100_000
|
||||
STRING_FILLER_LENGTH = 900
|
||||
DEFAULT_STRING_FILLER_LENGTH = 900
|
||||
LARGE_STRING_FILLER_LENGTH = 3_000
|
||||
|
||||
|
||||
def create_join_collection_template(
|
||||
name: str,
|
||||
string_filler_length: int = DEFAULT_STRING_FILLER_LENGTH,
|
||||
create_indexes: bool = False,
|
||||
) -> config.CollectionTemplate:
|
||||
"""Create a collection template for join calibration."""
|
||||
@ -102,7 +104,7 @@ def create_join_collection_template(
|
||||
config.FieldTemplate(
|
||||
name="string_filler",
|
||||
data_type=DataType.STRING,
|
||||
distribution=StringRandomDistribution(STRING_FILLER_LENGTH, pool_size=1000),
|
||||
distribution=StringRandomDistribution(string_filler_length, pool_size=1000),
|
||||
indexed=False,
|
||||
),
|
||||
],
|
||||
@ -115,11 +117,20 @@ join_coll_1 = create_join_collection_template(
|
||||
name="join_coll_1",
|
||||
create_indexes=False,
|
||||
)
|
||||
|
||||
join_coll_1_large = create_join_collection_template(
|
||||
name="join_coll_1_large",
|
||||
string_filler_length=LARGE_STRING_FILLER_LENGTH,
|
||||
create_indexes=False,
|
||||
)
|
||||
join_coll_2 = create_join_collection_template(
|
||||
name="join_coll_2",
|
||||
create_indexes=True,
|
||||
)
|
||||
join_coll_2_large = create_join_collection_template(
|
||||
name="join_coll_2_large",
|
||||
string_filler_length=LARGE_STRING_FILLER_LENGTH,
|
||||
create_indexes=True,
|
||||
)
|
||||
|
||||
join_database = config.DatabaseConfig(
|
||||
connection_string=os.getenv("MONGODB_URI", "mongodb://localhost"),
|
||||
@ -133,7 +144,7 @@ join_data_generator = config.DataGeneratorConfig(
|
||||
enabled=True,
|
||||
create_indexes=True,
|
||||
batch_size=COLLECTION_CARDINALITY,
|
||||
collection_templates=[join_coll_1, join_coll_2],
|
||||
collection_templates=[join_coll_1, join_coll_1_large, join_coll_2, join_coll_2_large],
|
||||
write_mode=config.WriteMode.REPLACE,
|
||||
collection_name_with_card=False,
|
||||
)
|
||||
|
||||
@ -29,8 +29,8 @@
|
||||
Join Cost Model Calibration entry point.
|
||||
|
||||
Assumptions:
|
||||
- The WiredTiger cache is large enough (≥ 100 MB) to contain the join
|
||||
calibration collection. If the WT cache were too small, the CPU
|
||||
- The WiredTiger cache is large enough (≥ 100 MB) by default to contain the
|
||||
join calibration collection. If the WT cache were too small, the CPU
|
||||
measurements would be invalid.
|
||||
"""
|
||||
|
||||
@ -44,6 +44,7 @@ from join_calibration_settings import (
|
||||
join_data_generator,
|
||||
join_database,
|
||||
)
|
||||
from join_workload_execution import run_join_explain
|
||||
from mongod_manager import MongodManager
|
||||
from scipy.stats import trim_mean
|
||||
|
||||
@ -182,6 +183,124 @@ async def calibrate_random_io(manager: MongodManager, num_lookups: int = 100):
|
||||
return fetch_mean_ms
|
||||
|
||||
|
||||
async def calibrate_join_algorithms(manager: MongodManager, num_runs: int = 10):
|
||||
"""
|
||||
Compares INLJ vs HJ by running $lookup queries across join fields and selectivities.
|
||||
|
||||
The WT cache is set to 256 MB so neither ~300 MB "large" collection fits entirely
|
||||
in memory, forcing relatively realistic disk I/O during joins.
|
||||
|
||||
For each (join_field, predicate) combination we run three rounds, each after a
|
||||
cold restart so cache conditions are comparable:
|
||||
1. Any (algorithm chosen by optimizer)
|
||||
2. Force INLJ
|
||||
3. Force HJ
|
||||
|
||||
High-cardinality join fields (uniform_256, uniform_16) produce very large result
|
||||
sets, so we don't actually run those join queries because they would take too long.
|
||||
We just verify that the optimizer picks HJ, which is the best algorithm for these.
|
||||
|
||||
Prints a per-combination summary showing which algorithm is genuinely faster
|
||||
and whether the optimizer's majority pick is correct.
|
||||
"""
|
||||
print(f"\n=== Join Algorithm Calibration ({num_runs} runs per config) ===")
|
||||
|
||||
join_cache_args = ["--wiredTigerCacheSizeGB", "0.25"]
|
||||
join_fields = [
|
||||
# (name, low or high cardinality)
|
||||
("unique", "low"),
|
||||
("uniform_64k", "low"),
|
||||
("uniform_4k", "low"),
|
||||
("uniform_256", "high"),
|
||||
("uniform_16", "high"),
|
||||
]
|
||||
predicate_constants = [256, 1024, 4096, 16384, 65536]
|
||||
|
||||
header = (
|
||||
f"{'join_field':<16} {'pred':<8} {'optimizer_picks':<24} "
|
||||
f"{'INLJ_ms':>10} {'HJ_ms':>10} {'faster':>8} {'correct':>10}"
|
||||
)
|
||||
separator = "-" * len(header)
|
||||
print(header)
|
||||
print(separator)
|
||||
|
||||
for join_field in join_fields:
|
||||
for pred_const in predicate_constants:
|
||||
pipeline = [
|
||||
{"$match": {"random": {"$lte": pred_const}}},
|
||||
{
|
||||
"$lookup": {
|
||||
"from": "join_coll_2_large",
|
||||
"localField": join_field[0],
|
||||
"foreignField": join_field[0],
|
||||
"as": "joined",
|
||||
}
|
||||
},
|
||||
{"$unwind": "$joined"},
|
||||
{"$count": "total"},
|
||||
]
|
||||
high_join_cardinality = join_field[1] == "high"
|
||||
verbosity = "queryPlanner" if high_join_cardinality else "executionStats"
|
||||
|
||||
# Using the algorithm which the optimizer picks
|
||||
manager.restart_cold(extra_start_args=join_cache_args)
|
||||
await manager.database.set_parameter("internalJoinMethod", "any")
|
||||
algo_freqs: dict[str, int] = {}
|
||||
for _ in range(num_runs):
|
||||
_, algo = await run_join_explain(
|
||||
manager.database, "join_coll_1_large", pipeline, verbosity
|
||||
)
|
||||
algo_freqs[algo] = algo_freqs.get(algo, 0) + 1
|
||||
|
||||
optimizer_picks_str = " ".join(
|
||||
f"{algo} {freq}/{num_runs}"
|
||||
for algo, freq in sorted(algo_freqs.items(), key=lambda x: -x[1])
|
||||
)
|
||||
majority_algo = max(algo_freqs, key=algo_freqs.get)
|
||||
|
||||
inlj_mean, hj_mean = None, None
|
||||
if not high_join_cardinality:
|
||||
# Forcing INLJ
|
||||
manager.restart_cold(extra_start_args=join_cache_args)
|
||||
await manager.database.set_parameter("internalJoinMethod", "INLJ")
|
||||
inlj_times = []
|
||||
for _ in range(num_runs):
|
||||
t, algo = await run_join_explain(
|
||||
manager.database, "join_coll_1_large", pipeline
|
||||
)
|
||||
assert algo == "INLJ", f"Expected INLJ but got {algo}"
|
||||
inlj_times.append(t)
|
||||
|
||||
# Forcing HJ
|
||||
manager.restart_cold(extra_start_args=join_cache_args)
|
||||
await manager.database.set_parameter("internalJoinMethod", "HJ")
|
||||
hj_times = []
|
||||
for _ in range(num_runs):
|
||||
t, algo = await run_join_explain(
|
||||
manager.database, "join_coll_1_large", pipeline
|
||||
)
|
||||
assert algo == "HJ", f"Expected HJ but got {algo}"
|
||||
hj_times.append(t)
|
||||
|
||||
inlj_mean = trim_mean(inlj_times, proportiontocut=TRIMMED_MEAN_PROPORTION)
|
||||
hj_mean = trim_mean(hj_times, proportiontocut=TRIMMED_MEAN_PROPORTION)
|
||||
|
||||
# Print whether the optimizer is making the correct decision
|
||||
faster = "HJ" if high_join_cardinality or hj_mean < inlj_mean else "INLJ"
|
||||
correct = "✓" if majority_algo == faster else "✗"
|
||||
|
||||
def fmt_mean(v):
|
||||
return f"{v:>10.1f}" if v is not None else f"{'-':>10}"
|
||||
|
||||
print(
|
||||
f"{join_field[0]:<16} {pred_const:<8} {optimizer_picks_str:<24} "
|
||||
f"{fmt_mean(inlj_mean)} {fmt_mean(hj_mean)} {faster:>8} {correct:>10}"
|
||||
)
|
||||
|
||||
await manager.database.set_parameter("internalJoinMethod", "any")
|
||||
print(separator)
|
||||
|
||||
|
||||
async def main():
|
||||
"""Entry point function."""
|
||||
script_directory = os.path.abspath(os.path.dirname(__file__))
|
||||
@ -200,6 +319,8 @@ async def main():
|
||||
"all",
|
||||
"--setParameter",
|
||||
"internalMeasureQueryExecutionTimeInNanoseconds=true",
|
||||
"--setParameter",
|
||||
"internalEnableJoinOptimization=true",
|
||||
],
|
||||
) as manager:
|
||||
generator = DataGenerator(manager.database, join_data_generator)
|
||||
@ -222,6 +343,8 @@ async def main():
|
||||
f" ({time_rand_page_ms:.4f}ms / {time_tuple_ms:.6f}ms)"
|
||||
)
|
||||
|
||||
await calibrate_join_algorithms(manager)
|
||||
|
||||
print("DONE!")
|
||||
|
||||
|
||||
|
||||
65
buildscripts/cost_model/join_workload_execution.py
Normal file
65
buildscripts/cost_model/join_workload_execution.py
Normal file
@ -0,0 +1,65 @@
|
||||
# Copyright (C) 2026-present MongoDB, Inc.
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the Server Side Public License, version 1,
|
||||
# as published by MongoDB, Inc.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# Server Side Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the Server Side Public License
|
||||
# along with this program. If not, see
|
||||
# <http://www.mongodb.com/licensing/server-side-public-license>.
|
||||
#
|
||||
# As a special exception, the copyright holders give permission to link the
|
||||
# code of portions of this program with the OpenSSL library under certain
|
||||
# conditions as described in each individual source file and distribute
|
||||
# linked combinations including the program with the OpenSSL library. You
|
||||
# must comply with the Server Side Public License in all respects for
|
||||
# all of the code used other than as permitted herein. If you modify file(s)
|
||||
# with this exception, you may extend this exception to your version of the
|
||||
# file(s), but you are not obligated to do so. If you do not wish to do so,
|
||||
# delete this exception statement from your version. If you delete this
|
||||
# exception statement from all source files in the program, then also delete
|
||||
# it in the license file.
|
||||
#
|
||||
"""Join workload execution helpers for calibration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from database_instance import DatabaseInstance
|
||||
|
||||
|
||||
def abbreviate_stage(stage_name: str) -> str:
|
||||
"""Abbreviate join algorithm stage names."""
|
||||
abbreviations = {
|
||||
"HASH_JOIN_EMBEDDING": "HJ",
|
||||
"NESTED_LOOP_JOIN_EMBEDDING": "NLJ",
|
||||
"INDEXED_NESTED_LOOP_JOIN_EMBEDDING": "INLJ",
|
||||
}
|
||||
return abbreviations.get(stage_name, stage_name)
|
||||
|
||||
|
||||
async def run_join_explain(
|
||||
database: DatabaseInstance,
|
||||
collection_name: str,
|
||||
pipeline: list,
|
||||
verbosity: str = "executionStats",
|
||||
) -> tuple[float | None, str]:
|
||||
"""
|
||||
Run an aggregation explain and return (exec_time_ms, algorithm).
|
||||
If verbosity is 'queryPlanner', exec_time_ms will be None.
|
||||
"""
|
||||
explain = await database.explain_aggregate(collection_name, pipeline, verbosity)
|
||||
cursor = explain["stages"][0]["$cursor"]
|
||||
exec_time_ms = (
|
||||
cursor["executionStats"]["executionTimeMillis"] if verbosity == "executionStats" else None
|
||||
)
|
||||
query_plan = cursor["queryPlanner"]["winningPlan"]["queryPlan"]
|
||||
algorithm = abbreviate_stage(query_plan["stage"])
|
||||
|
||||
assert query_plan["leftEmbeddingField"] == "none", f"Expected {collection_name} as outer table"
|
||||
|
||||
return exec_time_ms, algorithm
|
||||
@ -68,7 +68,7 @@ class MongodManager:
|
||||
def is_running(self):
|
||||
return self._proc is not None and self._proc.poll() is None
|
||||
|
||||
def start(self):
|
||||
def start(self, extra_start_args=None):
|
||||
"""Start the mongod process."""
|
||||
if self.is_running:
|
||||
raise RuntimeError("mongod is already running")
|
||||
@ -81,6 +81,7 @@ class MongodManager:
|
||||
"--dbpath",
|
||||
self.dbpath,
|
||||
*self.extra_args,
|
||||
*(extra_start_args or []),
|
||||
],
|
||||
stdout=log_fh,
|
||||
stderr=subprocess.STDOUT,
|
||||
@ -103,11 +104,11 @@ class MongodManager:
|
||||
self._proc = None
|
||||
self._database = None
|
||||
|
||||
def restart_cold(self):
|
||||
def restart_cold(self, extra_start_args=None):
|
||||
"""Stop mongod, drop OS page cache, start fresh."""
|
||||
self.stop()
|
||||
self.flush_os_cache()
|
||||
self.start()
|
||||
self.start(extra_start_args=extra_start_args)
|
||||
|
||||
def flush_os_cache(self):
|
||||
"""Flush the OS page cache."""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user