diff --git a/buildscripts/cost_model/database_instance.py b/buildscripts/cost_model/database_instance.py index 1e1a61aded5..05ee018a40b 100644 --- a/buildscripts/cost_model/database_instance.py +++ b/buildscripts/cost_model/database_instance.py @@ -99,6 +99,16 @@ class DatabaseInstance: verbosity="executionStats", ) + async def explain_aggregate( + self, collection_name: str, pipeline: list, verbosity: str = "executionStats" + ) -> dict[str, any]: + """Return explain for the given aggregation pipeline.""" + return await self.database.command( + "explain", + {"aggregate": collection_name, "pipeline": pipeline, "cursor": {}}, + verbosity=verbosity, + ) + async def hide_index(self, collection_name: str, index_name: str) -> None: """Hide the given index from the query optimizer.""" await self.database.command( diff --git a/buildscripts/cost_model/join_calibration_settings.py b/buildscripts/cost_model/join_calibration_settings.py index a771d619f78..af2d0421587 100644 --- a/buildscripts/cost_model/join_calibration_settings.py +++ b/buildscripts/cost_model/join_calibration_settings.py @@ -48,11 +48,13 @@ import config from random_generator import DataType, RandomDistribution, RangeGenerator, StringRandomDistribution COLLECTION_CARDINALITY = 100_000 -STRING_FILLER_LENGTH = 900 +DEFAULT_STRING_FILLER_LENGTH = 900 +LARGE_STRING_FILLER_LENGTH = 3_000 def create_join_collection_template( name: str, + string_filler_length: int = DEFAULT_STRING_FILLER_LENGTH, create_indexes: bool = False, ) -> config.CollectionTemplate: """Create a collection template for join calibration.""" @@ -102,7 +104,7 @@ def create_join_collection_template( config.FieldTemplate( name="string_filler", data_type=DataType.STRING, - distribution=StringRandomDistribution(STRING_FILLER_LENGTH, pool_size=1000), + distribution=StringRandomDistribution(string_filler_length, pool_size=1000), indexed=False, ), ], @@ -115,11 +117,20 @@ join_coll_1 = create_join_collection_template( name="join_coll_1", create_indexes=False, ) - +join_coll_1_large = create_join_collection_template( + name="join_coll_1_large", + string_filler_length=LARGE_STRING_FILLER_LENGTH, + create_indexes=False, +) join_coll_2 = create_join_collection_template( name="join_coll_2", create_indexes=True, ) +join_coll_2_large = create_join_collection_template( + name="join_coll_2_large", + string_filler_length=LARGE_STRING_FILLER_LENGTH, + create_indexes=True, +) join_database = config.DatabaseConfig( connection_string=os.getenv("MONGODB_URI", "mongodb://localhost"), @@ -133,7 +144,7 @@ join_data_generator = config.DataGeneratorConfig( enabled=True, create_indexes=True, batch_size=COLLECTION_CARDINALITY, - collection_templates=[join_coll_1, join_coll_2], + collection_templates=[join_coll_1, join_coll_1_large, join_coll_2, join_coll_2_large], write_mode=config.WriteMode.REPLACE, collection_name_with_card=False, ) diff --git a/buildscripts/cost_model/join_start.py b/buildscripts/cost_model/join_start.py index 182fc9a4f96..72950e23092 100644 --- a/buildscripts/cost_model/join_start.py +++ b/buildscripts/cost_model/join_start.py @@ -29,8 +29,8 @@ Join Cost Model Calibration entry point. Assumptions: -- The WiredTiger cache is large enough (≥ 100 MB) to contain the join - calibration collection. If the WT cache were too small, the CPU +- The WiredTiger cache is large enough (≥ 100 MB) by default to contain the + join calibration collection. If the WT cache were too small, the CPU measurements would be invalid. """ @@ -44,6 +44,7 @@ from join_calibration_settings import ( join_data_generator, join_database, ) +from join_workload_execution import run_join_explain from mongod_manager import MongodManager from scipy.stats import trim_mean @@ -182,6 +183,124 @@ async def calibrate_random_io(manager: MongodManager, num_lookups: int = 100): return fetch_mean_ms +async def calibrate_join_algorithms(manager: MongodManager, num_runs: int = 10): + """ + Compares INLJ vs HJ by running $lookup queries across join fields and selectivities. + + The WT cache is set to 256 MB so neither ~300 MB "large" collection fits entirely + in memory, forcing relatively realistic disk I/O during joins. + + For each (join_field, predicate) combination we run three rounds, each after a + cold restart so cache conditions are comparable: + 1. Any (algorithm chosen by optimizer) + 2. Force INLJ + 3. Force HJ + + High-cardinality join fields (uniform_256, uniform_16) produce very large result + sets, so we don't actually run those join queries because they would take too long. + We just verify that the optimizer picks HJ, which is the best algorithm for these. + + Prints a per-combination summary showing which algorithm is genuinely faster + and whether the optimizer's majority pick is correct. + """ + print(f"\n=== Join Algorithm Calibration ({num_runs} runs per config) ===") + + join_cache_args = ["--wiredTigerCacheSizeGB", "0.25"] + join_fields = [ + # (name, low or high cardinality) + ("unique", "low"), + ("uniform_64k", "low"), + ("uniform_4k", "low"), + ("uniform_256", "high"), + ("uniform_16", "high"), + ] + predicate_constants = [256, 1024, 4096, 16384, 65536] + + header = ( + f"{'join_field':<16} {'pred':<8} {'optimizer_picks':<24} " + f"{'INLJ_ms':>10} {'HJ_ms':>10} {'faster':>8} {'correct':>10}" + ) + separator = "-" * len(header) + print(header) + print(separator) + + for join_field in join_fields: + for pred_const in predicate_constants: + pipeline = [ + {"$match": {"random": {"$lte": pred_const}}}, + { + "$lookup": { + "from": "join_coll_2_large", + "localField": join_field[0], + "foreignField": join_field[0], + "as": "joined", + } + }, + {"$unwind": "$joined"}, + {"$count": "total"}, + ] + high_join_cardinality = join_field[1] == "high" + verbosity = "queryPlanner" if high_join_cardinality else "executionStats" + + # Using the algorithm which the optimizer picks + manager.restart_cold(extra_start_args=join_cache_args) + await manager.database.set_parameter("internalJoinMethod", "any") + algo_freqs: dict[str, int] = {} + for _ in range(num_runs): + _, algo = await run_join_explain( + manager.database, "join_coll_1_large", pipeline, verbosity + ) + algo_freqs[algo] = algo_freqs.get(algo, 0) + 1 + + optimizer_picks_str = " ".join( + f"{algo} {freq}/{num_runs}" + for algo, freq in sorted(algo_freqs.items(), key=lambda x: -x[1]) + ) + majority_algo = max(algo_freqs, key=algo_freqs.get) + + inlj_mean, hj_mean = None, None + if not high_join_cardinality: + # Forcing INLJ + manager.restart_cold(extra_start_args=join_cache_args) + await manager.database.set_parameter("internalJoinMethod", "INLJ") + inlj_times = [] + for _ in range(num_runs): + t, algo = await run_join_explain( + manager.database, "join_coll_1_large", pipeline + ) + assert algo == "INLJ", f"Expected INLJ but got {algo}" + inlj_times.append(t) + + # Forcing HJ + manager.restart_cold(extra_start_args=join_cache_args) + await manager.database.set_parameter("internalJoinMethod", "HJ") + hj_times = [] + for _ in range(num_runs): + t, algo = await run_join_explain( + manager.database, "join_coll_1_large", pipeline + ) + assert algo == "HJ", f"Expected HJ but got {algo}" + hj_times.append(t) + + inlj_mean = trim_mean(inlj_times, proportiontocut=TRIMMED_MEAN_PROPORTION) + hj_mean = trim_mean(hj_times, proportiontocut=TRIMMED_MEAN_PROPORTION) + + # Print whether the optimizer is making the correct decision + faster = "HJ" if high_join_cardinality or hj_mean < inlj_mean else "INLJ" + correct = "✓" if majority_algo == faster else "✗" + + def fmt_mean(v): + return f"{v:>10.1f}" if v is not None else f"{'-':>10}" + + print( + f"{join_field[0]:<16} {pred_const:<8} {optimizer_picks_str:<24} " + f"{fmt_mean(inlj_mean)} {fmt_mean(hj_mean)} {faster:>8} {correct:>10}" + ) + + await manager.database.set_parameter("internalJoinMethod", "any") + print(separator) + + async def main(): """Entry point function.""" script_directory = os.path.abspath(os.path.dirname(__file__)) @@ -200,6 +319,8 @@ async def main(): "all", "--setParameter", "internalMeasureQueryExecutionTimeInNanoseconds=true", + "--setParameter", + "internalEnableJoinOptimization=true", ], ) as manager: generator = DataGenerator(manager.database, join_data_generator) @@ -222,6 +343,8 @@ async def main(): f" ({time_rand_page_ms:.4f}ms / {time_tuple_ms:.6f}ms)" ) + await calibrate_join_algorithms(manager) + print("DONE!") diff --git a/buildscripts/cost_model/join_workload_execution.py b/buildscripts/cost_model/join_workload_execution.py new file mode 100644 index 00000000000..ad7117c7b96 --- /dev/null +++ b/buildscripts/cost_model/join_workload_execution.py @@ -0,0 +1,65 @@ +# Copyright (C) 2026-present MongoDB, Inc. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the Server Side Public License, version 1, +# as published by MongoDB, Inc. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# Server Side Public License for more details. +# +# You should have received a copy of the Server Side Public License +# along with this program. If not, see +# . +# +# As a special exception, the copyright holders give permission to link the +# code of portions of this program with the OpenSSL library under certain +# conditions as described in each individual source file and distribute +# linked combinations including the program with the OpenSSL library. You +# must comply with the Server Side Public License in all respects for +# all of the code used other than as permitted herein. If you modify file(s) +# with this exception, you may extend this exception to your version of the +# file(s), but you are not obligated to do so. If you do not wish to do so, +# delete this exception statement from your version. If you delete this +# exception statement from all source files in the program, then also delete +# it in the license file. +# +"""Join workload execution helpers for calibration.""" + +from __future__ import annotations + +from database_instance import DatabaseInstance + + +def abbreviate_stage(stage_name: str) -> str: + """Abbreviate join algorithm stage names.""" + abbreviations = { + "HASH_JOIN_EMBEDDING": "HJ", + "NESTED_LOOP_JOIN_EMBEDDING": "NLJ", + "INDEXED_NESTED_LOOP_JOIN_EMBEDDING": "INLJ", + } + return abbreviations.get(stage_name, stage_name) + + +async def run_join_explain( + database: DatabaseInstance, + collection_name: str, + pipeline: list, + verbosity: str = "executionStats", +) -> tuple[float | None, str]: + """ + Run an aggregation explain and return (exec_time_ms, algorithm). + If verbosity is 'queryPlanner', exec_time_ms will be None. + """ + explain = await database.explain_aggregate(collection_name, pipeline, verbosity) + cursor = explain["stages"][0]["$cursor"] + exec_time_ms = ( + cursor["executionStats"]["executionTimeMillis"] if verbosity == "executionStats" else None + ) + query_plan = cursor["queryPlanner"]["winningPlan"]["queryPlan"] + algorithm = abbreviate_stage(query_plan["stage"]) + + assert query_plan["leftEmbeddingField"] == "none", f"Expected {collection_name} as outer table" + + return exec_time_ms, algorithm diff --git a/buildscripts/cost_model/mongod_manager.py b/buildscripts/cost_model/mongod_manager.py index b5c7964c670..3cc97555fcb 100644 --- a/buildscripts/cost_model/mongod_manager.py +++ b/buildscripts/cost_model/mongod_manager.py @@ -68,7 +68,7 @@ class MongodManager: def is_running(self): return self._proc is not None and self._proc.poll() is None - def start(self): + def start(self, extra_start_args=None): """Start the mongod process.""" if self.is_running: raise RuntimeError("mongod is already running") @@ -81,6 +81,7 @@ class MongodManager: "--dbpath", self.dbpath, *self.extra_args, + *(extra_start_args or []), ], stdout=log_fh, stderr=subprocess.STDOUT, @@ -103,11 +104,11 @@ class MongodManager: self._proc = None self._database = None - def restart_cold(self): + def restart_cold(self, extra_start_args=None): """Stop mongod, drop OS page cache, start fresh.""" self.stop() self.flush_os_cache() - self.start() + self.start(extra_start_args=extra_start_args) def flush_os_cache(self): """Flush the OS page cache."""