SERVER-106251 Modify cost model calibration code to be able to parse classic execution trees (#37438)
GitOrigin-RevId: 2829d57a40eb743ce17d933830688505fcea21f7
This commit is contained in:
parent
a8e86bea4d
commit
d8beb34ebc
@ -35,7 +35,7 @@ from dataclasses import asdict, dataclass
|
||||
from typing import Sequence
|
||||
|
||||
import bson.json_util as json
|
||||
import execution_tree
|
||||
import execution_tree_sbe
|
||||
import physical_tree
|
||||
from config import BenchmarkConfig
|
||||
from database_instance import DatabaseInstance, Pipeline, get_database_parameter
|
||||
@ -154,7 +154,7 @@ class ExperimentResult:
|
||||
|
||||
explain: Sequence[dict[str, any]]
|
||||
physical_tree: Sequence[physical_tree.Node]
|
||||
execution_tree: Sequence[execution_tree.Node]
|
||||
execution_tree: Sequence[execution_tree_sbe.Node]
|
||||
mean: float
|
||||
|
||||
def print(self, index: int = None):
|
||||
@ -200,7 +200,7 @@ async def benchmark(config: BenchmarkConfig, database: DatabaseInstance, task: B
|
||||
def make_variant(explain: Sequence[dict[str, any]]) -> ExperimentResult:
|
||||
"""Make one variant of the A/B test."""
|
||||
pt = [physical_tree.build(e["queryPlanner"]["winningPlan"]["queryPlan"]) for e in explain]
|
||||
et = [execution_tree.build_execution_tree(e["executionStats"]) for e in explain]
|
||||
et = [execution_tree_sbe.build_execution_tree(e["executionStats"]) for e in explain]
|
||||
mean = sum(et.total_execution_time for et in et) / len(et)
|
||||
return ExperimentResult(explain=explain, physical_tree=pt, execution_tree=et, mean=mean)
|
||||
|
||||
|
||||
@ -30,6 +30,20 @@
|
||||
import functools
|
||||
import time
|
||||
|
||||
"""
|
||||
Returns a preorder traversal (node, child 1...child n) of the QSN/execution tree
|
||||
For example, a query with an OR over 3 indices could turn into a tree rooted with a FETCH,
|
||||
who has a single OR child, which in turn has 3 index scan children.
|
||||
This would return a preorder of [FETCH, OR, IXSCAN1, IXSCAN2, IXSCAN3].
|
||||
"""
|
||||
|
||||
|
||||
def preorder(node):
|
||||
res = [node]
|
||||
for child in node.children:
|
||||
res += preorder(child)
|
||||
return res
|
||||
|
||||
|
||||
def timer_decorator(func):
|
||||
"""Measure and print out execution time of the decorated function."""
|
||||
|
||||
@ -39,7 +39,7 @@ import os
|
||||
from typing import Callable, Sequence, Tuple
|
||||
|
||||
import config
|
||||
import execution_tree as et
|
||||
import execution_tree_sbe as et
|
||||
import experiment as exp
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
148
buildscripts/cost_model/execution_tree_classic.py
Normal file
148
buildscripts/cost_model/execution_tree_classic.py
Normal file
@ -0,0 +1,148 @@
|
||||
# Copyright (C) 2025-present MongoDB, Inc.
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the Server Side Public License, version 1,
|
||||
# as published by MongoDB, Inc.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# Server Side Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the Server Side Public License
|
||||
# along with this program. If not, see
|
||||
# <http://www.mongodb.com/licensing/server-side-public-license>.
|
||||
#
|
||||
# As a special exception, the copyright holders give permission to link the
|
||||
# code of portions of this program with the OpenSSL library under certain
|
||||
# conditions as described in each individual source file and distribute
|
||||
# linked combinations including the program with the OpenSSL library. You
|
||||
# must comply with the Server Side Public License in all respects for
|
||||
# all of the code used other than as permitted herein. If you modify file(s)
|
||||
# with this exception, you may extend this exception to your version of the
|
||||
# file(s), but you are not obligated to do so. If you do not wish to do so,
|
||||
# delete this exception statement from your version. If you delete this
|
||||
# exception statement from all source files in the program, then also delete
|
||||
# it in the license file.
|
||||
#
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import bson.json_util as json
|
||||
|
||||
|
||||
@dataclass
|
||||
class Node:
|
||||
"""Represent Classic Execution Node"""
|
||||
|
||||
stage: str
|
||||
execution_time_nanoseconds: int
|
||||
n_returned: int
|
||||
n_processed: int
|
||||
seeks: Optional[int]
|
||||
children: list[Node]
|
||||
|
||||
def get_execution_time(self):
|
||||
"""Execution time of this node without execution time of its children"""
|
||||
return self.execution_time_nanoseconds - sum(
|
||||
n.execution_time_nanoseconds for n in self.children
|
||||
)
|
||||
|
||||
def print(self, level=0):
|
||||
"""Pretty print the execution tree"""
|
||||
print(
|
||||
f'{"| " * level}{self.stage}, totalExecutionTime: {self.execution_time_nanoseconds:,}ns, seeks: {self.seeks}, nReturned: {self.n_returned}, nProcessed: {self.n_processed}'
|
||||
)
|
||||
for child in self.children:
|
||||
child.print(level + 1)
|
||||
|
||||
|
||||
def build_execution_tree(execution_stats: dict[str, Any]) -> Node:
|
||||
"""Build Classic execution tree from 'executionStats' field of query explain"""
|
||||
assert execution_stats["executionSuccess"]
|
||||
return process_stage(execution_stats["executionStages"])
|
||||
|
||||
|
||||
def process_stage(stage: dict[str, Any]) -> Node:
|
||||
"""Parse the given execution stage"""
|
||||
processors = {
|
||||
"SUBPLAN": process_subplan,
|
||||
"COLLSCAN": process_collscan,
|
||||
"IXSCAN": process_ixscan,
|
||||
"FETCH": process_fetch,
|
||||
"AND_HASH": process_intersection,
|
||||
"AND_SORTED": process_intersection,
|
||||
"OR": process_or,
|
||||
"MERGE_SORT": process_mergesort,
|
||||
"SORT_MERGE": process_mergesort,
|
||||
"SORT": process_sort,
|
||||
"LIMIT": process_limitskip,
|
||||
"SKIP": process_limitskip,
|
||||
}
|
||||
processor = processors.get(stage["stage"])
|
||||
if processor is None:
|
||||
print(json.dumps(stage, indent=4))
|
||||
raise ValueError(f"Unknown stage: {stage}")
|
||||
|
||||
return processor(stage)
|
||||
|
||||
|
||||
def process_subplan(stage: dict[str, Any]) -> Node:
|
||||
input_stage = process_stage(stage["inputStage"])
|
||||
return Node(**get_common_fields(stage), n_processed=stage["nReturned"], children=[input_stage])
|
||||
|
||||
|
||||
def process_collscan(stage: dict[str, Any]) -> Node:
|
||||
return Node(**get_common_fields(stage), n_processed=stage["docsExamined"], children=[])
|
||||
|
||||
|
||||
def process_ixscan(stage: dict[str, Any]) -> Node:
|
||||
return Node(**get_common_fields(stage), n_processed=stage["keysExamined"], children=[])
|
||||
|
||||
|
||||
def process_fetch(stage: dict[str, Any]) -> Node:
|
||||
input_stage = process_stage(stage["inputStage"])
|
||||
return Node(
|
||||
**get_common_fields(stage), n_processed=stage["docsExamined"], children=[input_stage]
|
||||
)
|
||||
|
||||
|
||||
def process_or(stage: dict[str, Any]) -> Node:
|
||||
children = [process_stage(child) for child in stage["inputStages"]]
|
||||
return Node(**get_common_fields(stage), n_processed=stage["nReturned"], children=children)
|
||||
|
||||
|
||||
def process_intersection(stage: dict[str, Any]) -> Node:
|
||||
children = [process_stage(child) for child in stage["inputStages"]]
|
||||
n_processed = sum(child.n_processed for child in children)
|
||||
return Node(**get_common_fields(stage), n_processed=n_processed, children=children)
|
||||
|
||||
|
||||
def process_mergesort(stage: dict[str, Any]) -> Node:
|
||||
children = [process_stage(child) for child in stage["inputStages"]]
|
||||
return Node(**get_common_fields(stage), n_processed=stage["nReturned"], children=children)
|
||||
|
||||
|
||||
def process_sort(stage: dict[str, Any]) -> Node:
|
||||
input_stage = process_stage(stage["inputStage"])
|
||||
return Node(**get_common_fields(stage), n_processed=stage["nReturned"], children=[input_stage])
|
||||
|
||||
|
||||
def process_limitskip(stage: dict[str, Any]) -> Node:
|
||||
input_stage = process_stage(stage["inputStage"])
|
||||
return Node(
|
||||
**get_common_fields(stage), n_processed=input_stage.n_processed, children=[input_stage]
|
||||
)
|
||||
|
||||
|
||||
def get_common_fields(json_stage: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract common fields from classic nodes"""
|
||||
return {
|
||||
"stage": json_stage["stage"],
|
||||
"execution_time_nanoseconds": json_stage["executionTimeNanos"],
|
||||
"n_returned": json_stage["nReturned"],
|
||||
"seeks": json_stage.get("seeks"),
|
||||
}
|
||||
@ -30,7 +30,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import bson.json_util as json
|
||||
|
||||
@ -50,7 +50,7 @@ class Node:
|
||||
children: list[Node]
|
||||
|
||||
def get_execution_time(self):
|
||||
"""Execution time of the SBE node without execuion time of its children."""
|
||||
"""Execution time of the SBE node without execution time of its children."""
|
||||
return self.total_execution_time - sum(n.total_execution_time for n in self.children)
|
||||
|
||||
def print(self, level=0):
|
||||
@ -62,13 +62,13 @@ class Node:
|
||||
child.print(level + 1)
|
||||
|
||||
|
||||
def build_execution_tree(execution_stats: dict[str, any]) -> Node:
|
||||
"""Build SBE executioon tree from 'executionStats' field of query explain."""
|
||||
def build_execution_tree(execution_stats: dict[str, Any]) -> Node:
|
||||
"""Build SBE execution tree from 'executionStats' field of query explain."""
|
||||
assert execution_stats["executionSuccess"]
|
||||
return process_stage(execution_stats["executionStages"])
|
||||
|
||||
|
||||
def process_stage(stage: dict[str, any]) -> Node:
|
||||
def process_stage(stage: dict[str, Any]) -> Node:
|
||||
"""Parse the given SBE stage."""
|
||||
processors = {
|
||||
"filter": process_filter,
|
||||
@ -100,13 +100,13 @@ def process_stage(stage: dict[str, any]) -> Node:
|
||||
return processor(stage)
|
||||
|
||||
|
||||
def process_filter(stage: dict[str, any]) -> Node:
|
||||
def process_filter(stage: dict[str, Any]) -> Node:
|
||||
"""Process filter stage."""
|
||||
input_stage = process_stage(stage["inputStage"])
|
||||
return Node(**get_common_fields(stage), n_processed=stage["numTested"], children=[input_stage])
|
||||
|
||||
|
||||
def process_traverse(stage: dict[str, any]) -> Node:
|
||||
def process_traverse(stage: dict[str, Any]) -> Node:
|
||||
"""Process traverse"""
|
||||
outer_stage = process_stage(stage["outerStage"])
|
||||
inner_stage = process_stage(stage["innerStage"])
|
||||
@ -117,7 +117,7 @@ def process_traverse(stage: dict[str, any]) -> Node:
|
||||
)
|
||||
|
||||
|
||||
def process_hash_join_node(stage: dict[str, any]) -> Node:
|
||||
def process_hash_join_node(stage: dict[str, Any]) -> Node:
|
||||
"""Process hj node."""
|
||||
outer_stage = process_stage(stage["outerStage"])
|
||||
inner_stage = process_stage(stage["innerStage"])
|
||||
@ -127,7 +127,7 @@ def process_hash_join_node(stage: dict[str, any]) -> Node:
|
||||
)
|
||||
|
||||
|
||||
def process_nlj(stage: dict[str, any]) -> Node:
|
||||
def process_nlj(stage: dict[str, Any]) -> Node:
|
||||
"""Process nlj stage."""
|
||||
outer_stage = process_stage(stage["outerStage"])
|
||||
inner_stage = process_stage(stage["innerStage"])
|
||||
@ -137,7 +137,7 @@ def process_nlj(stage: dict[str, any]) -> Node:
|
||||
)
|
||||
|
||||
|
||||
def process_inner_node(stage: dict[str, any]) -> Node:
|
||||
def process_inner_node(stage: dict[str, Any]) -> Node:
|
||||
"""Process SBE stage with one input stage."""
|
||||
input_stage = process_stage(stage["inputStage"])
|
||||
return Node(
|
||||
@ -145,23 +145,23 @@ def process_inner_node(stage: dict[str, any]) -> Node:
|
||||
)
|
||||
|
||||
|
||||
def process_leaf_node(stage: dict[str, any]) -> Node:
|
||||
def process_leaf_node(stage: dict[str, Any]) -> Node:
|
||||
"""Process SBE stage without input stages."""
|
||||
return Node(**get_common_fields(stage), n_processed=stage["nReturned"], children=[])
|
||||
|
||||
|
||||
def process_seek(stage: dict[str, any]) -> Node:
|
||||
def process_seek(stage: dict[str, Any]) -> Node:
|
||||
"""Process seek stage."""
|
||||
return Node(**get_common_fields(stage), n_processed=stage["numReads"], children=[])
|
||||
|
||||
|
||||
def process_union_node(stage: dict[str, any]) -> Node:
|
||||
def process_union_node(stage: dict[str, Any]) -> Node:
|
||||
"""Process union stage."""
|
||||
children = [process_stage(child) for child in stage["inputStages"]]
|
||||
return Node(**get_common_fields(stage), n_processed=stage["nReturned"], children=children)
|
||||
|
||||
|
||||
def process_unwind_node(stage: dict[str, any]) -> Node:
|
||||
def process_unwind_node(stage: dict[str, Any]) -> Node:
|
||||
"""Process unwind stage."""
|
||||
input_stage = process_stage(stage["inputStage"])
|
||||
return Node(
|
||||
@ -169,14 +169,14 @@ def process_unwind_node(stage: dict[str, any]) -> Node:
|
||||
)
|
||||
|
||||
|
||||
def process_unique_node(stage: dict[str, any]) -> Node:
|
||||
def process_unique_node(stage: dict[str, Any]) -> Node:
|
||||
"""Process unique stage."""
|
||||
input_stage = process_stage(stage["inputStage"])
|
||||
n_processed = stage["dupsTested"]
|
||||
return Node(**get_common_fields(stage), n_processed=n_processed, children=[input_stage])
|
||||
|
||||
|
||||
def process_branch_node(stage: dict[str, any]) -> Node:
|
||||
def process_branch_node(stage: dict[str, Any]) -> Node:
|
||||
"""Process unique stage."""
|
||||
then_stage = process_stage(stage["thenStage"])
|
||||
else_stage = process_stage(stage["elseStage"])
|
||||
@ -186,8 +186,8 @@ def process_branch_node(stage: dict[str, any]) -> Node:
|
||||
)
|
||||
|
||||
|
||||
def get_common_fields(json_stage: dict[str, any]) -> dict[str, any]:
|
||||
"""Exctract common field from json representation of SBE stage."""
|
||||
def get_common_fields(json_stage: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract common field from json representation of SBE stage."""
|
||||
return {
|
||||
"stage": json_stage["stage"],
|
||||
"plan_node_id": json_stage["planNodeId"],
|
||||
@ -98,7 +98,7 @@ from __future__ import annotations
|
||||
import dataclasses
|
||||
|
||||
import bson.json_util as json
|
||||
import execution_tree as sbe
|
||||
import execution_tree_sbe as sbe
|
||||
import pandas as pd
|
||||
import physical_tree as abt
|
||||
import seaborn as sns
|
||||
|
||||
@ -33,7 +33,7 @@ from collections import defaultdict, deque
|
||||
from typing import Callable, Mapping, Sequence, TypeVar
|
||||
|
||||
import bson.json_util as json
|
||||
import execution_tree
|
||||
import execution_tree_sbe
|
||||
import physical_tree
|
||||
from config import AbtCalibratorConfig
|
||||
from cost_estimator import CostModelParameters, ExecutionStats
|
||||
@ -87,9 +87,9 @@ def find_nodes(root: Node, predicate: Callable[[Node], bool]) -> list[Node]:
|
||||
return result
|
||||
|
||||
|
||||
def get_excution_stats(root: execution_tree.Node, node_id: int) -> ExecutionStats:
|
||||
def get_excution_stats(root: execution_tree_sbe.Node, node_id: int) -> ExecutionStats:
|
||||
"""Extract execution stats from the given Execution Tree for the ABT node defined with the given node_id."""
|
||||
queue: deque[execution_tree.Node] = deque()
|
||||
queue: deque[execution_tree_sbe.Node] = deque()
|
||||
queue.append(root)
|
||||
|
||||
execution_time: int = 0
|
||||
@ -120,7 +120,7 @@ def parse_explain(explain: Mapping[str, any], abt_types: Sequence[str]):
|
||||
"""Extract ExecutionStats from the given explain for the given ABT types."""
|
||||
|
||||
try:
|
||||
et = execution_tree.build_execution_tree(explain["executionStats"])
|
||||
et = execution_tree_sbe.build_execution_tree(explain["executionStats"])
|
||||
pt = physical_tree.build(explain["queryPlanner"]["winningPlan"]["queryPlan"])
|
||||
except Exception as exception:
|
||||
print(f"*** Failed to parse explain with the followinf error: {exception}")
|
||||
@ -131,7 +131,7 @@ def parse_explain(explain: Mapping[str, any], abt_types: Sequence[str]):
|
||||
|
||||
|
||||
def extract_execution_stats(
|
||||
et: execution_tree.Node, pt: physical_tree.Node, abt_types: Sequence[str]
|
||||
et: execution_tree_sbe.Node, pt: physical_tree.Node, abt_types: Sequence[str]
|
||||
) -> Mapping[str, Sequence[ExecutionStats]]:
|
||||
"""Extract ExecutionStats from the given SBE and ABT trees for the given ABT types."""
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -30,6 +30,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
__all__ = ["Node", "build"]
|
||||
|
||||
@ -39,32 +40,29 @@ class Node:
|
||||
"""Represent Query Solution node."""
|
||||
|
||||
node_type: str
|
||||
plan_node_id: int
|
||||
children: list[Node]
|
||||
|
||||
def print(self, level=0):
|
||||
"""Pretty print of the QSN tree."""
|
||||
print(f'{"| "*level}{self.node_type}, planNodeId: {self.plan_node_id}')
|
||||
print(f'{"| "*level}{self.node_type}')
|
||||
for child in self.children:
|
||||
child.print(level + 1)
|
||||
|
||||
|
||||
def build(optimizer_plan: dict[str, any]) -> Node:
|
||||
def build(optimizer_plan: dict[str, Any]) -> Node:
|
||||
"""Build QSN tree from query explain."""
|
||||
|
||||
return parse_optimizer_node(optimizer_plan)
|
||||
|
||||
|
||||
def parse_optimizer_node(explain_node: dict[str, any]) -> Node:
|
||||
def parse_optimizer_node(explain_node: dict[str, Any]) -> Node:
|
||||
"""Recursively parse QSN from query explain's node."""
|
||||
|
||||
children = get_children(explain_node)
|
||||
return Node(
|
||||
node_type=explain_node["stage"], plan_node_id=explain_node["planNodeId"], children=children
|
||||
)
|
||||
return Node(node_type=explain_node["stage"], children=children)
|
||||
|
||||
|
||||
def get_children(explain_node: dict[str, any]) -> list[Node]:
|
||||
def get_children(explain_node: dict[str, Any]) -> list[Node]:
|
||||
"""Get children nodes of the QSN."""
|
||||
|
||||
if "inputStage" in explain_node:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user