SERVER-106251 Modify cost model calibration code to be able to parse classic execution trees (#37438)

GitOrigin-RevId: 2829d57a40eb743ce17d933830688505fcea21f7
This commit is contained in:
Asher Kornfeld 2025-06-20 16:25:03 -04:00 committed by MongoDB Bot
parent a8e86bea4d
commit d8beb34ebc
9 changed files with 806 additions and 666 deletions

View File

@ -35,7 +35,7 @@ from dataclasses import asdict, dataclass
from typing import Sequence
import bson.json_util as json
import execution_tree
import execution_tree_sbe
import physical_tree
from config import BenchmarkConfig
from database_instance import DatabaseInstance, Pipeline, get_database_parameter
@ -154,7 +154,7 @@ class ExperimentResult:
explain: Sequence[dict[str, any]]
physical_tree: Sequence[physical_tree.Node]
execution_tree: Sequence[execution_tree.Node]
execution_tree: Sequence[execution_tree_sbe.Node]
mean: float
def print(self, index: int = None):
@ -200,7 +200,7 @@ async def benchmark(config: BenchmarkConfig, database: DatabaseInstance, task: B
def make_variant(explain: Sequence[dict[str, any]]) -> ExperimentResult:
"""Make one variant of the A/B test."""
pt = [physical_tree.build(e["queryPlanner"]["winningPlan"]["queryPlan"]) for e in explain]
et = [execution_tree.build_execution_tree(e["executionStats"]) for e in explain]
et = [execution_tree_sbe.build_execution_tree(e["executionStats"]) for e in explain]
mean = sum(et.total_execution_time for et in et) / len(et)
return ExperimentResult(explain=explain, physical_tree=pt, execution_tree=et, mean=mean)

View File

@ -30,6 +30,20 @@
import functools
import time
"""
Returns a preorder traversal (node, child 1...child n) of the QSN/execution tree
For example, a query with an OR over 3 indices could turn into a tree rooted with a FETCH,
who has a single OR child, which in turn has 3 index scan children.
This would return a preorder of [FETCH, OR, IXSCAN1, IXSCAN2, IXSCAN3].
"""
def preorder(node):
res = [node]
for child in node.children:
res += preorder(child)
return res
def timer_decorator(func):
"""Measure and print out execution time of the decorated function."""

View File

@ -39,7 +39,7 @@ import os
from typing import Callable, Sequence, Tuple
import config
import execution_tree as et
import execution_tree_sbe as et
import experiment as exp
import numpy as np
import pandas as pd

View File

@ -0,0 +1,148 @@
# Copyright (C) 2025-present MongoDB, Inc.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the Server Side Public License, version 1,
# as published by MongoDB, Inc.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# Server Side Public License for more details.
#
# You should have received a copy of the Server Side Public License
# along with this program. If not, see
# <http://www.mongodb.com/licensing/server-side-public-license>.
#
# As a special exception, the copyright holders give permission to link the
# code of portions of this program with the OpenSSL library under certain
# conditions as described in each individual source file and distribute
# linked combinations including the program with the OpenSSL library. You
# must comply with the Server Side Public License in all respects for
# all of the code used other than as permitted herein. If you modify file(s)
# with this exception, you may extend this exception to your version of the
# file(s), but you are not obligated to do so. If you do not wish to do so,
# delete this exception statement from your version. If you delete this
# exception statement from all source files in the program, then also delete
# it in the license file.
#
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Optional
import bson.json_util as json
@dataclass
class Node:
"""Represent Classic Execution Node"""
stage: str
execution_time_nanoseconds: int
n_returned: int
n_processed: int
seeks: Optional[int]
children: list[Node]
def get_execution_time(self):
"""Execution time of this node without execution time of its children"""
return self.execution_time_nanoseconds - sum(
n.execution_time_nanoseconds for n in self.children
)
def print(self, level=0):
"""Pretty print the execution tree"""
print(
f'{"| " * level}{self.stage}, totalExecutionTime: {self.execution_time_nanoseconds:,}ns, seeks: {self.seeks}, nReturned: {self.n_returned}, nProcessed: {self.n_processed}'
)
for child in self.children:
child.print(level + 1)
def build_execution_tree(execution_stats: dict[str, Any]) -> Node:
"""Build Classic execution tree from 'executionStats' field of query explain"""
assert execution_stats["executionSuccess"]
return process_stage(execution_stats["executionStages"])
def process_stage(stage: dict[str, Any]) -> Node:
"""Parse the given execution stage"""
processors = {
"SUBPLAN": process_subplan,
"COLLSCAN": process_collscan,
"IXSCAN": process_ixscan,
"FETCH": process_fetch,
"AND_HASH": process_intersection,
"AND_SORTED": process_intersection,
"OR": process_or,
"MERGE_SORT": process_mergesort,
"SORT_MERGE": process_mergesort,
"SORT": process_sort,
"LIMIT": process_limitskip,
"SKIP": process_limitskip,
}
processor = processors.get(stage["stage"])
if processor is None:
print(json.dumps(stage, indent=4))
raise ValueError(f"Unknown stage: {stage}")
return processor(stage)
def process_subplan(stage: dict[str, Any]) -> Node:
input_stage = process_stage(stage["inputStage"])
return Node(**get_common_fields(stage), n_processed=stage["nReturned"], children=[input_stage])
def process_collscan(stage: dict[str, Any]) -> Node:
return Node(**get_common_fields(stage), n_processed=stage["docsExamined"], children=[])
def process_ixscan(stage: dict[str, Any]) -> Node:
return Node(**get_common_fields(stage), n_processed=stage["keysExamined"], children=[])
def process_fetch(stage: dict[str, Any]) -> Node:
input_stage = process_stage(stage["inputStage"])
return Node(
**get_common_fields(stage), n_processed=stage["docsExamined"], children=[input_stage]
)
def process_or(stage: dict[str, Any]) -> Node:
children = [process_stage(child) for child in stage["inputStages"]]
return Node(**get_common_fields(stage), n_processed=stage["nReturned"], children=children)
def process_intersection(stage: dict[str, Any]) -> Node:
children = [process_stage(child) for child in stage["inputStages"]]
n_processed = sum(child.n_processed for child in children)
return Node(**get_common_fields(stage), n_processed=n_processed, children=children)
def process_mergesort(stage: dict[str, Any]) -> Node:
children = [process_stage(child) for child in stage["inputStages"]]
return Node(**get_common_fields(stage), n_processed=stage["nReturned"], children=children)
def process_sort(stage: dict[str, Any]) -> Node:
input_stage = process_stage(stage["inputStage"])
return Node(**get_common_fields(stage), n_processed=stage["nReturned"], children=[input_stage])
def process_limitskip(stage: dict[str, Any]) -> Node:
input_stage = process_stage(stage["inputStage"])
return Node(
**get_common_fields(stage), n_processed=input_stage.n_processed, children=[input_stage]
)
def get_common_fields(json_stage: dict[str, Any]) -> dict[str, Any]:
"""Extract common fields from classic nodes"""
return {
"stage": json_stage["stage"],
"execution_time_nanoseconds": json_stage["executionTimeNanos"],
"n_returned": json_stage["nReturned"],
"seeks": json_stage.get("seeks"),
}

View File

@ -30,7 +30,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional
from typing import Any, Optional
import bson.json_util as json
@ -50,7 +50,7 @@ class Node:
children: list[Node]
def get_execution_time(self):
"""Execution time of the SBE node without execuion time of its children."""
"""Execution time of the SBE node without execution time of its children."""
return self.total_execution_time - sum(n.total_execution_time for n in self.children)
def print(self, level=0):
@ -62,13 +62,13 @@ class Node:
child.print(level + 1)
def build_execution_tree(execution_stats: dict[str, any]) -> Node:
"""Build SBE executioon tree from 'executionStats' field of query explain."""
def build_execution_tree(execution_stats: dict[str, Any]) -> Node:
"""Build SBE execution tree from 'executionStats' field of query explain."""
assert execution_stats["executionSuccess"]
return process_stage(execution_stats["executionStages"])
def process_stage(stage: dict[str, any]) -> Node:
def process_stage(stage: dict[str, Any]) -> Node:
"""Parse the given SBE stage."""
processors = {
"filter": process_filter,
@ -100,13 +100,13 @@ def process_stage(stage: dict[str, any]) -> Node:
return processor(stage)
def process_filter(stage: dict[str, any]) -> Node:
def process_filter(stage: dict[str, Any]) -> Node:
"""Process filter stage."""
input_stage = process_stage(stage["inputStage"])
return Node(**get_common_fields(stage), n_processed=stage["numTested"], children=[input_stage])
def process_traverse(stage: dict[str, any]) -> Node:
def process_traverse(stage: dict[str, Any]) -> Node:
"""Process traverse"""
outer_stage = process_stage(stage["outerStage"])
inner_stage = process_stage(stage["innerStage"])
@ -117,7 +117,7 @@ def process_traverse(stage: dict[str, any]) -> Node:
)
def process_hash_join_node(stage: dict[str, any]) -> Node:
def process_hash_join_node(stage: dict[str, Any]) -> Node:
"""Process hj node."""
outer_stage = process_stage(stage["outerStage"])
inner_stage = process_stage(stage["innerStage"])
@ -127,7 +127,7 @@ def process_hash_join_node(stage: dict[str, any]) -> Node:
)
def process_nlj(stage: dict[str, any]) -> Node:
def process_nlj(stage: dict[str, Any]) -> Node:
"""Process nlj stage."""
outer_stage = process_stage(stage["outerStage"])
inner_stage = process_stage(stage["innerStage"])
@ -137,7 +137,7 @@ def process_nlj(stage: dict[str, any]) -> Node:
)
def process_inner_node(stage: dict[str, any]) -> Node:
def process_inner_node(stage: dict[str, Any]) -> Node:
"""Process SBE stage with one input stage."""
input_stage = process_stage(stage["inputStage"])
return Node(
@ -145,23 +145,23 @@ def process_inner_node(stage: dict[str, any]) -> Node:
)
def process_leaf_node(stage: dict[str, any]) -> Node:
def process_leaf_node(stage: dict[str, Any]) -> Node:
"""Process SBE stage without input stages."""
return Node(**get_common_fields(stage), n_processed=stage["nReturned"], children=[])
def process_seek(stage: dict[str, any]) -> Node:
def process_seek(stage: dict[str, Any]) -> Node:
"""Process seek stage."""
return Node(**get_common_fields(stage), n_processed=stage["numReads"], children=[])
def process_union_node(stage: dict[str, any]) -> Node:
def process_union_node(stage: dict[str, Any]) -> Node:
"""Process union stage."""
children = [process_stage(child) for child in stage["inputStages"]]
return Node(**get_common_fields(stage), n_processed=stage["nReturned"], children=children)
def process_unwind_node(stage: dict[str, any]) -> Node:
def process_unwind_node(stage: dict[str, Any]) -> Node:
"""Process unwind stage."""
input_stage = process_stage(stage["inputStage"])
return Node(
@ -169,14 +169,14 @@ def process_unwind_node(stage: dict[str, any]) -> Node:
)
def process_unique_node(stage: dict[str, any]) -> Node:
def process_unique_node(stage: dict[str, Any]) -> Node:
"""Process unique stage."""
input_stage = process_stage(stage["inputStage"])
n_processed = stage["dupsTested"]
return Node(**get_common_fields(stage), n_processed=n_processed, children=[input_stage])
def process_branch_node(stage: dict[str, any]) -> Node:
def process_branch_node(stage: dict[str, Any]) -> Node:
"""Process unique stage."""
then_stage = process_stage(stage["thenStage"])
else_stage = process_stage(stage["elseStage"])
@ -186,8 +186,8 @@ def process_branch_node(stage: dict[str, any]) -> Node:
)
def get_common_fields(json_stage: dict[str, any]) -> dict[str, any]:
"""Exctract common field from json representation of SBE stage."""
def get_common_fields(json_stage: dict[str, Any]) -> dict[str, Any]:
"""Extract common field from json representation of SBE stage."""
return {
"stage": json_stage["stage"],
"plan_node_id": json_stage["planNodeId"],

View File

@ -98,7 +98,7 @@ from __future__ import annotations
import dataclasses
import bson.json_util as json
import execution_tree as sbe
import execution_tree_sbe as sbe
import pandas as pd
import physical_tree as abt
import seaborn as sns

View File

@ -33,7 +33,7 @@ from collections import defaultdict, deque
from typing import Callable, Mapping, Sequence, TypeVar
import bson.json_util as json
import execution_tree
import execution_tree_sbe
import physical_tree
from config import AbtCalibratorConfig
from cost_estimator import CostModelParameters, ExecutionStats
@ -87,9 +87,9 @@ def find_nodes(root: Node, predicate: Callable[[Node], bool]) -> list[Node]:
return result
def get_excution_stats(root: execution_tree.Node, node_id: int) -> ExecutionStats:
def get_excution_stats(root: execution_tree_sbe.Node, node_id: int) -> ExecutionStats:
"""Extract execution stats from the given Execution Tree for the ABT node defined with the given node_id."""
queue: deque[execution_tree.Node] = deque()
queue: deque[execution_tree_sbe.Node] = deque()
queue.append(root)
execution_time: int = 0
@ -120,7 +120,7 @@ def parse_explain(explain: Mapping[str, any], abt_types: Sequence[str]):
"""Extract ExecutionStats from the given explain for the given ABT types."""
try:
et = execution_tree.build_execution_tree(explain["executionStats"])
et = execution_tree_sbe.build_execution_tree(explain["executionStats"])
pt = physical_tree.build(explain["queryPlanner"]["winningPlan"]["queryPlan"])
except Exception as exception:
print(f"*** Failed to parse explain with the followinf error: {exception}")
@ -131,7 +131,7 @@ def parse_explain(explain: Mapping[str, any], abt_types: Sequence[str]):
def extract_execution_stats(
et: execution_tree.Node, pt: physical_tree.Node, abt_types: Sequence[str]
et: execution_tree_sbe.Node, pt: physical_tree.Node, abt_types: Sequence[str]
) -> Mapping[str, Sequence[ExecutionStats]]:
"""Extract ExecutionStats from the given SBE and ABT trees for the given ABT types."""

File diff suppressed because it is too large Load Diff

View File

@ -30,6 +30,7 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
__all__ = ["Node", "build"]
@ -39,32 +40,29 @@ class Node:
"""Represent Query Solution node."""
node_type: str
plan_node_id: int
children: list[Node]
def print(self, level=0):
"""Pretty print of the QSN tree."""
print(f'{"| "*level}{self.node_type}, planNodeId: {self.plan_node_id}')
print(f'{"| "*level}{self.node_type}')
for child in self.children:
child.print(level + 1)
def build(optimizer_plan: dict[str, any]) -> Node:
def build(optimizer_plan: dict[str, Any]) -> Node:
"""Build QSN tree from query explain."""
return parse_optimizer_node(optimizer_plan)
def parse_optimizer_node(explain_node: dict[str, any]) -> Node:
def parse_optimizer_node(explain_node: dict[str, Any]) -> Node:
"""Recursively parse QSN from query explain's node."""
children = get_children(explain_node)
return Node(
node_type=explain_node["stage"], plan_node_id=explain_node["planNodeId"], children=children
)
return Node(node_type=explain_node["stage"], children=children)
def get_children(explain_node: dict[str, any]) -> list[Node]:
def get_children(explain_node: dict[str, Any]) -> list[Node]:
"""Get children nodes of the QSN."""
if "inputStage" in explain_node: