SERVER-111295 [v8.0] Set python as formatter in format_multirun (#41681)
GitOrigin-RevId: 0a5f595c13f329cc64a37f58e7369dd9469ee848
This commit is contained in:
parent
ad27dbb8da
commit
fbc2f1ea04
1
.gitattributes
vendored
1
.gitattributes
vendored
@ -6,6 +6,7 @@
|
||||
external rules-lint-ignored=true
|
||||
**/*.tpl.h rules-lint-ignored=true
|
||||
**/*.tpl.cpp rules-lint-ignored=true
|
||||
rpm/*.spec rules-lint-ignored=true
|
||||
src/mongo/bson/util/bson_column_compressed_data.inl rules-lint-ignored=true
|
||||
src/mongo/bson/util/simple8b.inl rules-lint-ignored=true
|
||||
*.idl linguist-language=yaml
|
||||
|
||||
@ -198,7 +198,9 @@ def get_releases_json(bazelisk_directory):
|
||||
pass
|
||||
|
||||
with open(releases, "wb") as f:
|
||||
body = read_remote_text_file("https://api.github.com/repos/bazelbuild/bazel/releases")
|
||||
body = read_remote_text_file(
|
||||
"https://api.github.com/repos/bazelbuild/bazel/releases"
|
||||
)
|
||||
f.write(body.encode("utf-8"))
|
||||
return json.loads(body)
|
||||
|
||||
@ -245,7 +247,9 @@ def get_operating_system():
|
||||
if operating_system not in ("linux", "darwin", "windows"):
|
||||
raise Exception(
|
||||
'Unsupported operating system "{}". '
|
||||
"Bazel currently only supports Linux, macOS and Windows.".format(operating_system)
|
||||
"Bazel currently only supports Linux, macOS and Windows.".format(
|
||||
operating_system
|
||||
)
|
||||
)
|
||||
return operating_system
|
||||
|
||||
@ -262,7 +266,10 @@ def determine_bazel_filename(version):
|
||||
if machine not in supported_machines:
|
||||
raise Exception(
|
||||
'Unsupported machine architecture "{}". Bazel {} only supports {} on {}.'.format(
|
||||
machine, version, ", ".join(supported_machines), operating_system.capitalize()
|
||||
machine,
|
||||
version,
|
||||
", ".join(supported_machines),
|
||||
operating_system.capitalize(),
|
||||
)
|
||||
)
|
||||
|
||||
@ -270,7 +277,9 @@ def determine_bazel_filename(version):
|
||||
bazel_flavor = "bazel"
|
||||
if get_env_or_config("BAZELISK_NOJDK", "0") != "0":
|
||||
bazel_flavor = "bazel_nojdk"
|
||||
return "{}-{}-{}-{}{}".format(bazel_flavor, version, operating_system, machine, filename_suffix)
|
||||
return "{}-{}-{}-{}{}".format(
|
||||
bazel_flavor, version, operating_system, machine, filename_suffix
|
||||
)
|
||||
|
||||
|
||||
def get_supported_machine_archs(version, operating_system):
|
||||
@ -498,7 +507,9 @@ def execute_bazel(bazel_path, argv):
|
||||
cmd = make_bazel_cmd(bazel_path, argv)
|
||||
|
||||
# We cannot use close_fds on Windows, so disable it there.
|
||||
p = subprocess.Popen([cmd["exec"]] + cmd["args"], close_fds=os.name != "nt", env=cmd["env"])
|
||||
p = subprocess.Popen(
|
||||
[cmd["exec"]] + cmd["args"], close_fds=os.name != "nt", env=cmd["env"]
|
||||
)
|
||||
while True:
|
||||
try:
|
||||
return p.wait()
|
||||
|
||||
@ -13,7 +13,9 @@ import textwrap
|
||||
from typing import Dict
|
||||
|
||||
|
||||
def write_config_header(input_path: str, output_path: str, definitions: Dict[str, str]) -> None:
|
||||
def write_config_header(
|
||||
input_path: str, output_path: str, definitions: Dict[str, str]
|
||||
) -> None:
|
||||
with open(input_path) as in_file:
|
||||
content = in_file.read()
|
||||
|
||||
@ -27,18 +29,28 @@ def write_config_header(input_path: str, output_path: str, definitions: Dict[str
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Generate a config header file")
|
||||
parser.add_argument("--compiler-path", help="Path to the compiler executable", required=True)
|
||||
parser.add_argument("--compiler-args", help="Extra compiler arguments", required=True)
|
||||
parser.add_argument(
|
||||
"--compiler-path", help="Path to the compiler executable", required=True
|
||||
)
|
||||
parser.add_argument(
|
||||
"--compiler-args", help="Extra compiler arguments", required=True
|
||||
)
|
||||
parser.add_argument("--env-vars", help="Extra environment variables", required=True)
|
||||
parser.add_argument(
|
||||
"--output-path", help="Path to the output config header file", required=True
|
||||
)
|
||||
parser.add_argument(
|
||||
"--template-path", help="Path to the config header's template file", required=True
|
||||
"--template-path",
|
||||
help="Path to the config header's template file",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument("--extra-definitions", help="Extra header definitions")
|
||||
parser.add_argument("--check-path", help="Path to the suppored configure checks", required=True)
|
||||
parser.add_argument("--log-path", help="Path to the suppored configure checks", required=True)
|
||||
parser.add_argument(
|
||||
"--check-path", help="Path to the suppored configure checks", required=True
|
||||
)
|
||||
parser.add_argument(
|
||||
"--log-path", help="Path to the suppored configure checks", required=True
|
||||
)
|
||||
parser.add_argument("--additional-input", help="extra files", action="append")
|
||||
|
||||
args = parser.parse_args()
|
||||
@ -72,7 +84,11 @@ if __name__ == "__main__":
|
||||
%s
|
||||
|
||||
"""
|
||||
% (generate_config_module, called_args, list(generate_config_header_args.keys()))
|
||||
% (
|
||||
generate_config_module,
|
||||
called_args,
|
||||
list(generate_config_header_args.keys()),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@ -17,11 +17,15 @@ bazel_cache = os.path.expanduser(args.bazel_cache)
|
||||
# the cc_library and cc_binaries in our build. There is not a good way from
|
||||
# within the build to get all those targets, so we will generate the list via query
|
||||
# https://sig-product-docs.synopsys.com/bundle/coverity-docs/page/coverity-analysis/topics/building_with_bazel.html#build_with_bazel
|
||||
cmd = [
|
||||
cmd = (
|
||||
[
|
||||
bazel_executable,
|
||||
bazel_cache,
|
||||
"aquery",
|
||||
] + bazel_cmd_args + [args.bazel_query]
|
||||
]
|
||||
+ bazel_cmd_args
|
||||
+ [args.bazel_query]
|
||||
)
|
||||
print(f"Running command: {cmd}")
|
||||
proc = subprocess.run(
|
||||
cmd,
|
||||
@ -33,9 +37,7 @@ proc = subprocess.run(
|
||||
print(proc.stderr)
|
||||
|
||||
targets = set()
|
||||
with open('coverity_targets.list', 'w') as f:
|
||||
with open("coverity_targets.list", "w") as f:
|
||||
for line in proc.stdout.splitlines():
|
||||
if line.startswith(" Target: "):
|
||||
f.write(line.split()[-1] + "\n")
|
||||
|
||||
|
||||
|
||||
@ -39,6 +39,7 @@ format_multirun(
|
||||
graphql = "//:prettier",
|
||||
html = "//:prettier",
|
||||
markdown = "//:prettier",
|
||||
python = "@aspect_rules_lint//format:ruff",
|
||||
sql = "//:prettier",
|
||||
starlark = "@buildifier_prebuilt//:buildifier",
|
||||
visibility = ["//visibility:public"],
|
||||
|
||||
@ -31,7 +31,9 @@ def _git_diff(args: list) -> str:
|
||||
return result.stdout.strip() + os.linesep
|
||||
|
||||
|
||||
def _get_files_changed_since_fork_point(origin_branch: str = "origin/master") -> List[str]:
|
||||
def _get_files_changed_since_fork_point(
|
||||
origin_branch: str = "origin/master",
|
||||
) -> List[str]:
|
||||
"""Query git to get a list of files in the repo from a diff."""
|
||||
# There are 3 diffs we run:
|
||||
# 1. List of commits between origin/master and HEAD of current branch
|
||||
@ -68,7 +70,9 @@ def run_rules_lint(
|
||||
print("Running rules_lint formatter")
|
||||
if files_to_format != "all":
|
||||
command += files_to_format
|
||||
repo_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
repo_path = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
subprocess.run(command, check=True, env=os.environ, cwd=repo_path)
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
@ -83,7 +87,9 @@ def run_shellscripts_linters(shellscripts_linters: pathlib.Path, check: bool) ->
|
||||
command.append("fix")
|
||||
else:
|
||||
print("Running shellscripts linter")
|
||||
repo_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
repo_path = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
)
|
||||
subprocess.run(command, check=True, env=os.environ, cwd=repo_path)
|
||||
except subprocess.CalledProcessError:
|
||||
return False
|
||||
@ -145,7 +151,9 @@ def main() -> int:
|
||||
# If we are running in bazel, default the directory to the workspace
|
||||
default_dir = os.environ.get("BUILD_WORKSPACE_DIRECTORY")
|
||||
if not default_dir:
|
||||
print("This script must be run though bazel. Please run 'bazel run //:format' instead")
|
||||
print(
|
||||
"This script must be run though bazel. Please run 'bazel run //:format' instead"
|
||||
)
|
||||
print("*** IF BAZEL IS NOT INSTALLED, RUN THE FOLLOWING: ***\n")
|
||||
print("python buildscripts/install_bazel.py")
|
||||
return 1
|
||||
@ -154,7 +162,9 @@ def main() -> int:
|
||||
prog="Format", description="This script formats code in mongodb"
|
||||
)
|
||||
|
||||
parser.add_argument("--check", help="Run in check mode", default=False, action="store_true")
|
||||
parser.add_argument(
|
||||
"--check", help="Run in check mode", default=False, action="store_true"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prettier", help="Set the path to prettier", required=True, type=pathlib.Path
|
||||
)
|
||||
@ -200,7 +210,9 @@ def main() -> int:
|
||||
print(
|
||||
f"The number of commits between current branch and origin branch ({args.origin_branch}) is too large: {distance} commits"
|
||||
)
|
||||
print("WARNING!!! Defaulting to formatting all files, this may take a while.")
|
||||
print(
|
||||
"WARNING!!! Defaulting to formatting all files, this may take a while."
|
||||
)
|
||||
print(
|
||||
"Please update your local branch with the latest changes from origin, or use `bazel run format -- --origin-branch other_branch` to select a different origin branch"
|
||||
)
|
||||
@ -217,12 +229,17 @@ def main() -> int:
|
||||
validate_bazel_groups(generate_report=True, fix=not args.check)
|
||||
|
||||
if files_to_format != "all":
|
||||
files_to_format = [str(file) for file in files_to_format if os.path.isfile(file)]
|
||||
files_to_format = [
|
||||
str(file) for file in files_to_format if os.path.isfile(file)
|
||||
]
|
||||
|
||||
return (
|
||||
0
|
||||
if run_rules_lint(
|
||||
args.rules_lint_format, args.rules_lint_format_check, args.check, files_to_format
|
||||
args.rules_lint_format,
|
||||
args.rules_lint_format_check,
|
||||
args.check,
|
||||
files_to_format,
|
||||
)
|
||||
and run_shellscripts_linters(shellscripts_linters_path, args.check)
|
||||
and run_prettier(prettier_path, args.check, files_to_format)
|
||||
|
||||
@ -7,7 +7,9 @@ parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--depfile", action="append")
|
||||
parser.add_argument("--install-dir")
|
||||
parser.add_argument("--install-mode", choices=["copy", "symlink", "hardlink"], default="hardlink")
|
||||
parser.add_argument(
|
||||
"--install-mode", choices=["copy", "symlink", "hardlink"], default="hardlink"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
if os.path.exists(args.install_dir):
|
||||
@ -67,12 +69,15 @@ def install(src, install_type):
|
||||
if os.path.isdir(src):
|
||||
for root, _, files in os.walk(src):
|
||||
for name in files:
|
||||
dest_dir = os.path.dirname(os.path.join(root, name)).replace(
|
||||
src, dst
|
||||
)
|
||||
dest_dir = os.path.dirname(
|
||||
os.path.join(root, name)
|
||||
).replace(src, dst)
|
||||
if not os.path.exists(dest_dir):
|
||||
os.makedirs(dest_dir)
|
||||
os.link(os.path.join(root, name), os.path.join(dest_dir, name))
|
||||
os.link(
|
||||
os.path.join(root, name),
|
||||
os.path.join(dest_dir, name),
|
||||
)
|
||||
else:
|
||||
try:
|
||||
os.link(src, dst)
|
||||
|
||||
@ -71,7 +71,9 @@ def merge_check_options_into_config(
|
||||
override = check_options_list_to_map(incoming_config.get("CheckOptions"))
|
||||
if override:
|
||||
base.update(override) # later wins
|
||||
target_config["CheckOptions"] = [{"key": k, "value": v} for k, v in sorted(base.items())]
|
||||
target_config["CheckOptions"] = [
|
||||
{"key": k, "value": v} for k, v in sorted(base.items())
|
||||
]
|
||||
|
||||
|
||||
def deep_merge_dicts(base: Any, override: Any) -> Any:
|
||||
@ -176,7 +178,9 @@ def main() -> None:
|
||||
# then generic merge:
|
||||
merged_config = deep_merge_dicts(merged_config, incoming_config)
|
||||
|
||||
merged_config["Checks"] = ",".join(split_checks_to_list(merged_config.get("Checks")))
|
||||
merged_config["Checks"] = ",".join(
|
||||
split_checks_to_list(merged_config.get("Checks"))
|
||||
)
|
||||
output_path = pathlib.Path(args.out)
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
@ -184,4 +188,4 @@ def main() -> None:
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -10,13 +10,19 @@ from datetime import datetime
|
||||
|
||||
def log_subprocess_run(*args, **kwargs):
|
||||
arg_list_or_string = kwargs["args"] if "args" in kwargs else args[0]
|
||||
print(" ".join(arg_list_or_string) if type(arg_list_or_string) == list else arg_list_or_string)
|
||||
print(
|
||||
" ".join(arg_list_or_string)
|
||||
if type(arg_list_or_string) == list
|
||||
else arg_list_or_string
|
||||
)
|
||||
return subprocess.run(*args, **kwargs)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--distro", type=str, help="Restrict to only update a single distro.")
|
||||
parser.add_argument(
|
||||
"--distro", type=str, help="Restrict to only update a single distro."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-cleanup",
|
||||
action="store_true",
|
||||
@ -42,14 +48,18 @@ Your docker images, volumes and containers will be purged if you continue. Enter
|
||||
code = compile(f.read(), container_file_path, "exec")
|
||||
exec(code, {}, remote_execution_containers)
|
||||
|
||||
for distro, re_container in remote_execution_containers["REMOTE_EXECUTION_CONTAINERS"].items():
|
||||
for distro, re_container in remote_execution_containers[
|
||||
"REMOTE_EXECUTION_CONTAINERS"
|
||||
].items():
|
||||
if args.distro is not None:
|
||||
if distro != args.distro:
|
||||
continue
|
||||
|
||||
if not args.skip_cleanup:
|
||||
# Clean host system between container builds to avoid running into disk space issues.
|
||||
print("Cleaning host system's docker images, containers, volumes, and networks...")
|
||||
print(
|
||||
"Cleaning host system's docker images, containers, volumes, and networks..."
|
||||
)
|
||||
for command in [
|
||||
"docker stop $(docker ps -a -q)", # Stop all running containers
|
||||
"docker rm $(docker ps -a -q)", # Remove all containers
|
||||
@ -65,7 +75,9 @@ Your docker images, volumes and containers will be purged if you continue. Enter
|
||||
print(f"Using dockerfile: {dockerfile}")
|
||||
print(f"Using tag: {tag}\n")
|
||||
|
||||
log_subprocess_run(["docker", "buildx", "create", "--use", "default"], check=True)
|
||||
log_subprocess_run(
|
||||
["docker", "buildx", "create", "--use", "default"], check=True
|
||||
)
|
||||
log_subprocess_run(
|
||||
[
|
||||
"docker",
|
||||
|
||||
@ -18,7 +18,12 @@ def main(
|
||||
with open(base_config, "rt") as fh:
|
||||
base_config_content = yaml.safe_load(fh)
|
||||
if "selector" in base_config_content:
|
||||
for x in ["roots", "exclude_files", "exclude_with_any_tags", "include_with_any_tags"]:
|
||||
for x in [
|
||||
"roots",
|
||||
"exclude_files",
|
||||
"exclude_with_any_tags",
|
||||
"include_with_any_tags",
|
||||
]:
|
||||
base_config_content["selector"].pop(x, None)
|
||||
|
||||
content = base_config_content
|
||||
|
||||
@ -62,9 +62,13 @@ if __name__ == "__main__":
|
||||
|
||||
if os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR"):
|
||||
undeclared_output_dir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR")
|
||||
resmoke_args.append(f"--dbpathPrefix={os.path.join(undeclared_output_dir,'data')}")
|
||||
resmoke_args.append(
|
||||
f"--dbpathPrefix={os.path.join(undeclared_output_dir,'data')}"
|
||||
)
|
||||
resmoke_args.append(f"--taskWorkDir={undeclared_output_dir}")
|
||||
resmoke_args.append(f"--reportFile={os.path.join(undeclared_output_dir,'report.json')}")
|
||||
resmoke_args.append(
|
||||
f"--reportFile={os.path.join(undeclared_output_dir,'report.json')}"
|
||||
)
|
||||
|
||||
if os.environ.get("TEST_SRCDIR"):
|
||||
test_srcdir = os.environ.get("TEST_SRCDIR")
|
||||
|
||||
@ -52,7 +52,9 @@ PLATFORM_NAME_MAP = {
|
||||
REQUESTS_SESSION = requests.Session()
|
||||
REQUESTS_SESSION.mount(
|
||||
"https://",
|
||||
HTTPAdapter(max_retries=Retry(total=5, backoff_factor=1, status_forcelist=[502, 503, 504])),
|
||||
HTTPAdapter(
|
||||
max_retries=Retry(total=5, backoff_factor=1, status_forcelist=[502, 503, 504])
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -61,7 +63,9 @@ def download_toolchain(toolchain_url: str, local_path: str) -> bool:
|
||||
|
||||
response = REQUESTS_SESSION.get(toolchain_url)
|
||||
if response.status_code != requests.codes.ok:
|
||||
print(f"WARNING: HTTP {response.status_code} status downloading {toolchain_url}")
|
||||
print(
|
||||
f"WARNING: HTTP {response.status_code} status downloading {toolchain_url}"
|
||||
)
|
||||
return False
|
||||
|
||||
with open(local_path, "wb") as f:
|
||||
@ -88,10 +92,14 @@ def main():
|
||||
help="The build id, this should be the toolchain revision (githash), or the evergreen task id (version and date) if it is a --patch toolchain.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"toolchain_version", choices=toolchain_versions + ["all"], help="Toolchain version"
|
||||
"toolchain_version",
|
||||
choices=toolchain_versions + ["all"],
|
||||
help="Toolchain version",
|
||||
)
|
||||
parser.add_argument(
|
||||
"toolchain_component", choices=toolchain_components + ["all"], help="Toolchain component"
|
||||
"toolchain_component",
|
||||
choices=toolchain_components + ["all"],
|
||||
help="Toolchain component",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--patch_toolchain",
|
||||
@ -162,9 +170,13 @@ def main():
|
||||
)
|
||||
print(f'TOOLCHAIN_ID = "{args.build_id}"', file=f)
|
||||
print(f"TOOLCHAIN_MAP_{version_str.upper()} = {{", file=f)
|
||||
for key, value in sorted(mongo_toolchain_version.items(), key=lambda x: x[0]):
|
||||
for key, value in sorted(
|
||||
mongo_toolchain_version.items(), key=lambda x: x[0]
|
||||
):
|
||||
print(f' "{key}": {{', file=f)
|
||||
for subkey, subvalue in sorted(value.items(), key=lambda x: x[0]):
|
||||
for subkey, subvalue in sorted(
|
||||
value.items(), key=lambda x: x[0]
|
||||
):
|
||||
print(f' "{subkey}": "{subvalue}",', file=f)
|
||||
print(" },", file=f)
|
||||
print("}", file=f)
|
||||
|
||||
@ -8,6 +8,7 @@ import pathlib
|
||||
import tempfile
|
||||
import urllib.request
|
||||
|
||||
|
||||
def sha256_file(filename: str) -> str:
|
||||
sha256_hash = hashlib.sha256()
|
||||
with open(filename, "rb") as f:
|
||||
@ -15,23 +16,29 @@ def sha256_file(filename: str) -> str:
|
||||
sha256_hash.update(block)
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("patch_build_id",
|
||||
help="Patch build id from toolchain-builder project.")
|
||||
parser.add_argument("patch_build_date_string",
|
||||
help="Patch build date string from toolchain-builder project, get this at the task URL, ex the date is 24_01_09_16_10_07 for https://spruce.mongodb.com/task/toolchain_builder_amazon2023_compile_11bae3c145a48dd7be9ee8aa44e5591783f787aa_24_01_09_16_10_07/")
|
||||
parser.add_argument(
|
||||
"patch_build_id", help="Patch build id from toolchain-builder project."
|
||||
)
|
||||
parser.add_argument(
|
||||
"patch_build_date_string",
|
||||
help="Patch build date string from toolchain-builder project, get this at the task URL, ex the date is 24_01_09_16_10_07 for https://spruce.mongodb.com/task/toolchain_builder_amazon2023_compile_11bae3c145a48dd7be9ee8aa44e5591783f787aa_24_01_09_16_10_07/",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
mongo_toolchain_version = {}
|
||||
version_file_path = os.path.join(pathlib.Path(__file__).parent.resolve(), "mongo_toolchain_version.bzl")
|
||||
version_file_path = os.path.join(
|
||||
pathlib.Path(__file__).parent.resolve(), "mongo_toolchain_version.bzl"
|
||||
)
|
||||
with open(version_file_path, "r") as f:
|
||||
code = compile(f.read(), version_file_path, "exec")
|
||||
exec(code, {}, mongo_toolchain_version)
|
||||
|
||||
|
||||
for toolchain_name, toolchain in mongo_toolchain_version["TOOLCHAIN_MAP"].items():
|
||||
underscore_platform_name = toolchain['platform_name'].replace("-", "_")
|
||||
underscore_platform_name = toolchain["platform_name"].replace("-", "_")
|
||||
|
||||
toolchain_url = mongo_toolchain_version["TOOLCHAIN_URL_FORMAT"].format(
|
||||
platform_name=toolchain["platform_name"],
|
||||
@ -39,9 +46,12 @@ def main():
|
||||
patch_build_id=args.patch_build_id,
|
||||
patch_build_date=args.patch_build_date_string,
|
||||
)
|
||||
|
||||
|
||||
temp_dir = tempfile.gettempdir()
|
||||
local_tarball_path = os.path.join(temp_dir, f"bazel_v4_toolchain_builder_{underscore_platform_name}_{args.patch_build_id}.tar.gz")
|
||||
local_tarball_path = os.path.join(
|
||||
temp_dir,
|
||||
f"bazel_v4_toolchain_builder_{underscore_platform_name}_{args.patch_build_id}.tar.gz",
|
||||
)
|
||||
|
||||
print(f"Downloading {toolchain_url}...")
|
||||
|
||||
@ -54,15 +64,23 @@ def main():
|
||||
|
||||
with open(version_file_path, "w") as f:
|
||||
print(f"Writing toolchain map to {version_file_path}...")
|
||||
print("# Use mongo/bazel/toolchains/toolchain_generator.py to generate this mapping for a given patch build.\n", file=f)
|
||||
print(f"TOOLCHAIN_URL_FORMAT = \"{mongo_toolchain_version['TOOLCHAIN_URL_FORMAT']}\"", file=f)
|
||||
print(f"TOOLCHAIN_PATCH_BUILD_ID = \"{args.patch_build_id}\"", file=f)
|
||||
print(f"TOOLCHAIN_PATCH_BUILD_DATE = \"{args.patch_build_date_string}\"", file=f)
|
||||
print(
|
||||
"# Use mongo/bazel/toolchains/toolchain_generator.py to generate this mapping for a given patch build.\n",
|
||||
file=f,
|
||||
)
|
||||
print(
|
||||
f"TOOLCHAIN_URL_FORMAT = \"{mongo_toolchain_version['TOOLCHAIN_URL_FORMAT']}\"",
|
||||
file=f,
|
||||
)
|
||||
print(f'TOOLCHAIN_PATCH_BUILD_ID = "{args.patch_build_id}"', file=f)
|
||||
print(f'TOOLCHAIN_PATCH_BUILD_DATE = "{args.patch_build_date_string}"', file=f)
|
||||
print("TOOLCHAIN_MAP = {", file=f)
|
||||
for key, value in sorted(mongo_toolchain_version["TOOLCHAIN_MAP"].items(), key=lambda x: x[0]):
|
||||
print(f" \"{key}\": {{", file=f)
|
||||
for subkey, subvalue in sorted(value.items(), key=lambda x: x[0]):
|
||||
print(f" \"{subkey}\": \"{subvalue}\",", file=f)
|
||||
for key, value in sorted(
|
||||
mongo_toolchain_version["TOOLCHAIN_MAP"].items(), key=lambda x: x[0]
|
||||
):
|
||||
print(f' "{key}": {{', file=f)
|
||||
for subkey, subvalue in sorted(value.items(), key=lambda x: x[0]):
|
||||
print(f' "{subkey}": "{subvalue}",', file=f)
|
||||
print(" },", file=f)
|
||||
print("}", file=f)
|
||||
|
||||
@ -70,5 +88,6 @@ def main():
|
||||
print(f"Finished writing to {version_file_path}:")
|
||||
print(f.read())
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -27,7 +27,13 @@ bazel_tags_to_autogenerate = [
|
||||
def autogenerate_targets(args, bazel):
|
||||
bazel_autogenerate_flag = "include_autogenerated_targets"
|
||||
output_file = os.path.join(
|
||||
"src", "mongo", "db", "modules", "enterprise", "autogenerated_targets", "BUILD.bazel"
|
||||
"src",
|
||||
"mongo",
|
||||
"db",
|
||||
"modules",
|
||||
"enterprise",
|
||||
"autogenerated_targets",
|
||||
"BUILD.bazel",
|
||||
)
|
||||
if not any(bazel_autogenerate_flag in arg for arg in args):
|
||||
if os.path.exists(output_file):
|
||||
@ -40,7 +46,13 @@ def generate_targets(
|
||||
args=[],
|
||||
bazel="bazel",
|
||||
output_file=os.path.join(
|
||||
"src", "mongo", "db", "modules", "enterprise", "autogenerated_targets", "BUILD.bazel"
|
||||
"src",
|
||||
"mongo",
|
||||
"db",
|
||||
"modules",
|
||||
"enterprise",
|
||||
"autogenerated_targets",
|
||||
"BUILD.bazel",
|
||||
),
|
||||
):
|
||||
targets = []
|
||||
|
||||
@ -45,7 +45,6 @@ def run_pty_command(cmd):
|
||||
|
||||
|
||||
def generate_compiledb(bazel_bin, persistent_compdb, enterprise):
|
||||
|
||||
# compiledb ignores command line args so just make a version rc file in anycase
|
||||
write_mongo_variables_bazelrc([])
|
||||
if persistent_compdb:
|
||||
@ -165,7 +164,8 @@ def generate_compiledb(bazel_bin, persistent_compdb, enterprise):
|
||||
if external_link.exists():
|
||||
os.unlink(external_link)
|
||||
os.symlink(
|
||||
pathlib.Path(os.readlink(REPO_ROOT / "bazel-out")).parent.parent.parent / "external",
|
||||
pathlib.Path(os.readlink(REPO_ROOT / "bazel-out")).parent.parent.parent
|
||||
/ "external",
|
||||
external_link,
|
||||
target_is_directory=True,
|
||||
)
|
||||
@ -199,7 +199,9 @@ def generate_compiledb(bazel_bin, persistent_compdb, enterprise):
|
||||
os.chmod(config, 0o744)
|
||||
with fileinput.FileInput(config, inplace=True) as file:
|
||||
for line in file:
|
||||
print(line.replace("bazel-out/", f"{symlink_prefix}out/"), end="")
|
||||
print(
|
||||
line.replace("bazel-out/", f"{symlink_prefix}out/"), end=""
|
||||
)
|
||||
shutil.copyfile(configs[1], clang_tidy_file)
|
||||
with open(".mongo_checks_module_path", "w") as f:
|
||||
f.write(
|
||||
@ -216,10 +218,14 @@ def generate_compiledb(bazel_bin, persistent_compdb, enterprise):
|
||||
shutil.copyfile(pathlib.Path("bazel-bin") / ".clang-tidy", clang_tidy_file)
|
||||
if persistent_compdb:
|
||||
shutdown_proc = subprocess.run(
|
||||
[bazel_bin, f"--output_base={output_base}", "shutdown"], capture_output=True, text=True
|
||||
[bazel_bin, f"--output_base={output_base}", "shutdown"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if shutdown_proc.returncode != 0:
|
||||
print(f"Failed to shutdown compiledb output_base: {shutdown_proc.returncode}")
|
||||
print(
|
||||
f"Failed to shutdown compiledb output_base: {shutdown_proc.returncode}"
|
||||
)
|
||||
print("--- stdout ---:")
|
||||
print(shutdown_proc.stdout)
|
||||
print("--- stderr ---:")
|
||||
|
||||
@ -26,6 +26,9 @@ def engflow_auth(args):
|
||||
and "--config local" not in args_str
|
||||
and "--config public-release" not in args_str
|
||||
):
|
||||
if os.environ.get("CI") is None and platform.machine().lower() not in {"ppc64le", "s390x"}:
|
||||
if os.environ.get("CI") is None and platform.machine().lower() not in {
|
||||
"ppc64le",
|
||||
"s390x",
|
||||
}:
|
||||
setup_auth_wrapper()
|
||||
wrapper_debug(f"engflow auth time: {time.time() - start}")
|
||||
|
||||
@ -14,7 +14,9 @@ from bazel.wrapper_hook.wrapper_debug import wrapper_debug
|
||||
|
||||
|
||||
def get_deps_dirs(deps):
|
||||
tmp_dir = pathlib.Path(os.environ["Temp"] if platform.system() == "Windows" else "/tmp")
|
||||
tmp_dir = pathlib.Path(
|
||||
os.environ["Temp"] if platform.system() == "Windows" else "/tmp"
|
||||
)
|
||||
bazel_bin = REPO_ROOT / "bazel-bin"
|
||||
for dep in deps:
|
||||
try:
|
||||
@ -41,7 +43,9 @@ def add_module_to_path(poetry_dir, modules_added):
|
||||
|
||||
|
||||
def setup_python_path():
|
||||
tmp_dir = pathlib.Path(os.environ["Temp"] if platform.system() == "Windows" else "/tmp")
|
||||
tmp_dir = pathlib.Path(
|
||||
os.environ["Temp"] if platform.system() == "Windows" else "/tmp"
|
||||
)
|
||||
modules_added = set()
|
||||
|
||||
for out_dir in [
|
||||
@ -92,7 +96,9 @@ def search_for_modules(deps, deps_installed, lockfile_changed=False):
|
||||
def install_modules(bazel):
|
||||
need_to_install = False
|
||||
pwd_hash = hashlib.md5(str(REPO_ROOT).encode()).hexdigest()
|
||||
lockfile_hash_file = pathlib.Path(tempfile.gettempdir()) / f"{pwd_hash}_lockfile_hash"
|
||||
lockfile_hash_file = (
|
||||
pathlib.Path(tempfile.gettempdir()) / f"{pwd_hash}_lockfile_hash"
|
||||
)
|
||||
with open(REPO_ROOT / "poetry.lock", "rb") as f:
|
||||
current_hash = hashlib.md5(f.read()).hexdigest()
|
||||
|
||||
@ -132,7 +138,9 @@ def install_modules(bazel):
|
||||
]
|
||||
)
|
||||
if proc.returncode != 0:
|
||||
print("Failed to install modules using remote exec/cache, falling back to local...")
|
||||
print(
|
||||
"Failed to install modules using remote exec/cache, falling back to local..."
|
||||
)
|
||||
proc = subprocess.run(
|
||||
cmd
|
||||
+ [
|
||||
|
||||
@ -83,7 +83,9 @@ def list_files_without_targets(
|
||||
"src/mongo/util/processinfo_solaris.cpp",
|
||||
}
|
||||
|
||||
typed_files_in_targets = [line for line in files_with_targets if line.endswith(f".{ext}")]
|
||||
typed_files_in_targets = [
|
||||
line for line in files_with_targets if line.endswith(f".{ext}")
|
||||
]
|
||||
|
||||
print(f"Checking that all {type_name} files have BUILD.bazel targets...")
|
||||
|
||||
@ -184,7 +186,9 @@ def run_rules_lint(bazel_bin: str, args: List[str]) -> bool:
|
||||
# so that the naive thing of pasting that flag to lint.sh will do what the user expects.
|
||||
if "--fix" in args:
|
||||
fix = "patch"
|
||||
args.extend(["--@aspect_rules_lint//lint:fix", "--output_groups=rules_lint_patch"])
|
||||
args.extend(
|
||||
["--@aspect_rules_lint//lint:fix", "--output_groups=rules_lint_patch"]
|
||||
)
|
||||
args.remove("--fix")
|
||||
|
||||
# the --dry-run flag must immediately follow the --fix flag
|
||||
@ -208,7 +212,15 @@ def run_rules_lint(bazel_bin: str, args: List[str]) -> bool:
|
||||
# jq on windows outputs CRLF which breaks this script. https://github.com/jqlang/jq/issues/92
|
||||
valid_reports = (
|
||||
subprocess.run(
|
||||
["jq", "--arg", "ext", ".out", "--raw-output", filter_expr, buildevents_path],
|
||||
[
|
||||
"jq",
|
||||
"--arg",
|
||||
"ext",
|
||||
".out",
|
||||
"--raw-output",
|
||||
filter_expr,
|
||||
buildevents_path,
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
@ -220,7 +232,11 @@ def run_rules_lint(bazel_bin: str, args: List[str]) -> bool:
|
||||
failing_reports = 0
|
||||
for report in valid_reports:
|
||||
# Exclude coverage reports, and check if the output is empty.
|
||||
if "coverage.dat" in report or not os.path.exists(report) or not os.path.getsize(report):
|
||||
if (
|
||||
"coverage.dat" in report
|
||||
or not os.path.exists(report)
|
||||
or not os.path.getsize(report)
|
||||
):
|
||||
# Report is empty. No linting errors.
|
||||
continue
|
||||
with open(report, "r", encoding="utf-8") as f:
|
||||
@ -238,7 +254,15 @@ def run_rules_lint(bazel_bin: str, args: List[str]) -> bool:
|
||||
if fix:
|
||||
valid_patches = (
|
||||
subprocess.run(
|
||||
["jq", "--arg", "ext", ".patch", "--raw-output", filter_expr, buildevents_path],
|
||||
[
|
||||
"jq",
|
||||
"--arg",
|
||||
"ext",
|
||||
".patch",
|
||||
"--raw-output",
|
||||
filter_expr,
|
||||
buildevents_path,
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
@ -249,7 +273,11 @@ def run_rules_lint(bazel_bin: str, args: List[str]) -> bool:
|
||||
|
||||
for patch in valid_patches:
|
||||
# Exclude coverage, and check if the patch is empty.
|
||||
if "coverage.dat" in patch or not os.path.exists(patch) or not os.path.getsize(patch):
|
||||
if (
|
||||
"coverage.dat" in patch
|
||||
or not os.path.exists(patch)
|
||||
or not os.path.getsize(patch)
|
||||
):
|
||||
# Patch is empty. No linting errors.
|
||||
continue
|
||||
|
||||
@ -260,7 +288,9 @@ def run_rules_lint(bazel_bin: str, args: List[str]) -> bool:
|
||||
print()
|
||||
elif fix == "patch":
|
||||
subprocess.run(
|
||||
["patch", "-p1"], check=True, stdin=open(patch, "r", encoding="utf-8")
|
||||
["patch", "-p1"],
|
||||
check=True,
|
||||
stdin=open(patch, "r", encoding="utf-8"),
|
||||
)
|
||||
else:
|
||||
print(f"ERROR: unknown fix type {fix}", file=sys.stderr)
|
||||
|
||||
@ -25,7 +25,9 @@ class DuplicateSourceNames(Exception):
|
||||
def get_buildozer_output(autocomplete_query):
|
||||
from buildscripts.install_bazel import install_bazel
|
||||
|
||||
buildozer_name = "buildozer" if not platform.system() == "Windows" else "buildozer.exe"
|
||||
buildozer_name = (
|
||||
"buildozer" if not platform.system() == "Windows" else "buildozer.exe"
|
||||
)
|
||||
buildozer = shutil.which(buildozer_name)
|
||||
if not buildozer:
|
||||
buildozer = str(pathlib.Path(f"~/.local/bin/{buildozer_name}").expanduser())
|
||||
@ -87,7 +89,12 @@ def test_runner_interface(
|
||||
|
||||
if autocomplete_query:
|
||||
str_args = " ".join(args)
|
||||
if "'//:*'" in str_args or "':*'" in str_args or "//:all" in str_args or ":all" in str_args:
|
||||
if (
|
||||
"'//:*'" in str_args
|
||||
or "':*'" in str_args
|
||||
or "//:all" in str_args
|
||||
or ":all" in str_args
|
||||
):
|
||||
plus_autocomplete_query = True
|
||||
|
||||
if os.environ.get("CI") is not None:
|
||||
@ -174,7 +181,9 @@ def test_runner_interface(
|
||||
if not real_target:
|
||||
for bin_target in set(sources_to_bin.values()):
|
||||
if (
|
||||
pathlib.Path(bin_target.replace("//", "").replace(":", "/")).name
|
||||
pathlib.Path(
|
||||
bin_target.replace("//", "").replace(":", "/")
|
||||
).name
|
||||
== test_name
|
||||
):
|
||||
bin_targets.append(bin_target)
|
||||
|
||||
@ -13,6 +13,7 @@ ARCH_NORMALIZE_MAP = {
|
||||
"s390x": "s390x",
|
||||
}
|
||||
|
||||
|
||||
def get_mongo_arch(args):
|
||||
arch = platform.machine().lower()
|
||||
if arch in ARCH_NORMALIZE_MAP:
|
||||
@ -20,14 +21,18 @@ def get_mongo_arch(args):
|
||||
else:
|
||||
return arch
|
||||
|
||||
|
||||
def get_mongo_version(args):
|
||||
proc = subprocess.run(["git", "describe", "--abbrev=0"], capture_output=True, text=True)
|
||||
proc = subprocess.run(
|
||||
["git", "describe", "--abbrev=0"], capture_output=True, text=True
|
||||
)
|
||||
return proc.stdout.strip()[1:]
|
||||
|
||||
|
||||
def write_mongo_variables_bazelrc(args):
|
||||
mongo_version = get_mongo_version(args)
|
||||
mongo_arch = get_mongo_arch(args)
|
||||
|
||||
|
||||
repo_root = pathlib.Path(os.path.abspath(__file__)).parent.parent.parent
|
||||
version_file = os.path.join(repo_root, ".bazelrc.mongo_variables")
|
||||
existing_hash = ""
|
||||
@ -42,4 +47,4 @@ common --define=MONGO_VERSION={mongo_version}
|
||||
current_hash = hashlib.md5(bazelrc_contents.encode()).hexdigest()
|
||||
if existing_hash != current_hash:
|
||||
with open(version_file, "w", encoding="utf-8") as f:
|
||||
f.write(bazelrc_contents)
|
||||
f.write(bazelrc_contents)
|
||||
|
||||
@ -32,7 +32,12 @@ def main():
|
||||
autogenerate_targets(sys.argv, sys.argv[1])
|
||||
|
||||
enterprise = True
|
||||
if check_bazel_command_type(sys.argv[1:]) not in ["clean", "shutdown", "version", None]:
|
||||
if check_bazel_command_type(sys.argv[1:]) not in [
|
||||
"clean",
|
||||
"shutdown",
|
||||
"version",
|
||||
None,
|
||||
]:
|
||||
args = sys.argv
|
||||
enterprise_mod = REPO_ROOT / "src" / "mongo" / "db" / "modules" / "enterprise"
|
||||
if not enterprise_mod.exists():
|
||||
|
||||
@ -13,14 +13,14 @@ from optparse import OptionParser
|
||||
|
||||
def aggregate(inputs, output):
|
||||
"""Aggregate the tracefiles given in inputs to a tracefile given by output."""
|
||||
args = ['lcov']
|
||||
args = ["lcov"]
|
||||
|
||||
for name in inputs:
|
||||
args += ['-a', name]
|
||||
args += ["-a", name]
|
||||
|
||||
args += ['-o', output]
|
||||
args += ["-o", output]
|
||||
|
||||
print(' '.join(args))
|
||||
print(" ".join(args))
|
||||
|
||||
return subprocess.call(args)
|
||||
|
||||
@ -46,17 +46,19 @@ def main():
|
||||
for path in args[:-1]:
|
||||
_, ext = os.path.splitext(path)
|
||||
|
||||
if ext == '.info':
|
||||
if ext == ".info":
|
||||
if getfilesize(path) > 0:
|
||||
inputs.append(path)
|
||||
|
||||
elif ext == '.txt':
|
||||
inputs += [line.strip() for line in open(path) if getfilesize(line.strip()) > 0]
|
||||
elif ext == ".txt":
|
||||
inputs += [
|
||||
line.strip() for line in open(path) if getfilesize(line.strip()) > 0
|
||||
]
|
||||
else:
|
||||
return "unrecognized file type"
|
||||
|
||||
return aggregate(inputs, args[-1])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Script to configure a sharded cluster in Antithesis from the mongos container."""
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
from time import sleep
|
||||
@ -59,7 +60,8 @@ retry_until_success(mongo_process_running, {"host": "configsvr1", "port": 27019}
|
||||
retry_until_success(mongo_process_running, {"host": "configsvr2", "port": 27019})
|
||||
retry_until_success(mongo_process_running, {"host": "configsvr3", "port": 27019})
|
||||
retry_until_success(
|
||||
subprocess.run, {
|
||||
subprocess.run,
|
||||
{
|
||||
"args": [
|
||||
"mongo",
|
||||
"--host",
|
||||
@ -70,14 +72,16 @@ retry_until_success(
|
||||
f"rs.initiate({json.dumps(CONFIGSVR_CONFIG)})",
|
||||
],
|
||||
"check": True,
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
# Create Shard1 once all nodes are up
|
||||
retry_until_success(mongo_process_running, {"host": "database1", "port": 27018})
|
||||
retry_until_success(mongo_process_running, {"host": "database2", "port": 27018})
|
||||
retry_until_success(mongo_process_running, {"host": "database3", "port": 27018})
|
||||
retry_until_success(
|
||||
subprocess.run, {
|
||||
subprocess.run,
|
||||
{
|
||||
"args": [
|
||||
"mongo",
|
||||
"--host",
|
||||
@ -88,14 +92,16 @@ retry_until_success(
|
||||
f"rs.initiate({json.dumps(SHARD1_CONFIG)})",
|
||||
],
|
||||
"check": True,
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
# Create Shard2 once all nodes are up
|
||||
retry_until_success(mongo_process_running, {"host": "database4", "port": 27018})
|
||||
retry_until_success(mongo_process_running, {"host": "database5", "port": 27018})
|
||||
retry_until_success(mongo_process_running, {"host": "database6", "port": 27018})
|
||||
retry_until_success(
|
||||
subprocess.run, {
|
||||
subprocess.run,
|
||||
{
|
||||
"args": [
|
||||
"mongo",
|
||||
"--host",
|
||||
@ -106,11 +112,13 @@ retry_until_success(
|
||||
f"rs.initiate({json.dumps(SHARD2_CONFIG)})",
|
||||
],
|
||||
"check": True,
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
# Start mongos
|
||||
retry_until_success(
|
||||
subprocess.run, {
|
||||
subprocess.run,
|
||||
{
|
||||
"args": [
|
||||
"mongos",
|
||||
"--bind_ip",
|
||||
@ -126,11 +134,13 @@ retry_until_success(
|
||||
"--fork",
|
||||
],
|
||||
"check": True,
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
# Add shards to cluster
|
||||
retry_until_success(
|
||||
subprocess.run, {
|
||||
subprocess.run,
|
||||
{
|
||||
"args": [
|
||||
"mongo",
|
||||
"--host",
|
||||
@ -141,9 +151,11 @@ retry_until_success(
|
||||
'sh.addShard("Shard1/database1:27018,database2:27018,database3:27018")',
|
||||
],
|
||||
"check": True,
|
||||
})
|
||||
},
|
||||
)
|
||||
retry_until_success(
|
||||
subprocess.run, {
|
||||
subprocess.run,
|
||||
{
|
||||
"args": [
|
||||
"mongo",
|
||||
"--host",
|
||||
@ -154,7 +166,8 @@ retry_until_success(
|
||||
'sh.addShard("Shard2/database4:27018,database5:27018,database6:27018")',
|
||||
],
|
||||
"check": True,
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
while True:
|
||||
sleep(10)
|
||||
|
||||
@ -1,12 +1,15 @@
|
||||
"""Util functions to assist in setting up a sharded cluster topology in Antithesis."""
|
||||
|
||||
import subprocess
|
||||
import time
|
||||
|
||||
|
||||
def mongo_process_running(host, port):
|
||||
"""Check to see if the process at the given host & port is running."""
|
||||
return subprocess.run(['mongo', '--host', host, '--port',
|
||||
str(port), '--eval', '"db.stats()"'], check=True)
|
||||
return subprocess.run(
|
||||
["mongo", "--host", host, "--port", str(port), "--eval", '"db.stats()"'],
|
||||
check=True,
|
||||
)
|
||||
|
||||
|
||||
def retry_until_success(func, kwargs=None, wait_time=1, timeout_period=30):
|
||||
@ -16,10 +19,13 @@ def retry_until_success(func, kwargs=None, wait_time=1, timeout_period=30):
|
||||
while True:
|
||||
if time.time() > timeout:
|
||||
raise TimeoutError(
|
||||
f"{func.__name__} called with {kwargs} timed out after {timeout_period} second(s).")
|
||||
f"{func.__name__} called with {kwargs} timed out after {timeout_period} second(s)."
|
||||
)
|
||||
try:
|
||||
func(**kwargs)
|
||||
break
|
||||
except: # pylint: disable=bare-except
|
||||
print(f"Retrying {func.__name__} called with {kwargs} after {wait_time} second(s).")
|
||||
print(
|
||||
f"Retrying {func.__name__} called with {kwargs} after {wait_time} second(s)."
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Script to initialize a workload container in Antithesis."""
|
||||
|
||||
from time import sleep
|
||||
|
||||
while True:
|
||||
|
||||
@ -48,13 +48,17 @@ def get_replacements_to_apply(fixes_file) -> dict:
|
||||
fixes_data = json.load(fin)
|
||||
for clang_tidy_check in fixes_data:
|
||||
for main_source_file in fixes_data[clang_tidy_check]:
|
||||
for violation_instance in fixes_data[clang_tidy_check][main_source_file]:
|
||||
for violation_instance in fixes_data[clang_tidy_check][
|
||||
main_source_file
|
||||
]:
|
||||
replacements = fixes_data[clang_tidy_check][main_source_file][
|
||||
violation_instance
|
||||
]["replacements"]
|
||||
if can_replacements_be_applied(replacements):
|
||||
for replacement in replacements:
|
||||
replacements_to_apply[replacement["FilePath"]].append(replacement)
|
||||
replacements_to_apply[replacement["FilePath"]].append(
|
||||
replacement
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"""WARNING: not applying replacements for {clang_tidy_check} in {main_source_file} at offset {violation_instance}, at least one file that is part of the automatic replacement has changed since clang-tidy was run, or is not writeable."""
|
||||
@ -88,7 +92,9 @@ def _combine_errors(dir: str) -> str:
|
||||
if "Notes" in fix:
|
||||
fix_msg = fix["Notes"][0]
|
||||
if len(fix["Notes"]) > 1:
|
||||
print(f'Warning: this script may be missing values in [{fix["Notes"]}]')
|
||||
print(
|
||||
f'Warning: this script may be missing values in [{fix["Notes"]}]'
|
||||
)
|
||||
else:
|
||||
fix_msg = fix["DiagnosticMessage"]
|
||||
fix_data = (
|
||||
@ -105,7 +111,9 @@ def _combine_errors(dir: str) -> str:
|
||||
.setdefault(
|
||||
str(fix_msg.get("FileOffset", "FileOffset Not Found")),
|
||||
{
|
||||
"replacements": fix_msg.get("Replacements", "Replacements not found"),
|
||||
"replacements": fix_msg.get(
|
||||
"Replacements", "Replacements not found"
|
||||
),
|
||||
"message": fix_msg.get("Message", "Message not found"),
|
||||
"count": 0,
|
||||
"source_files": [],
|
||||
@ -171,7 +179,9 @@ def main(argv=sys.argv[1:]):
|
||||
] = replacement["ReplacementText"].encode()
|
||||
|
||||
if replacement["Length"] != len(replacement["ReplacementText"]):
|
||||
adjustments += len(replacement["ReplacementText"]) - replacement["Length"]
|
||||
adjustments += (
|
||||
len(replacement["ReplacementText"]) - replacement["Length"]
|
||||
)
|
||||
|
||||
with open(file, "wb") as fout:
|
||||
fout.write(bytes(file_bytes))
|
||||
|
||||
@ -31,13 +31,14 @@ def add_file_to_tree(root_node: FileNode, file_parts: List[str]):
|
||||
for i, dir in enumerate(file_parts[:-1]):
|
||||
node_dirs = current_node.dirs
|
||||
if dir not in node_dirs:
|
||||
directory = "/".join(file_parts[:i + 1])
|
||||
directory = "/".join(file_parts[: i + 1])
|
||||
node_dirs[dir] = FileNode(f"./{directory}")
|
||||
|
||||
current_node = node_dirs[dir]
|
||||
|
||||
assert (current_node.owners_file is
|
||||
None), f"{'/'.join(file_parts[:-1])} there are two OWNERS files in this directory"
|
||||
assert (
|
||||
current_node.owners_file is None
|
||||
), f"{'/'.join(file_parts[:-1])} there are two OWNERS files in this directory"
|
||||
current_node.owners_file = file_parts[-1]
|
||||
|
||||
|
||||
@ -65,7 +66,9 @@ def process_owners_file(output_lines: list[str], node: FileNode) -> None:
|
||||
with open(owners_file_path, "r") as file:
|
||||
contents = yaml.safe_load(file)
|
||||
assert "version" in contents, f"Version not found in {owners_file_path}"
|
||||
assert contents["version"] in parsers, f"Unsupported version in {owners_file_path}"
|
||||
assert (
|
||||
contents["version"] in parsers
|
||||
), f"Unsupported version in {owners_file_path}"
|
||||
parser = parsers[contents["version"]]
|
||||
owners_lines = parser.parse(directory, owners_file_path, contents)
|
||||
output_lines.extend(owners_lines)
|
||||
@ -91,7 +94,9 @@ def print_diff_and_instructions(old_codeowners_contents, new_codeowners_contents
|
||||
)
|
||||
sys.stdout.writelines(diff)
|
||||
|
||||
print("If you are seeing this message in CI you likely need to run `bazel run codeowners`")
|
||||
print(
|
||||
"If you are seeing this message in CI you likely need to run `bazel run codeowners`"
|
||||
)
|
||||
|
||||
|
||||
def validate_generated_codeowners(validator_path: str) -> int:
|
||||
@ -114,7 +119,9 @@ def validate_generated_codeowners(validator_path: str) -> int:
|
||||
|
||||
|
||||
@cache
|
||||
def get_unowned_files(codeowners_binary_path: str, codeowners_file: str = None) -> Set[str]:
|
||||
def get_unowned_files(
|
||||
codeowners_binary_path: str, codeowners_file: str = None
|
||||
) -> Set[str]:
|
||||
temp_output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt")
|
||||
temp_output_file.close()
|
||||
codeowners_file_arg = ""
|
||||
@ -140,7 +147,9 @@ def get_unowned_files(codeowners_binary_path: str, codeowners_file: str = None)
|
||||
return unowned_files
|
||||
|
||||
|
||||
def check_new_files(codeowners_binary_path: str, expansions_file: str, branch: str) -> int:
|
||||
def check_new_files(
|
||||
codeowners_binary_path: str, expansions_file: str, branch: str
|
||||
) -> int:
|
||||
new_files = evergreen_git.get_new_files(expansions_file, branch)
|
||||
if not new_files:
|
||||
print("No new files were detected.")
|
||||
@ -168,27 +177,35 @@ def check_new_files(codeowners_binary_path: str, expansions_file: str, branch: s
|
||||
return 0
|
||||
|
||||
|
||||
def check_orphaned_files(codeowners_binary_path: str, expansions_file: str, branch: str,
|
||||
codeowners_file: str) -> int:
|
||||
def check_orphaned_files(
|
||||
codeowners_binary_path: str, expansions_file: str, branch: str, codeowners_file: str
|
||||
) -> int:
|
||||
# This compares the new codeowners file with the old codeowners file on the same working tree
|
||||
# This tells us which coverage is lost between codeowners file changes
|
||||
current_unowned_files = get_unowned_files(codeowners_binary_path)
|
||||
base_revision = evergreen_git.get_diff_revision(expansions_file, branch)
|
||||
previous_codeowners_file_contents = evergreen_git.get_file_at_revision(
|
||||
codeowners_file, base_revision)
|
||||
codeowners_file, base_revision
|
||||
)
|
||||
if previous_codeowners_file_contents is None:
|
||||
return 0
|
||||
temp_codeowners_file = tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt")
|
||||
temp_codeowners_file = tempfile.NamedTemporaryFile(
|
||||
mode="w", delete=False, suffix=".txt"
|
||||
)
|
||||
temp_codeowners_file.write(previous_codeowners_file_contents)
|
||||
temp_codeowners_file.close()
|
||||
old_unowned_files = get_unowned_files(codeowners_binary_path, temp_codeowners_file.name)
|
||||
old_unowned_files = get_unowned_files(
|
||||
codeowners_binary_path, temp_codeowners_file.name
|
||||
)
|
||||
|
||||
unowned_files_difference = current_unowned_files - old_unowned_files
|
||||
if not unowned_files_difference:
|
||||
print("No files have lost ownership with these changes.")
|
||||
return 0
|
||||
|
||||
print("The following files lost ownership with CODEOWNERS changes:", file=sys.stderr)
|
||||
print(
|
||||
"The following files lost ownership with CODEOWNERS changes:", file=sys.stderr
|
||||
)
|
||||
for file in sorted(unowned_files_difference):
|
||||
print(f"- {file}", file=sys.stderr)
|
||||
|
||||
@ -196,21 +213,22 @@ def check_orphaned_files(codeowners_binary_path: str, expansions_file: str, bran
|
||||
|
||||
|
||||
def post_generation_checks(
|
||||
validator_path: str,
|
||||
should_run_validation: bool,
|
||||
codeowners_binary_path: str,
|
||||
should_check_new_files: bool,
|
||||
expansions_file: str,
|
||||
branch: str,
|
||||
codeowners_file_path: str,
|
||||
validator_path: str,
|
||||
should_run_validation: bool,
|
||||
codeowners_binary_path: str,
|
||||
should_check_new_files: bool,
|
||||
expansions_file: str,
|
||||
branch: str,
|
||||
codeowners_file_path: str,
|
||||
) -> int:
|
||||
status = 0
|
||||
if should_run_validation:
|
||||
status |= validate_generated_codeowners(validator_path)
|
||||
if should_check_new_files:
|
||||
status |= check_new_files(codeowners_binary_path, expansions_file, branch)
|
||||
status |= check_orphaned_files(codeowners_binary_path, expansions_file, branch,
|
||||
codeowners_file_path)
|
||||
status |= check_orphaned_files(
|
||||
codeowners_binary_path, expansions_file, branch, codeowners_file_path
|
||||
)
|
||||
|
||||
return status
|
||||
|
||||
@ -219,8 +237,12 @@ def main():
|
||||
# If we are running in bazel, default the directory to the workspace
|
||||
default_dir = os.environ.get("BUILD_WORKSPACE_DIRECTORY")
|
||||
if not default_dir:
|
||||
process = subprocess.run(["git", "rev-parse", "--show-toplevel"], capture_output=True,
|
||||
text=True, check=True)
|
||||
process = subprocess.run(
|
||||
["git", "rev-parse", "--show-toplevel"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
default_dir = process.stdout.strip()
|
||||
|
||||
codeowners_validator_path = os.environ.get("CODEOWNERS_VALIDATOR_PATH")
|
||||
@ -247,12 +269,14 @@ def main():
|
||||
help="Path of the CODEOWNERS file to be generated.",
|
||||
default=os.path.join(".github", "CODEOWNERS"),
|
||||
)
|
||||
parser.add_argument("--repo-dir", help="Root of the repo to scan for OWNER files.",
|
||||
default=default_dir)
|
||||
parser.add_argument(
|
||||
"--repo-dir",
|
||||
help="Root of the repo to scan for OWNER files.",
|
||||
default=default_dir,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--check",
|
||||
help=
|
||||
"When set, program exits 1 when the CODEOWNERS content changes. This will skip generation",
|
||||
help="When set, program exits 1 when the CODEOWNERS content changes. This will skip generation",
|
||||
default=False,
|
||||
action="store_true",
|
||||
)
|
||||
@ -276,8 +300,7 @@ def main():
|
||||
)
|
||||
parser.add_argument(
|
||||
"--branch",
|
||||
help=
|
||||
"Helps the script understand what branch to compare against to see what new files are added when run locally. Defaults to master or main.",
|
||||
help="Helps the script understand what branch to compare against to see what new files are added when run locally. Defaults to master or main.",
|
||||
default=None,
|
||||
action="store",
|
||||
)
|
||||
@ -301,9 +324,14 @@ def main():
|
||||
root_node = build_tree(files)
|
||||
process_dir(output_lines, root_node)
|
||||
except Exception as ex:
|
||||
print("An exception was found while generating the CODEOWNERS file.", file=sys.stderr)
|
||||
print("Please refer to the docs to see the spec for OWNERS.yml files here :",
|
||||
file=sys.stderr)
|
||||
print(
|
||||
"An exception was found while generating the CODEOWNERS file.",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(
|
||||
"Please refer to the docs to see the spec for OWNERS.yml files here :",
|
||||
file=sys.stderr,
|
||||
)
|
||||
print(
|
||||
"https://github.com/10gen/mongo/blob/master/docs/owners/owners_format.md",
|
||||
file=sys.stderr,
|
||||
@ -327,7 +355,8 @@ def main():
|
||||
should_check_new_files = True
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Invalid value for CODEOWNERS_CHECK_NEW_FILES: {should_check_new_files}")
|
||||
f"Invalid value for CODEOWNERS_CHECK_NEW_FILES: {should_check_new_files}"
|
||||
)
|
||||
else:
|
||||
should_check_new_files = args.check_new_files
|
||||
|
||||
@ -350,7 +379,9 @@ def main():
|
||||
|
||||
with open(output_file, "w") as file:
|
||||
file.write(new_contents)
|
||||
print(f"Successfully wrote to the CODEOWNERS file at: {os.path.abspath(output_file)}")
|
||||
print(
|
||||
f"Successfully wrote to the CODEOWNERS file at: {os.path.abspath(output_file)}"
|
||||
)
|
||||
|
||||
# Add validation after generating CODEOWNERS file
|
||||
return post_generation_checks(
|
||||
|
||||
@ -9,12 +9,16 @@ import yaml
|
||||
|
||||
# Parser for OWNERS.yml files version 1.0.0
|
||||
class OwnersParserV1:
|
||||
def parse(self, directory: str, owners_file_path: str, contents: Dict[str, any]) -> List[str]:
|
||||
def parse(
|
||||
self, directory: str, owners_file_path: str, contents: Dict[str, any]
|
||||
) -> List[str]:
|
||||
lines = []
|
||||
no_parent_owners = False
|
||||
if "options" in contents:
|
||||
options = contents["options"]
|
||||
no_parent_owners = "no_parent_owners" in options and options["no_parent_owners"]
|
||||
no_parent_owners = (
|
||||
"no_parent_owners" in options and options["no_parent_owners"]
|
||||
)
|
||||
|
||||
if no_parent_owners:
|
||||
# Specfying no owners will ensure that no file in this directory has an owner unless it
|
||||
@ -28,15 +32,18 @@ class OwnersParserV1:
|
||||
if "filters" in contents:
|
||||
filters = contents["filters"]
|
||||
for _filter in filters:
|
||||
assert ("approvers" in _filter
|
||||
), f"Filter in {owners_file_path} does not have approvers."
|
||||
assert (
|
||||
"approvers" in _filter
|
||||
), f"Filter in {owners_file_path} does not have approvers."
|
||||
approvers = _filter["approvers"]
|
||||
del _filter["approvers"]
|
||||
if "metadata" in _filter:
|
||||
del _filter["metadata"]
|
||||
|
||||
# the last key remaining should be the pattern for the filter
|
||||
assert len(_filter) == 1, f"Filter in {owners_file_path} has incorrect values."
|
||||
assert (
|
||||
len(_filter) == 1
|
||||
), f"Filter in {owners_file_path} has incorrect values."
|
||||
pattern = next(iter(_filter))
|
||||
owners: set[str] = set()
|
||||
|
||||
@ -44,7 +51,9 @@ class OwnersParserV1:
|
||||
if "@" in owner:
|
||||
# approver is email, just add as is
|
||||
if not owner.endswith("@mongodb.com"):
|
||||
raise RuntimeError("Any emails specified must be a mongodb.com email.")
|
||||
raise RuntimeError(
|
||||
"Any emails specified must be a mongodb.com email."
|
||||
)
|
||||
owners.add(owner)
|
||||
else:
|
||||
# approver is github username, need to prefix with @
|
||||
@ -52,8 +61,9 @@ class OwnersParserV1:
|
||||
|
||||
NOOWNERS_NAME = "NOOWNERS-DO-NOT-USE-DEPRECATED-2024-07-01"
|
||||
if NOOWNERS_NAME in approvers:
|
||||
assert (len(approvers) == 1
|
||||
), f"{NOOWNERS_NAME} must be the only approver when it is used."
|
||||
assert (
|
||||
len(approvers) == 1
|
||||
), f"{NOOWNERS_NAME} must be the only approver when it is used."
|
||||
else:
|
||||
for approver in approvers:
|
||||
if approver in aliases:
|
||||
@ -72,7 +82,8 @@ class OwnersParserV1:
|
||||
def process_alias_import(self, path: str) -> Dict[str, List[str]]:
|
||||
if not path.startswith("//"):
|
||||
raise RuntimeError(
|
||||
f"Alias file paths must start with // and be relative to the repo root: {path}")
|
||||
f"Alias file paths must start with // and be relative to the repo root: {path}"
|
||||
)
|
||||
|
||||
# remove // from beginning of path
|
||||
parsed_path = path[2::]
|
||||
@ -108,7 +119,9 @@ class OwnersParserV1:
|
||||
parsed_pattern = f"/{directory}/**/{pattern}"
|
||||
|
||||
if not self.test_pattern(parsed_pattern):
|
||||
raise (RuntimeError(f"Can not find any files that match pattern: `{pattern}`"))
|
||||
raise (
|
||||
RuntimeError(f"Can not find any files that match pattern: `{pattern}`")
|
||||
)
|
||||
|
||||
return self.get_line(parsed_pattern, owners)
|
||||
|
||||
|
||||
@ -24,6 +24,8 @@ class OwnersParserV2(OwnersParserV1):
|
||||
parsed_pattern = f"/{directory}/{pattern}"
|
||||
|
||||
if not self.test_pattern(parsed_pattern):
|
||||
raise (RuntimeError(f"Can not find any files that match pattern: `{pattern}`"))
|
||||
raise (
|
||||
RuntimeError(f"Can not find any files that match pattern: `{pattern}`")
|
||||
)
|
||||
|
||||
return self.get_line(parsed_pattern, owners)
|
||||
|
||||
@ -10,11 +10,13 @@ def get_validator_env() -> dict:
|
||||
"""Prepare the environment for the codeowners-validator."""
|
||||
env = os.environ.copy()
|
||||
|
||||
env.update({
|
||||
"REPOSITORY_PATH": ".",
|
||||
"CHECKS": "duppatterns,syntax",
|
||||
"EXPERIMENTAL_CHECKS": "avoid-shadowing",
|
||||
})
|
||||
env.update(
|
||||
{
|
||||
"REPOSITORY_PATH": ".",
|
||||
"CHECKS": "duppatterns,syntax",
|
||||
"EXPERIMENTAL_CHECKS": "avoid-shadowing",
|
||||
}
|
||||
)
|
||||
return env
|
||||
|
||||
|
||||
@ -28,8 +30,9 @@ def run_validator(validator_path: str) -> int:
|
||||
env = get_validator_env()
|
||||
|
||||
try:
|
||||
result = subprocess.run([validator_path], env=env, check=True, capture_output=True,
|
||||
text=True)
|
||||
result = subprocess.run(
|
||||
[validator_path], env=env, check=True, capture_output=True, text=True
|
||||
)
|
||||
if result.stdout:
|
||||
print(result.stdout)
|
||||
return 0
|
||||
@ -40,7 +43,10 @@ def run_validator(validator_path: str) -> int:
|
||||
print(exc.stderr, file=sys.stderr)
|
||||
return exc.returncode
|
||||
except FileNotFoundError:
|
||||
print("Error: Failed to run codeowners-validator after installation", file=sys.stderr)
|
||||
print(
|
||||
"Error: Failed to run codeowners-validator after installation",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
@ -15,21 +15,21 @@ from datetime import datetime
|
||||
|
||||
from retry import retry
|
||||
|
||||
NORMALIZED_ARCH = {"x86_64": "x64", "aarch64": "arm64", "arm64": "arm64", "AMD64": "x64"}
|
||||
NORMALIZED_ARCH = {
|
||||
"x86_64": "x64",
|
||||
"aarch64": "arm64",
|
||||
"arm64": "arm64",
|
||||
"AMD64": "x64",
|
||||
}
|
||||
|
||||
NORMALIZED_OS = {"Windows": "windows", "Darwin": "macos", "Linux": "linux"}
|
||||
|
||||
CHECKSUMS = {
|
||||
"engflow_auth_linux_arm64":
|
||||
"ad5ffee1e6db926f5066aa40ee35517b1993851d0063ac121dbf5b407c81e2bf",
|
||||
"engflow_auth_linux_x64":
|
||||
"b731bae21628b2be321c24b342854c6ed1ed0326010e62a2ecf0b5650a56cf1a",
|
||||
"engflow_auth_macos_arm64":
|
||||
"69057929b4515d41b1af861c9bfdbc47427cc5ce5a80c94d4776c8bef672292e",
|
||||
"engflow_auth_macos_x64":
|
||||
"a322373e41faa7750c34348f357c5a4a670a66cfd988e80b4343c72822d91292",
|
||||
"engflow_auth_windows_x64.exe":
|
||||
"cb9590ffcc6731389ded173250f604b37778417450b1dc92c6bafadeef342826",
|
||||
"engflow_auth_linux_arm64": "ad5ffee1e6db926f5066aa40ee35517b1993851d0063ac121dbf5b407c81e2bf",
|
||||
"engflow_auth_linux_x64": "b731bae21628b2be321c24b342854c6ed1ed0326010e62a2ecf0b5650a56cf1a",
|
||||
"engflow_auth_macos_arm64": "69057929b4515d41b1af861c9bfdbc47427cc5ce5a80c94d4776c8bef672292e",
|
||||
"engflow_auth_macos_x64": "a322373e41faa7750c34348f357c5a4a670a66cfd988e80b4343c72822d91292",
|
||||
"engflow_auth_windows_x64.exe": "cb9590ffcc6731389ded173250f604b37778417450b1dc92c6bafadeef342826",
|
||||
}
|
||||
GH_URL_PREFIX = "https://github.com/EngFlow/auth/releases/download/v0.0.13/"
|
||||
CLUSTER = "sodalite.cluster.engflow.com"
|
||||
@ -73,7 +73,9 @@ def install(verbose: bool) -> str:
|
||||
binary_path += ".exe"
|
||||
if os.path.exists(binary_path):
|
||||
if verbose:
|
||||
print(f"{binary_filename} already exists at {binary_path}, skipping download")
|
||||
print(
|
||||
f"{binary_filename} already exists at {binary_path}, skipping download"
|
||||
)
|
||||
else:
|
||||
url = GH_URL_PREFIX + tag
|
||||
print(f"Downloading {url}...")
|
||||
@ -109,7 +111,9 @@ def update_bazelrc(binary_path: str, verbose: bool):
|
||||
|
||||
def authenticate(binary_path: str, verbose: bool) -> bool:
|
||||
need_login = False
|
||||
p = subprocess.run(f"{binary_path} export {CLUSTER}", shell=True, capture_output=True)
|
||||
p = subprocess.run(
|
||||
f"{binary_path} export {CLUSTER}", shell=True, capture_output=True
|
||||
)
|
||||
if p.returncode != 0:
|
||||
need_login = True
|
||||
else:
|
||||
|
||||
@ -19,8 +19,10 @@ def write_file(repo: Repo, file_name: str) -> None:
|
||||
file.write("change\n")
|
||||
|
||||
|
||||
@unittest.skipIf(sys.platform == "win32",
|
||||
reason="This test breaks on windows and only needs to work on linux")
|
||||
@unittest.skipIf(
|
||||
sys.platform == "win32",
|
||||
reason="This test breaks on windows and only needs to work on linux",
|
||||
)
|
||||
class TestChangedFiles(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
@ -37,7 +39,8 @@ class TestChangedFiles(unittest.TestCase):
|
||||
|
||||
# get tracked files that have been changed that are tracked by git
|
||||
diff_output = root_repo.git.execute(
|
||||
["git", "diff", "--name-only", "--diff-filter=d", commit])
|
||||
["git", "diff", "--name-only", "--diff-filter=d", commit]
|
||||
)
|
||||
files_to_copy.update(diff_output.split("\n"))
|
||||
|
||||
# gets all the untracked changes in the current repo
|
||||
@ -86,7 +89,9 @@ class TestChangedFiles(unittest.TestCase):
|
||||
# make a new file that has not been commited yet
|
||||
write_file(self.repo, new_file_name)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) as tmp:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w", encoding="utf-8", delete=False
|
||||
) as tmp:
|
||||
tmp.write("fake_expansion: true\n")
|
||||
self.expansions_file = tmp.name
|
||||
|
||||
@ -98,12 +103,18 @@ class TestChangedFiles(unittest.TestCase):
|
||||
def test_local_unchanged_files(self):
|
||||
evergreen_git.get_remote_branch_ref = MagicMock(return_value=self.base_revision)
|
||||
new_files = evergreen_git.get_new_files()
|
||||
self.assertEqual(new_files, [],
|
||||
msg="New files list was not empty when no new files were added to git.")
|
||||
self.assertEqual(
|
||||
new_files,
|
||||
[],
|
||||
msg="New files list was not empty when no new files were added to git.",
|
||||
)
|
||||
|
||||
changed_files = evergreen_git.get_changed_files()
|
||||
self.assertEqual(changed_files, [changed_file_name],
|
||||
msg="Changed file list was not as expected.")
|
||||
self.assertEqual(
|
||||
changed_files,
|
||||
[changed_file_name],
|
||||
msg="Changed file list was not as expected.",
|
||||
)
|
||||
|
||||
self.repo.git.execute(["git", "add", "."])
|
||||
|
||||
@ -131,10 +142,15 @@ class TestChangedFiles(unittest.TestCase):
|
||||
tmp.write(f"revision: {self.base_revision}\n")
|
||||
self.repo.git.execute(["git", "add", "."])
|
||||
new_files = evergreen_git.get_new_files(expansions_file=self.expansions_file)
|
||||
self.assertEqual(new_files, [new_file_name],
|
||||
msg="New file list did not contain the new file.")
|
||||
self.assertEqual(
|
||||
new_files,
|
||||
[new_file_name],
|
||||
msg="New file list did not contain the new file.",
|
||||
)
|
||||
|
||||
changed_files = evergreen_git.get_changed_files(expansions_file=self.expansions_file)
|
||||
changed_files = evergreen_git.get_changed_files(
|
||||
expansions_file=self.expansions_file
|
||||
)
|
||||
self.assertEqual(
|
||||
changed_files,
|
||||
[changed_file_name, new_file_name],
|
||||
@ -146,10 +162,15 @@ class TestChangedFiles(unittest.TestCase):
|
||||
self.repo.git.execute(["git", "add", "."])
|
||||
self.repo.git.execute(["git", "commit", "-m", "Fake waterfall changes"])
|
||||
new_files = evergreen_git.get_new_files(expansions_file=self.expansions_file)
|
||||
self.assertEqual(new_files, [new_file_name],
|
||||
msg="New file list did not contain the new file.")
|
||||
self.assertEqual(
|
||||
new_files,
|
||||
[new_file_name],
|
||||
msg="New file list did not contain the new file.",
|
||||
)
|
||||
|
||||
changed_files = evergreen_git.get_changed_files(expansions_file=self.expansions_file)
|
||||
changed_files = evergreen_git.get_changed_files(
|
||||
expansions_file=self.expansions_file
|
||||
)
|
||||
self.assertEqual(
|
||||
changed_files,
|
||||
[changed_file_name, new_file_name],
|
||||
|
||||
@ -36,7 +36,14 @@ def get_mongodb_remote(repo: Repo) -> Remote:
|
||||
assert len(parts) >= 2, f"Unexpected git remote url: {url}"
|
||||
owner = parts[-2].split(":")[-1]
|
||||
|
||||
if owner in ("10gen", "mongodb", "evergreen-ci", "mongodb-ets", "realm", "mongodb-js"):
|
||||
if owner in (
|
||||
"10gen",
|
||||
"mongodb",
|
||||
"evergreen-ci",
|
||||
"mongodb-ets",
|
||||
"realm",
|
||||
"mongodb-js",
|
||||
):
|
||||
picked_remote = remote
|
||||
print(f"Selected remote: {remote.url}")
|
||||
break
|
||||
@ -106,13 +113,15 @@ def get_diff_revision(expansions_file: str = None, branch: str = None) -> str:
|
||||
return diff_commit
|
||||
|
||||
|
||||
def get_changed_files(expansions_file: str = None, branch: str = None,
|
||||
diff_filter: str = "d") -> List[str]:
|
||||
def get_changed_files(
|
||||
expansions_file: str = None, branch: str = None, diff_filter: str = "d"
|
||||
) -> List[str]:
|
||||
diff_commit = get_diff_revision(expansions_file, branch)
|
||||
repo = Repo()
|
||||
|
||||
output = repo.git.execute(
|
||||
["git", "diff", "--name-only", f"--diff-filter={diff_filter}", diff_commit])
|
||||
["git", "diff", "--name-only", f"--diff-filter={diff_filter}", diff_commit]
|
||||
)
|
||||
files = output.split("\n")
|
||||
return [file for file in files if file]
|
||||
|
||||
@ -135,7 +144,10 @@ def get_files_to_lint() -> List[str]:
|
||||
tracked_files = repo.git.execute(["git", "ls-files"]).split("\n")
|
||||
# all unstaged files from git
|
||||
tracked_files.extend(
|
||||
repo.git.execute(["git", "ls-files", "--others", "--exclude-standard"]).split("\n"))
|
||||
repo.git.execute(["git", "ls-files", "--others", "--exclude-standard"]).split(
|
||||
"\n"
|
||||
)
|
||||
)
|
||||
# remove any empty entries
|
||||
tracked_files = list(filter(bool, tracked_files))
|
||||
return tracked_files
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Options used for building process."""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
@ -18,7 +19,8 @@ class PathOptions:
|
||||
|
||||
if not self._compiled_shared_library_file_patterns:
|
||||
self._compiled_shared_library_file_patterns = tuple(
|
||||
map(re.compile, self._shared_library_file_patterns))
|
||||
map(re.compile, self._shared_library_file_patterns)
|
||||
)
|
||||
return self._compiled_shared_library_file_patterns
|
||||
|
||||
def is_shared_library_file(self, path: str) -> bool:
|
||||
|
||||
@ -19,16 +19,25 @@ def find_all_failed(bin_path: str) -> list[str]:
|
||||
if contents:
|
||||
ignored_paths.append(contents)
|
||||
|
||||
process = subprocess.run([bin_path, "--format=json", "--mode=check", "-r", "./"], check=True,
|
||||
capture_output=True)
|
||||
process = subprocess.run(
|
||||
[bin_path, "--format=json", "--mode=check", "-r", "./"],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
buildifier_results = json.loads(process.stdout)
|
||||
if buildifier_results["success"]:
|
||||
return []
|
||||
|
||||
return [
|
||||
result["filename"] for result in buildifier_results["files"]
|
||||
if (not result["formatted"] and \
|
||||
not any(result["filename"].startswith(ignored_path) for ignored_path in ignored_paths))
|
||||
result["filename"]
|
||||
for result in buildifier_results["files"]
|
||||
if (
|
||||
not result["formatted"]
|
||||
and not any(
|
||||
result["filename"].startswith(ignored_path)
|
||||
for ignored_path in ignored_paths
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@ -44,14 +53,18 @@ def fix_all(bin_path: str):
|
||||
|
||||
def lint(bin_path: str, files: list[str], generate_report: bool):
|
||||
for file in files:
|
||||
process = subprocess.run([bin_path, "--format=json", "--mode=check", file], check=True,
|
||||
capture_output=True)
|
||||
process = subprocess.run(
|
||||
[bin_path, "--format=json", "--mode=check", file],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
result = json.loads(process.stdout)
|
||||
if result["success"]:
|
||||
continue
|
||||
# This purposefully gives a exit code of 4 when there is a diff
|
||||
process = subprocess.run([bin_path, "--mode=diff", file], capture_output=True,
|
||||
encoding='utf-8')
|
||||
process = subprocess.run(
|
||||
[bin_path, "--mode=diff", file], capture_output=True, encoding="utf-8"
|
||||
)
|
||||
if process.returncode not in (0, 4):
|
||||
raise RuntimeError()
|
||||
diff = process.stdout
|
||||
@ -62,7 +75,8 @@ def lint(bin_path: str, files: list[str], generate_report: bool):
|
||||
header = (
|
||||
"There are linting errors in this file, fix them with one of the following commands:\n"
|
||||
"python3 buildscripts/buildifier.py fix-all\n"
|
||||
f"python3 buildscripts/buildifier.py fix {file}\n\n")
|
||||
f"python3 buildscripts/buildifier.py fix {file}\n\n"
|
||||
)
|
||||
report = make_report(f"{file} warnings", json.dumps(result, indent=2), 1)
|
||||
try_combine_reports(report)
|
||||
put_report(report)
|
||||
@ -79,15 +93,20 @@ def fix(bin_path: str, files: list[str]):
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='buildifier wrapper')
|
||||
parser = argparse.ArgumentParser(description="buildifier wrapper")
|
||||
parser.add_argument(
|
||||
"--binary-dir", "-b", type=str,
|
||||
"--binary-dir",
|
||||
"-b",
|
||||
type=str,
|
||||
help="Path to the buildifier binary, defaults to looking in the current directory.",
|
||||
default="")
|
||||
default="",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--generate-report", action="store_true",
|
||||
"--generate-report",
|
||||
action="store_true",
|
||||
help="Whether or not a report of the lint errors should be generated for evergreen.",
|
||||
default=False)
|
||||
default=False,
|
||||
)
|
||||
parser.set_defaults(subcommand=None)
|
||||
|
||||
sub = parser.add_subparsers(title="buildifier subcommands", help="sub-command help")
|
||||
@ -108,7 +127,8 @@ def main():
|
||||
|
||||
args = parser.parse_args()
|
||||
assert os.path.abspath(os.curdir) == str(
|
||||
mongo_dir.absolute()), "buildifier.py must be run from the root of the mongo repo"
|
||||
mongo_dir.absolute()
|
||||
), "buildifier.py must be run from the root of the mongo repo"
|
||||
binary_name = "buildifier.exe" if platform.system() == "Windows" else "buildifier"
|
||||
if args.binary_dir:
|
||||
binary_path = os.path.join(args.binary_dir, binary_name)
|
||||
@ -127,7 +147,9 @@ def main():
|
||||
else:
|
||||
# we purposefully do not use sub.choices.keys() so it does not print as a dict_keys object
|
||||
choices = [key for key in sub.choices]
|
||||
raise RuntimeError(f"One of the following subcommands must be specified: {choices}")
|
||||
raise RuntimeError(
|
||||
f"One of the following subcommands must be specified: {choices}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Command line utility for determining what jstests have been added or modified."""
|
||||
|
||||
import collections
|
||||
import copy
|
||||
import json
|
||||
@ -61,18 +62,27 @@ SUITE_FILES = ["with_server"]
|
||||
|
||||
BURN_IN_TEST_MEMBERSHIP_FILE = "burn_in_test_membership_map_file_for_ci.json"
|
||||
|
||||
SUPPORTED_TEST_KINDS = ("fsm_workload_test", "js_test", "json_schema_test",
|
||||
"multi_stmt_txn_passthrough", "parallel_fsm_workload_test",
|
||||
"all_versions_js_test")
|
||||
SUPPORTED_TEST_KINDS = (
|
||||
"fsm_workload_test",
|
||||
"js_test",
|
||||
"json_schema_test",
|
||||
"multi_stmt_txn_passthrough",
|
||||
"parallel_fsm_workload_test",
|
||||
"all_versions_js_test",
|
||||
)
|
||||
RUN_ALL_FEATURE_FLAG_TESTS = "--runAllFeatureFlagTests"
|
||||
|
||||
|
||||
class RepeatConfig(object):
|
||||
"""Configuration for how tests should be repeated."""
|
||||
|
||||
def __init__(self, repeat_tests_secs: Optional[int] = None,
|
||||
repeat_tests_min: Optional[int] = None, repeat_tests_max: Optional[int] = None,
|
||||
repeat_tests_num: Optional[int] = None):
|
||||
def __init__(
|
||||
self,
|
||||
repeat_tests_secs: Optional[int] = None,
|
||||
repeat_tests_min: Optional[int] = None,
|
||||
repeat_tests_max: Optional[int] = None,
|
||||
repeat_tests_num: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Create a Repeat Config.
|
||||
|
||||
@ -97,10 +107,14 @@ class RepeatConfig(object):
|
||||
|
||||
if self.repeat_tests_max:
|
||||
if not self.repeat_tests_secs:
|
||||
raise ValueError("Must specify --repeat-tests-secs with --repeat-tests-max")
|
||||
raise ValueError(
|
||||
"Must specify --repeat-tests-secs with --repeat-tests-max"
|
||||
)
|
||||
|
||||
if self.repeat_tests_min and self.repeat_tests_min > self.repeat_tests_max:
|
||||
raise ValueError("--repeat-tests-secs-min is greater than --repeat-tests-max")
|
||||
raise ValueError(
|
||||
"--repeat-tests-secs-min is greater than --repeat-tests-max"
|
||||
)
|
||||
|
||||
if self.repeat_tests_min and not self.repeat_tests_secs:
|
||||
raise ValueError("Must specify --repeat-tests-secs with --repeat-tests-min")
|
||||
@ -120,15 +134,19 @@ class RepeatConfig(object):
|
||||
repeat_options += f" --repeatTestsMax={self.repeat_tests_max} "
|
||||
return repeat_options
|
||||
|
||||
repeat_suites = self.repeat_tests_num if self.repeat_tests_num else REPEAT_SUITES
|
||||
repeat_suites = (
|
||||
self.repeat_tests_num if self.repeat_tests_num else REPEAT_SUITES
|
||||
)
|
||||
return f" --repeatSuites={repeat_suites} "
|
||||
|
||||
def __repr__(self):
|
||||
"""Build string representation of object for debugging."""
|
||||
return "".join([
|
||||
f"RepeatConfig[num={self.repeat_tests_num}, secs={self.repeat_tests_secs}, ",
|
||||
f"min={self.repeat_tests_min}, max={self.repeat_tests_max}]",
|
||||
])
|
||||
return "".join(
|
||||
[
|
||||
f"RepeatConfig[num={self.repeat_tests_num}, secs={self.repeat_tests_secs}, ",
|
||||
f"min={self.repeat_tests_min}, max={self.repeat_tests_max}]",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def is_file_a_test_file(file_path: str) -> bool:
|
||||
@ -161,11 +179,15 @@ def find_excludes(selector_file: str) -> Tuple[List, List, List]:
|
||||
try:
|
||||
js_test = yml["selector"]["js_test"]
|
||||
except KeyError:
|
||||
raise Exception(f"The selector file {selector_file} is missing the 'selector.js_test' key")
|
||||
raise Exception(
|
||||
f"The selector file {selector_file} is missing the 'selector.js_test' key"
|
||||
)
|
||||
|
||||
return (default_if_none(js_test.get("exclude_suites"), []),
|
||||
default_if_none(js_test.get("exclude_tasks"), []),
|
||||
default_if_none(js_test.get("exclude_tests"), []))
|
||||
return (
|
||||
default_if_none(js_test.get("exclude_suites"), []),
|
||||
default_if_none(js_test.get("exclude_tasks"), []),
|
||||
default_if_none(js_test.get("exclude_tests"), []),
|
||||
)
|
||||
|
||||
|
||||
def filter_tests(tests: Set[str], exclude_tests: List[str]) -> Set[str]:
|
||||
@ -199,7 +221,9 @@ def create_executor_list(suites, exclude_suites):
|
||||
try:
|
||||
with open(BURN_IN_TEST_MEMBERSHIP_FILE) as file:
|
||||
test_membership = collections.defaultdict(list, json.load(file))
|
||||
LOGGER.info(f"Using cached test membership file {BURN_IN_TEST_MEMBERSHIP_FILE}.")
|
||||
LOGGER.info(
|
||||
f"Using cached test membership file {BURN_IN_TEST_MEMBERSHIP_FILE}."
|
||||
)
|
||||
except FileNotFoundError:
|
||||
LOGGER.info("Getting test membership data.")
|
||||
test_membership = create_test_membership_map(test_kind=SUPPORTED_TEST_KINDS)
|
||||
@ -208,7 +232,9 @@ def create_executor_list(suites, exclude_suites):
|
||||
for suite in suites:
|
||||
LOGGER.debug("Adding tests for suite", suite=suite, tests=suite.tests)
|
||||
for test in suite.tests:
|
||||
LOGGER.debug("membership for test", test=test, membership=test_membership[test])
|
||||
LOGGER.debug(
|
||||
"membership for test", test=test, membership=test_membership[test]
|
||||
)
|
||||
for executor in set(test_membership[test]) - set(exclude_suites):
|
||||
if test not in memberships[executor]:
|
||||
memberships[executor].append(test)
|
||||
@ -255,9 +281,9 @@ class TaskToBurnInInfo(NamedTuple):
|
||||
|
||||
@classmethod
|
||||
def from_task(
|
||||
cls,
|
||||
task: VariantTask,
|
||||
tests_by_suite: Dict[str, List[str]],
|
||||
cls,
|
||||
task: VariantTask,
|
||||
tests_by_suite: Dict[str, List[str]],
|
||||
) -> "TaskToBurnInInfo":
|
||||
"""
|
||||
Gather the information needed to run the given task.
|
||||
@ -271,7 +297,8 @@ class TaskToBurnInInfo(NamedTuple):
|
||||
name=suite_name,
|
||||
resmoke_args=resmoke_args,
|
||||
tests=tests_by_suite[suite_name],
|
||||
) for suite_name, resmoke_args in task.combined_suite_to_resmoke_args_map.items()
|
||||
)
|
||||
for suite_name, resmoke_args in task.combined_suite_to_resmoke_args_map.items()
|
||||
if len(tests_by_suite[suite_name]) > 0
|
||||
]
|
||||
return cls(
|
||||
@ -280,9 +307,12 @@ class TaskToBurnInInfo(NamedTuple):
|
||||
)
|
||||
|
||||
|
||||
def create_task_list(evergreen_conf: EvergreenProjectConfig, build_variant: str,
|
||||
tests_by_suite: Dict[str, List[str]],
|
||||
exclude_tasks: [str]) -> Dict[str, TaskToBurnInInfo]:
|
||||
def create_task_list(
|
||||
evergreen_conf: EvergreenProjectConfig,
|
||||
build_variant: str,
|
||||
tests_by_suite: Dict[str, List[str]],
|
||||
exclude_tasks: [str],
|
||||
) -> Dict[str, TaskToBurnInInfo]:
|
||||
"""
|
||||
Find associated tasks for the specified build_variant and suites.
|
||||
|
||||
@ -294,15 +324,20 @@ def create_task_list(evergreen_conf: EvergreenProjectConfig, build_variant: str,
|
||||
"""
|
||||
log = LOGGER.bind(build_variant=build_variant)
|
||||
|
||||
log.debug("creating task list for suites", suites=tests_by_suite, exclude_tasks=exclude_tasks)
|
||||
log.debug(
|
||||
"creating task list for suites",
|
||||
suites=tests_by_suite,
|
||||
exclude_tasks=exclude_tasks,
|
||||
)
|
||||
evg_build_variant = _get_evg_build_variant_by_name(evergreen_conf, build_variant)
|
||||
|
||||
# Find all the build variant tasks.
|
||||
exclude_tasks_set = set(exclude_tasks)
|
||||
all_variant_tasks = {
|
||||
task.name: task
|
||||
for task in evg_build_variant.tasks if task.name not in exclude_tasks_set and (
|
||||
task.is_run_tests_task or task.is_generate_resmoke_task)
|
||||
for task in evg_build_variant.tasks
|
||||
if task.name not in exclude_tasks_set
|
||||
and (task.is_run_tests_task or task.is_generate_resmoke_task)
|
||||
}
|
||||
|
||||
# Return the list of tasks to run for the specified suite.
|
||||
@ -327,10 +362,13 @@ def _set_resmoke_cmd(repeat_config: RepeatConfig, resmoke_args: [str]) -> [str]:
|
||||
return new_args
|
||||
|
||||
|
||||
def create_task_list_for_tests(changed_tests: Set[str], build_variant: str,
|
||||
evg_conf: EvergreenProjectConfig,
|
||||
exclude_suites: Optional[List] = None,
|
||||
exclude_tasks: Optional[List] = None) -> Dict[str, TaskToBurnInInfo]:
|
||||
def create_task_list_for_tests(
|
||||
changed_tests: Set[str],
|
||||
build_variant: str,
|
||||
evg_conf: EvergreenProjectConfig,
|
||||
exclude_suites: Optional[List] = None,
|
||||
exclude_tasks: Optional[List] = None,
|
||||
) -> Dict[str, TaskToBurnInInfo]:
|
||||
"""
|
||||
Create a list of tests by task for the given tests.
|
||||
|
||||
@ -355,9 +393,12 @@ def create_task_list_for_tests(changed_tests: Set[str], build_variant: str,
|
||||
return create_task_list(evg_conf, build_variant, tests_by_executor, exclude_tasks)
|
||||
|
||||
|
||||
def create_tests_by_task(build_variant: str, evg_conf: EvergreenProjectConfig,
|
||||
changed_tests: Set[str],
|
||||
install_dir: Optional[str]) -> Dict[str, TaskToBurnInInfo]:
|
||||
def create_tests_by_task(
|
||||
build_variant: str,
|
||||
evg_conf: EvergreenProjectConfig,
|
||||
changed_tests: Set[str],
|
||||
install_dir: Optional[str],
|
||||
) -> Dict[str, TaskToBurnInInfo]:
|
||||
"""
|
||||
Create a list of tests by task.
|
||||
|
||||
@ -378,8 +419,9 @@ def create_tests_by_task(build_variant: str, evg_conf: EvergreenProjectConfig,
|
||||
buildscripts.resmokelib.parser.set_run_options(run_options)
|
||||
|
||||
if changed_tests:
|
||||
return create_task_list_for_tests(changed_tests, build_variant, evg_conf, exclude_suites,
|
||||
exclude_tasks)
|
||||
return create_task_list_for_tests(
|
||||
changed_tests, build_variant, evg_conf, exclude_suites, exclude_tasks
|
||||
)
|
||||
|
||||
LOGGER.info("No new or modified tests found.")
|
||||
return {}
|
||||
@ -404,7 +446,9 @@ def run_tests(tests_by_task: Dict[str, TaskToBurnInInfo], resmoke_cmd: [str]) ->
|
||||
try:
|
||||
subprocess.check_call(new_resmoke_cmd, shell=False)
|
||||
except subprocess.CalledProcessError as err:
|
||||
log.warning("Resmoke returned an error with suite", error=err.returncode)
|
||||
log.warning(
|
||||
"Resmoke returned an error with suite", error=err.returncode
|
||||
)
|
||||
sys.exit(err.returncode)
|
||||
|
||||
|
||||
@ -424,7 +468,9 @@ def _configure_logging(verbose: bool):
|
||||
logging.getLogger(log_name).setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def _get_evg_build_variant_by_name(evergreen_conf: EvergreenProjectConfig, name: str) -> Variant:
|
||||
def _get_evg_build_variant_by_name(
|
||||
evergreen_conf: EvergreenProjectConfig, name: str
|
||||
) -> Variant:
|
||||
"""
|
||||
Get the evergreen build variant by name from the evergreen config file.
|
||||
|
||||
@ -467,7 +513,11 @@ class FileChangeDetector(ABC):
|
||||
LOGGER.info("Calculated revision map", revision_map=revision_map)
|
||||
|
||||
changed_files = find_changed_files_in_repos(repos, revision_map)
|
||||
return {os.path.normpath(path) for path in changed_files if is_file_a_test_file(path)}
|
||||
return {
|
||||
os.path.normpath(path)
|
||||
for path in changed_files
|
||||
if is_file_a_test_file(path)
|
||||
}
|
||||
|
||||
|
||||
class LocalFileChangeDetector(FileChangeDetector):
|
||||
@ -588,23 +638,31 @@ class YamlBurnInExecutor(BurnInExecutor):
|
||||
|
||||
:param tests_by_task: Dictionary of tasks to run with tests to run in each.
|
||||
"""
|
||||
discovered_tasks = DiscoveredTaskList(discovered_tasks=[
|
||||
DiscoveredTask(
|
||||
task_name=task_name,
|
||||
suites=[
|
||||
DiscoveredSuite(suite_name=suite.name, test_list=suite.tests)
|
||||
for suite in task_info.suites
|
||||
],
|
||||
) for task_name, task_info in tests_by_task.items()
|
||||
])
|
||||
discovered_tasks = DiscoveredTaskList(
|
||||
discovered_tasks=[
|
||||
DiscoveredTask(
|
||||
task_name=task_name,
|
||||
suites=[
|
||||
DiscoveredSuite(suite_name=suite.name, test_list=suite.tests)
|
||||
for suite in task_info.suites
|
||||
],
|
||||
)
|
||||
for task_name, task_info in tests_by_task.items()
|
||||
]
|
||||
)
|
||||
print(yaml.safe_dump(discovered_tasks.dict()))
|
||||
|
||||
|
||||
class BurnInOrchestrator:
|
||||
"""Orchestrate the execution of burn_in_tests."""
|
||||
|
||||
def __init__(self, change_detector: FileChangeDetector, burn_in_executor: BurnInExecutor,
|
||||
evg_conf: EvergreenProjectConfig, install_dir: Optional[str]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
change_detector: FileChangeDetector,
|
||||
burn_in_executor: BurnInExecutor,
|
||||
evg_conf: EvergreenProjectConfig,
|
||||
install_dir: Optional[str],
|
||||
) -> None:
|
||||
"""
|
||||
Create a new orchestrator.
|
||||
|
||||
@ -627,8 +685,9 @@ class BurnInOrchestrator:
|
||||
changed_tests = self.change_detector.find_changed_tests(repos)
|
||||
LOGGER.info("Found changed tests", files=changed_tests)
|
||||
|
||||
tests_by_task = create_tests_by_task(build_variant, self.evg_conf, changed_tests,
|
||||
self.install_dir)
|
||||
tests_by_task = create_tests_by_task(
|
||||
build_variant, self.evg_conf, changed_tests, self.install_dir
|
||||
)
|
||||
LOGGER.debug("tests and tasks found", tests_by_task=tests_by_task)
|
||||
|
||||
self.burn_in_executor.execute(tests_by_task)
|
||||
@ -640,34 +699,92 @@ def cli():
|
||||
|
||||
|
||||
@cli.command(context_settings=dict(ignore_unknown_options=True))
|
||||
@click.option("--no-exec", "no_exec", default=False, is_flag=True,
|
||||
help="Do not execute the found tests.")
|
||||
@click.option("--build-variant", "build_variant", default=DEFAULT_VARIANT, metavar='BUILD_VARIANT',
|
||||
help="Tasks to run will be selected from this build variant.")
|
||||
@click.option("--repeat-tests", "repeat_tests_num", default=None, type=int,
|
||||
help="Number of times to repeat tests.")
|
||||
@click.option("--repeat-tests-min", "repeat_tests_min", default=None, type=int,
|
||||
help="The minimum number of times to repeat tests if time option is specified.")
|
||||
@click.option("--repeat-tests-max", "repeat_tests_max", default=None, type=int,
|
||||
help="The maximum number of times to repeat tests if time option is specified.")
|
||||
@click.option("--repeat-tests-secs", "repeat_tests_secs", default=None, type=int, metavar="SECONDS",
|
||||
help="Repeat tests for the given time (in secs).")
|
||||
@click.option("--yaml", "use_yaml", is_flag=True, default=False,
|
||||
help="Output discovered tasks in YAML. Tests will not be run.")
|
||||
@click.option("--verbose", "verbose", default=False, is_flag=True, help="Enable extra logging.")
|
||||
@click.option(
|
||||
"--origin-rev", "origin_rev", default=None,
|
||||
help="The revision in the mongo repo that changes will be compared against if specified.")
|
||||
@click.option("--install-dir", "install_dir", type=str,
|
||||
help="Path to bin directory of a testable installation")
|
||||
@click.option("--evg-project-file", "evg_project_file", default=DEFAULT_EVG_PROJECT_FILE,
|
||||
help="Evergreen project config file")
|
||||
"--no-exec",
|
||||
"no_exec",
|
||||
default=False,
|
||||
is_flag=True,
|
||||
help="Do not execute the found tests.",
|
||||
)
|
||||
@click.option(
|
||||
"--build-variant",
|
||||
"build_variant",
|
||||
default=DEFAULT_VARIANT,
|
||||
metavar="BUILD_VARIANT",
|
||||
help="Tasks to run will be selected from this build variant.",
|
||||
)
|
||||
@click.option(
|
||||
"--repeat-tests",
|
||||
"repeat_tests_num",
|
||||
default=None,
|
||||
type=int,
|
||||
help="Number of times to repeat tests.",
|
||||
)
|
||||
@click.option(
|
||||
"--repeat-tests-min",
|
||||
"repeat_tests_min",
|
||||
default=None,
|
||||
type=int,
|
||||
help="The minimum number of times to repeat tests if time option is specified.",
|
||||
)
|
||||
@click.option(
|
||||
"--repeat-tests-max",
|
||||
"repeat_tests_max",
|
||||
default=None,
|
||||
type=int,
|
||||
help="The maximum number of times to repeat tests if time option is specified.",
|
||||
)
|
||||
@click.option(
|
||||
"--repeat-tests-secs",
|
||||
"repeat_tests_secs",
|
||||
default=None,
|
||||
type=int,
|
||||
metavar="SECONDS",
|
||||
help="Repeat tests for the given time (in secs).",
|
||||
)
|
||||
@click.option(
|
||||
"--yaml",
|
||||
"use_yaml",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Output discovered tasks in YAML. Tests will not be run.",
|
||||
)
|
||||
@click.option(
|
||||
"--verbose", "verbose", default=False, is_flag=True, help="Enable extra logging."
|
||||
)
|
||||
@click.option(
|
||||
"--origin-rev",
|
||||
"origin_rev",
|
||||
default=None,
|
||||
help="The revision in the mongo repo that changes will be compared against if specified.",
|
||||
)
|
||||
@click.option(
|
||||
"--install-dir",
|
||||
"install_dir",
|
||||
type=str,
|
||||
help="Path to bin directory of a testable installation",
|
||||
)
|
||||
@click.option(
|
||||
"--evg-project-file",
|
||||
"evg_project_file",
|
||||
default=DEFAULT_EVG_PROJECT_FILE,
|
||||
help="Evergreen project config file",
|
||||
)
|
||||
@click.argument("resmoke_args", nargs=-1, type=click.UNPROCESSED)
|
||||
def run(build_variant: str, no_exec: bool, repeat_tests_num: Optional[int],
|
||||
repeat_tests_min: Optional[int], repeat_tests_max: Optional[int],
|
||||
repeat_tests_secs: Optional[int], resmoke_args: str, verbose: bool,
|
||||
origin_rev: Optional[str], install_dir: Optional[str], use_yaml: bool,
|
||||
evg_project_file: Optional[str]) -> None:
|
||||
def run(
|
||||
build_variant: str,
|
||||
no_exec: bool,
|
||||
repeat_tests_num: Optional[int],
|
||||
repeat_tests_min: Optional[int],
|
||||
repeat_tests_max: Optional[int],
|
||||
repeat_tests_secs: Optional[int],
|
||||
resmoke_args: str,
|
||||
verbose: bool,
|
||||
origin_rev: Optional[str],
|
||||
install_dir: Optional[str],
|
||||
use_yaml: bool,
|
||||
evg_project_file: Optional[str],
|
||||
) -> None:
|
||||
"""
|
||||
Run new or changed tests in repeated mode to validate their stability.
|
||||
|
||||
@ -719,7 +836,9 @@ def run(build_variant: str, no_exec: bool, repeat_tests_num: Optional[int],
|
||||
elif no_exec:
|
||||
executor = NopBurnInExecutor()
|
||||
|
||||
burn_in_orchestrator = BurnInOrchestrator(change_detector, executor, evg_conf, install_dir)
|
||||
burn_in_orchestrator = BurnInOrchestrator(
|
||||
change_detector, executor, evg_conf, install_dir
|
||||
)
|
||||
burn_in_orchestrator.burn_in(repos, build_variant)
|
||||
|
||||
|
||||
@ -742,7 +861,8 @@ def generate_test_membership_map_file_for_ci():
|
||||
with open(BURN_IN_TEST_MEMBERSHIP_FILE, "w") as file:
|
||||
json.dump(test_membership, file)
|
||||
LOGGER.info(
|
||||
f"Finished writing burn_in test membership mapping to {BURN_IN_TEST_MEMBERSHIP_FILE}")
|
||||
f"Finished writing burn_in test membership mapping to {BURN_IN_TEST_MEMBERSHIP_FILE}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -42,14 +42,23 @@ def main():
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument('-o', nargs='?', type=argparse.FileType('w'), default=sys.stdout,
|
||||
help='output file (default sys.stdout)')
|
||||
parser.add_argument('template_file', help='Cheetah template file')
|
||||
parser.add_argument('template_arg', nargs='*', default=[], help='Cheetah template args')
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
nargs="?",
|
||||
type=argparse.FileType("w"),
|
||||
default=sys.stdout,
|
||||
help="output file (default sys.stdout)",
|
||||
)
|
||||
parser.add_argument("template_file", help="Cheetah template file")
|
||||
parser.add_argument(
|
||||
"template_arg", nargs="*", default=[], help="Cheetah template args"
|
||||
)
|
||||
opts = parser.parse_args()
|
||||
|
||||
opts.o.write(str(Template(file=opts.template_file, namespaces=[{'args': opts.template_arg}])))
|
||||
opts.o.write(
|
||||
str(Template(file=opts.template_file, namespaces=[{"args": opts.template_arg}]))
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -3,6 +3,7 @@
|
||||
The API also provides methods to access specific fields present in the mongodb/mongo
|
||||
configuration file.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
@ -33,15 +34,19 @@ def parse_evergreen_file(path, evergreen_binary="evergreen"):
|
||||
|
||||
prev_environ = os.environ.copy()
|
||||
if sys.platform in ("win32", "cygwin"):
|
||||
LOGGER.info(f"Previous os.environ={os.environ} before updating 'USERPROFILE'")
|
||||
if 'HOME' in os.environ:
|
||||
os.environ['USERPROFILE'] = os.environ['HOME']
|
||||
LOGGER.info(
|
||||
f"Previous os.environ={os.environ} before updating 'USERPROFILE'"
|
||||
)
|
||||
if "HOME" in os.environ:
|
||||
os.environ["USERPROFILE"] = os.environ["HOME"]
|
||||
else:
|
||||
LOGGER.warn(
|
||||
"'HOME' enviorment variable unset. This will likely cause us to be unable to find evergreen binary."
|
||||
)
|
||||
|
||||
default_evergreen_location = os.path.expanduser(os.path.join("~", "evergreen"))
|
||||
default_evergreen_location = os.path.expanduser(
|
||||
os.path.join("~", "evergreen")
|
||||
)
|
||||
|
||||
# Restore enviorment if it was modified above on windows
|
||||
os.environ.clear()
|
||||
@ -80,9 +85,12 @@ class EvergreenProjectConfig(object):
|
||||
self.tasks = [Task(task_dict) for task_dict in self._conf["tasks"]]
|
||||
self._tasks_by_name = {task.name: task for task in self.tasks}
|
||||
self.task_groups = [
|
||||
TaskGroup(task_group_dict) for task_group_dict in self._conf.get("task_groups", [])
|
||||
TaskGroup(task_group_dict)
|
||||
for task_group_dict in self._conf.get("task_groups", [])
|
||||
]
|
||||
self._task_groups_by_name = {task_group.name: task_group for task_group in self.task_groups}
|
||||
self._task_groups_by_name = {
|
||||
task_group.name: task_group for task_group in self.task_groups
|
||||
}
|
||||
self.variants = [
|
||||
Variant(variant_dict, self._tasks_by_name, self._task_groups_by_name)
|
||||
for variant_dict in self._conf["buildvariants"]
|
||||
@ -214,12 +222,17 @@ class Task(object):
|
||||
|
||||
if self.is_run_tests_task:
|
||||
return [command_vars.get("suite", self.name)]
|
||||
if self.is_generate_resmoke_task and not self.is_initialize_multiversion_tasks_task:
|
||||
if (
|
||||
self.is_generate_resmoke_task
|
||||
and not self.is_initialize_multiversion_tasks_task
|
||||
):
|
||||
return [command_vars.get("suite", self.generated_task_name)]
|
||||
if self.is_initialize_multiversion_tasks_task:
|
||||
return [
|
||||
suite
|
||||
for suite in self.initialize_multiversion_tasks_command.get("vars", {}).keys()
|
||||
for suite in self.initialize_multiversion_tasks_command.get(
|
||||
"vars", {}
|
||||
).keys()
|
||||
]
|
||||
|
||||
raise ValueError(f"{self.name} task does not run a resmoke.py test suite")
|
||||
@ -286,10 +299,18 @@ class Variant(object):
|
||||
# A task in conf_dict may be a task_group, containing a list of tasks.
|
||||
for task_in_group in task_group_map.get(task_name).tasks:
|
||||
self.tasks.append(
|
||||
VariantTask(task_map.get(task_in_group), task.get("distros", run_on), self))
|
||||
VariantTask(
|
||||
task_map.get(task_in_group),
|
||||
task.get("distros", run_on),
|
||||
self,
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.tasks.append(
|
||||
VariantTask(task_map.get(task["name"]), task.get("distros", run_on), self))
|
||||
VariantTask(
|
||||
task_map.get(task["name"]), task.get("distros", run_on), self
|
||||
)
|
||||
)
|
||||
self.distro_names = set(run_on)
|
||||
for task in self.tasks:
|
||||
self.distro_names.update(task.run_on)
|
||||
|
||||
@ -67,7 +67,9 @@ def get_half_cpu_mask():
|
||||
|
||||
def count_running_clang_tidy(cmd_path):
|
||||
try:
|
||||
output = subprocess.check_output(["ps", "-axww", "-o", "pid=,command="], text=True)
|
||||
output = subprocess.check_output(
|
||||
["ps", "-axww", "-o", "pid=,command="], text=True
|
||||
)
|
||||
return sum(1 for line in output.splitlines() if cmd_path in line)
|
||||
except Exception as e:
|
||||
print(f"WARNING: failed to check running clang-tidy processes: {e}")
|
||||
@ -110,7 +112,9 @@ def main():
|
||||
if (
|
||||
(arg.endswith(".cpp") or arg.endswith(".h"))
|
||||
and rel.startswith("src/mongo")
|
||||
and not rel.startswith("src/mongo/db/modules/enterprise/src/streams/third_party")
|
||||
and not rel.startswith(
|
||||
"src/mongo/db/modules/enterprise/src/streams/third_party"
|
||||
)
|
||||
):
|
||||
files_to_check.append(rel)
|
||||
else:
|
||||
@ -125,7 +129,9 @@ def main():
|
||||
)
|
||||
|
||||
if not os.path.exists("compile_commands.json"):
|
||||
print("ERROR: failed to find compile_commands.json, run 'bazel build compiledb'")
|
||||
print(
|
||||
"ERROR: failed to find compile_commands.json, run 'bazel build compiledb'"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
with open("compile_commands.json") as f:
|
||||
@ -168,7 +174,9 @@ def main():
|
||||
|
||||
# probably a header, skip caching and let clang-tidy do its thing:
|
||||
else:
|
||||
proc = subprocess.run(clang_tidy_cmd + files_to_check + other_args, capture_output=True)
|
||||
proc = subprocess.run(
|
||||
clang_tidy_cmd + files_to_check + other_args, capture_output=True
|
||||
)
|
||||
|
||||
sys.stdout.buffer.write(proc.stdout)
|
||||
sys.stderr.buffer.write(proc.stderr)
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Functions for working with github."""
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
@ -38,7 +39,7 @@ class GithubApi(object):
|
||||
links = response.headers["Link"].split(",")
|
||||
for link in links:
|
||||
link_parts = link.split(";")
|
||||
link_type = link_parts[1].replace("rel=", "").strip(" \"")
|
||||
link_type = link_parts[1].replace("rel=", "").strip(' "')
|
||||
link_address = link_parts[0].strip("<> ")
|
||||
link_object[link_type] = link_address
|
||||
|
||||
@ -46,8 +47,9 @@ class GithubApi(object):
|
||||
|
||||
def get_commits(self, owner, project, params):
|
||||
"""Get the list of commits from a specified repository from github."""
|
||||
url = "{api_server}/repos/{owner}/{project}/commits".format(api_server=self.api_server,
|
||||
owner=owner, project=project)
|
||||
url = "{api_server}/repos/{owner}/{project}/commits".format(
|
||||
api_server=self.api_server, owner=owner, project=project
|
||||
)
|
||||
|
||||
LOGGER.debug("get_commits project=%s/%s, params: %s", owner, project, params)
|
||||
response = self._make_request(url, params)
|
||||
@ -61,7 +63,11 @@ class GithubApi(object):
|
||||
|
||||
links = self._parse_link(response)
|
||||
|
||||
LOGGER.debug("Commits from github (count=%d): [%s - %s]", len(commits), commits[-1]["sha"],
|
||||
commits[0]["sha"])
|
||||
LOGGER.debug(
|
||||
"Commits from github (count=%d): [%s - %s]",
|
||||
len(commits),
|
||||
commits[-1]["sha"],
|
||||
commits[0]["sha"],
|
||||
)
|
||||
|
||||
return commits
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Module to access a JIRA server."""
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import jira
|
||||
|
||||
@ -22,14 +22,23 @@ def main():
|
||||
usage = "usage: %prog [options]"
|
||||
parser = optparse.OptionParser(description=__doc__, usage=usage)
|
||||
parser.add_option(
|
||||
"-i", "--interval", dest="interval", default=5, type="int",
|
||||
"-i",
|
||||
"--interval",
|
||||
dest="interval",
|
||||
default=5,
|
||||
type="int",
|
||||
help="Collect system resource information every <interval> seconds. "
|
||||
"Default is every 5 seconds.")
|
||||
"Default is every 5 seconds.",
|
||||
)
|
||||
parser.add_option(
|
||||
"-o", "--output-file", dest="outfile", default="-",
|
||||
"-o",
|
||||
"--output-file",
|
||||
dest="outfile",
|
||||
default="-",
|
||||
help="If '-', then the file is written to stdout."
|
||||
" Any other value is treated as the output file name. By default,"
|
||||
" output is written to stdout.")
|
||||
" output is written to stdout.",
|
||||
)
|
||||
|
||||
(options, _) = parser.parse_args()
|
||||
|
||||
@ -40,8 +49,11 @@ def main():
|
||||
response = requests.get("http://localhost:2285/status")
|
||||
if response.status_code != requests.codes.ok:
|
||||
print(
|
||||
"Received a {} HTTP response: {}".format(response.status_code,
|
||||
response.text), file=sys.stderr)
|
||||
"Received a {} HTTP response: {}".format(
|
||||
response.status_code, response.text
|
||||
),
|
||||
file=sys.stderr,
|
||||
)
|
||||
time.sleep(options.interval)
|
||||
continue
|
||||
|
||||
@ -49,8 +61,12 @@ def main():
|
||||
try:
|
||||
res_json = response.json()
|
||||
except ValueError:
|
||||
print("Invalid JSON object returned with response: {}".format(response.text),
|
||||
file=sys.stderr)
|
||||
print(
|
||||
"Invalid JSON object returned with response: {}".format(
|
||||
response.text
|
||||
),
|
||||
file=sys.stderr,
|
||||
)
|
||||
time.sleep(options.interval)
|
||||
continue
|
||||
|
||||
|
||||
@ -53,13 +53,24 @@ def main():
|
||||
usage = "usage: %prog [options] report1.json report2.json ..."
|
||||
parser = OptionParser(description=__doc__, usage=usage)
|
||||
parser.add_option(
|
||||
"-o", "--output-file", dest="outfile", default="-",
|
||||
help=("If '-', then the combined report file is written to stdout."
|
||||
" Any other value is treated as the output file name. By default,"
|
||||
" output is written to stdout."))
|
||||
parser.add_option("-x", "--no-report-exit", dest="report_exit", default=True,
|
||||
action="store_false",
|
||||
help="Do not exit with a non-zero code if any test in the report fails.")
|
||||
"-o",
|
||||
"--output-file",
|
||||
dest="outfile",
|
||||
default="-",
|
||||
help=(
|
||||
"If '-', then the combined report file is written to stdout."
|
||||
" Any other value is treated as the output file name. By default,"
|
||||
" output is written to stdout."
|
||||
),
|
||||
)
|
||||
parser.add_option(
|
||||
"-x",
|
||||
"--no-report-exit",
|
||||
dest="report_exit",
|
||||
default=True,
|
||||
action="store_false",
|
||||
help="Do not exit with a non-zero code if any test in the report fails.",
|
||||
)
|
||||
|
||||
(options, args) = parser.parse_args()
|
||||
|
||||
|
||||
@ -16,13 +16,13 @@ from enum import Enum
|
||||
|
||||
import yaml
|
||||
|
||||
_COMPARE_FIELDS_SERVER_PARAMETERS = ['default', 'set_at', 'validator', 'test_only']
|
||||
_COMPARE_FIELDS_CONFIGS = ['arg_vartype', 'requires', 'hidden', 'redact']
|
||||
_COMPARE_FIELDS_SERVER_PARAMETERS = ["default", "set_at", "validator", "test_only"]
|
||||
_COMPARE_FIELDS_CONFIGS = ["arg_vartype", "requires", "hidden", "redact"]
|
||||
|
||||
|
||||
class ComparisonType(str, Enum):
|
||||
CONFIGS = 'configs'
|
||||
SERVER_PARAMETERS = 'server_parameters'
|
||||
CONFIGS = "configs"
|
||||
SERVER_PARAMETERS = "server_parameters"
|
||||
|
||||
|
||||
class PropertyDiff:
|
||||
@ -44,7 +44,8 @@ def build_diff_fn(compare_fields: list) -> callable:
|
||||
for field in compare_fields:
|
||||
if prop_base.get(field) != prop_inc.get(field):
|
||||
change_diffs[field] = PropertyDiff(
|
||||
str(prop_base.get(field, "")), str(prop_inc.get(field, "")))
|
||||
str(prop_base.get(field, "")), str(prop_inc.get(field, ""))
|
||||
)
|
||||
return change_diffs
|
||||
|
||||
return diff_fn
|
||||
@ -76,13 +77,17 @@ class ComputeDiffsFromIncrementedVersionHandler:
|
||||
requires knowledge of a base dictionary of properties (base_properties) to execute.
|
||||
"""
|
||||
|
||||
def __init__(self, handler_type: ComparisonType, base_properties: dict, calc_diff_fn: callable):
|
||||
def __init__(
|
||||
self,
|
||||
handler_type: ComparisonType,
|
||||
base_properties: dict,
|
||||
calc_diff_fn: callable,
|
||||
):
|
||||
self.calc_diff = calc_diff_fn
|
||||
self.handler_type = handler_type
|
||||
self.properties_diff = PropertiesDiffs(base_properties, {}, {})
|
||||
|
||||
def _compare_and_partition(self, yaml_props: dict, yaml_file_name: str) -> None:
|
||||
|
||||
for yaml_key, yaml_val in yaml_props.items():
|
||||
compare_key = (yaml_key, yaml_file_name)
|
||||
|
||||
@ -124,17 +129,20 @@ def load_yaml(dirs: list, exclusions: list, idl_yaml_handlers: list) -> None:
|
||||
break
|
||||
|
||||
for name in filenames:
|
||||
if not name.endswith('.idl'):
|
||||
if not name.endswith(".idl"):
|
||||
continue
|
||||
|
||||
with io.open(os.path.join(dirpath, name), 'r', encoding='utf-8') as idl_yaml_stream:
|
||||
with io.open(
|
||||
os.path.join(dirpath, name), "r", encoding="utf-8"
|
||||
) as idl_yaml_stream:
|
||||
idl_yaml = yaml.safe_load(idl_yaml_stream)
|
||||
for handler in idl_yaml_handlers:
|
||||
handler.handle(idl_yaml, name)
|
||||
|
||||
|
||||
def get_properties_diffs(mode: ComparisonType, base_version_dirs: list, inc_version_dirs: list,
|
||||
exclude: list) -> PropertiesDiffs:
|
||||
def get_properties_diffs(
|
||||
mode: ComparisonType, base_version_dirs: list, inc_version_dirs: list, exclude: list
|
||||
) -> PropertiesDiffs:
|
||||
"""Returns a PropertiesDiffs object containing the changes between properties in base_version_dirs and inc_version_dirs."""
|
||||
|
||||
compare_fields = []
|
||||
@ -143,22 +151,22 @@ def get_properties_diffs(mode: ComparisonType, base_version_dirs: list, inc_vers
|
||||
elif mode == ComparisonType.CONFIGS:
|
||||
compare_fields = _COMPARE_FIELDS_CONFIGS
|
||||
else:
|
||||
raise Exception(f'Unknown option {mode}')
|
||||
raise Exception(f"Unknown option {mode}")
|
||||
|
||||
diff_fn = build_diff_fn(compare_fields)
|
||||
|
||||
base_handler = BuildBasePropertiesForComparisonHandler(mode)
|
||||
load_yaml(base_version_dirs, exclude, [base_handler])
|
||||
|
||||
increment_handler = ComputeDiffsFromIncrementedVersionHandler(mode, base_handler.properties,
|
||||
diff_fn)
|
||||
increment_handler = ComputeDiffsFromIncrementedVersionHandler(
|
||||
mode, base_handler.properties, diff_fn
|
||||
)
|
||||
load_yaml(inc_version_dirs, exclude, [increment_handler])
|
||||
|
||||
return increment_handler.properties_diff
|
||||
|
||||
|
||||
def output_diffs(mode: ComparisonType, diff: PropertiesDiffs) -> None:
|
||||
|
||||
pp = pprint.PrettyPrinter()
|
||||
|
||||
mode_format = ""
|
||||
@ -167,56 +175,67 @@ def output_diffs(mode: ComparisonType, diff: PropertiesDiffs) -> None:
|
||||
elif mode == ComparisonType.SERVER_PARAMETERS:
|
||||
mode_format = "server parameter"
|
||||
else:
|
||||
raise Exception(f'Unknown option {mode}')
|
||||
raise Exception(f"Unknown option {mode}")
|
||||
|
||||
for sp, val in diff.added.items():
|
||||
if not val.get('test_only'):
|
||||
print(f'Added {mode_format} {str(sp)}')
|
||||
if not val.get("test_only"):
|
||||
print(f"Added {mode_format} {str(sp)}")
|
||||
pp.pprint(val)
|
||||
print()
|
||||
|
||||
for sp, val in diff.removed.items():
|
||||
if not val.get('test_only'):
|
||||
print(f'Removed {mode_format} {str(sp)}')
|
||||
if not val.get("test_only"):
|
||||
print(f"Removed {mode_format} {str(sp)}")
|
||||
pp.pprint(val)
|
||||
print()
|
||||
|
||||
for sp, val in diff.modified.items():
|
||||
if not val.get('test_only'):
|
||||
print(f'Modified {mode_format} {str(sp)}')
|
||||
if not val.get("test_only"):
|
||||
print(f"Modified {mode_format} {str(sp)}")
|
||||
for property_name, delta in val.items():
|
||||
print(f'<{property_name}> changed from [{delta.base}] to [{delta.inc}]')
|
||||
print(f"<{property_name}> changed from [{delta.base}] to [{delta.inc}]")
|
||||
print()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
arg_parser = argparse.ArgumentParser(prog="Core Server IDL Parameter/Config Diff")
|
||||
|
||||
arg_parser.add_argument(
|
||||
'mode', choices=[ComparisonType.SERVER_PARAMETERS.value, ComparisonType.CONFIGS.value])
|
||||
"mode",
|
||||
choices=[ComparisonType.SERVER_PARAMETERS.value, ComparisonType.CONFIGS.value],
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
'-b', '--base_version_dirs',
|
||||
help='A colon-separated list of paths to the base version for comparison', required=True)
|
||||
"-b",
|
||||
"--base_version_dirs",
|
||||
help="A colon-separated list of paths to the base version for comparison",
|
||||
required=True,
|
||||
)
|
||||
|
||||
arg_parser.add_argument(
|
||||
'-i', '--incremented_version_dirs',
|
||||
help='A colon-separated list of paths to the incremented version for comparison',
|
||||
required=True)
|
||||
"-i",
|
||||
"--incremented_version_dirs",
|
||||
help="A colon-separated list of paths to the incremented version for comparison",
|
||||
required=True,
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
'-e', '--exclude_dirs',
|
||||
help='A colon-separated list of directory path strings to exclude from comparison, ' +
|
||||
'e.g. a path /foo/bar/dir will be excluded by an argument of any of foo/bar/dir, bar/dir,' +
|
||||
'foo, or bar, or dir ', required=False)
|
||||
"-e",
|
||||
"--exclude_dirs",
|
||||
help="A colon-separated list of directory path strings to exclude from comparison, "
|
||||
+ "e.g. a path /foo/bar/dir will be excluded by an argument of any of foo/bar/dir, bar/dir,"
|
||||
+ "foo, or bar, or dir ",
|
||||
required=False,
|
||||
)
|
||||
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
incremented_version_dirs = str.split(args.incremented_version_dirs, ':')
|
||||
base_version_dirs = str.split(args.base_version_dirs, ':')
|
||||
exclude = set(args.exclude_dirs.split(':')) if args.exclude_dirs else set()
|
||||
incremented_version_dirs = str.split(args.incremented_version_dirs, ":")
|
||||
base_version_dirs = str.split(args.base_version_dirs, ":")
|
||||
exclude = set(args.exclude_dirs.split(":")) if args.exclude_dirs else set()
|
||||
mode = ComparisonType(args.mode)
|
||||
|
||||
diffs = get_properties_diffs(mode, base_version_dirs, incremented_version_dirs, exclude)
|
||||
diffs = get_properties_diffs(
|
||||
mode, base_version_dirs, incremented_version_dirs, exclude
|
||||
)
|
||||
output_diffs(mode, diffs)
|
||||
|
||||
|
||||
@ -254,19 +273,25 @@ class TestBuildBasePropertiesForComparisonHandler(unittest.TestCase):
|
||||
"""
|
||||
yaml_obj = yaml.load(document, Loader=yaml.FullLoader)
|
||||
|
||||
fixture = BuildBasePropertiesForComparisonHandler(ComparisonType.SERVER_PARAMETERS)
|
||||
fixture = BuildBasePropertiesForComparisonHandler(
|
||||
ComparisonType.SERVER_PARAMETERS
|
||||
)
|
||||
fixture.handle(yaml_obj, filename)
|
||||
|
||||
#should filter out configs, but parse server parameters
|
||||
self.assertIsNone(fixture.properties.get(("net.compression.compressors", filename)))
|
||||
# should filter out configs, but parse server parameters
|
||||
self.assertIsNone(
|
||||
fixture.properties.get(("net.compression.compressors", filename))
|
||||
)
|
||||
self.assertIsNotNone(fixture.properties[("changeStreamOptions", filename)])
|
||||
|
||||
fixture = BuildBasePropertiesForComparisonHandler(ComparisonType.CONFIGS)
|
||||
fixture.handle(yaml_obj, filename)
|
||||
|
||||
#should filter out server parameters, but parse configs
|
||||
# should filter out server parameters, but parse configs
|
||||
self.assertIsNone(fixture.properties.get(("changeStreamOptions", filename)))
|
||||
self.assertIsNotNone(fixture.properties.get(("net.compression.compressors", filename)))
|
||||
self.assertIsNotNone(
|
||||
fixture.properties.get(("net.compression.compressors", filename))
|
||||
)
|
||||
|
||||
def test_empty_yaml_obj_does_nothing(self):
|
||||
filename = "test.yml"
|
||||
@ -277,7 +302,9 @@ class TestBuildBasePropertiesForComparisonHandler(unittest.TestCase):
|
||||
|
||||
yaml_obj = yaml.load(document, Loader=yaml.FullLoader)
|
||||
|
||||
fixture = BuildBasePropertiesForComparisonHandler(ComparisonType.SERVER_PARAMETERS)
|
||||
fixture = BuildBasePropertiesForComparisonHandler(
|
||||
ComparisonType.SERVER_PARAMETERS
|
||||
)
|
||||
fixture.handle(yaml_obj, filename)
|
||||
self.assertTrue(len(fixture.properties) == 0)
|
||||
|
||||
@ -328,8 +355,9 @@ class TestComputeDiffsFromIncrementedVersionHandler(unittest.TestCase):
|
||||
|
||||
inc_yaml_obj = yaml.load(document, Loader=yaml.FullLoader)
|
||||
|
||||
inc_fixture = ComputeDiffsFromIncrementedVersionHandler(ComparisonType.CONFIGS, {},
|
||||
self.config_diff_function)
|
||||
inc_fixture = ComputeDiffsFromIncrementedVersionHandler(
|
||||
ComparisonType.CONFIGS, {}, self.config_diff_function
|
||||
)
|
||||
inc_fixture.handle(inc_yaml_obj, filename)
|
||||
|
||||
properties_diffs = inc_fixture.properties_diff
|
||||
@ -339,8 +367,9 @@ class TestComputeDiffsFromIncrementedVersionHandler(unittest.TestCase):
|
||||
self.assertIsNone(properties_diffs.added.get(("testOptions", filename)))
|
||||
self.assertIsNone(properties_diffs.added.get(("helloMorld", filename)))
|
||||
|
||||
inc_fixture = ComputeDiffsFromIncrementedVersionHandler(ComparisonType.SERVER_PARAMETERS,
|
||||
{}, self.parameter_diff_function)
|
||||
inc_fixture = ComputeDiffsFromIncrementedVersionHandler(
|
||||
ComparisonType.SERVER_PARAMETERS, {}, self.parameter_diff_function
|
||||
)
|
||||
inc_fixture.handle(inc_yaml_obj, filename)
|
||||
|
||||
properties_diffs = inc_fixture.properties_diff
|
||||
@ -375,8 +404,9 @@ class TestComputeDiffsFromIncrementedVersionHandler(unittest.TestCase):
|
||||
|
||||
inc_yaml_obj = yaml.load(document, Loader=yaml.FullLoader)
|
||||
|
||||
inc_fixture = ComputeDiffsFromIncrementedVersionHandler(ComparisonType.CONFIGS, {},
|
||||
self.config_diff_function)
|
||||
inc_fixture = ComputeDiffsFromIncrementedVersionHandler(
|
||||
ComparisonType.CONFIGS, {}, self.config_diff_function
|
||||
)
|
||||
inc_fixture.handle(inc_yaml_obj, filename)
|
||||
|
||||
properties_diffs = inc_fixture.properties_diff
|
||||
@ -386,8 +416,9 @@ class TestComputeDiffsFromIncrementedVersionHandler(unittest.TestCase):
|
||||
self.assertTrue(len(properties_diffs.removed) == 0)
|
||||
self.assertTrue(len(properties_diffs.modified) == 0)
|
||||
|
||||
inc_fixture = ComputeDiffsFromIncrementedVersionHandler(ComparisonType.SERVER_PARAMETERS,
|
||||
{}, self.parameter_diff_function)
|
||||
inc_fixture = ComputeDiffsFromIncrementedVersionHandler(
|
||||
ComparisonType.SERVER_PARAMETERS, {}, self.parameter_diff_function
|
||||
)
|
||||
inc_fixture.handle(inc_yaml_obj, filename)
|
||||
|
||||
properties_diffs = inc_fixture.properties_diff
|
||||
@ -404,13 +435,16 @@ class TestComputeDiffsFromIncrementedVersionHandler(unittest.TestCase):
|
||||
"""
|
||||
|
||||
def get_base_data():
|
||||
return {("ok", "test.yaml"): {"yes": "no"}, ("also_ok", "blah.yaml"): {"no": "yes"}}
|
||||
return {
|
||||
("ok", "test.yaml"): {"yes": "no"},
|
||||
("also_ok", "blah.yaml"): {"no": "yes"},
|
||||
}
|
||||
|
||||
inc_yaml_obj = yaml.load(document, Loader=yaml.FullLoader)
|
||||
|
||||
inc_fixture = ComputeDiffsFromIncrementedVersionHandler(ComparisonType.CONFIGS,
|
||||
get_base_data(),
|
||||
self.config_diff_function)
|
||||
inc_fixture = ComputeDiffsFromIncrementedVersionHandler(
|
||||
ComparisonType.CONFIGS, get_base_data(), self.config_diff_function
|
||||
)
|
||||
inc_fixture.handle(inc_yaml_obj, filename)
|
||||
|
||||
properties_diffs = inc_fixture.properties_diff
|
||||
@ -422,9 +456,11 @@ class TestComputeDiffsFromIncrementedVersionHandler(unittest.TestCase):
|
||||
self.assertTrue(len(properties_diffs.added) == 0)
|
||||
self.assertTrue(len(properties_diffs.modified) == 0)
|
||||
|
||||
inc_fixture = ComputeDiffsFromIncrementedVersionHandler(ComparisonType.SERVER_PARAMETERS,
|
||||
get_base_data(),
|
||||
self.parameter_diff_function)
|
||||
inc_fixture = ComputeDiffsFromIncrementedVersionHandler(
|
||||
ComparisonType.SERVER_PARAMETERS,
|
||||
get_base_data(),
|
||||
self.parameter_diff_function,
|
||||
)
|
||||
inc_fixture.handle(inc_yaml_obj, filename)
|
||||
|
||||
properties_diffs = inc_fixture.properties_diff
|
||||
@ -475,8 +511,9 @@ class TestComputeDiffsFromIncrementedVersionHandler(unittest.TestCase):
|
||||
"""
|
||||
inc_yaml_obj = yaml.load(document, Loader=yaml.FullLoader)
|
||||
|
||||
inc_fixture = ComputeDiffsFromIncrementedVersionHandler(ComparisonType.CONFIGS, {},
|
||||
build_diff_fn(['default']))
|
||||
inc_fixture = ComputeDiffsFromIncrementedVersionHandler(
|
||||
ComparisonType.CONFIGS, {}, build_diff_fn(["default"])
|
||||
)
|
||||
inc_fixture.handle(inc_yaml_obj, filename)
|
||||
|
||||
properties_diffs = inc_fixture.properties_diff
|
||||
@ -485,8 +522,9 @@ class TestComputeDiffsFromIncrementedVersionHandler(unittest.TestCase):
|
||||
self.assertTrue(len(properties_diffs.removed) == 0)
|
||||
self.assertTrue(len(properties_diffs.modified) == 0)
|
||||
|
||||
inc_fixture = ComputeDiffsFromIncrementedVersionHandler(ComparisonType.SERVER_PARAMETERS,
|
||||
{}, build_diff_fn(['set_at']))
|
||||
inc_fixture = ComputeDiffsFromIncrementedVersionHandler(
|
||||
ComparisonType.SERVER_PARAMETERS, {}, build_diff_fn(["set_at"])
|
||||
)
|
||||
inc_fixture.handle(inc_yaml_obj, filename)
|
||||
|
||||
properties_diffs = inc_fixture.properties_diff
|
||||
@ -541,11 +579,13 @@ class TestComputeDiffsFromIncrementedVersionHandler(unittest.TestCase):
|
||||
diff_fn = build_diff_fn(_COMPARE_FIELDS_CONFIGS)
|
||||
|
||||
config_base_properties_handler = BuildBasePropertiesForComparisonHandler(
|
||||
ComparisonType.CONFIGS)
|
||||
ComparisonType.CONFIGS
|
||||
)
|
||||
config_base_properties_handler.handle(document_yaml, filename)
|
||||
|
||||
config_inc_properties_handler = ComputeDiffsFromIncrementedVersionHandler(
|
||||
ComparisonType.CONFIGS, config_base_properties_handler.properties, diff_fn)
|
||||
ComparisonType.CONFIGS, config_base_properties_handler.properties, diff_fn
|
||||
)
|
||||
config_inc_properties_handler.handle(document_inc_yaml, filename)
|
||||
|
||||
property_diff = config_inc_properties_handler.properties_diff
|
||||
@ -555,11 +595,15 @@ class TestComputeDiffsFromIncrementedVersionHandler(unittest.TestCase):
|
||||
diff_fn = build_diff_fn(_COMPARE_FIELDS_SERVER_PARAMETERS)
|
||||
|
||||
sp_base_properties_handler = BuildBasePropertiesForComparisonHandler(
|
||||
ComparisonType.SERVER_PARAMETERS)
|
||||
ComparisonType.SERVER_PARAMETERS
|
||||
)
|
||||
sp_base_properties_handler.handle(document_yaml, filename)
|
||||
|
||||
sp_inc_properties_handler = ComputeDiffsFromIncrementedVersionHandler(
|
||||
ComparisonType.SERVER_PARAMETERS, sp_base_properties_handler.properties, diff_fn)
|
||||
ComparisonType.SERVER_PARAMETERS,
|
||||
sp_base_properties_handler.properties,
|
||||
diff_fn,
|
||||
)
|
||||
sp_inc_properties_handler.handle(document_inc_yaml, filename)
|
||||
|
||||
property_diff = sp_inc_properties_handler.properties_diff
|
||||
@ -645,37 +689,50 @@ class TestComputeDiffsFromIncrementedVersionHandler(unittest.TestCase):
|
||||
diff_fn = build_diff_fn(_COMPARE_FIELDS_CONFIGS)
|
||||
|
||||
config_base_properties_handler = BuildBasePropertiesForComparisonHandler(
|
||||
ComparisonType.CONFIGS)
|
||||
ComparisonType.CONFIGS
|
||||
)
|
||||
config_base_properties_handler.handle(document_yaml, filename)
|
||||
|
||||
config_inc_properties_handler = ComputeDiffsFromIncrementedVersionHandler(
|
||||
ComparisonType.CONFIGS, config_base_properties_handler.properties, diff_fn)
|
||||
ComparisonType.CONFIGS, config_base_properties_handler.properties, diff_fn
|
||||
)
|
||||
config_inc_properties_handler.handle(document_inc_yaml, filename)
|
||||
|
||||
property_diff = config_inc_properties_handler.properties_diff
|
||||
|
||||
self.assertEqual(
|
||||
property_diff.modified.get(("asdf", filename)).get("arg_vartype").base, "String")
|
||||
property_diff.modified.get(("asdf", filename)).get("arg_vartype").base,
|
||||
"String",
|
||||
)
|
||||
self.assertEqual(
|
||||
property_diff.modified.get(("asdf", filename)).get("arg_vartype").inc, "int")
|
||||
property_diff.modified.get(("asdf", filename)).get("arg_vartype").inc, "int"
|
||||
)
|
||||
self.assertIsNone(property_diff.modified.get(("zxcv", filename)))
|
||||
|
||||
diff_fn = build_diff_fn(_COMPARE_FIELDS_SERVER_PARAMETERS)
|
||||
|
||||
sp_base_properties_handler = BuildBasePropertiesForComparisonHandler(
|
||||
ComparisonType.SERVER_PARAMETERS)
|
||||
ComparisonType.SERVER_PARAMETERS
|
||||
)
|
||||
sp_base_properties_handler.handle(document_yaml, filename)
|
||||
|
||||
sp_inc_properties_handler = ComputeDiffsFromIncrementedVersionHandler(
|
||||
ComparisonType.SERVER_PARAMETERS, sp_base_properties_handler.properties, diff_fn)
|
||||
ComparisonType.SERVER_PARAMETERS,
|
||||
sp_base_properties_handler.properties,
|
||||
diff_fn,
|
||||
)
|
||||
sp_inc_properties_handler.handle(document_inc_yaml, filename)
|
||||
|
||||
property_diff = sp_inc_properties_handler.properties_diff
|
||||
|
||||
self.assertEqual(
|
||||
property_diff.modified.get(("testOptions", filename)).get("set_at").base, "cluster")
|
||||
property_diff.modified.get(("testOptions", filename)).get("set_at").base,
|
||||
"cluster",
|
||||
)
|
||||
self.assertEqual(
|
||||
property_diff.modified.get(("testOptions", filename)).get("set_at").inc, "runtime")
|
||||
property_diff.modified.get(("testOptions", filename)).get("set_at").inc,
|
||||
"runtime",
|
||||
)
|
||||
self.assertIsNone(property_diff.modified.get(("testParameter", filename)))
|
||||
|
||||
|
||||
|
||||
@ -37,7 +37,7 @@ from cost_estimator import estimate
|
||||
from database_instance import DatabaseInstance
|
||||
from sklearn.linear_model import LinearRegression
|
||||
|
||||
__all__ = ['calibrate']
|
||||
__all__ = ["calibrate"]
|
||||
|
||||
|
||||
async def calibrate(config: AbtCalibratorConfig, database: DatabaseInstance):
|
||||
@ -55,18 +55,21 @@ async def calibrate(config: AbtCalibratorConfig, database: DatabaseInstance):
|
||||
return result
|
||||
|
||||
|
||||
def calibrate_node(abt_df: pd.DataFrame, config: AbtCalibratorConfig,
|
||||
node_config: AbtNodeCalibrationConfig):
|
||||
def calibrate_node(
|
||||
abt_df: pd.DataFrame,
|
||||
config: AbtCalibratorConfig,
|
||||
node_config: AbtNodeCalibrationConfig,
|
||||
):
|
||||
abt_node_df = abt_df[abt_df.abt_type == node_config.type]
|
||||
if node_config.filter_function is not None:
|
||||
abt_node_df = node_config.filter_function(abt_node_df)
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
if node_config.variables_override is None:
|
||||
variables = ['n_processed']
|
||||
variables = ["n_processed"]
|
||||
else:
|
||||
variables = node_config.variables_override
|
||||
y = abt_node_df['execution_time']
|
||||
y = abt_node_df["execution_time"]
|
||||
X = abt_node_df[variables]
|
||||
|
||||
X = sm.add_constant(X)
|
||||
|
||||
@ -87,8 +87,8 @@ class CostModelCoefficients:
|
||||
|
||||
def to_camel_case(string):
|
||||
"""Convert a snake_case string to camelCase one."""
|
||||
words = string.split('_')
|
||||
return words[0] + ''.join(w.capitalize() for w in words[1:])
|
||||
words = string.split("_")
|
||||
return words[0] + "".join(w.capitalize() for w in words[1:])
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -109,12 +109,12 @@ class BenchmarkTask:
|
||||
|
||||
def print(self):
|
||||
"""Prints the task."""
|
||||
print('Cost Model Coefficients Overrides:')
|
||||
print(f'\tA: {self.cost_model_a.to_dict()}')
|
||||
print(f'\tB: {self.cost_model_b.to_dict()}')
|
||||
print(f'threshold: {self.threshold}')
|
||||
print(f'collection: {self.collection_name}')
|
||||
print(f'pipeline: {self.pipeline}')
|
||||
print("Cost Model Coefficients Overrides:")
|
||||
print(f"\tA: {self.cost_model_a.to_dict()}")
|
||||
print(f"\tB: {self.cost_model_b.to_dict()}")
|
||||
print(f"threshold: {self.threshold}")
|
||||
print(f"collection: {self.collection_name}")
|
||||
print(f"pipeline: {self.pipeline}")
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -131,20 +131,20 @@ class BenchmarkResult:
|
||||
|
||||
def print(self):
|
||||
"""Print the results."""
|
||||
print('## Benchmark Task')
|
||||
print("## Benchmark Task")
|
||||
self.task.print()
|
||||
print('\n## Result')
|
||||
print(f'Means: A: {self.variant_a.mean:,.2f}, B: {self.variant_b.mean:,.2f}.')
|
||||
print("\n## Result")
|
||||
print(f"Means: A: {self.variant_a.mean:,.2f}, B: {self.variant_b.mean:,.2f}.")
|
||||
print(f"t-test's p-value: {self.pvalue}.")
|
||||
if self.pvalue < self.task.threshold:
|
||||
print('The means are significantly different.')
|
||||
print("The means are significantly different.")
|
||||
else:
|
||||
print('The means are not significantly different.')
|
||||
print("The means are not significantly different.")
|
||||
|
||||
print('\n### A\n')
|
||||
print("\n### A\n")
|
||||
self.variant_a.print()
|
||||
|
||||
print('\n### B\n')
|
||||
print("\n### B\n")
|
||||
self.variant_b.print()
|
||||
|
||||
|
||||
@ -162,13 +162,15 @@ class ExperimentResult:
|
||||
if index is None:
|
||||
index = len(self.explain) // 2
|
||||
|
||||
print('ABT Physical Tree')
|
||||
print("ABT Physical Tree")
|
||||
self.physical_tree[index].print()
|
||||
print('\nSBE Execution Tree')
|
||||
print("\nSBE Execution Tree")
|
||||
self.execution_tree[index].print()
|
||||
|
||||
|
||||
async def benchmark(config: BenchmarkConfig, database: DatabaseInstance, task: BenchmarkTask):
|
||||
async def benchmark(
|
||||
config: BenchmarkConfig, database: DatabaseInstance, task: BenchmarkTask
|
||||
):
|
||||
"""Run the A/B performance task.
|
||||
|
||||
It executes the given pipeline for both overrides of Cost Model Coefficients,
|
||||
@ -177,7 +179,9 @@ async def benchmark(config: BenchmarkConfig, database: DatabaseInstance, task: B
|
||||
value (usually 0.05 or 0.01) we can say that the Null hypothesis is proven and there is
|
||||
no significant difference in the execution times.
|
||||
"""
|
||||
async with get_database_parameter(database, 'internalCostModelCoefficients') as db_param:
|
||||
async with get_database_parameter(
|
||||
database, "internalCostModelCoefficients"
|
||||
) as db_param:
|
||||
await db_param.set(json.dumps(task.cost_model_a.to_dict()))
|
||||
result_a = await run(config, database, task.collection_name, task.pipeline)
|
||||
|
||||
@ -190,22 +194,34 @@ async def benchmark(config: BenchmarkConfig, database: DatabaseInstance, task: B
|
||||
execution_times_a = [et.total_execution_time for et in variant_a.execution_tree]
|
||||
execution_times_b = [et.total_execution_time for et in variant_b.execution_tree]
|
||||
|
||||
ttest_result = stats.ttest_ind(execution_times_a, execution_times_b, equal_var=False)
|
||||
ttest_result = stats.ttest_ind(
|
||||
execution_times_a, execution_times_b, equal_var=False
|
||||
)
|
||||
|
||||
return BenchmarkResult(task=task, variant_a=variant_a, variant_b=variant_b,
|
||||
pvalue=ttest_result.pvalue)
|
||||
return BenchmarkResult(
|
||||
task=task, variant_a=variant_a, variant_b=variant_b, pvalue=ttest_result.pvalue
|
||||
)
|
||||
|
||||
|
||||
def make_variant(explain: Sequence[dict[str, any]]) -> ExperimentResult:
|
||||
"""Make one variant of the A/B test."""
|
||||
pt = [physical_tree.build(e['queryPlanner']['winningPlan']['queryPlan']) for e in explain]
|
||||
et = [execution_tree.build_execution_tree(e['executionStats']) for e in explain]
|
||||
pt = [
|
||||
physical_tree.build(e["queryPlanner"]["winningPlan"]["queryPlan"])
|
||||
for e in explain
|
||||
]
|
||||
et = [execution_tree.build_execution_tree(e["executionStats"]) for e in explain]
|
||||
mean = sum(et.total_execution_time for et in et) / len(et)
|
||||
return ExperimentResult(explain=explain, physical_tree=pt, execution_tree=et, mean=mean)
|
||||
return ExperimentResult(
|
||||
explain=explain, physical_tree=pt, execution_tree=et, mean=mean
|
||||
)
|
||||
|
||||
|
||||
async def run(config: BenchmarkConfig, database: DatabaseInstance, collection: str,
|
||||
pipeline: Pipeline):
|
||||
async def run(
|
||||
config: BenchmarkConfig,
|
||||
database: DatabaseInstance,
|
||||
collection: str,
|
||||
pipeline: Pipeline,
|
||||
):
|
||||
"""Run one variant of the A/B test."""
|
||||
|
||||
# warmup
|
||||
@ -215,10 +231,10 @@ async def run(config: BenchmarkConfig, database: DatabaseInstance, collection: s
|
||||
result = []
|
||||
for _ in range(config.runs):
|
||||
explain = await database.explain(collection, pipeline)
|
||||
if explain['ok'] == 1:
|
||||
if explain["ok"] == 1:
|
||||
result.append(explain)
|
||||
else:
|
||||
logging.warn('Query execution failed: %s', explain)
|
||||
logging.warn("Query execution failed: %s", explain)
|
||||
return result
|
||||
|
||||
|
||||
@ -231,14 +247,18 @@ async def smoke_test():
|
||||
await database.enable_cascades(True)
|
||||
cost_model_a = CostModelCoefficients(index_scan_incremental_cost=0.0001)
|
||||
cost_model_b = CostModelCoefficients(index_scan_incremental_cost=0.9)
|
||||
task = BenchmarkTask(collection_name='c_str_05_45000',
|
||||
pipeline=[{'$match': {'choice1': 'hello', 'choice2': 'gaussian'}}],
|
||||
cost_model_a=cost_model_a, cost_model_b=cost_model_b, threshold=0.05)
|
||||
task = BenchmarkTask(
|
||||
collection_name="c_str_05_45000",
|
||||
pipeline=[{"$match": {"choice1": "hello", "choice2": "gaussian"}}],
|
||||
cost_model_a=cost_model_a,
|
||||
cost_model_b=cost_model_b,
|
||||
threshold=0.05,
|
||||
)
|
||||
|
||||
res = await benchmark(config, database, task)
|
||||
res.print()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(smoke_test())
|
||||
|
||||
@ -37,48 +37,52 @@ from random_generator import (
|
||||
RangeGenerator,
|
||||
)
|
||||
|
||||
__all__ = ['main_config', 'distributions']
|
||||
__all__ = ["main_config", "distributions"]
|
||||
|
||||
# A string value to fill up collections and not used in queries.
|
||||
HIDDEN_STRING_VALUE = '__hidden_string_value'
|
||||
HIDDEN_STRING_VALUE = "__hidden_string_value"
|
||||
|
||||
# Data distributions settings.
|
||||
distributions = {}
|
||||
|
||||
string_choice_values = [
|
||||
'h',
|
||||
'hi',
|
||||
'hi!',
|
||||
'hola',
|
||||
'hello',
|
||||
'square',
|
||||
'squared',
|
||||
'gaussian',
|
||||
'chisquare',
|
||||
'chisquared',
|
||||
'hello world',
|
||||
'distribution',
|
||||
"h",
|
||||
"hi",
|
||||
"hi!",
|
||||
"hola",
|
||||
"hello",
|
||||
"square",
|
||||
"squared",
|
||||
"gaussian",
|
||||
"chisquare",
|
||||
"chisquared",
|
||||
"hello world",
|
||||
"distribution",
|
||||
]
|
||||
|
||||
string_choice_weights = [10, 20, 5, 17, 30, 7, 9, 15, 40, 2, 12, 1]
|
||||
|
||||
distributions['string_choice'] = RandomDistribution.choice(string_choice_values,
|
||||
string_choice_weights)
|
||||
distributions["string_choice"] = RandomDistribution.choice(
|
||||
string_choice_values, string_choice_weights
|
||||
)
|
||||
|
||||
small_query_weights = [i for i in range(10, 201, 10)]
|
||||
small_query_cardinality = sum(small_query_weights)
|
||||
|
||||
int_choice_values = [i for i in range(1, 1000, 50)]
|
||||
random.shuffle(int_choice_values)
|
||||
distributions['int_choice'] = RandomDistribution.choice(int_choice_values, small_query_weights)
|
||||
distributions["int_choice"] = RandomDistribution.choice(
|
||||
int_choice_values, small_query_weights
|
||||
)
|
||||
|
||||
distributions['random_string'] = ArrayRandomDistribution(
|
||||
distributions["random_string"] = ArrayRandomDistribution(
|
||||
RandomDistribution.uniform(RangeGenerator(DataType.INTEGER, 5, 10, 2)),
|
||||
RandomDistribution.uniform(RangeGenerator(DataType.STRING, "a", "z")))
|
||||
RandomDistribution.uniform(RangeGenerator(DataType.STRING, "a", "z")),
|
||||
)
|
||||
|
||||
|
||||
def generate_random_str(num: int):
|
||||
strs = distributions['random_string'].generate(num)
|
||||
strs = distributions["random_string"].generate(num)
|
||||
str_list = []
|
||||
for char_array in strs:
|
||||
str_res = "".join(char_array)
|
||||
@ -90,45 +94,74 @@ def generate_random_str(num: int):
|
||||
def random_strings_distr(size: int, count: int):
|
||||
distr = ArrayRandomDistribution(
|
||||
RandomDistribution.uniform([size]),
|
||||
RandomDistribution.uniform(RangeGenerator(DataType.STRING, "a", "z")))
|
||||
RandomDistribution.uniform(RangeGenerator(DataType.STRING, "a", "z")),
|
||||
)
|
||||
|
||||
return RandomDistribution.uniform([''.join(s) for s in distr.generate(count)])
|
||||
return RandomDistribution.uniform(["".join(s) for s in distr.generate(count)])
|
||||
|
||||
|
||||
small_string_choice = generate_random_str(20)
|
||||
|
||||
distributions['string_choice_small'] = RandomDistribution.choice(small_string_choice,
|
||||
small_query_weights)
|
||||
distributions["string_choice_small"] = RandomDistribution.choice(
|
||||
small_string_choice, small_query_weights
|
||||
)
|
||||
|
||||
string_range_4 = RandomDistribution.normal(RangeGenerator(DataType.STRING, "abca", "abc_"))
|
||||
string_range_5 = RandomDistribution.normal(RangeGenerator(DataType.STRING, "abcda", "abcd_"))
|
||||
string_range_7 = RandomDistribution.normal(RangeGenerator(DataType.STRING, "hello_a", "hello__"))
|
||||
string_range_4 = RandomDistribution.normal(
|
||||
RangeGenerator(DataType.STRING, "abca", "abc_")
|
||||
)
|
||||
string_range_5 = RandomDistribution.normal(
|
||||
RangeGenerator(DataType.STRING, "abcda", "abcd_")
|
||||
)
|
||||
string_range_7 = RandomDistribution.normal(
|
||||
RangeGenerator(DataType.STRING, "hello_a", "hello__")
|
||||
)
|
||||
string_range_12 = RandomDistribution.normal(
|
||||
RangeGenerator(DataType.STRING, "helloworldaa", "helloworldd_"))
|
||||
RangeGenerator(DataType.STRING, "helloworldaa", "helloworldd_")
|
||||
)
|
||||
|
||||
distributions['string_mixed'] = RandomDistribution.mixed(
|
||||
[string_range_4, string_range_5, string_range_7, string_range_12], [0.1, 0.15, 0.25, 0.5])
|
||||
distributions["string_mixed"] = RandomDistribution.mixed(
|
||||
[string_range_4, string_range_5, string_range_7, string_range_12],
|
||||
[0.1, 0.15, 0.25, 0.5],
|
||||
)
|
||||
|
||||
distributions['string_uniform'] = RandomDistribution.uniform(
|
||||
RangeGenerator(DataType.STRING, "helloworldaa", "helloworldd_"))
|
||||
distributions["string_uniform"] = RandomDistribution.uniform(
|
||||
RangeGenerator(DataType.STRING, "helloworldaa", "helloworldd_")
|
||||
)
|
||||
|
||||
distributions['int_normal'] = RandomDistribution.normal(
|
||||
RangeGenerator(DataType.INTEGER, 0, 1000, 2))
|
||||
distributions["int_normal"] = RandomDistribution.normal(
|
||||
RangeGenerator(DataType.INTEGER, 0, 1000, 2)
|
||||
)
|
||||
|
||||
lengths_distr = RandomDistribution.uniform(RangeGenerator(DataType.INTEGER, 1, 10))
|
||||
distributions['array_small'] = ArrayRandomDistribution(lengths_distr, distributions['int_normal'])
|
||||
distributions["array_small"] = ArrayRandomDistribution(
|
||||
lengths_distr, distributions["int_normal"]
|
||||
)
|
||||
|
||||
# Database settings
|
||||
database = config.DatabaseConfig(connection_string='mongodb://localhost',
|
||||
database_name='abt_calibration', dump_path='~/data/dump',
|
||||
restore_from_dump=config.RestoreMode.NEVER, dump_on_exit=False)
|
||||
database = config.DatabaseConfig(
|
||||
connection_string="mongodb://localhost",
|
||||
database_name="abt_calibration",
|
||||
dump_path="~/data/dump",
|
||||
restore_from_dump=config.RestoreMode.NEVER,
|
||||
dump_on_exit=False,
|
||||
)
|
||||
|
||||
|
||||
# Collection template settings
|
||||
def create_index_scan_collection_template(name: str, cardinality: int) -> config.CollectionTemplate:
|
||||
def create_index_scan_collection_template(
|
||||
name: str, cardinality: int
|
||||
) -> config.CollectionTemplate:
|
||||
values = [
|
||||
'iqtbr5b5is', 'vt5s3tf8o6', 'b0rgm58qsn', '9m59if353m', 'biw2l9ok17', 'b9ct0ue14d',
|
||||
'oxj0vxjsti', 'f3k8w9vb49', 'ec7v82k6nk', 'f49ufwaqx7'
|
||||
"iqtbr5b5is",
|
||||
"vt5s3tf8o6",
|
||||
"b0rgm58qsn",
|
||||
"9m59if353m",
|
||||
"biw2l9ok17",
|
||||
"b9ct0ue14d",
|
||||
"oxj0vxjsti",
|
||||
"f3k8w9vb49",
|
||||
"ec7v82k6nk",
|
||||
"f49ufwaqx7",
|
||||
]
|
||||
|
||||
start_weight = 10
|
||||
@ -143,80 +176,174 @@ def create_index_scan_collection_template(name: str, cardinality: int) -> config
|
||||
distr = RandomDistribution.choice(values, weights)
|
||||
|
||||
return config.CollectionTemplate(
|
||||
name=name, fields=[
|
||||
config.FieldTemplate(name="choice", data_type=config.DataType.STRING,
|
||||
distribution=distr, indexed=True),
|
||||
config.FieldTemplate(name="mixed1", data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_mixed"], indexed=False),
|
||||
config.FieldTemplate(name="uniform1", data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_uniform"], indexed=False),
|
||||
config.FieldTemplate(name="choice2", data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_choice"], indexed=False),
|
||||
config.FieldTemplate(name="mixed2", data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_mixed"], indexed=False),
|
||||
], compound_indexes=[], cardinalities=[cardinality])
|
||||
name=name,
|
||||
fields=[
|
||||
config.FieldTemplate(
|
||||
name="choice",
|
||||
data_type=config.DataType.STRING,
|
||||
distribution=distr,
|
||||
indexed=True,
|
||||
),
|
||||
config.FieldTemplate(
|
||||
name="mixed1",
|
||||
data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_mixed"],
|
||||
indexed=False,
|
||||
),
|
||||
config.FieldTemplate(
|
||||
name="uniform1",
|
||||
data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_uniform"],
|
||||
indexed=False,
|
||||
),
|
||||
config.FieldTemplate(
|
||||
name="choice2",
|
||||
data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_choice"],
|
||||
indexed=False,
|
||||
),
|
||||
config.FieldTemplate(
|
||||
name="mixed2",
|
||||
data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_mixed"],
|
||||
indexed=False,
|
||||
),
|
||||
],
|
||||
compound_indexes=[],
|
||||
cardinalities=[cardinality],
|
||||
)
|
||||
|
||||
|
||||
def create_physical_scan_collection_template(name: str,
|
||||
payload_size: int = 0) -> config.CollectionTemplate:
|
||||
def create_physical_scan_collection_template(
|
||||
name: str, payload_size: int = 0
|
||||
) -> config.CollectionTemplate:
|
||||
template = config.CollectionTemplate(
|
||||
name=name, fields=[
|
||||
config.FieldTemplate(name="choice1", data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_choice"], indexed=False),
|
||||
config.FieldTemplate(name="mixed1", data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_mixed"], indexed=False),
|
||||
config.FieldTemplate(name="uniform1", data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_uniform"], indexed=False),
|
||||
config.FieldTemplate(name="choice", data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_choice"], indexed=False),
|
||||
config.FieldTemplate(name="mixed2", data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_mixed"], indexed=False),
|
||||
], compound_indexes=[], cardinalities=[1000, 5000, 10000])
|
||||
name=name,
|
||||
fields=[
|
||||
config.FieldTemplate(
|
||||
name="choice1",
|
||||
data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_choice"],
|
||||
indexed=False,
|
||||
),
|
||||
config.FieldTemplate(
|
||||
name="mixed1",
|
||||
data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_mixed"],
|
||||
indexed=False,
|
||||
),
|
||||
config.FieldTemplate(
|
||||
name="uniform1",
|
||||
data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_uniform"],
|
||||
indexed=False,
|
||||
),
|
||||
config.FieldTemplate(
|
||||
name="choice",
|
||||
data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_choice"],
|
||||
indexed=False,
|
||||
),
|
||||
config.FieldTemplate(
|
||||
name="mixed2",
|
||||
data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_mixed"],
|
||||
indexed=False,
|
||||
),
|
||||
],
|
||||
compound_indexes=[],
|
||||
cardinalities=[1000, 5000, 10000],
|
||||
)
|
||||
|
||||
if payload_size > 0:
|
||||
payload_distr = random_strings_distr(payload_size, 1000)
|
||||
template.fields.append(
|
||||
config.FieldTemplate(name="payload", data_type=config.DataType.STRING,
|
||||
distribution=payload_distr, indexed=False))
|
||||
config.FieldTemplate(
|
||||
name="payload",
|
||||
data_type=config.DataType.STRING,
|
||||
distribution=payload_distr,
|
||||
indexed=False,
|
||||
)
|
||||
)
|
||||
return template
|
||||
|
||||
|
||||
collection_caridinalities = list(range(10000, 50001, 10000))
|
||||
|
||||
c_int_05 = config.CollectionTemplate(
|
||||
name="c_int_05", fields=[
|
||||
config.FieldTemplate(name="in1", data_type=config.DataType.INTEGER,
|
||||
distribution=distributions["int_normal"], indexed=True),
|
||||
config.FieldTemplate(name="mixed1", data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_mixed"], indexed=False),
|
||||
config.FieldTemplate(name="uniform1", data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_uniform"], indexed=False),
|
||||
config.FieldTemplate(name="in2", data_type=config.DataType.INTEGER,
|
||||
distribution=distributions["int_normal"], indexed=True),
|
||||
config.FieldTemplate(name="mixed2", data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_mixed"], indexed=False),
|
||||
], compound_indexes=[], cardinalities=collection_caridinalities)
|
||||
name="c_int_05",
|
||||
fields=[
|
||||
config.FieldTemplate(
|
||||
name="in1",
|
||||
data_type=config.DataType.INTEGER,
|
||||
distribution=distributions["int_normal"],
|
||||
indexed=True,
|
||||
),
|
||||
config.FieldTemplate(
|
||||
name="mixed1",
|
||||
data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_mixed"],
|
||||
indexed=False,
|
||||
),
|
||||
config.FieldTemplate(
|
||||
name="uniform1",
|
||||
data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_uniform"],
|
||||
indexed=False,
|
||||
),
|
||||
config.FieldTemplate(
|
||||
name="in2",
|
||||
data_type=config.DataType.INTEGER,
|
||||
distribution=distributions["int_normal"],
|
||||
indexed=True,
|
||||
),
|
||||
config.FieldTemplate(
|
||||
name="mixed2",
|
||||
data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_mixed"],
|
||||
indexed=False,
|
||||
),
|
||||
],
|
||||
compound_indexes=[],
|
||||
cardinalities=collection_caridinalities,
|
||||
)
|
||||
|
||||
c_arr_01 = config.CollectionTemplate(
|
||||
name="c_arr_01", fields=[
|
||||
config.FieldTemplate(name="as", data_type=config.DataType.INTEGER,
|
||||
distribution=distributions["array_small"], indexed=True)
|
||||
], compound_indexes=[], cardinalities=collection_caridinalities)
|
||||
name="c_arr_01",
|
||||
fields=[
|
||||
config.FieldTemplate(
|
||||
name="as",
|
||||
data_type=config.DataType.INTEGER,
|
||||
distribution=distributions["array_small"],
|
||||
indexed=True,
|
||||
)
|
||||
],
|
||||
compound_indexes=[],
|
||||
cardinalities=collection_caridinalities,
|
||||
)
|
||||
|
||||
index_scan = create_index_scan_collection_template('index_scan', 1000000)
|
||||
index_scan = create_index_scan_collection_template("index_scan", 1000000)
|
||||
|
||||
physical_scan = create_physical_scan_collection_template('physical_scan', 2000)
|
||||
physical_scan = create_physical_scan_collection_template("physical_scan", 2000)
|
||||
|
||||
# Data Generator settings
|
||||
data_generator = config.DataGeneratorConfig(
|
||||
enabled=True, create_indexes=True, batch_size=10000,
|
||||
enabled=True,
|
||||
create_indexes=True,
|
||||
batch_size=10000,
|
||||
collection_templates=[index_scan, physical_scan, c_int_05, c_arr_01],
|
||||
write_mode=config.WriteMode.REPLACE, collection_name_with_card=True)
|
||||
write_mode=config.WriteMode.REPLACE,
|
||||
collection_name_with_card=True,
|
||||
)
|
||||
|
||||
# Workload Execution settings
|
||||
workload_execution = config.WorkloadExecutionConfig(
|
||||
enabled=True, output_collection_name='calibrationData', write_mode=config.WriteMode.REPLACE,
|
||||
warmup_runs=3, runs=30)
|
||||
enabled=True,
|
||||
output_collection_name="calibrationData",
|
||||
write_mode=config.WriteMode.REPLACE,
|
||||
warmup_runs=3,
|
||||
runs=30,
|
||||
)
|
||||
|
||||
|
||||
def make_filter_by_note(note_value: any):
|
||||
@ -227,30 +354,45 @@ def make_filter_by_note(note_value: any):
|
||||
|
||||
|
||||
abt_nodes = [
|
||||
config.AbtNodeCalibrationConfig(type='PhysicalScan',
|
||||
filter_function=make_filter_by_note('PhysicalScan')),
|
||||
config.AbtNodeCalibrationConfig(type='IndexScan',
|
||||
filter_function=make_filter_by_note('IndexScan')),
|
||||
config.AbtNodeCalibrationConfig(type='Seek', filter_function=make_filter_by_note('IndexScan')),
|
||||
config.AbtNodeCalibrationConfig(type='Filter',
|
||||
filter_function=make_filter_by_note('PhysicalScan')),
|
||||
config.AbtNodeCalibrationConfig(type='Evaluation',
|
||||
filter_function=make_filter_by_note('Evaluation')),
|
||||
config.AbtNodeCalibrationConfig(type='NestedLoopJoin'),
|
||||
config.AbtNodeCalibrationConfig(type='HashJoin'),
|
||||
config.AbtNodeCalibrationConfig(type='MergeJoin'),
|
||||
config.AbtNodeCalibrationConfig(type='Union'),
|
||||
config.AbtNodeCalibrationConfig(type='LimitSkip',
|
||||
filter_function=make_filter_by_note('LimitSkip')),
|
||||
config.AbtNodeCalibrationConfig(type='GroupBy'),
|
||||
config.AbtNodeCalibrationConfig(type='Unwind'),
|
||||
config.AbtNodeCalibrationConfig(type='Unique'),
|
||||
config.AbtNodeCalibrationConfig(
|
||||
type="PhysicalScan", filter_function=make_filter_by_note("PhysicalScan")
|
||||
),
|
||||
config.AbtNodeCalibrationConfig(
|
||||
type="IndexScan", filter_function=make_filter_by_note("IndexScan")
|
||||
),
|
||||
config.AbtNodeCalibrationConfig(
|
||||
type="Seek", filter_function=make_filter_by_note("IndexScan")
|
||||
),
|
||||
config.AbtNodeCalibrationConfig(
|
||||
type="Filter", filter_function=make_filter_by_note("PhysicalScan")
|
||||
),
|
||||
config.AbtNodeCalibrationConfig(
|
||||
type="Evaluation", filter_function=make_filter_by_note("Evaluation")
|
||||
),
|
||||
config.AbtNodeCalibrationConfig(type="NestedLoopJoin"),
|
||||
config.AbtNodeCalibrationConfig(type="HashJoin"),
|
||||
config.AbtNodeCalibrationConfig(type="MergeJoin"),
|
||||
config.AbtNodeCalibrationConfig(type="Union"),
|
||||
config.AbtNodeCalibrationConfig(
|
||||
type="LimitSkip", filter_function=make_filter_by_note("LimitSkip")
|
||||
),
|
||||
config.AbtNodeCalibrationConfig(type="GroupBy"),
|
||||
config.AbtNodeCalibrationConfig(type="Unwind"),
|
||||
config.AbtNodeCalibrationConfig(type="Unique"),
|
||||
]
|
||||
|
||||
# Calibrator settings
|
||||
abt_calibrator = config.AbtCalibratorConfig(
|
||||
enabled=True, test_size=0.2, input_collection_name=workload_execution.output_collection_name,
|
||||
trace=False, nodes=abt_nodes)
|
||||
enabled=True,
|
||||
test_size=0.2,
|
||||
input_collection_name=workload_execution.output_collection_name,
|
||||
trace=False,
|
||||
nodes=abt_nodes,
|
||||
)
|
||||
|
||||
main_config = config.Config(database=database, data_generator=data_generator,
|
||||
abt_calibrator=abt_calibrator, workload_execution=workload_execution)
|
||||
main_config = config.Config(
|
||||
database=database,
|
||||
data_generator=data_generator,
|
||||
abt_calibrator=abt_calibrator,
|
||||
workload_execution=workload_execution,
|
||||
)
|
||||
|
||||
@ -40,15 +40,18 @@ from random_generator import (
|
||||
RangeGenerator,
|
||||
)
|
||||
|
||||
__all__ = ['database_config', 'data_generator_config']
|
||||
__all__ = ["database_config", "data_generator_config"]
|
||||
|
||||
################################################################################
|
||||
# Data distributions
|
||||
################################################################################
|
||||
|
||||
|
||||
def add_distribution(distr_set: Sequence[RandomDistribution], distr_type: DistributionType,
|
||||
rg: RangeGenerator):
|
||||
def add_distribution(
|
||||
distr_set: Sequence[RandomDistribution],
|
||||
distr_type: DistributionType,
|
||||
rg: RangeGenerator,
|
||||
):
|
||||
distr = None
|
||||
if distr_type == DistributionType.UNIFORM:
|
||||
distr = RandomDistribution.uniform(rg)
|
||||
@ -104,27 +107,48 @@ for range_gen in int_ranges_2:
|
||||
# Mixes of distributions with different NDV and value distances
|
||||
int_distributions.append(
|
||||
RandomDistribution.mixed(
|
||||
children=[int_distributions[0], int_distributions_offset[0], int_distributions[4]],
|
||||
weight=[1, 1, 1]))
|
||||
children=[
|
||||
int_distributions[0],
|
||||
int_distributions_offset[0],
|
||||
int_distributions[4],
|
||||
],
|
||||
weight=[1, 1, 1],
|
||||
)
|
||||
)
|
||||
|
||||
int_distributions.append(
|
||||
RandomDistribution.mixed(
|
||||
children=[int_distributions[1], int_distributions[4], int_distributions[7]],
|
||||
weight=[1, 1, 1]))
|
||||
weight=[1, 1, 1],
|
||||
)
|
||||
)
|
||||
|
||||
int_distributions.append(
|
||||
RandomDistribution.mixed(
|
||||
children=[
|
||||
int_distributions[1], int_distributions_offset[1], int_distributions[3],
|
||||
int_distributions[2], int_distributions_offset[2]
|
||||
], weight=[1, 1, 1, 1, 1]))
|
||||
int_distributions[1],
|
||||
int_distributions_offset[1],
|
||||
int_distributions[3],
|
||||
int_distributions[2],
|
||||
int_distributions_offset[2],
|
||||
],
|
||||
weight=[1, 1, 1, 1, 1],
|
||||
)
|
||||
)
|
||||
|
||||
int_distributions.append(
|
||||
RandomDistribution.mixed(
|
||||
children=[
|
||||
int_distributions[2], int_distributions[3], int_distributions[6],
|
||||
int_distributions_offset[1], int_distributions_offset[2], int_distributions_offset[5]
|
||||
], weight=[1, 1, 1, 1, 1, 1]))
|
||||
int_distributions[2],
|
||||
int_distributions[3],
|
||||
int_distributions[6],
|
||||
int_distributions_offset[1],
|
||||
int_distributions_offset[2],
|
||||
int_distributions_offset[5],
|
||||
],
|
||||
weight=[1, 1, 1, 1, 1, 1],
|
||||
)
|
||||
)
|
||||
|
||||
#############################
|
||||
# Double number distributions
|
||||
@ -137,7 +161,7 @@ dbl_ranges = [
|
||||
# 10K unique doubles with different distances
|
||||
RangeGenerator(DataType.DOUBLE, 0.0, 1000.0, 0.1),
|
||||
RangeGenerator(DataType.DOUBLE, 0.0, 100000.0, 10),
|
||||
RangeGenerator(DataType.DOUBLE, 0.0, 10000000.0, 1000)
|
||||
RangeGenerator(DataType.DOUBLE, 0.0, 10000000.0, 1000),
|
||||
]
|
||||
|
||||
dbl_distributions = []
|
||||
@ -149,16 +173,25 @@ for range_gen in dbl_ranges:
|
||||
dbl_distributions.append(
|
||||
RandomDistribution.mixed(
|
||||
children=[dbl_distributions[0], dbl_distributions[3], dbl_distributions[10]],
|
||||
weight=[1, 1, 1]))
|
||||
weight=[1, 1, 1],
|
||||
)
|
||||
)
|
||||
|
||||
dbl_distributions.append(
|
||||
RandomDistribution.mixed(
|
||||
children=[
|
||||
dbl_distributions[0],
|
||||
dbl_distributions[4],
|
||||
RandomDistribution.normal(RangeGenerator(DataType.DOUBLE, 500.0, 600.0, 0.1)),
|
||||
RandomDistribution.normal(RangeGenerator(DataType.DOUBLE, 3000200.0, 5000100.0, 3030)),
|
||||
], weight=[1, 1, 1, 1]))
|
||||
RandomDistribution.normal(
|
||||
RangeGenerator(DataType.DOUBLE, 500.0, 600.0, 0.1)
|
||||
),
|
||||
RandomDistribution.normal(
|
||||
RangeGenerator(DataType.DOUBLE, 3000200.0, 5000100.0, 3030)
|
||||
),
|
||||
],
|
||||
weight=[1, 1, 1, 1],
|
||||
)
|
||||
)
|
||||
|
||||
#############################
|
||||
# Date distributions
|
||||
@ -168,13 +201,27 @@ HOUR = MINUTE * 60
|
||||
DAY = HOUR * 24
|
||||
MONTH = DAY * 30
|
||||
|
||||
range_dtt_1y = RangeGenerator(DataType.DATE, datetime(2007, 1, 1), datetime(2008, 1, 1), HOUR)
|
||||
range_dtt_1m_1 = RangeGenerator(DataType.DATE, datetime(2007, 2, 1), datetime(2008, 3, 1), HOUR)
|
||||
range_dtt_1m_2 = RangeGenerator(DataType.DATE, datetime(2007, 6, 1), datetime(2008, 7, 1), HOUR)
|
||||
range_dtt_1m_3 = RangeGenerator(DataType.DATE, datetime(2007, 10, 1), datetime(2008, 11, 1), HOUR)
|
||||
range_dtt_10y_1 = RangeGenerator(DataType.DATE, datetime(2006, 1, 1), datetime(2016, 1, 1), DAY)
|
||||
range_dtt_10y_2 = RangeGenerator(DataType.DATE, datetime(1995, 1, 1), datetime(2005, 1, 1), DAY)
|
||||
range_dtt_20y = RangeGenerator(DataType.DATE, datetime(1997, 10, 1), datetime(2017, 11, 1), MONTH)
|
||||
range_dtt_1y = RangeGenerator(
|
||||
DataType.DATE, datetime(2007, 1, 1), datetime(2008, 1, 1), HOUR
|
||||
)
|
||||
range_dtt_1m_1 = RangeGenerator(
|
||||
DataType.DATE, datetime(2007, 2, 1), datetime(2008, 3, 1), HOUR
|
||||
)
|
||||
range_dtt_1m_2 = RangeGenerator(
|
||||
DataType.DATE, datetime(2007, 6, 1), datetime(2008, 7, 1), HOUR
|
||||
)
|
||||
range_dtt_1m_3 = RangeGenerator(
|
||||
DataType.DATE, datetime(2007, 10, 1), datetime(2008, 11, 1), HOUR
|
||||
)
|
||||
range_dtt_10y_1 = RangeGenerator(
|
||||
DataType.DATE, datetime(2006, 1, 1), datetime(2016, 1, 1), DAY
|
||||
)
|
||||
range_dtt_10y_2 = RangeGenerator(
|
||||
DataType.DATE, datetime(1995, 1, 1), datetime(2005, 1, 1), DAY
|
||||
)
|
||||
range_dtt_20y = RangeGenerator(
|
||||
DataType.DATE, datetime(1997, 10, 1), datetime(2017, 11, 1), MONTH
|
||||
)
|
||||
|
||||
dt_distributions = []
|
||||
|
||||
@ -182,25 +229,33 @@ add_distribution(dt_distributions, DistributionType.UNIFORM, range_dtt_1y)
|
||||
add_distribution(dt_distributions, DistributionType.NORMAL, range_dtt_10y_1)
|
||||
|
||||
dt_distributions.append(
|
||||
RandomDistribution.mixed([
|
||||
RandomDistribution.uniform(range_dtt_1y),
|
||||
RandomDistribution.uniform(range_dtt_1m_1),
|
||||
RandomDistribution.uniform(range_dtt_1m_2),
|
||||
RandomDistribution.uniform(range_dtt_1m_3)
|
||||
], [1, 1, 1, 1]))
|
||||
RandomDistribution.mixed(
|
||||
[
|
||||
RandomDistribution.uniform(range_dtt_1y),
|
||||
RandomDistribution.uniform(range_dtt_1m_1),
|
||||
RandomDistribution.uniform(range_dtt_1m_2),
|
||||
RandomDistribution.uniform(range_dtt_1m_3),
|
||||
],
|
||||
[1, 1, 1, 1],
|
||||
)
|
||||
)
|
||||
|
||||
dt_distributions.append(
|
||||
RandomDistribution.mixed([
|
||||
RandomDistribution.uniform(range_dtt_10y_1),
|
||||
RandomDistribution.uniform(range_dtt_10y_2),
|
||||
RandomDistribution.uniform(range_dtt_20y)
|
||||
], [1, 1, 1]))
|
||||
RandomDistribution.mixed(
|
||||
[
|
||||
RandomDistribution.uniform(range_dtt_10y_1),
|
||||
RandomDistribution.uniform(range_dtt_10y_2),
|
||||
RandomDistribution.uniform(range_dtt_20y),
|
||||
],
|
||||
[1, 1, 1],
|
||||
)
|
||||
)
|
||||
|
||||
#######################
|
||||
# String distributions
|
||||
|
||||
PRINTED_CHAR_MIN_CODE = ord('0')
|
||||
PRINTED_CHAR_MAX_CODE = ord('~')
|
||||
PRINTED_CHAR_MIN_CODE = ord("0")
|
||||
PRINTED_CHAR_MAX_CODE = ord("~")
|
||||
|
||||
ascii_printable_chars = [
|
||||
chr(code) for code in range(PRINTED_CHAR_MIN_CODE, PRINTED_CHAR_MAX_CODE + 1)
|
||||
@ -209,19 +264,27 @@ ascii_printable_chars = [
|
||||
|
||||
def next_char(char: str, distance: int, min_char_code: int, max_char_code: int):
|
||||
char_code = ord(char)
|
||||
assert (min_char_code <= char_code <= max_char_code
|
||||
), f'char_code "{char_code}" is out of range ({min_char_code}, {max_char_code})'
|
||||
assert (
|
||||
min_char_code <= char_code <= max_char_code
|
||||
), f'char_code "{char_code}" is out of range ({min_char_code}, {max_char_code})'
|
||||
number_of_chars = max_char_code - min_char_code + 1
|
||||
new_char_code = ((char_code - min_char_code + distance) % number_of_chars) + min_char_code
|
||||
assert (min_char_code <= new_char_code <=
|
||||
max_char_code), f'new char code "{new_char_code}" is out of range'
|
||||
new_char_code = (
|
||||
(char_code - min_char_code + distance) % number_of_chars
|
||||
) + min_char_code
|
||||
assert (
|
||||
min_char_code <= new_char_code <= max_char_code
|
||||
), f'new char code "{new_char_code}" is out of range'
|
||||
return chr(new_char_code)
|
||||
|
||||
|
||||
def generate_str_by_distance(num_strings: int, seed_str: str, distance_distr_0: RandomDistribution,
|
||||
distance_distr_1: RandomDistribution,
|
||||
distance_distr_2: RandomDistribution,
|
||||
distance_distr_3: RandomDistribution):
|
||||
def generate_str_by_distance(
|
||||
num_strings: int,
|
||||
seed_str: str,
|
||||
distance_distr_0: RandomDistribution,
|
||||
distance_distr_1: RandomDistribution,
|
||||
distance_distr_2: RandomDistribution,
|
||||
distance_distr_3: RandomDistribution,
|
||||
):
|
||||
"""
|
||||
Generate a set of unique strings with different string distances.
|
||||
|
||||
@ -240,14 +303,18 @@ def generate_str_by_distance(num_strings: int, seed_str: str, distance_distr_0:
|
||||
cur_str = seed_str
|
||||
str_set.add(cur_str)
|
||||
for i in range(1, num_strings):
|
||||
new_str = next_char(cur_str[0], distances_0[i], PRINTED_CHAR_MIN_CODE,
|
||||
PRINTED_CHAR_MAX_CODE)
|
||||
new_str += next_char(cur_str[1], distances_1[i], PRINTED_CHAR_MIN_CODE,
|
||||
PRINTED_CHAR_MAX_CODE)
|
||||
new_str += next_char(cur_str[2], distances_2[i], PRINTED_CHAR_MIN_CODE,
|
||||
PRINTED_CHAR_MAX_CODE)
|
||||
new_str += next_char(cur_str[3], distances_3[i], PRINTED_CHAR_MIN_CODE,
|
||||
PRINTED_CHAR_MAX_CODE)
|
||||
new_str = next_char(
|
||||
cur_str[0], distances_0[i], PRINTED_CHAR_MIN_CODE, PRINTED_CHAR_MAX_CODE
|
||||
)
|
||||
new_str += next_char(
|
||||
cur_str[1], distances_1[i], PRINTED_CHAR_MIN_CODE, PRINTED_CHAR_MAX_CODE
|
||||
)
|
||||
new_str += next_char(
|
||||
cur_str[2], distances_2[i], PRINTED_CHAR_MIN_CODE, PRINTED_CHAR_MAX_CODE
|
||||
)
|
||||
new_str += next_char(
|
||||
cur_str[3], distances_3[i], PRINTED_CHAR_MIN_CODE, PRINTED_CHAR_MAX_CODE
|
||||
)
|
||||
str_set.add(new_str)
|
||||
cur_str = new_str
|
||||
return list(str_set)
|
||||
@ -268,17 +335,17 @@ d4 = RandomDistribution.uniform(range_int_20_30)
|
||||
# Sets of strings where characters at different positions have different distances
|
||||
string_sets = {}
|
||||
# 250 unique strings
|
||||
string_sets['set_1112_250'] = generate_str_by_distance(250, 'xxxx', d1, d1, d1, d2)
|
||||
string_sets['set_2221_250'] = generate_str_by_distance(250, 'azay', d2, d2, d3, d1)
|
||||
string_sets['set_5555_250'] = generate_str_by_distance(250, 'axbz', d4, d4, d4, d4)
|
||||
string_sets["set_1112_250"] = generate_str_by_distance(250, "xxxx", d1, d1, d1, d2)
|
||||
string_sets["set_2221_250"] = generate_str_by_distance(250, "azay", d2, d2, d3, d1)
|
||||
string_sets["set_5555_250"] = generate_str_by_distance(250, "axbz", d4, d4, d4, d4)
|
||||
# 1000 unique strings
|
||||
string_sets['set_1112_1000'] = generate_str_by_distance(1000, 'xxxx', d1, d1, d1, d2)
|
||||
string_sets['set_2221_1000'] = generate_str_by_distance(1000, 'azay', d2, d2, d3, d1)
|
||||
string_sets['set_5555_1000'] = generate_str_by_distance(1000, 'axbz', d4, d4, d4, d4)
|
||||
string_sets["set_1112_1000"] = generate_str_by_distance(1000, "xxxx", d1, d1, d1, d2)
|
||||
string_sets["set_2221_1000"] = generate_str_by_distance(1000, "azay", d2, d2, d3, d1)
|
||||
string_sets["set_5555_1000"] = generate_str_by_distance(1000, "axbz", d4, d4, d4, d4)
|
||||
# 10000 unique strings
|
||||
string_sets['set_1112_10000'] = generate_str_by_distance(10000, 'xxxx', d1, d1, d1, d2)
|
||||
string_sets['set_2221_10000'] = generate_str_by_distance(10000, 'azay', d2, d2, d3, d1)
|
||||
string_sets['set_5555_10000'] = generate_str_by_distance(10000, 'axbz', d4, d4, d4, d4)
|
||||
string_sets["set_1112_10000"] = generate_str_by_distance(10000, "xxxx", d1, d1, d1, d2)
|
||||
string_sets["set_2221_10000"] = generate_str_by_distance(10000, "azay", d2, d2, d3, d1)
|
||||
string_sets["set_5555_10000"] = generate_str_by_distance(10000, "axbz", d4, d4, d4, d4)
|
||||
|
||||
# Weights with different variance. For instance if the smallest weight is 1, and the biggest weight is 5
|
||||
# then some values in a choice distribution will be picked with at most 5 times higher probability.
|
||||
@ -291,19 +358,26 @@ weight_range_s = RangeGenerator(DataType.INTEGER, 95, 101, 1)
|
||||
weight_range_l = RangeGenerator(DataType.INTEGER, 25, 101, 2)
|
||||
|
||||
weights = {}
|
||||
weights['weight_unif_s'] = RandomDistribution.uniform(weight_range_s)
|
||||
weights['weight_unif_l'] = RandomDistribution.uniform(weight_range_l)
|
||||
weights["weight_unif_s"] = RandomDistribution.uniform(weight_range_s)
|
||||
weights["weight_unif_l"] = RandomDistribution.uniform(weight_range_l)
|
||||
|
||||
#weights['weight_norm_s'] = RandomDistribution.normal(weight_range_s)
|
||||
#weights['weight_norm_l'] = RandomDistribution.normal(weight_range_l)
|
||||
# weights['weight_norm_s'] = RandomDistribution.normal(weight_range_s)
|
||||
# weights['weight_norm_l'] = RandomDistribution.normal(weight_range_l)
|
||||
|
||||
#weights['chi2_s'] = RandomDistribution.noncentral_chisquare(weight_range_s)
|
||||
#weights['chi2_l'] = RandomDistribution.noncentral_chisquare(weight_range_l)
|
||||
# weights['chi2_s'] = RandomDistribution.noncentral_chisquare(weight_range_s)
|
||||
# weights['chi2_l'] = RandomDistribution.noncentral_chisquare(weight_range_l)
|
||||
|
||||
|
||||
def add_choice_distr(distr_set: Sequence[RandomDistribution], str_set: Sequence[str],
|
||||
weight_distr: RandomDistribution, v_name: str, w_name: str):
|
||||
distr = RandomDistribution.choice(str_set, weight_distr.generate(len(str_set)), v_name, w_name)
|
||||
def add_choice_distr(
|
||||
distr_set: Sequence[RandomDistribution],
|
||||
str_set: Sequence[str],
|
||||
weight_distr: RandomDistribution,
|
||||
v_name: str,
|
||||
w_name: str,
|
||||
):
|
||||
distr = RandomDistribution.choice(
|
||||
str_set, weight_distr.generate(len(str_set)), v_name, w_name
|
||||
)
|
||||
distr_set.append(distr)
|
||||
|
||||
|
||||
@ -320,12 +394,19 @@ for set_name, cur_set in string_sets.items():
|
||||
|
||||
# array lenght distributions - they are all uniform
|
||||
arr_len_dist_s = RandomDistribution.uniform(RangeGenerator(DataType.INTEGER, 1, 6, 1))
|
||||
arr_len_dist_m = RandomDistribution.uniform(RangeGenerator(DataType.INTEGER, 90, 110, 3))
|
||||
arr_len_dist_l = RandomDistribution.uniform(RangeGenerator(DataType.INTEGER, 900, 1100, 10))
|
||||
arr_len_dist_m = RandomDistribution.uniform(
|
||||
RangeGenerator(DataType.INTEGER, 90, 110, 3)
|
||||
)
|
||||
arr_len_dist_l = RandomDistribution.uniform(
|
||||
RangeGenerator(DataType.INTEGER, 900, 1100, 10)
|
||||
)
|
||||
|
||||
|
||||
def add_array_distr(distr_set: Sequence[RandomDistribution], lengths_distr: RandomDistribution,
|
||||
value_distr: RandomDistribution):
|
||||
def add_array_distr(
|
||||
distr_set: Sequence[RandomDistribution],
|
||||
lengths_distr: RandomDistribution,
|
||||
value_distr: RandomDistribution,
|
||||
):
|
||||
distr_set.append(ArrayRandomDistribution(lengths_distr, value_distr))
|
||||
|
||||
|
||||
@ -349,24 +430,30 @@ add_array_distr(arr_distributions, arr_len_dist_l, str_distributions[-1])
|
||||
|
||||
# 30% scalars, 70% arrays
|
||||
arr_distributions.append(
|
||||
RandomDistribution.mixed([int_distributions[0], arr_distributions[0]], [0.3, 0.7]))
|
||||
RandomDistribution.mixed([int_distributions[0], arr_distributions[0]], [0.3, 0.7])
|
||||
)
|
||||
arr_distributions.append(
|
||||
RandomDistribution.mixed([int_distributions[-1], arr_distributions[-1]], [0.3, 0.7]))
|
||||
RandomDistribution.mixed([int_distributions[-1], arr_distributions[-1]], [0.3, 0.7])
|
||||
)
|
||||
# 70% scalars, 30% arrays
|
||||
arr_distributions.append(
|
||||
RandomDistribution.mixed([int_distributions[0], arr_distributions[0]], [0.7, 0.3]))
|
||||
RandomDistribution.mixed([int_distributions[0], arr_distributions[0]], [0.7, 0.3])
|
||||
)
|
||||
arr_distributions.append(
|
||||
RandomDistribution.mixed([int_distributions[-1], arr_distributions[-1]], [0.7, 0.3]))
|
||||
RandomDistribution.mixed([int_distributions[-1], arr_distributions[-1]], [0.7, 0.3])
|
||||
)
|
||||
|
||||
arr_zero_size = RandomDistribution.uniform(RangeGenerator(DataType.INTEGER, 0, 1, 1))
|
||||
arr_empty_distr = ArrayRandomDistribution(arr_zero_size, int_distributions[0])
|
||||
|
||||
# 20% empty arrays
|
||||
arr_distributions.append(
|
||||
RandomDistribution.mixed([arr_empty_distr, arr_distributions[2]], [0.2, 0.8]))
|
||||
RandomDistribution.mixed([arr_empty_distr, arr_distributions[2]], [0.2, 0.8])
|
||||
)
|
||||
# 80% empty arrays
|
||||
arr_distributions.append(
|
||||
RandomDistribution.mixed([arr_empty_distr, arr_distributions[2]], [0.8, 0.2]))
|
||||
RandomDistribution.mixed([arr_empty_distr, arr_distributions[2]], [0.8, 0.2])
|
||||
)
|
||||
|
||||
###############################
|
||||
# Mixed data type distributions
|
||||
@ -377,40 +464,68 @@ mix_distributions = []
|
||||
int_str_mix_1 = [int_distributions[0], str_distributions[0]]
|
||||
int_str_mix_2 = [int_distributions_offset[7], str_distributions[-1]]
|
||||
|
||||
mix_distributions.append(RandomDistribution.mixed(children=int_str_mix_1, weight=[0.5, 0.5]))
|
||||
mix_distributions.append(RandomDistribution.mixed(children=int_str_mix_2, weight=[0.5, 0.5]))
|
||||
mix_distributions.append(
|
||||
RandomDistribution.mixed(children=int_str_mix_1, weight=[0.5, 0.5])
|
||||
)
|
||||
mix_distributions.append(
|
||||
RandomDistribution.mixed(children=int_str_mix_2, weight=[0.5, 0.5])
|
||||
)
|
||||
|
||||
mix_distributions.append(RandomDistribution.mixed(children=int_str_mix_1, weight=[0.1, 0.9]))
|
||||
mix_distributions.append(RandomDistribution.mixed(children=int_str_mix_1, weight=[0.9, 0.1]))
|
||||
mix_distributions.append(RandomDistribution.mixed(children=int_str_mix_2, weight=[0.1, 0.9]))
|
||||
mix_distributions.append(RandomDistribution.mixed(children=int_str_mix_2, weight=[0.9, 0.1]))
|
||||
mix_distributions.append(
|
||||
RandomDistribution.mixed(children=int_str_mix_1, weight=[0.1, 0.9])
|
||||
)
|
||||
mix_distributions.append(
|
||||
RandomDistribution.mixed(children=int_str_mix_1, weight=[0.9, 0.1])
|
||||
)
|
||||
mix_distributions.append(
|
||||
RandomDistribution.mixed(children=int_str_mix_2, weight=[0.1, 0.9])
|
||||
)
|
||||
mix_distributions.append(
|
||||
RandomDistribution.mixed(children=int_str_mix_2, weight=[0.9, 0.1])
|
||||
)
|
||||
|
||||
# Doubles and strings
|
||||
dbl_ascii_range = RangeGenerator(DataType.DOUBLE, float(PRINTED_CHAR_MIN_CODE),
|
||||
float(PRINTED_CHAR_MAX_CODE), 0.01)
|
||||
dbl_ascii_range = RangeGenerator(
|
||||
DataType.DOUBLE, float(PRINTED_CHAR_MIN_CODE), float(PRINTED_CHAR_MAX_CODE), 0.01
|
||||
)
|
||||
ascii_double_range_distr = RandomDistribution.normal(dbl_ascii_range)
|
||||
|
||||
dbl_str_mix_1 = [ascii_double_range_distr, str_distributions[1]]
|
||||
mix_distributions.append(RandomDistribution.mixed(children=dbl_str_mix_1, weight=[0.5, 0.5]))
|
||||
mix_distributions.append(RandomDistribution.mixed(children=dbl_str_mix_1, weight=[0.1, 0.9]))
|
||||
mix_distributions.append(RandomDistribution.mixed(children=dbl_str_mix_1, weight=[0.9, 0.1]))
|
||||
mix_distributions.append(
|
||||
RandomDistribution.mixed(children=dbl_str_mix_1, weight=[0.5, 0.5])
|
||||
)
|
||||
mix_distributions.append(
|
||||
RandomDistribution.mixed(children=dbl_str_mix_1, weight=[0.1, 0.9])
|
||||
)
|
||||
mix_distributions.append(
|
||||
RandomDistribution.mixed(children=dbl_str_mix_1, weight=[0.9, 0.1])
|
||||
)
|
||||
|
||||
dbl_str_mix_2 = [dbl_distributions[5], str_distributions[0]]
|
||||
mix_distributions.append(RandomDistribution.mixed(children=dbl_str_mix_2, weight=[0.5, 0.5]))
|
||||
mix_distributions.append(
|
||||
RandomDistribution.mixed(children=dbl_str_mix_2, weight=[0.5, 0.5])
|
||||
)
|
||||
|
||||
dbl_str_mix_3 = [dbl_distributions[5], str_distributions[5]]
|
||||
mix_distributions.append(RandomDistribution.mixed(children=dbl_str_mix_3, weight=[0.5, 0.5]))
|
||||
mix_distributions.append(
|
||||
RandomDistribution.mixed(children=dbl_str_mix_3, weight=[0.5, 0.5])
|
||||
)
|
||||
|
||||
# Doubles and/or strings and dates
|
||||
|
||||
dbl_str_dt_mix_1 = [ascii_double_range_distr, str_distributions[4], dt_distributions[0]]
|
||||
mix_distributions.append(
|
||||
RandomDistribution.mixed(children=dbl_str_dt_mix_1, weight=[0.5, 0.5, 0.5]))
|
||||
RandomDistribution.mixed(children=dbl_str_dt_mix_1, weight=[0.5, 0.5, 0.5])
|
||||
)
|
||||
|
||||
str_dt_mix_1 = [str_distributions[0], dt_distributions[-1]]
|
||||
mix_distributions.append(RandomDistribution.mixed(children=str_dt_mix_1, weight=[0.5, 0.5]))
|
||||
mix_distributions.append(
|
||||
RandomDistribution.mixed(children=str_dt_mix_1, weight=[0.5, 0.5])
|
||||
)
|
||||
str_dt_mix_2 = [str_distributions[-1], dt_distributions[0]]
|
||||
mix_distributions.append(RandomDistribution.mixed(children=str_dt_mix_2, weight=[0.5, 0.5]))
|
||||
mix_distributions.append(
|
||||
RandomDistribution.mixed(children=str_dt_mix_2, weight=[0.5, 0.5])
|
||||
)
|
||||
|
||||
################################################################################
|
||||
# Collection templates
|
||||
@ -424,46 +539,88 @@ mix_distributions.append(RandomDistribution.mixed(children=str_dt_mix_2, weight=
|
||||
collection_cardinalities = [500]
|
||||
|
||||
field_templates = [
|
||||
config.FieldTemplate(name=f'{str(dist)}', data_type=config.DataType.INTEGER, distribution=dist,
|
||||
indexed=False) for dist in int_distributions
|
||||
config.FieldTemplate(
|
||||
name=f"{str(dist)}",
|
||||
data_type=config.DataType.INTEGER,
|
||||
distribution=dist,
|
||||
indexed=False,
|
||||
)
|
||||
for dist in int_distributions
|
||||
]
|
||||
field_templates += [
|
||||
config.FieldTemplate(name=f'{str(dist)}', data_type=config.DataType.STRING, distribution=dist,
|
||||
indexed=False) for dist in str_distributions
|
||||
config.FieldTemplate(
|
||||
name=f"{str(dist)}",
|
||||
data_type=config.DataType.STRING,
|
||||
distribution=dist,
|
||||
indexed=False,
|
||||
)
|
||||
for dist in str_distributions
|
||||
]
|
||||
field_templates += [
|
||||
config.FieldTemplate(name=f'{str(dist)}', data_type=config.DataType.ARRAY, distribution=dist,
|
||||
indexed=False) for dist in arr_distributions
|
||||
config.FieldTemplate(
|
||||
name=f"{str(dist)}",
|
||||
data_type=config.DataType.ARRAY,
|
||||
distribution=dist,
|
||||
indexed=False,
|
||||
)
|
||||
for dist in arr_distributions
|
||||
]
|
||||
field_templates += [
|
||||
config.FieldTemplate(name=f'{str(dist)}', data_type=config.DataType.DOUBLE, distribution=dist,
|
||||
indexed=False) for dist in dbl_distributions
|
||||
config.FieldTemplate(
|
||||
name=f"{str(dist)}",
|
||||
data_type=config.DataType.DOUBLE,
|
||||
distribution=dist,
|
||||
indexed=False,
|
||||
)
|
||||
for dist in dbl_distributions
|
||||
]
|
||||
field_templates += [
|
||||
config.FieldTemplate(name=f'{str(dist)}', data_type=config.DataType.DATE, distribution=dist,
|
||||
indexed=False) for dist in dt_distributions
|
||||
config.FieldTemplate(
|
||||
name=f"{str(dist)}",
|
||||
data_type=config.DataType.DATE,
|
||||
distribution=dist,
|
||||
indexed=False,
|
||||
)
|
||||
for dist in dt_distributions
|
||||
]
|
||||
field_templates += [
|
||||
config.FieldTemplate(name=f'{str(dist)}', data_type=config.DataType.MIXDATA, distribution=dist,
|
||||
indexed=False) for dist in mix_distributions
|
||||
config.FieldTemplate(
|
||||
name=f"{str(dist)}",
|
||||
data_type=config.DataType.MIXDATA,
|
||||
distribution=dist,
|
||||
indexed=False,
|
||||
)
|
||||
for dist in mix_distributions
|
||||
]
|
||||
|
||||
ce_data = config.CollectionTemplate(name="ce_data", fields=field_templates, compound_indexes=[],
|
||||
cardinalities=collection_cardinalities)
|
||||
ce_data = config.CollectionTemplate(
|
||||
name="ce_data",
|
||||
fields=field_templates,
|
||||
compound_indexes=[],
|
||||
cardinalities=collection_cardinalities,
|
||||
)
|
||||
|
||||
################################################################################
|
||||
# Database settings
|
||||
################################################################################
|
||||
|
||||
database_config = config.DatabaseConfig(
|
||||
connection_string='mongodb://localhost', database_name='ce_accuracy_test', dump_path=Path(
|
||||
'..', '..', 'jstests', 'query_golden', 'libs', 'data'),
|
||||
restore_from_dump=config.RestoreMode.NEVER, dump_on_exit=False)
|
||||
connection_string="mongodb://localhost",
|
||||
database_name="ce_accuracy_test",
|
||||
dump_path=Path("..", "..", "jstests", "query_golden", "libs", "data"),
|
||||
restore_from_dump=config.RestoreMode.NEVER,
|
||||
dump_on_exit=False,
|
||||
)
|
||||
|
||||
################################################################################
|
||||
# Data Generator settings
|
||||
################################################################################
|
||||
|
||||
data_generator_config = config.DataGeneratorConfig(
|
||||
enabled=True, create_indexes=False, batch_size=10000, collection_templates=[ce_data],
|
||||
write_mode=config.WriteMode.REPLACE, collection_name_with_card=True)
|
||||
enabled=True,
|
||||
create_indexes=False,
|
||||
batch_size=10000,
|
||||
collection_templates=[ce_data],
|
||||
write_mode=config.WriteMode.REPLACE,
|
||||
collection_name_with_card=True,
|
||||
)
|
||||
|
||||
@ -52,10 +52,15 @@ class CollectionTemplateEncoder(json.JSONEncoder):
|
||||
if isinstance(o, CollectionTemplate):
|
||||
collections = []
|
||||
for card in o.cardinalities:
|
||||
name = f'{o.name}_{card}'
|
||||
name = f"{o.name}_{card}"
|
||||
collections.append(
|
||||
dict(collectionName=name, fields=o.fields, compoundIndexes=o.compound_indexes,
|
||||
cardinality=card))
|
||||
dict(
|
||||
collectionName=name,
|
||||
fields=o.fields,
|
||||
compoundIndexes=o.compound_indexes,
|
||||
cardinality=card,
|
||||
)
|
||||
)
|
||||
return collections
|
||||
elif isinstance(o, FieldTemplate):
|
||||
return dict(fieldName=o.name, dataType=o.data_type, indexed=o.indexed)
|
||||
@ -82,11 +87,11 @@ async def dump_collection(db, dump_path, database_name, coll_name, chunk_size):
|
||||
"""Dump a collection into separate files each containing at most chunk_size documents."""
|
||||
|
||||
def open_chunk_file(chunk_id):
|
||||
chunk_name = f'{coll_name}_{chunk_id}'
|
||||
chunk_file_path = Path(dump_path) / f'{chunk_name}'
|
||||
print(f'Writing chunk: {chunk_file_path}')
|
||||
chunk_file = open(chunk_file_path, 'w', encoding="utf-8")
|
||||
chunk_file.write('// This is a generated file.\n')
|
||||
chunk_name = f"{coll_name}_{chunk_id}"
|
||||
chunk_file_path = Path(dump_path) / f"{chunk_name}"
|
||||
print(f"Writing chunk: {chunk_file_path}")
|
||||
chunk_file = open(chunk_file_path, "w", encoding="utf-8")
|
||||
chunk_file.write("// This is a generated file.\n")
|
||||
chunk_file.write(f'{chunk_name} = {{collName: "{coll_name}", collData: [\n')
|
||||
return chunk_file, "'" + chunk_name + "'"
|
||||
|
||||
@ -115,7 +120,7 @@ async def dump_collection(db, dump_path, database_name, coll_name, chunk_size):
|
||||
|
||||
chunk_file.write(json.dumps(doc, cls=OidEncoder))
|
||||
doc_pos += 1
|
||||
chunk_file.write(',')
|
||||
chunk_file.write(",")
|
||||
chunk_file.write("\n")
|
||||
close_chunk_file(chunk_file)
|
||||
return chunk_names
|
||||
@ -123,16 +128,17 @@ async def dump_collection(db, dump_path, database_name, coll_name, chunk_size):
|
||||
|
||||
async def dump_collections_to_json(db, dump_path, database_name, collections):
|
||||
chunk_size = 100 # number of documents per chunk file
|
||||
print(f'Dumping all collections into chunks of size {chunk_size}.')
|
||||
print(f"Dumping all collections into chunks of size {chunk_size}.")
|
||||
all_chunk_names = []
|
||||
for coll_name in collections:
|
||||
coll_chunk_names = await dump_collection(db, dump_path, database_name, coll_name,
|
||||
chunk_size)
|
||||
coll_chunk_names = await dump_collection(
|
||||
db, dump_path, database_name, coll_name, chunk_size
|
||||
)
|
||||
all_chunk_names.extend(coll_chunk_names)
|
||||
|
||||
# Generate a JS file that loads all chunk files
|
||||
load_file = open(Path(dump_path) / f'{database_name}.data', 'w')
|
||||
load_file.write('// This is a generated file.\n')
|
||||
load_file = open(Path(dump_path) / f"{database_name}.data", "w")
|
||||
load_file.write("// This is a generated file.\n")
|
||||
# Create an array named 'chunkNames' with all chunk file names to be loaded.
|
||||
load_file.write(f'const chunkNames = [{",".join(all_chunk_names)}];')
|
||||
|
||||
@ -146,9 +152,11 @@ async def generate_histograms(coll_template, coll, dump_path):
|
||||
doc_count = await coll.count_documents({})
|
||||
for field in coll_template.fields:
|
||||
field_data = []
|
||||
if re.match('^mixeddata_.*', field.name):
|
||||
if re.match("^mixeddata_.*", field.name):
|
||||
continue
|
||||
async for doc in coll.find({field.name: {"$exists": True}}, {"_id": 0, field.name: 1}):
|
||||
async for doc in coll.find(
|
||||
{field.name: {"$exists": True}}, {"_id": 0, field.name: 1}
|
||||
):
|
||||
field_val = doc[field.name]
|
||||
if isinstance(field_val, str):
|
||||
field_val = re.escape(field_val)
|
||||
@ -157,10 +165,11 @@ async def generate_histograms(coll_template, coll, dump_path):
|
||||
continue
|
||||
field_data.append(field_val)
|
||||
if len(field_data) > 0:
|
||||
fig_file_name = f'{dump_path}/{coll.name}_{field.name}.png'
|
||||
print(f'Generating histogram {fig_file_name}')
|
||||
hist = sns.displot(data=field_data, kind="hist",
|
||||
bins=round(math.sqrt(doc_count))).figure
|
||||
fig_file_name = f"{dump_path}/{coll.name}_{field.name}.png"
|
||||
print(f"Generating histogram {fig_file_name}")
|
||||
hist = sns.displot(
|
||||
data=field_data, kind="hist", bins=round(math.sqrt(doc_count))
|
||||
).figure
|
||||
hist.savefig(fig_file_name)
|
||||
plt.close(hist)
|
||||
|
||||
@ -173,7 +182,6 @@ async def main():
|
||||
# 1. Database Instance provides connectivity to a MongoDB instance, it loads data optionally
|
||||
# from the dump on creating and stores data optionally to the dump on closing.
|
||||
with DatabaseInstance(database_config) as database_instance:
|
||||
|
||||
# 2. Generate random data and populate collections with it.
|
||||
old_db_collections = await database_instance.database.list_collection_names()
|
||||
for coll_name in old_db_collections:
|
||||
@ -188,27 +196,40 @@ async def main():
|
||||
# TODO: This is an alternative way to export the data. It is better than what is implemented,
|
||||
# but cannot be used until we find a way to call 'mongoimport' from the corresponding JS test.
|
||||
#
|
||||
#for coll_name in db_collections:
|
||||
# for coll_name in db_collections:
|
||||
# subprocess.run([
|
||||
# 'mongoexport', f'--db={database_config.database_name}', f'--collection={coll_name}',
|
||||
# f'--out={coll_name}.dat'
|
||||
# ], cwd=database_config.dump_path, check=True)
|
||||
await dump_collections_to_json(database_instance.database, database_config.dump_path,
|
||||
database_config.database_name, db_collections)
|
||||
await dump_collections_to_json(
|
||||
database_instance.database,
|
||||
database_config.dump_path,
|
||||
database_config.database_name,
|
||||
db_collections,
|
||||
)
|
||||
|
||||
# 4. Export the collection templates used to create the test collections into JSON file
|
||||
with open(Path(database_config.dump_path) / f'{database_config.database_name}.schema',
|
||||
"w") as metadata_file:
|
||||
with open(
|
||||
Path(database_config.dump_path) / f"{database_config.database_name}.schema",
|
||||
"w",
|
||||
) as metadata_file:
|
||||
collections = []
|
||||
for coll_template in data_generator_config.collection_templates:
|
||||
for card in coll_template.cardinalities:
|
||||
name = f'{coll_template.name}_{card}'
|
||||
name = f"{coll_template.name}_{card}"
|
||||
collections.append(
|
||||
dict(collectionName=name, fields=coll_template.fields,
|
||||
compound_indexes=coll_template.compound_indexes, cardinality=card))
|
||||
dict(
|
||||
collectionName=name,
|
||||
fields=coll_template.fields,
|
||||
compound_indexes=coll_template.compound_indexes,
|
||||
cardinality=card,
|
||||
)
|
||||
)
|
||||
# Uncomment this to generate histograms in PNG format
|
||||
# await generate_histograms(coll_template, database_instance.database[name], database_config.dump_path)
|
||||
json_metadata = json.dumps(collections, indent=4, cls=CollectionTemplateEncoder)
|
||||
json_metadata = json.dumps(
|
||||
collections, indent=4, cls=CollectionTemplateEncoder
|
||||
)
|
||||
metadata_file.write("// This is a generated file.\nconst dbMetadata = ")
|
||||
metadata_file.write(json_metadata)
|
||||
metadata_file.write(";")
|
||||
@ -216,7 +237,7 @@ async def main():
|
||||
print("DONE!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
asyncio.run(main())
|
||||
|
||||
@ -39,7 +39,7 @@ def timer_decorator(func):
|
||||
t0 = time.perf_counter()
|
||||
result = func(*args, **kwargs)
|
||||
t1 = time.perf_counter()
|
||||
print(f'{func.__name__} took {t1-t0}s.')
|
||||
print(f"{func.__name__} took {t1-t0}s.")
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
@ -68,8 +68,9 @@ class LinearModel:
|
||||
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
def estimate(fit, X: np.ndarray, y: np.ndarray, test_size: float,
|
||||
trace: bool = False) -> LinearModel:
|
||||
def estimate(
|
||||
fit, X: np.ndarray, y: np.ndarray, test_size: float, trace: bool = False
|
||||
) -> LinearModel:
|
||||
"""Estimate cost model parameters."""
|
||||
|
||||
if len(X) == 0:
|
||||
@ -80,7 +81,7 @@ def estimate(fit, X: np.ndarray, y: np.ndarray, test_size: float,
|
||||
X_training, X_test, y_training, y_test = train_test_split(X, y, test_size=test_size)
|
||||
|
||||
if trace:
|
||||
print(f'Training size: {len(X_training)}, test size: {len(X_test)}')
|
||||
print(f"Training size: {len(X_training)}, test size: {len(X_test)}")
|
||||
print(X_training)
|
||||
print(y_training)
|
||||
|
||||
@ -96,5 +97,11 @@ def estimate(fit, X: np.ndarray, y: np.ndarray, test_size: float,
|
||||
evs = explained_variance_score(y_test, y_predict)
|
||||
corrcoef = np.corrcoef(np.transpose(X[:, 1:]), y)
|
||||
|
||||
return LinearModel(coef=coef[1:], intercept=coef[0], mse=mse, r2=r2, evs=evs,
|
||||
corrcoef=corrcoef[0, 1:])
|
||||
return LinearModel(
|
||||
coef=coef[1:],
|
||||
intercept=coef[0],
|
||||
mse=mse,
|
||||
r2=r2,
|
||||
evs=evs,
|
||||
corrcoef=corrcoef[0, 1:],
|
||||
)
|
||||
|
||||
@ -41,7 +41,7 @@ from motor.motor_asyncio import AsyncIOMotorCollection
|
||||
from pymongo import IndexModel
|
||||
from random_generator import DataType, RandomDistribution
|
||||
|
||||
__all__ = ['DataGenerator']
|
||||
__all__ = ["DataGenerator"]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -96,52 +96,78 @@ class DataGenerator:
|
||||
coll = self.database.database[coll_info.name]
|
||||
if self.config.write_mode == WriteMode.REPLACE:
|
||||
await coll.drop()
|
||||
tasks.append(asyncio.create_task(self._populate_collection(coll, coll_info)))
|
||||
tasks.append(
|
||||
asyncio.create_task(self._populate_collection(coll, coll_info))
|
||||
)
|
||||
if self.config.create_indexes:
|
||||
tasks.append(
|
||||
asyncio.create_task(create_single_field_indexes(coll, coll_info.fields)))
|
||||
tasks.append(asyncio.create_task(create_compound_indexes(coll, coll_info)))
|
||||
asyncio.create_task(
|
||||
create_single_field_indexes(coll, coll_info.fields)
|
||||
)
|
||||
)
|
||||
tasks.append(
|
||||
asyncio.create_task(create_compound_indexes(coll, coll_info))
|
||||
)
|
||||
|
||||
for task in tasks:
|
||||
await task
|
||||
|
||||
t1 = time.time()
|
||||
print(f'\npopulate Collections took {t1-t0} s.')
|
||||
print(f"\npopulate Collections took {t1-t0} s.")
|
||||
|
||||
def _generate_collection_infos(self):
|
||||
for coll_template in self.config.collection_templates:
|
||||
fields = [
|
||||
FieldInfo(name=ft.name, type=ft.data_type, distribution=ft.distribution,
|
||||
indexed=ft.indexed) for ft in coll_template.fields
|
||||
FieldInfo(
|
||||
name=ft.name,
|
||||
type=ft.data_type,
|
||||
distribution=ft.distribution,
|
||||
indexed=ft.indexed,
|
||||
)
|
||||
for ft in coll_template.fields
|
||||
]
|
||||
for doc_count in coll_template.cardinalities:
|
||||
name = f'{coll_template.name}'
|
||||
name = f"{coll_template.name}"
|
||||
if self.config.collection_name_with_card is True:
|
||||
name = f'{coll_template.name}_{doc_count}'
|
||||
yield CollectionInfo(name=name, fields=fields, documents_count=doc_count,
|
||||
compound_indexes=coll_template.compound_indexes)
|
||||
name = f"{coll_template.name}_{doc_count}"
|
||||
yield CollectionInfo(
|
||||
name=name,
|
||||
fields=fields,
|
||||
documents_count=doc_count,
|
||||
compound_indexes=coll_template.compound_indexes,
|
||||
)
|
||||
|
||||
async def _populate_collection(self, coll: AsyncIOMotorCollection,
|
||||
coll_info: CollectionInfo) -> None:
|
||||
print(f'\nGenerating ${coll_info.name} ...')
|
||||
async def _populate_collection(
|
||||
self, coll: AsyncIOMotorCollection, coll_info: CollectionInfo
|
||||
) -> None:
|
||||
print(f"\nGenerating ${coll_info.name} ...")
|
||||
batch_size = self.config.batch_size
|
||||
tasks = []
|
||||
for _ in range(coll_info.documents_count // batch_size):
|
||||
tasks.append(asyncio.create_task(populate_batch(coll, batch_size, coll_info.fields)))
|
||||
tasks.append(
|
||||
asyncio.create_task(populate_batch(coll, batch_size, coll_info.fields))
|
||||
)
|
||||
if coll_info.documents_count % batch_size > 0:
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
populate_batch(coll, coll_info.documents_count % batch_size, coll_info.fields)))
|
||||
populate_batch(
|
||||
coll, coll_info.documents_count % batch_size, coll_info.fields
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
for task in tasks:
|
||||
await task
|
||||
|
||||
|
||||
async def populate_batch(coll: AsyncIOMotorCollection, documents_count: int,
|
||||
fields: Sequence[FieldInfo]) -> None:
|
||||
async def populate_batch(
|
||||
coll: AsyncIOMotorCollection, documents_count: int, fields: Sequence[FieldInfo]
|
||||
) -> None:
|
||||
"""Generate collection data and write it to the collection."""
|
||||
|
||||
await coll.insert_many(generate_collection_data(documents_count, fields), ordered=False)
|
||||
await coll.insert_many(
|
||||
generate_collection_data(documents_count, fields), ordered=False
|
||||
)
|
||||
|
||||
|
||||
def generate_collection_data(documents_count: int, fields: Sequence[FieldInfo]):
|
||||
@ -149,30 +175,43 @@ def generate_collection_data(documents_count: int, fields: Sequence[FieldInfo]):
|
||||
|
||||
documents = [{} for _ in range(documents_count)]
|
||||
for field in fields:
|
||||
for field_index, field_data in enumerate(field.distribution.generate(documents_count)):
|
||||
for field_index, field_data in enumerate(
|
||||
field.distribution.generate(documents_count)
|
||||
):
|
||||
documents[field_index][field.name] = field_data
|
||||
return documents
|
||||
|
||||
|
||||
async def create_single_field_indexes(coll: AsyncIOMotorCollection,
|
||||
fields: Sequence[FieldInfo]) -> None:
|
||||
async def create_single_field_indexes(
|
||||
coll: AsyncIOMotorCollection, fields: Sequence[FieldInfo]
|
||||
) -> None:
|
||||
"""Create single-fields indexes on the given collection."""
|
||||
|
||||
indexes = [IndexModel([(field.name, pymongo.ASCENDING)]) for field in fields if field.indexed]
|
||||
indexes = [
|
||||
IndexModel([(field.name, pymongo.ASCENDING)])
|
||||
for field in fields
|
||||
if field.indexed
|
||||
]
|
||||
if len(indexes) > 0:
|
||||
await coll.create_indexes(indexes)
|
||||
print(f'create_single_field_indexes done. {[index.document for index in indexes]}')
|
||||
print(
|
||||
f"create_single_field_indexes done. {[index.document for index in indexes]}"
|
||||
)
|
||||
|
||||
|
||||
async def create_compound_indexes(coll: AsyncIOMotorCollection, coll_info: CollectionInfo) -> None:
|
||||
async def create_compound_indexes(
|
||||
coll: AsyncIOMotorCollection, coll_info: CollectionInfo
|
||||
) -> None:
|
||||
"""Create a coumpound indexes on the given collection."""
|
||||
|
||||
indexes_spec = []
|
||||
index_specs = []
|
||||
for compound_index in coll_info.compound_indexes:
|
||||
index_spec = IndexModel([(field, pymongo.ASCENDING) for field in compound_index])
|
||||
index_spec = IndexModel(
|
||||
[(field, pymongo.ASCENDING) for field in compound_index]
|
||||
)
|
||||
indexes_spec.append(index_spec)
|
||||
index_specs.append([(field, pymongo.ASCENDING) for field in compound_index])
|
||||
if len(indexes_spec) > 0:
|
||||
await coll.create_indexes(indexes_spec)
|
||||
print(f'create_compound_indexes done. {index_specs}')
|
||||
print(f"create_compound_indexes done. {index_specs}")
|
||||
|
||||
@ -36,9 +36,9 @@ from typing import Any, Mapping, NewType, Sequence
|
||||
from config import DatabaseConfig, RestoreMode
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
|
||||
__all__ = ['DatabaseInstance', 'Pipeline']
|
||||
__all__ = ["DatabaseInstance", "Pipeline"]
|
||||
"""MongoDB Aggregate's Pipeline"""
|
||||
Pipeline = NewType('Pipeline', Sequence[Mapping[str, Any]])
|
||||
Pipeline = NewType("Pipeline", Sequence[Mapping[str, Any]])
|
||||
|
||||
|
||||
class DatabaseInstance:
|
||||
@ -52,8 +52,9 @@ class DatabaseInstance:
|
||||
|
||||
def __enter__(self):
|
||||
if self.config.restore_from_dump == RestoreMode.ALWAYS or (
|
||||
self.config.restore_from_dump == RestoreMode.ONLY_NEW
|
||||
and self.config.database_name not in self.client.list_database_names()):
|
||||
self.config.restore_from_dump == RestoreMode.ONLY_NEW
|
||||
and self.config.database_name not in self.client.list_database_names()
|
||||
):
|
||||
self.restore()
|
||||
return self
|
||||
|
||||
@ -68,74 +69,94 @@ class DatabaseInstance:
|
||||
|
||||
def restore(self):
|
||||
"""Restore the database from the 'self.dump_directory'."""
|
||||
subprocess.run(['mongorestore', '--nsInclude', f'{self.config.database_name}.*', '--drop'],
|
||||
shell=True, check=True, cwd=self.config.dump_path)
|
||||
subprocess.run(
|
||||
["mongorestore", "--nsInclude", f"{self.config.database_name}.*", "--drop"],
|
||||
shell=True,
|
||||
check=True,
|
||||
cwd=self.config.dump_path,
|
||||
)
|
||||
|
||||
def dump(self):
|
||||
"""Dump the database into 'self.dump_directory'."""
|
||||
subprocess.run(['mongodump', '--db', self.config.database_name], cwd=self.config.dump_path,
|
||||
check=True)
|
||||
subprocess.run(
|
||||
["mongodump", "--db", self.config.database_name],
|
||||
cwd=self.config.dump_path,
|
||||
check=True,
|
||||
)
|
||||
|
||||
async def set_parameter(self, name: str, value: any) -> None:
|
||||
"""Set MongoDB Parameter."""
|
||||
await self.client.admin.command({'setParameter': 1, name: value})
|
||||
await self.client.admin.command({"setParameter": 1, name: value})
|
||||
|
||||
async def get_parameter(self, name: str) -> any:
|
||||
return (await self.client.admin.command({'getParameter': 1, name: 1}))[name]
|
||||
return (await self.client.admin.command({"getParameter": 1, name: 1}))[name]
|
||||
|
||||
async def enable_sbe(self, state: bool) -> None:
|
||||
"""Enable new query execution engine. Throw pymongo.errors.OperationFailure in case of failure."""
|
||||
await self.set_parameter('internalQueryFrameworkControl',
|
||||
'trySbeEngine' if state else 'forceClassicEngine')
|
||||
await self.set_parameter(
|
||||
"internalQueryFrameworkControl",
|
||||
"trySbeEngine" if state else "forceClassicEngine",
|
||||
)
|
||||
|
||||
async def enable_cascades(self, state: bool) -> None:
|
||||
"""Enable new query optimizer. Requires featureFlagCommonQueryFramework set to True."""
|
||||
|
||||
# Set FeatureCompatibilityVersion compatible with featureFlagCommonQueryFramework.
|
||||
version = (await self.client.admin.command(
|
||||
{'getParameter': 1,
|
||||
'featureFlagCommonQueryFramework': 1}))['featureFlagCommonQueryFramework']['version']
|
||||
version = (
|
||||
await self.client.admin.command(
|
||||
{"getParameter": 1, "featureFlagCommonQueryFramework": 1}
|
||||
)
|
||||
)["featureFlagCommonQueryFramework"]["version"]
|
||||
await self.client.admin.command(
|
||||
{'setFeatureCompatibilityVersion': version, 'confirm': True})
|
||||
{"setFeatureCompatibilityVersion": version, "confirm": True}
|
||||
)
|
||||
|
||||
await self.client.admin.command(
|
||||
{'configureFailPoint': 'enableExplainInBonsai', 'mode': 'alwaysOn'})
|
||||
await self.set_parameter('internalQueryFrameworkControl',
|
||||
'forceBonsai' if state else 'trySbeEngine')
|
||||
{"configureFailPoint": "enableExplainInBonsai", "mode": "alwaysOn"}
|
||||
)
|
||||
await self.set_parameter(
|
||||
"internalQueryFrameworkControl", "forceBonsai" if state else "trySbeEngine"
|
||||
)
|
||||
|
||||
async def explain(self, collection_name: str, pipeline: Pipeline) -> dict[str, any]:
|
||||
"""Return explain for the given pipeline."""
|
||||
return await self.database.command(
|
||||
'explain', {'aggregate': collection_name, 'pipeline': pipeline, 'cursor': {}},
|
||||
verbosity='executionStats')
|
||||
"explain",
|
||||
{"aggregate": collection_name, "pipeline": pipeline, "cursor": {}},
|
||||
verbosity="executionStats",
|
||||
)
|
||||
|
||||
async def hide_index(self, collection_name: str, index_name: str) -> None:
|
||||
"""Hide the given index from the query optimizer."""
|
||||
await self.database.command(
|
||||
{'collMod': collection_name, 'index': {'name': index_name, 'hidden': True}})
|
||||
{"collMod": collection_name, "index": {"name": index_name, "hidden": True}}
|
||||
)
|
||||
|
||||
async def unhide_index(self, collection_name: str, index_name: str) -> None:
|
||||
"""Make the given index visible for the query optimizer."""
|
||||
await self.database.command(
|
||||
{'collMod': collection_name, 'index': {'name': index_name, 'hidden': False}})
|
||||
{"collMod": collection_name, "index": {"name": index_name, "hidden": False}}
|
||||
)
|
||||
|
||||
async def hide_all_indexes(self, collection_name: str) -> None:
|
||||
"""Hide all indexes of the given collection from the query optimizer."""
|
||||
for index in self.database[collection_name].list_indexes():
|
||||
if index['name'] != '_id_':
|
||||
await self.hide_index(collection_name, index['name'])
|
||||
if index["name"] != "_id_":
|
||||
await self.hide_index(collection_name, index["name"])
|
||||
|
||||
async def unhide_all_indexes(self, collection_name: str) -> None:
|
||||
"""Make all indexes of the given collection visible fpr the query optimizer."""
|
||||
for index in self.database[collection_name].list_indexes():
|
||||
if index['name'] != '_id_':
|
||||
await self.unhide_index(collection_name, index['name'])
|
||||
if index["name"] != "_id_":
|
||||
await self.unhide_index(collection_name, index["name"])
|
||||
|
||||
async def drop_collection(self, collection_name: str) -> None:
|
||||
"""Drop collection."""
|
||||
await self.database[collection_name].drop()
|
||||
|
||||
async def insert_many(self, collection_name: str, docs: Sequence[Mapping[str, any]]) -> None:
|
||||
async def insert_many(
|
||||
self, collection_name: str, docs: Sequence[Mapping[str, any]]
|
||||
) -> None:
|
||||
"""Insert documents into the collection with the given name."""
|
||||
if len(docs) > 0:
|
||||
await self.database[collection_name].insert_many(docs, ordered=False)
|
||||
@ -146,12 +167,12 @@ class DatabaseInstance:
|
||||
|
||||
async def get_stats(self, collection_name: str):
|
||||
"""Get collection statistics."""
|
||||
return await self.database.command('collstats', collection_name)
|
||||
return await self.database.command("collstats", collection_name)
|
||||
|
||||
async def get_average_document_size(self, collection_name: str) -> float:
|
||||
"""Get average document size for the given collection."""
|
||||
stats = await self.get_stats(collection_name)
|
||||
avg_size = stats.get('avgObjSize')
|
||||
avg_size = stats.get("avgObjSize")
|
||||
return avg_size if avg_size is not None else 0
|
||||
|
||||
|
||||
@ -177,7 +198,9 @@ class DatabaseParameter:
|
||||
if self.original_value is not None:
|
||||
await self.set(self.original_value)
|
||||
else:
|
||||
raise ValueError(f'The parameter "{self.parameter_name}" has not been remembered.')
|
||||
raise ValueError(
|
||||
f'The parameter "{self.parameter_name}" has not been remembered.'
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
|
||||
@ -65,20 +65,20 @@ class CostEstimator:
|
||||
self.cost_model = cost_model
|
||||
|
||||
self.estimators = {
|
||||
'PhysicalScan': self.physical_scan,
|
||||
'IndexScan': self.index_scan,
|
||||
'Seek': self.seek,
|
||||
'Filter': self.filter,
|
||||
'Evaluation': self.evaluation,
|
||||
'GroupBy': self.group_by,
|
||||
'Unwind': self.unwind,
|
||||
'NestedLoopJoin': self.nested_loop_join,
|
||||
'HashJoin': self.hash_join,
|
||||
'MergeJoin': self.merge_join,
|
||||
'Unique': self.unique,
|
||||
'Union': self.union,
|
||||
'LimitSkip': self.limit_skip,
|
||||
'Root': self.root,
|
||||
"PhysicalScan": self.physical_scan,
|
||||
"IndexScan": self.index_scan,
|
||||
"Seek": self.seek,
|
||||
"Filter": self.filter,
|
||||
"Evaluation": self.evaluation,
|
||||
"GroupBy": self.group_by,
|
||||
"Unwind": self.unwind,
|
||||
"NestedLoopJoin": self.nested_loop_join,
|
||||
"HashJoin": self.hash_join,
|
||||
"MergeJoin": self.merge_join,
|
||||
"Unique": self.unique,
|
||||
"Union": self.union,
|
||||
"LimitSkip": self.limit_skip,
|
||||
"Root": self.root,
|
||||
}
|
||||
|
||||
def estimate(self, abt_node_name: str, cardinality: int) -> float:
|
||||
@ -88,55 +88,93 @@ class CostEstimator:
|
||||
|
||||
def physical_scan(self, cardinality: int) -> float:
|
||||
"""Estinamate PhysicalScan ABT node."""
|
||||
return self.cost_model.scan_startup_cost + cardinality * self.cost_model.scan_incremental_cost
|
||||
return (
|
||||
self.cost_model.scan_startup_cost
|
||||
+ cardinality * self.cost_model.scan_incremental_cost
|
||||
)
|
||||
|
||||
def index_scan(self, cardinality: int) -> float:
|
||||
"""Estinamate IndexScan ABT node."""
|
||||
return self.cost_model.index_scan_startup_cost + cardinality * self.cost_model.index_scan_incremental_cost
|
||||
return (
|
||||
self.cost_model.index_scan_startup_cost
|
||||
+ cardinality * self.cost_model.index_scan_incremental_cost
|
||||
)
|
||||
|
||||
def seek(self, cardinality: int) -> float:
|
||||
"""Estinamate Seek ABT node."""
|
||||
return self.cost_model.seek_startup_cost + cardinality * self.cost_model.seek_cost
|
||||
return (
|
||||
self.cost_model.seek_startup_cost + cardinality * self.cost_model.seek_cost
|
||||
)
|
||||
|
||||
def filter(self, cardinality: int) -> float:
|
||||
"""Estinamate Filter ABT node."""
|
||||
return self.cost_model.filter_startup_cost + cardinality * self.cost_model.filter_incremental_cost
|
||||
return (
|
||||
self.cost_model.filter_startup_cost
|
||||
+ cardinality * self.cost_model.filter_incremental_cost
|
||||
)
|
||||
|
||||
def evaluation(self, cardinality: int) -> float:
|
||||
"""Estinamate Evaluation ABT node."""
|
||||
return self.cost_model.eval_startup_cost + cardinality * self.cost_model.eval_incremental_cost
|
||||
return (
|
||||
self.cost_model.eval_startup_cost
|
||||
+ cardinality * self.cost_model.eval_incremental_cost
|
||||
)
|
||||
|
||||
def group_by(self, cardinality: int) -> float:
|
||||
"""Estinamate GroupBy ABT node."""
|
||||
return self.cost_model.group_by_startup_cost + cardinality * self.cost_model.group_by_incremental_cost
|
||||
return (
|
||||
self.cost_model.group_by_startup_cost
|
||||
+ cardinality * self.cost_model.group_by_incremental_cost
|
||||
)
|
||||
|
||||
def unwind(self, cardinality: int) -> float:
|
||||
"""Estinamate Unwind ABT node."""
|
||||
return self.cost_model.unwind_startup_cost + cardinality * self.cost_model.unwind_incremental_cost
|
||||
return (
|
||||
self.cost_model.unwind_startup_cost
|
||||
+ cardinality * self.cost_model.unwind_incremental_cost
|
||||
)
|
||||
|
||||
def nested_loop_join(self, cardinality: int) -> float:
|
||||
"""Estinamate NestedLoopJoin ABT node."""
|
||||
return self.cost_model.nested_loop_join_startup_cost + cardinality * self.cost_model.nested_loop_join_incremental_cost
|
||||
return (
|
||||
self.cost_model.nested_loop_join_startup_cost
|
||||
+ cardinality * self.cost_model.nested_loop_join_incremental_cost
|
||||
)
|
||||
|
||||
def hash_join(self, cardinality: int) -> float:
|
||||
"""Estinamate HashJoin ABT node."""
|
||||
return self.cost_model.hash_join_startup_cost + cardinality * self.cost_model.hash_join_incremental_cost
|
||||
return (
|
||||
self.cost_model.hash_join_startup_cost
|
||||
+ cardinality * self.cost_model.hash_join_incremental_cost
|
||||
)
|
||||
|
||||
def merge_join(self, cardinality: int) -> float:
|
||||
"""Estinamate MergeJoin ABT node."""
|
||||
return self.cost_model.merge_join_startup_cost + cardinality * self.cost_model.merge_join_incremental_cost
|
||||
return (
|
||||
self.cost_model.merge_join_startup_cost
|
||||
+ cardinality * self.cost_model.merge_join_incremental_cost
|
||||
)
|
||||
|
||||
def unique(self, cardinality: int) -> float:
|
||||
"""Estinamate Unique ABT node."""
|
||||
return self.cost_model.unique_startup_cost + cardinality * self.cost_model.unique_incremental_cost
|
||||
return (
|
||||
self.cost_model.unique_startup_cost
|
||||
+ cardinality * self.cost_model.unique_incremental_cost
|
||||
)
|
||||
|
||||
def union(self, cardinality: int) -> float:
|
||||
"""Estinamate Union ABT node."""
|
||||
return self.cost_model.union_startup_cost + cardinality * self.cost_model.union_incremental_cost
|
||||
return (
|
||||
self.cost_model.union_startup_cost
|
||||
+ cardinality * self.cost_model.union_incremental_cost
|
||||
)
|
||||
|
||||
def limit_skip(self, cardinality: int) -> float:
|
||||
"""Estinamate LimitSkip ABT node."""
|
||||
return self.cost_model.limit_skip_startup_cost + cardinality * self.cost_model.limit_skip_incremental_cost
|
||||
return (
|
||||
self.cost_model.limit_skip_startup_cost
|
||||
+ cardinality * self.cost_model.limit_skip_incremental_cost
|
||||
)
|
||||
|
||||
def root(self, _: int) -> float:
|
||||
"""Root ABT node is always 0."""
|
||||
@ -153,13 +191,23 @@ class AbtCostEstimator:
|
||||
def __init__(self, estimate_node: Callable[[str, int], float]):
|
||||
self.estimate_node = estimate_node
|
||||
|
||||
def estimate(self, abt: pt.Node, sbe: et.Node,
|
||||
estimations: Sequence[Tuple[str, ExecutionStats, float]], level=0):
|
||||
def estimate(
|
||||
self,
|
||||
abt: pt.Node,
|
||||
sbe: et.Node,
|
||||
estimations: Sequence[Tuple[str, ExecutionStats, float]],
|
||||
level=0,
|
||||
):
|
||||
stats = get_excution_stats(sbe, abt.plan_node_id)
|
||||
local_cost = self.estimate_node(abt.node_type, stats.n_processed)
|
||||
estimations.append((abt.node_type, stats, local_cost))
|
||||
child_cost = sum((self.estimate(child, sbe, estimations, level + 1)
|
||||
for child in abt.children), start=0.0)
|
||||
child_cost = sum(
|
||||
(
|
||||
self.estimate(child, sbe, estimations, level + 1)
|
||||
for child in abt.children
|
||||
),
|
||||
start=0.0,
|
||||
)
|
||||
return local_cost + child_cost
|
||||
|
||||
|
||||
@ -167,16 +215,29 @@ class AbtCostEstimator:
|
||||
class EndToEndStatisticsRow:
|
||||
"""Represents a row with descriptive statistics of one query execution."""
|
||||
|
||||
def __init__(self, pipeline: str = None, abt_type: str = None, abt_type_id: int = 0,
|
||||
execution_time: float = 0.0, estimated_cost: float = 0.0, n_documents: int = 0):
|
||||
self.pipeline = pipeline if pipeline is not None else ''
|
||||
self.abt_type = abt_type if abt_type is not None else ''
|
||||
def __init__(
|
||||
self,
|
||||
pipeline: str = None,
|
||||
abt_type: str = None,
|
||||
abt_type_id: int = 0,
|
||||
execution_time: float = 0.0,
|
||||
estimated_cost: float = 0.0,
|
||||
n_documents: int = 0,
|
||||
):
|
||||
self.pipeline = pipeline if pipeline is not None else ""
|
||||
self.abt_type = abt_type if abt_type is not None else ""
|
||||
self.abt_type_id = abt_type_id
|
||||
self.execution_time = execution_time
|
||||
self.estimated_cost = estimated_cost
|
||||
self.estimation_error = execution_time - estimated_cost
|
||||
self.estimation_error_per_doc = self.estimation_error / n_documents if n_documents != 0 else 0
|
||||
self.relative_error = self.estimation_error / self.execution_time if self.execution_time != 0 else 0
|
||||
self.estimation_error_per_doc = (
|
||||
self.estimation_error / n_documents if n_documents != 0 else 0
|
||||
)
|
||||
self.relative_error = (
|
||||
self.estimation_error / self.execution_time
|
||||
if self.execution_time != 0
|
||||
else 0
|
||||
)
|
||||
|
||||
pipeline: str
|
||||
abt_type: str
|
||||
@ -189,11 +250,20 @@ class EndToEndStatisticsRow:
|
||||
|
||||
|
||||
def make_config():
|
||||
def create_end2end_collection_template(name: str,
|
||||
cardinality: int) -> config.CollectionTemplate:
|
||||
def create_end2end_collection_template(
|
||||
name: str, cardinality: int
|
||||
) -> config.CollectionTemplate:
|
||||
values = [
|
||||
'iqtbr5b5is', 'vt5s3tf8o6', 'b0rgm58qsn', '9m59if353m', 'biw2l9ok17', 'b9ct0ue14d',
|
||||
'oxj0vxjsti', 'f3k8w9vb49', 'ec7v82k6nk', 'f49ufwaqx7'
|
||||
"iqtbr5b5is",
|
||||
"vt5s3tf8o6",
|
||||
"b0rgm58qsn",
|
||||
"9m59if353m",
|
||||
"biw2l9ok17",
|
||||
"b9ct0ue14d",
|
||||
"oxj0vxjsti",
|
||||
"f3k8w9vb49",
|
||||
"ec7v82k6nk",
|
||||
"f49ufwaqx7",
|
||||
]
|
||||
|
||||
start_weight = 30
|
||||
@ -208,108 +278,198 @@ def make_config():
|
||||
distr = RandomDistribution.choice(values, weights)
|
||||
|
||||
return config.CollectionTemplate(
|
||||
name=name, fields=[
|
||||
config.FieldTemplate(name="indexed_choice", data_type=config.DataType.STRING,
|
||||
distribution=distr, indexed=True),
|
||||
config.FieldTemplate(name="int1", data_type=config.DataType.INTEGER,
|
||||
distribution=distributions["int_normal"], indexed=True),
|
||||
config.FieldTemplate(name="non_indexed_choice", data_type=config.DataType.STRING,
|
||||
distribution=distributions['string_choice'], indexed=False),
|
||||
config.FieldTemplate(name="uniform1", data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_uniform"], indexed=False),
|
||||
config.FieldTemplate(name="int2", data_type=config.DataType.INTEGER,
|
||||
distribution=distributions["int_normal"], indexed=True),
|
||||
config.FieldTemplate(name="choice2", data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_choice"], indexed=False),
|
||||
config.FieldTemplate(name="mixed2", data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_mixed"], indexed=False),
|
||||
], compound_indexes=[], cardinalities=[cardinality])
|
||||
name=name,
|
||||
fields=[
|
||||
config.FieldTemplate(
|
||||
name="indexed_choice",
|
||||
data_type=config.DataType.STRING,
|
||||
distribution=distr,
|
||||
indexed=True,
|
||||
),
|
||||
config.FieldTemplate(
|
||||
name="int1",
|
||||
data_type=config.DataType.INTEGER,
|
||||
distribution=distributions["int_normal"],
|
||||
indexed=True,
|
||||
),
|
||||
config.FieldTemplate(
|
||||
name="non_indexed_choice",
|
||||
data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_choice"],
|
||||
indexed=False,
|
||||
),
|
||||
config.FieldTemplate(
|
||||
name="uniform1",
|
||||
data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_uniform"],
|
||||
indexed=False,
|
||||
),
|
||||
config.FieldTemplate(
|
||||
name="int2",
|
||||
data_type=config.DataType.INTEGER,
|
||||
distribution=distributions["int_normal"],
|
||||
indexed=True,
|
||||
),
|
||||
config.FieldTemplate(
|
||||
name="choice2",
|
||||
data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_choice"],
|
||||
indexed=False,
|
||||
),
|
||||
config.FieldTemplate(
|
||||
name="mixed2",
|
||||
data_type=config.DataType.STRING,
|
||||
distribution=distributions["string_mixed"],
|
||||
indexed=False,
|
||||
),
|
||||
],
|
||||
compound_indexes=[],
|
||||
cardinalities=[cardinality],
|
||||
)
|
||||
|
||||
col_end2end = create_end2end_collection_template('end2end', 2000000)
|
||||
col_end2end = create_end2end_collection_template("end2end", 2000000)
|
||||
|
||||
data_generator_config = config.DataGeneratorConfig(
|
||||
enabled=True, create_indexes=True, batch_size=10000, collection_templates=[col_end2end],
|
||||
write_mode=config.WriteMode.REPLACE, collection_name_with_card=True)
|
||||
enabled=True,
|
||||
create_indexes=True,
|
||||
batch_size=10000,
|
||||
collection_templates=[col_end2end],
|
||||
write_mode=config.WriteMode.REPLACE,
|
||||
collection_name_with_card=True,
|
||||
)
|
||||
|
||||
workload_execution_config = config.WorkloadExecutionConfig(
|
||||
enabled=True, output_collection_name='end2endData', write_mode=config.WriteMode.APPEND,
|
||||
warmup_runs=3, runs=30)
|
||||
enabled=True,
|
||||
output_collection_name="end2endData",
|
||||
write_mode=config.WriteMode.APPEND,
|
||||
warmup_runs=3,
|
||||
runs=30,
|
||||
)
|
||||
|
||||
# The cost model to test.
|
||||
cost_model = CostModelCoefficients(
|
||||
scan_incremental_cost=422.31145989, scan_startup_cost=6175.527218993269,
|
||||
index_scan_incremental_cost=403.68075869, index_scan_startup_cost=14054.983953111061,
|
||||
seek_cost=1223.35513997, seek_startup_cost=7488.662376624863,
|
||||
filter_incremental_cost=83.7274685, filter_startup_cost=1461.3148783443378,
|
||||
eval_incremental_cost=430.6176946, eval_startup_cost=1103.4048573163343,
|
||||
group_by_incremental_cost=413.07932374, group_by_startup_cost=1199.8878012735659,
|
||||
unwind_incremental_cost=586.57200195, unwind_startup_cost=1.0,
|
||||
scan_incremental_cost=422.31145989,
|
||||
scan_startup_cost=6175.527218993269,
|
||||
index_scan_incremental_cost=403.68075869,
|
||||
index_scan_startup_cost=14054.983953111061,
|
||||
seek_cost=1223.35513997,
|
||||
seek_startup_cost=7488.662376624863,
|
||||
filter_incremental_cost=83.7274685,
|
||||
filter_startup_cost=1461.3148783443378,
|
||||
eval_incremental_cost=430.6176946,
|
||||
eval_startup_cost=1103.4048573163343,
|
||||
group_by_incremental_cost=413.07932374,
|
||||
group_by_startup_cost=1199.8878012735659,
|
||||
unwind_incremental_cost=586.57200195,
|
||||
unwind_startup_cost=1.0,
|
||||
nested_loop_join_incremental_cost=161.62301944,
|
||||
nested_loop_join_startup_cost=402.8455479458652, hash_join_incremental_cost=250.61365634,
|
||||
hash_join_startup_cost=1.0, merge_join_incremental_cost=111.23423304,
|
||||
merge_join_startup_cost=1517.7970800404169, unique_incremental_cost=269.71368614,
|
||||
unique_startup_cost=1.0, union_incremental_cost=111.94945268,
|
||||
union_startup_cost=69.88096657391543, limit_skip_incremental_cost=62.42111111,
|
||||
limit_skip_startup_cost=655.1342592592522)
|
||||
nested_loop_join_startup_cost=402.8455479458652,
|
||||
hash_join_incremental_cost=250.61365634,
|
||||
hash_join_startup_cost=1.0,
|
||||
merge_join_incremental_cost=111.23423304,
|
||||
merge_join_startup_cost=1517.7970800404169,
|
||||
unique_incremental_cost=269.71368614,
|
||||
unique_startup_cost=1.0,
|
||||
union_incremental_cost=111.94945268,
|
||||
union_startup_cost=69.88096657391543,
|
||||
limit_skip_incremental_cost=62.42111111,
|
||||
limit_skip_startup_cost=655.1342592592522,
|
||||
)
|
||||
|
||||
cost_estimator = CostEstimator(cost_model)
|
||||
|
||||
processor_config = config.End2EndProcessorConfig(
|
||||
enabled=True, estimator=cost_estimator.estimate,
|
||||
input_collection_name=workload_execution_config.output_collection_name)
|
||||
enabled=True,
|
||||
estimator=cost_estimator.estimate,
|
||||
input_collection_name=workload_execution_config.output_collection_name,
|
||||
)
|
||||
|
||||
return config.EntToEndTestingConfig(
|
||||
database=main_config.database, data_generator=data_generator_config,
|
||||
workload_execution=workload_execution_config, processor=processor_config,
|
||||
result_csv_filepath="end2end.csv")
|
||||
database=main_config.database,
|
||||
data_generator=data_generator_config,
|
||||
workload_execution=workload_execution_config,
|
||||
processor=processor_config,
|
||||
result_csv_filepath="end2end.csv",
|
||||
)
|
||||
|
||||
|
||||
async def execute_queries(database: DatabaseInstance, we_config: config.WorkloadExecutionConfig,
|
||||
collections: Sequence[CollectionInfo]):
|
||||
collection = [ci for ci in collections if ci.name.startswith('end2end')][0]
|
||||
async def execute_queries(
|
||||
database: DatabaseInstance,
|
||||
we_config: config.WorkloadExecutionConfig,
|
||||
collections: Sequence[CollectionInfo],
|
||||
):
|
||||
collection = [ci for ci in collections if ci.name.startswith("end2end")][0]
|
||||
|
||||
requests = []
|
||||
|
||||
limits = [5, 10, 15, 20, 25, 50]
|
||||
skips = [15, 10, 5]
|
||||
|
||||
for field in [f for f in collection.fields if f.name == 'indexed_choice']:
|
||||
for field in [f for f in collection.fields if f.name == "indexed_choice"]:
|
||||
for val in field.distribution.get_values():
|
||||
if val.startswith('_'):
|
||||
if val.startswith("_"):
|
||||
continue
|
||||
limit = limits[len(requests) % len(limits)]
|
||||
skip = skips[len(requests) % len(skips)]
|
||||
requests.append(
|
||||
Query(pipeline=[{'$match': {field.name: val}}, {"$skip": skip}, {"$limit": limit},
|
||||
{"$project": {"int1": 1}}]))
|
||||
Query(
|
||||
pipeline=[
|
||||
{"$match": {field.name: val}},
|
||||
{"$skip": skip},
|
||||
{"$limit": limit},
|
||||
{"$project": {"int1": 1}},
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
for field in [f for f in collection.fields if f.name == 'non_indexed_choice']:
|
||||
for val in ['chisquare', 'hi']:
|
||||
for field in [f for f in collection.fields if f.name == "non_indexed_choice"]:
|
||||
for val in ["chisquare", "hi"]:
|
||||
limit = limits[len(requests) % len(limits)]
|
||||
skip = skips[len(requests) % len(skips)]
|
||||
requests.append(
|
||||
Query(pipeline=[{'$match': {field.name: val}}, {"$skip": skip}, {"$limit": limit},
|
||||
{"$project": {"int1": 1}}]))
|
||||
Query(
|
||||
pipeline=[
|
||||
{"$match": {field.name: val}},
|
||||
{"$skip": skip},
|
||||
{"$limit": limit},
|
||||
{"$project": {"int1": 1}},
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
for i in range(100, 1000, 250):
|
||||
limit = limits[len(requests) % len(limits)]
|
||||
skip = skips[len(requests) % len(skips)]
|
||||
|
||||
requests.append(
|
||||
Query(pipeline=[{'$match': {'in1': i, 'in2': 1000 -
|
||||
i}}, {"$skip": skip}, {"$limit": limit}]))
|
||||
Query(
|
||||
pipeline=[
|
||||
{"$match": {"in1": i, "in2": 1000 - i}},
|
||||
{"$skip": skip},
|
||||
{"$limit": limit},
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
requests.append(
|
||||
Query(pipeline=[{'$match': {'in1': {'$lte': i}, 'in2': 1000 - i}}, {"$skip": skip},
|
||||
{"$limit": limit}]))
|
||||
Query(
|
||||
pipeline=[
|
||||
{"$match": {"in1": {"$lte": i}, "in2": 1000 - i}},
|
||||
{"$skip": skip},
|
||||
{"$limit": limit},
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
await workload_execution.execute(database, we_config, [collection], requests)
|
||||
|
||||
|
||||
async def execute_index_intersect_queries(database: DatabaseInstance,
|
||||
we_config: config.WorkloadExecutionConfig,
|
||||
collections: Sequence[CollectionInfo]):
|
||||
collection = [ci for ci in collections if ci.name.startswith('end2end')][0]
|
||||
async def execute_index_intersect_queries(
|
||||
database: DatabaseInstance,
|
||||
we_config: config.WorkloadExecutionConfig,
|
||||
collections: Sequence[CollectionInfo],
|
||||
):
|
||||
collection = [ci for ci in collections if ci.name.startswith("end2end")][0]
|
||||
|
||||
requests = []
|
||||
|
||||
@ -321,19 +481,36 @@ async def execute_index_intersect_queries(database: DatabaseInstance,
|
||||
skip = skips[len(requests) % len(skips)]
|
||||
|
||||
requests.append(
|
||||
Query(pipeline=[{'$match': {'in1': i, 'in2': 1000 -
|
||||
i}}, {"$skip": skip}, {"$limit": limit}]))
|
||||
Query(
|
||||
pipeline=[
|
||||
{"$match": {"in1": i, "in2": 1000 - i}},
|
||||
{"$skip": skip},
|
||||
{"$limit": limit},
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
requests.append(
|
||||
Query(pipeline=[{'$match': {'in1': {'$lte': i}, 'in2': 1000 - i}}, {"$skip": skip},
|
||||
{"$limit": limit}]))
|
||||
Query(
|
||||
pipeline=[
|
||||
{"$match": {"in1": {"$lte": i}, "in2": 1000 - i}},
|
||||
{"$skip": skip},
|
||||
{"$limit": limit},
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
async with get_database_parameter(
|
||||
database, 'internalCostModelCoefficients') as cost_model_param, get_database_parameter(
|
||||
database, 'internalCascadesOptimizerDisableMergeJoinRIDIntersect'
|
||||
) as merge_join_param, get_database_parameter(
|
||||
database,
|
||||
'internalCascadesOptimizerDisableHashJoinRIDIntersect') as hash_join_param:
|
||||
async with (
|
||||
get_database_parameter(
|
||||
database, "internalCostModelCoefficients"
|
||||
) as cost_model_param,
|
||||
get_database_parameter(
|
||||
database, "internalCascadesOptimizerDisableMergeJoinRIDIntersect"
|
||||
) as merge_join_param,
|
||||
get_database_parameter(
|
||||
database, "internalCascadesOptimizerDisableHashJoinRIDIntersect"
|
||||
) as hash_join_param,
|
||||
):
|
||||
await cost_model_param.set('{"filterIncrementalCost": 10000.0}')
|
||||
await merge_join_param.set(False)
|
||||
await hash_join_param.set(False)
|
||||
@ -350,7 +527,7 @@ def extract_abt_nodes(df: pd.DataFrame, estimate_cost) -> pd.DataFrame:
|
||||
"""Extract ABT Nodes and execution statistics from calibration DataFrame."""
|
||||
|
||||
def extract(df_seq):
|
||||
es_dict = exp.extract_execution_stats(df_seq['sbe'], df_seq['abt'], [])
|
||||
es_dict = exp.extract_execution_stats(df_seq["sbe"], df_seq["abt"], [])
|
||||
|
||||
rows = []
|
||||
for abt_type, es in es_dict.items():
|
||||
@ -359,48 +536,70 @@ def extract_abt_nodes(df: pd.DataFrame, estimate_cost) -> pd.DataFrame:
|
||||
continue
|
||||
estimated_cost = estimate_cost(abt_type, stat.n_processed)
|
||||
rows.append(
|
||||
EndToEndStatisticsRow(abt_type=abt_type, execution_time=stat.execution_time,
|
||||
estimated_cost=estimated_cost,
|
||||
n_documents=stat.n_processed))
|
||||
EndToEndStatisticsRow(
|
||||
abt_type=abt_type,
|
||||
execution_time=stat.execution_time,
|
||||
estimated_cost=estimated_cost,
|
||||
n_documents=stat.n_processed,
|
||||
)
|
||||
)
|
||||
return rows
|
||||
|
||||
return pd.DataFrame(list(df.apply(extract, axis=1).explode()))
|
||||
|
||||
|
||||
def build_abt_nodes_report(df: pd.DataFrame, processor_config: config.End2EndProcessorConfig):
|
||||
def build_abt_nodes_report(
|
||||
df: pd.DataFrame, processor_config: config.End2EndProcessorConfig
|
||||
):
|
||||
return extract_abt_nodes(df, processor_config.estimator)
|
||||
|
||||
|
||||
def build_queries_report(df: pd.DataFrame, processor_config: config.End2EndProcessorConfig):
|
||||
def build_queries_report(
|
||||
df: pd.DataFrame, processor_config: config.End2EndProcessorConfig
|
||||
):
|
||||
abt_estimator = AbtCostEstimator(processor_config.estimator)
|
||||
|
||||
def calculate_cost(row):
|
||||
rows = []
|
||||
estimations = []
|
||||
total_estimated_cost = abt_estimator.estimate(row['abt'], row['sbe'], estimations)
|
||||
total_estimated_cost = abt_estimator.estimate(
|
||||
row["abt"], row["sbe"], estimations
|
||||
)
|
||||
local_id = 0
|
||||
rows.append(
|
||||
EndToEndStatisticsRow(pipeline=row['pipeline'], abt_type_id=local_id,
|
||||
execution_time=row['total_execution_time'],
|
||||
estimated_cost=total_estimated_cost))
|
||||
for (abt_type, stats, local_cost) in estimations:
|
||||
EndToEndStatisticsRow(
|
||||
pipeline=row["pipeline"],
|
||||
abt_type_id=local_id,
|
||||
execution_time=row["total_execution_time"],
|
||||
estimated_cost=total_estimated_cost,
|
||||
)
|
||||
)
|
||||
for abt_type, stats, local_cost in estimations:
|
||||
local_id += 1
|
||||
rows.append(
|
||||
EndToEndStatisticsRow(pipeline=row['pipeline'], abt_type=abt_type,
|
||||
abt_type_id=local_id,
|
||||
execution_time=row['total_execution_time'],
|
||||
estimated_cost=local_cost, n_documents=stats.n_processed))
|
||||
EndToEndStatisticsRow(
|
||||
pipeline=row["pipeline"],
|
||||
abt_type=abt_type,
|
||||
abt_type_id=local_id,
|
||||
execution_time=row["total_execution_time"],
|
||||
estimated_cost=local_cost,
|
||||
n_documents=stats.n_processed,
|
||||
)
|
||||
)
|
||||
return rows
|
||||
|
||||
return pd.DataFrame(list(df.apply(calculate_cost, axis=1).explode()))
|
||||
|
||||
|
||||
async def conduct_end2end(database: DatabaseInstance,
|
||||
processor_config: config.End2EndProcessorConfig):
|
||||
async def conduct_end2end(
|
||||
database: DatabaseInstance, processor_config: config.End2EndProcessorConfig
|
||||
):
|
||||
if not processor_config.enabled:
|
||||
return {}
|
||||
|
||||
df = await exp.load_calibration_data(database, processor_config.input_collection_name)
|
||||
df = await exp.load_calibration_data(
|
||||
database, processor_config.input_collection_name
|
||||
)
|
||||
noout_df = exp.remove_outliers(df, 0.0, 0.90)
|
||||
|
||||
abt_report = build_abt_nodes_report(noout_df, processor_config)
|
||||
@ -409,21 +608,26 @@ async def conduct_end2end(database: DatabaseInstance,
|
||||
|
||||
report = pd.concat([abt_report, queries_report], axis=0)
|
||||
|
||||
group_columns = ['pipeline', 'abt_type', 'abt_type_id']
|
||||
group_columns = ["pipeline", "abt_type", "abt_type_id"]
|
||||
|
||||
def calc_r2(group):
|
||||
return r2_score(group['execution_time'], group['estimated_cost'])
|
||||
return r2_score(group["execution_time"], group["estimated_cost"])
|
||||
|
||||
r2_scores = report.groupby(group_columns).apply(calc_r2).reset_index()
|
||||
r2_scores.columns = [group_columns + ['r2'], [''] * (len(group_columns) + 1)]
|
||||
r2_scores.columns = [group_columns + ["r2"], [""] * (len(group_columns) + 1)]
|
||||
|
||||
agg_stats = report.groupby(group_columns)[[
|
||||
'execution_time', 'estimated_cost', 'estimation_error', 'estimation_error_per_doc',
|
||||
'relative_error'
|
||||
]].agg([np.mean, np.std, np.min, np.max])
|
||||
agg_stats = report.groupby(group_columns)[
|
||||
[
|
||||
"execution_time",
|
||||
"estimated_cost",
|
||||
"estimation_error",
|
||||
"estimation_error_per_doc",
|
||||
"relative_error",
|
||||
]
|
||||
].agg([np.mean, np.std, np.min, np.max])
|
||||
|
||||
report = pd.merge(r2_scores, agg_stats, on=group_columns)
|
||||
del report['abt_type_id']
|
||||
del report["abt_type_id"]
|
||||
return report
|
||||
|
||||
|
||||
@ -434,7 +638,6 @@ async def end2end(e2e_config: config.EntToEndTestingConfig):
|
||||
# 1. Database Instance provides connectivity to a MongoDB instance, it loads data optionally
|
||||
# from the dump on creating and stores data optionally to the dump on closing.
|
||||
with DatabaseInstance(e2e_config.database) as database:
|
||||
|
||||
# 2. Data generation (optional), generates random data and populates collections with it.
|
||||
generator = DataGenerator(database, e2e_config.data_generator)
|
||||
await generator.populate_collections()
|
||||
@ -443,10 +646,12 @@ async def end2end(e2e_config: config.EntToEndTestingConfig):
|
||||
# It runs the pipelines and stores explains to the database.
|
||||
execution_query_functions = [execute_queries, execute_index_intersect_queries]
|
||||
for execute_query in execution_query_functions:
|
||||
await execute_query(database, e2e_config.workload_execution, generator.collection_infos)
|
||||
await execute_query(
|
||||
database, e2e_config.workload_execution, generator.collection_infos
|
||||
)
|
||||
e2e_config.workload_execution.write_mode = config.WriteMode.APPEND
|
||||
|
||||
#4. Process end to end testing. Compare the estimated and actual costs and return results.
|
||||
# 4. Process end to end testing. Compare the estimated and actual costs and return results.
|
||||
report = await conduct_end2end(database, e2e_config.processor)
|
||||
if e2e_config.result_csv_filepath is not None:
|
||||
report.to_csv(e2e_config.result_csv_filepath, index=False)
|
||||
@ -457,7 +662,7 @@ async def main():
|
||||
await end2end(e2e_config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
|
||||
@ -34,7 +34,7 @@ from typing import Optional
|
||||
|
||||
import bson.json_util as json
|
||||
|
||||
__all__ = ['Node', 'build_execution_tree']
|
||||
__all__ = ["Node", "build_execution_tree"]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -51,7 +51,9 @@ class Node:
|
||||
|
||||
def get_execution_time(self):
|
||||
"""Execution time of the SBE node without execuion time of its children."""
|
||||
return self.total_execution_time - sum(n.total_execution_time for n in self.children)
|
||||
return self.total_execution_time - sum(
|
||||
n.total_execution_time for n in self.children
|
||||
)
|
||||
|
||||
def print(self, level=0):
|
||||
"""Pretty print of the SBE tree."""
|
||||
@ -64,126 +66,152 @@ class Node:
|
||||
|
||||
def build_execution_tree(execution_stats: dict[str, any]) -> Node:
|
||||
"""Build SBE executioon tree from 'executionStats' field of query explain."""
|
||||
assert execution_stats['executionSuccess']
|
||||
return process_stage(execution_stats['executionStages'])
|
||||
assert execution_stats["executionSuccess"]
|
||||
return process_stage(execution_stats["executionStages"])
|
||||
|
||||
|
||||
def process_stage(stage: dict[str, any]) -> Node:
|
||||
"""Parse the given SBE stage."""
|
||||
processors = {
|
||||
'filter': process_filter,
|
||||
'cfilter': process_filter,
|
||||
'traverse': process_traverse,
|
||||
'project': process_inner_node,
|
||||
'limit': process_inner_node,
|
||||
'ixscan_generic': process_seek,
|
||||
'scan': process_seek,
|
||||
'coscan': process_leaf_node,
|
||||
'nlj': process_nlj,
|
||||
'hj': process_hash_join_node,
|
||||
'mj': process_hash_join_node,
|
||||
'seek': process_seek,
|
||||
'ixseek': process_seek,
|
||||
'limitskip': process_inner_node,
|
||||
'group': process_inner_node,
|
||||
'union': process_union_node,
|
||||
'unique': process_unique_node,
|
||||
'unwind': process_unwind_node,
|
||||
'branch': process_branch_node,
|
||||
"filter": process_filter,
|
||||
"cfilter": process_filter,
|
||||
"traverse": process_traverse,
|
||||
"project": process_inner_node,
|
||||
"limit": process_inner_node,
|
||||
"ixscan_generic": process_seek,
|
||||
"scan": process_seek,
|
||||
"coscan": process_leaf_node,
|
||||
"nlj": process_nlj,
|
||||
"hj": process_hash_join_node,
|
||||
"mj": process_hash_join_node,
|
||||
"seek": process_seek,
|
||||
"ixseek": process_seek,
|
||||
"limitskip": process_inner_node,
|
||||
"group": process_inner_node,
|
||||
"union": process_union_node,
|
||||
"unique": process_unique_node,
|
||||
"unwind": process_unwind_node,
|
||||
"branch": process_branch_node,
|
||||
}
|
||||
|
||||
processor = processors.get(stage['stage'])
|
||||
processor = processors.get(stage["stage"])
|
||||
if processor is None:
|
||||
print(json.dumps(stage, indent=4))
|
||||
raise ValueError(f'Unknown stage: {stage}')
|
||||
raise ValueError(f"Unknown stage: {stage}")
|
||||
|
||||
return processor(stage)
|
||||
|
||||
|
||||
def process_filter(stage: dict[str, any]) -> Node:
|
||||
"""Process filter stage."""
|
||||
input_stage = process_stage(stage['inputStage'])
|
||||
return Node(**get_common_fields(stage), n_processed=stage['numTested'], children=[input_stage])
|
||||
input_stage = process_stage(stage["inputStage"])
|
||||
return Node(
|
||||
**get_common_fields(stage),
|
||||
n_processed=stage["numTested"],
|
||||
children=[input_stage],
|
||||
)
|
||||
|
||||
|
||||
def process_traverse(stage: dict[str, any]) -> Node:
|
||||
"""Process traverse, not used by Bonsai."""
|
||||
outer_stage = process_stage(stage['outerStage'])
|
||||
inner_stage = process_stage(stage['innerStage'])
|
||||
return Node(**get_common_fields(stage), n_processed=stage['nReturned'],
|
||||
children=[outer_stage, inner_stage])
|
||||
outer_stage = process_stage(stage["outerStage"])
|
||||
inner_stage = process_stage(stage["innerStage"])
|
||||
return Node(
|
||||
**get_common_fields(stage),
|
||||
n_processed=stage["nReturned"],
|
||||
children=[outer_stage, inner_stage],
|
||||
)
|
||||
|
||||
|
||||
def process_hash_join_node(stage: dict[str, any]) -> Node:
|
||||
"""Process hj node."""
|
||||
outer_stage = process_stage(stage['outerStage'])
|
||||
inner_stage = process_stage(stage['innerStage'])
|
||||
outer_stage = process_stage(stage["outerStage"])
|
||||
inner_stage = process_stage(stage["innerStage"])
|
||||
n_processed = outer_stage.n_returned + inner_stage.n_returned
|
||||
return Node(**get_common_fields(stage), n_processed=n_processed,
|
||||
children=[outer_stage, inner_stage])
|
||||
return Node(
|
||||
**get_common_fields(stage),
|
||||
n_processed=n_processed,
|
||||
children=[outer_stage, inner_stage],
|
||||
)
|
||||
|
||||
|
||||
def process_nlj(stage: dict[str, any]) -> Node:
|
||||
"""Process nlj stage."""
|
||||
outer_stage = process_stage(stage['outerStage'])
|
||||
inner_stage = process_stage(stage['innerStage'])
|
||||
n_processed = stage['totalDocsExamined']
|
||||
return Node(**get_common_fields(stage), n_processed=n_processed,
|
||||
children=[outer_stage, inner_stage])
|
||||
outer_stage = process_stage(stage["outerStage"])
|
||||
inner_stage = process_stage(stage["innerStage"])
|
||||
n_processed = stage["totalDocsExamined"]
|
||||
return Node(
|
||||
**get_common_fields(stage),
|
||||
n_processed=n_processed,
|
||||
children=[outer_stage, inner_stage],
|
||||
)
|
||||
|
||||
|
||||
def process_inner_node(stage: dict[str, any]) -> Node:
|
||||
"""Process SBE stage with one input stage."""
|
||||
input_stage = process_stage(stage['inputStage'])
|
||||
return Node(**get_common_fields(stage), n_processed=input_stage.n_returned,
|
||||
children=[input_stage])
|
||||
input_stage = process_stage(stage["inputStage"])
|
||||
return Node(
|
||||
**get_common_fields(stage),
|
||||
n_processed=input_stage.n_returned,
|
||||
children=[input_stage],
|
||||
)
|
||||
|
||||
|
||||
def process_leaf_node(stage: dict[str, any]) -> Node:
|
||||
"""Process SBE stage without input stages."""
|
||||
return Node(**get_common_fields(stage), n_processed=stage['nReturned'], children=[])
|
||||
return Node(**get_common_fields(stage), n_processed=stage["nReturned"], children=[])
|
||||
|
||||
|
||||
def process_seek(stage: dict[str, any]) -> Node:
|
||||
"""Process seek stage."""
|
||||
return Node(**get_common_fields(stage), n_processed=stage['numReads'], children=[])
|
||||
return Node(**get_common_fields(stage), n_processed=stage["numReads"], children=[])
|
||||
|
||||
|
||||
def process_union_node(stage: dict[str, any]) -> Node:
|
||||
"""Process union stage."""
|
||||
children = [process_stage(child) for child in stage['inputStages']]
|
||||
return Node(**get_common_fields(stage), n_processed=stage['nReturned'], children=children)
|
||||
children = [process_stage(child) for child in stage["inputStages"]]
|
||||
return Node(
|
||||
**get_common_fields(stage), n_processed=stage["nReturned"], children=children
|
||||
)
|
||||
|
||||
|
||||
def process_unwind_node(stage: dict[str, any]) -> Node:
|
||||
"""Process unwind stage."""
|
||||
input_stage = process_stage(stage['inputStage'])
|
||||
return Node(**get_common_fields(stage), n_processed=input_stage.n_returned,
|
||||
children=[input_stage])
|
||||
input_stage = process_stage(stage["inputStage"])
|
||||
return Node(
|
||||
**get_common_fields(stage),
|
||||
n_processed=input_stage.n_returned,
|
||||
children=[input_stage],
|
||||
)
|
||||
|
||||
|
||||
def process_unique_node(stage: dict[str, any]) -> Node:
|
||||
"""Process unique stage."""
|
||||
input_stage = process_stage(stage['inputStage'])
|
||||
n_processed = stage['dupsTested']
|
||||
return Node(**get_common_fields(stage), n_processed=n_processed, children=[input_stage])
|
||||
input_stage = process_stage(stage["inputStage"])
|
||||
n_processed = stage["dupsTested"]
|
||||
return Node(
|
||||
**get_common_fields(stage), n_processed=n_processed, children=[input_stage]
|
||||
)
|
||||
|
||||
|
||||
def process_branch_node(stage: dict[str, any]) -> Node:
|
||||
"""Process unique stage."""
|
||||
then_stage = process_stage(stage['thenStage'])
|
||||
else_stage = process_stage(stage['elseStage'])
|
||||
then_stage = process_stage(stage["thenStage"])
|
||||
else_stage = process_stage(stage["elseStage"])
|
||||
n_processed = then_stage.n_returned + else_stage.n_returned
|
||||
return Node(**get_common_fields(stage), n_processed=n_processed,
|
||||
children=[then_stage, else_stage])
|
||||
return Node(
|
||||
**get_common_fields(stage),
|
||||
n_processed=n_processed,
|
||||
children=[then_stage, else_stage],
|
||||
)
|
||||
|
||||
|
||||
def get_common_fields(json_stage: dict[str, any]) -> dict[str, any]:
|
||||
"""Exctract common field from json representation of SBE stage."""
|
||||
return {
|
||||
'stage': json_stage['stage'],
|
||||
'plan_node_id': json_stage['planNodeId'],
|
||||
'total_execution_time': json_stage['executionTimeNanos'],
|
||||
'n_returned': json_stage['nReturned'],
|
||||
'seeks': json_stage.get('seeks'),
|
||||
"stage": json_stage["stage"],
|
||||
"plan_node_id": json_stage["planNodeId"],
|
||||
"total_execution_time": json_stage["executionTimeNanos"],
|
||||
"n_returned": json_stage["nReturned"],
|
||||
"seeks": json_stage.get("seeks"),
|
||||
}
|
||||
|
||||
@ -109,21 +109,26 @@ from sklearn.linear_model import LinearRegression
|
||||
from sklearn.metrics import r2_score
|
||||
|
||||
|
||||
async def load_calibration_data(database: DatabaseInstance, collection_name: str) -> pd.DataFrame:
|
||||
async def load_calibration_data(
|
||||
database: DatabaseInstance, collection_name: str
|
||||
) -> pd.DataFrame:
|
||||
"""Load workflow data containing explain output from database and parse it. Retuned calibration DataFrame with parsed SBE and ABT."""
|
||||
|
||||
data = await database.get_all_documents(collection_name)
|
||||
df = pd.DataFrame(data)
|
||||
df['sbe'] = df.explain.apply(lambda e: sbe.build_execution_tree(
|
||||
json.loads(e)['executionStats']))
|
||||
df['abt'] = df.explain.apply(lambda e: abt.build(
|
||||
json.loads(e)['queryPlanner']['winningPlan']['queryPlan']))
|
||||
df['total_execution_time'] = df.sbe.apply(lambda t: t.total_execution_time)
|
||||
df["sbe"] = df.explain.apply(
|
||||
lambda e: sbe.build_execution_tree(json.loads(e)["executionStats"])
|
||||
)
|
||||
df["abt"] = df.explain.apply(
|
||||
lambda e: abt.build(json.loads(e)["queryPlanner"]["winningPlan"]["queryPlan"])
|
||||
)
|
||||
df["total_execution_time"] = df.sbe.apply(lambda t: t.total_execution_time)
|
||||
return df
|
||||
|
||||
|
||||
def remove_outliers(df: pd.DataFrame, lower_percentile: float = 0.1,
|
||||
upper_percentile: float = 0.9) -> pd.DataFrame:
|
||||
def remove_outliers(
|
||||
df: pd.DataFrame, lower_percentile: float = 0.1, upper_percentile: float = 0.9
|
||||
) -> pd.DataFrame:
|
||||
"""Remove the outliers from the parsed calibration DataFrame."""
|
||||
|
||||
def is_not_outlier(df_seq):
|
||||
@ -131,8 +136,11 @@ def remove_outliers(df: pd.DataFrame, lower_percentile: float = 0.1,
|
||||
high = df_seq.quantile(upper_percentile)
|
||||
return (df_seq >= low) & (df_seq <= high)
|
||||
|
||||
return df[df.groupby(['run_id', 'collection',
|
||||
'pipeline']).total_execution_time.transform(is_not_outlier).eq(1)]
|
||||
return df[
|
||||
df.groupby(["run_id", "collection", "pipeline"])
|
||||
.total_execution_time.transform(is_not_outlier)
|
||||
.eq(1)
|
||||
]
|
||||
|
||||
|
||||
def extract_sbe_stages(df: pd.DataFrame) -> pd.DataFrame:
|
||||
@ -140,18 +148,24 @@ def extract_sbe_stages(df: pd.DataFrame) -> pd.DataFrame:
|
||||
|
||||
def flatten_sbe_stages(explain):
|
||||
def traverse(node, stages):
|
||||
execution_time = node['executionTimeNanos']
|
||||
children_fields = ['innerStage', 'outerStage', 'inputStage', 'thenStage', 'elseStage']
|
||||
execution_time = node["executionTimeNanos"]
|
||||
children_fields = [
|
||||
"innerStage",
|
||||
"outerStage",
|
||||
"inputStage",
|
||||
"thenStage",
|
||||
"elseStage",
|
||||
]
|
||||
for field in children_fields:
|
||||
if field in node and node[field]:
|
||||
child = node[field]
|
||||
execution_time -= child['executionTimeNanos']
|
||||
execution_time -= child["executionTimeNanos"]
|
||||
traverse(child, stages)
|
||||
del node[field]
|
||||
node['executionTime'] = execution_time
|
||||
node["executionTime"] = execution_time
|
||||
stages.append(node)
|
||||
|
||||
sbe_tree = json.loads(explain)['executionStats']['executionStages']
|
||||
sbe_tree = json.loads(explain)["executionStats"]["executionStages"]
|
||||
result = []
|
||||
traverse(sbe_tree, result)
|
||||
return result
|
||||
@ -168,15 +182,18 @@ def extract_abt_nodes(df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""Extract ABT Nodes and execution statistics from calibration DataFrame."""
|
||||
|
||||
def extract(df_seq):
|
||||
es_dict = extract_execution_stats(df_seq['sbe'], df_seq['abt'], [])
|
||||
es_dict = extract_execution_stats(df_seq["sbe"], df_seq["abt"], [])
|
||||
|
||||
rows = []
|
||||
for abt_type, es in es_dict.items():
|
||||
for stat in es:
|
||||
row = {
|
||||
'abt_type': abt_type, **dataclasses.asdict(stat),
|
||||
**json.loads(df_seq['query_parameters']), 'run_id': df_seq.run_id,
|
||||
'pipeline': df_seq.pipeline, 'source': df_seq.name
|
||||
"abt_type": abt_type,
|
||||
**dataclasses.asdict(stat),
|
||||
**json.loads(df_seq["query_parameters"]),
|
||||
"run_id": df_seq.run_id,
|
||||
"pipeline": df_seq.pipeline,
|
||||
"source": df_seq.name,
|
||||
}
|
||||
rows.append(row)
|
||||
return rows
|
||||
@ -187,13 +204,15 @@ def extract_abt_nodes(df: pd.DataFrame) -> pd.DataFrame:
|
||||
def print_trees(calibration_df: pd.DataFrame, abt_df: pd.DataFrame, row_index: int = 0):
|
||||
"""Print SBE and ABT Trees."""
|
||||
row = calibration_df.loc[abt_df.iloc[row_index].source]
|
||||
print('SBE')
|
||||
print("SBE")
|
||||
row.sbe.print()
|
||||
print('\nABT')
|
||||
print("\nABT")
|
||||
row.abt.print()
|
||||
|
||||
|
||||
def print_explain(calibration_df: pd.DataFrame, abt_df: pd.DataFrame, row_index: int = 0):
|
||||
def print_explain(
|
||||
calibration_df: pd.DataFrame, abt_df: pd.DataFrame, row_index: int = 0
|
||||
):
|
||||
"""Print explain."""
|
||||
row = calibration_df.loc[abt_df.iloc[row_index].source]
|
||||
explain = json.loads(row.explain)
|
||||
@ -205,34 +224,38 @@ def calibrate(abt_node_df: pd.DataFrame, variables: list[str] = None):
|
||||
"""Calibrate the ABT node given in abd_node_df with the given model input variables."""
|
||||
# pylint: disable=invalid-name
|
||||
if variables is None:
|
||||
variables = ['n_processed']
|
||||
y = abt_node_df['execution_time']
|
||||
variables = ["n_processed"]
|
||||
y = abt_node_df["execution_time"]
|
||||
X = abt_node_df[variables]
|
||||
X = sm.add_constant(X)
|
||||
|
||||
nnls = LinearRegression(positive=True, fit_intercept=False)
|
||||
model = nnls.fit(X, y)
|
||||
y_pred = model.predict(X)
|
||||
print(f'R2: {r2_score(y, y_pred)}')
|
||||
print(f'Coefficients: {model.coef_}')
|
||||
print(f"R2: {r2_score(y, y_pred)}")
|
||||
print(f"Coefficients: {model.coef_}")
|
||||
|
||||
sns.scatterplot(x=abt_node_df['n_processed'], y=abt_node_df['execution_time'])
|
||||
sns.lineplot(x=abt_node_df['n_processed'], y=y_pred, color='red')
|
||||
sns.scatterplot(x=abt_node_df["n_processed"], y=abt_node_df["execution_time"])
|
||||
sns.lineplot(x=abt_node_df["n_processed"], y=y_pred, color="red")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
from config import DatabaseConfig
|
||||
|
||||
async def test():
|
||||
"""Smoke tests."""
|
||||
database_config = DatabaseConfig(connection_string='mongodb://localhost',
|
||||
database_name='abt_calibration', dump_path='',
|
||||
restore_from_dump=False, dump_on_exit=False)
|
||||
database_config = DatabaseConfig(
|
||||
connection_string="mongodb://localhost",
|
||||
database_name="abt_calibration",
|
||||
dump_path="",
|
||||
restore_from_dump=False,
|
||||
dump_on_exit=False,
|
||||
)
|
||||
database = DatabaseInstance(database_config)
|
||||
|
||||
raw_df = await load_calibration_data(database, 'calibrationData')
|
||||
raw_df = await load_calibration_data(database, "calibrationData")
|
||||
print(raw_df.head())
|
||||
|
||||
cleaned_df = remove_outliers(raw_df, 0.0, 0.9)
|
||||
@ -241,7 +264,7 @@ if __name__ == '__main__':
|
||||
sbe_stages_df = extract_sbe_stages(cleaned_df)
|
||||
print(sbe_stages_df.head())
|
||||
|
||||
seek_df = get_sbe_stage(sbe_stages_df, 'seek')
|
||||
seek_df = get_sbe_stage(sbe_stages_df, "seek")
|
||||
print(seek_df.head())
|
||||
|
||||
abt_nodes_df = extract_abt_nodes(cleaned_df)
|
||||
|
||||
@ -40,20 +40,20 @@ from cost_estimator import CostModelParameters, ExecutionStats
|
||||
from database_instance import DatabaseInstance
|
||||
from workload_execution import QueryParameters
|
||||
|
||||
__all__ = ['extract_parameters', 'extract_execution_stats']
|
||||
__all__ = ["extract_parameters", "extract_execution_stats"]
|
||||
|
||||
|
||||
async def extract_parameters(
|
||||
config: AbtCalibratorConfig, database: DatabaseInstance,
|
||||
abt_types: Sequence[str]) -> Mapping[str, Sequence[CostModelParameters]]:
|
||||
config: AbtCalibratorConfig, database: DatabaseInstance, abt_types: Sequence[str]
|
||||
) -> Mapping[str, Sequence[CostModelParameters]]:
|
||||
"""Read measurements from database and extract cost model parameters for the given ABT types."""
|
||||
|
||||
stats = defaultdict(list)
|
||||
|
||||
docs = await database.get_all_documents(config.input_collection_name)
|
||||
for result in docs:
|
||||
explain = json.loads(result['explain'])
|
||||
query_parameters = QueryParameters.from_json(result['query_parameters'])
|
||||
explain = json.loads(result["explain"])
|
||||
query_parameters = QueryParameters.from_json(result["query_parameters"])
|
||||
res = parse_explain(explain, abt_types)
|
||||
for abt_type, es in res.items():
|
||||
stats[abt_type] += [
|
||||
@ -65,10 +65,12 @@ async def extract_parameters(
|
||||
return stats
|
||||
|
||||
|
||||
Node = TypeVar('Node')
|
||||
Node = TypeVar("Node")
|
||||
|
||||
|
||||
def find_abt_node_by_type(root: physical_tree.Node, abt_type: str) -> Sequence[physical_tree.Node]:
|
||||
def find_abt_node_by_type(
|
||||
root: physical_tree.Node, abt_type: str
|
||||
) -> Sequence[physical_tree.Node]:
|
||||
"""Find ABT node by its type."""
|
||||
return find_nodes(root, lambda node: node.node_type == abt_type)
|
||||
|
||||
@ -111,26 +113,28 @@ def get_excution_stats(root: execution_tree.Node, node_id: int) -> ExecutionStat
|
||||
|
||||
assert n_sbe_nodes <= 1
|
||||
|
||||
return ExecutionStats(execution_time=execution_time, n_returned=n_returned,
|
||||
n_processed=n_processed)
|
||||
return ExecutionStats(
|
||||
execution_time=execution_time, n_returned=n_returned, n_processed=n_processed
|
||||
)
|
||||
|
||||
|
||||
def parse_explain(explain: Mapping[str, any], abt_types: Sequence[str]):
|
||||
"""Extract ExecutionStats from the given explain for the given ABT types."""
|
||||
|
||||
try:
|
||||
et = execution_tree.build_execution_tree(explain['executionStats'])
|
||||
pt = physical_tree.build(explain['queryPlanner']['winningPlan']['queryPlan'])
|
||||
et = execution_tree.build_execution_tree(explain["executionStats"])
|
||||
pt = physical_tree.build(explain["queryPlanner"]["winningPlan"]["queryPlan"])
|
||||
except Exception as exception:
|
||||
print(f'*** Failed to parse explain with the followinf error: {exception}')
|
||||
print(f"*** Failed to parse explain with the followinf error: {exception}")
|
||||
print(explain)
|
||||
raise exception
|
||||
|
||||
return extract_execution_stats(et, pt, abt_types)
|
||||
|
||||
|
||||
def extract_execution_stats(et: execution_tree.Node, pt: physical_tree.Node,
|
||||
abt_types: Sequence[str]) -> Mapping[str, Sequence[ExecutionStats]]:
|
||||
def extract_execution_stats(
|
||||
et: execution_tree.Node, pt: physical_tree.Node, abt_types: Sequence[str]
|
||||
) -> Mapping[str, Sequence[ExecutionStats]]:
|
||||
"""Extract ExecutionStats from the given SBE and ABT trees for the given ABT types."""
|
||||
|
||||
if len(abt_types) == 0:
|
||||
@ -144,7 +148,7 @@ def extract_execution_stats(et: execution_tree.Node, pt: physical_tree.Node,
|
||||
result[abt_type].append(execution_stats)
|
||||
return result
|
||||
except AssertionError as ae:
|
||||
print(f'{pt.node_type} {ae} {pt}')
|
||||
print(f"{pt.node_type} {ae} {pt}")
|
||||
raise ae
|
||||
|
||||
|
||||
|
||||
@ -31,7 +31,7 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
__all__ = ['Node', 'build']
|
||||
__all__ = ["Node", "build"]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -62,20 +62,25 @@ def build(optimizer_plan: dict[str, any]) -> Node:
|
||||
def parse_optimizer_node(explain_node: dict[str, any]) -> Node:
|
||||
"""Recursively parse ABT node from query explain's node."""
|
||||
children = get_children(explain_node)
|
||||
properties = explain_node['properties']
|
||||
return Node(node_type=explain_node['nodeType'], plan_node_id=properties['planNodeID'],
|
||||
cost=properties['cost'], local_cost=properties['localCost'],
|
||||
adjusted_ce=properties['adjustedCE'], children=children)
|
||||
properties = explain_node["properties"]
|
||||
return Node(
|
||||
node_type=explain_node["nodeType"],
|
||||
plan_node_id=properties["planNodeID"],
|
||||
cost=properties["cost"],
|
||||
local_cost=properties["localCost"],
|
||||
adjusted_ce=properties["adjustedCE"],
|
||||
children=children,
|
||||
)
|
||||
|
||||
|
||||
def get_children(explain_node: dict[str, any]) -> list[Node]:
|
||||
"""Get children nodes of the ABT node."""
|
||||
if 'children' in explain_node:
|
||||
children = [parse_optimizer_node(child) for child in explain_node['children']]
|
||||
if "children" in explain_node:
|
||||
children = [parse_optimizer_node(child) for child in explain_node["children"]]
|
||||
else:
|
||||
children = []
|
||||
|
||||
children_refs = ['child', 'leftChild', 'rightChild']
|
||||
children_refs = ["child", "leftChild", "rightChild"]
|
||||
|
||||
for ref in children_refs:
|
||||
if ref in explain_node:
|
||||
|
||||
@ -34,7 +34,7 @@ import pandas as pd
|
||||
import query_solution_tree as qsn
|
||||
from workload_execution import QueryParameters
|
||||
|
||||
Node = TypeVar('Node')
|
||||
Node = TypeVar("Node")
|
||||
|
||||
|
||||
def parse_explain(explain: dict[str, any]) -> (qsn.Node, sbe.Node):
|
||||
@ -83,41 +83,69 @@ class ParametersBuilder:
|
||||
|
||||
def buildDataFrame(self) -> pd.DataFrame:
|
||||
return pd.DataFrame(
|
||||
self.rows, columns=[
|
||||
'stage', 'execution_time', 'n_processed', 'seeks', 'note', 'keys_length_in_bytes',
|
||||
'average_document_size_in_bytes', 'number_of_fields'
|
||||
])
|
||||
self.rows,
|
||||
columns=[
|
||||
"stage",
|
||||
"execution_time",
|
||||
"n_processed",
|
||||
"seeks",
|
||||
"note",
|
||||
"keys_length_in_bytes",
|
||||
"average_document_size_in_bytes",
|
||||
"number_of_fields",
|
||||
],
|
||||
)
|
||||
|
||||
def _process(self, qsn_node: qsn.Node, sbe_tree: sbe.Node, params: QueryParameters):
|
||||
processor = self._get_processor(qsn_node.node_type)
|
||||
self.rows.append(processor(qsn_node.node_type, qsn_node.plan_node_id, sbe_tree, params))
|
||||
self.rows.append(
|
||||
processor(qsn_node.node_type, qsn_node.plan_node_id, sbe_tree, params)
|
||||
)
|
||||
for child in qsn_node.children:
|
||||
self._process(child, sbe_tree, params)
|
||||
|
||||
def _get_processor(self, stage: str):
|
||||
return self.processors.get(stage, self.default_processor)
|
||||
|
||||
def _process_generic(self, stage: str, node_id: int, sbe_tree: sbe.Node,
|
||||
params: QueryParameters):
|
||||
nodes: list[sbe.Node] = find_nodes(sbe_tree, lambda node: node.plan_node_id == node_id)
|
||||
def _process_generic(
|
||||
self, stage: str, node_id: int, sbe_tree: sbe.Node, params: QueryParameters
|
||||
):
|
||||
nodes: list[sbe.Node] = find_nodes(
|
||||
sbe_tree, lambda node: node.plan_node_id == node_id
|
||||
)
|
||||
if len(nodes) == 0:
|
||||
raise ValueError(f"Cannot find sbe nodes of {stage}")
|
||||
return ParametersBuilder._build_row(
|
||||
stage, params, execution_time=sum([node.get_execution_time() for node in nodes]),
|
||||
n_processed=max([node.n_processed for node in nodes]), seeks=sum(
|
||||
[node.seeks for node in nodes if node.seeks]))
|
||||
stage,
|
||||
params,
|
||||
execution_time=sum([node.get_execution_time() for node in nodes]),
|
||||
n_processed=max([node.n_processed for node in nodes]),
|
||||
seeks=sum([node.seeks for node in nodes if node.seeks]),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_row(stage: str, params: QueryParameters, execution_time: int = None,
|
||||
n_processed: int = None, seeks: int = None):
|
||||
def _build_row(
|
||||
stage: str,
|
||||
params: QueryParameters,
|
||||
execution_time: int = None,
|
||||
n_processed: int = None,
|
||||
seeks: int = None,
|
||||
):
|
||||
return [
|
||||
stage, execution_time, n_processed, seeks, params.note, params.keys_length_in_bytes,
|
||||
params.average_document_size_in_bytes, params.number_of_fields
|
||||
stage,
|
||||
execution_time,
|
||||
n_processed,
|
||||
seeks,
|
||||
params.note,
|
||||
params.keys_length_in_bytes,
|
||||
params.average_document_size_in_bytes,
|
||||
params.number_of_fields,
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import json
|
||||
|
||||
explain = """
|
||||
{
|
||||
"explainVersion": "2",
|
||||
@ -746,7 +774,7 @@ if __name__ == "__main__":
|
||||
qsn_tree.print()
|
||||
sbe_tree.print()
|
||||
|
||||
params = QueryParameters(10, 2000, 'rooted-or')
|
||||
params = QueryParameters(10, 2000, "rooted-or")
|
||||
builder = ParametersBuilder()
|
||||
builder.process(explainJson, params)
|
||||
df = builder.buildDataFrame()
|
||||
|
||||
@ -31,7 +31,7 @@ from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
__all__ = ['Node', 'build']
|
||||
__all__ = ["Node", "build"]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -59,17 +59,22 @@ def parse_optimizer_node(explain_node: dict[str, any]) -> Node:
|
||||
"""Recursively parse QSN from query explain's node."""
|
||||
|
||||
children = get_children(explain_node)
|
||||
return Node(node_type=explain_node['stage'], plan_node_id=explain_node['planNodeId'],
|
||||
children=children)
|
||||
return Node(
|
||||
node_type=explain_node["stage"],
|
||||
plan_node_id=explain_node["planNodeId"],
|
||||
children=children,
|
||||
)
|
||||
|
||||
|
||||
def get_children(explain_node: dict[str, any]) -> list[Node]:
|
||||
"""Get children nodes of the QSN."""
|
||||
|
||||
if 'inputStage' in explain_node:
|
||||
children = [parse_optimizer_node(explain_node['inputStage'])]
|
||||
elif 'inputStages' in explain_node:
|
||||
children = [parse_optimizer_node(child) for child in explain_node['inputStages']]
|
||||
if "inputStage" in explain_node:
|
||||
children = [parse_optimizer_node(explain_node["inputStage"])]
|
||||
elif "inputStages" in explain_node:
|
||||
children = [
|
||||
parse_optimizer_node(child) for child in explain_node["inputStages"]
|
||||
]
|
||||
else:
|
||||
children = []
|
||||
return children
|
||||
@ -77,6 +82,7 @@ def get_children(explain_node: dict[str, any]) -> list[Node]:
|
||||
|
||||
if __name__ == "__main__":
|
||||
import json
|
||||
|
||||
explain = """
|
||||
{
|
||||
"explainVersion": "2",
|
||||
|
||||
@ -38,9 +38,9 @@ from typing import Generic, Sequence, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['RangeGenerator', 'DataType', 'RandomDistribution']
|
||||
__all__ = ["RangeGenerator", "DataType", "RandomDistribution"]
|
||||
|
||||
TVar = TypeVar('TVar', str, int, float, datetime)
|
||||
TVar = TypeVar("TVar", str, int, float, datetime)
|
||||
|
||||
|
||||
class DataType(Enum):
|
||||
@ -61,18 +61,18 @@ class DataType(Enum):
|
||||
|
||||
def __str__(self):
|
||||
typenames = {
|
||||
DataType.DOUBLE: 'dbl',
|
||||
DataType.STRING: 'str',
|
||||
DataType.OBJECT: 'obj',
|
||||
DataType.ARRAY: 'arr',
|
||||
DataType.OBJECTID: 'oid',
|
||||
DataType.BOOLEAN: 'bool',
|
||||
DataType.DATE: 'dt',
|
||||
DataType.NULL: 'null',
|
||||
DataType.INTEGER: 'int',
|
||||
DataType.TIMESTAMP: 'ts',
|
||||
DataType.DECIMAL128: 'dec',
|
||||
DataType.MIXDATA: 'mixdata',
|
||||
DataType.DOUBLE: "dbl",
|
||||
DataType.STRING: "str",
|
||||
DataType.OBJECT: "obj",
|
||||
DataType.ARRAY: "arr",
|
||||
DataType.OBJECTID: "oid",
|
||||
DataType.BOOLEAN: "bool",
|
||||
DataType.DATE: "dt",
|
||||
DataType.NULL: "null",
|
||||
DataType.INTEGER: "int",
|
||||
DataType.TIMESTAMP: "ts",
|
||||
DataType.DECIMAL128: "dec",
|
||||
DataType.MIXDATA: "mixdata",
|
||||
}
|
||||
return typenames[self]
|
||||
|
||||
@ -89,7 +89,8 @@ class RangeGenerator(Generic[TVar]):
|
||||
|
||||
def __post_init__(self):
|
||||
assert type(self.interval_begin) == type(
|
||||
self.interval_end), 'Interval ends must of the same type.'
|
||||
self.interval_end
|
||||
), "Interval ends must of the same type."
|
||||
if type(self.interval_begin) == int or type(self.interval_begin) == float:
|
||||
self.ndv = round((self.interval_end - self.interval_begin) / self.step)
|
||||
elif type(self.interval_begin) == datetime:
|
||||
@ -101,37 +102,39 @@ class RangeGenerator(Generic[TVar]):
|
||||
"""Generate the range."""
|
||||
|
||||
gen_range_dict = {
|
||||
DataType.STRING:
|
||||
ansi_range,
|
||||
DataType.INTEGER:
|
||||
range,
|
||||
DataType.STRING: ansi_range,
|
||||
DataType.INTEGER: range,
|
||||
# The arange function produces equi-distant values which is too regular for CE testing.
|
||||
# It is left here as a possible way of generating doubles.
|
||||
# DataType.DOUBLE: np.arange
|
||||
DataType.DOUBLE:
|
||||
double_range,
|
||||
DataType.DATE:
|
||||
datetime_range,
|
||||
DataType.DOUBLE: double_range,
|
||||
DataType.DATE: datetime_range,
|
||||
}
|
||||
|
||||
gen_range = gen_range_dict.get(self.data_type)
|
||||
if gen_range is None:
|
||||
raise ValueError(f'Unsupported data type: {self.data_type}')
|
||||
raise ValueError(f"Unsupported data type: {self.data_type}")
|
||||
|
||||
return list(gen_range(self.interval_begin, self.interval_end, self.step))
|
||||
|
||||
def __str__(self):
|
||||
# TODO: for now skip NDV from the name to make it shorter.
|
||||
#ndv_str = "_" if self.ndv <= 0 else f'_{self.ndv}_'
|
||||
begin_str = str(self.interval_begin.date()) if isinstance(
|
||||
self.interval_begin, datetime) else str(self.interval_begin)
|
||||
end_str = str(self.interval_end.date()) if isinstance(self.interval_end, datetime) else str(
|
||||
self.interval_end)
|
||||
# ndv_str = "_" if self.ndv <= 0 else f'_{self.ndv}_'
|
||||
begin_str = (
|
||||
str(self.interval_begin.date())
|
||||
if isinstance(self.interval_begin, datetime)
|
||||
else str(self.interval_begin)
|
||||
)
|
||||
end_str = (
|
||||
str(self.interval_end.date())
|
||||
if isinstance(self.interval_end, datetime)
|
||||
else str(self.interval_end)
|
||||
)
|
||||
|
||||
str_rep = f'{str(self.data_type)}_{begin_str}-{end_str}-{self.step}'
|
||||
str_rep = f"{str(self.data_type)}_{begin_str}-{end_str}-{self.step}"
|
||||
# Remove dots and spaces from field names.
|
||||
str_rep = str_rep.replace('.', ',')
|
||||
str_rep = str_rep.replace(' ', '_')
|
||||
str_rep = str_rep.replace(".", ",")
|
||||
str_rep = str_rep.replace(" ", "_")
|
||||
return str_rep
|
||||
|
||||
|
||||
@ -145,14 +148,14 @@ def ansi_range(begin: str, end: str, step: int = 1):
|
||||
"""Produces a sequence of string from begin to end."""
|
||||
|
||||
alphabet_size = 28
|
||||
non_alpha_char = '_'
|
||||
non_alpha_char = "_"
|
||||
|
||||
def ansi_to_int(data: str) -> int:
|
||||
res = 0
|
||||
for char in data.lower():
|
||||
res = res * alphabet_size
|
||||
if 'a' <= char <= 'z':
|
||||
res += ord(char) - ord('a') + 1
|
||||
if "a" <= char <= "z":
|
||||
res += ord(char) - ord("a") + 1
|
||||
else:
|
||||
res += alphabet_size - 1
|
||||
return res
|
||||
@ -164,10 +167,10 @@ def ansi_range(begin: str, end: str, step: int = 1):
|
||||
if remainder == alphabet_size - 1:
|
||||
char = non_alpha_char
|
||||
else:
|
||||
char = chr(remainder + ord('a') - 1)
|
||||
char = chr(remainder + ord("a") - 1)
|
||||
result.append(char)
|
||||
result.reverse()
|
||||
return ''.join(result)
|
||||
return "".join(result)
|
||||
|
||||
def get_common_prefix_len(s1: str, s2: str):
|
||||
index = 0
|
||||
@ -188,7 +191,7 @@ def ansi_range(begin: str, end: str, step: int = 1):
|
||||
if prefix_len == 0:
|
||||
yield int_to_ansi(number)
|
||||
else:
|
||||
yield f'{prefix}{int_to_ansi(number)}'
|
||||
yield f"{prefix}{int_to_ansi(number)}"
|
||||
|
||||
|
||||
def datetime_range(begin: datetime, end: datetime, step: int = 60):
|
||||
@ -199,8 +202,8 @@ def datetime_range(begin: datetime, end: datetime, step: int = 60):
|
||||
for _ in range(0, num_values):
|
||||
random_ts = np.random.randint(begin_ts, end_ts)
|
||||
yield datetime.fromtimestamp(random_ts)
|
||||
#random_dates = [datetime.fromtimestamp(random_ts) for random_ts in random.sample(range(int(begin_ts), int(end_ts)), num_values)]
|
||||
#return random_dates
|
||||
# random_dates = [datetime.fromtimestamp(random_ts) for random_ts in random.sample(range(int(begin_ts), int(end_ts)), num_values)]
|
||||
# return random_dates
|
||||
|
||||
|
||||
class DistributionType(Enum):
|
||||
@ -226,8 +229,8 @@ class RandomDistribution:
|
||||
distribution_type: DistributionType
|
||||
values: Union[Sequence[TVar], RangeGenerator]
|
||||
weights: Union[Sequence[float], None]
|
||||
values_name: str = ''
|
||||
weights_name: str = ''
|
||||
values_name: str = ""
|
||||
weights_name: str = ""
|
||||
|
||||
def __str__(self):
|
||||
def print_values(vals):
|
||||
@ -235,56 +238,71 @@ class RandomDistribution:
|
||||
return str(vals)
|
||||
elif isinstance(vals[0], RandomDistribution):
|
||||
# Must be a mixed distribution
|
||||
res = ''
|
||||
res = ""
|
||||
for distr in vals:
|
||||
res += f'{str(distr)}_'
|
||||
res += f"{str(distr)}_"
|
||||
return res
|
||||
else:
|
||||
# All values are of the same type because of how RangeGenerator works
|
||||
return f'{type(vals[0]).__name__}_{min(vals)}_{max(vals)}_{len(vals)}'
|
||||
return f"{type(vals[0]).__name__}_{min(vals)}_{max(vals)}_{len(vals)}"
|
||||
|
||||
range_str = ''
|
||||
if hasattr(self, 'values'):
|
||||
range_str = ""
|
||||
if hasattr(self, "values"):
|
||||
range_str = print_values(self.values)
|
||||
if self.values_name != '':
|
||||
range_str += f'_{self.values_name}'
|
||||
if self.weights_name != '':
|
||||
range_str += f'_{self.weights_name}'
|
||||
if self.values_name != "":
|
||||
range_str += f"_{self.values_name}"
|
||||
if self.weights_name != "":
|
||||
range_str += f"_{self.weights_name}"
|
||||
|
||||
distr_str = f'{str(self.distribution_type)}_{range_str}'
|
||||
distr_str = f"{str(self.distribution_type)}_{range_str}"
|
||||
return distr_str
|
||||
|
||||
@staticmethod
|
||||
def choice(values: Sequence[TVar], weights: Union[Sequence[float], RangeGenerator],
|
||||
v_name: str = '', w_name: str = ''):
|
||||
def choice(
|
||||
values: Sequence[TVar],
|
||||
weights: Union[Sequence[float], RangeGenerator],
|
||||
v_name: str = "",
|
||||
w_name: str = "",
|
||||
):
|
||||
"""Create choice distribution."""
|
||||
return RandomDistribution(distribution_type=DistributionType.CHOICE, values=values,
|
||||
weights=weights, values_name=v_name, weights_name=w_name)
|
||||
return RandomDistribution(
|
||||
distribution_type=DistributionType.CHOICE,
|
||||
values=values,
|
||||
weights=weights,
|
||||
values_name=v_name,
|
||||
weights_name=w_name,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def normal(values: Union[Sequence[TVar], RangeGenerator]):
|
||||
"""Create normal distribution."""
|
||||
return RandomDistribution(distribution_type=DistributionType.NORMAL, values=values,
|
||||
weights=None)
|
||||
return RandomDistribution(
|
||||
distribution_type=DistributionType.NORMAL, values=values, weights=None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def noncentral_chisquare(values: Union[Sequence[TVar], RangeGenerator]):
|
||||
"""Create Non Central Chi2 distribution."""
|
||||
return RandomDistribution(distribution_type=DistributionType.CHI2, values=values,
|
||||
weights=None)
|
||||
return RandomDistribution(
|
||||
distribution_type=DistributionType.CHI2, values=values, weights=None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def uniform(values: Union[Sequence[TVar], RangeGenerator]):
|
||||
"""Create uniform distribution."""
|
||||
return RandomDistribution(distribution_type=DistributionType.UNIFORM, values=values,
|
||||
weights=None)
|
||||
return RandomDistribution(
|
||||
distribution_type=DistributionType.UNIFORM, values=values, weights=None
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def mixed(children: Sequence[RandomDistribution],
|
||||
weight: Union[Sequence[float], RangeGenerator]):
|
||||
def mixed(
|
||||
children: Sequence[RandomDistribution],
|
||||
weight: Union[Sequence[float], RangeGenerator],
|
||||
):
|
||||
"""Create mixed distribution."""
|
||||
return RandomDistribution(distribution_type=DistributionType.MIXDIST, values=children,
|
||||
weights=weight)
|
||||
return RandomDistribution(
|
||||
distribution_type=DistributionType.MIXDIST, values=children, weights=weight
|
||||
)
|
||||
|
||||
def generate(self, size: int) -> Sequence[TVar]:
|
||||
"""Generate random data sequence of the given size."""
|
||||
@ -306,7 +324,9 @@ class RandomDistribution:
|
||||
probs = None
|
||||
|
||||
if probs is not None and len(probs) != len(values):
|
||||
raise ValueError(f'values and probs must be the same size: {probs} !! {values}')
|
||||
raise ValueError(
|
||||
f"values and probs must be the same size: {probs} !! {values}"
|
||||
)
|
||||
|
||||
if len(values) == 0:
|
||||
raise ValueError(f"Values cannot be empty: {self.values}")
|
||||
@ -382,7 +402,9 @@ class RandomDistribution:
|
||||
index = len(values) - 1
|
||||
return values[index]
|
||||
|
||||
return [get_value(n) for n in _rng.noncentral_chisquare(df=df, nonc=nonc, size=size)]
|
||||
return [
|
||||
get_value(n) for n in _rng.noncentral_chisquare(df=df, nonc=nonc, size=size)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _uniform(size: int, values: Sequence[TVar], _: Sequence[float]):
|
||||
@ -393,15 +415,19 @@ class RandomDistribution:
|
||||
return [get_value(n) for n in _rng.uniform(low=0, high=len(values), size=size)]
|
||||
|
||||
@staticmethod
|
||||
def _mixed(size: int, children: Sequence[RandomDistribution], probs: Sequence[float]):
|
||||
def _mixed(
|
||||
size: int, children: Sequence[RandomDistribution], probs: Sequence[float]
|
||||
):
|
||||
if probs is None:
|
||||
raise ValueError(f'probs must be specified for mixed distributions: {str(children)}')
|
||||
raise ValueError(
|
||||
f"probs must be specified for mixed distributions: {str(children)}"
|
||||
)
|
||||
|
||||
result = []
|
||||
for child_distr, prob in zip(children, probs):
|
||||
if not isinstance(child_distr, RandomDistribution):
|
||||
raise ValueError(
|
||||
f'children must be of type RandomDistribution for mixed distribution, child_distr: {child_distr}'
|
||||
f"children must be of type RandomDistribution for mixed distribution, child_distr: {child_distr}"
|
||||
)
|
||||
child_size = int(size * prob)
|
||||
result.append(child_distr.generate(child_size))
|
||||
@ -419,14 +445,16 @@ class ArrayRandomDistribution(RandomDistribution):
|
||||
lengths_distr: RandomDistribution = _NO_DEFAULT
|
||||
value_distr: RandomDistribution = _NO_DEFAULT
|
||||
|
||||
def __init__(self, lengths_distr: RandomDistribution, value_distr: RandomDistribution):
|
||||
def __init__(
|
||||
self, lengths_distr: RandomDistribution, value_distr: RandomDistribution
|
||||
):
|
||||
self.lengths_distr = lengths_distr
|
||||
self.value_distr = value_distr
|
||||
self.distribution_type = value_distr.distribution_type
|
||||
|
||||
def __str__(self):
|
||||
distr_str = f'{super().__str__()}'
|
||||
distr_str += f'array_{str(self.value_distr)}_{str(self.lengths_distr)}'
|
||||
distr_str = f"{super().__str__()}"
|
||||
distr_str += f"array_{str(self.value_distr)}_{str(self.lengths_distr)}"
|
||||
return distr_str
|
||||
|
||||
def generate(self, size: int):
|
||||
@ -450,8 +478,12 @@ class DocumentRandomDistribution(RandomDistribution):
|
||||
fields_distr: RandomDistribution = _NO_DEFAULT
|
||||
field_to_distribution: dict = _NO_DEFAULT
|
||||
|
||||
def __init__(self, number_of_fields_distr: RandomDistribution, fields_distr: RandomDistribution,
|
||||
field_to_distribution: dict):
|
||||
def __init__(
|
||||
self,
|
||||
number_of_fields_distr: RandomDistribution,
|
||||
fields_distr: RandomDistribution,
|
||||
field_to_distribution: dict,
|
||||
):
|
||||
self.number_of_fields_distr = number_of_fields_distr
|
||||
self.fields_distr = fields_distr
|
||||
self.field_to_distribution = field_to_distribution
|
||||
@ -462,7 +494,7 @@ class DocumentRandomDistribution(RandomDistribution):
|
||||
raise ValueError("Must provide a RandomDistribution for each field")
|
||||
|
||||
def __str__(self):
|
||||
return f'{super().__str__()}'
|
||||
return f"{super().__str__()}"
|
||||
|
||||
def generate(self, size: int):
|
||||
"""Generate random document sequence of the given size."""
|
||||
@ -479,7 +511,9 @@ class DocumentRandomDistribution(RandomDistribution):
|
||||
for idx, num in enumerate(nums):
|
||||
doc = {}
|
||||
if not isinstance(num, int):
|
||||
raise ValueError("the number of fields must be an int for document generation")
|
||||
raise ValueError(
|
||||
"the number of fields must be an int for document generation"
|
||||
)
|
||||
|
||||
field_names = self.fields_distr.generate(num)
|
||||
for field in field_names:
|
||||
@ -494,12 +528,12 @@ class DocumentRandomDistribution(RandomDistribution):
|
||||
return self.fields_distr.get_values()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
from collections import Counter
|
||||
|
||||
def print_distr(title, distr, size=10000):
|
||||
"""Print distribution."""
|
||||
print(f'\n{title}: {str(distr)}\n')
|
||||
print(f"\n{title}: {str(distr)}\n")
|
||||
rs = distr.generate(size)
|
||||
has_arrays = any(isinstance(elem, list) for elem in rs)
|
||||
has_dict = any(isinstance(elem, dict) for elem in rs)
|
||||
@ -516,40 +550,61 @@ if __name__ == '__main__':
|
||||
for elem in rs:
|
||||
print(elem)
|
||||
|
||||
choice = RandomDistribution.choice(values=['pooh', 'rabbit', 'piglet', 'Chris'],
|
||||
weights=[0.5, 0.1, 0.1, 0.3])
|
||||
choice = RandomDistribution.choice(
|
||||
values=["pooh", "rabbit", "piglet", "Chris"], weights=[0.5, 0.1, 0.1, 0.3]
|
||||
)
|
||||
print_distr("Choice", choice, 1000)
|
||||
|
||||
string_generator = RangeGenerator(data_type=DataType.STRING, interval_begin='hello_a',
|
||||
interval_end='hello__')
|
||||
string_generator = RangeGenerator(
|
||||
data_type=DataType.STRING, interval_begin="hello_a", interval_end="hello__"
|
||||
)
|
||||
str_normal = RandomDistribution.normal(string_generator)
|
||||
print_distr("Normal for strings", str_normal)
|
||||
|
||||
int_noncentral_chisquare = RandomDistribution.noncentral_chisquare(list(range(1, 30)))
|
||||
int_noncentral_chisquare = RandomDistribution.noncentral_chisquare(
|
||||
list(range(1, 30))
|
||||
)
|
||||
print_distr("Noncentral Chisquare for integers", int_noncentral_chisquare)
|
||||
|
||||
float_uniform = RandomDistribution.uniform(RangeGenerator(DataType.DOUBLE, 0.1, 10.0, 0.37))
|
||||
float_uniform = RandomDistribution.uniform(
|
||||
RangeGenerator(DataType.DOUBLE, 0.1, 10.0, 0.37)
|
||||
)
|
||||
print_distr("Uniform for floats", float_uniform)
|
||||
|
||||
float_normal = RandomDistribution.normal(RangeGenerator(DataType.DOUBLE, 0.1, 10.0, 0.37))
|
||||
float_normal = RandomDistribution.normal(
|
||||
RangeGenerator(DataType.DOUBLE, 0.1, 10.0, 0.37)
|
||||
)
|
||||
print_distr("Normal for floats", float_normal)
|
||||
|
||||
FOUR_DAYS_IN_SECONDS = 60 * 20 * 24 * 12
|
||||
|
||||
date_uniform = RandomDistribution.uniform(
|
||||
RangeGenerator(DataType.DATE, datetime(2007, 1, 1), datetime(2008, 1, 1),
|
||||
FOUR_DAYS_IN_SECONDS))
|
||||
RangeGenerator(
|
||||
DataType.DATE,
|
||||
datetime(2007, 1, 1),
|
||||
datetime(2008, 1, 1),
|
||||
FOUR_DAYS_IN_SECONDS,
|
||||
)
|
||||
)
|
||||
print_distr("Uniform for dates", date_uniform, size=1000)
|
||||
|
||||
date_normal = RandomDistribution.normal(
|
||||
RangeGenerator(DataType.DATE, datetime(2007, 1, 1), datetime(2008, 1, 1),
|
||||
FOUR_DAYS_IN_SECONDS))
|
||||
RangeGenerator(
|
||||
DataType.DATE,
|
||||
datetime(2007, 1, 1),
|
||||
datetime(2008, 1, 1),
|
||||
FOUR_DAYS_IN_SECONDS,
|
||||
)
|
||||
)
|
||||
print_distr("Normal for dates", date_normal, size=1000)
|
||||
|
||||
str_chisquare2 = RandomDistribution.normal(RangeGenerator(DataType.STRING, "aa", "ba"))
|
||||
str_chisquare2 = RandomDistribution.normal(
|
||||
RangeGenerator(DataType.STRING, "aa", "ba")
|
||||
)
|
||||
str_normal2 = RandomDistribution.normal(RangeGenerator(DataType.STRING, "ap", "bp"))
|
||||
mixed = RandomDistribution.mixed(children=[float_uniform, str_chisquare2, str_normal2],
|
||||
weight=[0.3, 0.5, 0.2])
|
||||
mixed = RandomDistribution.mixed(
|
||||
children=[float_uniform, str_chisquare2, str_normal2], weight=[0.3, 0.5, 0.2]
|
||||
)
|
||||
print_distr("Mixed", mixed, 20_000)
|
||||
|
||||
int_normal = RandomDistribution.normal(RangeGenerator(DataType.INTEGER, 2, 10))
|
||||
@ -557,24 +612,33 @@ if __name__ == '__main__':
|
||||
arr_distr = ArrayRandomDistribution(int_normal, mixed)
|
||||
print_distr("Mixed Arrays", arr_distr, 100)
|
||||
|
||||
mixed_with_arrays = RandomDistribution.mixed(children=[float_uniform, str_normal2, arr_distr],
|
||||
weight=[0.3, 0.2, 0.5])
|
||||
mixed_with_arrays = RandomDistribution.mixed(
|
||||
children=[float_uniform, str_normal2, arr_distr], weight=[0.3, 0.2, 0.5]
|
||||
)
|
||||
nested_arr_distr = ArrayRandomDistribution(int_normal, mixed_with_arrays)
|
||||
|
||||
print_distr("Mixed Nested Arrays", nested_arr_distr, 100)
|
||||
|
||||
simple_doc_distr = DocumentRandomDistribution(
|
||||
RandomDistribution.normal(RangeGenerator(DataType.INTEGER, 1, 2)),
|
||||
RandomDistribution.uniform(["obj"]), {"obj": int_normal})
|
||||
RandomDistribution.uniform(["obj"]),
|
||||
{"obj": int_normal},
|
||||
)
|
||||
|
||||
field_name_choice = RandomDistribution.uniform(['a', 'b', 'c', 'd', 'e', 'f'])
|
||||
field_name_choice = RandomDistribution.uniform(["a", "b", "c", "d", "e", "f"])
|
||||
|
||||
field_to_distr = {
|
||||
'a': int_normal, 'b': str_normal, 'c': mixed, 'd': arr_distr, 'e': nested_arr_distr,
|
||||
'f': simple_doc_distr
|
||||
"a": int_normal,
|
||||
"b": str_normal,
|
||||
"c": mixed,
|
||||
"d": arr_distr,
|
||||
"e": nested_arr_distr,
|
||||
"f": simple_doc_distr,
|
||||
}
|
||||
nested_doc_distr = DocumentRandomDistribution(
|
||||
RandomDistribution.normal(RangeGenerator(DataType.INTEGER, 0, 7)), field_name_choice,
|
||||
field_to_distr)
|
||||
RandomDistribution.normal(RangeGenerator(DataType.INTEGER, 0, 7)),
|
||||
field_name_choice,
|
||||
field_to_distr,
|
||||
)
|
||||
|
||||
print_distr("Nested Document generation", nested_doc_distr, 100)
|
||||
|
||||
@ -46,160 +46,228 @@ from workload_execution import Query, QueryParameters
|
||||
__all__ = []
|
||||
|
||||
|
||||
def save_to_csv(parameters: Mapping[str, Sequence[CostModelParameters]], filepath: str) -> None:
|
||||
def save_to_csv(
|
||||
parameters: Mapping[str, Sequence[CostModelParameters]], filepath: str
|
||||
) -> None:
|
||||
"""Save model input parameters to a csv file."""
|
||||
abt_type_name = 'abt_type'
|
||||
abt_type_name = "abt_type"
|
||||
fieldnames = [
|
||||
abt_type_name, *[f.name for f in dataclasses.fields(ExecutionStats)],
|
||||
*[f.name for f in dataclasses.fields(QueryParameters)]
|
||||
abt_type_name,
|
||||
*[f.name for f in dataclasses.fields(ExecutionStats)],
|
||||
*[f.name for f in dataclasses.fields(QueryParameters)],
|
||||
]
|
||||
with open(filepath, 'w', newline='') as csvfile:
|
||||
with open(filepath, "w", newline="") as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
writer.writeheader()
|
||||
for abt_type, type_params_list in parameters.items():
|
||||
for type_params in type_params_list:
|
||||
fields = dataclasses.asdict(type_params.execution_stats) | dataclasses.asdict(
|
||||
type_params.query_params)
|
||||
fields = dataclasses.asdict(
|
||||
type_params.execution_stats
|
||||
) | dataclasses.asdict(type_params.query_params)
|
||||
fields[abt_type_name] = abt_type
|
||||
writer.writerow(fields)
|
||||
|
||||
|
||||
async def execute_index_scan_queries(database: DatabaseInstance,
|
||||
collections: Sequence[CollectionInfo]):
|
||||
collection = [ci for ci in collections if ci.name.startswith('index_scan')][0]
|
||||
fields = [f for f in collection.fields if f.name == 'choice']
|
||||
async def execute_index_scan_queries(
|
||||
database: DatabaseInstance, collections: Sequence[CollectionInfo]
|
||||
):
|
||||
collection = [ci for ci in collections if ci.name.startswith("index_scan")][0]
|
||||
fields = [f for f in collection.fields if f.name == "choice"]
|
||||
|
||||
requests = []
|
||||
|
||||
for field in fields:
|
||||
for val in field.distribution.get_values():
|
||||
if val.startswith('_'):
|
||||
if val.startswith("_"):
|
||||
continue
|
||||
keys_length = len(val) + 2
|
||||
requests.append(
|
||||
Query(pipeline=[{'$match': {field.name: val}}], keys_length_in_bytes=keys_length,
|
||||
note='IndexScan'))
|
||||
Query(
|
||||
pipeline=[{"$match": {field.name: val}}],
|
||||
keys_length_in_bytes=keys_length,
|
||||
note="IndexScan",
|
||||
)
|
||||
)
|
||||
|
||||
await workload_execution.execute(database, main_config.workload_execution, [collection],
|
||||
requests)
|
||||
await workload_execution.execute(
|
||||
database, main_config.workload_execution, [collection], requests
|
||||
)
|
||||
|
||||
|
||||
async def execute_physical_scan_queries(database: DatabaseInstance,
|
||||
collections: Sequence[CollectionInfo]):
|
||||
collections = [ci for ci in collections if ci.name.startswith('physical_scan')]
|
||||
fields = [f for f in collections[0].fields if f.name == 'choice']
|
||||
async def execute_physical_scan_queries(
|
||||
database: DatabaseInstance, collections: Sequence[CollectionInfo]
|
||||
):
|
||||
collections = [ci for ci in collections if ci.name.startswith("physical_scan")]
|
||||
fields = [f for f in collections[0].fields if f.name == "choice"]
|
||||
requests = []
|
||||
for field in fields:
|
||||
for val in field.distribution.get_values()[::3]:
|
||||
if val.startswith('_'):
|
||||
if val.startswith("_"):
|
||||
continue
|
||||
keys_length = len(val) + 2
|
||||
requests.append(
|
||||
Query(pipeline=[{'$match': {field.name: val}}, {"$limit": 10}],
|
||||
keys_length_in_bytes=keys_length, note='PhysicalScan'))
|
||||
Query(
|
||||
pipeline=[{"$match": {field.name: val}}, {"$limit": 10}],
|
||||
keys_length_in_bytes=keys_length,
|
||||
note="PhysicalScan",
|
||||
)
|
||||
)
|
||||
|
||||
await workload_execution.execute(database, main_config.workload_execution, collections,
|
||||
requests)
|
||||
await workload_execution.execute(
|
||||
database, main_config.workload_execution, collections, requests
|
||||
)
|
||||
|
||||
|
||||
async def execute_index_intersections_with_requests(database: DatabaseInstance,
|
||||
collections: Sequence[CollectionInfo],
|
||||
requests: Sequence[Query]):
|
||||
async def execute_index_intersections_with_requests(
|
||||
database: DatabaseInstance,
|
||||
collections: Sequence[CollectionInfo],
|
||||
requests: Sequence[Query],
|
||||
):
|
||||
try:
|
||||
await database.set_parameter('internalCostModelCoefficients',
|
||||
'{"filterIncrementalCost": 10000.0}')
|
||||
await database.set_parameter('internalCascadesOptimizerDisableMergeJoinRIDIntersect', False)
|
||||
await database.set_parameter('internalCascadesOptimizerDisableHashJoinRIDIntersect', False)
|
||||
await database.set_parameter(
|
||||
"internalCostModelCoefficients", '{"filterIncrementalCost": 10000.0}'
|
||||
)
|
||||
await database.set_parameter(
|
||||
"internalCascadesOptimizerDisableMergeJoinRIDIntersect", False
|
||||
)
|
||||
await database.set_parameter(
|
||||
"internalCascadesOptimizerDisableHashJoinRIDIntersect", False
|
||||
)
|
||||
|
||||
await workload_execution.execute(database, main_config.workload_execution, collections,
|
||||
requests)
|
||||
await workload_execution.execute(
|
||||
database, main_config.workload_execution, collections, requests
|
||||
)
|
||||
|
||||
await database.set_parameter('internalCascadesOptimizerDisableMergeJoinRIDIntersect', True)
|
||||
await database.set_parameter('internalCascadesOptimizerDisableHashJoinRIDIntersect', True)
|
||||
await database.set_parameter(
|
||||
"internalCascadesOptimizerDisableMergeJoinRIDIntersect", True
|
||||
)
|
||||
await database.set_parameter(
|
||||
"internalCascadesOptimizerDisableHashJoinRIDIntersect", True
|
||||
)
|
||||
|
||||
main_config.workload_execution.write_mode = WriteMode.APPEND
|
||||
await workload_execution.execute(database, main_config.workload_execution, collections,
|
||||
requests[::4])
|
||||
await workload_execution.execute(
|
||||
database, main_config.workload_execution, collections, requests[::4]
|
||||
)
|
||||
|
||||
finally:
|
||||
await database.set_parameter('internalCascadesOptimizerDisableMergeJoinRIDIntersect', False)
|
||||
await database.set_parameter('internalCascadesOptimizerDisableHashJoinRIDIntersect', False)
|
||||
await database.set_parameter('internalCostModelCoefficients', '')
|
||||
await database.set_parameter(
|
||||
"internalCascadesOptimizerDisableMergeJoinRIDIntersect", False
|
||||
)
|
||||
await database.set_parameter(
|
||||
"internalCascadesOptimizerDisableHashJoinRIDIntersect", False
|
||||
)
|
||||
await database.set_parameter("internalCostModelCoefficients", "")
|
||||
|
||||
|
||||
async def execute_index_intersections(database: DatabaseInstance,
|
||||
collections: Sequence[CollectionInfo]):
|
||||
collections = [ci for ci in collections if ci.name.startswith('c_int')]
|
||||
async def execute_index_intersections(
|
||||
database: DatabaseInstance, collections: Sequence[CollectionInfo]
|
||||
):
|
||||
collections = [ci for ci in collections if ci.name.startswith("c_int")]
|
||||
|
||||
requests = []
|
||||
|
||||
for i in range(0, 1000, 100):
|
||||
requests.append(Query(pipeline=[{'$match': {'in1': i, 'in2': i}}], keys_length_in_bytes=1))
|
||||
requests.append(
|
||||
Query(pipeline=[{"$match": {"in1": i, "in2": i}}], keys_length_in_bytes=1)
|
||||
)
|
||||
|
||||
requests.append(
|
||||
Query(pipeline=[{'$match': {'in1': i, 'in2': 1000 - i}}], keys_length_in_bytes=1))
|
||||
Query(
|
||||
pipeline=[{"$match": {"in1": i, "in2": 1000 - i}}],
|
||||
keys_length_in_bytes=1,
|
||||
)
|
||||
)
|
||||
|
||||
requests.append(
|
||||
Query(pipeline=[{'$match': {'in1': {'$lte': i}, 'in2': 1000 - i}}],
|
||||
keys_length_in_bytes=1))
|
||||
Query(
|
||||
pipeline=[{"$match": {"in1": {"$lte": i}, "in2": 1000 - i}}],
|
||||
keys_length_in_bytes=1,
|
||||
)
|
||||
)
|
||||
|
||||
requests.append(
|
||||
Query(pipeline=[{'$match': {'in1': i, 'in2': {'$gt': 1000 - i}}}],
|
||||
keys_length_in_bytes=1))
|
||||
Query(
|
||||
pipeline=[{"$match": {"in1": i, "in2": {"$gt": 1000 - i}}}],
|
||||
keys_length_in_bytes=1,
|
||||
)
|
||||
)
|
||||
|
||||
await execute_index_intersections_with_requests(database, collections, requests)
|
||||
|
||||
|
||||
async def execute_evaluation(database: DatabaseInstance, collections: Sequence[CollectionInfo]):
|
||||
collections = [ci for ci in collections if ci.name.startswith('c_int_05')]
|
||||
async def execute_evaluation(
|
||||
database: DatabaseInstance, collections: Sequence[CollectionInfo]
|
||||
):
|
||||
collections = [ci for ci in collections if ci.name.startswith("c_int_05")]
|
||||
requests = []
|
||||
|
||||
for i in [100, 500, 1000]:
|
||||
requests.append(
|
||||
Query(pipeline=[{'$project': {'uniform1': 1, 'mixed2': 1}}, {"$limit": i}],
|
||||
keys_length_in_bytes=1, number_of_fields=1, note='Evaluation'))
|
||||
Query(
|
||||
pipeline=[{"$project": {"uniform1": 1, "mixed2": 1}}, {"$limit": i}],
|
||||
keys_length_in_bytes=1,
|
||||
number_of_fields=1,
|
||||
note="Evaluation",
|
||||
)
|
||||
)
|
||||
|
||||
await workload_execution.execute(database, main_config.workload_execution, collections,
|
||||
requests)
|
||||
await workload_execution.execute(
|
||||
database, main_config.workload_execution, collections, requests
|
||||
)
|
||||
|
||||
|
||||
async def execute_unwind(database: DatabaseInstance, collections: Sequence[CollectionInfo]):
|
||||
collections = [ci for ci in collections if ci.name.startswith('c_arr_01')]
|
||||
async def execute_unwind(
|
||||
database: DatabaseInstance, collections: Sequence[CollectionInfo]
|
||||
):
|
||||
collections = [ci for ci in collections if ci.name.startswith("c_arr_01")]
|
||||
requests = []
|
||||
# average size of arrays in the collection
|
||||
average_size_of_arrays = 10
|
||||
|
||||
for _ in range(500, 1000, 100):
|
||||
requests.append(
|
||||
Query(pipeline=[{"$unwind": "$as"}], number_of_fields=average_size_of_arrays))
|
||||
Query(
|
||||
pipeline=[{"$unwind": "$as"}], number_of_fields=average_size_of_arrays
|
||||
)
|
||||
)
|
||||
|
||||
await workload_execution.execute(database, main_config.workload_execution, collections,
|
||||
requests)
|
||||
await workload_execution.execute(
|
||||
database, main_config.workload_execution, collections, requests
|
||||
)
|
||||
|
||||
|
||||
async def execute_unique(database: DatabaseInstance, collections: Sequence[CollectionInfo]):
|
||||
collections = [ci for ci in collections if ci.name.startswith('c_arr_01')]
|
||||
async def execute_unique(
|
||||
database: DatabaseInstance, collections: Sequence[CollectionInfo]
|
||||
):
|
||||
collections = [ci for ci in collections if ci.name.startswith("c_arr_01")]
|
||||
requests = []
|
||||
|
||||
for i in range(500, 1000, 200):
|
||||
requests.append(Query(pipeline=[{"$match": {"as": {"$gt": i}}}]))
|
||||
|
||||
await workload_execution.execute(database, main_config.workload_execution, collections,
|
||||
requests)
|
||||
await workload_execution.execute(
|
||||
database, main_config.workload_execution, collections, requests
|
||||
)
|
||||
|
||||
|
||||
async def execute_limitskip(database: DatabaseInstance, collections: Sequence[CollectionInfo]):
|
||||
collection = [ci for ci in collections if ci.name.startswith('index_scan')][0]
|
||||
async def execute_limitskip(
|
||||
database: DatabaseInstance, collections: Sequence[CollectionInfo]
|
||||
):
|
||||
collection = [ci for ci in collections if ci.name.startswith("index_scan")][0]
|
||||
limits = [5, 10, 15, 20]
|
||||
skips = [5, 10, 15, 20]
|
||||
requests = []
|
||||
|
||||
for limit in limits:
|
||||
for skip in skips:
|
||||
requests.append(Query(pipeline=[{"$skip": skip}, {"$limit": limit}], note="LimitSkip"))
|
||||
requests.append(
|
||||
Query(pipeline=[{"$skip": skip}, {"$limit": limit}], note="LimitSkip")
|
||||
)
|
||||
|
||||
await workload_execution.execute(database, main_config.workload_execution, [collection],
|
||||
requests)
|
||||
await workload_execution.execute(
|
||||
database, main_config.workload_execution, [collection], requests
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
@ -210,7 +278,6 @@ async def main():
|
||||
# 1. Database Instance provides connectivity to a MongoDB instance, it loads data optionally
|
||||
# from the dump on creating and stores data optionally to the dump on closing.
|
||||
with DatabaseInstance(main_config.database) as database:
|
||||
|
||||
# 2. Data generation (optional), generates random data and populates collections with it.
|
||||
generator = DataGenerator(database, main_config.data_generator)
|
||||
await generator.populate_collections()
|
||||
@ -235,16 +302,17 @@ async def main():
|
||||
# aparses the explains, nd calibrates the cost model for the ABT nodes.
|
||||
models = await abt_calibrator.calibrate(main_config.abt_calibrator, database)
|
||||
for abt, model in models.items():
|
||||
print(f'{abt}\t\t{model}')
|
||||
print(f"{abt}\t\t{model}")
|
||||
|
||||
parameters = await parameters_extractor.extract_parameters(main_config.abt_calibrator,
|
||||
database, [])
|
||||
save_to_csv(parameters, 'parameters.csv')
|
||||
parameters = await parameters_extractor.extract_parameters(
|
||||
main_config.abt_calibrator, database, []
|
||||
)
|
||||
save_to_csv(parameters, "parameters.csv")
|
||||
|
||||
print("DONE!")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
|
||||
@ -38,7 +38,7 @@ from config import WorkloadExecutionConfig, WriteMode
|
||||
from data_generator import CollectionInfo
|
||||
from database_instance import DatabaseInstance, Pipeline
|
||||
|
||||
__all__ = ['execute']
|
||||
__all__ = ["execute"]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -70,15 +70,19 @@ class QueryParameters:
|
||||
return QueryParameters(**json.loads(json_str))
|
||||
|
||||
|
||||
async def execute(database: DatabaseInstance, config: WorkloadExecutionConfig,
|
||||
collection_infos: Sequence[CollectionInfo], queries: Sequence[Query]):
|
||||
async def execute(
|
||||
database: DatabaseInstance,
|
||||
config: WorkloadExecutionConfig,
|
||||
collection_infos: Sequence[CollectionInfo],
|
||||
queries: Sequence[Query],
|
||||
):
|
||||
"""Run the given queries and write the collected explain into collection."""
|
||||
if not config.enabled:
|
||||
return
|
||||
|
||||
collector = WorkloadExecution(database, config)
|
||||
await collector.async_init()
|
||||
print('>>> running queries')
|
||||
print(">>> running queries")
|
||||
await collector.collect(collection_infos, queries)
|
||||
|
||||
|
||||
@ -97,33 +101,46 @@ class WorkloadExecution:
|
||||
if self.config.write_mode == WriteMode.REPLACE:
|
||||
await self.database.drop_collection(self.config.output_collection_name)
|
||||
|
||||
async def collect(self, collection_infos: Sequence[CollectionInfo], queries: Sequence[Query]):
|
||||
async def collect(
|
||||
self, collection_infos: Sequence[CollectionInfo], queries: Sequence[Query]
|
||||
):
|
||||
"""Run the given piplelines on the given collection to generate and collect execution statistics."""
|
||||
measurements = []
|
||||
|
||||
for coll_info in collection_infos:
|
||||
print(f'\n>>>>> running queries on collection {coll_info.name}')
|
||||
print(f"\n>>>>> running queries on collection {coll_info.name}")
|
||||
for query in queries:
|
||||
print(f'>>>>>>> running query {query.pipeline}')
|
||||
print(f">>>>>>> running query {query.pipeline}")
|
||||
await self._run_query(coll_info, query, measurements)
|
||||
|
||||
await self.database.insert_many(self.config.output_collection_name, measurements)
|
||||
await self.database.insert_many(
|
||||
self.config.output_collection_name, measurements
|
||||
)
|
||||
|
||||
async def _run_query(self, coll_info: CollectionInfo, query: Query, result: Sequence):
|
||||
async def _run_query(
|
||||
self, coll_info: CollectionInfo, query: Query, result: Sequence
|
||||
):
|
||||
# warm up
|
||||
for _ in range(self.config.warmup_runs):
|
||||
await self.database.explain(coll_info.name, query.pipeline)
|
||||
|
||||
run_id = ObjectId()
|
||||
avg_doc_size = await self.database.get_average_document_size(coll_info.name)
|
||||
parameters = QueryParameters(keys_length_in_bytes=query.keys_length_in_bytes,
|
||||
number_of_fields=query.number_of_fields,
|
||||
average_document_size_in_bytes=avg_doc_size, note=query.note)
|
||||
parameters = QueryParameters(
|
||||
keys_length_in_bytes=query.keys_length_in_bytes,
|
||||
number_of_fields=query.number_of_fields,
|
||||
average_document_size_in_bytes=avg_doc_size,
|
||||
note=query.note,
|
||||
)
|
||||
for _ in range(self.config.runs):
|
||||
explain = await self.database.explain(coll_info.name, query.pipeline)
|
||||
if explain['ok'] == 1:
|
||||
result.append({
|
||||
'run_id': run_id, 'collection': coll_info.name,
|
||||
'pipeline': json.dumps(query.pipeline), 'explain': json.dumps(explain),
|
||||
'query_parameters': parameters.to_json()
|
||||
})
|
||||
if explain["ok"] == 1:
|
||||
result.append(
|
||||
{
|
||||
"run_id": run_id,
|
||||
"collection": coll_info.name,
|
||||
"pipeline": json.dumps(query.pipeline),
|
||||
"explain": json.dumps(explain),
|
||||
"query_parameters": parameters.to_json(),
|
||||
}
|
||||
)
|
||||
|
||||
@ -55,13 +55,16 @@ def main(testlog_dir: str, silent_fail: bool = False):
|
||||
end = start + int(duration)
|
||||
|
||||
report["results"].append(
|
||||
Result({
|
||||
"test_file": test_file,
|
||||
"status": status,
|
||||
"start": start,
|
||||
"end": end,
|
||||
"log_raw": log_raw,
|
||||
}))
|
||||
Result(
|
||||
{
|
||||
"test_file": test_file,
|
||||
"status": status,
|
||||
"start": start,
|
||||
"end": end,
|
||||
"log_raw": log_raw,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
if report["results"]:
|
||||
with open("report.json", "wt") as fh:
|
||||
|
||||
@ -24,8 +24,10 @@ def process_version(version: evergreen.Version) -> None:
|
||||
|
||||
build_variant_name = build_variant.build_variant
|
||||
unsupported_variants = ["tsan", "asan", "aubsan", "macos"]
|
||||
if any(unsupported_variant in build_variant_name
|
||||
for unsupported_variant in unsupported_variants):
|
||||
if any(
|
||||
unsupported_variant in build_variant_name
|
||||
for unsupported_variant in unsupported_variants
|
||||
):
|
||||
continue
|
||||
|
||||
for task_id in build_variant.tasks:
|
||||
@ -53,7 +55,9 @@ def main(expansions_file: str, output_file: str) -> int:
|
||||
is_patch = expansions.get("is_patch", False)
|
||||
|
||||
today = datetime.datetime.utcnow().date()
|
||||
start_of_today = datetime.datetime(today.year, today.month, today.day, tzinfo=tz.UTC)
|
||||
start_of_today = datetime.datetime(
|
||||
today.year, today.month, today.day, tzinfo=tz.UTC
|
||||
)
|
||||
|
||||
# STM daily cron runs everyday at 4 AM
|
||||
# We scan the day before yesterday so we do not have to worry about in progress tasks assuming
|
||||
@ -61,7 +65,7 @@ def main(expansions_file: str, output_file: str) -> int:
|
||||
start_of_window = start_of_today - datetime.timedelta(days=2)
|
||||
end_of_window = start_of_today - datetime.timedelta(days=1)
|
||||
|
||||
#We only care maintaining alerts as of v7.2
|
||||
# We only care maintaining alerts as of v7.2
|
||||
mongo_projects = ["mongodb-mongo-master", "mongodb-mongo-master-nightly"]
|
||||
for project in evg_api.all_projects():
|
||||
if not project.enabled:
|
||||
@ -79,8 +83,9 @@ def main(expansions_file: str, output_file: str) -> int:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=cores) as executor:
|
||||
futures = []
|
||||
for project in mongo_projects:
|
||||
patches = evg_api.patches_by_project_time_window(project, end_of_window,
|
||||
start_of_window)
|
||||
patches = evg_api.patches_by_project_time_window(
|
||||
project, end_of_window, start_of_window
|
||||
)
|
||||
# This covers all user patches generated with `evergreen patch`
|
||||
for patch in patches:
|
||||
if not patch.activated:
|
||||
@ -91,20 +96,33 @@ def main(expansions_file: str, output_file: str) -> int:
|
||||
# This covers all automated versions that are run by crons and other stuff
|
||||
for requester in ["adhoc", "gitter_request", "github_pull_request"]:
|
||||
for version in evg_api.versions_by_project_time_window(
|
||||
project, end_of_window, start_of_window, requester):
|
||||
project, end_of_window, start_of_window, requester
|
||||
):
|
||||
futures.append(executor.submit(process_version, version=version))
|
||||
|
||||
concurrent.futures.wait(futures)
|
||||
errors = []
|
||||
if timeouts_without_dumps:
|
||||
errors.append("ERROR: The following tasks timed out without core dumps uploaded:")
|
||||
errors.append(
|
||||
"ERROR: The following tasks timed out without core dumps uploaded:"
|
||||
)
|
||||
errors.extend(
|
||||
[f"https://spruce.mongodb.com/task/{task_id}" for task_id in timeouts_without_dumps])
|
||||
[
|
||||
f"https://spruce.mongodb.com/task/{task_id}"
|
||||
for task_id in timeouts_without_dumps
|
||||
]
|
||||
)
|
||||
|
||||
if passed_with_dumps:
|
||||
errors.append("ERROR: The following tasks had core dumps uploaded while being successful:")
|
||||
errors.append(
|
||||
"ERROR: The following tasks had core dumps uploaded while being successful:"
|
||||
)
|
||||
errors.extend(
|
||||
[f"https://spruce.mongodb.com/task/{task_id}" for task_id in passed_with_dumps])
|
||||
[
|
||||
f"https://spruce.mongodb.com/task/{task_id}"
|
||||
for task_id in passed_with_dumps
|
||||
]
|
||||
)
|
||||
|
||||
if not errors:
|
||||
return 0
|
||||
@ -124,11 +142,13 @@ def main(expansions_file: str, output_file: str) -> int:
|
||||
|
||||
if timeouts_without_dumps:
|
||||
msg.append(
|
||||
f"- {len(timeouts_without_dumps)} task(s) timed out without core dumps being uploaded")
|
||||
f"- {len(timeouts_without_dumps)} task(s) timed out without core dumps being uploaded"
|
||||
)
|
||||
|
||||
if passed_with_dumps:
|
||||
msg.append(
|
||||
f"- {len(passed_with_dumps)} task(s) had core dumps uploaded while being successful")
|
||||
f"- {len(passed_with_dumps)} task(s) had core dumps uploaded while being successful"
|
||||
)
|
||||
|
||||
msg.append(
|
||||
f"For more details view the `Task Errors` file <https://spruce.mongodb.com/task/{current_task_id}/files|here>."
|
||||
@ -136,24 +156,33 @@ def main(expansions_file: str, output_file: str) -> int:
|
||||
|
||||
if not is_patch:
|
||||
evg_api.send_slack_message(
|
||||
target="#sdp-test-alerts", #TODO SERVER-83205: change to #sdp-triager
|
||||
target="#sdp-test-alerts", # TODO SERVER-83205: change to #sdp-triager
|
||||
msg="\n".join(msg),
|
||||
)
|
||||
|
||||
#TODO SERVER-83205: change to return 1
|
||||
# TODO SERVER-83205: change to return 1
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='DailyTaskScanner',
|
||||
description='Iterates over all of the tasks in mongodb-mongo-master and '
|
||||
'mongodb-mongo-master-nightly over a day and looks for certain errors to send '
|
||||
'alerts of.')
|
||||
prog="DailyTaskScanner",
|
||||
description="Iterates over all of the tasks in mongodb-mongo-master and "
|
||||
"mongodb-mongo-master-nightly over a day and looks for certain errors to send "
|
||||
"alerts of.",
|
||||
)
|
||||
|
||||
parser.add_argument("--expansions-file", "-e", help="Expansions file to read task info from.",
|
||||
default="../expansions.yml")
|
||||
parser.add_argument("--output-file", "-f", help="File to output errors to.",
|
||||
default="task_errors.txt")
|
||||
parser.add_argument(
|
||||
"--expansions-file",
|
||||
"-e",
|
||||
help="Expansions file to read task info from.",
|
||||
default="../expansions.yml",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-file",
|
||||
"-f",
|
||||
help="File to output errors to.",
|
||||
default="task_errors.txt",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
sys.exit(main(args.expansions_file, args.output_file))
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Script to generate & upload 'buildId -> debug symbols URL' mappings to symbolizer service."""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
@ -41,8 +42,13 @@ class CmdClient:
|
||||
:return: Command output.
|
||||
"""
|
||||
|
||||
out = subprocess.run(args, close_fds=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||
check=False)
|
||||
out = subprocess.run(
|
||||
args,
|
||||
close_fds=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
check=False,
|
||||
)
|
||||
return out.stdout.strip().decode()
|
||||
|
||||
|
||||
@ -73,8 +79,11 @@ class BinVersionOutput(NamedTuple):
|
||||
class CmdOutputExtractor:
|
||||
"""Data extractor from command output."""
|
||||
|
||||
def __init__(self, cmd_client: Optional[CmdClient] = None,
|
||||
json_decoder: Optional[JSONDecoder] = None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
cmd_client: Optional[CmdClient] = None,
|
||||
json_decoder: Optional[JSONDecoder] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize.
|
||||
|
||||
@ -117,10 +126,10 @@ class CmdOutputExtractor:
|
||||
build_id = None
|
||||
for line in out.splitlines():
|
||||
line = line.strip()
|
||||
if line.startswith('Build ID'):
|
||||
if line.startswith("Build ID"):
|
||||
if build_id is not None:
|
||||
raise ValueError("Found multiple Build ID values.")
|
||||
build_id = line.split(': ')[1]
|
||||
build_id = line.split(": ")[1]
|
||||
return build_id
|
||||
|
||||
def _get_mongodb_version(self, out: str) -> Optional[str]:
|
||||
@ -143,8 +152,13 @@ class CmdOutputExtractor:
|
||||
class DownloadOptions(object):
|
||||
"""A class to collect download option configurations."""
|
||||
|
||||
def __init__(self, download_binaries=False, download_symbols=False, download_artifacts=False,
|
||||
download_python_venv=False):
|
||||
def __init__(
|
||||
self,
|
||||
download_binaries=False,
|
||||
download_symbols=False,
|
||||
download_artifacts=False,
|
||||
download_python_venv=False,
|
||||
):
|
||||
"""Initialize instance."""
|
||||
|
||||
self.download_binaries = download_binaries
|
||||
@ -158,16 +172,26 @@ class Mapper:
|
||||
|
||||
# This amount of attributes are necessary.
|
||||
|
||||
default_web_service_base_url: str = "https://symbolizer-service.server-tig.prod.corp.mongodb.com"
|
||||
default_cache_dir = os.path.join(os.getcwd(), 'build', 'symbols_cache')
|
||||
selected_binaries = ('mongos', 'mongod', 'mongo')
|
||||
default_web_service_base_url: str = (
|
||||
"https://symbolizer-service.server-tig.prod.corp.mongodb.com"
|
||||
)
|
||||
default_cache_dir = os.path.join(os.getcwd(), "build", "symbols_cache")
|
||||
selected_binaries = ("mongos", "mongod", "mongo")
|
||||
default_client_credentials_scope = "servertig-symbolizer-fullaccess"
|
||||
default_client_credentials_user_name = "client-user"
|
||||
default_creds_file_path = os.path.join(os.getcwd(), '.symbolizer_credentials.json')
|
||||
default_creds_file_path = os.path.join(os.getcwd(), ".symbolizer_credentials.json")
|
||||
|
||||
def __init__(self, evg_version: str, evg_variant: str, is_san_variant: bool, client_id: str,
|
||||
client_secret: str, cache_dir: str = None, web_service_base_url: str = None,
|
||||
logger: logging.Logger = None):
|
||||
def __init__(
|
||||
self,
|
||||
evg_version: str,
|
||||
evg_variant: str,
|
||||
is_san_variant: bool,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
cache_dir: str = None,
|
||||
web_service_base_url: str = None,
|
||||
logger: logging.Logger = None,
|
||||
):
|
||||
"""
|
||||
Initialize instance.
|
||||
|
||||
@ -184,11 +208,13 @@ class Mapper:
|
||||
self.evg_variant = evg_variant
|
||||
self.is_san_variant = is_san_variant
|
||||
self.cache_dir = cache_dir or self.default_cache_dir
|
||||
self.web_service_base_url = web_service_base_url or self.default_web_service_base_url
|
||||
self.web_service_base_url = (
|
||||
web_service_base_url or self.default_web_service_base_url
|
||||
)
|
||||
|
||||
if not logger:
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger('symbolizer')
|
||||
logger = logging.getLogger("symbolizer")
|
||||
logger.setLevel(logging.INFO)
|
||||
self.logger = logger
|
||||
|
||||
@ -196,12 +222,15 @@ class Mapper:
|
||||
|
||||
self.multiversion_setup = SetupMultiversion(
|
||||
DownloadOptions(download_symbols=True, download_binaries=True),
|
||||
variant=self.evg_variant, ignore_failed_push=True)
|
||||
variant=self.evg_variant,
|
||||
ignore_failed_push=True,
|
||||
)
|
||||
self.debug_symbols_url = None
|
||||
self.url = None
|
||||
self.configs = Configs(
|
||||
client_credentials_scope=self.default_client_credentials_scope,
|
||||
client_credentials_user_name=self.default_client_credentials_user_name)
|
||||
client_credentials_user_name=self.default_client_credentials_user_name,
|
||||
)
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.path_options = PathOptions()
|
||||
@ -219,23 +248,34 @@ class Mapper:
|
||||
if os.path.exists(self.default_creds_file_path):
|
||||
with open(self.default_creds_file_path) as cfile:
|
||||
data = json.loads(cfile.read())
|
||||
access_token, expire_time = data.get("access_token"), data.get("expire_time")
|
||||
access_token, expire_time = (
|
||||
data.get("access_token"),
|
||||
data.get("expire_time"),
|
||||
)
|
||||
if time.time() < expire_time:
|
||||
# credentials haven't expired yet
|
||||
self.http_client.headers.update({"Authorization": f"Bearer {access_token}"})
|
||||
self.http_client.headers.update(
|
||||
{"Authorization": f"Bearer {access_token}"}
|
||||
)
|
||||
return
|
||||
|
||||
credentials = get_client_cred_oauth_credentials(self.client_id, self.client_secret,
|
||||
configs=self.configs)
|
||||
self.http_client.headers.update({"Authorization": f"Bearer {credentials.access_token}"})
|
||||
credentials = get_client_cred_oauth_credentials(
|
||||
self.client_id, self.client_secret, configs=self.configs
|
||||
)
|
||||
self.http_client.headers.update(
|
||||
{"Authorization": f"Bearer {credentials.access_token}"}
|
||||
)
|
||||
|
||||
# write credentials to local file for further usage
|
||||
with open(self.default_creds_file_path, "w") as cfile:
|
||||
cfile.write(
|
||||
json.dumps({
|
||||
"access_token": credentials.access_token,
|
||||
"expire_time": time.time() + credentials.expires_in
|
||||
}))
|
||||
json.dumps(
|
||||
{
|
||||
"access_token": credentials.access_token,
|
||||
"expire_time": time.time() + credentials.expires_in,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
def __enter__(self):
|
||||
"""Return instance when used as a context manager."""
|
||||
@ -261,7 +301,7 @@ class Mapper:
|
||||
:param url: download URL
|
||||
:return: full name for local file
|
||||
"""
|
||||
return url.split('/')[-1]
|
||||
return url.split("/")[-1]
|
||||
|
||||
def setup_urls(self):
|
||||
"""Set up URLs using multiversion."""
|
||||
@ -273,12 +313,16 @@ class Mapper:
|
||||
# Sanitizer builds are not stripped and contain debug symbols
|
||||
download_symbols_url = binaries_url
|
||||
else:
|
||||
download_symbols_url = urlinfo.urls.get("mongo-debugsymbols.tgz") or urlinfo.urls.get(
|
||||
"mongo-debugsymbols.zip")
|
||||
download_symbols_url = urlinfo.urls.get(
|
||||
"mongo-debugsymbols.tgz"
|
||||
) or urlinfo.urls.get("mongo-debugsymbols.zip")
|
||||
|
||||
if not download_symbols_url:
|
||||
self.logger.error("Couldn't find URL for debug symbols. Version: %s, URLs dict: %s",
|
||||
self.evg_version, urlinfo.urls)
|
||||
self.logger.error(
|
||||
"Couldn't find URL for debug symbols. Version: %s, URLs dict: %s",
|
||||
self.evg_version,
|
||||
urlinfo.urls,
|
||||
)
|
||||
raise ValueError(f"Debug symbols URL not found. URLs dict: {urlinfo.urls}")
|
||||
|
||||
self.debug_symbols_url = download_symbols_url
|
||||
@ -291,7 +335,7 @@ class Mapper:
|
||||
:param path: full path of file
|
||||
:return: full path of directory of unpacked file
|
||||
"""
|
||||
foldername = path.replace('.tgz', '', 1).split('/')[-1]
|
||||
foldername = path.replace(".tgz", "", 1).split("/")[-1]
|
||||
out_dir = os.path.join(self.cache_dir, foldername)
|
||||
|
||||
if not os.path.exists(out_dir):
|
||||
@ -335,26 +379,36 @@ class Mapper:
|
||||
# shared libraries folder holds shared libraries, tons of them.
|
||||
# some build variants do not contain shared libraries.
|
||||
|
||||
binaries_unpacked_path = os.path.join(binaries_unpacked_path, 'dist-test')
|
||||
binaries_unpacked_path = os.path.join(binaries_unpacked_path, "dist-test")
|
||||
|
||||
self.logger.info("INSIDE unpacked binaries/dist-test: %s",
|
||||
os.listdir(binaries_unpacked_path))
|
||||
self.logger.info(
|
||||
"INSIDE unpacked binaries/dist-test: %s", os.listdir(binaries_unpacked_path)
|
||||
)
|
||||
|
||||
mongod_bin = os.path.join(binaries_unpacked_path, self.path_options.main_binary_folder_name,
|
||||
MONGOD)
|
||||
mongod_bin = os.path.join(
|
||||
binaries_unpacked_path, self.path_options.main_binary_folder_name, MONGOD
|
||||
)
|
||||
bin_version_output = extractor.get_bin_version(mongod_bin)
|
||||
|
||||
if bin_version_output.mongodb_version is None:
|
||||
self.logger.error("mongodb version could not be extracted. \n`%s --version` output: %s",
|
||||
mongod_bin, bin_version_output.cmd_output)
|
||||
self.logger.error(
|
||||
"mongodb version could not be extracted. \n`%s --version` output: %s",
|
||||
mongod_bin,
|
||||
bin_version_output.cmd_output,
|
||||
)
|
||||
return
|
||||
else:
|
||||
self.logger.info("Extracted mongodb version: %s", bin_version_output.mongodb_version)
|
||||
self.logger.info(
|
||||
"Extracted mongodb version: %s", bin_version_output.mongodb_version
|
||||
)
|
||||
|
||||
# start with main binary folder
|
||||
for binary in self.selected_binaries:
|
||||
full_bin_path = os.path.join(binaries_unpacked_path,
|
||||
self.path_options.main_binary_folder_name, binary)
|
||||
full_bin_path = os.path.join(
|
||||
binaries_unpacked_path,
|
||||
self.path_options.main_binary_folder_name,
|
||||
binary,
|
||||
)
|
||||
|
||||
if not os.path.exists(full_bin_path):
|
||||
self.logger.error("Could not find binary at %s", full_bin_path)
|
||||
@ -363,35 +417,41 @@ class Mapper:
|
||||
build_id_output = extractor.get_build_id(full_bin_path)
|
||||
|
||||
if not build_id_output.build_id:
|
||||
self.logger.error("Build ID couldn't be extracted. \nReadELF output %s",
|
||||
build_id_output.cmd_output)
|
||||
self.logger.error(
|
||||
"Build ID couldn't be extracted. \nReadELF output %s",
|
||||
build_id_output.cmd_output,
|
||||
)
|
||||
continue
|
||||
else:
|
||||
self.logger.info("Extracted build ID: %s", build_id_output.build_id)
|
||||
|
||||
yield {
|
||||
'url': self.url,
|
||||
'debug_symbols_url': self.debug_symbols_url,
|
||||
'build_id': build_id_output.build_id,
|
||||
'file_name': binary,
|
||||
'version': bin_version_output.mongodb_version,
|
||||
"url": self.url,
|
||||
"debug_symbols_url": self.debug_symbols_url,
|
||||
"build_id": build_id_output.build_id,
|
||||
"file_name": binary,
|
||||
"version": bin_version_output.mongodb_version,
|
||||
}
|
||||
|
||||
# move to shared libraries folder.
|
||||
# it contains all shared library binary files,
|
||||
# we run readelf on each of them.
|
||||
lib_folder_path = os.path.join(binaries_unpacked_path,
|
||||
self.path_options.shared_library_folder_name)
|
||||
lib_folder_path = os.path.join(
|
||||
binaries_unpacked_path, self.path_options.shared_library_folder_name
|
||||
)
|
||||
|
||||
if not os.path.exists(lib_folder_path):
|
||||
# sometimes we don't get lib folder, which means there is no shared libraries for current build variant.
|
||||
self.logger.info("'%s' folder does not exist.",
|
||||
self.path_options.shared_library_folder_name)
|
||||
self.logger.info(
|
||||
"'%s' folder does not exist.",
|
||||
self.path_options.shared_library_folder_name,
|
||||
)
|
||||
sofiles = []
|
||||
else:
|
||||
sofiles = os.listdir(lib_folder_path)
|
||||
self.logger.info("'%s' folder: %s", self.path_options.shared_library_folder_name,
|
||||
sofiles)
|
||||
self.logger.info(
|
||||
"'%s' folder: %s", self.path_options.shared_library_folder_name, sofiles
|
||||
)
|
||||
|
||||
for sofile in sofiles:
|
||||
sofile_path = os.path.join(lib_folder_path, sofile)
|
||||
@ -403,18 +463,20 @@ class Mapper:
|
||||
build_id_output = extractor.get_build_id(sofile_path)
|
||||
|
||||
if not build_id_output.build_id:
|
||||
self.logger.error("Build ID couldn't be extracted. \nReadELF out %s",
|
||||
build_id_output.cmd_output)
|
||||
self.logger.error(
|
||||
"Build ID couldn't be extracted. \nReadELF out %s",
|
||||
build_id_output.cmd_output,
|
||||
)
|
||||
continue
|
||||
else:
|
||||
self.logger.info("Extracted build ID: %s", build_id_output.build_id)
|
||||
|
||||
yield {
|
||||
'url': self.url,
|
||||
'debug_symbols_url': self.debug_symbols_url,
|
||||
'build_id': build_id_output.build_id,
|
||||
'file_name': sofile,
|
||||
'version': bin_version_output.mongodb_version,
|
||||
"url": self.url,
|
||||
"debug_symbols_url": self.debug_symbols_url,
|
||||
"build_id": build_id_output.build_id,
|
||||
"file_name": sofile,
|
||||
"version": bin_version_output.mongodb_version,
|
||||
}
|
||||
|
||||
def run(self):
|
||||
@ -428,12 +490,17 @@ class Mapper:
|
||||
# mappings is a generator, we iterate over to generate mappings on the go
|
||||
for mapping in mappings:
|
||||
self.logger.info("Creating mapping %s", mapping)
|
||||
response = self.http_client.post('/'.join((self.web_service_base_url, 'add')),
|
||||
json=mapping)
|
||||
response = self.http_client.post(
|
||||
"/".join((self.web_service_base_url, "add")), json=mapping
|
||||
)
|
||||
if response.status_code != 200:
|
||||
self.logger.error(
|
||||
"Could not store mapping, web service returned status code %s from URL %s. "
|
||||
"Response: %s", response.status_code, response.url, response.text)
|
||||
"Response: %s",
|
||||
response.status_code,
|
||||
response.url,
|
||||
response.text,
|
||||
)
|
||||
|
||||
|
||||
def make_argument_parser(parser=None, **kwargs):
|
||||
@ -454,10 +521,14 @@ def make_argument_parser(parser=None, **kwargs):
|
||||
def main(options):
|
||||
"""Execute mapper here. Main entry point."""
|
||||
|
||||
mapper = Mapper(evg_version=options.version, evg_variant=options.variant,
|
||||
is_san_variant=options.is_san_variant, client_id=options.client_id,
|
||||
client_secret=options.client_secret,
|
||||
web_service_base_url=options.web_service_base_url)
|
||||
mapper = Mapper(
|
||||
evg_version=options.version,
|
||||
evg_variant=options.variant,
|
||||
is_san_variant=options.is_san_variant,
|
||||
client_id=options.client_id,
|
||||
client_secret=options.client_secret,
|
||||
web_service_base_url=options.web_service_base_url,
|
||||
)
|
||||
|
||||
# when used as a context manager, mapper instance automatically cleans files/folders after finishing its job.
|
||||
# in other cases, mapper.cleanup() method should be called manually.
|
||||
@ -465,6 +536,6 @@ def main(options):
|
||||
mapper.run()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
mapper_options = make_argument_parser(description=__doc__).parse_args()
|
||||
main(mapper_options)
|
||||
|
||||
@ -9,8 +9,14 @@ _MAX_RUNS = 1000
|
||||
|
||||
|
||||
def print_pb_version_ids(run_id: str):
|
||||
result = subprocess.run(f"evergreen list-patches -j -n {_MAX_RUNS * 2}", shell=True, check=True,
|
||||
stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
||||
result = subprocess.run(
|
||||
f"evergreen list-patches -j -n {_MAX_RUNS * 2}",
|
||||
shell=True,
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
data = json.loads(result.stdout)
|
||||
|
||||
filtered_objects = [
|
||||
@ -26,7 +32,7 @@ def print_pb_version_ids(run_id: str):
|
||||
"For analysis, visit",
|
||||
f"https://ui.honeycomb.io/mongodb-4b/environments/production?query=%7B%22time_range%22%3A86400%2C%22granularity%22%3A0%2C%22breakdowns%22%3A%5B%22evergreen.task.name%22%5D%2C%22calculations%22%3A%5B%7B%22op%22%3A%22COUNT%22%7D%5D%2C%22filters%22%3A%5B%7B%22column%22%3A%22evergreen.task.name%22%2C%22op%22%3A%22!%3D%22%2C%22value%22%3A%22lint_repo%22%7D%2C%7B%22column%22%3A%22evergreen.project.id%22%2C%22op%22%3A%22%3D%22%2C%22value%22%3A%22mongodb-mongo-master%22%7D%2C%7B%22column%22%3A%22evergreen.task.status%22%2C%22op%22%3A%22%3D%22%2C%22value%22%3A%22failed%22%7D%2C%7B%22column%22%3A%22evergreen.version.id%22%2C%22op%22%3A%22in%22%2C%22value%22%3A%5B%22\
|
||||
{'%22%2C%22'.join(versions)}\
|
||||
%22%5D%7D%5D%2C%22filter_combination%22%3A%22AND%22%2C%22orders%22%3A%5B%7B%22op%22%3A%22COUNT%22%2C%22order%22%3A%22descending%22%7D%5D%2C%22havings%22%3A%5B%5D%2C%22limit%22%3A1000%7D"
|
||||
%22%5D%7D%5D%2C%22filter_combination%22%3A%22AND%22%2C%22orders%22%3A%5B%7B%22op%22%3A%22COUNT%22%2C%22order%22%3A%22descending%22%7D%5D%2C%22havings%22%3A%5B%5D%2C%22limit%22%3A1000%7D",
|
||||
)
|
||||
print("---")
|
||||
|
||||
@ -38,11 +44,17 @@ def deflakinator(runs: int, evergreen_args: str):
|
||||
print("Kicking off evergreen patch builds:")
|
||||
print("---")
|
||||
|
||||
command = f"evergreen patch {evergreen_args} --yes --finalize --description \"flakinator run id: {run_id}\""
|
||||
command = f'evergreen patch {evergreen_args} --yes --finalize --description "flakinator run id: {run_id}"'
|
||||
for _ in range(runs):
|
||||
print(command)
|
||||
subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE, text=True)
|
||||
subprocess.run(
|
||||
command,
|
||||
shell=True,
|
||||
check=True,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
)
|
||||
|
||||
print_pb_version_ids(run_id)
|
||||
|
||||
@ -50,9 +62,16 @@ def deflakinator(runs: int, evergreen_args: str):
|
||||
def main():
|
||||
"""Run the main function."""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--runs", type=int, choices=range(1, _MAX_RUNS), metavar=f"[1-{_MAX_RUNS}]",
|
||||
help="Number of times to run each patch build")
|
||||
parser.add_argument("--evergreen-args", type=str, help="Arguments to pass to evergreen patch")
|
||||
parser.add_argument(
|
||||
"--runs",
|
||||
type=int,
|
||||
choices=range(1, _MAX_RUNS),
|
||||
metavar=f"[1-{_MAX_RUNS}]",
|
||||
help="Number of times to run each patch build",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--evergreen-args", type=str, help="Arguments to pass to evergreen patch"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
deflakinator(args.runs, args.evergreen_args)
|
||||
|
||||
@ -39,17 +39,23 @@ def main():
|
||||
operating_system = determine_platform()
|
||||
architechture = determine_architecture()
|
||||
if operating_system == "windows" and architechture == "arm64":
|
||||
raise RuntimeError("There are no published arm windows releases for buildifier.")
|
||||
raise RuntimeError(
|
||||
"There are no published arm windows releases for buildifier."
|
||||
)
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='DownloadBuildifier',
|
||||
description='This downloads buildifier, it is intended for use in evergreen.'
|
||||
'This is our temperary solution to get bazel linting while we determine a '
|
||||
'long-term solution for getting buildifier on evergreen/development hosts.')
|
||||
prog="DownloadBuildifier",
|
||||
description="This downloads buildifier, it is intended for use in evergreen."
|
||||
"This is our temperary solution to get bazel linting while we determine a "
|
||||
"long-term solution for getting buildifier on evergreen/development hosts.",
|
||||
)
|
||||
|
||||
parser.add_argument("--download-location", "-f",
|
||||
help="Name of directory to download the buildifier binary to.",
|
||||
default="./")
|
||||
parser.add_argument(
|
||||
"--download-location",
|
||||
"-f",
|
||||
help="Name of directory to download the buildifier binary to.",
|
||||
default="./",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@ -5,63 +5,67 @@ import argparse
|
||||
|
||||
import requests
|
||||
|
||||
BASE_URI = 'https://evergreen.mongodb.com/rest/v2/'
|
||||
BASE_URI = "https://evergreen.mongodb.com/rest/v2/"
|
||||
|
||||
|
||||
def _get_auth_headers(evergreen_api_user, evergreen_api_key):
|
||||
return {
|
||||
'Api-User': evergreen_api_user,
|
||||
'Api-Key': evergreen_api_key,
|
||||
"Api-User": evergreen_api_user,
|
||||
"Api-Key": evergreen_api_key,
|
||||
}
|
||||
|
||||
|
||||
def _get_build_id(build_variant_name, version_id, auth_headers):
|
||||
url = BASE_URI + 'versions/' + version_id
|
||||
url = BASE_URI + "versions/" + version_id
|
||||
response = requests.get(url, headers=auth_headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError('Invalid version_id:', version_id)
|
||||
raise ValueError("Invalid version_id:", version_id)
|
||||
|
||||
version_json = response.json()
|
||||
build_variants = version_json['build_variants_status']
|
||||
build_variants = version_json["build_variants_status"]
|
||||
for build_variant in build_variants:
|
||||
if build_variant['build_variant'] == build_variant_name:
|
||||
return build_variant['build_id']
|
||||
if build_variant["build_variant"] == build_variant_name:
|
||||
return build_variant["build_id"]
|
||||
|
||||
raise RuntimeError('The compile-variant ' + build_variant_name +
|
||||
' does not exist for the build with version_id ' + version_id)
|
||||
raise RuntimeError(
|
||||
"The compile-variant "
|
||||
+ build_variant_name
|
||||
+ " does not exist for the build with version_id "
|
||||
+ version_id
|
||||
)
|
||||
|
||||
|
||||
# All of our sys-perf compile variants have exactly one task,
|
||||
# so we can safely select the first one from the build variants tasks.
|
||||
def _get_task_id(build_id, auth_headers):
|
||||
url = BASE_URI + 'builds/' + build_id
|
||||
url = BASE_URI + "builds/" + build_id
|
||||
response = requests.get(url, headers=auth_headers)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError('Unexpected error when trying to reach' + url)
|
||||
raise RuntimeError("Unexpected error when trying to reach" + url)
|
||||
|
||||
build_json = response.json()
|
||||
task_list = build_json['tasks']
|
||||
task_list = build_json["tasks"]
|
||||
if len(task_list) != 1:
|
||||
raise RuntimeError('Recieved unexpected tasklist:', task_list)
|
||||
raise RuntimeError("Recieved unexpected tasklist:", task_list)
|
||||
return task_list[0]
|
||||
|
||||
|
||||
# The API used here always grabs the latest execution
|
||||
def _get_binary_details(task_id, auth_headers):
|
||||
url = BASE_URI + 'tasks/' + task_id
|
||||
url = BASE_URI + "tasks/" + task_id
|
||||
response = requests.get(url, headers=auth_headers)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError('Unexpected error when trying to reach' + url)
|
||||
raise RuntimeError("Unexpected error when trying to reach" + url)
|
||||
|
||||
task_json = response.json()
|
||||
if task_json['status'] != 'success':
|
||||
raise RuntimeError('The task ' + task_id + ' did not run sucessfully')
|
||||
if task_json["status"] != "success":
|
||||
raise RuntimeError("The task " + task_id + " did not run sucessfully")
|
||||
|
||||
# The binary will always be the first artifact, unless we make large changes to system_perf.yml
|
||||
artifacts = task_json['artifacts']
|
||||
if len(artifacts) > 0 and artifacts[0]['name'].startswith('mongo'):
|
||||
artifacts = task_json["artifacts"]
|
||||
if len(artifacts) > 0 and artifacts[0]["name"].startswith("mongo"):
|
||||
return artifacts[0]
|
||||
raise RuntimeError('Unexpected list of artifacts:' + artifacts)
|
||||
raise RuntimeError("Unexpected list of artifacts:" + artifacts)
|
||||
|
||||
|
||||
def _get_binary_url(version_id, build_variant, evergreen_api_user, evergreen_api_key):
|
||||
@ -69,36 +73,55 @@ def _get_binary_url(version_id, build_variant, evergreen_api_user, evergreen_api
|
||||
build_id = _get_build_id(build_variant, version_id, auth_headers)
|
||||
task_id = _get_task_id(build_id, auth_headers)
|
||||
binary_json = _get_binary_details(task_id, auth_headers)
|
||||
return binary_json['url']
|
||||
return binary_json["url"]
|
||||
|
||||
|
||||
def _download_binary_file(url, save_path):
|
||||
response = requests.get(url, stream=True)
|
||||
if response.status_code == 200:
|
||||
with open(save_path, 'wb') as file:
|
||||
with open(save_path, "wb") as file:
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
file.write(chunk)
|
||||
else:
|
||||
raise RuntimeError('Failed to download the file ' + url)
|
||||
raise RuntimeError("Failed to download the file " + url)
|
||||
|
||||
|
||||
def _download_sys_perf_binaries(version_id, build_variant, evergreen_api_user, evergreen_api_key):
|
||||
url = _get_binary_url(version_id, build_variant, evergreen_api_user, evergreen_api_key)
|
||||
_download_binary_file(url, 'binary.tar.gz')
|
||||
def _download_sys_perf_binaries(
|
||||
version_id, build_variant, evergreen_api_user, evergreen_api_key
|
||||
):
|
||||
url = _get_binary_url(
|
||||
version_id, build_variant, evergreen_api_user, evergreen_api_key
|
||||
)
|
||||
_download_binary_file(url, "binary.tar.gz")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
argParser = argparse.ArgumentParser()
|
||||
argParser.add_argument("-v", "--version_id",
|
||||
help="Evergreen version_id from which binaries will be downloaded")
|
||||
argParser.add_argument("-b", "--build_variant",
|
||||
help="Build variant for which binaries will be downloaded")
|
||||
argParser.add_argument(
|
||||
"-u", "--evergreen_api_user",
|
||||
help="Evergreen API user, see https://spruce.mongodb.com/preferences/cli")
|
||||
argParser.add_argument("-k", "--evergreen_api_key",
|
||||
help="Evergreen API key, see https://spruce.mongodb.com/preferences/cli")
|
||||
"-v",
|
||||
"--version_id",
|
||||
help="Evergreen version_id from which binaries will be downloaded",
|
||||
)
|
||||
argParser.add_argument(
|
||||
"-b",
|
||||
"--build_variant",
|
||||
help="Build variant for which binaries will be downloaded",
|
||||
)
|
||||
argParser.add_argument(
|
||||
"-u",
|
||||
"--evergreen_api_user",
|
||||
help="Evergreen API user, see https://spruce.mongodb.com/preferences/cli",
|
||||
)
|
||||
argParser.add_argument(
|
||||
"-k",
|
||||
"--evergreen_api_key",
|
||||
help="Evergreen API key, see https://spruce.mongodb.com/preferences/cli",
|
||||
)
|
||||
args = argParser.parse_args()
|
||||
|
||||
_download_sys_perf_binaries(args.version_id, args.build_variant, args.evergreen_api_user,
|
||||
args.evergreen_api_key)
|
||||
_download_sys_perf_binaries(
|
||||
args.version_id,
|
||||
args.build_variant,
|
||||
args.evergreen_api_user,
|
||||
args.evergreen_api_key,
|
||||
)
|
||||
|
||||
@ -28,12 +28,15 @@ MAXIMUM_CODE = 99999999 # JIRA Ticket + XX
|
||||
codes = [] # type: ignore
|
||||
|
||||
# Each AssertLocation identifies the C++ source location of an assertion
|
||||
AssertLocation = namedtuple("AssertLocation", ['sourceFile', 'byteOffset', 'lines', 'code'])
|
||||
AssertLocation = namedtuple(
|
||||
"AssertLocation", ["sourceFile", "byteOffset", "lines", "code"]
|
||||
)
|
||||
|
||||
list_files = False # pylint: disable=invalid-name
|
||||
|
||||
_CODE_PATTERNS = [
|
||||
re.compile(p + r'\s*(?P<code>\d+)', re.MULTILINE) for p in [
|
||||
re.compile(p + r"\s*(?P<code>\d+)", re.MULTILINE)
|
||||
for p in [
|
||||
# All the asserts and their optional variant suffixes
|
||||
r"(?:f|i|m|msg|t|u)(?:assert)"
|
||||
r"(?:ed)?"
|
||||
@ -56,20 +59,22 @@ _CODE_PATTERNS = [
|
||||
]
|
||||
]
|
||||
|
||||
_DIR_EXCLUDE_RE = re.compile(r'(\..*'
|
||||
r'|pcre2.*'
|
||||
r'|32bit.*'
|
||||
r'|mongodb-.*'
|
||||
r'|debian.*'
|
||||
r'|mongo-cxx-driver.*'
|
||||
r'|.*gotools.*'
|
||||
r'|.*mozjs.*'
|
||||
r')')
|
||||
_DIR_EXCLUDE_RE = re.compile(
|
||||
r"(\..*"
|
||||
r"|pcre2.*"
|
||||
r"|32bit.*"
|
||||
r"|mongodb-.*"
|
||||
r"|debian.*"
|
||||
r"|mongo-cxx-driver.*"
|
||||
r"|.*gotools.*"
|
||||
r"|.*mozjs.*"
|
||||
r")"
|
||||
)
|
||||
|
||||
_FILE_INCLUDE_RE = re.compile(r'.*\.(cpp|c|h|py|idl)')
|
||||
_FILE_INCLUDE_RE = re.compile(r".*\.(cpp|c|h|py|idl)")
|
||||
|
||||
|
||||
def get_all_source_files(prefix='.'):
|
||||
def get_all_source_files(prefix="."):
|
||||
"""Return source files."""
|
||||
|
||||
def walk(path):
|
||||
@ -92,8 +97,8 @@ def foreach_source_file(callback, src_root):
|
||||
"""Invoke a callback on the text of each source file."""
|
||||
for source_file in get_all_source_files(prefix=src_root):
|
||||
if list_files:
|
||||
print('scanning file: ' + source_file)
|
||||
with open(source_file, 'r', encoding='utf-8') as fh:
|
||||
print("scanning file: " + source_file)
|
||||
with open(source_file, "r", encoding="utf-8") as fh:
|
||||
callback(source_file, fh.read())
|
||||
|
||||
|
||||
@ -106,8 +111,12 @@ def parse_source_files(callback, src_root):
|
||||
# Note that this will include the text of the full match but will report the
|
||||
# position of the beginning of the code portion rather than the beginning of the
|
||||
# match. This is to position editors on the spot that needs to change.
|
||||
loc = AssertLocation(source_file, match.start('code'), match.group(0),
|
||||
match.group('code'))
|
||||
loc = AssertLocation(
|
||||
source_file,
|
||||
match.start("code"),
|
||||
match.group(0),
|
||||
match.group("code"),
|
||||
)
|
||||
callback(loc)
|
||||
|
||||
foreach_source_file(scan_for_codes, src_root)
|
||||
@ -134,7 +143,7 @@ def get_line_and_column_for_position(loc, _file_cache=None):
|
||||
def is_terminated(lines):
|
||||
"""Determine if assert is terminated, from .cpp/.h source lines as text."""
|
||||
code_block = " ".join(lines)
|
||||
return ';' in code_block or code_block.count('(') - code_block.count(')') <= 0
|
||||
return ";" in code_block or code_block.count("(") - code_block.count(")") <= 0
|
||||
|
||||
|
||||
def get_next_code(seen, server_ticket=0):
|
||||
@ -167,7 +176,7 @@ def check_error_codes():
|
||||
return len(errors) == 0
|
||||
|
||||
|
||||
def read_error_codes(src_root='src/mongo'):
|
||||
def read_error_codes(src_root="src/mongo"):
|
||||
"""Define callback, call parse_source_files() with callback, save matches to global codes list."""
|
||||
seen = {}
|
||||
errors = []
|
||||
@ -247,7 +256,9 @@ def replace_bad_codes(errors, next_code_generator):
|
||||
|
||||
for loc in skip_errors:
|
||||
line, col = get_line_and_column_for_position(loc)
|
||||
print("SKIPPING NONZERO code=%s: %s:%d:%d" % (loc.code, loc.sourceFile, line, col))
|
||||
print(
|
||||
"SKIPPING NONZERO code=%s: %s:%d:%d" % (loc.code, loc.sourceFile, line, col)
|
||||
)
|
||||
|
||||
# Dedupe, sort, and reverse so we don't have to update offsets as we go.
|
||||
for assert_loc in reversed(sorted(set(zero_errors))):
|
||||
@ -257,16 +268,16 @@ def replace_bad_codes(errors, next_code_generator):
|
||||
|
||||
ln = line_num - 1
|
||||
|
||||
with open(source_file, 'r+') as fh:
|
||||
with open(source_file, "r+") as fh:
|
||||
print("LINE_%d_BEFORE:%s" % (line_num, fh.readlines()[ln].rstrip()))
|
||||
|
||||
fh.seek(0)
|
||||
text = fh.read()
|
||||
assert text[byte_offset] == '0'
|
||||
assert text[byte_offset] == "0"
|
||||
fh.seek(0)
|
||||
fh.write(text[:byte_offset])
|
||||
fh.write(str(next(next_code_generator)))
|
||||
fh.write(text[byte_offset + 1:])
|
||||
fh.write(text[byte_offset + 1 :])
|
||||
fh.seek(0)
|
||||
|
||||
print("LINE_%d_AFTER :%s" % (line_num, fh.readlines()[ln].rstrip()))
|
||||
@ -281,7 +292,7 @@ def coerce_to_number(ticket_value):
|
||||
if isinstance(ticket_value, int):
|
||||
return ticket_value
|
||||
|
||||
ticket_re = re.compile(r'(?:SERVER-)?(\d+)', re.IGNORECASE)
|
||||
ticket_re = re.compile(r"(?:SERVER-)?(\d+)", re.IGNORECASE)
|
||||
matches = ticket_re.fullmatch(ticket_value)
|
||||
if not matches:
|
||||
print("Unknown ticket number. Input: " + ticket_value)
|
||||
@ -293,16 +304,37 @@ def coerce_to_number(ticket_value):
|
||||
def main():
|
||||
"""Validate error codes."""
|
||||
parser = OptionParser(description=__doc__.strip())
|
||||
parser.add_option("--fix", dest="replace", action="store_true", default=False,
|
||||
help="Fix zero codes in source files [default: %default]")
|
||||
parser.add_option("-q", "--quiet", dest="quiet", action="store_true", default=False,
|
||||
help="Suppress output on success [default: %default]")
|
||||
parser.add_option("--list-files", dest="list_files", action="store_true", default=False,
|
||||
help="Print the name of each file as it is scanned [default: %default]")
|
||||
parser.add_option(
|
||||
"--ticket", dest="ticket", type="str", action="store", default=None,
|
||||
"--fix",
|
||||
dest="replace",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Fix zero codes in source files [default: %default]",
|
||||
)
|
||||
parser.add_option(
|
||||
"-q",
|
||||
"--quiet",
|
||||
dest="quiet",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Suppress output on success [default: %default]",
|
||||
)
|
||||
parser.add_option(
|
||||
"--list-files",
|
||||
dest="list_files",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Print the name of each file as it is scanned [default: %default]",
|
||||
)
|
||||
parser.add_option(
|
||||
"--ticket",
|
||||
dest="ticket",
|
||||
type="str",
|
||||
action="store",
|
||||
default=None,
|
||||
help="Generate error codes for a given SERVER ticket number. Inputs can be of"
|
||||
" the form: `--ticket=12345` or `--ticket=SERVER-12345`.")
|
||||
" the form: `--ticket=12345` or `--ticket=SERVER-12345`.",
|
||||
)
|
||||
options, extra = parser.parse_args()
|
||||
if extra:
|
||||
parser.error(f"Unrecognized arguments: {' '.join(extra)}")
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Activate an evergreen task in the existing build."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
@ -63,7 +64,8 @@ def activate_task(expansions: EvgExpansions, evg_api: EvergreenApi) -> None:
|
||||
if expansions.task == BURN_IN_TAGS:
|
||||
version = evg_api.version_by_id(expansions.version_id)
|
||||
burn_in_build_variants = [
|
||||
variant for variant in version.build_variants_map.keys()
|
||||
variant
|
||||
for variant in version.build_variants_map.keys()
|
||||
if variant.endswith(BURN_IN_VARIANT_SUFFIX)
|
||||
]
|
||||
for build_variant in burn_in_build_variants:
|
||||
@ -72,29 +74,42 @@ def activate_task(expansions: EvgExpansions, evg_api: EvergreenApi) -> None:
|
||||
|
||||
for task in task_list:
|
||||
if task.display_name == BURN_IN_TESTS:
|
||||
LOGGER.info("Activating task", task_id=task.task_id,
|
||||
task_name=task.display_name)
|
||||
LOGGER.info(
|
||||
"Activating task",
|
||||
task_id=task.task_id,
|
||||
task_name=task.display_name,
|
||||
)
|
||||
try:
|
||||
evg_api.configure_task(task.task_id, activated=True)
|
||||
except Exception:
|
||||
LOGGER.warning("Could not activate task", task_id=task.task_id,
|
||||
task_name=task.display_name)
|
||||
LOGGER.warning(
|
||||
"Could not activate task",
|
||||
task_id=task.task_id,
|
||||
task_name=task.display_name,
|
||||
)
|
||||
tasks_not_activated.append(task.task_id)
|
||||
|
||||
else:
|
||||
task_list = evg_api.tasks_by_build(expansions.build_id)
|
||||
for task in task_list:
|
||||
if task.display_name == expansions.task:
|
||||
LOGGER.info("Activating task", task_id=task.task_id, task_name=task.display_name)
|
||||
LOGGER.info(
|
||||
"Activating task", task_id=task.task_id, task_name=task.display_name
|
||||
)
|
||||
try:
|
||||
evg_api.configure_task(task.task_id, activated=True)
|
||||
except Exception:
|
||||
LOGGER.warning("Could not activate task", task_id=task.task_id,
|
||||
task_name=task.display_name)
|
||||
LOGGER.warning(
|
||||
"Could not activate task",
|
||||
task_id=task.task_id,
|
||||
task_name=task.display_name,
|
||||
)
|
||||
tasks_not_activated.append(task.task_id)
|
||||
if len(tasks_not_activated) > 0:
|
||||
LOGGER.error("Some tasks were unable to be activated",
|
||||
unactivated_tasks=len(tasks_not_activated))
|
||||
LOGGER.error(
|
||||
"Some tasks were unable to be activated",
|
||||
unactivated_tasks=len(tasks_not_activated),
|
||||
)
|
||||
raise ValueError(
|
||||
"Some tasks were unable to be activated, failing the task to let the author know. "
|
||||
"This should not be a blocking issue but may mean that some tasks are missing from your patch."
|
||||
@ -102,10 +117,18 @@ def activate_task(expansions: EvgExpansions, evg_api: EvergreenApi) -> None:
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option("--expansion-file", type=str, required=True,
|
||||
help="Location of expansions file generated by evergreen.")
|
||||
@click.option("--evergreen-config", type=str, default=EVG_CONFIG_FILE,
|
||||
help="Location of evergreen configuration file.")
|
||||
@click.option(
|
||||
"--expansion-file",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Location of expansions file generated by evergreen.",
|
||||
)
|
||||
@click.option(
|
||||
"--evergreen-config",
|
||||
type=str,
|
||||
default=EVG_CONFIG_FILE,
|
||||
help="Location of evergreen configuration file.",
|
||||
)
|
||||
@click.option("--verbose", is_flag=True, default=False, help="Enable verbose logging.")
|
||||
def main(expansion_file: str, evergreen_config: str, verbose: bool) -> None:
|
||||
"""
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
"""Convert Evergreen's expansions.yml to an eval-able shell script."""
|
||||
|
||||
import sys
|
||||
from shlex import quote
|
||||
|
||||
@ -12,16 +13,20 @@ try:
|
||||
import click
|
||||
import yaml
|
||||
except ModuleNotFoundError:
|
||||
_error("ERROR: Failed to import a dependency. This is almost certainly because "
|
||||
"the task did not initialize the venv immediately after cloning the repository.")
|
||||
_error(
|
||||
"ERROR: Failed to import a dependency. This is almost certainly because "
|
||||
"the task did not initialize the venv immediately after cloning the repository."
|
||||
)
|
||||
|
||||
|
||||
def _load_defaults(defaults_file: str) -> dict:
|
||||
with open(defaults_file) as fh:
|
||||
defaults = yaml.safe_load(fh)
|
||||
if not isinstance(defaults, dict):
|
||||
_error("ERROR: expected to read a dictionary. expansions.defaults.yml"
|
||||
"must be a dictionary. Check the indentation.")
|
||||
_error(
|
||||
"ERROR: expected to read a dictionary. expansions.defaults.yml"
|
||||
"must be a dictionary. Check the indentation."
|
||||
)
|
||||
|
||||
# expansions MUST be strings. Reject any that are not
|
||||
bad_expansions = set()
|
||||
@ -30,11 +35,13 @@ def _load_defaults(defaults_file: str) -> dict:
|
||||
bad_expansions.add(key)
|
||||
|
||||
if bad_expansions:
|
||||
_error("ERROR: all default expansions must be strings. You can "
|
||||
" fix this error by quoting the values in expansions.defaults.yml. "
|
||||
"Integers, floating points, 'true', 'false', and 'null' "
|
||||
"must be quoted. The following keys were interpreted as "
|
||||
f"other types: {bad_expansions}")
|
||||
_error(
|
||||
"ERROR: all default expansions must be strings. You can "
|
||||
" fix this error by quoting the values in expansions.defaults.yml. "
|
||||
"Integers, floating points, 'true', 'false', and 'null' "
|
||||
"must be quoted. The following keys were interpreted as "
|
||||
f"other types: {bad_expansions}"
|
||||
)
|
||||
|
||||
# These values show up if 1. Python's str is used to naively convert
|
||||
# a boolean to str, 2. A human manually entered one of those strings.
|
||||
@ -48,11 +55,13 @@ def _load_defaults(defaults_file: str) -> dict:
|
||||
risky_boolean_keys.add(key)
|
||||
|
||||
if risky_boolean_keys:
|
||||
_error("ERROR: Found keys which had 'True' or 'False' as values. "
|
||||
"Shell scripts assume that booleans are represented as 'true'"
|
||||
" or 'false' (leading lowercase). If you added a new boolean, "
|
||||
"ensure that it's represented in lowercase. If not, please report this in "
|
||||
f"#server-testing. Keys with bad values: {risky_boolean_keys}")
|
||||
_error(
|
||||
"ERROR: Found keys which had 'True' or 'False' as values. "
|
||||
"Shell scripts assume that booleans are represented as 'true'"
|
||||
" or 'false' (leading lowercase). If you added a new boolean, "
|
||||
"ensure that it's represented in lowercase. If not, please report this in "
|
||||
f"#server-testing. Keys with bad values: {risky_boolean_keys}"
|
||||
)
|
||||
|
||||
return defaults
|
||||
|
||||
@ -62,8 +71,10 @@ def _load_expansions(expansions_file) -> dict:
|
||||
expansions = yaml.safe_load(fh)
|
||||
|
||||
if not isinstance(expansions, dict):
|
||||
_error("ERROR: expected to read a dictionary. Has the output format "
|
||||
"of expansions.write changed?")
|
||||
_error(
|
||||
"ERROR: expected to read a dictionary. Has the output format "
|
||||
"of expansions.write changed?"
|
||||
)
|
||||
|
||||
if not expansions:
|
||||
_error("ERROR: found 0 expansions. This is almost certainly wrong.")
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate multiple powercycle tasks to run in evergreen."""
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import Any, List, Set, Tuple
|
||||
|
||||
@ -18,17 +19,20 @@ from buildscripts.util.fileops import write_file
|
||||
from buildscripts.util.read_config import read_config_file
|
||||
from buildscripts.util.taskname import name_generated_task
|
||||
|
||||
Config = namedtuple("config", [
|
||||
"current_task_name",
|
||||
"task_names",
|
||||
"num_tasks",
|
||||
"timeout_params",
|
||||
"remote_credentials_vars",
|
||||
"set_up_ec2_instance_vars",
|
||||
"run_powercycle_vars",
|
||||
"build_variant",
|
||||
"distro",
|
||||
])
|
||||
Config = namedtuple(
|
||||
"config",
|
||||
[
|
||||
"current_task_name",
|
||||
"task_names",
|
||||
"num_tasks",
|
||||
"timeout_params",
|
||||
"remote_credentials_vars",
|
||||
"set_up_ec2_instance_vars",
|
||||
"run_powercycle_vars",
|
||||
"build_variant",
|
||||
"distro",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def make_config(expansions_file: Any) -> Config:
|
||||
@ -58,8 +62,17 @@ def make_config(expansions_file: Any) -> Config:
|
||||
build_variant = expansions.get("build_variant")
|
||||
distro = expansions.get("distro_id")
|
||||
|
||||
return Config(current_task_name, task_names, num_tasks, timeout_params, remote_credentials_vars,
|
||||
set_up_ec2_instance_vars, run_powercycle_vars, build_variant, distro)
|
||||
return Config(
|
||||
current_task_name,
|
||||
task_names,
|
||||
num_tasks,
|
||||
timeout_params,
|
||||
remote_credentials_vars,
|
||||
set_up_ec2_instance_vars,
|
||||
run_powercycle_vars,
|
||||
build_variant,
|
||||
distro,
|
||||
)
|
||||
|
||||
|
||||
def get_setup_commands() -> Tuple[List[FunctionCall], Set[TaskDependency]]:
|
||||
@ -87,8 +100,9 @@ def get_skip_compile_setup_commands() -> Tuple[List[FunctionCall], set]:
|
||||
@click.command()
|
||||
@click.argument("expansions_file", type=str, default="expansions.yml")
|
||||
@click.argument("output_file", type=str, default="powercycle_tasks.json")
|
||||
def main(expansions_file: str = "expansions.yml",
|
||||
output_file: str = "powercycle_tasks.json") -> None:
|
||||
def main(
|
||||
expansions_file: str = "expansions.yml", output_file: str = "powercycle_tasks.json"
|
||||
) -> None:
|
||||
"""Generate multiple powercycle tasks to run in evergreen."""
|
||||
|
||||
config = make_config(expansions_file)
|
||||
@ -101,28 +115,41 @@ def main(expansions_file: str = "expansions.yml",
|
||||
else:
|
||||
commands, task_dependency = get_setup_commands()
|
||||
|
||||
commands.extend([
|
||||
FunctionCall("set up remote credentials", config.remote_credentials_vars),
|
||||
BuiltInCommand("timeout.update", config.timeout_params),
|
||||
FunctionCall("set up EC2 instance", config.set_up_ec2_instance_vars),
|
||||
FunctionCall("run powercycle test", config.run_powercycle_vars),
|
||||
])
|
||||
commands.extend(
|
||||
[
|
||||
FunctionCall(
|
||||
"set up remote credentials", config.remote_credentials_vars
|
||||
),
|
||||
BuiltInCommand("timeout.update", config.timeout_params),
|
||||
FunctionCall("set up EC2 instance", config.set_up_ec2_instance_vars),
|
||||
FunctionCall("run powercycle test", config.run_powercycle_vars),
|
||||
]
|
||||
)
|
||||
|
||||
sub_tasks.update({
|
||||
Task(
|
||||
name_generated_task(task_name, index, config.num_tasks, config.build_variant),
|
||||
commands, task_dependency)
|
||||
for index in range(config.num_tasks)
|
||||
})
|
||||
sub_tasks.update(
|
||||
{
|
||||
Task(
|
||||
name_generated_task(
|
||||
task_name, index, config.num_tasks, config.build_variant
|
||||
),
|
||||
commands,
|
||||
task_dependency,
|
||||
)
|
||||
for index in range(config.num_tasks)
|
||||
}
|
||||
)
|
||||
|
||||
build_variant.display_task(
|
||||
config.current_task_name.replace("_gen", ""), sub_tasks, distros=[config.distro],
|
||||
execution_existing_tasks={ExistingTask(config.current_task_name)})
|
||||
config.current_task_name.replace("_gen", ""),
|
||||
sub_tasks,
|
||||
distros=[config.distro],
|
||||
execution_existing_tasks={ExistingTask(config.current_task_name)},
|
||||
)
|
||||
shrub_project = ShrubProject.empty()
|
||||
shrub_project.add_build_variant(build_variant)
|
||||
|
||||
write_file(output_file, shrub_project.json())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -30,40 +30,54 @@ SYS_PLATFORM = sys.platform
|
||||
|
||||
# Apply factor for a task based on the build variant it is running on.
|
||||
VARIANT_TASK_FACTOR_OVERRIDES = {
|
||||
"enterprise-rhel-8-64-bit": [{"task": r"logical_session_cache_replication.*", "factor": 0.75}],
|
||||
"enterprise-rhel-8-64-bit": [
|
||||
{"task": r"logical_session_cache_replication.*", "factor": 0.75}
|
||||
],
|
||||
"enterprise-rhel-8-64-bit-inmem": [
|
||||
{"task": "secondary_reads_passthrough", "factor": 0.3},
|
||||
{"task": "multi_stmt_txn_jscore_passthrough_with_migration", "factor": 0.3},
|
||||
],
|
||||
"enterprise-rhel8-debug-tsan": [{
|
||||
"task": r"shard.*uninitialized_fcv_jscore_passthrough.*", "factor": 0.125
|
||||
}],
|
||||
"rhel8-debug-aubsan-classic-engine": [{
|
||||
"task": r"shard.*uninitialized_fcv_jscore_passthrough.*", "factor": 0.25
|
||||
}],
|
||||
"rhel8-debug-aubsan-all-feature-flags": [{
|
||||
"task": r"shard.*uninitialized_fcv_jscore_passthrough.*", "factor": 0.25
|
||||
}],
|
||||
"enterprise-rhel8-debug-tsan": [
|
||||
{"task": r"shard.*uninitialized_fcv_jscore_passthrough.*", "factor": 0.125}
|
||||
],
|
||||
"rhel8-debug-aubsan-classic-engine": [
|
||||
{"task": r"shard.*uninitialized_fcv_jscore_passthrough.*", "factor": 0.25}
|
||||
],
|
||||
"rhel8-debug-aubsan-all-feature-flags": [
|
||||
{"task": r"shard.*uninitialized_fcv_jscore_passthrough.*", "factor": 0.25}
|
||||
],
|
||||
# TODO(SERVER-91466): figure out why noPassthrough tests are taking up more memory after switching
|
||||
# from Windows Server 2019 to Windows Server 2022
|
||||
"enterprise-windows-all-feature-flags-required": [{"task": "noPassthrough", "factor": 0.5}],
|
||||
"enterprise-windows-all-feature-flags-required": [
|
||||
{"task": "noPassthrough", "factor": 0.5}
|
||||
],
|
||||
"enterprise-windows": [{"task": "noPassthrough", "factor": 0.5}],
|
||||
"windows-debug-suggested": [{"task": "noPassthrough", "factor": 0.5}],
|
||||
"windows": [{"task": "noPassthrough", "factor": 0.5}],
|
||||
}
|
||||
|
||||
TASKS_FACTORS = [{"task": r"replica_sets.*", "factor": 0.5}, {"task": r"sharding.*", "factor": 0.5}]
|
||||
TASKS_FACTORS = [
|
||||
{"task": r"replica_sets.*", "factor": 0.5},
|
||||
{"task": r"sharding.*", "factor": 0.5},
|
||||
]
|
||||
|
||||
DISTRO_MULTIPLIERS = {"rhel8.8-large": 1.618}
|
||||
|
||||
# Apply factor for a task based on the machine type it is running on.
|
||||
MACHINE_TASK_FACTOR_OVERRIDES = {
|
||||
"aarch64":
|
||||
TASKS_FACTORS,
|
||||
"aarch64": TASKS_FACTORS,
|
||||
"ppc64le": [
|
||||
dict(task=r"causally_consistent_hedged_reads_jscore_passthrough.*", factor=0.125),
|
||||
dict(task=r"causally_consistent_read_concern_snapshot_passthrough.*", factor=0.125),
|
||||
dict(task=r"sharded_causally_consistent_read_concern_snapshot_passthrough.*", factor=0.125),
|
||||
dict(
|
||||
task=r"causally_consistent_hedged_reads_jscore_passthrough.*", factor=0.125
|
||||
),
|
||||
dict(
|
||||
task=r"causally_consistent_read_concern_snapshot_passthrough.*",
|
||||
factor=0.125,
|
||||
),
|
||||
dict(
|
||||
task=r"sharded_causally_consistent_read_concern_snapshot_passthrough.*",
|
||||
factor=0.125,
|
||||
),
|
||||
],
|
||||
}
|
||||
|
||||
@ -117,8 +131,12 @@ def determine_final_multiplier(distro):
|
||||
def determine_factor(task_name, variant, distro, factor):
|
||||
"""Determine the job factor."""
|
||||
factors = [
|
||||
get_task_factor(task_name, MACHINE_TASK_FACTOR_OVERRIDES, PLATFORM_MACHINE, factor),
|
||||
get_task_factor(task_name, PLATFORM_TASK_FACTOR_OVERRIDES, SYS_PLATFORM, factor),
|
||||
get_task_factor(
|
||||
task_name, MACHINE_TASK_FACTOR_OVERRIDES, PLATFORM_MACHINE, factor
|
||||
),
|
||||
get_task_factor(
|
||||
task_name, PLATFORM_TASK_FACTOR_OVERRIDES, SYS_PLATFORM, factor
|
||||
),
|
||||
get_task_factor(task_name, VARIANT_TASK_FACTOR_OVERRIDES, variant, factor),
|
||||
global_task_factor(task_name, GLOBAL_TASK_FACTOR_OVERRIDES, factor),
|
||||
]
|
||||
@ -153,34 +171,71 @@ def main():
|
||||
"""Determine the resmoke jobs value a task should use in Evergreen."""
|
||||
parser = argparse.ArgumentParser(description=main.__doc__)
|
||||
|
||||
parser.add_argument("--taskName", dest="task", required=True, help="Task being executed.")
|
||||
parser.add_argument("--buildVariant", dest="variant", required=True,
|
||||
help="Build variant task is being executed on.")
|
||||
parser.add_argument("--distro", dest="distro", required=True,
|
||||
help="Distro task is being executed on.")
|
||||
parser.add_argument(
|
||||
"--jobFactor", dest="jobs_factor", type=float, default=1.0,
|
||||
help=("Job factor to use as a mulitplier with the number of CPUs. Defaults"
|
||||
" to %(default)s."))
|
||||
"--taskName", dest="task", required=True, help="Task being executed."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--jobsMax", dest="jobs_max", type=int, default=0,
|
||||
help=("Maximum number of jobs to use. Specify 0 to indicate the number of"
|
||||
" jobs is determined by --jobFactor and the number of CPUs. Defaults"
|
||||
" to %(default)s."))
|
||||
"--buildVariant",
|
||||
dest="variant",
|
||||
required=True,
|
||||
help="Build variant task is being executed on.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outFile", dest="outfile", help=("File to write configuration to. If"
|
||||
" unspecified no file is generated."))
|
||||
"--distro",
|
||||
dest="distro",
|
||||
required=True,
|
||||
help="Distro task is being executed on.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--jobFactor",
|
||||
dest="jobs_factor",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help=(
|
||||
"Job factor to use as a mulitplier with the number of CPUs. Defaults"
|
||||
" to %(default)s."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--jobsMax",
|
||||
dest="jobs_max",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Maximum number of jobs to use. Specify 0 to indicate the number of"
|
||||
" jobs is determined by --jobFactor and the number of CPUs. Defaults"
|
||||
" to %(default)s."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outFile",
|
||||
dest="outfile",
|
||||
help=(
|
||||
"File to write configuration to. If" " unspecified no file is generated."
|
||||
),
|
||||
)
|
||||
|
||||
options = parser.parse_args()
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
|
||||
structlog.configure(logger_factory=structlog.stdlib.LoggerFactory())
|
||||
|
||||
LOGGER.info("Finding job count", task=options.task, variant=options.variant,
|
||||
platform=PLATFORM_MACHINE, sys=SYS_PLATFORM, cpu_count=CPU_COUNT)
|
||||
LOGGER.info(
|
||||
"Finding job count",
|
||||
task=options.task,
|
||||
variant=options.variant,
|
||||
platform=PLATFORM_MACHINE,
|
||||
sys=SYS_PLATFORM,
|
||||
cpu_count=CPU_COUNT,
|
||||
)
|
||||
|
||||
jobs = determine_jobs(options.task, options.variant, options.distro, options.jobs_max,
|
||||
options.jobs_factor)
|
||||
jobs = determine_jobs(
|
||||
options.task,
|
||||
options.variant,
|
||||
options.distro,
|
||||
options.jobs_max,
|
||||
options.jobs_factor,
|
||||
)
|
||||
if jobs < CPU_COUNT:
|
||||
print("Reducing number of jobs to run from {} to {}".format(CPU_COUNT, jobs))
|
||||
output_jobs(jobs, options.outfile)
|
||||
|
||||
@ -20,17 +20,42 @@ def parse_command_line():
|
||||
"""Parse command line options."""
|
||||
parser = argparse.ArgumentParser(description=main.__doc__)
|
||||
|
||||
parser.add_argument("--list-tags", action="store_true", default=False,
|
||||
help="List all tags used by tasks in evergreen yml.")
|
||||
parser.add_argument("--list-tasks", type=str, help="List all tasks for the given buildvariant.")
|
||||
parser.add_argument("--list-variants-and-tasks", action="store_true",
|
||||
help="List all tasks for every buildvariant.")
|
||||
parser.add_argument("-t", "--tasks-for-tag", type=str, default=None, action="append",
|
||||
help="List all tasks that use the given tag.")
|
||||
parser.add_argument("-x", "--remove-tasks-for-tag-filter", type=str, default=None,
|
||||
action="append", help="Remove tasks tagged with given tag.")
|
||||
parser.add_argument("--evergreen-file", type=str, default=DEFAULT_EVERGREEN_FILE,
|
||||
help="Location of evergreen file.")
|
||||
parser.add_argument(
|
||||
"--list-tags",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="List all tags used by tasks in evergreen yml.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list-tasks", type=str, help="List all tasks for the given buildvariant."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--list-variants-and-tasks",
|
||||
action="store_true",
|
||||
help="List all tasks for every buildvariant.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--tasks-for-tag",
|
||||
type=str,
|
||||
default=None,
|
||||
action="append",
|
||||
help="List all tasks that use the given tag.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-x",
|
||||
"--remove-tasks-for-tag-filter",
|
||||
type=str,
|
||||
default=None,
|
||||
action="append",
|
||||
help="Remove tasks tagged with given tag.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--evergreen-file",
|
||||
type=str,
|
||||
default=DEFAULT_EVERGREEN_FILE,
|
||||
help="Location of evergreen file.",
|
||||
)
|
||||
|
||||
options = parser.parse_args()
|
||||
|
||||
@ -116,7 +141,9 @@ def get_tasks_with_tag(evg_config, tags, filters):
|
||||
:param filters: lst of tags to filter out.
|
||||
:return: list of tasks marked with the given tag.
|
||||
"""
|
||||
return sorted([task.name for task in evg_config.tasks if is_task_tagged(task, tags, filters)])
|
||||
return sorted(
|
||||
[task.name for task in evg_config.tasks if is_task_tagged(task, tags, filters)]
|
||||
)
|
||||
|
||||
|
||||
def list_tasks_with_tag(evg_config, tags, filters):
|
||||
@ -148,7 +175,9 @@ def main():
|
||||
list_all_tasks(evg_config, options.list_tasks)
|
||||
|
||||
if options.tasks_for_tag:
|
||||
list_tasks_with_tag(evg_config, options.tasks_for_tag, options.remove_tasks_for_tag_filter)
|
||||
list_tasks_with_tag(
|
||||
evg_config, options.tasks_for_tag, options.remove_tasks_for_tag_filter
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Determine the timeout value a task should use in evergreen."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
@ -59,8 +60,12 @@ class TimeoutOverride(BaseModel):
|
||||
idle_timeout: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def from_seconds(cls, task: str, exec_timeout_secs: Optional[float],
|
||||
idle_timeout_secs: Optional[float]) -> TimeoutOverride:
|
||||
def from_seconds(
|
||||
cls,
|
||||
task: str,
|
||||
exec_timeout_secs: Optional[float],
|
||||
idle_timeout_secs: Optional[float],
|
||||
) -> TimeoutOverride:
|
||||
"""Create an instance of an override from seconds."""
|
||||
exec_timeout = exec_timeout_secs / 60 if exec_timeout_secs else None
|
||||
idle_timeout = idle_timeout_secs / 60 if idle_timeout_secs else None
|
||||
@ -94,7 +99,9 @@ class TimeoutOverrides(BaseModel):
|
||||
with open(file_path) as file_handler:
|
||||
return cls(**yaml.safe_load(file_handler))
|
||||
|
||||
def _lookup_override(self, build_variant: str, task_name: str) -> Optional[TimeoutOverride]:
|
||||
def _lookup_override(
|
||||
self, build_variant: str, task_name: str
|
||||
) -> Optional[TimeoutOverride]:
|
||||
"""
|
||||
Check if the given task on the given build variant has an override defined.
|
||||
|
||||
@ -105,19 +112,27 @@ class TimeoutOverrides(BaseModel):
|
||||
:return: Timeout override if found.
|
||||
"""
|
||||
overrides = [
|
||||
override for override in self.overrides.get(build_variant, [])
|
||||
override
|
||||
for override in self.overrides.get(build_variant, [])
|
||||
if override.task == task_name
|
||||
]
|
||||
if overrides:
|
||||
if len(overrides) > 1:
|
||||
LOGGER.error("Found multiple overrides for the same task",
|
||||
build_variant=build_variant, task=task_name,
|
||||
overrides=[override.dict() for override in overrides])
|
||||
raise ValueError(f"Found multiple overrides for '{task_name}' on '{build_variant}'")
|
||||
LOGGER.error(
|
||||
"Found multiple overrides for the same task",
|
||||
build_variant=build_variant,
|
||||
task=task_name,
|
||||
overrides=[override.dict() for override in overrides],
|
||||
)
|
||||
raise ValueError(
|
||||
f"Found multiple overrides for '{task_name}' on '{build_variant}'"
|
||||
)
|
||||
return overrides[0]
|
||||
return None
|
||||
|
||||
def lookup_exec_override(self, build_variant: str, task_name: str) -> Optional[timedelta]:
|
||||
def lookup_exec_override(
|
||||
self, build_variant: str, task_name: str
|
||||
) -> Optional[timedelta]:
|
||||
"""
|
||||
Look up the exec timeout override of the given build variant/task.
|
||||
|
||||
@ -130,7 +145,9 @@ class TimeoutOverrides(BaseModel):
|
||||
return override.get_exec_timeout()
|
||||
return None
|
||||
|
||||
def lookup_idle_override(self, build_variant: str, task_name: str) -> Optional[timedelta]:
|
||||
def lookup_idle_override(
|
||||
self, build_variant: str, task_name: str
|
||||
) -> Optional[timedelta]:
|
||||
"""
|
||||
Look up the idle timeout override of the given build variant/task.
|
||||
|
||||
@ -144,8 +161,11 @@ class TimeoutOverrides(BaseModel):
|
||||
return None
|
||||
|
||||
|
||||
def output_timeout(exec_timeout: timedelta, idle_timeout: Optional[timedelta],
|
||||
output_file: Optional[str]) -> None:
|
||||
def output_timeout(
|
||||
exec_timeout: timedelta,
|
||||
idle_timeout: Optional[timedelta],
|
||||
output_file: Optional[str],
|
||||
) -> None:
|
||||
"""
|
||||
Output timeout configuration to the specified location.
|
||||
|
||||
@ -170,8 +190,12 @@ class TaskTimeoutOrchestrator:
|
||||
"""An orchestrator for determining task timeouts."""
|
||||
|
||||
@inject.autoparams()
|
||||
def __init__(self, timeout_service: TimeoutService, timeout_overrides: TimeoutOverrides,
|
||||
evg_project_config: EvergreenProjectConfig) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
timeout_service: TimeoutService,
|
||||
timeout_overrides: TimeoutOverrides,
|
||||
evg_project_config: EvergreenProjectConfig,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the orchestrator.
|
||||
|
||||
@ -183,10 +207,15 @@ class TaskTimeoutOrchestrator:
|
||||
self.timeout_overrides = timeout_overrides
|
||||
self.evg_project_config = evg_project_config
|
||||
|
||||
def determine_exec_timeout(self, task_name: str, variant: str,
|
||||
idle_timeout: Optional[timedelta] = None,
|
||||
exec_timeout: Optional[timedelta] = None, evg_alias: str = "",
|
||||
historic_timeout: Optional[timedelta] = None) -> timedelta:
|
||||
def determine_exec_timeout(
|
||||
self,
|
||||
task_name: str,
|
||||
variant: str,
|
||||
idle_timeout: Optional[timedelta] = None,
|
||||
exec_timeout: Optional[timedelta] = None,
|
||||
evg_alias: str = "",
|
||||
historic_timeout: Optional[timedelta] = None,
|
||||
) -> timedelta:
|
||||
"""
|
||||
Determine what exec timeout should be used.
|
||||
|
||||
@ -205,36 +234,56 @@ class TaskTimeoutOrchestrator:
|
||||
override = self.timeout_overrides.lookup_exec_override(variant, task_name)
|
||||
|
||||
if exec_timeout and exec_timeout.total_seconds() != 0:
|
||||
LOGGER.info("Using timeout from cmd line",
|
||||
exec_timeout_secs=exec_timeout.total_seconds())
|
||||
LOGGER.info(
|
||||
"Using timeout from cmd line",
|
||||
exec_timeout_secs=exec_timeout.total_seconds(),
|
||||
)
|
||||
determined_timeout = exec_timeout
|
||||
|
||||
elif override is not None:
|
||||
LOGGER.info("Overriding configured timeout", exec_timeout_secs=override.total_seconds())
|
||||
LOGGER.info(
|
||||
"Overriding configured timeout",
|
||||
exec_timeout_secs=override.total_seconds(),
|
||||
)
|
||||
determined_timeout = override
|
||||
|
||||
elif self._is_required_build_variant(
|
||||
variant) and determined_timeout > DEFAULT_REQUIRED_BUILD_TIMEOUT:
|
||||
LOGGER.info("Overriding required-builder timeout",
|
||||
exec_timeout_secs=DEFAULT_REQUIRED_BUILD_TIMEOUT.total_seconds())
|
||||
elif (
|
||||
self._is_required_build_variant(variant)
|
||||
and determined_timeout > DEFAULT_REQUIRED_BUILD_TIMEOUT
|
||||
):
|
||||
LOGGER.info(
|
||||
"Overriding required-builder timeout",
|
||||
exec_timeout_secs=DEFAULT_REQUIRED_BUILD_TIMEOUT.total_seconds(),
|
||||
)
|
||||
determined_timeout = DEFAULT_REQUIRED_BUILD_TIMEOUT
|
||||
|
||||
elif evg_alias == COMMIT_QUEUE_ALIAS:
|
||||
LOGGER.info("Overriding commit-queue timeout",
|
||||
exec_timeout_secs=COMMIT_QUEUE_TIMEOUT.total_seconds())
|
||||
LOGGER.info(
|
||||
"Overriding commit-queue timeout",
|
||||
exec_timeout_secs=COMMIT_QUEUE_TIMEOUT.total_seconds(),
|
||||
)
|
||||
determined_timeout = COMMIT_QUEUE_TIMEOUT
|
||||
|
||||
# The timeout needs to be at least as large as the idle timeout.
|
||||
if idle_timeout and determined_timeout.total_seconds() < idle_timeout.total_seconds():
|
||||
LOGGER.info("Making exec timeout as large as idle timeout",
|
||||
exec_timeout_secs=idle_timeout.total_seconds())
|
||||
if (
|
||||
idle_timeout
|
||||
and determined_timeout.total_seconds() < idle_timeout.total_seconds()
|
||||
):
|
||||
LOGGER.info(
|
||||
"Making exec timeout as large as idle timeout",
|
||||
exec_timeout_secs=idle_timeout.total_seconds(),
|
||||
)
|
||||
return idle_timeout
|
||||
|
||||
return determined_timeout
|
||||
|
||||
def determine_idle_timeout(self, task_name: str, variant: str,
|
||||
idle_timeout: Optional[timedelta] = None,
|
||||
historic_timeout: Optional[timedelta] = None) -> Optional[timedelta]:
|
||||
def determine_idle_timeout(
|
||||
self,
|
||||
task_name: str,
|
||||
variant: str,
|
||||
idle_timeout: Optional[timedelta] = None,
|
||||
historic_timeout: Optional[timedelta] = None,
|
||||
) -> Optional[timedelta]:
|
||||
"""
|
||||
Determine what idle timeout should be used.
|
||||
|
||||
@ -249,18 +298,29 @@ class TaskTimeoutOrchestrator:
|
||||
override = self.timeout_overrides.lookup_idle_override(variant, task_name)
|
||||
|
||||
if idle_timeout and idle_timeout.total_seconds() != 0:
|
||||
LOGGER.info("Using timeout from cmd line",
|
||||
idle_timeout_secs=idle_timeout.total_seconds())
|
||||
LOGGER.info(
|
||||
"Using timeout from cmd line",
|
||||
idle_timeout_secs=idle_timeout.total_seconds(),
|
||||
)
|
||||
determined_timeout = idle_timeout
|
||||
|
||||
elif override is not None:
|
||||
LOGGER.info("Overriding configured timeout", idle_timeout_secs=override.total_seconds())
|
||||
LOGGER.info(
|
||||
"Overriding configured timeout",
|
||||
idle_timeout_secs=override.total_seconds(),
|
||||
)
|
||||
determined_timeout = override
|
||||
|
||||
return determined_timeout
|
||||
|
||||
def determine_historic_timeout(self, project: str, task: str, variant: str, suite_name: str,
|
||||
exec_timeout_factor: Optional[float]) -> TimeoutOverride:
|
||||
def determine_historic_timeout(
|
||||
self,
|
||||
project: str,
|
||||
task: str,
|
||||
variant: str,
|
||||
suite_name: str,
|
||||
exec_timeout_factor: Optional[float],
|
||||
) -> TimeoutOverride:
|
||||
"""
|
||||
Calculate the timeout based on historic test results.
|
||||
|
||||
@ -283,11 +343,15 @@ class TaskTimeoutOrchestrator:
|
||||
timeout_estimate = self.timeout_service.get_timeout_estimate(timeout_params)
|
||||
if timeout_estimate and timeout_estimate.is_specified():
|
||||
exec_timeout = timeout_estimate.calculate_task_timeout(
|
||||
repeat_factor=1, scaling_factor=exec_timeout_factor)
|
||||
repeat_factor=1, scaling_factor=exec_timeout_factor
|
||||
)
|
||||
idle_timeout = timeout_estimate.calculate_test_timeout(repeat_factor=1)
|
||||
if exec_timeout is not None or idle_timeout is not None:
|
||||
LOGGER.info("Getting historic based timeout", exec_timeout_secs=exec_timeout,
|
||||
idle_timeout_secs=idle_timeout)
|
||||
LOGGER.info(
|
||||
"Getting historic based timeout",
|
||||
exec_timeout_secs=exec_timeout,
|
||||
idle_timeout_secs=idle_timeout,
|
||||
)
|
||||
return TimeoutOverride.from_seconds(task, exec_timeout, idle_timeout)
|
||||
return TimeoutOverride(task=task, exec_timeout=None, idle_timeout=None)
|
||||
|
||||
@ -312,10 +376,18 @@ class TaskTimeoutOrchestrator:
|
||||
bv = self.evg_project_config.get_variant(build_variant)
|
||||
return "!" in bv.display_name
|
||||
|
||||
def determine_timeouts(self, cli_idle_timeout: Optional[timedelta],
|
||||
cli_exec_timeout: Optional[timedelta], outfile: Optional[str],
|
||||
project: str, task: str, variant: str, evg_alias: str, suite_name: str,
|
||||
exec_timeout_factor: Optional[float]) -> None:
|
||||
def determine_timeouts(
|
||||
self,
|
||||
cli_idle_timeout: Optional[timedelta],
|
||||
cli_exec_timeout: Optional[timedelta],
|
||||
outfile: Optional[str],
|
||||
project: str,
|
||||
task: str,
|
||||
variant: str,
|
||||
evg_alias: str,
|
||||
suite_name: str,
|
||||
exec_timeout_factor: Optional[float],
|
||||
) -> None:
|
||||
"""
|
||||
Determine the timeouts to use for the given task and write timeouts to expansion file.
|
||||
|
||||
@ -329,13 +401,21 @@ class TaskTimeoutOrchestrator:
|
||||
:param suite_name: Name of evergreen suite being run.
|
||||
:param exec_timeout_factor: Scaling factor to use when determining timeout.
|
||||
"""
|
||||
historic_timeout = self.determine_historic_timeout(project, task, variant, suite_name,
|
||||
exec_timeout_factor)
|
||||
historic_timeout = self.determine_historic_timeout(
|
||||
project, task, variant, suite_name, exec_timeout_factor
|
||||
)
|
||||
|
||||
idle_timeout = self.determine_idle_timeout(task, variant, cli_idle_timeout,
|
||||
historic_timeout.get_idle_timeout())
|
||||
exec_timeout = self.determine_exec_timeout(task, variant, idle_timeout, cli_exec_timeout,
|
||||
evg_alias, historic_timeout.get_exec_timeout())
|
||||
idle_timeout = self.determine_idle_timeout(
|
||||
task, variant, cli_idle_timeout, historic_timeout.get_idle_timeout()
|
||||
)
|
||||
exec_timeout = self.determine_exec_timeout(
|
||||
task,
|
||||
variant,
|
||||
idle_timeout,
|
||||
cli_exec_timeout,
|
||||
evg_alias,
|
||||
historic_timeout.get_exec_timeout(),
|
||||
)
|
||||
|
||||
output_timeout(exec_timeout, idle_timeout, outfile)
|
||||
|
||||
@ -344,42 +424,92 @@ def main():
|
||||
"""Determine the timeout value a task should use in evergreen."""
|
||||
parser = argparse.ArgumentParser(description=main.__doc__)
|
||||
|
||||
parser.add_argument("--install-dir", dest="install_dir", required=True,
|
||||
help="Path to bin directory of testable installation")
|
||||
parser.add_argument("--task-name", dest="task", required=True, help="Task being executed.")
|
||||
parser.add_argument("--suite-name", dest="suite_name", required=True,
|
||||
help="Resmoke suite being run against.")
|
||||
parser.add_argument("--build-variant", dest="variant", required=True,
|
||||
help="Build variant task is being executed on.")
|
||||
parser.add_argument("--project", dest="project", required=True,
|
||||
help="Evergreen project task is being executed on.")
|
||||
parser.add_argument("--evg-alias", dest="evg_alias", required=True,
|
||||
help="Evergreen alias used to trigger build.")
|
||||
parser.add_argument("--test-flags", dest="test_flags",
|
||||
help="Test flags that are used for `resmoke.py run` command call.")
|
||||
parser.add_argument("--timeout", dest="timeout", type=int, help="Timeout to use (in sec).")
|
||||
parser.add_argument("--exec-timeout", dest="exec_timeout", type=int,
|
||||
help="Exec timeout to use (in sec).")
|
||||
parser.add_argument("--exec-timeout-factor", dest="exec_timeout_factor", type=float,
|
||||
help="Exec timeout factor to use (in sec).")
|
||||
parser.add_argument("--out-file", dest="outfile", help="File to write configuration to.")
|
||||
parser.add_argument("--timeout-overrides", dest="timeout_overrides_file",
|
||||
default=DEFAULT_TIMEOUT_OVERRIDES,
|
||||
help="File containing timeout overrides to use.")
|
||||
parser.add_argument("--evg-api-config", dest="evg_api_config",
|
||||
default=DEFAULT_EVERGREEN_AUTH_CONFIG, help="Evergreen API config file.")
|
||||
parser.add_argument("--evg-project-config", dest="evg_project_config",
|
||||
default=DEFAULT_EVERGREEN_CONFIG, help="Evergreen project config file.")
|
||||
parser.add_argument(
|
||||
"--install-dir",
|
||||
dest="install_dir",
|
||||
required=True,
|
||||
help="Path to bin directory of testable installation",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task-name", dest="task", required=True, help="Task being executed."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--suite-name",
|
||||
dest="suite_name",
|
||||
required=True,
|
||||
help="Resmoke suite being run against.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--build-variant",
|
||||
dest="variant",
|
||||
required=True,
|
||||
help="Build variant task is being executed on.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--project",
|
||||
dest="project",
|
||||
required=True,
|
||||
help="Evergreen project task is being executed on.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--evg-alias",
|
||||
dest="evg_alias",
|
||||
required=True,
|
||||
help="Evergreen alias used to trigger build.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test-flags",
|
||||
dest="test_flags",
|
||||
help="Test flags that are used for `resmoke.py run` command call.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout", dest="timeout", type=int, help="Timeout to use (in sec)."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--exec-timeout",
|
||||
dest="exec_timeout",
|
||||
type=int,
|
||||
help="Exec timeout to use (in sec).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--exec-timeout-factor",
|
||||
dest="exec_timeout_factor",
|
||||
type=float,
|
||||
help="Exec timeout factor to use (in sec).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--out-file", dest="outfile", help="File to write configuration to."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout-overrides",
|
||||
dest="timeout_overrides_file",
|
||||
default=DEFAULT_TIMEOUT_OVERRIDES,
|
||||
help="File containing timeout overrides to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--evg-api-config",
|
||||
dest="evg_api_config",
|
||||
default=DEFAULT_EVERGREEN_AUTH_CONFIG,
|
||||
help="Evergreen API config file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--evg-project-config",
|
||||
dest="evg_project_config",
|
||||
default=DEFAULT_EVERGREEN_CONFIG,
|
||||
help="Evergreen project config file.",
|
||||
)
|
||||
|
||||
options = parser.parse_args()
|
||||
|
||||
timeout_override = timedelta(seconds=options.timeout) if options.timeout else None
|
||||
exec_timeout_override = timedelta(
|
||||
seconds=options.exec_timeout) if options.exec_timeout else None
|
||||
exec_timeout_override = (
|
||||
timedelta(seconds=options.exec_timeout) if options.exec_timeout else None
|
||||
)
|
||||
|
||||
task_name = determine_task_base_name(options.task, options.variant)
|
||||
timeout_overrides = TimeoutOverrides.from_yaml_file(
|
||||
os.path.expanduser(options.timeout_overrides_file))
|
||||
os.path.expanduser(options.timeout_overrides_file)
|
||||
)
|
||||
|
||||
enable_logging(verbose=False)
|
||||
LOGGER.info("Determining timeouts", cli_args=options)
|
||||
@ -387,22 +517,36 @@ def main():
|
||||
def dependencies(binder: inject.Binder) -> None:
|
||||
binder.bind(
|
||||
EvergreenApi,
|
||||
RetryingEvergreenApi.get_api(config_file=os.path.expanduser(options.evg_api_config)))
|
||||
RetryingEvergreenApi.get_api(
|
||||
config_file=os.path.expanduser(options.evg_api_config)
|
||||
),
|
||||
)
|
||||
binder.bind(TimeoutOverrides, timeout_overrides)
|
||||
binder.bind(EvergreenProjectConfig,
|
||||
parse_evergreen_file(os.path.expanduser(options.evg_project_config)))
|
||||
binder.bind(
|
||||
EvergreenProjectConfig,
|
||||
parse_evergreen_file(os.path.expanduser(options.evg_project_config)),
|
||||
)
|
||||
binder.bind(
|
||||
ResmokeProxyService,
|
||||
ResmokeProxyService(
|
||||
run_options=f"--installDir={shlex.quote(options.install_dir)} {options.test_flags}")
|
||||
run_options=f"--installDir={shlex.quote(options.install_dir)} {options.test_flags}"
|
||||
),
|
||||
)
|
||||
|
||||
inject.configure(dependencies)
|
||||
|
||||
task_timeout_orchestrator = inject.instance(TaskTimeoutOrchestrator)
|
||||
task_timeout_orchestrator.determine_timeouts(
|
||||
timeout_override, exec_timeout_override, options.outfile, options.project, task_name,
|
||||
options.variant, options.evg_alias, options.suite_name, options.exec_timeout_factor)
|
||||
timeout_override,
|
||||
exec_timeout_override,
|
||||
options.outfile,
|
||||
options.project,
|
||||
task_name,
|
||||
options.variant,
|
||||
options.evg_alias,
|
||||
options.suite_name,
|
||||
options.exec_timeout_factor,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -15,9 +15,19 @@ import requests
|
||||
from buildscripts.util.read_config import read_config_file
|
||||
|
||||
|
||||
def process_file(file: str, aws_secret: str, aws_key: str, project: str, variant: str,
|
||||
version_id: str, revision: int, task_name: str, file_number: int, upload_name: str,
|
||||
start_time: int) -> Optional[Dict[str, str]]:
|
||||
def process_file(
|
||||
file: str,
|
||||
aws_secret: str,
|
||||
aws_key: str,
|
||||
project: str,
|
||||
variant: str,
|
||||
version_id: str,
|
||||
revision: int,
|
||||
task_name: str,
|
||||
file_number: int,
|
||||
upload_name: str,
|
||||
start_time: int,
|
||||
) -> Optional[Dict[str, str]]:
|
||||
print(f"{file} started compressing at {time.time() - start_time}")
|
||||
compressed_file = f"{file}.gz"
|
||||
with open(file, "rb") as f_in:
|
||||
@ -26,12 +36,16 @@ def process_file(file: str, aws_secret: str, aws_key: str, project: str, variant
|
||||
|
||||
print(f"{file} finished compressing at {time.time() - start_time}")
|
||||
|
||||
s3_client = boto3.client('s3', aws_access_key_id=aws_key, aws_secret_access_key=aws_secret)
|
||||
s3_client = boto3.client(
|
||||
"s3", aws_access_key_id=aws_key, aws_secret_access_key=aws_secret
|
||||
)
|
||||
basename = os.path.basename(compressed_file)
|
||||
object_path = f"{project}/{variant}/{version_id}/{task_name}-{revision}-{file_number}/{basename}"
|
||||
extra_args = {"ContentType": "application/gzip", "ACL": "public-read"}
|
||||
try:
|
||||
s3_client.upload_file(compressed_file, "mciuploads", object_path, ExtraArgs=extra_args)
|
||||
s3_client.upload_file(
|
||||
compressed_file, "mciuploads", object_path, ExtraArgs=extra_args
|
||||
)
|
||||
except Exception as ex:
|
||||
print(f"ERROR: failed to upload file to s3 {file}", file=sys.stderr)
|
||||
print(ex, file=sys.stderr)
|
||||
@ -43,8 +57,10 @@ def process_file(file: str, aws_secret: str, aws_key: str, project: str, variant
|
||||
# Sanity check to ensure the url exists
|
||||
r = requests.head(url)
|
||||
if r.status_code != 200:
|
||||
print(f"ERROR: Could not verify that {compressed_file} was uploaded to {url}",
|
||||
file=sys.stderr)
|
||||
print(
|
||||
f"ERROR: Could not verify that {compressed_file} was uploaded to {url}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return None
|
||||
|
||||
print(f"{compressed_file} uploaded at {time.time() - start_time} to {url}")
|
||||
@ -59,8 +75,9 @@ def process_file(file: str, aws_secret: str, aws_key: str, project: str, variant
|
||||
return task_artifact
|
||||
|
||||
|
||||
def main(output_file: str, patterns: List[str], display_name: str, expansions_file: str) -> int:
|
||||
|
||||
def main(
|
||||
output_file: str, patterns: List[str], display_name: str, expansions_file: str
|
||||
) -> int:
|
||||
if not output_file.endswith(".json"):
|
||||
print("WARN: filename input should end with `.json`", file=sys.stderr)
|
||||
|
||||
@ -98,11 +115,21 @@ def main(output_file: str, patterns: List[str], display_name: str, expansions_fi
|
||||
for i, path in enumerate(files):
|
||||
file_number = i + 1
|
||||
futures.append(
|
||||
executor.submit(process_file, file=path, aws_secret=aws_secret_key,
|
||||
aws_key=aws_access_key, project=project, variant=build_variant,
|
||||
version_id=version_id, revision=revision, task_name=task_name,
|
||||
file_number=file_number, upload_name=display_name,
|
||||
start_time=start_time))
|
||||
executor.submit(
|
||||
process_file,
|
||||
file=path,
|
||||
aws_secret=aws_secret_key,
|
||||
aws_key=aws_access_key,
|
||||
project=project,
|
||||
variant=build_variant,
|
||||
version_id=version_id,
|
||||
revision=revision,
|
||||
task_name=task_name,
|
||||
file_number=file_number,
|
||||
upload_name=display_name,
|
||||
start_time=start_time,
|
||||
)
|
||||
)
|
||||
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
result = future.result()
|
||||
@ -118,21 +145,40 @@ def main(output_file: str, patterns: List[str], display_name: str, expansions_fi
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog='FastArchiver',
|
||||
description='This improves archiving times of a large amount of big files in evergreen '
|
||||
'by compressing and uploading them asynchronously. '
|
||||
'This also uses pigz, which is a multithreaded implementation of gzip, '
|
||||
'to improve gzipping times.')
|
||||
prog="FastArchiver",
|
||||
description="This improves archiving times of a large amount of big files in evergreen "
|
||||
"by compressing and uploading them asynchronously. "
|
||||
"This also uses pigz, which is a multithreaded implementation of gzip, "
|
||||
"to improve gzipping times.",
|
||||
)
|
||||
|
||||
parser.add_argument("--output-file", "-f", help="Name of output attach.artifacts file.",
|
||||
required=True)
|
||||
parser.add_argument("--pattern", "-p", help="glob patterns of files to be archived.",
|
||||
dest="patterns", action="append", default=[], required=True)
|
||||
parser.add_argument("--display-name", "-n", help="The display name of the file in evergreen",
|
||||
required=True)
|
||||
parser.add_argument("--expansions-file", "-e",
|
||||
help="Expansions file to read task info and aws credentials from.",
|
||||
default="../expansions.yml")
|
||||
parser.add_argument(
|
||||
"--output-file",
|
||||
"-f",
|
||||
help="Name of output attach.artifacts file.",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pattern",
|
||||
"-p",
|
||||
help="glob patterns of files to be archived.",
|
||||
dest="patterns",
|
||||
action="append",
|
||||
default=[],
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--display-name",
|
||||
"-n",
|
||||
help="The display name of the file in evergreen",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--expansions-file",
|
||||
"-e",
|
||||
help="Expansions file to read task info and aws credentials from.",
|
||||
default="../expansions.yml",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
exit(main(args.output_file, args.patterns, args.display_name, args.expansions_file))
|
||||
|
||||
@ -61,7 +61,9 @@ def main(diff_file, ent_path):
|
||||
base_feature_flags = fh.read().split()
|
||||
with open("patch_all_feature_flags.txt", "r") as fh:
|
||||
patch_feature_flags = fh.read().split()
|
||||
enabled_feature_flags = [flag for flag in base_feature_flags if flag not in patch_feature_flags]
|
||||
enabled_feature_flags = [
|
||||
flag for flag in base_feature_flags if flag not in patch_feature_flags
|
||||
]
|
||||
|
||||
if not enabled_feature_flags:
|
||||
print(
|
||||
@ -69,23 +71,30 @@ def main(diff_file, ent_path):
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
tests_with_feature_flag_tag = get_tests_with_feature_flag_tags(enabled_feature_flags, ent_path)
|
||||
tests_with_feature_flag_tag = get_tests_with_feature_flag_tags(
|
||||
enabled_feature_flags, ent_path
|
||||
)
|
||||
|
||||
_run_git_cmd(["apply", diff_file])
|
||||
_run_git_cmd(["apply", diff_file], cwd=ent_path)
|
||||
tests_missing_fcv_tag = get_tests_missing_fcv_tag(tests_with_feature_flag_tag)
|
||||
|
||||
if tests_missing_fcv_tag:
|
||||
print(f"Found tests missing `{REQUIRES_FCV_TAG_LATEST}` tag:\n" +
|
||||
"\n".join(tests_missing_fcv_tag))
|
||||
print(
|
||||
f"Found tests missing `{REQUIRES_FCV_TAG_LATEST}` tag:\n"
|
||||
+ "\n".join(tests_missing_fcv_tag)
|
||||
)
|
||||
sys.exit(1)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--diff-file-name", type=str,
|
||||
help="Name of the file containing the git diff")
|
||||
parser.add_argument("--enterprise-path", type=str, help="Path to the enterprise module")
|
||||
parser.add_argument(
|
||||
"--diff-file-name", type=str, help="Name of the file containing the git diff"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enterprise-path", type=str, help="Path to the enterprise module"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args.diff_file_name, args.enterprise_path)
|
||||
|
||||
@ -29,7 +29,9 @@ def _relink_binaries_with_symbols(failed_tests: List[str]):
|
||||
bazel_build_flags.replace("--config=strip-debug-during-link", "")
|
||||
|
||||
relink_command = [
|
||||
arg for arg in ["bazel", "build", *bazel_build_flags.split(" "), *failed_tests] if arg
|
||||
arg
|
||||
for arg in ["bazel", "build", *bazel_build_flags.split(" "), *failed_tests]
|
||||
if arg
|
||||
]
|
||||
|
||||
# Enable this when/if we want to strip debug symbols during linking unit tests
|
||||
@ -39,13 +41,17 @@ def _relink_binaries_with_symbols(failed_tests: List[str]):
|
||||
# check=True,
|
||||
# )
|
||||
|
||||
repro_test_command = " ".join(["test" if arg == "build" else arg for arg in relink_command])
|
||||
repro_test_command = " ".join(
|
||||
["test" if arg == "build" else arg for arg in relink_command]
|
||||
)
|
||||
with open(".failed_unittest_repro.txt", "w", encoding="utf-8") as f:
|
||||
f.write(repro_test_command)
|
||||
print(f"Repro command written to .failed_unittest_repro.txt: {repro_test_command}")
|
||||
|
||||
|
||||
def _copy_bins_to_upload(failed_tests: List[str], upload_bin_dir: str, upload_lib_dir: str) -> bool:
|
||||
def _copy_bins_to_upload(
|
||||
failed_tests: List[str], upload_bin_dir: str, upload_lib_dir: str
|
||||
) -> bool:
|
||||
success = True
|
||||
bazel_bin_dir = Path("./bazel-bin")
|
||||
for failed_test in failed_tests:
|
||||
@ -80,7 +86,9 @@ def _copy_bins_to_upload(failed_tests: List[str], upload_bin_dir: str, upload_li
|
||||
dsym_dir = full_binary_path.with_suffix(".dSYM")
|
||||
if dsym_dir.is_dir():
|
||||
print(f"Copying dsym {dsym_dir} to {upload_bin_dir}")
|
||||
shutil.copytree(dsym_dir, upload_bin_dir / dsym_dir.name, dirs_exist_ok=True)
|
||||
shutil.copytree(
|
||||
dsym_dir, upload_bin_dir / dsym_dir.name, dirs_exist_ok=True
|
||||
)
|
||||
|
||||
# Copy debug symbols for dynamic builds
|
||||
lib_dir = Path("bazel-bin/install/lib")
|
||||
@ -105,7 +113,9 @@ def main(testlog_dir: str = "bazel-testlogs"):
|
||||
print("No failed tests found.")
|
||||
exit(0)
|
||||
|
||||
print(f"Found {len(failed_tests)} failed tests. Gathering binaries and debug symbols.")
|
||||
print(
|
||||
f"Found {len(failed_tests)} failed tests. Gathering binaries and debug symbols."
|
||||
)
|
||||
_relink_binaries_with_symbols(failed_tests)
|
||||
|
||||
print("Copying binaries and debug symbols to upload directories.")
|
||||
|
||||
@ -22,25 +22,24 @@ if not gdb:
|
||||
|
||||
|
||||
def detect_toolchain(progspace):
|
||||
|
||||
readelf_bin = '/opt/mongodbtoolchain/v4/bin/readelf'
|
||||
readelf_bin = "/opt/mongodbtoolchain/v4/bin/readelf"
|
||||
if not os.path.exists(readelf_bin):
|
||||
readelf_bin = 'readelf'
|
||||
readelf_bin = "readelf"
|
||||
|
||||
gcc_version_regex = re.compile(r'.*\]\s*GCC: \(GNU\) (\d+\.\d+\.\d+)\s*$')
|
||||
clang_version_regex = re.compile(r'.*\]\s*MongoDB clang version (\d+\.\d+\.\d+).*')
|
||||
gcc_version_regex = re.compile(r".*\]\s*GCC: \(GNU\) (\d+\.\d+\.\d+)\s*$")
|
||||
clang_version_regex = re.compile(r".*\]\s*MongoDB clang version (\d+\.\d+\.\d+).*")
|
||||
|
||||
readelf_cmd = [readelf_bin, '-p', '.comment', progspace.filename]
|
||||
readelf_cmd = [readelf_bin, "-p", ".comment", progspace.filename]
|
||||
# take an educated guess as to where we could find the c++ printers, better than hardcoding
|
||||
result = subprocess.run(readelf_cmd, capture_output=True, text=True)
|
||||
|
||||
gcc_version = None
|
||||
for line in result.stdout.split('\n'):
|
||||
for line in result.stdout.split("\n"):
|
||||
if match := re.search(gcc_version_regex, line):
|
||||
gcc_version = match.group(1)
|
||||
break
|
||||
clang_version = None
|
||||
for line in result.stdout.split('\n'):
|
||||
for line in result.stdout.split("\n"):
|
||||
if match := re.search(clang_version_regex, line):
|
||||
clang_version = match.group(1)
|
||||
break
|
||||
@ -55,13 +54,13 @@ def detect_toolchain(progspace):
|
||||
# default is v4 if we can't find a version
|
||||
toolchain_ver = None
|
||||
if gcc_version:
|
||||
if int(gcc_version.split('.')[0]) == 8:
|
||||
toolchain_ver = 'v3'
|
||||
elif int(gcc_version.split('.')[0]) == 11:
|
||||
toolchain_ver = 'v4'
|
||||
if int(gcc_version.split(".")[0]) == 8:
|
||||
toolchain_ver = "v3"
|
||||
elif int(gcc_version.split(".")[0]) == 11:
|
||||
toolchain_ver = "v4"
|
||||
|
||||
if not toolchain_ver:
|
||||
toolchain_ver = 'v4'
|
||||
toolchain_ver = "v4"
|
||||
print(f"""
|
||||
WARNING: could not detect a MongoDB toolchain to load matching stdcxx printers:
|
||||
-----------------
|
||||
@ -75,7 +74,8 @@ STDERR:
|
||||
Assuming {toolchain_ver} as a default, this could cause issues with the printers.""")
|
||||
|
||||
pp = glob.glob(
|
||||
f"/opt/mongodbtoolchain/{toolchain_ver}/share/gcc-*/python/libstdcxx/v6/printers.py")
|
||||
f"/opt/mongodbtoolchain/{toolchain_ver}/share/gcc-*/python/libstdcxx/v6/printers.py"
|
||||
)
|
||||
printers = pp[0]
|
||||
return os.path.dirname(os.path.dirname(os.path.dirname(printers)))
|
||||
|
||||
@ -91,6 +91,7 @@ def load_libstdcxx_printers(progspace):
|
||||
global stdlib_printers # pylint: disable=invalid-name,global-variable-not-assigned
|
||||
from libstdcxx.v6 import printers as stdlib_printers
|
||||
from libstdcxx.v6 import register_libstdcxx_printers
|
||||
|
||||
register_libstdcxx_printers(progspace)
|
||||
print(
|
||||
f"Loaded libstdc++ pretty printers from '{stdcxx_printer_toolchain_paths[progspace]}'"
|
||||
@ -112,7 +113,8 @@ def on_new_object_file(objfile) -> None:
|
||||
warnings.warn(
|
||||
"Unable to locate the libstdc++ GDB pretty printers without an executable"
|
||||
" file. Try running the `file` command with the path to the executable file"
|
||||
" and reloading the core dump with the `core-file` command")
|
||||
" and reloading the core dump with the `core-file` command"
|
||||
)
|
||||
return
|
||||
|
||||
load_libstdcxx_printers(objfile.new_objfile.progspace)
|
||||
@ -132,7 +134,8 @@ except ImportError:
|
||||
|
||||
if sys.version_info[0] < 3:
|
||||
raise gdb.GdbError(
|
||||
"MongoDB gdb extensions only support Python 3. Your GDB was compiled against Python 2")
|
||||
"MongoDB gdb extensions only support Python 3. Your GDB was compiled against Python 2"
|
||||
)
|
||||
|
||||
|
||||
def get_process_name():
|
||||
@ -190,12 +193,14 @@ def lookup_type(gdb_type_str: str) -> gdb.Type:
|
||||
except Exception as exc:
|
||||
exceptions.append(exc)
|
||||
|
||||
raise gdb.error("Failed to get type, tried:\n%s" % '\n'.join([str(exc) for exc in exceptions]))
|
||||
raise gdb.error(
|
||||
"Failed to get type, tried:\n%s" % "\n".join([str(exc) for exc in exceptions])
|
||||
)
|
||||
|
||||
|
||||
def get_current_thread_name():
|
||||
"""Return the name of the current GDB thread."""
|
||||
fallback_name = '"%s"' % (gdb.selected_thread().name or '')
|
||||
fallback_name = '"%s"' % (gdb.selected_thread().name or "")
|
||||
try:
|
||||
# This goes through the pretty printer for StringData which adds "" around the name.
|
||||
name = str(gdb.parse_and_eval("mongo::getThreadName()"))
|
||||
@ -208,7 +213,9 @@ def get_current_thread_name():
|
||||
|
||||
def get_global_service_context():
|
||||
"""Return the global ServiceContext object."""
|
||||
return gdb.parse_and_eval("'mongo::(anonymous namespace)::globalServiceContext'").dereference()
|
||||
return gdb.parse_and_eval(
|
||||
"'mongo::(anonymous namespace)::globalServiceContext'"
|
||||
).dereference()
|
||||
|
||||
|
||||
def get_session_catalog():
|
||||
@ -217,7 +224,9 @@ def get_session_catalog():
|
||||
Returns None if no SessionCatalog could be found.
|
||||
"""
|
||||
# The SessionCatalog is a decoration on the ServiceContext.
|
||||
session_catalog_dec = get_decoration(get_global_service_context(), "mongo::SessionCatalog")
|
||||
session_catalog_dec = get_decoration(
|
||||
get_global_service_context(), "mongo::SessionCatalog"
|
||||
)
|
||||
if session_catalog_dec is None:
|
||||
return None
|
||||
return session_catalog_dec[1]
|
||||
@ -251,7 +260,8 @@ def get_wt_session(recovery_unit, recovery_unit_impl_type):
|
||||
if not wt_session_handle.dereference().address:
|
||||
return None
|
||||
wt_session = wt_session_handle.dereference().cast(
|
||||
lookup_type("mongo::WiredTigerSession"))["_session"]
|
||||
lookup_type("mongo::WiredTigerSession")
|
||||
)["_session"]
|
||||
return wt_session
|
||||
|
||||
|
||||
@ -269,28 +279,32 @@ def get_decorations(obj):
|
||||
try:
|
||||
yield (deco_type_name, obj)
|
||||
except Exception as err:
|
||||
print("Failed to look up decoration type: " + deco_type_name + ": " + str(err))
|
||||
print(
|
||||
"Failed to look up decoration type: " + deco_type_name + ": " + str(err)
|
||||
)
|
||||
|
||||
|
||||
def get_object_decoration(decorable, start, index):
|
||||
decoration_data = get_unique_ptr_bytes(decorable["_decorations"]["_data"])
|
||||
entry = start[index]
|
||||
deco_type_info = str(entry["typeInfo"])
|
||||
deco_type_name = re.sub(r'.* <typeinfo for (.*)>', r'\1', deco_type_info)
|
||||
deco_type_name = re.sub(r".* <typeinfo for (.*)>", r"\1", deco_type_info)
|
||||
offset = int(entry["offset"])
|
||||
obj = decoration_data[offset]
|
||||
obj_addr = re.sub(r'^(.*) .*', r'\1', str(obj.address))
|
||||
obj_addr = re.sub(r"^(.*) .*", r"\1", str(obj.address))
|
||||
obj = _cast_decoration_value(deco_type_name, int(obj.address))
|
||||
return (deco_type_name, obj, obj_addr)
|
||||
|
||||
|
||||
def get_decorable_info(decorable):
|
||||
decorable_t = decorable.type.template_argument(0)
|
||||
reg_sym, _ = gdb.lookup_symbol("mongo::decorable_detail::gdbRegistry<{}>".format(decorable_t))
|
||||
reg_sym, _ = gdb.lookup_symbol(
|
||||
"mongo::decorable_detail::gdbRegistry<{}>".format(decorable_t)
|
||||
)
|
||||
decl_vector = reg_sym.value()["_entries"]
|
||||
start = decl_vector["_M_impl"]["_M_start"]
|
||||
finish = decl_vector["_M_impl"]["_M_finish"]
|
||||
decinfo_t = lookup_type('mongo::decorable_detail::Registry::Entry')
|
||||
decinfo_t = lookup_type("mongo::decorable_detail::Registry::Entry")
|
||||
count = int((int(finish) - int(start)) / decinfo_t.sizeof)
|
||||
return start, count
|
||||
|
||||
@ -327,10 +341,10 @@ def get_boost_optional(optional):
|
||||
|
||||
TODO: Import the boost pretty printers instead of using this custom function.
|
||||
"""
|
||||
if not optional['m_initialized']:
|
||||
if not optional["m_initialized"]:
|
||||
return None
|
||||
value_ref_type = optional.type.template_argument(0).pointer()
|
||||
storage = optional['m_storage']['dummy_']['data']
|
||||
storage = optional["m_storage"]["dummy_"]["data"]
|
||||
return storage.cast(value_ref_type).dereference()
|
||||
|
||||
|
||||
@ -409,7 +423,11 @@ class GetMongoDecoration(gdb.Command):
|
||||
(type_name, obj) = dec
|
||||
print(type_name, obj)
|
||||
else:
|
||||
print("No decoration found whose type name contains '" + type_name_substr + "'.")
|
||||
print(
|
||||
"No decoration found whose type name contains '"
|
||||
+ type_name_substr
|
||||
+ "'."
|
||||
)
|
||||
|
||||
|
||||
# Register command
|
||||
@ -448,21 +466,25 @@ class DumpMongoDSessionCatalog(gdb.Command):
|
||||
)
|
||||
return
|
||||
session_kv_pairs = get_session_kv_pairs()
|
||||
print("Dumping %d Session objects from the SessionCatalog" % len(session_kv_pairs))
|
||||
print(
|
||||
"Dumping %d Session objects from the SessionCatalog" % len(session_kv_pairs)
|
||||
)
|
||||
|
||||
# Optionally search for a specified session, based on its id.
|
||||
if lsid_to_find:
|
||||
print("Only printing information for session " + lsid_to_find + ", if found.")
|
||||
print(
|
||||
"Only printing information for session " + lsid_to_find + ", if found."
|
||||
)
|
||||
lsids_to_print = [lsid_to_find]
|
||||
else:
|
||||
lsids_to_print = [str(s['first']['_id']) for s in session_kv_pairs]
|
||||
lsids_to_print = [str(s["first"]["_id"]) for s in session_kv_pairs]
|
||||
|
||||
for session_kv in session_kv_pairs:
|
||||
# The Session objects are stored inside the SessionRuntimeInfo object.
|
||||
session_runtime_info = get_unique_ptr(session_kv['second']).dereference()
|
||||
parent_session = session_runtime_info['parentSession']
|
||||
child_sessions = absl_get_nodes(session_runtime_info['childSessions'])
|
||||
lsid = str(parent_session['_sessionId']['_id'])
|
||||
session_runtime_info = get_unique_ptr(session_kv["second"]).dereference()
|
||||
parent_session = session_runtime_info["parentSession"]
|
||||
child_sessions = absl_get_nodes(session_runtime_info["childSessions"])
|
||||
lsid = str(parent_session["_sessionId"]["_id"])
|
||||
|
||||
# If we are only interested in a specific session, then we print out the entire Session
|
||||
# objects, to aid more detailed debugging.
|
||||
@ -470,7 +492,7 @@ class DumpMongoDSessionCatalog(gdb.Command):
|
||||
self.dump_session_runtime_info(session_runtime_info)
|
||||
print(parent_session)
|
||||
for child_session_kv in child_sessions:
|
||||
child_session = child_session_kv['second']
|
||||
child_session = child_session_kv["second"]
|
||||
print(child_session)
|
||||
# Terminate if this is the only session we care about.
|
||||
break
|
||||
@ -484,24 +506,27 @@ class DumpMongoDSessionCatalog(gdb.Command):
|
||||
self.dump_session_runtime_info(session_runtime_info)
|
||||
self.dump_session(parent_session)
|
||||
for child_session_kv in child_sessions:
|
||||
child_session = child_session_kv['second']
|
||||
child_session = child_session_kv["second"]
|
||||
self.dump_session(child_session)
|
||||
|
||||
@staticmethod
|
||||
def dump_session_runtime_info(session_runtime_info):
|
||||
"""Dump the session runtime info."""
|
||||
|
||||
parent_session = session_runtime_info['parentSession']
|
||||
parent_session = session_runtime_info["parentSession"]
|
||||
# TODO: Add a custom pretty printer for LogicalSessionId.
|
||||
lsid = str(parent_session['_sessionId']['_id'])[1:-1]
|
||||
lsid = str(parent_session["_sessionId"]["_id"])[1:-1]
|
||||
print("SessionId =", lsid)
|
||||
fields_to_print = ['checkoutOpCtx', 'killsRequested']
|
||||
fields_to_print = ["checkoutOpCtx", "killsRequested"]
|
||||
for field in fields_to_print:
|
||||
# Skip fields that aren't found on the object.
|
||||
if field in get_field_names(session_runtime_info):
|
||||
print(field, "=", session_runtime_info[field])
|
||||
else:
|
||||
print("Could not find field '%s' on the SessionRuntimeInfo object." % field)
|
||||
print(
|
||||
"Could not find field '%s' on the SessionRuntimeInfo object."
|
||||
% field
|
||||
)
|
||||
print("")
|
||||
|
||||
@staticmethod
|
||||
@ -509,13 +534,16 @@ class DumpMongoDSessionCatalog(gdb.Command):
|
||||
"""Dump the session."""
|
||||
|
||||
print("Session (" + str(session.address) + "):")
|
||||
fields_to_print = ['_sessionId', '_numWaitingToCheckOut']
|
||||
fields_to_print = ["_sessionId", "_numWaitingToCheckOut"]
|
||||
for field in fields_to_print:
|
||||
# Skip fields that aren't found on the object.
|
||||
if field in get_field_names(session):
|
||||
print(field, "=", session[field])
|
||||
else:
|
||||
print("Could not find field '%s' on the SessionRuntimeInfo object." % field)
|
||||
print(
|
||||
"Could not find field '%s' on the SessionRuntimeInfo object."
|
||||
% field
|
||||
)
|
||||
|
||||
# Print the information from a TransactionParticipant if a session has one.
|
||||
txn_part_dec = get_decoration(session, "TransactionParticipant")
|
||||
@ -528,28 +556,33 @@ class DumpMongoDSessionCatalog(gdb.Command):
|
||||
# from that object. If, in the future, we want to print fields from the
|
||||
# 'PrivateState' object, we can inspect the TransactionParticipant's '_p' field.
|
||||
txn_part = txn_part_dec[1]
|
||||
txn_part_observable_state = txn_part['_o']
|
||||
fields_to_print = ['txnState', 'activeTxnNumberAndRetryCounter']
|
||||
txn_part_observable_state = txn_part["_o"]
|
||||
fields_to_print = ["txnState", "activeTxnNumberAndRetryCounter"]
|
||||
print("TransactionParticipant (" + str(txn_part.address) + "):")
|
||||
for field in fields_to_print:
|
||||
# Skip fields that aren't found on the object.
|
||||
if field in get_field_names(txn_part_observable_state):
|
||||
print(field, "=", txn_part_observable_state[field])
|
||||
else:
|
||||
print("Could not find field '%s' on the TransactionParticipant" % field)
|
||||
print(
|
||||
"Could not find field '%s' on the TransactionParticipant"
|
||||
% field
|
||||
)
|
||||
|
||||
# The 'txnResourceStash' field is a boost::optional so we unpack it manually if it
|
||||
# is non-empty. We are only interested in its Locker object for now. TODO: Load the
|
||||
# boost pretty printers so the object will be printed clearly by default, without
|
||||
# the need for special unpacking.
|
||||
val = get_boost_optional(txn_part_observable_state['txnResourceStash'])
|
||||
val = get_boost_optional(txn_part_observable_state["txnResourceStash"])
|
||||
if val:
|
||||
locker_addr = get_unique_ptr(val["_locker"])
|
||||
locker_obj = locker_addr.dereference().cast(lookup_type("mongo::Locker"))
|
||||
print('txnResourceStash._locker', "@", locker_addr)
|
||||
locker_obj = locker_addr.dereference().cast(
|
||||
lookup_type("mongo::Locker")
|
||||
)
|
||||
print("txnResourceStash._locker", "@", locker_addr)
|
||||
print("txnResourceStash._locker._id", "=", locker_obj["_id"])
|
||||
else:
|
||||
print('txnResourceStash', "=", None)
|
||||
print("txnResourceStash", "=", None)
|
||||
print("")
|
||||
|
||||
|
||||
@ -581,7 +614,8 @@ class DumpMongoDBMutexes(gdb.Command):
|
||||
|
||||
# Use the STL pretty-printer to iterate over the list
|
||||
printer = stdlib_printers.StdForwardListPrinter( # pylint: disable=undefined-variable
|
||||
str(diagnostic_info_list.type), diagnostic_info_list)
|
||||
str(diagnostic_info_list.type), diagnostic_info_list
|
||||
)
|
||||
|
||||
# Prepare structured output doc
|
||||
client_name = str(client["_desc"])
|
||||
@ -595,7 +629,9 @@ class DumpMongoDBMutexes(gdb.Command):
|
||||
output_doc["mutex"] = str(diagnostic_info["_captureName"])[1:-1]
|
||||
|
||||
millis = int(diagnostic_info["_timestamp"]["millis"])
|
||||
dt = datetime.datetime.fromtimestamp(millis / 1000, tz=datetime.timezone.utc)
|
||||
dt = datetime.datetime.fromtimestamp(
|
||||
millis / 1000, tz=datetime.timezone.utc
|
||||
)
|
||||
output_doc["since"] = dt.isoformat()
|
||||
print(json.dumps(output_doc))
|
||||
|
||||
@ -616,7 +652,7 @@ class MongoDBDumpLocks(gdb.Command):
|
||||
print("Running Hang Analyzer Supplement - MongoDBDumpLocks")
|
||||
|
||||
main_binary_name = get_process_name()
|
||||
if main_binary_name == 'mongod':
|
||||
if main_binary_name == "mongod":
|
||||
self.dump_mongod_locks()
|
||||
else:
|
||||
print("Not invoking mongod lock dump for: %s" % (main_binary_name))
|
||||
@ -628,7 +664,9 @@ class MongoDBDumpLocks(gdb.Command):
|
||||
try:
|
||||
# Call into mongod, and dump the state of lock manager
|
||||
# Note that output will go to mongod's standard output, not the debugger output window
|
||||
gdb.execute("call mongo::dumpLockManager()", from_tty=False, to_string=False)
|
||||
gdb.execute(
|
||||
"call mongo::dumpLockManager()", from_tty=False, to_string=False
|
||||
)
|
||||
except gdb.error as gdberr:
|
||||
print("Ignoring error '%s' in dump_mongod_locks" % str(gdberr))
|
||||
|
||||
@ -642,7 +680,9 @@ class MongoDBDumpRecoveryUnits(gdb.Command):
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize MongoDBDumpRecoveryUnits."""
|
||||
RegisterMongoCommand.register(self, "mongodb-dump-recovery-units", gdb.COMMAND_DATA)
|
||||
RegisterMongoCommand.register(
|
||||
self, "mongodb-dump-recovery-units", gdb.COMMAND_DATA
|
||||
)
|
||||
|
||||
def invoke(self, arg, _from_tty):
|
||||
"""Invoke MongoDBDumpRecoveryUnits."""
|
||||
@ -684,12 +724,17 @@ class MongoDBDumpRecoveryUnits(gdb.Command):
|
||||
recovery_unit = None
|
||||
if operation_context_handle:
|
||||
operation_context = operation_context_handle.dereference()
|
||||
recovery_unit_handle = get_unique_ptr(operation_context["_recoveryUnit"])
|
||||
recovery_unit_handle = get_unique_ptr(
|
||||
operation_context["_recoveryUnit"]
|
||||
)
|
||||
# By default, cast the recovery unit as "mongo::WiredTigerRecoveryUnit"
|
||||
recovery_unit = recovery_unit_handle.dereference().cast(
|
||||
lookup_type(recovery_unit_impl_type))
|
||||
lookup_type(recovery_unit_impl_type)
|
||||
)
|
||||
|
||||
output_doc["recoveryUnit"] = hex(recovery_unit_handle) if recovery_unit else "0x0"
|
||||
output_doc["recoveryUnit"] = (
|
||||
hex(recovery_unit_handle) if recovery_unit else "0x0"
|
||||
)
|
||||
wt_session = get_wt_session(recovery_unit, recovery_unit_impl_type)
|
||||
if wt_session:
|
||||
output_doc["WT_SESSION"] = hex(wt_session)
|
||||
@ -700,14 +745,18 @@ class MongoDBDumpRecoveryUnits(gdb.Command):
|
||||
# Dump stashed recovery unit info for each session in a mongod process
|
||||
for session_kv in get_session_kv_pairs():
|
||||
# The Session objects are stored inside the SessionRuntimeInfo object.
|
||||
session_runtime_info = get_unique_ptr(session_kv['second']).dereference()
|
||||
parent_session = session_runtime_info['parentSession']
|
||||
child_sessions = absl_get_nodes(session_runtime_info['childSessions'])
|
||||
session_runtime_info = get_unique_ptr(session_kv["second"]).dereference()
|
||||
parent_session = session_runtime_info["parentSession"]
|
||||
child_sessions = absl_get_nodes(session_runtime_info["childSessions"])
|
||||
|
||||
MongoDBDumpRecoveryUnits.dump_session(parent_session, recovery_unit_impl_type)
|
||||
MongoDBDumpRecoveryUnits.dump_session(
|
||||
parent_session, recovery_unit_impl_type
|
||||
)
|
||||
for child_session_kv in child_sessions:
|
||||
child_session = child_session_kv['second']
|
||||
MongoDBDumpRecoveryUnits.dump_session(child_session, recovery_unit_impl_type)
|
||||
child_session = child_session_kv["second"]
|
||||
MongoDBDumpRecoveryUnits.dump_session(
|
||||
child_session, recovery_unit_impl_type
|
||||
)
|
||||
|
||||
if enabled_at_start:
|
||||
gdb.execute("set print static-members on")
|
||||
@ -726,15 +775,21 @@ class MongoDBDumpRecoveryUnits(gdb.Command):
|
||||
if txn_participant_dec:
|
||||
txn_participant_observable_state = txn_participant_dec[1]["_o"]
|
||||
txn_resource_stash = get_boost_optional(
|
||||
txn_participant_observable_state["txnResourceStash"])
|
||||
txn_participant_observable_state["txnResourceStash"]
|
||||
)
|
||||
if txn_resource_stash:
|
||||
output_doc["txnResourceStash"] = str(txn_resource_stash.address)
|
||||
recovery_unit_handle = get_unique_ptr(txn_resource_stash["_recoveryUnit"])
|
||||
recovery_unit_handle = get_unique_ptr(
|
||||
txn_resource_stash["_recoveryUnit"]
|
||||
)
|
||||
# By default, cast the recovery unit as "mongo::WiredTigerRecoveryUnit"
|
||||
recovery_unit = recovery_unit_handle.dereference().cast(
|
||||
lookup_type(recovery_unit_impl_type))
|
||||
lookup_type(recovery_unit_impl_type)
|
||||
)
|
||||
|
||||
output_doc["recoveryUnit"] = hex(recovery_unit_handle) if recovery_unit else "0x0"
|
||||
output_doc["recoveryUnit"] = (
|
||||
hex(recovery_unit_handle) if recovery_unit else "0x0"
|
||||
)
|
||||
wt_session = get_wt_session(recovery_unit, recovery_unit_impl_type)
|
||||
if wt_session:
|
||||
output_doc["WT_SESSION"] = hex(wt_session)
|
||||
@ -754,17 +809,22 @@ class MongoDBDumpStorageEngineInfo(gdb.Command):
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize MongoDBDumpStorageEngineInfo."""
|
||||
RegisterMongoCommand.register(self, "mongodb-dump-storage-engine-info", gdb.COMMAND_DATA)
|
||||
RegisterMongoCommand.register(
|
||||
self, "mongodb-dump-storage-engine-info", gdb.COMMAND_DATA
|
||||
)
|
||||
|
||||
def invoke(self, arg, _from_tty): # pylint: disable=unused-argument
|
||||
"""Invoke MongoDBDumpStorageEngineInfo."""
|
||||
print("Running Hang Analyzer Supplement - MongoDBDumpStorageEngineInfo")
|
||||
|
||||
main_binary_name = get_process_name()
|
||||
if main_binary_name == 'mongod':
|
||||
if main_binary_name == "mongod":
|
||||
self.dump_mongod_storage_engine_info()
|
||||
else:
|
||||
print("Not invoking mongod storage engine info dump for: %s" % (main_binary_name))
|
||||
print(
|
||||
"Not invoking mongod storage engine info dump for: %s"
|
||||
% (main_binary_name)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def dump_mongod_storage_engine_info():
|
||||
@ -775,9 +835,13 @@ class MongoDBDumpStorageEngineInfo(gdb.Command):
|
||||
# Note that output will go to mongod's standard output, not the debugger output window
|
||||
gdb.execute(
|
||||
"call mongo::getGlobalServiceContext()->_storageEngine._ptr._value._M_b._M_p->dump()",
|
||||
from_tty=False, to_string=False)
|
||||
from_tty=False,
|
||||
to_string=False,
|
||||
)
|
||||
except gdb.error as gdberr:
|
||||
print("Ignoring error '%s' in dump_mongod_storage_engine_info" % str(gdberr))
|
||||
print(
|
||||
"Ignoring error '%s' in dump_mongod_storage_engine_info" % str(gdberr)
|
||||
)
|
||||
|
||||
|
||||
# Register command
|
||||
@ -794,7 +858,9 @@ class BtIfActive(gdb.Command):
|
||||
def invoke(self, arg, _from_tty): # pylint: disable=unused-argument
|
||||
"""Invoke GDB to print stack trace."""
|
||||
try:
|
||||
idle_location = gdb.parse_and_eval("mongo::for_debuggers::idleThreadLocation")
|
||||
idle_location = gdb.parse_and_eval(
|
||||
"mongo::for_debuggers::idleThreadLocation"
|
||||
)
|
||||
except gdb.error:
|
||||
idle_location = None # If unsure, print a stack trace.
|
||||
|
||||
@ -821,7 +887,7 @@ class MongoDBUniqueStack(gdb.Command):
|
||||
"""Invoke GDB to dump stacks."""
|
||||
stacks = {}
|
||||
if not arg:
|
||||
arg = 'bt' # default to 'bt'
|
||||
arg = "bt" # default to 'bt'
|
||||
|
||||
current_thread = gdb.selected_thread()
|
||||
try:
|
||||
@ -839,24 +905,28 @@ class MongoDBUniqueStack(gdb.Command):
|
||||
def _process_thread_stack(arg, stacks, thread):
|
||||
"""Process the thread stack."""
|
||||
thread_info = {} # thread dict to hold per thread data
|
||||
thread_info['pthread'] = get_thread_id()
|
||||
thread_info['gdb_thread_num'] = thread.num
|
||||
thread_info['lwpid'] = thread.ptid[1]
|
||||
thread_info['name'] = get_current_thread_name()
|
||||
thread_info["pthread"] = get_thread_id()
|
||||
thread_info["gdb_thread_num"] = thread.num
|
||||
thread_info["lwpid"] = thread.ptid[1]
|
||||
thread_info["name"] = get_current_thread_name()
|
||||
|
||||
if sys.platform.startswith("linux"):
|
||||
header_format = "Thread {gdb_thread_num}: {name} (Thread 0x{pthread:x} (LWP {lwpid}))"
|
||||
header_format = (
|
||||
"Thread {gdb_thread_num}: {name} (Thread 0x{pthread:x} (LWP {lwpid}))"
|
||||
)
|
||||
elif sys.platform.startswith("sunos"):
|
||||
(_, _, thread_tid) = thread.ptid
|
||||
if thread_tid != 0 and thread_info['lwpid'] != 0:
|
||||
header_format = "Thread {gdb_thread_num}: {name} (Thread {pthread} (LWP {lwpid}))"
|
||||
elif thread_info['lwpid'] != 0:
|
||||
if thread_tid != 0 and thread_info["lwpid"] != 0:
|
||||
header_format = (
|
||||
"Thread {gdb_thread_num}: {name} (Thread {pthread} (LWP {lwpid}))"
|
||||
)
|
||||
elif thread_info["lwpid"] != 0:
|
||||
header_format = "Thread {gdb_thread_num}: {name} (LWP {lwpid})"
|
||||
else:
|
||||
header_format = "Thread {gdb_thread_num}: {name} (Thread {pthread})"
|
||||
else:
|
||||
raise ValueError("Unsupported platform: {}".format(sys.platform))
|
||||
thread_info['header'] = header_format.format(**thread_info)
|
||||
thread_info["header"] = header_format.format(**thread_info)
|
||||
|
||||
addrs = [] # list of return addresses from frames
|
||||
frame = gdb.newest_frame()
|
||||
@ -865,17 +935,17 @@ class MongoDBUniqueStack(gdb.Command):
|
||||
try:
|
||||
frame = frame.older()
|
||||
except gdb.error as err:
|
||||
print("{} {}".format(thread_info['header'], err))
|
||||
print("{} {}".format(thread_info["header"], err))
|
||||
break
|
||||
addrs_tuple = tuple(addrs) # tuples are hashable, lists aren't.
|
||||
|
||||
unique = stacks.setdefault(addrs_tuple, {'threads': []})
|
||||
unique['threads'].append(thread_info)
|
||||
if 'output' not in unique:
|
||||
unique = stacks.setdefault(addrs_tuple, {"threads": []})
|
||||
unique["threads"].append(thread_info)
|
||||
if "output" not in unique:
|
||||
try:
|
||||
unique['output'] = gdb.execute(arg, to_string=True).rstrip()
|
||||
unique["output"] = gdb.execute(arg, to_string=True).rstrip()
|
||||
except gdb.error as err:
|
||||
print("{} {}".format(thread_info['header'], err))
|
||||
print("{} {}".format(thread_info["header"], err))
|
||||
|
||||
@staticmethod
|
||||
def _dump_unique_stacks(stacks):
|
||||
@ -883,13 +953,13 @@ class MongoDBUniqueStack(gdb.Command):
|
||||
|
||||
def first_tid(stack):
|
||||
"""Return the first tid."""
|
||||
return stack['threads'][0]['gdb_thread_num']
|
||||
return stack["threads"][0]["gdb_thread_num"]
|
||||
|
||||
for stack in sorted(list(stacks.values()), key=first_tid, reverse=True):
|
||||
for i, thread in enumerate(stack['threads']):
|
||||
prefix = '' if i == 0 else 'Duplicate '
|
||||
print(prefix + thread['header'])
|
||||
print(stack['output'])
|
||||
for i, thread in enumerate(stack["threads"]):
|
||||
prefix = "" if i == 0 else "Duplicate "
|
||||
print(prefix + thread["header"])
|
||||
print(stack["output"])
|
||||
print() # leave extra blank line after each thread stack
|
||||
|
||||
|
||||
@ -910,14 +980,16 @@ class MongoDBJavaScriptStack(gdb.Command):
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize MongoDBJavaScriptStack."""
|
||||
RegisterMongoCommand.register(self, "mongodb-javascript-stack", gdb.COMMAND_STATUS)
|
||||
RegisterMongoCommand.register(
|
||||
self, "mongodb-javascript-stack", gdb.COMMAND_STATUS
|
||||
)
|
||||
|
||||
def invoke(self, arg, _from_tty): # pylint: disable=unused-argument
|
||||
"""Invoke GDB to dump JS stacks."""
|
||||
print("Running Print JavaScript Stack Supplement")
|
||||
|
||||
main_binary_name = get_process_name()
|
||||
if main_binary_name.endswith('mongod') or main_binary_name.endswith('mongo'):
|
||||
if main_binary_name.endswith("mongod") or main_binary_name.endswith("mongo"):
|
||||
self.javascript_stack()
|
||||
else:
|
||||
print("No JavaScript stack print done for: %s" % (main_binary_name))
|
||||
@ -930,13 +1002,14 @@ class MongoDBJavaScriptStack(gdb.Command):
|
||||
# `'_M_b' in atomic_scope`, so exceptions for flow control it is. :|
|
||||
try:
|
||||
# reach into std::atomic and grab the pointer. This is for libc++
|
||||
return atomic_scope['_M_b']['_M_p']
|
||||
return atomic_scope["_M_b"]["_M_p"]
|
||||
except gdb.error:
|
||||
# Worst case scenario: try and use .load(), but it's probably
|
||||
# inlined. parse_and_eval required because you can't call methods
|
||||
# in gdb on the Python API
|
||||
return gdb.parse_and_eval(
|
||||
f"((std::atomic<mongo::mozjs::MozJSImplScope*> *)({atomic_scope.address}))->load()")
|
||||
f"((std::atomic<mongo::mozjs::MozJSImplScope*> *)({atomic_scope.address}))->load()"
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@ -973,16 +1046,18 @@ class MongoDBJavaScriptStack(gdb.Command):
|
||||
continue
|
||||
|
||||
scope = ptr.dereference()
|
||||
if scope['_inOp'] == 0:
|
||||
if scope["_inOp"] == 0:
|
||||
continue
|
||||
|
||||
gdb.execute('thread', from_tty=False, to_string=False)
|
||||
gdb.execute("thread", from_tty=False, to_string=False)
|
||||
# gdb continues to not support calling methods through Python,
|
||||
# so work around it by casting the raw ptr back to its type,
|
||||
# and calling the method through execute darkness
|
||||
gdb.execute(
|
||||
f'printf "%s\\n", ((mongo::mozjs::MozJSImplScope*)({ptr}))->buildStackString().c_str()',
|
||||
from_tty=False, to_string=False)
|
||||
from_tty=False,
|
||||
to_string=False,
|
||||
)
|
||||
|
||||
except gdb.error as err:
|
||||
print("Ignoring GDB error '%s' in javascript_stack" % str(err))
|
||||
@ -1002,7 +1077,7 @@ class MongoDBPPrintBsonAtPointer(gdb.Command):
|
||||
|
||||
def invoke(self, args, _from_tty):
|
||||
"""Invoke."""
|
||||
args = args.split(' ')
|
||||
args = args.split(" ")
|
||||
if len(args) == 0 or (len(args) == 1 and len(args[0]) == 0):
|
||||
print("Usage: mongodb-pprint-bson <ptr> <optional length>")
|
||||
return
|
||||
@ -1017,6 +1092,7 @@ class MongoDBPPrintBsonAtPointer(gdb.Command):
|
||||
bsonobj = next(bson.decode_iter(memory))
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
pprint(bsonobj)
|
||||
|
||||
|
||||
|
||||
@ -19,7 +19,8 @@ if not gdb:
|
||||
|
||||
if sys.version_info[0] < 3:
|
||||
raise gdb.GdbError(
|
||||
"MongoDB gdb extensions only support Python 3. Your GDB was compiled against Python 2")
|
||||
"MongoDB gdb extensions only support Python 3. Your GDB was compiled against Python 2"
|
||||
)
|
||||
|
||||
|
||||
class NonExecutingThread(object):
|
||||
@ -69,7 +70,9 @@ class Thread(object):
|
||||
return not self == other
|
||||
|
||||
def __str__(self):
|
||||
return "{} (Thread 0x{:012x} (LWP {}))".format(self.name, self.thread_id, self.lwpid)
|
||||
return "{} (Thread 0x{:012x} (LWP {}))".format(
|
||||
self.name, self.thread_id, self.lwpid
|
||||
)
|
||||
|
||||
def key(self):
|
||||
"""Return thread key."""
|
||||
@ -125,7 +128,7 @@ class Graph(object):
|
||||
def add_node(self, node):
|
||||
"""Add node to graph."""
|
||||
if not self.find_node(node):
|
||||
self.nodes[node.key()] = {'node': node, 'next_nodes': []}
|
||||
self.nodes[node.key()] = {"node": node, "next_nodes": []}
|
||||
|
||||
def find_node(self, node):
|
||||
"""Find node in graph."""
|
||||
@ -137,8 +140,8 @@ class Graph(object):
|
||||
"""Find from node."""
|
||||
for node_key in self.nodes:
|
||||
node = self.nodes[node_key]
|
||||
for next_node in node['next_nodes']:
|
||||
if next_node == from_node['node'].key():
|
||||
for next_node in node["next_nodes"]:
|
||||
if next_node == from_node["node"].key():
|
||||
return node
|
||||
return None
|
||||
|
||||
@ -148,7 +151,7 @@ class Graph(object):
|
||||
temp_nodes = {}
|
||||
for node_key in self.nodes:
|
||||
node = self.nodes[node_key]
|
||||
if node['next_nodes'] or self.find_from_node(node) is not None:
|
||||
if node["next_nodes"] or self.find_from_node(node) is not None:
|
||||
temp_nodes[node_key] = self.nodes[node_key]
|
||||
self.nodes = temp_nodes
|
||||
|
||||
@ -164,50 +167,62 @@ class Graph(object):
|
||||
self.add_node(to_node)
|
||||
t_node = self.nodes[to_node.key()]
|
||||
|
||||
for n_node in f_node['next_nodes']:
|
||||
for n_node in f_node["next_nodes"]:
|
||||
if n_node == to_node.key():
|
||||
return
|
||||
self.nodes[from_node.key()]['next_nodes'].append(to_node.key())
|
||||
self.nodes[from_node.key()]["next_nodes"].append(to_node.key())
|
||||
|
||||
def print(self):
|
||||
"""Print graph."""
|
||||
for node_key in self.nodes:
|
||||
print("Node", self.nodes[node_key]['node'])
|
||||
for to_node in self.nodes[node_key]['next_nodes']:
|
||||
print(" ->", self.nodes[to_node]['node'])
|
||||
print("Node", self.nodes[node_key]["node"])
|
||||
for to_node in self.nodes[node_key]["next_nodes"]:
|
||||
print(" ->", self.nodes[to_node]["node"])
|
||||
|
||||
def _get_node_escaped(self, node_key):
|
||||
"""Return the name of the node with any double quotes escaped.
|
||||
|
||||
The DOT language requires that literal double quotes be escaped using a backslash character.
|
||||
"""
|
||||
return str(self.nodes[node_key]['node']).replace('"', '\\"')
|
||||
return str(self.nodes[node_key]["node"]).replace('"', '\\"')
|
||||
|
||||
def to_graph(self, nodes=None, message=None):
|
||||
"""Return the 'to_graph'."""
|
||||
sb = []
|
||||
sb.append('# Legend:')
|
||||
sb.append('# Thread 1 -> Lock C (MODE_IX) indicates Thread 1 is waiting on Lock C and'
|
||||
' Lock C is currently held in MODE_IX')
|
||||
sb.append('# Lock C (MODE_IX) -> Thread 2 indicates Lock C is held by Thread 2 in'
|
||||
' MODE_IX')
|
||||
sb.append("# Legend:")
|
||||
sb.append(
|
||||
"# Thread 1 -> Lock C (MODE_IX) indicates Thread 1 is waiting on Lock C and"
|
||||
" Lock C is currently held in MODE_IX"
|
||||
)
|
||||
sb.append(
|
||||
"# Lock C (MODE_IX) -> Thread 2 indicates Lock C is held by Thread 2 in"
|
||||
" MODE_IX"
|
||||
)
|
||||
if message is not None:
|
||||
sb.append(message)
|
||||
sb.append('digraph "mongod+lock-status" {')
|
||||
# Draw the graph from left to right. There can be hundreds of threads blocked by the same
|
||||
# resource, but only a few resources involved in a deadlock, so we prefer a long graph
|
||||
# than a super wide one. Long resource / thread names would make a wide graph even wider.
|
||||
sb.append(' rankdir=LR;')
|
||||
sb.append(" rankdir=LR;")
|
||||
for node_key in self.nodes:
|
||||
for next_node_key in self.nodes[node_key]['next_nodes']:
|
||||
sb.append(' "{}" -> "{}";'.format(
|
||||
self._get_node_escaped(node_key), self._get_node_escaped(next_node_key)))
|
||||
for next_node_key in self.nodes[node_key]["next_nodes"]:
|
||||
sb.append(
|
||||
' "{}" -> "{}";'.format(
|
||||
self._get_node_escaped(node_key),
|
||||
self._get_node_escaped(next_node_key),
|
||||
)
|
||||
)
|
||||
for node_key in self.nodes:
|
||||
color = ""
|
||||
if nodes and node_key in nodes:
|
||||
color = "color = red"
|
||||
|
||||
sb.append(' "{0}" [label="{0}" {1}]'.format(self._get_node_escaped(node_key), color))
|
||||
sb.append(
|
||||
' "{0}" [label="{0}" {1}]'.format(
|
||||
self._get_node_escaped(node_key), color
|
||||
)
|
||||
)
|
||||
sb.append("}")
|
||||
return "\n".join(sb)
|
||||
|
||||
@ -221,10 +236,10 @@ class Graph(object):
|
||||
nodes_in_cycle = []
|
||||
nodes_visited.add(node_key)
|
||||
nodes_in_cycle.append(node_key)
|
||||
for node in self.nodes[node_key]['next_nodes']:
|
||||
for node in self.nodes[node_key]["next_nodes"]:
|
||||
if node in nodes_in_cycle:
|
||||
# The graph cycle starts at the index of node in nodes_in_cycle.
|
||||
return nodes_in_cycle[nodes_in_cycle.index(node):]
|
||||
return nodes_in_cycle[nodes_in_cycle.index(node) :]
|
||||
if node not in nodes_visited:
|
||||
dfs_nodes = self.depth_first_search(node, nodes_visited, nodes_in_cycle)
|
||||
if dfs_nodes:
|
||||
@ -241,13 +256,15 @@ class Graph(object):
|
||||
if node not in nodes_visited:
|
||||
cycle_path = self.depth_first_search(node, nodes_visited)
|
||||
if cycle_path:
|
||||
return [str(self.nodes[node_key]['node']) for node_key in cycle_path]
|
||||
return [
|
||||
str(self.nodes[node_key]["node"]) for node_key in cycle_path
|
||||
]
|
||||
return None
|
||||
|
||||
|
||||
def find_thread(thread_dict, search_thread_id):
|
||||
"""Find thread."""
|
||||
for (_, thread) in list(thread_dict.items()):
|
||||
for _, thread in list(thread_dict.items()):
|
||||
if thread.thread_id == search_thread_id:
|
||||
return thread
|
||||
return None
|
||||
@ -286,7 +303,7 @@ def find_frame(function_name_pattern):
|
||||
|
||||
def find_mutex_holder(graph, thread_dict, show):
|
||||
"""Find mutex holder."""
|
||||
frame = find_frame(r'std::mutex::lock\(\)')
|
||||
frame = find_frame(r"std::mutex::lock\(\)")
|
||||
if frame is None:
|
||||
return
|
||||
|
||||
@ -301,9 +318,12 @@ def find_mutex_holder(graph, thread_dict, show):
|
||||
# At time thread_dict was initialized, the mutex holder may not have been found.
|
||||
# Use the thread LWP as a substitute for showing output or generating the graph.
|
||||
if mutex_holder_lwpid not in thread_dict:
|
||||
print("Warning: Mutex at {} held by thread with LWP {}"
|
||||
" not found in thread_dict. Using LWP to track thread.".format(
|
||||
mutex_value, mutex_holder_lwpid))
|
||||
print(
|
||||
"Warning: Mutex at {} held by thread with LWP {}"
|
||||
" not found in thread_dict. Using LWP to track thread.".format(
|
||||
mutex_value, mutex_holder_lwpid
|
||||
)
|
||||
)
|
||||
mutex_holder = Thread(mutex_holder_lwpid, mutex_holder_lwpid, '"[unknown]"')
|
||||
else:
|
||||
mutex_holder = thread_dict[mutex_holder_lwpid]
|
||||
@ -311,8 +331,11 @@ def find_mutex_holder(graph, thread_dict, show):
|
||||
(_, mutex_waiter_lwpid, _) = gdb.selected_thread().ptid
|
||||
mutex_waiter = thread_dict[mutex_waiter_lwpid]
|
||||
if show:
|
||||
print("Mutex at {} held by {} waited on by {}".format(mutex_value, mutex_holder,
|
||||
mutex_waiter))
|
||||
print(
|
||||
"Mutex at {} held by {} waited on by {}".format(
|
||||
mutex_value, mutex_holder, mutex_waiter
|
||||
)
|
||||
)
|
||||
if graph:
|
||||
graph.add_edge(mutex_waiter, Lock(int(mutex_value), "Mutex"))
|
||||
graph.add_edge(Lock(int(mutex_value), "Mutex"), mutex_holder)
|
||||
@ -320,7 +343,7 @@ def find_mutex_holder(graph, thread_dict, show):
|
||||
|
||||
def find_lock_manager_holders(graph, thread_dict, show):
|
||||
"""Find lock manager holders."""
|
||||
frame = find_frame(r'mongo::Locker::')
|
||||
frame = find_frame(r"mongo::Locker::")
|
||||
if not frame:
|
||||
return
|
||||
|
||||
@ -349,8 +372,11 @@ def find_lock_manager_holders(graph, thread_dict, show):
|
||||
else:
|
||||
lock_holder = find_thread(thread_dict, lock_holder_id)
|
||||
if show:
|
||||
print("MongoDB Lock at {} held by {} ({}) waited on by {}".format(
|
||||
lock_head, lock_holder, lock_request["mode"], lock_waiter))
|
||||
print(
|
||||
"MongoDB Lock at {} held by {} ({}) waited on by {}".format(
|
||||
lock_head, lock_holder, lock_request["mode"], lock_waiter
|
||||
)
|
||||
)
|
||||
if graph:
|
||||
graph.add_edge(lock_waiter, Lock(int(lock_head), lock_request["mode"]))
|
||||
graph.add_edge(Lock(int(lock_head), lock_request["mode"]), lock_holder)
|
||||
@ -446,7 +472,7 @@ class MongoDBWaitsForGraph(gdb.Command):
|
||||
cycle_message = "# Cycle detected in the graph nodes %s" % cycle_nodes
|
||||
if graph_file:
|
||||
print("Saving digraph to %s" % graph_file)
|
||||
with open(graph_file, 'w') as fh:
|
||||
with open(graph_file, "w") as fh:
|
||||
fh.write(graph.to_graph(nodes=cycle_nodes, message=cycle_message))
|
||||
print(cycle_message.split("# ")[1])
|
||||
else:
|
||||
|
||||
@ -39,7 +39,8 @@ except ImportError:
|
||||
|
||||
if sys.version_info[0] < 3:
|
||||
raise gdb.GdbError(
|
||||
"MongoDB gdb extensions only support Python 3. Your GDB was compiled against Python 2")
|
||||
"MongoDB gdb extensions only support Python 3. Your GDB was compiled against Python 2"
|
||||
)
|
||||
|
||||
|
||||
def get_unique_ptr_bytes(obj):
|
||||
@ -50,7 +51,9 @@ def get_unique_ptr_bytes(obj):
|
||||
mongo::Decorable<> types which store the decorations as a slab of memory with
|
||||
std::unique_ptr<unsigned char[]>. In all other cases get_unique_ptr() can be preferred.
|
||||
"""
|
||||
return obj.cast(gdb.lookup_type('std::_Head_base<0, unsigned char*, false>'))['_M_head_impl']
|
||||
return obj.cast(gdb.lookup_type("std::_Head_base<0, unsigned char*, false>"))[
|
||||
"_M_head_impl"
|
||||
]
|
||||
|
||||
|
||||
def get_unique_ptr(obj):
|
||||
@ -71,20 +74,20 @@ class StatusPrinter(object):
|
||||
@staticmethod
|
||||
def extract_error(val):
|
||||
"""Extract the error object (if any) from a Status/StatusWith."""
|
||||
error = val['_error']
|
||||
if 'px' in error.type.iterkeys():
|
||||
return error['px']
|
||||
error = val["_error"]
|
||||
if "px" in error.type.iterkeys():
|
||||
return error["px"]
|
||||
return error
|
||||
|
||||
@staticmethod
|
||||
def generate_error_details(error):
|
||||
"""Generate a (code,reason) tuple from a Status/StatusWith error object."""
|
||||
info = error.dereference()
|
||||
code = info['code']
|
||||
code = info["code"]
|
||||
# Remove the mongo::ErrorCodes:: prefix. Does nothing if not a real ErrorCode.
|
||||
code = str(code).split('::')[-1]
|
||||
code = str(code).split("::")[-1]
|
||||
|
||||
return (code, info['reason'])
|
||||
return (code, info["reason"])
|
||||
|
||||
def __init__(self, val):
|
||||
"""Initialize StatusPrinter."""
|
||||
@ -94,8 +97,8 @@ class StatusPrinter(object):
|
||||
"""Return status for printing."""
|
||||
error = StatusPrinter.extract_error(self.val)
|
||||
if not error:
|
||||
return 'Status::OK()'
|
||||
return 'Status(%s, %s)' % StatusPrinter.generate_error_details(error)
|
||||
return "Status::OK()"
|
||||
return "Status(%s, %s)" % StatusPrinter.generate_error_details(error)
|
||||
|
||||
|
||||
class StatusWithPrinter(object):
|
||||
@ -107,10 +110,10 @@ class StatusWithPrinter(object):
|
||||
|
||||
def to_string(self):
|
||||
"""Return status for printing."""
|
||||
error = StatusPrinter.extract_error(self.val['_status'])
|
||||
error = StatusPrinter.extract_error(self.val["_status"])
|
||||
if not error:
|
||||
return 'StatusWith(OK, %s)' % (self.val['_t'])
|
||||
return 'StatusWith(%s, %s)' % StatusPrinter.generate_error_details(error)
|
||||
return "StatusWith(OK, %s)" % (self.val["_t"])
|
||||
return "StatusWith(%s, %s)" % StatusPrinter.generate_error_details(error)
|
||||
|
||||
|
||||
class StringDataPrinter(object):
|
||||
@ -123,20 +126,20 @@ class StringDataPrinter(object):
|
||||
@staticmethod
|
||||
def display_hint():
|
||||
"""Display hint."""
|
||||
return 'string'
|
||||
return "string"
|
||||
|
||||
def to_string(self):
|
||||
"""Return data for printing."""
|
||||
# As of SERVER-82604, StringData is based on std::string_view, so try with that first
|
||||
sv = self.val['_sv']
|
||||
sv = self.val["_sv"]
|
||||
if sv is not None:
|
||||
return sv
|
||||
|
||||
# ... back-off to the legacy format otherwise
|
||||
size = self.val['_size']
|
||||
size = self.val["_size"]
|
||||
if size == -1:
|
||||
return self.val['_data'].lazy_string()
|
||||
return self.val['_data'].lazy_string(length=size)
|
||||
return self.val["_data"].lazy_string()
|
||||
return self.val["_data"].lazy_string(length=size)
|
||||
|
||||
|
||||
class BoostOptionalPrinter(object):
|
||||
@ -157,7 +160,7 @@ class BSONObjPrinter(object):
|
||||
def __init__(self, val):
|
||||
"""Initialize BSONObjPrinter."""
|
||||
self.val = val
|
||||
self.ptr = self.val['_objdata'].cast(lookup_type('void').pointer())
|
||||
self.ptr = self.val["_objdata"].cast(lookup_type("void").pointer())
|
||||
self.is_valid = False
|
||||
|
||||
# Handle the endianness of the BSON object size, which is represented as a 32-bit integer
|
||||
@ -168,29 +171,36 @@ class BSONObjPrinter(object):
|
||||
self.size = -1
|
||||
self.raw_memory = None
|
||||
else:
|
||||
self.size = struct.unpack('<I', inferior.read_memory(self.ptr, 4))[0]
|
||||
self.raw_memory = bytes(memoryview(inferior.read_memory(self.ptr, self.size)))
|
||||
self.size = struct.unpack("<I", inferior.read_memory(self.ptr, 4))[0]
|
||||
self.raw_memory = bytes(
|
||||
memoryview(inferior.read_memory(self.ptr, self.size))
|
||||
)
|
||||
if bson:
|
||||
self.is_valid = bson.is_valid(self.raw_memory)
|
||||
|
||||
@staticmethod
|
||||
def display_hint():
|
||||
"""Display hint."""
|
||||
return 'map'
|
||||
return "map"
|
||||
|
||||
def children(self):
|
||||
"""Children."""
|
||||
# Do not decode a BSONObj with an invalid size, or that is considered
|
||||
# invalid by pymongo.
|
||||
if not bson or not self.is_valid or self.size < 5 or self.size > 17 * 1024 * 1024:
|
||||
if (
|
||||
not bson
|
||||
or not self.is_valid
|
||||
or self.size < 5
|
||||
or self.size > 17 * 1024 * 1024
|
||||
):
|
||||
return
|
||||
|
||||
options = CodecOptions(document_class=collections.OrderedDict)
|
||||
bsondoc = bson.decode(self.raw_memory, codec_options=options)
|
||||
|
||||
for key, val in list(bsondoc.items()):
|
||||
yield 'key', key
|
||||
yield 'value', bson.json_util.dumps(val)
|
||||
yield "key", key
|
||||
yield "value", bson.json_util.dumps(val)
|
||||
|
||||
def to_string(self):
|
||||
"""Return BSONObj for printing."""
|
||||
@ -198,7 +208,11 @@ class BSONObjPrinter(object):
|
||||
if self.size == -1:
|
||||
return "BSONObj @ %s - optimized out" % (self.ptr)
|
||||
|
||||
ownership = "owned" if self.val['_ownedBuffer']['_buffer']['_holder']['px'] else "unowned"
|
||||
ownership = (
|
||||
"owned"
|
||||
if self.val["_ownedBuffer"]["_buffer"]["_holder"]["px"]
|
||||
else "unowned"
|
||||
)
|
||||
|
||||
size = self.size
|
||||
# Print an invalid BSONObj size in hex.
|
||||
@ -232,12 +246,17 @@ class OplogEntryPrinter(object):
|
||||
|
||||
def to_string(self):
|
||||
"""Return OplogEntry for printing."""
|
||||
optime = self.val['_entry']['_opTimeBase']
|
||||
optime_str = "ts(%s, %s)" % (optime['_timestamp']['secs'], optime['_timestamp']['i'])
|
||||
optime = self.val["_entry"]["_opTimeBase"]
|
||||
optime_str = "ts(%s, %s)" % (
|
||||
optime["_timestamp"]["secs"],
|
||||
optime["_timestamp"]["i"],
|
||||
)
|
||||
return "OplogEntry(%s, %s, %s, %s)" % (
|
||||
str(self.val['_entry']['_durableReplOperation']['_opType']).split('::')[-1],
|
||||
str(self.val['_entry']['_commandType']).split('::')[-1],
|
||||
self.val['_entry']['_durableReplOperation']['_nss'], optime_str)
|
||||
str(self.val["_entry"]["_durableReplOperation"]["_opType"]).split("::")[-1],
|
||||
str(self.val["_entry"]["_commandType"]).split("::")[-1],
|
||||
self.val["_entry"]["_durableReplOperation"]["_nss"],
|
||||
optime_str,
|
||||
)
|
||||
|
||||
|
||||
class UUIDPrinter(object):
|
||||
@ -250,11 +269,11 @@ class UUIDPrinter(object):
|
||||
@staticmethod
|
||||
def display_hint():
|
||||
"""Display hint."""
|
||||
return 'string'
|
||||
return "string"
|
||||
|
||||
def to_string(self):
|
||||
"""Return UUID for printing."""
|
||||
raw_bytes = [self.val['_uuid']['_M_elems'][i] for i in range(16)]
|
||||
raw_bytes = [self.val["_uuid"]["_M_elems"][i] for i in range(16)]
|
||||
uuid_hex_bytes = [hex(int(b))[2:].zfill(2) for b in raw_bytes]
|
||||
return str(uuid.UUID("".join(uuid_hex_bytes)))
|
||||
|
||||
@ -269,11 +288,11 @@ class OIDPrinter(object):
|
||||
@staticmethod
|
||||
def display_hint():
|
||||
"""Display hint."""
|
||||
return 'string'
|
||||
return "string"
|
||||
|
||||
def to_string(self):
|
||||
"""Return OID for printing."""
|
||||
raw_bytes = [int(self.val['_data'][i]) for i in range(OBJECT_ID_WIDTH)]
|
||||
raw_bytes = [int(self.val["_data"][i]) for i in range(OBJECT_ID_WIDTH)]
|
||||
oid_hex_bytes = [hex(b & 0xFF)[2:].zfill(2) for b in raw_bytes]
|
||||
return "ObjectID('%s')" % "".join(oid_hex_bytes)
|
||||
|
||||
@ -288,12 +307,12 @@ class RecordIdPrinter(object):
|
||||
@staticmethod
|
||||
def display_hint():
|
||||
"""Display hint."""
|
||||
return 'string'
|
||||
return "string"
|
||||
|
||||
## Get the address at given offset of data as the selected pointer type
|
||||
def __get_data_address(self, ptr, offset):
|
||||
ptr_type = gdb.lookup_type(ptr).pointer()
|
||||
return self.val['_data']['_M_elems'][offset].address.cast(ptr_type)
|
||||
return self.val["_data"]["_M_elems"][offset].address.cast(ptr_type)
|
||||
|
||||
def to_string(self):
|
||||
"""Return RecordId for printing."""
|
||||
@ -301,27 +320,39 @@ class RecordIdPrinter(object):
|
||||
if rid_format == 0:
|
||||
return "null RecordId"
|
||||
elif rid_format == 1:
|
||||
koffset = 8 - 1 ## std::alignment_of_v<int64_t> - sizeof(Format); (see record_id.h)
|
||||
rid_address = self.__get_data_address('int64_t', koffset)
|
||||
koffset = (
|
||||
8 - 1
|
||||
) ## std::alignment_of_v<int64_t> - sizeof(Format); (see record_id.h)
|
||||
rid_address = self.__get_data_address("int64_t", koffset)
|
||||
return "RecordId long: %d" % int(rid_address.dereference())
|
||||
elif rid_format == 2:
|
||||
str_len = self.__get_data_address('int8_t', 0).dereference()
|
||||
array_address = self.__get_data_address('int8_t', 1)
|
||||
str_len = self.__get_data_address("int8_t", 0).dereference()
|
||||
array_address = self.__get_data_address("int8_t", 1)
|
||||
raw_bytes = [array_address[i] for i in range(0, str_len)]
|
||||
hex_bytes = [hex(b & 0xFF)[2:].zfill(2) for b in raw_bytes]
|
||||
return "RecordId small string %d hex bytes: %s" % (str_len, str("".join(hex_bytes)))
|
||||
return "RecordId small string %d hex bytes: %s" % (
|
||||
str_len,
|
||||
str("".join(hex_bytes)),
|
||||
)
|
||||
elif rid_format == 3:
|
||||
koffset = 8 - 1 ## std::alignment_of_v<ConstSharedBuffer> - sizeof(Format); (see record_id.h)
|
||||
buffer = self.__get_data_address('mongo::ConstSharedBuffer', koffset).dereference()
|
||||
holder_ptr = holder = buffer['_buffer']["_holder"]["px"]
|
||||
koffset = (
|
||||
8 - 1
|
||||
) ## std::alignment_of_v<ConstSharedBuffer> - sizeof(Format); (see record_id.h)
|
||||
buffer = self.__get_data_address(
|
||||
"mongo::ConstSharedBuffer", koffset
|
||||
).dereference()
|
||||
holder_ptr = holder = buffer["_buffer"]["_holder"]["px"]
|
||||
holder = holder.dereference()
|
||||
str_len = int(holder["_capacity"])
|
||||
# Start of data is immediately after pointer for holder
|
||||
start_ptr = (holder_ptr + 1).dereference().cast(lookup_type("char")).address
|
||||
raw_bytes = [start_ptr[i] for i in range(0, str_len)]
|
||||
hex_bytes = [hex(b & 0xFF)[2:].zfill(2) for b in raw_bytes]
|
||||
return "RecordId big string %d hex bytes @ %s: %s" % (str_len, holder_ptr + 1,
|
||||
str("".join(hex_bytes)))
|
||||
return "RecordId big string %d hex bytes @ %s: %s" % (
|
||||
str_len,
|
||||
holder_ptr + 1,
|
||||
str("".join(hex_bytes)),
|
||||
)
|
||||
else:
|
||||
return "unknown RecordId format: %d" % rid_format
|
||||
|
||||
@ -353,22 +384,22 @@ class DatabaseNamePrinter(object):
|
||||
@staticmethod
|
||||
def display_hint():
|
||||
"""Display hint."""
|
||||
return 'string'
|
||||
return "string"
|
||||
|
||||
def _get_storage_data(self):
|
||||
"""Return the data pointer from the _data Storage class."""
|
||||
data = self.val['_data']
|
||||
footer = data['_footer']
|
||||
data = self.val["_data"]
|
||||
footer = data["_footer"]
|
||||
f_size = footer.type.sizeof
|
||||
|
||||
# The last byte of _footer contain the flags (and the size when using small string).
|
||||
flags = footer.cast(gdb.lookup_type('char').array(f_size))[f_size - 1]
|
||||
flags = footer.cast(gdb.lookup_type("char").array(f_size))[f_size - 1]
|
||||
|
||||
data_ptr = data['_data']
|
||||
data_ptr = data["_data"]
|
||||
if is_small_string(flags):
|
||||
return data_ptr.address, small_string_size(flags)
|
||||
else:
|
||||
return data_ptr, data['_length']
|
||||
return data_ptr, data["_length"]
|
||||
|
||||
def _get_string(self, address, size):
|
||||
data = gdb.selected_inferior().read_memory(address, size).tobytes()
|
||||
@ -396,21 +427,30 @@ class DecorablePrinter(object):
|
||||
@staticmethod
|
||||
def display_hint():
|
||||
"""Display hint."""
|
||||
return 'map'
|
||||
return "map"
|
||||
|
||||
def to_string(self):
|
||||
"""Return Decorable for printing."""
|
||||
return "Decorable<{}> with {} elems ".format(self.val.type.template_argument(0), self.count)
|
||||
return "Decorable<{}> with {} elems ".format(
|
||||
self.val.type.template_argument(0), self.count
|
||||
)
|
||||
|
||||
def children(self):
|
||||
"""Children."""
|
||||
for index in range(self.count):
|
||||
try:
|
||||
deco_type_name, obj, obj_addr = get_object_decoration(self.val, self.start, index)
|
||||
yield ('key', "{}:{}:{}".format(index, obj_addr, deco_type_name))
|
||||
yield ('value', obj)
|
||||
deco_type_name, obj, obj_addr = get_object_decoration(
|
||||
self.val, self.start, index
|
||||
)
|
||||
yield ("key", "{}:{}:{}".format(index, obj_addr, deco_type_name))
|
||||
yield ("value", obj)
|
||||
except Exception as err:
|
||||
print("Failed to look up decoration type: " + deco_type_name + ": " + str(err))
|
||||
print(
|
||||
"Failed to look up decoration type: "
|
||||
+ deco_type_name
|
||||
+ ": "
|
||||
+ str(err)
|
||||
)
|
||||
|
||||
|
||||
def _get_flags(flag_val, flags):
|
||||
@ -443,7 +483,9 @@ class WtCursorPrinter(object):
|
||||
"""
|
||||
|
||||
try:
|
||||
with open("./src/third_party/wiredtiger/src/include/wiredtiger.in") as wiredtiger_header:
|
||||
with open(
|
||||
"./src/third_party/wiredtiger/src/include/wiredtiger.in"
|
||||
) as wiredtiger_header:
|
||||
file_contents = wiredtiger_header.read()
|
||||
cursor_flags_re = re.compile(r"#define\s+WT_CURSTD_(\w+)\s+0x(\d+)u")
|
||||
cursor_flags = cursor_flags_re.findall(file_contents)[::-1]
|
||||
@ -463,8 +505,12 @@ class WtCursorPrinter(object):
|
||||
for field in self.val.type.fields():
|
||||
field_val = self.val[field.name]
|
||||
if field.name == "flags":
|
||||
yield ("flags", "{} ({})".format(field_val,
|
||||
str(_get_flags(field_val, self.cursor_flags))))
|
||||
yield (
|
||||
"flags",
|
||||
"{} ({})".format(
|
||||
field_val, str(_get_flags(field_val, self.cursor_flags))
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield (field.name, field_val)
|
||||
|
||||
@ -477,7 +523,9 @@ class WtSessionImplPrinter(object):
|
||||
"""
|
||||
|
||||
try:
|
||||
with open("./src/third_party/wiredtiger/src/include/session.h") as session_header:
|
||||
with open(
|
||||
"./src/third_party/wiredtiger/src/include/session.h"
|
||||
) as session_header:
|
||||
file_contents = session_header.read()
|
||||
session_flags_re = re.compile(r"#define\s+WT_SESSION_(\w+)\s+0x(\d+)u")
|
||||
session_flags = session_flags_re.findall(file_contents)[::-1]
|
||||
@ -497,8 +545,12 @@ class WtSessionImplPrinter(object):
|
||||
for field in self.val.type.fields():
|
||||
field_val = self.val[field.name]
|
||||
if field.name == "flags":
|
||||
yield ("flags", "{} ({})".format(field_val,
|
||||
str(_get_flags(field_val, self.session_flags))))
|
||||
yield (
|
||||
"flags",
|
||||
"{} ({})".format(
|
||||
field_val, str(_get_flags(field_val, self.session_flags))
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield (field.name, field_val)
|
||||
|
||||
@ -531,8 +583,12 @@ class WtTxnPrinter(object):
|
||||
for field in self.val.type.fields():
|
||||
field_val = self.val[field.name]
|
||||
if field.name == "flags":
|
||||
yield ("flags", "{} ({})".format(field_val,
|
||||
str(_get_flags(field_val, self.txn_flags))))
|
||||
yield (
|
||||
"flags",
|
||||
"{} ({})".format(
|
||||
field_val, str(_get_flags(field_val, self.txn_flags))
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield (field.name, field_val)
|
||||
|
||||
@ -551,7 +607,11 @@ def absl_insert_version_after_absl(cpp_name):
|
||||
absl_ns_end = absl_ns_start + len(absl_ns_str)
|
||||
|
||||
return (
|
||||
cpp_name[:absl_ns_end] + ABSL_OPTION_INLINE_NAMESPACE_NAME + "::" + cpp_name[absl_ns_end:])
|
||||
cpp_name[:absl_ns_end]
|
||||
+ ABSL_OPTION_INLINE_NAMESPACE_NAME
|
||||
+ "::"
|
||||
+ cpp_name[absl_ns_end:]
|
||||
)
|
||||
|
||||
|
||||
def absl_get_settings(val):
|
||||
@ -559,8 +619,12 @@ def absl_get_settings(val):
|
||||
try:
|
||||
common_fields_storage_type = gdb.lookup_type(
|
||||
absl_insert_version_after_absl(
|
||||
"absl::container_internal::internal_compressed_tuple::Storage") +
|
||||
absl_insert_version_after_absl("<absl::container_internal::CommonFields, 0, false>"))
|
||||
"absl::container_internal::internal_compressed_tuple::Storage"
|
||||
)
|
||||
+ absl_insert_version_after_absl(
|
||||
"<absl::container_internal::CommonFields, 0, false>"
|
||||
)
|
||||
)
|
||||
except gdb.error as err:
|
||||
if not err.args[0].startswith("No type named "):
|
||||
raise
|
||||
@ -571,7 +635,9 @@ def absl_get_settings(val):
|
||||
common_fields_storage_type = gdb.lookup_type(
|
||||
absl_insert_version_after_absl(
|
||||
"absl::container_internal::internal_compressed_tuple::Storage"
|
||||
"<absl::container_internal::CommonFields, 0, false>"))
|
||||
"<absl::container_internal::CommonFields, 0, false>"
|
||||
)
|
||||
)
|
||||
|
||||
# The Hash, Eq, or Alloc functors may not be zero-sized objects.
|
||||
# mongo::LogicalSessionIdHash is one such example. An explicit cast is needed to
|
||||
@ -595,7 +661,9 @@ def absl_get_nodes(val):
|
||||
ctrl = settings["control_"]
|
||||
|
||||
# Derive the underlying type stored in the container.
|
||||
slot_type = lookup_type(str(val.type.strip_typedefs()) + "::slot_type").strip_typedefs()
|
||||
slot_type = lookup_type(
|
||||
str(val.type.strip_typedefs()) + "::slot_type"
|
||||
).strip_typedefs()
|
||||
|
||||
# Using the array of ctrl bytes, search for in-use slots and return them
|
||||
# https://github.com/abseil/abseil-cpp/blob/8a3caf7dea955b513a6c1b572a2423c6b4213402/absl/container/internal/raw_hash_set.h#L2108-L2113
|
||||
@ -616,7 +684,7 @@ class AbslHashSetPrinterBase(object):
|
||||
@staticmethod
|
||||
def display_hint():
|
||||
"""Display hint."""
|
||||
return 'array'
|
||||
return "array"
|
||||
|
||||
def to_string(self):
|
||||
"""Return absl::[node/flat]_hash_set for printing."""
|
||||
@ -668,7 +736,7 @@ class AbslHashMapPrinterBase(object):
|
||||
@staticmethod
|
||||
def display_hint():
|
||||
"""Display hint."""
|
||||
return 'map'
|
||||
return "map"
|
||||
|
||||
def to_string(self):
|
||||
"""Return absl::[node/flat]_hash_map for printing."""
|
||||
@ -690,8 +758,8 @@ class AbslNodeHashMapPrinter(AbslHashMapPrinterBase):
|
||||
def children(self):
|
||||
"""Children."""
|
||||
for kvp in absl_get_nodes(self.val):
|
||||
yield ('key', kvp['first'])
|
||||
yield ('value', kvp['second'])
|
||||
yield ("key", kvp["first"])
|
||||
yield ("value", kvp["second"])
|
||||
|
||||
|
||||
class AbslFlatHashMapPrinter(AbslHashMapPrinterBase):
|
||||
@ -704,8 +772,8 @@ class AbslFlatHashMapPrinter(AbslHashMapPrinterBase):
|
||||
def children(self):
|
||||
"""Children."""
|
||||
for kvp in absl_get_nodes(self.val):
|
||||
yield ('key', kvp['key'])
|
||||
yield ('value', kvp['value']["second"])
|
||||
yield ("key", kvp["key"])
|
||||
yield ("value", kvp["value"]["second"])
|
||||
|
||||
|
||||
class ImmutableMapIter(ImmerListIter):
|
||||
@ -717,7 +785,7 @@ class ImmutableMapIter(ImmerListIter):
|
||||
|
||||
def __next__(self):
|
||||
if self.pair:
|
||||
result = ('value', self.pair['second'])
|
||||
result = ("value", self.pair["second"])
|
||||
self.pair = None
|
||||
self.i += 1
|
||||
return result
|
||||
@ -726,8 +794,9 @@ class ImmutableMapIter(ImmerListIter):
|
||||
if self.i < self.curr[1] or self.i >= self.curr[2]:
|
||||
self.curr = self.region()
|
||||
self.pair = self.curr[0][self.i - self.curr[1]].cast(
|
||||
gdb.lookup_type(self.v.type.template_argument(0).name))
|
||||
result = ('key', self.pair['first'])
|
||||
gdb.lookup_type(self.v.type.template_argument(0).name)
|
||||
)
|
||||
result = ("key", self.pair["first"])
|
||||
return result
|
||||
|
||||
|
||||
@ -738,13 +807,16 @@ class ImmutableMapPrinter:
|
||||
self.val = val
|
||||
|
||||
def to_string(self):
|
||||
return '%s of size %d' % (self.val.type, int(self.val['_storage']['impl_']['size']))
|
||||
return "%s of size %d" % (
|
||||
self.val.type,
|
||||
int(self.val["_storage"]["impl_"]["size"]),
|
||||
)
|
||||
|
||||
def children(self):
|
||||
return ImmutableMapIter(self.val['_storage'])
|
||||
return ImmutableMapIter(self.val["_storage"])
|
||||
|
||||
def display_hint(self):
|
||||
return 'map'
|
||||
return "map"
|
||||
|
||||
|
||||
class ImmutableSetPrinter:
|
||||
@ -754,16 +826,19 @@ class ImmutableSetPrinter:
|
||||
self.val = val
|
||||
|
||||
def to_string(self):
|
||||
return '%s of size %d' % (self.val.type, int(self.val['_storage']['impl_']['size']))
|
||||
return "%s of size %d" % (
|
||||
self.val.type,
|
||||
int(self.val["_storage"]["impl_"]["size"]),
|
||||
)
|
||||
|
||||
def children(self):
|
||||
return ImmerListIter(self.val['_storage'])
|
||||
return ImmerListIter(self.val["_storage"])
|
||||
|
||||
def display_hint(self):
|
||||
return 'array'
|
||||
return "array"
|
||||
|
||||
|
||||
def find_match_brackets(search, opening='<', closing='>'):
|
||||
def find_match_brackets(search, opening="<", closing=">"):
|
||||
"""Return the index of the closing bracket that matches the first opening bracket.
|
||||
|
||||
Return -1 if no last matching bracket is found, i.e. not a template.
|
||||
@ -817,7 +892,9 @@ class MongoPrettyPrinterCollection(gdb.printing.PrettyPrinter):
|
||||
|
||||
def add(self, name, prefix, is_template, printer):
|
||||
"""Add a subprinter."""
|
||||
self.subprinters.append(MongoSubPrettyPrinter(name, prefix, is_template, printer))
|
||||
self.subprinters.append(
|
||||
MongoSubPrettyPrinter(name, prefix, is_template, printer)
|
||||
)
|
||||
|
||||
def __call__(self, val):
|
||||
"""Return matched printer type."""
|
||||
@ -837,7 +914,10 @@ class MongoPrettyPrinterCollection(gdb.printing.PrettyPrinter):
|
||||
# Ignore subtypes of templated classes.
|
||||
# We do not want HashTable<T>::iterator as an example, just HashTable<T>
|
||||
if printer.is_template:
|
||||
if (index + 1 == len(lookup_tag) and lookup_tag.find(printer.prefix) == 0):
|
||||
if (
|
||||
index + 1 == len(lookup_tag)
|
||||
and lookup_tag.find(printer.prefix) == 0
|
||||
):
|
||||
return printer.printer(val)
|
||||
elif lookup_tag == printer.prefix:
|
||||
return printer.printer(val)
|
||||
@ -851,13 +931,13 @@ class WtUpdateToBsonPrinter(object):
|
||||
def __init__(self, val):
|
||||
"""Initializer."""
|
||||
self.val = val
|
||||
self.size = self.val['size']
|
||||
self.ptr = self.val['data']
|
||||
self.size = self.val["size"]
|
||||
self.ptr = self.val["data"]
|
||||
|
||||
@staticmethod
|
||||
def display_hint():
|
||||
"""DisplayHint."""
|
||||
return 'map'
|
||||
return "map"
|
||||
|
||||
def to_string(self):
|
||||
"""ToString."""
|
||||
@ -866,11 +946,11 @@ class WtUpdateToBsonPrinter(object):
|
||||
fld = self.val.type.fields()[idx]
|
||||
val = self.val[fld.name]
|
||||
elems.append(str((fld.name, str(val))))
|
||||
return "WT_UPDATE: \n %s" % ('\n '.join(elems))
|
||||
return "WT_UPDATE: \n %s" % ("\n ".join(elems))
|
||||
|
||||
def children(self):
|
||||
"""children."""
|
||||
if self.val['type'] != 3:
|
||||
if self.val["type"] != 3:
|
||||
# Type 3 is a "normal" update. Notably type 4 is a deletion and type 1 represents a
|
||||
# delta relative to the previous committed version in the update chain. Only attempt
|
||||
# to parse type 3 as bson.
|
||||
@ -884,8 +964,8 @@ class WtUpdateToBsonPrinter(object):
|
||||
return
|
||||
|
||||
for key, value in list(bsonobj.items()):
|
||||
yield 'key', key
|
||||
yield 'value', bson.json_util.dumps(value)
|
||||
yield "key", key
|
||||
yield "value", bson.json_util.dumps(value)
|
||||
|
||||
|
||||
def make_inverse_enum_dict(enum_type_name):
|
||||
@ -898,7 +978,7 @@ def make_inverse_enum_dict(enum_type_name):
|
||||
enum_dict = gdb.types.make_enum_dict(lookup_type(enum_type_name))
|
||||
enum_inverse_dic = dict()
|
||||
for key, value in enum_dict.items():
|
||||
enum_inverse_dic[int(value)] = key.split('::')[-1] # take last element
|
||||
enum_inverse_dic[int(value)] = key.split("::")[-1] # take last element
|
||||
return enum_inverse_dic
|
||||
|
||||
|
||||
@ -941,19 +1021,22 @@ class SbeCodeFragmentPrinter(object):
|
||||
# The instructions stream is stored using 'absl::InlinedVector<uint8_t, 16>' type, which can
|
||||
# either use an inline buffer or an allocated one. The choice of storage is decoded in the
|
||||
# last bit of the 'metadata_' field.
|
||||
storage = self.val['_instrs']['storage_']
|
||||
meta = storage['metadata_'].cast(lookup_type('size_t'))
|
||||
self.is_inlined = (meta % 2 == 0)
|
||||
self.size = (meta >> 1)
|
||||
self.pdata = \
|
||||
storage['data_']['inlined']['inlined_data'].cast(lookup_type('uint8_t').pointer()) \
|
||||
if self.is_inlined \
|
||||
else storage['data_']['allocated']['allocated_data']
|
||||
storage = self.val["_instrs"]["storage_"]
|
||||
meta = storage["metadata_"].cast(lookup_type("size_t"))
|
||||
self.is_inlined = meta % 2 == 0
|
||||
self.size = meta >> 1
|
||||
self.pdata = (
|
||||
storage["data_"]["inlined"]["inlined_data"].cast(
|
||||
lookup_type("uint8_t").pointer()
|
||||
)
|
||||
if self.is_inlined
|
||||
else storage["data_"]["allocated"]["allocated_data"]
|
||||
)
|
||||
|
||||
# Precompute lookup tables for Instructions and Builtins.
|
||||
self.optags_lookup = make_inverse_enum_dict('mongo::sbe::vm::Instruction::Tags')
|
||||
self.builtins_lookup = make_inverse_enum_dict('mongo::sbe::vm::Builtin')
|
||||
self.valuetags_lookup = make_inverse_enum_dict('mongo::sbe::value::TypeTags')
|
||||
self.optags_lookup = make_inverse_enum_dict("mongo::sbe::vm::Instruction::Tags")
|
||||
self.builtins_lookup = make_inverse_enum_dict("mongo::sbe::vm::Builtin")
|
||||
self.valuetags_lookup = make_inverse_enum_dict("mongo::sbe::value::TypeTags")
|
||||
|
||||
def to_string(self):
|
||||
"""Return sbe::vm::CodeFragment for printing."""
|
||||
@ -961,26 +1044,29 @@ class SbeCodeFragmentPrinter(object):
|
||||
|
||||
def children(self):
|
||||
"""children."""
|
||||
yield '_instrs', '{... (to see raw output, run "disable pretty-printer")}'
|
||||
yield '_fixUps', self.val['_fixUps']
|
||||
yield '_stackSize', self.val['_stackSize']
|
||||
yield "_instrs", '{... (to see raw output, run "disable pretty-printer")}'
|
||||
yield "_fixUps", self.val["_fixUps"]
|
||||
yield "_stackSize", self.val["_stackSize"]
|
||||
|
||||
yield 'inlined', self.is_inlined
|
||||
yield 'instrs data at', '[{} - {}]'.format(hex(self.pdata), hex(self.pdata + self.size))
|
||||
yield 'instrs total size', self.size
|
||||
yield "inlined", self.is_inlined
|
||||
yield (
|
||||
"instrs data at",
|
||||
"[{} - {}]".format(hex(self.pdata), hex(self.pdata + self.size)),
|
||||
)
|
||||
yield "instrs total size", self.size
|
||||
|
||||
# Sizes for types we'll use when parsing the insructions stream.
|
||||
int_size = lookup_type('int').sizeof
|
||||
ptr_size = lookup_type('void').pointer().sizeof
|
||||
tag_size = lookup_type('mongo::sbe::value::TypeTags').sizeof
|
||||
value_size = lookup_type('mongo::sbe::value::Value').sizeof
|
||||
uint8_size = lookup_type('uint8_t').sizeof
|
||||
uint32_size = lookup_type('uint32_t').sizeof
|
||||
uint64_size = lookup_type('uint64_t').sizeof
|
||||
builtin_size = lookup_type('mongo::sbe::vm::Builtin').sizeof
|
||||
time_unit_size = lookup_type('mongo::TimeUnit').sizeof
|
||||
timezone_size = lookup_type('mongo::TimeZone').sizeof
|
||||
day_of_week_size = lookup_type('mongo::DayOfWeek').sizeof
|
||||
int_size = lookup_type("int").sizeof
|
||||
ptr_size = lookup_type("void").pointer().sizeof
|
||||
tag_size = lookup_type("mongo::sbe::value::TypeTags").sizeof
|
||||
value_size = lookup_type("mongo::sbe::value::Value").sizeof
|
||||
uint8_size = lookup_type("uint8_t").sizeof
|
||||
uint32_size = lookup_type("uint32_t").sizeof
|
||||
uint64_size = lookup_type("uint64_t").sizeof
|
||||
builtin_size = lookup_type("mongo::sbe::vm::Builtin").sizeof
|
||||
time_unit_size = lookup_type("mongo::TimeUnit").sizeof
|
||||
timezone_size = lookup_type("mongo::TimeZone").sizeof
|
||||
day_of_week_size = lookup_type("mongo::DayOfWeek").sizeof
|
||||
|
||||
cur_op = self.pdata
|
||||
end_op = self.pdata + self.size
|
||||
@ -991,7 +1077,7 @@ class SbeCodeFragmentPrinter(object):
|
||||
op_tag = read_as_integer(op_addr, 1)
|
||||
|
||||
if op_tag not in self.optags_lookup:
|
||||
yield hex(op_addr), 'unknown op tag: {}'.format(op_tag)
|
||||
yield hex(op_addr), "unknown op tag: {}".format(op_tag)
|
||||
error = True
|
||||
break
|
||||
op_name = self.optags_lookup[op_tag]
|
||||
@ -1000,70 +1086,93 @@ class SbeCodeFragmentPrinter(object):
|
||||
instr_count += 1
|
||||
|
||||
# Some instructions have extra arguments, embedded into the ops stream.
|
||||
args = ''
|
||||
if op_name in ['pushLocalVal', 'pushMoveLocalVal', 'pushLocalLambda']:
|
||||
args = 'arg: ' + str(read_as_integer(cur_op, int_size))
|
||||
args = ""
|
||||
if op_name in ["pushLocalVal", "pushMoveLocalVal", "pushLocalLambda"]:
|
||||
args = "arg: " + str(read_as_integer(cur_op, int_size))
|
||||
cur_op += int_size
|
||||
elif op_name in ['jmp', 'jmpTrue', 'jmpFalse', 'jmpNothing', 'jmpNotNothing']:
|
||||
elif op_name in [
|
||||
"jmp",
|
||||
"jmpTrue",
|
||||
"jmpFalse",
|
||||
"jmpNothing",
|
||||
"jmpNotNothing",
|
||||
]:
|
||||
offset = read_as_integer_signed(cur_op, int_size)
|
||||
cur_op += int_size
|
||||
args = 'offset: ' + str(offset) + ', target: ' + hex(cur_op + offset)
|
||||
elif op_name in ['pushConstVal', 'getFieldImm']:
|
||||
args = "offset: " + str(offset) + ", target: " + hex(cur_op + offset)
|
||||
elif op_name in ["pushConstVal", "getFieldImm"]:
|
||||
tag = read_as_integer(cur_op, tag_size)
|
||||
args = 'tag: ' + self.valuetags_lookup.get(tag, "unknown") + \
|
||||
', value: ' + hex(read_as_integer(cur_op + tag_size, value_size))
|
||||
cur_op += (tag_size + value_size)
|
||||
elif op_name in ['pushAccessVal', 'pushMoveVal']:
|
||||
args = 'accessor: ' + hex(read_as_integer(cur_op, ptr_size))
|
||||
args = (
|
||||
"tag: "
|
||||
+ self.valuetags_lookup.get(tag, "unknown")
|
||||
+ ", value: "
|
||||
+ hex(read_as_integer(cur_op + tag_size, value_size))
|
||||
)
|
||||
cur_op += tag_size + value_size
|
||||
elif op_name in ["pushAccessVal", "pushMoveVal"]:
|
||||
args = "accessor: " + hex(read_as_integer(cur_op, ptr_size))
|
||||
cur_op += ptr_size
|
||||
elif op_name in ['numConvert']:
|
||||
args = 'convert to: ' + \
|
||||
self.valuetags_lookup.get(read_as_integer(cur_op, tag_size), "unknown")
|
||||
elif op_name in ["numConvert"]:
|
||||
args = "convert to: " + self.valuetags_lookup.get(
|
||||
read_as_integer(cur_op, tag_size), "unknown"
|
||||
)
|
||||
cur_op += tag_size
|
||||
elif op_name in ['typeMatchImm']:
|
||||
args = 'mask: ' + hex(read_as_integer(cur_op, uint32_size))
|
||||
elif op_name in ["typeMatchImm"]:
|
||||
args = "mask: " + hex(read_as_integer(cur_op, uint32_size))
|
||||
cur_op += uint32_size
|
||||
elif op_name in ['function', 'functionSmall']:
|
||||
arity_size = \
|
||||
lookup_type('mongo::sbe::vm::ArityType').sizeof \
|
||||
if op_name == 'function' \
|
||||
else lookup_type('mongo::sbe::vm::SmallArityType').sizeof
|
||||
elif op_name in ["function", "functionSmall"]:
|
||||
arity_size = (
|
||||
lookup_type("mongo::sbe::vm::ArityType").sizeof
|
||||
if op_name == "function"
|
||||
else lookup_type("mongo::sbe::vm::SmallArityType").sizeof
|
||||
)
|
||||
builtin_id = read_as_integer(cur_op, builtin_size)
|
||||
args = 'builtin: ' + self.builtins_lookup.get(builtin_id, "unknown")
|
||||
args += ' arity: ' + str(read_as_integer(cur_op + builtin_size, arity_size))
|
||||
cur_op += (builtin_size + arity_size)
|
||||
elif op_name in ['fillEmptyImm']:
|
||||
args = 'Instruction::Constants: ' + str(read_as_integer(cur_op, uint8_size))
|
||||
args = "builtin: " + self.builtins_lookup.get(builtin_id, "unknown")
|
||||
args += " arity: " + str(
|
||||
read_as_integer(cur_op + builtin_size, arity_size)
|
||||
)
|
||||
cur_op += builtin_size + arity_size
|
||||
elif op_name in ["fillEmptyImm"]:
|
||||
args = "Instruction::Constants: " + str(
|
||||
read_as_integer(cur_op, uint8_size)
|
||||
)
|
||||
cur_op += uint8_size
|
||||
elif op_name in ['traverseFImm', 'traversePImm']:
|
||||
elif op_name in ["traverseFImm", "traversePImm"]:
|
||||
const_enum = read_as_integer(cur_op, uint8_size)
|
||||
cur_op += uint8_size
|
||||
args = \
|
||||
'Instruction::Constants: ' + str(const_enum) + \
|
||||
", offset: " + str(read_as_integer_signed(cur_op, int_size))
|
||||
args = (
|
||||
"Instruction::Constants: "
|
||||
+ str(const_enum)
|
||||
+ ", offset: "
|
||||
+ str(read_as_integer_signed(cur_op, int_size))
|
||||
)
|
||||
cur_op += int_size
|
||||
elif op_name in ['dateTruncImm']:
|
||||
elif op_name in ["dateTruncImm"]:
|
||||
unit = read_as_integer(cur_op, time_unit_size)
|
||||
cur_op += time_unit_size
|
||||
args = 'unit: ' + str(unit)
|
||||
args = "unit: " + str(unit)
|
||||
bin_size = read_as_integer(cur_op, uint64_size)
|
||||
cur_op += uint64_size
|
||||
args += ', binSize: ' + str(bin_size)
|
||||
args += ", binSize: " + str(bin_size)
|
||||
timezone = read_as_integer(cur_op, timezone_size)
|
||||
cur_op += timezone_size
|
||||
args += ', timezone: ' + hex(timezone)
|
||||
args += ", timezone: " + hex(timezone)
|
||||
day_of_week = read_as_integer(cur_op, day_of_week_size)
|
||||
cur_op += day_of_week_size
|
||||
args += ', dayOfWeek: ' + str(day_of_week)
|
||||
elif op_name in ['traverseCsiCellValues', 'traverseCsiCellTypes']:
|
||||
args += ", dayOfWeek: " + str(day_of_week)
|
||||
elif op_name in ["traverseCsiCellValues", "traverseCsiCellTypes"]:
|
||||
offset = read_as_integer_signed(cur_op, int_size)
|
||||
cur_op += int_size
|
||||
args = 'lambda at: ' + hex(cur_op + offset)
|
||||
args = "lambda at: " + hex(cur_op + offset)
|
||||
|
||||
yield hex(op_addr), '{} ({})'.format(op_name, args)
|
||||
yield hex(op_addr), "{} ({})".format(op_name, args)
|
||||
|
||||
yield 'instructions count', \
|
||||
instr_count if not error else '? (successfully parsed {})'.format(instr_count)
|
||||
yield (
|
||||
"instructions count",
|
||||
instr_count
|
||||
if not error
|
||||
else "? (successfully parsed {})".format(instr_count),
|
||||
)
|
||||
|
||||
|
||||
def build_pretty_printer():
|
||||
@ -1108,7 +1217,9 @@ def build_pretty_printer():
|
||||
pp.add("__wt_session_impl", "__wt_session_impl", False, WtSessionImplPrinter)
|
||||
pp.add("__wt_txn", "__wt_txn", False, WtTxnPrinter)
|
||||
pp.add("__wt_update", "__wt_update", False, WtUpdateToBsonPrinter)
|
||||
pp.add("CodeFragment", "mongo::sbe::vm::CodeFragment", False, SbeCodeFragmentPrinter)
|
||||
pp.add(
|
||||
"CodeFragment", "mongo::sbe::vm::CodeFragment", False, SbeCodeFragmentPrinter
|
||||
)
|
||||
pp.add("boost::optional", "boost::optional", True, BoostOptionalPrinter)
|
||||
pp.add("immutable::map", "mongo::immutable::map", True, ImmutableMapPrinter)
|
||||
pp.add("immutable::set", "mongo::immutable::set", True, ImmutableSetPrinter)
|
||||
@ -1126,6 +1237,8 @@ def build_pretty_printer():
|
||||
###################################################################################################
|
||||
|
||||
# Register pretty-printers, replace existing mongo printers
|
||||
gdb.printing.register_pretty_printer(gdb.current_objfile(), build_pretty_printer(), True)
|
||||
gdb.printing.register_pretty_printer(
|
||||
gdb.current_objfile(), build_pretty_printer(), True
|
||||
)
|
||||
|
||||
print("MongoDB GDB pretty-printers loaded")
|
||||
|
||||
@ -28,7 +28,7 @@ def eval_print_fn(val, print_fn):
|
||||
# replace them with a single EOL character so that GDB prints multi-line
|
||||
# explains nicely.
|
||||
pp_result = print_fn(val)
|
||||
pp_str = str(pp_result).replace("\"", "").replace("\\n", "\n")
|
||||
pp_str = str(pp_result).replace('"', "").replace("\\n", "\n")
|
||||
return pp_str
|
||||
|
||||
|
||||
@ -145,7 +145,9 @@ class FixedArityNodePrinter(object):
|
||||
global operator_indent_level
|
||||
|
||||
prior_indent = operator_indent_level
|
||||
current_indent = operator_indent_level + self.arity + len(self.custom_children) - 1
|
||||
current_indent = (
|
||||
operator_indent_level + self.arity + len(self.custom_children) - 1
|
||||
)
|
||||
for child in self.custom_children:
|
||||
lhs = "\n"
|
||||
for _ in range(current_indent):
|
||||
@ -178,8 +180,8 @@ class FixedArityNodePrinter(object):
|
||||
class Vector(object):
|
||||
def __init__(self, vec):
|
||||
self.vec = vec
|
||||
self.start = vec['_M_impl']['_M_start']
|
||||
self.finish = vec['_M_impl']['_M_finish']
|
||||
self.start = vec["_M_impl"]["_M_start"]
|
||||
self.finish = vec["_M_impl"]["_M_finish"]
|
||||
|
||||
def __iter__(self):
|
||||
item = self.start
|
||||
@ -194,8 +196,11 @@ class Vector(object):
|
||||
|
||||
def get(self, index):
|
||||
if index > self.count() - 1:
|
||||
raise gdb.GdbError("Invalid Vector access at index {} with size {}".format(
|
||||
index, self.count()))
|
||||
raise gdb.GdbError(
|
||||
"Invalid Vector access at index {} with size {}".format(
|
||||
index, self.count()
|
||||
)
|
||||
)
|
||||
item = self.start + index
|
||||
return item.dereference()
|
||||
|
||||
@ -284,7 +289,9 @@ class ScanNodePrinter(object):
|
||||
return str(bound_projections.get(0))
|
||||
|
||||
def to_string(self):
|
||||
return "Scan[{}, {}]".format(self.val["_scanDefName"], self.get_bound_projection())
|
||||
return "Scan[{}, {}]".format(
|
||||
self.val["_scanDefName"], self.get_bound_projection()
|
||||
)
|
||||
|
||||
|
||||
class FilterNodePrinter(FixedArityNodePrinter):
|
||||
@ -320,13 +327,16 @@ class ConstantPrinter(object):
|
||||
value_print_fn = "mongo::sbe::value::print"
|
||||
(print_fn_symbol, _) = gdb.lookup_symbol(value_print_fn)
|
||||
if print_fn_symbol is None:
|
||||
raise gdb.GdbError("Could not find pretty print function: " + value_print_fn)
|
||||
raise gdb.GdbError(
|
||||
"Could not find pretty print function: " + value_print_fn
|
||||
)
|
||||
print_fn = print_fn_symbol.value()
|
||||
return print_fn(tag, value)
|
||||
|
||||
def to_string(self):
|
||||
return "Constant[{}]".format(
|
||||
ConstantPrinter.print_sbe_value(self.val["_tag"], self.val["_val"]))
|
||||
ConstantPrinter.print_sbe_value(self.val["_tag"], self.val["_val"])
|
||||
)
|
||||
|
||||
|
||||
class VariablePrinter(object):
|
||||
@ -609,8 +619,12 @@ class FieldProjectionMapPrinter(object):
|
||||
res += "<root>: " + str(root_proj) + ", "
|
||||
|
||||
# Python reformats the string with embedded "=" characters, avoid that by replacing here.
|
||||
res += str(self.val["_fieldProjections"]).replace("=", ":").replace("{", "(").replace(
|
||||
"}", ")")
|
||||
res += (
|
||||
str(self.val["_fieldProjections"])
|
||||
.replace("=", ":")
|
||||
.replace("{", "(")
|
||||
.replace("}", ")")
|
||||
)
|
||||
res += "}"
|
||||
return res
|
||||
|
||||
@ -624,7 +638,8 @@ class PhysicalScanNodePrinter(FixedArityNodePrinter):
|
||||
|
||||
def to_string(self):
|
||||
return "PhysicalScan[{}, {}]".format(
|
||||
str(self.val["_fieldProjectionMap"]), str(self.val["_scanDefName"]))
|
||||
str(self.val["_fieldProjectionMap"]), str(self.val["_scanDefName"])
|
||||
)
|
||||
|
||||
|
||||
class ValueScanNodePrinter(FixedArityNodePrinter):
|
||||
@ -635,8 +650,9 @@ class ValueScanNodePrinter(FixedArityNodePrinter):
|
||||
super().__init__(val, 1, "ValueScan")
|
||||
|
||||
def to_string(self):
|
||||
return "ValueScan[hasRID={},arraySize={}]".format(self.val["_hasRID"],
|
||||
self.val["_arraySize"])
|
||||
return "ValueScan[hasRID={},arraySize={}]".format(
|
||||
self.val["_hasRID"], self.val["_arraySize"]
|
||||
)
|
||||
|
||||
|
||||
class CoScanNodePrinter(FixedArityNodePrinter):
|
||||
@ -656,8 +672,11 @@ class IndexScanNodePrinter(FixedArityNodePrinter):
|
||||
|
||||
def to_string(self):
|
||||
return "IndexScan[{{{}}}, scanDef={}, indexDef={}, interval={}]".format(
|
||||
self.val["_fieldProjectionMap"], self.val["_scanDefName"], self.val["_indexDefName"],
|
||||
self.val["_indexInterval"]).replace("\n", "")
|
||||
self.val["_fieldProjectionMap"],
|
||||
self.val["_scanDefName"],
|
||||
self.val["_indexDefName"],
|
||||
self.val["_indexInterval"],
|
||||
).replace("\n", "")
|
||||
|
||||
|
||||
class SeekNodePrinter(FixedArityNodePrinter):
|
||||
@ -668,9 +687,11 @@ class SeekNodePrinter(FixedArityNodePrinter):
|
||||
super().__init__(val, 2, "Seek")
|
||||
|
||||
def to_string(self):
|
||||
return "Seek[rid_projection: {}, {}, scanDef: {}]".format(self.val["_ridProjectionName"],
|
||||
self.val["_fieldProjectionMap"],
|
||||
self.val["_scanDefName"])
|
||||
return "Seek[rid_projection: {}, {}, scanDef: {}]".format(
|
||||
self.val["_ridProjectionName"],
|
||||
self.val["_fieldProjectionMap"],
|
||||
self.val["_scanDefName"],
|
||||
)
|
||||
|
||||
|
||||
class MemoLogicalDelegatorNodePrinter(FixedArityNodePrinter):
|
||||
@ -692,8 +713,9 @@ class MemoPhysicalDelegatorNodePrinter(FixedArityNodePrinter):
|
||||
super().__init__(val, 0, "MemoPhysicalDelegator")
|
||||
|
||||
def to_string(self):
|
||||
return "MemoPhysicalDelegator[group: {}, index: {}]".format(self.val["_nodeId"]["_groupId"],
|
||||
self.val["_nodeId"]["_index"])
|
||||
return "MemoPhysicalDelegator[group: {}, index: {}]".format(
|
||||
self.val["_nodeId"]["_groupId"], self.val["_nodeId"]["_index"]
|
||||
)
|
||||
|
||||
|
||||
class ResidualRequirementPrinter(object):
|
||||
@ -710,10 +732,18 @@ class ResidualRequirementPrinter(object):
|
||||
if get_boost_optional(key["_projectionName"]) is not None:
|
||||
res += "refProj: " + str(get_boost_optional(key["_projectionName"])) + ", "
|
||||
|
||||
res += "path: '" + str(key["_path"]).replace("| ", "").replace("\n", " -> ") + "'"
|
||||
res += (
|
||||
"path: '"
|
||||
+ str(key["_path"]).replace("| ", "").replace("\n", " -> ")
|
||||
+ "'"
|
||||
)
|
||||
|
||||
if get_boost_optional(req["_boundProjectionName"]) is not None:
|
||||
res += "boundProj: " + str(get_boost_optional(req["_boundProjectionName"])) + ", "
|
||||
res += (
|
||||
"boundProj: "
|
||||
+ str(get_boost_optional(req["_boundProjectionName"]))
|
||||
+ ", "
|
||||
)
|
||||
|
||||
res += ">"
|
||||
return res
|
||||
@ -805,9 +835,16 @@ class BinaryJoinNodePrinter(FixedArityNodePrinter):
|
||||
super().__init__(val, 3, "BinaryJoin")
|
||||
|
||||
def to_string(self):
|
||||
correlated = print_correlated_projections(self.val["_correlatedProjectionNames"])
|
||||
return "BinaryJoin[type=" + str(strip_namespace(
|
||||
self.val["_joinType"])) + ", " + correlated + "]"
|
||||
correlated = print_correlated_projections(
|
||||
self.val["_correlatedProjectionNames"]
|
||||
)
|
||||
return (
|
||||
"BinaryJoin[type="
|
||||
+ str(strip_namespace(self.val["_joinType"]))
|
||||
+ ", "
|
||||
+ correlated
|
||||
+ "]"
|
||||
)
|
||||
|
||||
|
||||
def print_eq_join_condition(leftKeys, rightKeys):
|
||||
@ -843,8 +880,9 @@ class MergeJoinNodePrinter(FixedArityNodePrinter):
|
||||
|
||||
# Manually add the collation ops.
|
||||
collationOps = Vector(self.val["_collation"])
|
||||
collationChild = "Collation[" + ", ".join(str(collation)
|
||||
for collation in collationOps) + "]"
|
||||
collationChild = (
|
||||
"Collation[" + ", ".join(str(collation) for collation in collationOps) + "]"
|
||||
)
|
||||
self.add_child(collationChild)
|
||||
|
||||
# Manually add the child which prints the sets of keys.
|
||||
@ -859,7 +897,8 @@ class MergeJoinNodePrinter(FixedArityNodePrinter):
|
||||
def print_collation_req(req):
|
||||
spec = Vector(req["_spec"])
|
||||
return ", ".join(
|
||||
str(entry["first"]) + ": " + strip_namespace(entry["second"]) for entry in spec)
|
||||
str(entry["first"]) + ": " + strip_namespace(entry["second"]) for entry in spec
|
||||
)
|
||||
|
||||
|
||||
class SortedMergeNodePrinter(DynamicArityNodePrinter):
|
||||
@ -869,7 +908,9 @@ class SortedMergeNodePrinter(DynamicArityNodePrinter):
|
||||
"""Initialize SortedMergeNodePrinter."""
|
||||
super().__init__(val, 2, "MergeJoin")
|
||||
|
||||
self.add_child("collation[" + print_collation_req(self.val["_collationReq"]) + "]")
|
||||
self.add_child(
|
||||
"collation[" + print_collation_req(self.val["_collationReq"]) + "]"
|
||||
)
|
||||
|
||||
def to_string(self):
|
||||
return "SortedMerge"
|
||||
@ -883,9 +924,16 @@ class NestedLoopJoinNodePrinter(FixedArityNodePrinter):
|
||||
super().__init__(val, 3, "NestedLoopJoin")
|
||||
|
||||
def to_string(self):
|
||||
correlated = print_correlated_projections(self.val["_correlatedProjectionNames"])
|
||||
return "NestedLoopJoin[type=" + strip_namespace(
|
||||
self.val["_joinType"]) + ", " + correlated + "]"
|
||||
correlated = print_correlated_projections(
|
||||
self.val["_correlatedProjectionNames"]
|
||||
)
|
||||
return (
|
||||
"NestedLoopJoin[type="
|
||||
+ strip_namespace(self.val["_joinType"])
|
||||
+ ", "
|
||||
+ correlated
|
||||
+ "]"
|
||||
)
|
||||
|
||||
|
||||
class UnwindNodePrinter(FixedArityNodePrinter):
|
||||
@ -912,8 +960,13 @@ class SpoolProducerNodePrinter(FixedArityNodePrinter):
|
||||
super().__init__(val, 4, "SpoolProducer")
|
||||
|
||||
def to_string(self):
|
||||
return "SpoolProducer[" + strip_namespace(self.val["_type"]) + ", id:" + str(
|
||||
self.val["_spoolId"]) + "]"
|
||||
return (
|
||||
"SpoolProducer["
|
||||
+ strip_namespace(self.val["_type"])
|
||||
+ ", id:"
|
||||
+ str(self.val["_spoolId"])
|
||||
+ "]"
|
||||
)
|
||||
|
||||
|
||||
class SpoolConsumerNodePrinter(FixedArityNodePrinter):
|
||||
@ -924,8 +977,13 @@ class SpoolConsumerNodePrinter(FixedArityNodePrinter):
|
||||
super().__init__(val, 1, "SpoolConsumer")
|
||||
|
||||
def to_string(self):
|
||||
return "SpoolConsumer[" + strip_namespace(self.val["_type"]) + ", id:" + str(
|
||||
self.val["_spoolId"]) + "]"
|
||||
return (
|
||||
"SpoolConsumer["
|
||||
+ strip_namespace(self.val["_type"])
|
||||
+ ", id:"
|
||||
+ str(self.val["_spoolId"])
|
||||
+ "]"
|
||||
)
|
||||
|
||||
|
||||
class CollationNodePrinter(FixedArityNodePrinter):
|
||||
@ -949,8 +1007,13 @@ class LimitSkipNodePrinter(FixedArityNodePrinter):
|
||||
super().__init__(val, 1, "LimitSkip")
|
||||
|
||||
def to_string(self):
|
||||
return "LimitSkip[limit: " + str(self.val["_property"]["_limit"]) + ", skip: " + str(
|
||||
self.val["_property"]["_skip"]) + "]"
|
||||
return (
|
||||
"LimitSkip[limit: "
|
||||
+ str(self.val["_property"]["_limit"])
|
||||
+ ", skip: "
|
||||
+ str(self.val["_property"]["_skip"])
|
||||
+ "]"
|
||||
)
|
||||
|
||||
|
||||
class ExchangeNodePrinter(FixedArityNodePrinter):
|
||||
@ -961,10 +1024,17 @@ class ExchangeNodePrinter(FixedArityNodePrinter):
|
||||
super().__init__(val, 2, "Exchange")
|
||||
|
||||
def to_string(self):
|
||||
return "Exchange[type: " + str(
|
||||
self.val["_distribution"]["_distributionAndProjections"]
|
||||
["_type"]) + ", projections: " + str(
|
||||
self.val["_distribution"]["_distributionAndProjections"]["_projectionNames"]) + "]"
|
||||
return (
|
||||
"Exchange[type: "
|
||||
+ str(self.val["_distribution"]["_distributionAndProjections"]["_type"])
|
||||
+ ", projections: "
|
||||
+ str(
|
||||
self.val["_distribution"]["_distributionAndProjections"][
|
||||
"_projectionNames"
|
||||
]
|
||||
)
|
||||
+ "]"
|
||||
)
|
||||
|
||||
|
||||
class ReferencesPrinter(DynamicArityNodePrinter):
|
||||
@ -991,12 +1061,17 @@ class PolyValuePrinter(object):
|
||||
self.type_set = str(self.poly_type).split("<", 1)[1]
|
||||
|
||||
if self.tag < 0:
|
||||
raise gdb.GdbError("Invalid PolyValue tag: {}, must be at least 0".format(self.tag))
|
||||
raise gdb.GdbError(
|
||||
"Invalid PolyValue tag: {}, must be at least 0".format(self.tag)
|
||||
)
|
||||
|
||||
# Check if the tag is out of range for the set of types that we know about.
|
||||
if self.tag > len(self.type_set.split(",")):
|
||||
raise gdb.GdbError("Unknown PolyValue tag: {} (max: {}), did you add a new one?".format(
|
||||
self.tag, str(self.type_set)))
|
||||
raise gdb.GdbError(
|
||||
"Unknown PolyValue tag: {} (max: {}), did you add a new one?".format(
|
||||
self.tag, str(self.type_set)
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def display_hint():
|
||||
@ -1004,8 +1079,11 @@ class PolyValuePrinter(object):
|
||||
return None
|
||||
|
||||
def cast_control_block(self, target_type):
|
||||
return self.control_block.dereference().address.cast(
|
||||
target_type.pointer()).dereference()["_t"]
|
||||
return (
|
||||
self.control_block.dereference()
|
||||
.address.cast(target_type.pointer())
|
||||
.dereference()["_t"]
|
||||
)
|
||||
|
||||
def get_dynamic_type(self):
|
||||
# Build up the dynamic type for the particular variant of this PolyValue instance. This is
|
||||
@ -1026,9 +1104,12 @@ class PolyValuePrinter(object):
|
||||
return "Unknown PolyValue tag: {}, did you add a new one?".format(self.tag)
|
||||
# GDB automatically formats types with children, remove the extra characters to get the
|
||||
# output that we want.
|
||||
return str(self.cast_control_block(dynamic_type)).replace(" = ", "").replace("{",
|
||||
"").replace(
|
||||
"}", "")
|
||||
return (
|
||||
str(self.cast_control_block(dynamic_type))
|
||||
.replace(" = ", "")
|
||||
.replace("{", "")
|
||||
.replace("}", "")
|
||||
)
|
||||
|
||||
|
||||
class AtomPrinter(object):
|
||||
@ -1076,9 +1157,12 @@ class DisjunctionPrinter(ConjunctionPrinter):
|
||||
|
||||
|
||||
def bool_expr_type(T):
|
||||
return (f"{OPTIMIZER_NS}::algebra::PolyValue<" + f"{OPTIMIZER_NS}::BoolExpr<{T}>::Atom, " +
|
||||
f"{OPTIMIZER_NS}::BoolExpr<{T}>::Conjunction, " +
|
||||
f"{OPTIMIZER_NS}::BoolExpr<{T}>::Disjunction>")
|
||||
return (
|
||||
f"{OPTIMIZER_NS}::algebra::PolyValue<"
|
||||
+ f"{OPTIMIZER_NS}::BoolExpr<{T}>::Atom, "
|
||||
+ f"{OPTIMIZER_NS}::BoolExpr<{T}>::Conjunction, "
|
||||
+ f"{OPTIMIZER_NS}::BoolExpr<{T}>::Disjunction>"
|
||||
)
|
||||
|
||||
|
||||
def register_optimizer_printers(pp):
|
||||
@ -1086,8 +1170,12 @@ def register_optimizer_printers(pp):
|
||||
|
||||
# IntervalRequirement printer.
|
||||
pp.add("Interval", f"{OPTIMIZER_NS}::IntervalRequirement", False, IntervalPrinter)
|
||||
pp.add("CompoundInterval", f"{OPTIMIZER_NS}::CompoundIntervalRequirement", False,
|
||||
CompoundIntervalPrinter)
|
||||
pp.add(
|
||||
"CompoundInterval",
|
||||
f"{OPTIMIZER_NS}::CompoundIntervalRequirement",
|
||||
False,
|
||||
CompoundIntervalPrinter,
|
||||
)
|
||||
|
||||
# IntervalReqExpr::Node printer.
|
||||
pp.add(
|
||||
@ -1111,12 +1199,20 @@ def register_optimizer_printers(pp):
|
||||
pp.add("Memo", f"{OPTIMIZER_NS}::cascades::Memo", False, MemoPrinter)
|
||||
|
||||
# ResidualRequirement printer.
|
||||
pp.add("ResidualRequirement", f"{OPTIMIZER_NS}::ResidualRequirement", False,
|
||||
ResidualRequirementPrinter)
|
||||
pp.add(
|
||||
"ResidualRequirement",
|
||||
f"{OPTIMIZER_NS}::ResidualRequirement",
|
||||
False,
|
||||
ResidualRequirementPrinter,
|
||||
)
|
||||
|
||||
# CandidateIndexEntry printer.
|
||||
pp.add("CandidateIndexEntry", f"{OPTIMIZER_NS}::CandidateIndexEntry", False,
|
||||
CandidateIndexEntryPrinter)
|
||||
pp.add(
|
||||
"CandidateIndexEntry",
|
||||
f"{OPTIMIZER_NS}::CandidateIndexEntry",
|
||||
False,
|
||||
CandidateIndexEntryPrinter,
|
||||
)
|
||||
|
||||
# BoolExpr<ResidualRequirement> is handled by the PolyValue printer, but still need to add
|
||||
# printers for each of the possible bool expr types.
|
||||
@ -1124,14 +1220,26 @@ def register_optimizer_printers(pp):
|
||||
bool_exprs = ["ResidualRequirement"]
|
||||
for bool_type in bool_expr_types:
|
||||
for expr in bool_exprs:
|
||||
pp.add(bool_type, f"{OPTIMIZER_NS}::BoolExpr<{OPTIMIZER_NS}::{expr}>::{bool_type}",
|
||||
False, getattr(sys.modules[__name__], bool_type + "Printer"))
|
||||
pp.add(
|
||||
bool_type,
|
||||
f"{OPTIMIZER_NS}::BoolExpr<{OPTIMIZER_NS}::{expr}>::{bool_type}",
|
||||
False,
|
||||
getattr(sys.modules[__name__], bool_type + "Printer"),
|
||||
)
|
||||
|
||||
# Utility types within the optimizer.
|
||||
pp.add("StrongStringAlias", f"{OPTIMIZER_NS}::StrongStringAlias", True,
|
||||
StrongStringAliasPrinter)
|
||||
pp.add("FieldProjectionMap", f"{OPTIMIZER_NS}::FieldProjectionMap", False,
|
||||
FieldProjectionMapPrinter)
|
||||
pp.add(
|
||||
"StrongStringAlias",
|
||||
f"{OPTIMIZER_NS}::StrongStringAlias",
|
||||
True,
|
||||
StrongStringAliasPrinter,
|
||||
)
|
||||
pp.add(
|
||||
"FieldProjectionMap",
|
||||
f"{OPTIMIZER_NS}::FieldProjectionMap",
|
||||
False,
|
||||
FieldProjectionMapPrinter,
|
||||
)
|
||||
|
||||
pp.add("ScanParams", f"{OPTIMIZER_NS}::ScanParams", False, ScanParamsPrinter)
|
||||
|
||||
@ -1202,9 +1310,13 @@ def register_optimizer_printers(pp):
|
||||
"ExpressionBinder",
|
||||
]
|
||||
for abt_type in abt_type_set:
|
||||
pp.add(abt_type, f"{OPTIMIZER_NS}::{abt_type}", False,
|
||||
getattr(sys.modules[__name__], abt_type + "Printer"))
|
||||
pp.add(
|
||||
abt_type,
|
||||
f"{OPTIMIZER_NS}::{abt_type}",
|
||||
False,
|
||||
getattr(sys.modules[__name__], abt_type + "Printer"),
|
||||
)
|
||||
|
||||
# Add the generic PolyValue printer which determines the exact type at runtime and attempts to
|
||||
# invoke the printer for that type.
|
||||
pp.add('PolyValue', OPTIMIZER_NS + "::algebra::PolyValue", True, PolyValuePrinter)
|
||||
pp.add("PolyValue", OPTIMIZER_NS + "::algebra::PolyValue", True, PolyValuePrinter)
|
||||
|
||||
@ -7,12 +7,13 @@ import gdb
|
||||
|
||||
# Pattern to match output of 'info files'
|
||||
PATTERN_ELF_SECTIONS = re.compile(
|
||||
r'(?P<begin>[0x0-9a-fA-F]+)\s-\s(?P<end>[0x0-9a-fA-F]+)\s\bis\b\s(?P<section>\.[a-z]+$)')
|
||||
r"(?P<begin>[0x0-9a-fA-F]+)\s-\s(?P<end>[0x0-9a-fA-F]+)\s\bis\b\s(?P<section>\.[a-z]+$)"
|
||||
)
|
||||
|
||||
|
||||
def parse_sections():
|
||||
"""Find addresses for .text, .data, and .bss sections."""
|
||||
file_info = gdb.execute('info files', to_string=True)
|
||||
file_info = gdb.execute("info files", to_string=True)
|
||||
section_map = {}
|
||||
for line in file_info.splitlines():
|
||||
line = line.strip()
|
||||
@ -20,10 +21,10 @@ def parse_sections():
|
||||
if match is None:
|
||||
continue
|
||||
|
||||
section = match.group('section')
|
||||
if section not in ('.text', '.data', '.bss'):
|
||||
section = match.group("section")
|
||||
if section not in (".text", ".data", ".bss"):
|
||||
continue
|
||||
begin = match.group('begin')
|
||||
begin = match.group("begin")
|
||||
section_map[section] = begin
|
||||
|
||||
return section_map
|
||||
@ -31,8 +32,9 @@ def parse_sections():
|
||||
|
||||
def load_sym_file_at_addrs(dbg_file, smap):
|
||||
"""Invoke add-symbol-file with addresses."""
|
||||
cmd = 'add-symbol-file {} {} -s .data {} -s .bss {}'.format(dbg_file, smap['.text'],
|
||||
smap['.data'], smap['.bss'])
|
||||
cmd = "add-symbol-file {} {} -s .data {} -s .bss {}".format(
|
||||
dbg_file, smap[".text"], smap[".data"], smap[".bss"]
|
||||
)
|
||||
gdb.execute(cmd, to_string=True)
|
||||
|
||||
|
||||
@ -41,18 +43,20 @@ class LoadDebugFile(gdb.Command):
|
||||
|
||||
def __init__(self):
|
||||
"""GDB Command API init."""
|
||||
super(LoadDebugFile, self).__init__('load-debug-symbols', gdb.COMPLETE_EXPRESSION)
|
||||
super(LoadDebugFile, self).__init__(
|
||||
"load-debug-symbols", gdb.COMPLETE_EXPRESSION
|
||||
)
|
||||
|
||||
def invoke(self, args, from_tty):
|
||||
"""GDB Command API invoke."""
|
||||
arglist = args.split()
|
||||
if len(arglist) != 1:
|
||||
print('Usage: load-debug-symbols <file_path>')
|
||||
print("Usage: load-debug-symbols <file_path>")
|
||||
return
|
||||
|
||||
dbg_file = arglist[0]
|
||||
if not os.path.exists(dbg_file):
|
||||
print('{} is not a valid file path'.format(dbg_file))
|
||||
print("{} is not a valid file path".format(dbg_file))
|
||||
return
|
||||
|
||||
try:
|
||||
@ -65,13 +69,13 @@ class LoadDebugFile(gdb.Command):
|
||||
LoadDebugFile()
|
||||
|
||||
PATTERN_ELF_SOLIB_SECTIONS = re.compile(
|
||||
r'(?P<begin>[0x0-9a-fA-F]+)\s-\s(?P<end>[0x0-9a-fA-F]+)\s\bis\b\s(?P<section>\.[a-z]+)\s\bin\b\s(?P<file>.*$)'
|
||||
r"(?P<begin>[0x0-9a-fA-F]+)\s-\s(?P<end>[0x0-9a-fA-F]+)\s\bis\b\s(?P<section>\.[a-z]+)\s\bin\b\s(?P<file>.*$)"
|
||||
)
|
||||
|
||||
|
||||
def parse_solib_sections():
|
||||
"""Find addresses for .text, .data, and .bss sections."""
|
||||
file_info = gdb.execute('info files', to_string=True)
|
||||
file_info = gdb.execute("info files", to_string=True)
|
||||
section_map = {}
|
||||
for line in file_info.splitlines():
|
||||
line = line.strip()
|
||||
@ -79,15 +83,18 @@ def parse_solib_sections():
|
||||
if match is None:
|
||||
continue
|
||||
|
||||
section = match.group('section')
|
||||
if section not in ('.text', '.data', '.bss'):
|
||||
section = match.group("section")
|
||||
if section not in (".text", ".data", ".bss"):
|
||||
continue
|
||||
begin = match.group('begin')
|
||||
begin = match.group("begin")
|
||||
# TODO duplicate fnames?
|
||||
fname = os.path.basename(match.group('file'))
|
||||
fname = os.path.basename(match.group("file"))
|
||||
|
||||
if fname.startswith("system-supplied DSO") or match.group('file').startswith(
|
||||
"/lib") or match.group('file').startswith("/usr/lib"):
|
||||
if (
|
||||
fname.startswith("system-supplied DSO")
|
||||
or match.group("file").startswith("/lib")
|
||||
or match.group("file").startswith("/usr/lib")
|
||||
):
|
||||
continue
|
||||
fname = f"{fname}.debug"
|
||||
section_map.setdefault(fname, {})
|
||||
@ -111,7 +118,9 @@ def find_dwarf_files(path):
|
||||
return out
|
||||
|
||||
|
||||
SOLIB_SEARCH_PATH_PREFIX = "The search path for loading non-absolute shared library symbol files is "
|
||||
SOLIB_SEARCH_PATH_PREFIX = (
|
||||
"The search path for loading non-absolute shared library symbol files is "
|
||||
)
|
||||
|
||||
|
||||
def extend_solib_search_path(new_path: str):
|
||||
@ -119,7 +128,7 @@ def extend_solib_search_path(new_path: str):
|
||||
solib_search_path = gdb.execute("show solib-search-path", to_string=True)
|
||||
# remove the prefix and suffix (which is a period and space) from the
|
||||
# search path
|
||||
solib_search_path = solib_search_path[len(SOLIB_SEARCH_PATH_PREFIX):-2]
|
||||
solib_search_path = solib_search_path[len(SOLIB_SEARCH_PATH_PREFIX) : -2]
|
||||
solib_search_path = f"{new_path}:{solib_search_path}"
|
||||
if solib_search_path.endswith(":"):
|
||||
solib_search_path = solib_search_path[:-1]
|
||||
@ -127,7 +136,9 @@ def extend_solib_search_path(new_path: str):
|
||||
gdb.execute(f"set solib-search-path {solib_search_path}", to_string=True)
|
||||
|
||||
|
||||
DEBUG_FILE_DIRECTORY_PREFIX = 'The directory where separate debug symbols are searched for is "'
|
||||
DEBUG_FILE_DIRECTORY_PREFIX = (
|
||||
'The directory where separate debug symbols are searched for is "'
|
||||
)
|
||||
|
||||
|
||||
def extend_debug_file_directory(new_path: str):
|
||||
@ -135,7 +146,7 @@ def extend_debug_file_directory(new_path: str):
|
||||
debug_file_directory = gdb.execute("show debug-file-directory", to_string=True)
|
||||
# remove the prefix and suffix (which is a period and space) from the
|
||||
# search path
|
||||
debug_file_directory = debug_file_directory[len(DEBUG_FILE_DIRECTORY_PREFIX):-3]
|
||||
debug_file_directory = debug_file_directory[len(DEBUG_FILE_DIRECTORY_PREFIX) : -3]
|
||||
debug_file_directory = f"{new_path}:{debug_file_directory}"
|
||||
if debug_file_directory.endswith(":"):
|
||||
debug_file_directory = debug_file_directory[:-1]
|
||||
@ -152,7 +163,7 @@ class LoadDistTest(gdb.Command):
|
||||
|
||||
def __init__(self):
|
||||
"""GDB Command API init."""
|
||||
super(LoadDistTest, self).__init__('load-dist-test', gdb.COMPLETE_EXPRESSION)
|
||||
super(LoadDistTest, self).__init__("load-dist-test", gdb.COMPLETE_EXPRESSION)
|
||||
|
||||
try:
|
||||
# test if we're running udb
|
||||
@ -166,11 +177,11 @@ class LoadDistTest(gdb.Command):
|
||||
"""Fetch the name of the binary gdb is attached to."""
|
||||
main_binary_name = gdb.objfiles()[0].filename
|
||||
main_binary_name = os.path.splitext(os.path.basename(main_binary_name))[0]
|
||||
if main_binary_name.endswith('mongod'):
|
||||
if main_binary_name.endswith("mongod"):
|
||||
return "mongod"
|
||||
if main_binary_name.endswith('mongo'):
|
||||
if main_binary_name.endswith("mongo"):
|
||||
return "mongo"
|
||||
if main_binary_name.endswith('mongos'):
|
||||
if main_binary_name.endswith("mongos"):
|
||||
return "mongos"
|
||||
|
||||
return None
|
||||
@ -183,7 +194,7 @@ class LoadDistTest(gdb.Command):
|
||||
print(f"No path provided, assuming '{arglist[0]}'")
|
||||
|
||||
if len(arglist) != 1:
|
||||
print('Usage: load-dist-test <path/to/dist-test>')
|
||||
print("Usage: load-dist-test <path/to/dist-test>")
|
||||
return
|
||||
|
||||
dist_test = arglist[0]
|
||||
|
||||
@ -11,7 +11,7 @@ if not gdb:
|
||||
from buildscripts.gdb.mongo import lookup_type
|
||||
|
||||
DEBUGGING = False
|
||||
'''
|
||||
"""
|
||||
Public API to be called by users. The input `ident` is a string of the form:
|
||||
'collection-2--4547167393143767234'.
|
||||
From within gdb type:
|
||||
@ -25,18 +25,20 @@ Some behaviors/limitations:
|
||||
more raw output.
|
||||
* Any `file:*.wt` can be output, e.g: `_mdb_catalog` or `WiredTiger`. Though the output may be less
|
||||
supported/of lower quality.
|
||||
'''
|
||||
"""
|
||||
|
||||
|
||||
def dump_pages_for_table(ident):
|
||||
conn_impl_type = lookup_type("WT_CONNECTION_IMPL")
|
||||
if not conn_impl_type:
|
||||
print('WT_CONNECTION_IMPL type not found. Try invoking this function from a different \
|
||||
thread and frame.')
|
||||
print(
|
||||
"WT_CONNECTION_IMPL type not found. Try invoking this function from a different \
|
||||
thread and frame."
|
||||
)
|
||||
return
|
||||
|
||||
conn_impl_ptr_type = conn_impl_type.pointer()
|
||||
dbg('impl', conn_impl_ptr_type)
|
||||
dbg("impl", conn_impl_ptr_type)
|
||||
|
||||
conn_ptr = None
|
||||
try:
|
||||
@ -46,18 +48,19 @@ thread and frame.')
|
||||
|
||||
if not conn_ptr or not conn_ptr.address:
|
||||
print(
|
||||
'Failed to find a suitable `WT_SESSION session` object to extract a connection object \
|
||||
"Failed to find a suitable `WT_SESSION session` object to extract a connection object \
|
||||
from. Try finding an eviction thread and frame, e.g: `__wt_evict_thread_run`. If the session is \
|
||||
optimized out, try going up stack frames until the variable is in a local scope rather than a \
|
||||
function input.')
|
||||
function input."
|
||||
)
|
||||
return
|
||||
|
||||
conn = conn_ptr.reinterpret_cast(conn_impl_ptr_type).dereference()
|
||||
dbg('conn', conn)
|
||||
data_handle, all_dhs = get_data_handle(conn, 'file:{}.wt'.format(ident))
|
||||
dbg("conn", conn)
|
||||
data_handle, all_dhs = get_data_handle(conn, "file:{}.wt".format(ident))
|
||||
if not data_handle:
|
||||
print('Data handle not found for ident. Ident: `{}`'.format(ident))
|
||||
print('All known data handles:')
|
||||
print("Data handle not found for ident. Ident: `{}`".format(ident))
|
||||
print("All known data handles:")
|
||||
pprint(all_dhs)
|
||||
return
|
||||
|
||||
@ -69,153 +72,161 @@ def dbg(ident, var):
|
||||
if not DEBUGGING:
|
||||
return
|
||||
|
||||
print('----------')
|
||||
print("----------")
|
||||
if type(var) == gdb.Value:
|
||||
print('{}: ({}*){}'.format(ident, var.type, var.address))
|
||||
print("{}: ({}*){}".format(ident, var.type, var.address))
|
||||
else:
|
||||
print(ident)
|
||||
print(' ' + str(type(var)))
|
||||
print(" " + str(type(var)))
|
||||
methods = dir(var)
|
||||
out = [name for name in methods if not name.startswith("__")]
|
||||
for item in out:
|
||||
print(' ' + item)
|
||||
print(" " + item)
|
||||
|
||||
if type(var) == gdb.Value:
|
||||
print('\n Fields:')
|
||||
print('\t' + '\n\t'.join(str(var).split('\n')))
|
||||
print("\n Fields:")
|
||||
print("\t" + "\n\t".join(str(var).split("\n")))
|
||||
|
||||
|
||||
def walk_wt_list(lst):
|
||||
ret = []
|
||||
node = lst['tqh_first']
|
||||
dbg('node', node)
|
||||
node = lst["tqh_first"]
|
||||
dbg("node", node)
|
||||
while True:
|
||||
if not node:
|
||||
break
|
||||
ret.append(node.dereference())
|
||||
node = node['q']['tqe_next']
|
||||
node = node["q"]["tqe_next"]
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def get_data_handle(conn, handle_name):
|
||||
dbg('datahandles', conn['dhqh'])
|
||||
dbg("datahandles", conn["dhqh"])
|
||||
ret = None
|
||||
all_file_dhs = []
|
||||
for handle in walk_wt_list(conn['dhqh']):
|
||||
if handle['name'].string().startswith('file:'):
|
||||
all_file_dhs.append(handle['name'].string()[5:-3])
|
||||
if handle['name'].string() == handle_name:
|
||||
for handle in walk_wt_list(conn["dhqh"]):
|
||||
if handle["name"].string().startswith("file:"):
|
||||
all_file_dhs.append(handle["name"].string()[5:-3])
|
||||
if handle["name"].string() == handle_name:
|
||||
ret = handle
|
||||
|
||||
return ret, all_file_dhs
|
||||
|
||||
|
||||
def get_btree_handle(dhandle):
|
||||
btree = lookup_type('WT_BTREE').pointer()
|
||||
return dhandle['handle'].reinterpret_cast(btree).dereference()
|
||||
btree = lookup_type("WT_BTREE").pointer()
|
||||
return dhandle["handle"].reinterpret_cast(btree).dereference()
|
||||
|
||||
|
||||
def dump_update_chain(update_chain):
|
||||
while True:
|
||||
if not update_chain:
|
||||
print(' λ (End of update chain)')
|
||||
print(" λ (End of update chain)")
|
||||
break
|
||||
dbg('update', update_chain)
|
||||
dbg("update", update_chain)
|
||||
wt_val = update_chain.dereference()
|
||||
obj = None
|
||||
dbg('wt_val', wt_val)
|
||||
val_bytes = gdb.selected_inferior().read_memory(wt_val['data'], wt_val['size'])
|
||||
can_bson = wt_val['type'] == 3
|
||||
dbg("wt_val", wt_val)
|
||||
val_bytes = gdb.selected_inferior().read_memory(wt_val["data"], wt_val["size"])
|
||||
can_bson = wt_val["type"] == 3
|
||||
if can_bson:
|
||||
try:
|
||||
obj = bson.decode_all(val_bytes)[0]
|
||||
except:
|
||||
pass
|
||||
print(' ' + '\n '.join(str(wt_val).split('\n')) + " " + str(obj) + " =>")
|
||||
print(" " + "\n ".join(str(wt_val).split("\n")) + " " + str(obj) + " =>")
|
||||
|
||||
update_chain = update_chain['next']
|
||||
update_chain = update_chain["next"]
|
||||
|
||||
|
||||
def dump_insert_list(wt_insert):
|
||||
key_struct = wt_insert['u']['key']
|
||||
key = gdb.selected_inferior().read_memory(
|
||||
int(wt_insert.address) + key_struct['offset'], key_struct['size']).tobytes()
|
||||
print('Key: ' + str(key))
|
||||
print('Value:')
|
||||
update_chain = wt_insert['upd']
|
||||
key_struct = wt_insert["u"]["key"]
|
||||
key = (
|
||||
gdb.selected_inferior()
|
||||
.read_memory(int(wt_insert.address) + key_struct["offset"], key_struct["size"])
|
||||
.tobytes()
|
||||
)
|
||||
print("Key: " + str(key))
|
||||
print("Value:")
|
||||
update_chain = wt_insert["upd"]
|
||||
dump_update_chain(update_chain)
|
||||
|
||||
|
||||
def dump_skip_list(wt_insert_head):
|
||||
if not wt_insert_head['head'].address:
|
||||
if not wt_insert_head["head"].address:
|
||||
return
|
||||
wt_insert = wt_insert_head['head'][0]
|
||||
wt_insert = wt_insert_head["head"][0]
|
||||
idx = 0
|
||||
while True:
|
||||
if not wt_insert:
|
||||
break
|
||||
dump_insert_list(wt_insert.dereference())
|
||||
dbg('insert' + str(idx), wt_insert.dereference())
|
||||
dbg("insert" + str(idx), wt_insert.dereference())
|
||||
idx += 1
|
||||
wt_insert = wt_insert['next'][0]
|
||||
wt_insert = wt_insert["next"][0]
|
||||
|
||||
|
||||
def dump_modified(leaf_page):
|
||||
print("Modify:")
|
||||
if not leaf_page['modify']:
|
||||
if not leaf_page["modify"]:
|
||||
print("No modifies")
|
||||
return
|
||||
|
||||
leaf_modify = leaf_page['modify'].dereference()
|
||||
dbg('modify', leaf_modify)
|
||||
row_leaf_insert = leaf_modify['u2']['row_leaf']['insert']
|
||||
dbg('row store', row_leaf_insert)
|
||||
leaf_modify = leaf_page["modify"].dereference()
|
||||
dbg("modify", leaf_modify)
|
||||
row_leaf_insert = leaf_modify["u2"]["row_leaf"]["insert"]
|
||||
dbg("row store", row_leaf_insert)
|
||||
if not row_leaf_insert:
|
||||
print("No insert list")
|
||||
else:
|
||||
print("Insert list:")
|
||||
dump_skip_list(row_leaf_insert.dereference().dereference())
|
||||
|
||||
row_leaf_update = leaf_modify['u2']['row_leaf']['update']
|
||||
row_leaf_update = leaf_modify["u2"]["row_leaf"]["update"]
|
||||
if not row_leaf_update:
|
||||
print("No update list")
|
||||
else:
|
||||
print("Update list:")
|
||||
leaf_num_entries = int(leaf_page['entries'])
|
||||
leaf_num_entries = int(leaf_page["entries"])
|
||||
for i in range(0, leaf_num_entries):
|
||||
dump_update_chain(row_leaf_update[i])
|
||||
|
||||
|
||||
def dump_disk(leaf_page):
|
||||
dbg('in-memory page:', leaf_page)
|
||||
dsk = leaf_page['dsk'].dereference()
|
||||
dbg("in-memory page:", leaf_page)
|
||||
dsk = leaf_page["dsk"].dereference()
|
||||
if int(dsk.address) == 0:
|
||||
print("No page loaded from disk.")
|
||||
return
|
||||
dbg('on-disk page:', dsk)
|
||||
dbg("on-disk page:", dsk)
|
||||
wt_page_header_size = 28
|
||||
wt_block_header_size = 12
|
||||
page_bytes = gdb.selected_inferior().read_memory(
|
||||
int(dsk.address) + wt_page_header_size + wt_block_header_size,
|
||||
int(dsk['mem_size'])).tobytes()
|
||||
page_bytes = (
|
||||
gdb.selected_inferior()
|
||||
.read_memory(
|
||||
int(dsk.address) + wt_page_header_size + wt_block_header_size,
|
||||
int(dsk["mem_size"]),
|
||||
)
|
||||
.tobytes()
|
||||
)
|
||||
print("Dsk:\n" + str(page_bytes))
|
||||
|
||||
|
||||
def dump_handle(dhandle):
|
||||
print("Dumping: " + dhandle['name'].string())
|
||||
print("Dumping: " + dhandle["name"].string())
|
||||
btree = get_btree_handle(dhandle)
|
||||
root = btree['root']
|
||||
root_page = root['page'].dereference()
|
||||
dbg('btree', btree)
|
||||
dbg('root', btree['root'])
|
||||
dbg('root page', root_page)
|
||||
rpindex = root_page['u']['intl']['__index'].dereference()
|
||||
leaf_num_entries = int(rpindex['entries'])
|
||||
root = btree["root"]
|
||||
root_page = root["page"].dereference()
|
||||
dbg("btree", btree)
|
||||
dbg("root", btree["root"])
|
||||
dbg("root page", root_page)
|
||||
rpindex = root_page["u"]["intl"]["__index"].dereference()
|
||||
leaf_num_entries = int(rpindex["entries"])
|
||||
for idx in range(0, leaf_num_entries):
|
||||
dbg('rpindex', rpindex)
|
||||
dbg('rp-pre-index', rpindex['index'].dereference().dereference())
|
||||
leaf_page = rpindex['index'][idx].dereference()['page'].dereference()
|
||||
dbg('leaf', leaf_page)
|
||||
dbg("rpindex", rpindex)
|
||||
dbg("rp-pre-index", rpindex["index"].dereference().dereference())
|
||||
leaf_page = rpindex["index"][idx].dereference()["page"].dereference()
|
||||
dbg("leaf", leaf_page)
|
||||
dump_disk(leaf_page)
|
||||
dump_modified(leaf_page)
|
||||
|
||||
@ -46,7 +46,7 @@ def generate_version_expansions():
|
||||
with open(VERSION_JSON, "r") as fh:
|
||||
data = fh.read()
|
||||
version_data = json.loads(data)
|
||||
version_line = version_data['version']
|
||||
version_line = version_data["version"]
|
||||
version_parts = match_verstr(version_line)
|
||||
if not version_parts:
|
||||
raise ValueError("Unable to parse version.json")
|
||||
@ -58,7 +58,9 @@ def generate_version_expansions():
|
||||
version_line = version_line.lstrip("r")
|
||||
version_parts = match_verstr(version_line)
|
||||
if not version_parts:
|
||||
raise ValueError("Unable to parse version from stdin and no version.json provided")
|
||||
raise ValueError(
|
||||
"Unable to parse version from stdin and no version.json provided"
|
||||
)
|
||||
|
||||
if version_parts[0]:
|
||||
expansions["suffix"] = "v8.0-latest"
|
||||
@ -83,7 +85,7 @@ def match_verstr(verstr):
|
||||
r2.3.4-git234, r2.3.4-rc0-234-githash If the version is invalid (i.e.
|
||||
doesn't start with "2.3.4" or "2.3.4-rc0", this will return False.
|
||||
"""
|
||||
res = re.match(r'^r?(?:\d+\.\d+\.\d+(?:-rc\d+|-alpha\d+)?)(-.*)?', verstr)
|
||||
res = re.match(r"^r?(?:\d+\.\d+\.\d+(?:-rc\d+|-alpha\d+)?)(-.*)?", verstr)
|
||||
if not res:
|
||||
return False
|
||||
return res.groups()
|
||||
|
||||
@ -71,7 +71,7 @@ def replace_variables(pattern: str, variables: dict) -> str:
|
||||
|
||||
def get_path_name_regex(pattern: str) -> str:
|
||||
"""Return the regex pattern for output names."""
|
||||
return '[0-9a-f]'.join([re.escape(part) for part in pattern.split('%')])
|
||||
return "[0-9a-f]".join([re.escape(part) for part in pattern.split("%")])
|
||||
|
||||
|
||||
# For compatibility with version<3.8 that does not support shutil.copytree with dirs_exist_ok=True
|
||||
@ -81,7 +81,9 @@ def copytree_dirs_exist_ok_compatibility(src, dest):
|
||||
os.makedirs(dest)
|
||||
files = os.listdir(src)
|
||||
for file in files:
|
||||
copytree_dirs_exist_ok_compatibility(os.path.join(src, file), os.path.join(dest, file))
|
||||
copytree_dirs_exist_ok_compatibility(
|
||||
os.path.join(src, file), os.path.join(dest, file)
|
||||
)
|
||||
else:
|
||||
shutil.copyfile(src, dest)
|
||||
|
||||
@ -94,10 +96,13 @@ def copytree_dirs_exist_ok(src, dest):
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.option('-n', '--dry-run', is_flag=True)
|
||||
@click.option('-v', '--verbose', is_flag=True)
|
||||
@click.option('--config', envvar='GOLDEN_TEST_CONFIG_PATH',
|
||||
help='Config file path. Also GOLDEN_TEST_CONFIG_PATH environment variable.')
|
||||
@click.option("-n", "--dry-run", is_flag=True)
|
||||
@click.option("-v", "--verbose", is_flag=True)
|
||||
@click.option(
|
||||
"--config",
|
||||
envvar="GOLDEN_TEST_CONFIG_PATH",
|
||||
help="Config file path. Also GOLDEN_TEST_CONFIG_PATH environment variable.",
|
||||
)
|
||||
@click.pass_context
|
||||
def cli(ctx, dry_run, verbose, config):
|
||||
"""Manage test results from golden data test framework.
|
||||
@ -150,31 +155,40 @@ class GoldenTestApp(object):
|
||||
def get_git_root(self):
|
||||
"""Return the root for git repo."""
|
||||
self.vprint("Querying git repo root")
|
||||
repo_root = check_output("git rev-parse --show-toplevel", shell=True, text=True).strip()
|
||||
repo_root = check_output(
|
||||
"git rev-parse --show-toplevel", shell=True, text=True
|
||||
).strip()
|
||||
self.vprint(f"Found git repo root: '{repo_root}'")
|
||||
return repo_root
|
||||
|
||||
def load_config(self, config_path):
|
||||
"""Load configuration file."""
|
||||
if config_path is None:
|
||||
raise AppError((
|
||||
"Can't load config. GOLDEN_TEST_CONFIG_PATH envrionment variable is not set. Golden test CLI must be configured before use.\n"
|
||||
"To configure it, follow the instructions in https://github.com/mongodb/mongo/blob/master/docs/golden_data_test_framework.md#how-to-diff-and-accept-new-test-outputs-on-a-workstation\n"
|
||||
"Note: After setup you may need to rerun the tests for this utility to find them."))
|
||||
raise AppError(
|
||||
(
|
||||
"Can't load config. GOLDEN_TEST_CONFIG_PATH envrionment variable is not set. Golden test CLI must be configured before use.\n"
|
||||
"To configure it, follow the instructions in https://github.com/mongodb/mongo/blob/master/docs/golden_data_test_framework.md#how-to-diff-and-accept-new-test-outputs-on-a-workstation\n"
|
||||
"Note: After setup you may need to rerun the tests for this utility to find them."
|
||||
)
|
||||
)
|
||||
|
||||
self.vprint(f"Loading config from path: '{config_path}'")
|
||||
config = GoldenTestConfig.from_yaml_file(config_path)
|
||||
|
||||
if config.outputRootPattern is None:
|
||||
raise AppError("Invalid config. outputRootPattern config parameter is not set")
|
||||
raise AppError(
|
||||
"Invalid config. outputRootPattern config parameter is not set"
|
||||
)
|
||||
|
||||
return config
|
||||
|
||||
def get_output_path(self, output_name):
|
||||
"""Return the path for given output name."""
|
||||
if not re.match(self.output_name_regex, output_name):
|
||||
raise AppError(f"Invalid name: '{output_name}'. " +
|
||||
f"Does not match configured pattern: {self.output_name_pattern}")
|
||||
raise AppError(
|
||||
f"Invalid name: '{output_name}'. "
|
||||
+ f"Does not match configured pattern: {self.output_name_pattern}"
|
||||
)
|
||||
output_path = os.path.join(self.output_parent_path, output_name)
|
||||
if not os.path.isdir(output_path):
|
||||
raise AppError(f"No such directory: '{output_path}'")
|
||||
@ -182,13 +196,17 @@ class GoldenTestApp(object):
|
||||
|
||||
def list_outputs(self):
|
||||
"""Return names of all available outputs."""
|
||||
self.vprint(f"Listing outputs in path: '{self.output_parent_path}' " +
|
||||
f"matching '{self.output_name_pattern}'")
|
||||
self.vprint(
|
||||
f"Listing outputs in path: '{self.output_parent_path}' "
|
||||
+ f"matching '{self.output_name_pattern}'"
|
||||
)
|
||||
|
||||
if not os.path.isdir(self.output_parent_path):
|
||||
return []
|
||||
return [
|
||||
o for o in os.listdir(self.output_parent_path) if re.match(self.output_name_regex, o)
|
||||
o
|
||||
for o in os.listdir(self.output_parent_path)
|
||||
if re.match(self.output_name_regex, o)
|
||||
and os.path.isdir(os.path.join(self.output_parent_path, o))
|
||||
]
|
||||
|
||||
@ -206,8 +224,10 @@ class GoldenTestApp(object):
|
||||
if latest_name is None:
|
||||
raise AppError("No outputs found")
|
||||
|
||||
self.vprint(f"Found output with latest creation time: {latest_name} " +
|
||||
f"created at {latest_ctime}")
|
||||
self.vprint(
|
||||
f"Found output with latest creation time: {latest_name} "
|
||||
+ f"created at {latest_ctime}"
|
||||
)
|
||||
|
||||
return latest_name
|
||||
|
||||
@ -215,20 +235,22 @@ class GoldenTestApp(object):
|
||||
"""Return actual and expected paths for given output name."""
|
||||
output_path = self.get_output_path(output_name)
|
||||
return OutputPaths(
|
||||
actual=os.path.join(output_path, "actual"), expected=os.path.join(
|
||||
output_path, "expected"))
|
||||
actual=os.path.join(output_path, "actual"),
|
||||
expected=os.path.join(output_path, "expected"),
|
||||
)
|
||||
|
||||
def setup_linux(self):
|
||||
# Create config file
|
||||
config_path = os.path.join(os.path.expanduser('~'), ".golden_test_config.yml")
|
||||
config_path = os.path.join(os.path.expanduser("~"), ".golden_test_config.yml")
|
||||
if not os.path.isfile(config_path):
|
||||
print(f"Creating {config_path}")
|
||||
config_contents = (
|
||||
r"""outputRootPattern: '/var/tmp/test_output/out-%%%%-%%%%-%%%%-%%%%'"""
|
||||
"\n"
|
||||
r"""diffCmd: 'git diff --no-index "{{expected}}" "{{actual}}"'"""
|
||||
"\n")
|
||||
with open(config_path, 'w') as file:
|
||||
"\n"
|
||||
)
|
||||
with open(config_path, "w") as file:
|
||||
file.write(config_contents)
|
||||
else:
|
||||
print(f"Skipping creating {config_path}, file exists.")
|
||||
@ -237,14 +259,21 @@ class GoldenTestApp(object):
|
||||
etc_environment_path = "/etc/environment"
|
||||
env_var_defined = False
|
||||
if os.path.isfile(etc_environment_path):
|
||||
with open(etc_environment_path, 'r') as file:
|
||||
with open(etc_environment_path, "r") as file:
|
||||
for line in file.readlines():
|
||||
if line.startswith("GOLDEN_TEST_CONFIG_PATH="):
|
||||
env_var_defined = True
|
||||
if not env_var_defined:
|
||||
print(f"Adding GOLDEN_TEST_CONFIG_PATH to {etc_environment_path}")
|
||||
env_var_contents = (f"GOLDEN_TEST_CONFIG_PATH=\"{config_path}\"")
|
||||
call(["sudo", "/bin/sh", "-c", f"echo '{env_var_contents}' >> {etc_environment_path}"])
|
||||
env_var_contents = f'GOLDEN_TEST_CONFIG_PATH="{config_path}"'
|
||||
call(
|
||||
[
|
||||
"sudo",
|
||||
"/bin/sh",
|
||||
"-c",
|
||||
f"echo '{env_var_contents}' >> {etc_environment_path}",
|
||||
]
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"Skipping adding GOLDEN_TEST_CONFIG_PATH to {etc_environment_path}, variable already defined."
|
||||
@ -252,15 +281,18 @@ class GoldenTestApp(object):
|
||||
|
||||
def setup_windows(self):
|
||||
# Create config file
|
||||
config_path = os.path.join(os.path.expandvars('%LocalAppData%'), ".golden_test_config.yml")
|
||||
config_path = os.path.join(
|
||||
os.path.expandvars("%LocalAppData%"), ".golden_test_config.yml"
|
||||
)
|
||||
if not os.path.isfile(config_path):
|
||||
print(f"Creating {config_path}")
|
||||
config_contents = (
|
||||
r"outputRootPattern: 'C:\Users\Administrator\AppData\Local\Temp\test_output\out-%%%%-%%%%-%%%%-%%%%'"
|
||||
"\n"
|
||||
r"""diffCmd: 'git diff --no-index "{{expected}}" "{{actual}}"'"""
|
||||
"\n")
|
||||
with open(config_path, 'w') as file:
|
||||
"\n"
|
||||
)
|
||||
with open(config_path, "w") as file:
|
||||
file.write(config_contents)
|
||||
else:
|
||||
print(f"Skipping creating {config_path}, file exists.")
|
||||
@ -268,17 +300,21 @@ class GoldenTestApp(object):
|
||||
# Add global GOLDEN_TEST_CONFIG_PATH environment variable
|
||||
if os.environ.get("GOLDEN_TEST_CONFIG_PATH") is None:
|
||||
print("Setting GOLDEN_TEST_CONFIG_PATH global variable")
|
||||
call([
|
||||
"runas", "/profile", "/user:administrator",
|
||||
f"setx GOLDEN_TEST_CONFIG_PATH \"{config_path}\""
|
||||
])
|
||||
call(
|
||||
[
|
||||
"runas",
|
||||
"/profile",
|
||||
"/user:administrator",
|
||||
f'setx GOLDEN_TEST_CONFIG_PATH "{config_path}"',
|
||||
]
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"Skipping setting GOLDEN_TEST_CONFIG_PATH global variable, variable already defined."
|
||||
)
|
||||
|
||||
@cli.command('diff', help='Diff the expected and actual folders of the test output')
|
||||
@click.argument('output_name', required=False)
|
||||
@cli.command("diff", help="Diff the expected and actual folders of the test output")
|
||||
@click.argument("output_name", required=False)
|
||||
@click.pass_obj
|
||||
def command_diff(self, output_name):
|
||||
"""Diff the expected and actual folders of the test output."""
|
||||
@ -290,13 +326,14 @@ class GoldenTestApp(object):
|
||||
self.vprint(f"Diffing results from output '{output_name}'")
|
||||
|
||||
paths = self.get_paths(output_name)
|
||||
diff_cmd = replace_variables(self.config.diffCmd,
|
||||
{'actual': paths.actual, 'expected': paths.expected})
|
||||
diff_cmd = replace_variables(
|
||||
self.config.diffCmd, {"actual": paths.actual, "expected": paths.expected}
|
||||
)
|
||||
self.vprint(f"Running command: '{diff_cmd}'")
|
||||
self.call_shell(diff_cmd)
|
||||
|
||||
@cli.command('get-path', help='Get the root folder path of the test output.')
|
||||
@click.argument('output_name', required=False)
|
||||
@cli.command("get-path", help="Get the root folder path of the test output.")
|
||||
@click.argument("output_name", required=False)
|
||||
@click.pass_obj
|
||||
def command_get_path(self, output_name):
|
||||
"""Get the root folder path of the test output."""
|
||||
@ -308,9 +345,10 @@ class GoldenTestApp(object):
|
||||
print(self.get_output_path(output_name))
|
||||
|
||||
@cli.command(
|
||||
'accept',
|
||||
help='Accept the actual test output and copy it as new golden data to the source repo.')
|
||||
@click.argument('output_name', required=False)
|
||||
"accept",
|
||||
help="Accept the actual test output and copy it as new golden data to the source repo.",
|
||||
)
|
||||
@click.argument("output_name", required=False)
|
||||
@click.pass_obj
|
||||
def command_accept(self, output_name):
|
||||
"""Accept the actual test output and copy it as new golden data to the source repo."""
|
||||
@ -328,7 +366,7 @@ class GoldenTestApp(object):
|
||||
if not self.dry_run:
|
||||
copytree_dirs_exist_ok(paths.actual, repo_root)
|
||||
|
||||
@cli.command('clean', help='Remove all test outputs')
|
||||
@cli.command("clean", help="Remove all test outputs")
|
||||
@click.pass_obj
|
||||
def command_clean(self):
|
||||
"""Remove all test outputs."""
|
||||
@ -342,7 +380,7 @@ class GoldenTestApp(object):
|
||||
if not self.dry_run:
|
||||
shutil.rmtree(output_path)
|
||||
|
||||
@cli.command('latest', help='Get the name of the most recent test output')
|
||||
@cli.command("latest", help="Get the name of the most recent test output")
|
||||
@click.pass_obj
|
||||
def command_latest(self):
|
||||
"""Get the name of the most recent test output."""
|
||||
@ -351,7 +389,7 @@ class GoldenTestApp(object):
|
||||
output_name = self.get_latest_output()
|
||||
print(output_name)
|
||||
|
||||
@cli.command('list', help='List all names of the available test outputs')
|
||||
@cli.command("list", help="List all names of the available test outputs")
|
||||
@click.pass_obj
|
||||
def command_list(self):
|
||||
"""List all names of the available test outputs."""
|
||||
@ -360,7 +398,7 @@ class GoldenTestApp(object):
|
||||
for output_name in self.list_outputs():
|
||||
print(output_name)
|
||||
|
||||
@cli.command('setup', help='Performs default setup based on current platform')
|
||||
@cli.command("setup", help="Performs default setup based on current platform")
|
||||
@click.pass_obj
|
||||
def command_setup(self):
|
||||
"""Performs default setup based on current platform."""
|
||||
@ -369,7 +407,9 @@ class GoldenTestApp(object):
|
||||
elif platform.platform().startswith("Windows"):
|
||||
self.setup_windows()
|
||||
else:
|
||||
raise AppError(f"Platform not supported by this setup utility: {platform.platform()}")
|
||||
raise AppError(
|
||||
f"Platform not supported by this setup utility: {platform.platform()}"
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@ -2,6 +2,8 @@
|
||||
"""Stub file pointing users to the new invocation."""
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Hello! It seems you've executed 'buildscripts/hang_analyzer.py'. We have recently\n"
|
||||
"repackaged the hang analyzer as a subcommand of resmoke. It can now be invoked as\n"
|
||||
"'./buildscripts/resmoke.py hang-analyzer' with all of the same arguments as before.")
|
||||
print(
|
||||
"Hello! It seems you've executed 'buildscripts/hang_analyzer.py'. We have recently\n"
|
||||
"repackaged the hang analyzer as a subcommand of resmoke. It can now be invoked as\n"
|
||||
"'./buildscripts/resmoke.py hang-analyzer' with all of the same arguments as before."
|
||||
)
|
||||
|
||||
@ -40,7 +40,7 @@ from typing import Dict, List, Set
|
||||
from pymongo import MongoClient
|
||||
|
||||
# Permit imports from "buildscripts".
|
||||
sys.path.append(os.path.normpath(os.path.join(os.path.abspath(__file__), '../../..')))
|
||||
sys.path.append(os.path.normpath(os.path.join(os.path.abspath(__file__), "../../..")))
|
||||
|
||||
# pylint: disable=wrong-import-position
|
||||
from buildscripts.idl.lib import list_idls, parse_idl
|
||||
@ -52,7 +52,7 @@ from idl import syntax
|
||||
|
||||
# pylint: enable=wrong-import-position
|
||||
|
||||
LOGGER_NAME = 'check-idl-definitions'
|
||||
LOGGER_NAME = "check-idl-definitions"
|
||||
LOGGER = logging.getLogger(LOGGER_NAME)
|
||||
|
||||
|
||||
@ -67,8 +67,9 @@ def is_test_or_third_party_idl(idl_path: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def get_command_definitions(api_version: str, directory: str,
|
||||
import_directories: List[str]) -> Dict[str, syntax.Command]:
|
||||
def get_command_definitions(
|
||||
api_version: str, directory: str, import_directories: List[str]
|
||||
) -> Dict[str, syntax.Command]:
|
||||
"""Get parsed IDL definitions of commands in a given API version."""
|
||||
|
||||
LOGGER.info("Searching for command definitions in %s", directory)
|
||||
@ -76,16 +77,22 @@ def get_command_definitions(api_version: str, directory: str,
|
||||
def gen():
|
||||
for idl_path in sorted(list_idls(directory)):
|
||||
if not is_test_or_third_party_idl(idl_path):
|
||||
for command in parse_idl(idl_path, import_directories).spec.symbols.commands:
|
||||
for command in parse_idl(
|
||||
idl_path, import_directories
|
||||
).spec.symbols.commands:
|
||||
if command.api_version == api_version:
|
||||
yield command.command_name, command
|
||||
|
||||
idl_commands = dict(gen())
|
||||
LOGGER.debug("Found %s IDL commands in API Version %s", len(idl_commands), api_version)
|
||||
LOGGER.debug(
|
||||
"Found %s IDL commands in API Version %s", len(idl_commands), api_version
|
||||
)
|
||||
return idl_commands
|
||||
|
||||
|
||||
def list_commands_for_api(api_version: str, mongod_or_mongos: str, install_dir: str) -> Set[str]:
|
||||
def list_commands_for_api(
|
||||
api_version: str, mongod_or_mongos: str, install_dir: str
|
||||
) -> Set[str]:
|
||||
"""Get a list of commands in a given API version by calling listCommands."""
|
||||
assert mongod_or_mongos in ("mongod", "mongos")
|
||||
logging.info("Calling listCommands on %s", mongod_or_mongos)
|
||||
@ -97,28 +104,43 @@ def list_commands_for_api(api_version: str, mongod_or_mongos: str, install_dir:
|
||||
logger = loggers.new_fixture_logger("MongoDFixture", 0)
|
||||
logger.parent = LOGGER
|
||||
fixture: interface.Fixture = fixturelib.make_fixture(
|
||||
"MongoDFixture", logger, 0, dbpath_prefix=dbpath.name,
|
||||
mongod_executable=mongod_executable, mongod_options={"set_parameters": {}})
|
||||
"MongoDFixture",
|
||||
logger,
|
||||
0,
|
||||
dbpath_prefix=dbpath.name,
|
||||
mongod_executable=mongod_executable,
|
||||
mongod_options={"set_parameters": {}},
|
||||
)
|
||||
else:
|
||||
logger = loggers.new_fixture_logger("ShardedClusterFixture", 0)
|
||||
logger.parent = LOGGER
|
||||
fixture = fixturelib.make_fixture(
|
||||
"ShardedClusterFixture", logger, 0, dbpath_prefix=dbpath.name,
|
||||
mongos_executable=mongos_executable, mongod_executable=mongod_executable,
|
||||
mongod_options={"set_parameters": {}})
|
||||
"ShardedClusterFixture",
|
||||
logger,
|
||||
0,
|
||||
dbpath_prefix=dbpath.name,
|
||||
mongos_executable=mongos_executable,
|
||||
mongod_executable=mongod_executable,
|
||||
mongod_options={"set_parameters": {}},
|
||||
)
|
||||
|
||||
fixture.setup()
|
||||
fixture.await_ready()
|
||||
|
||||
try:
|
||||
client = MongoClient(fixture.get_driver_connection_url()) # type: MongoClient
|
||||
reply = client.admin.command('listCommands') # type: Mapping[str, Any]
|
||||
reply = client.admin.command("listCommands") # type: Mapping[str, Any]
|
||||
commands = {
|
||||
name
|
||||
for name, info in reply['commands'].items() if api_version in info['apiVersions']
|
||||
for name, info in reply["commands"].items()
|
||||
if api_version in info["apiVersions"]
|
||||
}
|
||||
logging.info("Found %s commands in API Version %s on %s", len(commands), api_version,
|
||||
mongod_or_mongos)
|
||||
logging.info(
|
||||
"Found %s commands in API Version %s on %s",
|
||||
len(commands),
|
||||
api_version,
|
||||
mongod_or_mongos,
|
||||
)
|
||||
return commands
|
||||
finally:
|
||||
fixture.teardown()
|
||||
@ -128,7 +150,9 @@ def assert_command_sets_equal(api_version: str, command_sets: Dict[str, Set[str]
|
||||
"""Check that all sources have the same set of commands for a given API version."""
|
||||
LOGGER.info("Comparing %s command sets", len(command_sets))
|
||||
for name, commands in command_sets.items():
|
||||
LOGGER.info("--------- %s API Version %s commands --------------", name, api_version)
|
||||
LOGGER.info(
|
||||
"--------- %s API Version %s commands --------------", name, api_version
|
||||
)
|
||||
for command in sorted(commands):
|
||||
LOGGER.info("%s", command)
|
||||
|
||||
@ -138,13 +162,22 @@ def assert_command_sets_equal(api_version: str, command_sets: Dict[str, Set[str]
|
||||
for other_name, other_commands in it:
|
||||
if commands != other_commands:
|
||||
if commands - other_commands:
|
||||
LOGGER.error("%s has commands not in %s: %s", name, other_name,
|
||||
commands - other_commands)
|
||||
LOGGER.error(
|
||||
"%s has commands not in %s: %s",
|
||||
name,
|
||||
other_name,
|
||||
commands - other_commands,
|
||||
)
|
||||
if other_commands - commands:
|
||||
LOGGER.error("%s has commands not in %s: %s", other_name, name,
|
||||
other_commands - commands)
|
||||
LOGGER.error(
|
||||
"%s has commands not in %s: %s",
|
||||
other_name,
|
||||
name,
|
||||
other_commands - commands,
|
||||
)
|
||||
raise AssertionError(
|
||||
f"{name} and {other_name} have different commands in API Version {api_version}")
|
||||
f"{name} and {other_name} have different commands in API Version {api_version}"
|
||||
)
|
||||
|
||||
|
||||
def remove_skipped_commands(command_sets: Dict[str, Set[str]]):
|
||||
@ -167,12 +200,24 @@ def remove_skipped_commands(command_sets: Dict[str, Set[str]]):
|
||||
def main():
|
||||
"""Run the script."""
|
||||
arg_parser = argparse.ArgumentParser(description=__doc__)
|
||||
arg_parser.add_argument("--include", type=str, action="append",
|
||||
help="Directory to search for IDL import files")
|
||||
arg_parser.add_argument("--install-dir", dest="install_dir", required=True,
|
||||
help="Directory to search for MongoDB binaries")
|
||||
arg_parser.add_argument("-v", "--verbose", action="count", help="Enable verbose logging")
|
||||
arg_parser.add_argument("api_version", metavar="API_VERSION", help="API Version to check")
|
||||
arg_parser.add_argument(
|
||||
"--include",
|
||||
type=str,
|
||||
action="append",
|
||||
help="Directory to search for IDL import files",
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"--install-dir",
|
||||
dest="install_dir",
|
||||
required=True,
|
||||
help="Directory to search for MongoDB binaries",
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"-v", "--verbose", action="count", help="Enable verbose logging"
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"api_version", metavar="API_VERSION", help="API Version to check"
|
||||
)
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
class FakeArgs:
|
||||
@ -190,12 +235,20 @@ def main():
|
||||
loggers.configure_loggers()
|
||||
loggers.new_job_logger(sys.argv[0], 0)
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger(LOGGER_NAME).setLevel(logging.DEBUG if args.verbose else logging.INFO)
|
||||
logging.getLogger(LOGGER_NAME).setLevel(
|
||||
logging.DEBUG if args.verbose else logging.INFO
|
||||
)
|
||||
|
||||
command_sets = {}
|
||||
command_sets["mongod"] = list_commands_for_api(args.api_version, "mongod", args.install_dir)
|
||||
command_sets["mongos"] = list_commands_for_api(args.api_version, "mongos", args.install_dir)
|
||||
command_sets["idl"] = set(get_command_definitions(args.api_version, os.getcwd(), args.include))
|
||||
command_sets["mongod"] = list_commands_for_api(
|
||||
args.api_version, "mongod", args.install_dir
|
||||
)
|
||||
command_sets["mongos"] = list_commands_for_api(
|
||||
args.api_version, "mongos", args.install_dir
|
||||
)
|
||||
command_sets["idl"] = set(
|
||||
get_command_definitions(args.api_version, os.getcwd(), args.include)
|
||||
)
|
||||
remove_skipped_commands(command_sets)
|
||||
assert_command_sets_equal(args.api_version, command_sets)
|
||||
|
||||
|
||||
@ -38,7 +38,9 @@ from packaging.version import Version
|
||||
|
||||
# Get relative imports to work when the package is not installed on the PYTHONPATH.
|
||||
if __name__ == "__main__" and __package__ is None:
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
sys.path.append(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
)
|
||||
|
||||
# pylint: disable=wrong-import-position
|
||||
from buildscripts.resmokelib.multiversionconstants import (
|
||||
@ -49,7 +51,7 @@ from buildscripts.resmokelib.multiversionconstants import (
|
||||
|
||||
# pylint: enable=wrong-import-position
|
||||
|
||||
LOGGER_NAME = 'checkout-idl'
|
||||
LOGGER_NAME = "checkout-idl"
|
||||
LOGGER = logging.getLogger(LOGGER_NAME)
|
||||
|
||||
|
||||
@ -57,9 +59,9 @@ def get_tags() -> List[str]:
|
||||
"""Get a list of git tags that the IDL compatibility script should check against."""
|
||||
|
||||
def gen_versions_and_tags():
|
||||
for tag in check_output(['git', 'tag']).decode().split():
|
||||
for tag in check_output(["git", "tag"]).decode().split():
|
||||
# Releases are like "r5.6.7". Older ones aren't r-prefixed but we don't care about them.
|
||||
if not tag.startswith('r'):
|
||||
if not tag.startswith("r"):
|
||||
continue
|
||||
|
||||
try:
|
||||
@ -90,7 +92,9 @@ def get_tags() -> List[str]:
|
||||
for version, tag in sorted(gen_versions_and_tags(), reverse=True):
|
||||
major_minor_version = Version(f"{version.major}.{version.minor}")
|
||||
if major_minor_version in results:
|
||||
candidate_tag, candidate_is_prerelease_version = results[major_minor_version]
|
||||
candidate_tag, candidate_is_prerelease_version = results[
|
||||
major_minor_version
|
||||
]
|
||||
if candidate_tag is None:
|
||||
# This is the first tag we have seen for this version. Set our first
|
||||
# candidate tag and if this tag is a prerelease version.
|
||||
@ -115,27 +119,36 @@ def make_idl_directories(tags: List[str], destination: str) -> None:
|
||||
for tag in tags:
|
||||
LOGGER.info("Checking out IDL files in %s", tag)
|
||||
directory = os.path.join(destination, tag)
|
||||
for path in check_output(['git', 'ls-tree', '--name-only', '-r', tag]).decode().split():
|
||||
if not path.endswith('.idl'):
|
||||
for path in (
|
||||
check_output(["git", "ls-tree", "--name-only", "-r", tag]).decode().split()
|
||||
):
|
||||
if not path.endswith(".idl"):
|
||||
continue
|
||||
|
||||
contents = check_output(['git', 'show', f'{tag}:{path}']).decode()
|
||||
contents = check_output(["git", "show", f"{tag}:{path}"]).decode()
|
||||
output_path = os.path.join(directory, path)
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
with open(output_path, 'w+') as fd:
|
||||
with open(output_path, "w+") as fd:
|
||||
fd.write(contents)
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the script."""
|
||||
arg_parser = argparse.ArgumentParser(description=__doc__)
|
||||
arg_parser.add_argument("-v", "--verbose", action="count", help="Enable verbose logging")
|
||||
arg_parser.add_argument("destination", metavar="DESTINATION",
|
||||
help="Directory to check out past IDL file versions")
|
||||
arg_parser.add_argument(
|
||||
"-v", "--verbose", action="count", help="Enable verbose logging"
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"destination",
|
||||
metavar="DESTINATION",
|
||||
help="Directory to check out past IDL file versions",
|
||||
)
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
logging.getLogger(LOGGER_NAME).setLevel(logging.DEBUG if args.verbose else logging.INFO)
|
||||
logging.getLogger(LOGGER_NAME).setLevel(
|
||||
logging.DEBUG if args.verbose else logging.INFO
|
||||
)
|
||||
|
||||
tags = get_tags()
|
||||
LOGGER.info("Fetching IDL files for past tags: %s", tags)
|
||||
|
||||
@ -37,7 +37,7 @@ from typing import List
|
||||
import yaml
|
||||
|
||||
# Permit imports from "buildscripts".
|
||||
sys.path.append(os.path.normpath(os.path.join(os.path.abspath(__file__), '../../..')))
|
||||
sys.path.append(os.path.normpath(os.path.join(os.path.abspath(__file__), "../../..")))
|
||||
|
||||
# pylint: disable=wrong-import-position
|
||||
from buildscripts.idl import lib
|
||||
@ -59,7 +59,7 @@ def get_all_feature_flags(idl_dirs: List[str] = None):
|
||||
# Most IDL files do not contain feature flags.
|
||||
# We can discard these quickly without expensive YAML parsing.
|
||||
with open(idl_path) as idl_file:
|
||||
if 'feature_flags' not in idl_file.read():
|
||||
if "feature_flags" not in idl_file.read():
|
||||
continue
|
||||
with open(idl_path) as idl_file:
|
||||
doc = parser.parse_file(idl_file, idl_path)
|
||||
@ -81,7 +81,9 @@ def get_all_feature_flags_turned_off_by_default(idl_dirs: List[str] = None):
|
||||
all_flags = get_all_feature_flags(idl_dirs)
|
||||
all_default_false_flags = [flag for flag in all_flags if all_flags[flag] != "true"]
|
||||
|
||||
with open("buildscripts/resmokeconfig/fully_disabled_feature_flags.yml") as fully_disabled_ffs:
|
||||
with open(
|
||||
"buildscripts/resmokeconfig/fully_disabled_feature_flags.yml"
|
||||
) as fully_disabled_ffs:
|
||||
force_disabled_flags = yaml.safe_load(fully_disabled_ffs)
|
||||
|
||||
return list(set(all_default_false_flags) - set(force_disabled_flags))
|
||||
@ -99,5 +101,5 @@ def main():
|
||||
gen_all_feature_flags_file()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -35,7 +35,7 @@ import sys
|
||||
from typing import List
|
||||
|
||||
# Permit imports from "buildscripts".
|
||||
sys.path.append(os.path.normpath(os.path.join(os.path.abspath(__file__), '../../..')))
|
||||
sys.path.append(os.path.normpath(os.path.join(os.path.abspath(__file__), "../../..")))
|
||||
|
||||
# pylint: disable=wrong-import-position
|
||||
from buildscripts.idl import lib
|
||||
@ -57,7 +57,7 @@ def gen_all_server_params(idl_dirs: List[str] = None):
|
||||
# Most IDL files do not contain server parameters.
|
||||
# We can discard these quickly without expensive YAML parsing.
|
||||
with open(idl_path) as idl_file:
|
||||
if 'server_parameters' not in idl_file.read():
|
||||
if "server_parameters" not in idl_file.read():
|
||||
continue
|
||||
with open(idl_path) as idl_file:
|
||||
doc = parser.parse_file(idl_file, idl_path)
|
||||
@ -79,5 +79,5 @@ def main():
|
||||
gen_all_server_params_file()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -33,6 +33,7 @@ Represents the derived IDL specification after type resolution in the binding pa
|
||||
This is a lossy translation from the IDL Syntax tree as the IDL AST only contains information about
|
||||
the enums and structs that need code generated for them, and just enough information to do that.
|
||||
"""
|
||||
|
||||
import enum
|
||||
|
||||
from . import common
|
||||
@ -44,8 +45,9 @@ class IDLBoundSpec(object):
|
||||
def __init__(self, spec, error_collection):
|
||||
# type: (IDLAST, errors.ParserErrorCollection) -> None
|
||||
"""Must specify either an IDL document or errors, not both."""
|
||||
assert (spec is None and error_collection is not None) or (spec is not None
|
||||
and error_collection is None)
|
||||
assert (spec is None and error_collection is not None) or (
|
||||
spec is not None and error_collection is None
|
||||
)
|
||||
self.spec = spec
|
||||
self.errors = error_collection
|
||||
|
||||
@ -290,12 +292,16 @@ class Field(common.SourceLocation):
|
||||
# type: () -> bool
|
||||
"""Returns true if the IDL compiler should add a call to serialization options for this field."""
|
||||
return self.query_shape is not None and self.query_shape in [
|
||||
QueryShapeFieldType.LITERAL, QueryShapeFieldType.ANONYMIZE
|
||||
QueryShapeFieldType.LITERAL,
|
||||
QueryShapeFieldType.ANONYMIZE,
|
||||
]
|
||||
|
||||
@property
|
||||
def should_shapify(self):
|
||||
return self.query_shape is not None and self.query_shape != QueryShapeFieldType.PARAMETER
|
||||
return (
|
||||
self.query_shape is not None
|
||||
and self.query_shape != QueryShapeFieldType.PARAMETER
|
||||
)
|
||||
|
||||
|
||||
class Privilege(common.SourceLocation):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -35,34 +35,34 @@ Utilities for validating bson types, etc.
|
||||
# scalar: True if the type is not an array or object
|
||||
# bson_type_enum: The BSONType enum value for the given type
|
||||
_BSON_TYPE_INFORMATION = {
|
||||
"double": {'scalar': True, 'bson_type_enum': 'NumberDouble'},
|
||||
"string": {'scalar': True, 'bson_type_enum': 'String'},
|
||||
"object": {'scalar': False, 'bson_type_enum': 'Object'},
|
||||
"array": {'scalar': False, 'bson_type_enum': 'Array'},
|
||||
"bindata": {'scalar': True, 'bson_type_enum': 'BinData'},
|
||||
"undefined": {'scalar': True, 'bson_type_enum': 'Undefined'},
|
||||
"objectid": {'scalar': True, 'bson_type_enum': 'jstOID'},
|
||||
"bool": {'scalar': True, 'bson_type_enum': 'Bool'},
|
||||
"date": {'scalar': True, 'bson_type_enum': 'Date'},
|
||||
"null": {'scalar': True, 'bson_type_enum': 'jstNULL'},
|
||||
"regex": {'scalar': True, 'bson_type_enum': 'RegEx'},
|
||||
"int": {'scalar': True, 'bson_type_enum': 'NumberInt'},
|
||||
"timestamp": {'scalar': True, 'bson_type_enum': 'bsonTimestamp'},
|
||||
"long": {'scalar': True, 'bson_type_enum': 'NumberLong'},
|
||||
"decimal": {'scalar': True, 'bson_type_enum': 'NumberDecimal'},
|
||||
"double": {"scalar": True, "bson_type_enum": "NumberDouble"},
|
||||
"string": {"scalar": True, "bson_type_enum": "String"},
|
||||
"object": {"scalar": False, "bson_type_enum": "Object"},
|
||||
"array": {"scalar": False, "bson_type_enum": "Array"},
|
||||
"bindata": {"scalar": True, "bson_type_enum": "BinData"},
|
||||
"undefined": {"scalar": True, "bson_type_enum": "Undefined"},
|
||||
"objectid": {"scalar": True, "bson_type_enum": "jstOID"},
|
||||
"bool": {"scalar": True, "bson_type_enum": "Bool"},
|
||||
"date": {"scalar": True, "bson_type_enum": "Date"},
|
||||
"null": {"scalar": True, "bson_type_enum": "jstNULL"},
|
||||
"regex": {"scalar": True, "bson_type_enum": "RegEx"},
|
||||
"int": {"scalar": True, "bson_type_enum": "NumberInt"},
|
||||
"timestamp": {"scalar": True, "bson_type_enum": "bsonTimestamp"},
|
||||
"long": {"scalar": True, "bson_type_enum": "NumberLong"},
|
||||
"decimal": {"scalar": True, "bson_type_enum": "NumberDecimal"},
|
||||
}
|
||||
|
||||
# Dictionary of BinData subtype type Information
|
||||
# scalar: True if the type is not an array or object
|
||||
# bindata_enum: The BinDataType enum value for the given type
|
||||
_BINDATA_SUBTYPE = {
|
||||
"generic": {'scalar': True, 'bindata_enum': 'BinDataGeneral'},
|
||||
"function": {'scalar': True, 'bindata_enum': 'Function'},
|
||||
"generic": {"scalar": True, "bindata_enum": "BinDataGeneral"},
|
||||
"function": {"scalar": True, "bindata_enum": "Function"},
|
||||
# Also simply known as type 2, deprecated, and requires special handling
|
||||
#"binary": {
|
||||
# "binary": {
|
||||
# 'scalar': False,
|
||||
# 'bindata_enum': 'ByteArrayDeprecated'
|
||||
#},
|
||||
# },
|
||||
# Deprecated
|
||||
# "uuid_old": {
|
||||
# 'scalar': False,
|
||||
@ -86,14 +86,14 @@ def is_scalar_bson_type(name):
|
||||
# type: (str) -> bool
|
||||
"""Return True if this bson type is a scalar."""
|
||||
assert is_valid_bson_type(name)
|
||||
return _BSON_TYPE_INFORMATION[name]['scalar'] # type: ignore
|
||||
return _BSON_TYPE_INFORMATION[name]["scalar"] # type: ignore
|
||||
|
||||
|
||||
def cpp_bson_type_name(name):
|
||||
# type: (str) -> str
|
||||
"""Return the C++ type name for a bson type."""
|
||||
assert is_valid_bson_type(name)
|
||||
return _BSON_TYPE_INFORMATION[name]['bson_type_enum'] # type: ignore
|
||||
return _BSON_TYPE_INFORMATION[name]["bson_type_enum"] # type: ignore
|
||||
|
||||
|
||||
def list_valid_types():
|
||||
@ -112,4 +112,4 @@ def cpp_bindata_subtype_type_name(name):
|
||||
# type: (str) -> str
|
||||
"""Return the C++ type name for a bindata subtype."""
|
||||
assert is_valid_bindata_subtype(name)
|
||||
return _BINDATA_SUBTYPE[name]['bindata_enum'] # type: ignore
|
||||
return _BINDATA_SUBTYPE[name]["bindata_enum"] # type: ignore
|
||||
|
||||
@ -47,7 +47,7 @@ def title_case(name):
|
||||
# Only capitalize the last part of a fully-qualified name
|
||||
pos = name.rfind("::")
|
||||
if pos > -1:
|
||||
return name[:pos + 2] + name[pos + 2:pos + 3].upper() + name[pos + 3:]
|
||||
return name[: pos + 2] + name[pos + 2 : pos + 3].upper() + name[pos + 3 :]
|
||||
|
||||
return name[0:1].upper() + name[1:]
|
||||
|
||||
@ -71,9 +71,9 @@ def _escape_template_string(template):
|
||||
# type: (str) -> str
|
||||
"""Escape the '$' in template strings unless followed by '{'."""
|
||||
# See https://docs.python.org/2/library/string.html#template-strings
|
||||
template = template.replace('${', '#{')
|
||||
template = template.replace('$', '$$')
|
||||
return template.replace('#{', '${')
|
||||
template = template.replace("${", "#{")
|
||||
template = template.replace("$", "$$")
|
||||
return template.replace("#{", "${")
|
||||
|
||||
|
||||
def template_format(template, template_params=None):
|
||||
@ -114,5 +114,9 @@ class SourceLocation(object):
|
||||
Example location message:
|
||||
test.idl: (17, 4)
|
||||
"""
|
||||
msg = "%s: (%d, %d)" % (os.path.basename(self.file_name), self.line, self.column)
|
||||
msg = "%s: (%d, %d)" % (
|
||||
os.path.basename(self.file_name),
|
||||
self.line,
|
||||
self.column,
|
||||
)
|
||||
return msg # type: ignore
|
||||
|
||||
@ -72,29 +72,52 @@ class CompilerImportResolver(parser.ImportResolverBase):
|
||||
# type: (str, str) -> str
|
||||
"""Return the complete path to an imported file name."""
|
||||
|
||||
logging.debug("Resolving imported file '%s' for file '%s'", imported_file_name, base_file)
|
||||
logging.debug(
|
||||
"Resolving imported file '%s' for file '%s'", imported_file_name, base_file
|
||||
)
|
||||
|
||||
# Check for fully-qualified paths
|
||||
logging.debug("Checking for imported file '%s' for file '%s' at '%s'", imported_file_name,
|
||||
base_file, imported_file_name)
|
||||
logging.debug(
|
||||
"Checking for imported file '%s' for file '%s' at '%s'",
|
||||
imported_file_name,
|
||||
base_file,
|
||||
imported_file_name,
|
||||
)
|
||||
if os.path.isabs(imported_file_name) and os.path.exists(imported_file_name):
|
||||
logging.debug("Found imported file '%s' for file '%s' at '%s'", imported_file_name,
|
||||
base_file, imported_file_name)
|
||||
logging.debug(
|
||||
"Found imported file '%s' for file '%s' at '%s'",
|
||||
imported_file_name,
|
||||
base_file,
|
||||
imported_file_name,
|
||||
)
|
||||
return imported_file_name
|
||||
|
||||
for candidate_dir in self._import_directories or []:
|
||||
base_dir = os.path.abspath(candidate_dir)
|
||||
resolved_file_name = os.path.normpath(os.path.join(base_dir, imported_file_name))
|
||||
resolved_file_name = os.path.normpath(
|
||||
os.path.join(base_dir, imported_file_name)
|
||||
)
|
||||
|
||||
logging.debug("Checking for imported file '%s' for file '%s' at '%s'",
|
||||
imported_file_name, base_file, resolved_file_name)
|
||||
logging.debug(
|
||||
"Checking for imported file '%s' for file '%s' at '%s'",
|
||||
imported_file_name,
|
||||
base_file,
|
||||
resolved_file_name,
|
||||
)
|
||||
|
||||
if os.path.exists(resolved_file_name):
|
||||
logging.debug("Found imported file '%s' for file '%s' at '%s'", imported_file_name,
|
||||
base_file, resolved_file_name)
|
||||
logging.debug(
|
||||
"Found imported file '%s' for file '%s' at '%s'",
|
||||
imported_file_name,
|
||||
base_file,
|
||||
resolved_file_name,
|
||||
)
|
||||
return resolved_file_name
|
||||
|
||||
msg = ("Cannot find imported file '%s' for file '%s'" % (imported_file_name, base_file))
|
||||
msg = "Cannot find imported file '%s' for file '%s'" % (
|
||||
imported_file_name,
|
||||
base_file,
|
||||
)
|
||||
logging.error(msg)
|
||||
|
||||
raise errors.IDLError(msg)
|
||||
@ -102,7 +125,7 @@ class CompilerImportResolver(parser.ImportResolverBase):
|
||||
def open(self, resolved_file_name):
|
||||
# type: (str) -> Any
|
||||
"""Return an io.Stream for the requested file."""
|
||||
return io.open(resolved_file_name, encoding='utf-8')
|
||||
return io.open(resolved_file_name, encoding="utf-8")
|
||||
|
||||
|
||||
def _write_dependencies(spec, write_dependencies_inline):
|
||||
@ -133,7 +156,8 @@ def _update_import_includes(args, spec, header_file_name):
|
||||
|
||||
if args.output_base_dir:
|
||||
base_include_h_file_name = os.path.relpath(
|
||||
os.path.normpath(header_file_name), os.path.normpath(args.output_base_dir))
|
||||
os.path.normpath(header_file_name), os.path.normpath(args.output_base_dir)
|
||||
)
|
||||
else:
|
||||
base_include_h_file_name = os.path.abspath(header_file_name)
|
||||
|
||||
@ -144,21 +168,29 @@ def _update_import_includes(args, spec, header_file_name):
|
||||
if not spec.globals:
|
||||
spec.globals = syntax.Global(args.input_file, -1, -1)
|
||||
|
||||
first_dir = base_include_h_file_name.split('/')[0]
|
||||
first_dir = base_include_h_file_name.split("/")[0]
|
||||
|
||||
for resolved_file_name in spec.imports.resolved_imports:
|
||||
# Guess: the file naming rules are consistent across IDL invocations
|
||||
include_h_file_name = os.path.splitext(resolved_file_name)[0] + args.output_suffix + ".h"
|
||||
include_h_file_name = (
|
||||
os.path.splitext(resolved_file_name)[0] + args.output_suffix + ".h"
|
||||
)
|
||||
|
||||
if args.output_base_dir:
|
||||
base_dir = os.path.normpath(args.output_base_dir)
|
||||
include_h_file_name = os.path.relpath(os.path.normpath(include_h_file_name), base_dir)
|
||||
include_h_file_name = os.path.relpath(
|
||||
os.path.normpath(include_h_file_name), base_dir
|
||||
)
|
||||
|
||||
if os.path.isabs(base_dir):
|
||||
include_h_file_name = os.path.join(
|
||||
base_dir, include_h_file_name[include_h_file_name.rfind(first_dir):])
|
||||
base_dir,
|
||||
include_h_file_name[include_h_file_name.rfind(first_dir) :],
|
||||
)
|
||||
else:
|
||||
include_h_file_name = include_h_file_name[include_h_file_name.find(first_dir):]
|
||||
include_h_file_name = include_h_file_name[
|
||||
include_h_file_name.find(first_dir) :
|
||||
]
|
||||
else:
|
||||
include_h_file_name = os.path.abspath(include_h_file_name)
|
||||
|
||||
@ -176,9 +208,12 @@ def compile_idl(args):
|
||||
logging.error("File '%s' not found", args.input_file)
|
||||
|
||||
if args.output_source is None:
|
||||
if '.' not in args.input_file:
|
||||
logging.error("File name '%s' must be end with a filename extension, such as '%s.idl'",
|
||||
args.input_file, args.input_file)
|
||||
if "." not in args.input_file:
|
||||
logging.error(
|
||||
"File name '%s' must be end with a filename extension, such as '%s.idl'",
|
||||
args.input_file,
|
||||
args.input_file,
|
||||
)
|
||||
return False
|
||||
|
||||
file_name_prefix = os.path.splitext(args.input_file)[0]
|
||||
@ -194,9 +229,12 @@ def compile_idl(args):
|
||||
args.target_arch = platform.machine()
|
||||
|
||||
# Compile the IDL through the 3 passes
|
||||
with io.open(args.input_file, encoding='utf-8') as file_stream:
|
||||
parsed_doc = parser.parse(file_stream, args.input_file,
|
||||
CompilerImportResolver(args.import_directories))
|
||||
with io.open(args.input_file, encoding="utf-8") as file_stream:
|
||||
parsed_doc = parser.parse(
|
||||
file_stream,
|
||||
args.input_file,
|
||||
CompilerImportResolver(args.import_directories),
|
||||
)
|
||||
|
||||
if not parsed_doc.errors:
|
||||
if args.write_dependencies or args.write_dependencies_inline:
|
||||
@ -210,8 +248,13 @@ def compile_idl(args):
|
||||
|
||||
bound_doc = binder.bind(parsed_doc.spec)
|
||||
if not bound_doc.errors:
|
||||
generator.generate_code(bound_doc.spec, args.target_arch, args.output_base_dir,
|
||||
header_file_name, source_file_name)
|
||||
generator.generate_code(
|
||||
bound_doc.spec,
|
||||
args.target_arch,
|
||||
args.output_base_dir,
|
||||
header_file_name,
|
||||
source_file_name,
|
||||
)
|
||||
|
||||
return True
|
||||
else:
|
||||
|
||||
@ -32,7 +32,7 @@ from abc import ABCMeta, abstractmethod
|
||||
|
||||
from . import bson, common, writer
|
||||
|
||||
_STD_ARRAY_UINT8_16 = 'std::array<std::uint8_t,16>'
|
||||
_STD_ARRAY_UINT8_16 = "std::array<std::uint8_t,16>"
|
||||
|
||||
|
||||
def is_primitive_scalar_type(cpp_type):
|
||||
@ -42,26 +42,31 @@ def is_primitive_scalar_type(cpp_type):
|
||||
|
||||
Primitive scalar types need to have a default value to prevent warnings from Coverity.
|
||||
"""
|
||||
cpp_type = cpp_type.replace(' ', '')
|
||||
cpp_type = cpp_type.replace(" ", "")
|
||||
# TODO (SERVER-50101): Remove 'multiversion::FeatureCompatibilityVersion' once IDL supports
|
||||
# a commmand cpp_type of C++ enum.
|
||||
return cpp_type in [
|
||||
'bool', 'double', 'std::int32_t', 'std::uint32_t', 'std::uint64_t', 'std::int64_t',
|
||||
'multiversion::FeatureCompatibilityVersion'
|
||||
"bool",
|
||||
"double",
|
||||
"std::int32_t",
|
||||
"std::uint32_t",
|
||||
"std::uint64_t",
|
||||
"std::int64_t",
|
||||
"multiversion::FeatureCompatibilityVersion",
|
||||
]
|
||||
|
||||
|
||||
def is_primitive_type(cpp_type):
|
||||
# type: (str) -> bool
|
||||
"""Return True if a cpp_type is a primitive type and should not be returned as reference."""
|
||||
cpp_type = cpp_type.replace(' ', '')
|
||||
cpp_type = cpp_type.replace(" ", "")
|
||||
return is_primitive_scalar_type(cpp_type) or cpp_type == _STD_ARRAY_UINT8_16
|
||||
|
||||
|
||||
def _qualify_optional_type(cpp_type):
|
||||
# type: (str) -> str
|
||||
"""Qualify the type as optional."""
|
||||
return 'boost::optional<%s>' % (cpp_type)
|
||||
return "boost::optional<%s>" % (cpp_type)
|
||||
|
||||
|
||||
def _qualify_array_type(cpp_type):
|
||||
@ -74,7 +79,7 @@ def _optionally_make_call(method_name, param):
|
||||
# type: (str, str) -> str
|
||||
"""Return a call to method_name if it is not None, otherwise return an empty string."""
|
||||
if not method_name:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
return "%s(%s);" % (method_name, param)
|
||||
|
||||
@ -142,7 +147,9 @@ class CppTypeBase(metaclass=ABCMeta):
|
||||
return common.template_args(
|
||||
"${optionally_call_validator} ${member_name} = std::move(value);",
|
||||
member_name=member_name,
|
||||
optionally_call_validator=_optionally_make_call(validator_method_name, "value"),
|
||||
optionally_call_validator=_optionally_make_call(
|
||||
validator_method_name, "value"
|
||||
),
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
@ -175,7 +182,9 @@ class _CppTypeBasic(CppTypeBase):
|
||||
|
||||
def return_by_reference(self):
|
||||
# type: () -> bool
|
||||
return not is_primitive_type(self.get_type_name()) and not self._field.type.is_enum
|
||||
return (
|
||||
not is_primitive_type(self.get_type_name()) and not self._field.type.is_enum
|
||||
)
|
||||
|
||||
def is_view_type(self):
|
||||
# type: () -> bool
|
||||
@ -187,14 +196,17 @@ class _CppTypeBasic(CppTypeBase):
|
||||
|
||||
def get_getter_body(self, member_name):
|
||||
# type: (str) -> str
|
||||
return common.template_args('return ${member_name};', member_name=member_name)
|
||||
return common.template_args("return ${member_name};", member_name=member_name)
|
||||
|
||||
def get_setter_body(self, member_name, validator_method_name):
|
||||
# type: (str, str) -> str
|
||||
return common.template_args(
|
||||
'${optionally_call_validator} ${member_name} = std::move(value);',
|
||||
optionally_call_validator=_optionally_make_call(validator_method_name,
|
||||
'value'), member_name=member_name)
|
||||
"${optionally_call_validator} ${member_name} = std::move(value);",
|
||||
optionally_call_validator=_optionally_make_call(
|
||||
validator_method_name, "value"
|
||||
),
|
||||
member_name=member_name,
|
||||
)
|
||||
|
||||
def get_transform_to_getter_type(self, expression):
|
||||
# type: (str) -> Optional[str]
|
||||
@ -240,15 +252,18 @@ class _CppTypeView(CppTypeBase):
|
||||
|
||||
def get_getter_body(self, member_name):
|
||||
# type: (str) -> str
|
||||
return common.template_args('return ${member_name};', member_name=member_name)
|
||||
return common.template_args("return ${member_name};", member_name=member_name)
|
||||
|
||||
def get_setter_body(self, member_name, validator_method_name):
|
||||
# type: (str, str) -> str
|
||||
return common.template_args(
|
||||
'auto _tmpValue = ${value}; ${optionally_call_validator} ${member_name} = std::move(_tmpValue);',
|
||||
member_name=member_name, optionally_call_validator=_optionally_make_call(
|
||||
validator_method_name,
|
||||
'_tmpValue'), value=self.get_transform_to_storage_type("value"))
|
||||
"auto _tmpValue = ${value}; ${optionally_call_validator} ${member_name} = std::move(_tmpValue);",
|
||||
member_name=member_name,
|
||||
optionally_call_validator=_optionally_make_call(
|
||||
validator_method_name, "_tmpValue"
|
||||
),
|
||||
value=self.get_transform_to_storage_type("value"),
|
||||
)
|
||||
|
||||
def get_transform_to_getter_type(self, expression):
|
||||
# type: (str) -> Optional[str]
|
||||
@ -257,7 +272,7 @@ class _CppTypeView(CppTypeBase):
|
||||
def get_transform_to_storage_type(self, expression):
|
||||
# type: (str) -> Optional[str]
|
||||
return common.template_args(
|
||||
'${expression}.toString()',
|
||||
"${expression}.toString()",
|
||||
expression=expression,
|
||||
)
|
||||
|
||||
@ -267,7 +282,7 @@ class _CppTypeVector(CppTypeBase):
|
||||
|
||||
def __init__(self, field):
|
||||
# type: (ast.Field) -> None
|
||||
super(_CppTypeVector, self).__init__(field, 'std::vector<std::uint8_t>')
|
||||
super(_CppTypeVector, self).__init__(field, "std::vector<std::uint8_t>")
|
||||
|
||||
def get_type_name(self):
|
||||
# type: () -> str
|
||||
@ -279,7 +294,7 @@ class _CppTypeVector(CppTypeBase):
|
||||
|
||||
def get_getter_setter_type(self):
|
||||
# type: () -> str
|
||||
return 'ConstDataRange'
|
||||
return "ConstDataRange"
|
||||
|
||||
def return_by_reference(self):
|
||||
# type: () -> bool
|
||||
@ -295,27 +310,34 @@ class _CppTypeVector(CppTypeBase):
|
||||
|
||||
def get_getter_body(self, member_name):
|
||||
# type: (str) -> str
|
||||
return common.template_args('return ConstDataRange(${member_name});',
|
||||
member_name=member_name)
|
||||
return common.template_args(
|
||||
"return ConstDataRange(${member_name});", member_name=member_name
|
||||
)
|
||||
|
||||
def get_setter_body(self, member_name, validator_method_name):
|
||||
# type: (str, str) -> str
|
||||
return common.template_args(
|
||||
'auto _tmpValue = ${value}; ${optionally_call_validator} ${member_name} = std::move(_tmpValue);',
|
||||
member_name=member_name, optionally_call_validator=_optionally_make_call(
|
||||
validator_method_name,
|
||||
'_tmpValue'), value=self.get_transform_to_storage_type("value"))
|
||||
"auto _tmpValue = ${value}; ${optionally_call_validator} ${member_name} = std::move(_tmpValue);",
|
||||
member_name=member_name,
|
||||
optionally_call_validator=_optionally_make_call(
|
||||
validator_method_name, "_tmpValue"
|
||||
),
|
||||
value=self.get_transform_to_storage_type("value"),
|
||||
)
|
||||
|
||||
def get_transform_to_getter_type(self, expression):
|
||||
# type: (str) -> Optional[str]
|
||||
return common.template_args('ConstDataRange(${expression});', expression=expression)
|
||||
return common.template_args(
|
||||
"ConstDataRange(${expression});", expression=expression
|
||||
)
|
||||
|
||||
def get_transform_to_storage_type(self, expression):
|
||||
# type: (str) -> Optional[str]
|
||||
return common.template_args(
|
||||
'std::vector<std::uint8_t>(reinterpret_cast<const uint8_t*>(${expression}.data()), ' +
|
||||
'reinterpret_cast<const uint8_t*>(${expression}.data()) + ${expression}.length())',
|
||||
expression=expression)
|
||||
"std::vector<std::uint8_t>(reinterpret_cast<const uint8_t*>(${expression}.data()), "
|
||||
+ "reinterpret_cast<const uint8_t*>(${expression}.data()) + ${expression}.length())",
|
||||
expression=expression,
|
||||
)
|
||||
|
||||
|
||||
class _CppTypeDelegating(CppTypeBase):
|
||||
@ -360,7 +382,9 @@ class _CppTypeDelegating(CppTypeBase):
|
||||
|
||||
def get_storage_type_setter_body(self, member_name, validator_method_name):
|
||||
# type: (str, str) -> Optional[str]
|
||||
return self._base.get_storage_type_setter_body(member_name, validator_method_name)
|
||||
return self._base.get_storage_type_setter_body(
|
||||
member_name, validator_method_name
|
||||
)
|
||||
|
||||
def get_transform_to_getter_type(self, expression):
|
||||
# type: (str) -> Optional[str]
|
||||
@ -392,7 +416,7 @@ class _CppTypeArray(_CppTypeDelegating):
|
||||
# type: (str) -> str
|
||||
convert = self.get_transform_to_getter_type(member_name)
|
||||
if convert:
|
||||
return common.template_args('return ${convert};', convert=convert)
|
||||
return common.template_args("return ${convert};", convert=convert)
|
||||
return self._base.get_getter_body(member_name)
|
||||
|
||||
def get_setter_body(self, member_name, validator_method_name):
|
||||
@ -400,16 +424,20 @@ class _CppTypeArray(_CppTypeDelegating):
|
||||
convert = self.get_transform_to_storage_type("value")
|
||||
if convert:
|
||||
return common.template_args(
|
||||
'auto _tmpValue = ${convert}; ${optionally_call_validator} ${member_name} = std::move(_tmpValue);',
|
||||
member_name=member_name, optionally_call_validator=_optionally_make_call(
|
||||
validator_method_name, '_tmpValue'), convert=convert)
|
||||
"auto _tmpValue = ${convert}; ${optionally_call_validator} ${member_name} = std::move(_tmpValue);",
|
||||
member_name=member_name,
|
||||
optionally_call_validator=_optionally_make_call(
|
||||
validator_method_name, "_tmpValue"
|
||||
),
|
||||
convert=convert,
|
||||
)
|
||||
return self._base.get_setter_body(member_name, validator_method_name)
|
||||
|
||||
def get_transform_to_getter_type(self, expression):
|
||||
# type: (str) -> Optional[str]
|
||||
if self._base.get_storage_type() != self._base.get_getter_setter_type():
|
||||
return common.template_args(
|
||||
'transformVector(${expression})',
|
||||
"transformVector(${expression})",
|
||||
expression=expression,
|
||||
)
|
||||
return None
|
||||
@ -418,7 +446,7 @@ class _CppTypeArray(_CppTypeDelegating):
|
||||
# type: (str) -> Optional[str]
|
||||
if self._base.get_storage_type() != self._base.get_getter_setter_type():
|
||||
return common.template_args(
|
||||
'transformVector(${expression})',
|
||||
"transformVector(${expression})",
|
||||
expression=expression,
|
||||
)
|
||||
return None
|
||||
@ -443,7 +471,9 @@ class _CppTypeOptional(_CppTypeDelegating):
|
||||
|
||||
def get_getter_body(self, member_name):
|
||||
# type: (str) -> str
|
||||
base_expression = common.template_args("*${member_name}", member_name=member_name)
|
||||
base_expression = common.template_args(
|
||||
"*${member_name}", member_name=member_name
|
||||
)
|
||||
|
||||
convert = self._base.get_transform_to_getter_type(base_expression)
|
||||
if convert:
|
||||
@ -457,13 +487,18 @@ class _CppTypeOptional(_CppTypeDelegating):
|
||||
} else {
|
||||
return boost::none;
|
||||
}
|
||||
"""), member_name=member_name, convert=convert)
|
||||
"""),
|
||||
member_name=member_name,
|
||||
convert=convert,
|
||||
)
|
||||
elif self.is_view_type():
|
||||
# For optionals around view types, do an explicit construction
|
||||
return common.template_args('return ${param_type}{${member_name}};',
|
||||
param_type=self.get_getter_setter_type(),
|
||||
member_name=member_name)
|
||||
return common.template_args('return ${member_name};', member_name=member_name)
|
||||
return common.template_args(
|
||||
"return ${param_type}{${member_name}};",
|
||||
param_type=self.get_getter_setter_type(),
|
||||
member_name=member_name,
|
||||
)
|
||||
return common.template_args("return ${member_name};", member_name=member_name)
|
||||
|
||||
def _get_setter_body(self, member_name, validator_method_name, convert):
|
||||
# type: (str, str, str) -> str
|
||||
@ -479,8 +514,13 @@ class _CppTypeOptional(_CppTypeDelegating):
|
||||
} else {
|
||||
${member_name} = boost::none;
|
||||
}
|
||||
"""), member_name=member_name, convert=convert,
|
||||
optionally_call_validator=_optionally_make_call(validator_method_name, '_tmpValue'))
|
||||
"""),
|
||||
member_name=member_name,
|
||||
convert=convert,
|
||||
optionally_call_validator=_optionally_make_call(
|
||||
validator_method_name, "_tmpValue"
|
||||
),
|
||||
)
|
||||
return self._base.get_setter_body(member_name, validator_method_name)
|
||||
|
||||
def get_setter_body(self, member_name, validator_method_name):
|
||||
@ -498,9 +538,9 @@ def get_cpp_type_from_cpp_type_name(field, cpp_type_name, array):
|
||||
# type: (ast.Field, str, bool) -> CppTypeBase
|
||||
"""Get the C++ Type information for the given C++ type name, e.g. std::string."""
|
||||
cpp_type_info: CppTypeBase
|
||||
if cpp_type_name == 'std::string':
|
||||
cpp_type_info = _CppTypeView(field, 'std::string', 'std::string', 'StringData')
|
||||
elif cpp_type_name == 'std::vector<std::uint8_t>':
|
||||
if cpp_type_name == "std::string":
|
||||
cpp_type_info = _CppTypeView(field, "std::string", "std::string", "StringData")
|
||||
elif cpp_type_name == "std::vector<std::uint8_t>":
|
||||
cpp_type_info = _CppTypeVector(field)
|
||||
else:
|
||||
cpp_type_info = _CppTypeBasic(field, cpp_type_name)
|
||||
@ -514,7 +554,9 @@ def get_cpp_type_from_cpp_type_name(field, cpp_type_name, array):
|
||||
def get_cpp_type_without_optional(field):
|
||||
# type: (ast.Field) -> CppTypeBase
|
||||
"""Get the C++ Type information for the given field but ignore optional."""
|
||||
return get_cpp_type_from_cpp_type_name(field, field.type.cpp_type, field.type.is_array)
|
||||
return get_cpp_type_from_cpp_type_name(
|
||||
field, field.type.cpp_type, field.type.is_array
|
||||
)
|
||||
|
||||
|
||||
def get_cpp_type(field):
|
||||
@ -550,15 +592,17 @@ class BsonCppTypeBase(object, metaclass=ABCMeta):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def gen_serializer_expression(self, indented_writer, expression, should_shapify=False,
|
||||
is_catalog_ctxt=False):
|
||||
def gen_serializer_expression(
|
||||
self, indented_writer, expression, should_shapify=False, is_catalog_ctxt=False
|
||||
):
|
||||
# type: (writer.IndentedTextWriter, str, bool, bool) -> str
|
||||
"""Generate code with the text writer and return an expression to serialize the type."""
|
||||
pass
|
||||
|
||||
|
||||
def _call_method_or_global_function(expression, ast_type, should_shapify=False,
|
||||
is_catalog_ctxt=False):
|
||||
def _call_method_or_global_function(
|
||||
expression, ast_type, should_shapify=False, is_catalog_ctxt=False
|
||||
):
|
||||
# type: (str, ast.Type, bool, bool) -> str
|
||||
"""
|
||||
Given a fully-qualified method name, call it correctly.
|
||||
@ -568,25 +612,27 @@ def _call_method_or_global_function(expression, ast_type, should_shapify=False,
|
||||
enum deserializers/serializers which are not methods.
|
||||
"""
|
||||
method_name = ast_type.serializer
|
||||
serialization_context = 'getSerializationContext()' if ast_type.deserialize_with_tenant else ''
|
||||
shape_options = ''
|
||||
serialization_context = (
|
||||
"getSerializationContext()" if ast_type.deserialize_with_tenant else ""
|
||||
)
|
||||
shape_options = ""
|
||||
if should_shapify:
|
||||
shape_options = 'options'
|
||||
shape_options = "options"
|
||||
|
||||
short_method_name = writer.get_method_name(method_name)
|
||||
if writer.is_function(method_name):
|
||||
if ast_type.deserialize_with_tenant:
|
||||
if is_catalog_ctxt:
|
||||
# serializeForCatalog doesn't need a serializationContext
|
||||
serialization_context = ''
|
||||
serialization_context = ""
|
||||
method_name = method_name.replace("serialize", "serializeForCatalog")
|
||||
else:
|
||||
serialization_context = ', ' + serialization_context
|
||||
serialization_context = ", " + serialization_context
|
||||
if should_shapify:
|
||||
shape_options = ', ' + shape_options
|
||||
shape_options = ", " + shape_options
|
||||
|
||||
return common.template_args(
|
||||
'${method_name}(${expression}${shape_options}${serialization_context})',
|
||||
"${method_name}(${expression}${shape_options}${serialization_context})",
|
||||
expression=expression,
|
||||
method_name=method_name,
|
||||
shape_options=shape_options,
|
||||
@ -594,7 +640,7 @@ def _call_method_or_global_function(expression, ast_type, should_shapify=False,
|
||||
)
|
||||
|
||||
return common.template_args(
|
||||
'${expression}.${method_name}(${shape_options}${serialization_context})',
|
||||
"${expression}.${method_name}(${shape_options}${serialization_context})",
|
||||
expression=expression,
|
||||
method_name=short_method_name,
|
||||
shape_options=shape_options,
|
||||
@ -612,19 +658,23 @@ class _CommonBsonCppTypeBase(BsonCppTypeBase):
|
||||
|
||||
def gen_deserializer_expression(self, indented_writer, object_instance):
|
||||
# type: (writer.IndentedTextWriter, str) -> str
|
||||
return common.template_args('${object_instance}.${method_name}()',
|
||||
object_instance=object_instance,
|
||||
method_name=self._deserialize_method_name)
|
||||
return common.template_args(
|
||||
"${object_instance}.${method_name}()",
|
||||
object_instance=object_instance,
|
||||
method_name=self._deserialize_method_name,
|
||||
)
|
||||
|
||||
def has_serializer(self):
|
||||
# type: () -> bool
|
||||
return self._ast_type.serializer is not None
|
||||
|
||||
def gen_serializer_expression(self, indented_writer, expression, should_shapify=False,
|
||||
is_catalog_ctxt=False):
|
||||
def gen_serializer_expression(
|
||||
self, indented_writer, expression, should_shapify=False, is_catalog_ctxt=False
|
||||
):
|
||||
# type: (writer.IndentedTextWriter, str, bool, bool) -> str
|
||||
return _call_method_or_global_function(expression, self._ast_type, should_shapify,
|
||||
is_catalog_ctxt)
|
||||
return _call_method_or_global_function(
|
||||
expression, self._ast_type, should_shapify, is_catalog_ctxt
|
||||
)
|
||||
|
||||
|
||||
class _ObjectBsonCppTypeBase(BsonCppTypeBase):
|
||||
@ -635,34 +685,43 @@ class _ObjectBsonCppTypeBase(BsonCppTypeBase):
|
||||
if self._ast_type.deserializer:
|
||||
# Call a method like: Class::method(const BSONObj& value)
|
||||
indented_writer.write_line(
|
||||
common.template_args('const BSONObj localObject = ${object_instance}.Obj();',
|
||||
object_instance=object_instance))
|
||||
common.template_args(
|
||||
"const BSONObj localObject = ${object_instance}.Obj();",
|
||||
object_instance=object_instance,
|
||||
)
|
||||
)
|
||||
return "localObject"
|
||||
|
||||
# Just pass the BSONObj through without trying to parse it.
|
||||
return common.template_args('${object_instance}.Obj()', object_instance=object_instance)
|
||||
return common.template_args(
|
||||
"${object_instance}.Obj()", object_instance=object_instance
|
||||
)
|
||||
|
||||
def has_serializer(self):
|
||||
# type: () -> bool
|
||||
return self._ast_type.serializer is not None
|
||||
|
||||
def gen_serializer_expression(self, indented_writer, expression, should_shapify=False,
|
||||
is_catalog_ctxt=False):
|
||||
def gen_serializer_expression(
|
||||
self, indented_writer, expression, should_shapify=False, is_catalog_ctxt=False
|
||||
):
|
||||
# type: (writer.IndentedTextWriter, str, bool, bool) -> str
|
||||
method_name = writer.get_method_name(self._ast_type.serializer)
|
||||
function_arguments = []
|
||||
# SerializationContext is tied to tenant deserialization
|
||||
if self._ast_type.deserialize_with_tenant:
|
||||
function_arguments.append('getSerializationContext()')
|
||||
function_arguments.append("getSerializationContext()")
|
||||
# Provide options if custom shapification required.
|
||||
if should_shapify:
|
||||
function_arguments.append('options')
|
||||
function_arguments.append("options")
|
||||
|
||||
indented_writer.write_line(
|
||||
common.template_args(
|
||||
'const BSONObj localObject = ${expression}.${method_name}(${function_arguments});',
|
||||
expression=expression, method_name=method_name,
|
||||
function_arguments=', '.join(function_arguments)))
|
||||
"const BSONObj localObject = ${expression}.${method_name}(${function_arguments});",
|
||||
expression=expression,
|
||||
method_name=method_name,
|
||||
function_arguments=", ".join(function_arguments),
|
||||
)
|
||||
)
|
||||
return "localObject"
|
||||
|
||||
|
||||
@ -673,25 +732,34 @@ class _ArrayBsonCppTypeBase(BsonCppTypeBase):
|
||||
# type: (writer.IndentedTextWriter, str) -> str
|
||||
if self._ast_type.deserializer:
|
||||
indented_writer.write_line(
|
||||
common.template_args('BSONArray localArray(${object_instance}.Obj());',
|
||||
object_instance=object_instance))
|
||||
common.template_args(
|
||||
"BSONArray localArray(${object_instance}.Obj());",
|
||||
object_instance=object_instance,
|
||||
)
|
||||
)
|
||||
return "localArray"
|
||||
|
||||
# Just pass the BSONObj through without trying to parse it.
|
||||
return common.template_args('BSONArray(${object_instance}.Obj())',
|
||||
object_instance=object_instance)
|
||||
return common.template_args(
|
||||
"BSONArray(${object_instance}.Obj())", object_instance=object_instance
|
||||
)
|
||||
|
||||
def has_serializer(self):
|
||||
# type: () -> bool
|
||||
return self._ast_type.serializer is not None
|
||||
|
||||
def gen_serializer_expression(self, indented_writer, expression, should_shapify=False,
|
||||
is_catalog_ctxt=False):
|
||||
def gen_serializer_expression(
|
||||
self, indented_writer, expression, should_shapify=False, is_catalog_ctxt=False
|
||||
):
|
||||
# type: (writer.IndentedTextWriter, str, bool, bool) -> str
|
||||
method_name = writer.get_method_name(self._ast_type.serializer)
|
||||
indented_writer.write_line(
|
||||
common.template_args('BSONArray localArray(${expression}.${method_name}());',
|
||||
expression=expression, method_name=method_name))
|
||||
common.template_args(
|
||||
"BSONArray localArray(${expression}.${method_name}());",
|
||||
expression=expression,
|
||||
method_name=method_name,
|
||||
)
|
||||
)
|
||||
return "localArray"
|
||||
|
||||
|
||||
@ -700,32 +768,45 @@ class _BinDataBsonCppTypeBase(BsonCppTypeBase):
|
||||
|
||||
def gen_deserializer_expression(self, indented_writer, object_instance):
|
||||
# type: (writer.IndentedTextWriter, str) -> str
|
||||
if self._ast_type.bindata_subtype == 'uuid':
|
||||
return common.template_args('uassertStatusOK(UUID::parse(${object_instance}))',
|
||||
object_instance=object_instance)
|
||||
return common.template_args('${object_instance}._binDataVector()',
|
||||
object_instance=object_instance)
|
||||
if self._ast_type.bindata_subtype == "uuid":
|
||||
return common.template_args(
|
||||
"uassertStatusOK(UUID::parse(${object_instance}))",
|
||||
object_instance=object_instance,
|
||||
)
|
||||
return common.template_args(
|
||||
"${object_instance}._binDataVector()", object_instance=object_instance
|
||||
)
|
||||
|
||||
def has_serializer(self):
|
||||
# type: () -> bool
|
||||
return True
|
||||
|
||||
def gen_serializer_expression(self, indented_writer, expression, should_shapify=False,
|
||||
is_catalog_ctxt=False):
|
||||
def gen_serializer_expression(
|
||||
self, indented_writer, expression, should_shapify=False, is_catalog_ctxt=False
|
||||
):
|
||||
# type: (writer.IndentedTextWriter, str, bool, bool) -> str
|
||||
if self._ast_type.serializer:
|
||||
method_name = writer.get_method_name(self._ast_type.serializer)
|
||||
indented_writer.write_line(
|
||||
common.template_args('ConstDataRange tempCDR = ${expression}.${method_name}();',
|
||||
expression=expression, method_name=method_name))
|
||||
common.template_args(
|
||||
"ConstDataRange tempCDR = ${expression}.${method_name}();",
|
||||
expression=expression,
|
||||
method_name=method_name,
|
||||
)
|
||||
)
|
||||
else:
|
||||
indented_writer.write_line(
|
||||
common.template_args('ConstDataRange tempCDR(${expression});',
|
||||
expression=expression))
|
||||
common.template_args(
|
||||
"ConstDataRange tempCDR(${expression});", expression=expression
|
||||
)
|
||||
)
|
||||
|
||||
return common.template_args(
|
||||
'BSONBinData(tempCDR.data(), tempCDR.length(), ${bindata_subtype})',
|
||||
bindata_subtype=bson.cpp_bindata_subtype_type_name(self._ast_type.bindata_subtype))
|
||||
"BSONBinData(tempCDR.data(), tempCDR.length(), ${bindata_subtype})",
|
||||
bindata_subtype=bson.cpp_bindata_subtype_type_name(
|
||||
self._ast_type.bindata_subtype
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# For some types, we want to support custom serialization but defer most of the serialization to
|
||||
@ -739,19 +820,19 @@ def get_bson_cpp_type(ast_type):
|
||||
if len(ast_type.bson_serialization_type) > 1:
|
||||
return None
|
||||
|
||||
if ast_type.bson_serialization_type[0] == 'string':
|
||||
if ast_type.bson_serialization_type[0] == "string":
|
||||
return _CommonBsonCppTypeBase(ast_type, "valueStringData")
|
||||
|
||||
if ast_type.bson_serialization_type[0] == 'object':
|
||||
if ast_type.bson_serialization_type[0] == "object":
|
||||
return _ObjectBsonCppTypeBase(ast_type)
|
||||
|
||||
if ast_type.bson_serialization_type[0] == 'array':
|
||||
if ast_type.bson_serialization_type[0] == "array":
|
||||
return _ArrayBsonCppTypeBase(ast_type)
|
||||
|
||||
if ast_type.bson_serialization_type[0] == 'bindata':
|
||||
if ast_type.bson_serialization_type[0] == "bindata":
|
||||
return _BinDataBsonCppTypeBase(ast_type)
|
||||
|
||||
if ast_type.bson_serialization_type[0] == 'int':
|
||||
if ast_type.bson_serialization_type[0] == "int":
|
||||
return _CommonBsonCppTypeBase(ast_type, "_numberInt")
|
||||
|
||||
# Unsupported type
|
||||
|
||||
@ -52,7 +52,9 @@ class EnumTypeInfoBase(object, metaclass=ABCMeta):
|
||||
def get_qualified_cpp_type_name(self):
|
||||
# type: () -> str
|
||||
"""Get the fully qualified C++ type name for an enum."""
|
||||
return common.qualify_cpp_name(self._enum.cpp_namespace, self.get_cpp_type_name())
|
||||
return common.qualify_cpp_name(
|
||||
self._enum.cpp_namespace, self.get_cpp_type_name()
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_cpp_type_name(self):
|
||||
@ -69,32 +71,37 @@ class EnumTypeInfoBase(object, metaclass=ABCMeta):
|
||||
def _get_enum_deserializer_name(self):
|
||||
# type: () -> str
|
||||
"""Return the name of deserializer function without prefix."""
|
||||
return common.template_args("${enum_name}_parse",
|
||||
enum_name=common.title_case(self._enum.name))
|
||||
return common.template_args(
|
||||
"${enum_name}_parse", enum_name=common.title_case(self._enum.name)
|
||||
)
|
||||
|
||||
def get_enum_deserializer_name(self):
|
||||
# type: () -> str
|
||||
"""Return the name of deserializer function with non-method prefix."""
|
||||
return "::" + common.qualify_cpp_name(self._enum.cpp_namespace,
|
||||
self._get_enum_deserializer_name())
|
||||
return "::" + common.qualify_cpp_name(
|
||||
self._enum.cpp_namespace, self._get_enum_deserializer_name()
|
||||
)
|
||||
|
||||
def _get_enum_serializer_name(self):
|
||||
# type: () -> str
|
||||
"""Return the name of serializer function without prefix."""
|
||||
return common.template_args("${enum_name}_serializer",
|
||||
enum_name=common.title_case(self._enum.name))
|
||||
return common.template_args(
|
||||
"${enum_name}_serializer", enum_name=common.title_case(self._enum.name)
|
||||
)
|
||||
|
||||
def get_enum_serializer_name(self):
|
||||
# type: () -> str
|
||||
"""Return the name of serializer function with non-method prefix."""
|
||||
return "::" + common.qualify_cpp_name(self._enum.cpp_namespace,
|
||||
self._get_enum_serializer_name())
|
||||
return "::" + common.qualify_cpp_name(
|
||||
self._enum.cpp_namespace, self._get_enum_serializer_name()
|
||||
)
|
||||
|
||||
def _get_enum_extra_data_name(self):
|
||||
# type: () -> str
|
||||
"""Return the name of the get_extra_data function without prefix."""
|
||||
return common.template_args("${enum_name}_get_extra_data",
|
||||
enum_name=common.title_case(self._enum.name))
|
||||
return common.template_args(
|
||||
"${enum_name}_get_extra_data", enum_name=common.title_case(self._enum.name)
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_cpp_value_assignment(self, enum_value):
|
||||
@ -137,9 +144,11 @@ class EnumTypeInfoBase(object, metaclass=ABCMeta):
|
||||
if len(self._get_populated_extra_values()) == 0:
|
||||
return None
|
||||
|
||||
return common.template_args("BSONObj ${function_name}(${enum_name} value)",
|
||||
enum_name=self.get_cpp_type_name(),
|
||||
function_name=self._get_enum_extra_data_name())
|
||||
return common.template_args(
|
||||
"BSONObj ${function_name}(${enum_name} value)",
|
||||
enum_name=self.get_cpp_type_name(),
|
||||
function_name=self._get_enum_extra_data_name(),
|
||||
)
|
||||
|
||||
def gen_extra_data_definition(self, indented_writer):
|
||||
# type: (writer.IndentedTextWriter) -> None
|
||||
@ -152,43 +161,56 @@ class EnumTypeInfoBase(object, metaclass=ABCMeta):
|
||||
|
||||
# Generate an anonymous namespace full of BSON constants.
|
||||
#
|
||||
with writer.NamespaceScopeBlock(indented_writer, ['']):
|
||||
with writer.NamespaceScopeBlock(indented_writer, [""]):
|
||||
for enum_value in extra_values:
|
||||
indented_writer.write_line(
|
||||
common.template_args('// %s' % json.dumps(enum_value.extra_data)))
|
||||
common.template_args("// %s" % json.dumps(enum_value.extra_data))
|
||||
)
|
||||
|
||||
bson_value = ''.join(
|
||||
[('\\x%02x' % (b)) for b in bson.BSON.encode(enum_value.extra_data)])
|
||||
bson_value = "".join(
|
||||
[("\\x%02x" % (b)) for b in bson.BSON.encode(enum_value.extra_data)]
|
||||
)
|
||||
indented_writer.write_line(
|
||||
common.template_args(
|
||||
'const BSONObj ${const_name}("${bson_value}");',
|
||||
const_name=_get_constant_enum_extra_data_name(self._enum, enum_value),
|
||||
bson_value=bson_value))
|
||||
const_name=_get_constant_enum_extra_data_name(
|
||||
self._enum, enum_value
|
||||
),
|
||||
bson_value=bson_value,
|
||||
)
|
||||
)
|
||||
|
||||
indented_writer.write_empty_line()
|
||||
|
||||
# Generate implementation of get_extra_data function.
|
||||
#
|
||||
template_params = {
|
||||
'enum_name': self.get_cpp_type_name(),
|
||||
'function_name': self.get_extra_data_declaration(),
|
||||
"enum_name": self.get_cpp_type_name(),
|
||||
"function_name": self.get_extra_data_declaration(),
|
||||
}
|
||||
|
||||
with writer.TemplateContext(indented_writer, template_params):
|
||||
with writer.IndentedScopedBlock(indented_writer, "${function_name} {", "}"):
|
||||
with writer.IndentedScopedBlock(indented_writer, "switch (value) {", "}"):
|
||||
with writer.IndentedScopedBlock(
|
||||
indented_writer, "switch (value) {", "}"
|
||||
):
|
||||
for enum_value in extra_values:
|
||||
indented_writer.write_template(
|
||||
'case ${enum_name}::%s: return %s;' %
|
||||
(enum_value.name,
|
||||
_get_constant_enum_extra_data_name(self._enum, enum_value)))
|
||||
"case ${enum_name}::%s: return %s;"
|
||||
% (
|
||||
enum_value.name,
|
||||
_get_constant_enum_extra_data_name(
|
||||
self._enum, enum_value
|
||||
),
|
||||
)
|
||||
)
|
||||
if len(extra_values) != len(self._enum.values):
|
||||
# One or more enums does not have associated extra data.
|
||||
indented_writer.write_line('default: return BSONObj();')
|
||||
indented_writer.write_line("default: return BSONObj();")
|
||||
|
||||
if len(extra_values) == len(self._enum.values):
|
||||
# All enum cases handled, the compiler should know this.
|
||||
indented_writer.write_line('MONGO_UNREACHABLE;')
|
||||
indented_writer.write_line("MONGO_UNREACHABLE;")
|
||||
|
||||
|
||||
class _EnumTypeInt(EnumTypeInfoBase, metaclass=ABCMeta):
|
||||
@ -210,17 +232,21 @@ class _EnumTypeInt(EnumTypeInfoBase, metaclass=ABCMeta):
|
||||
# type: () -> str
|
||||
return common.template_args(
|
||||
"${enum_name} ${function_name}(const IDLParserContext& ctxt, std::int32_t value)",
|
||||
enum_name=self.get_cpp_type_name(), function_name=self._get_enum_deserializer_name())
|
||||
enum_name=self.get_cpp_type_name(),
|
||||
function_name=self._get_enum_deserializer_name(),
|
||||
)
|
||||
|
||||
def gen_deserializer_definition(self, indented_writer):
|
||||
# type: (writer.IndentedTextWriter) -> None
|
||||
enum_values = sorted(cast(ast.Enum, self._enum).values, key=lambda ev: int(ev.value))
|
||||
enum_values = sorted(
|
||||
cast(ast.Enum, self._enum).values, key=lambda ev: int(ev.value)
|
||||
)
|
||||
|
||||
template_params = {
|
||||
'enum_name': self.get_cpp_type_name(),
|
||||
'function_name': self.get_deserializer_declaration(),
|
||||
'min_value': enum_values[0].name,
|
||||
'max_value': enum_values[-1].name,
|
||||
"enum_name": self.get_cpp_type_name(),
|
||||
"function_name": self.get_deserializer_declaration(),
|
||||
"min_value": enum_values[0].name,
|
||||
"max_value": enum_values[-1].name,
|
||||
}
|
||||
|
||||
with writer.TemplateContext(indented_writer, template_params):
|
||||
@ -235,33 +261,41 @@ class _EnumTypeInt(EnumTypeInfoBase, metaclass=ABCMeta):
|
||||
} else {
|
||||
return static_cast<${enum_name}>(value);
|
||||
}
|
||||
"""))
|
||||
""")
|
||||
)
|
||||
|
||||
def get_serializer_declaration(self):
|
||||
# type: () -> str
|
||||
"""Get the serializer function declaration minus trailing semicolon."""
|
||||
return common.template_args("std::int32_t ${function_name}(${enum_name} value)",
|
||||
enum_name=self.get_cpp_type_name(),
|
||||
function_name=self._get_enum_serializer_name())
|
||||
return common.template_args(
|
||||
"std::int32_t ${function_name}(${enum_name} value)",
|
||||
enum_name=self.get_cpp_type_name(),
|
||||
function_name=self._get_enum_serializer_name(),
|
||||
)
|
||||
|
||||
def gen_serializer_definition(self, indented_writer):
|
||||
# type: (writer.IndentedTextWriter) -> None
|
||||
"""Generate the serializer function definition."""
|
||||
template_params = {
|
||||
'enum_name': self.get_cpp_type_name(),
|
||||
'function_name': self.get_serializer_declaration(),
|
||||
"enum_name": self.get_cpp_type_name(),
|
||||
"function_name": self.get_serializer_declaration(),
|
||||
}
|
||||
|
||||
with writer.TemplateContext(indented_writer, template_params):
|
||||
with writer.IndentedScopedBlock(indented_writer, "${function_name} {", "}"):
|
||||
indented_writer.write_template('return static_cast<std::int32_t>(value);')
|
||||
indented_writer.write_template(
|
||||
"return static_cast<std::int32_t>(value);"
|
||||
)
|
||||
|
||||
|
||||
def _get_constant_enum_extra_data_name(idl_enum, enum_value):
|
||||
# type: (Union[syntax.Enum,ast.Enum], Union[syntax.EnumValue,ast.EnumValue]) -> str
|
||||
"""Return the C++ name for a string constant of enum extra data value."""
|
||||
return common.template_args('k${enum_name}_${name}_extra_data',
|
||||
enum_name=common.title_case(idl_enum.name), name=enum_value.name)
|
||||
return common.template_args(
|
||||
"k${enum_name}_${name}_extra_data",
|
||||
enum_name=common.title_case(idl_enum.name),
|
||||
name=enum_value.name,
|
||||
)
|
||||
|
||||
|
||||
class _EnumTypeString(EnumTypeInfoBase, metaclass=ABCMeta):
|
||||
@ -269,8 +303,9 @@ class _EnumTypeString(EnumTypeInfoBase, metaclass=ABCMeta):
|
||||
|
||||
def get_cpp_type_name(self):
|
||||
# type: () -> str
|
||||
return common.template_args("${enum_name}Enum",
|
||||
enum_name=common.title_case(self._enum.name))
|
||||
return common.template_args(
|
||||
"${enum_name}Enum", enum_name=common.title_case(self._enum.name)
|
||||
)
|
||||
|
||||
def get_bson_types(self):
|
||||
# type: () -> List[str]
|
||||
@ -278,65 +313,76 @@ class _EnumTypeString(EnumTypeInfoBase, metaclass=ABCMeta):
|
||||
|
||||
def get_cpp_value_assignment(self, enum_value):
|
||||
# type: (ast.EnumValue) -> str
|
||||
return ''
|
||||
return ""
|
||||
|
||||
def get_deserializer_declaration(self):
|
||||
# type: () -> str
|
||||
cpp_type = self.get_cpp_type_name()
|
||||
func = self._get_enum_deserializer_name()
|
||||
return f'{cpp_type} {func}(const IDLParserContext& ctxt, StringData value)'
|
||||
return f"{cpp_type} {func}(const IDLParserContext& ctxt, StringData value)"
|
||||
|
||||
def gen_deserializer_definition(self, indented_writer):
|
||||
# type: (writer.IndentedTextWriter) -> None
|
||||
cpp_type = self.get_cpp_type_name()
|
||||
func = self._get_enum_deserializer_name()
|
||||
with writer.NamespaceScopeBlock(indented_writer, ['']):
|
||||
with writer.IndentedScopedBlock(indented_writer,
|
||||
f'constexpr std::array {cpp_type}_values{{', '};'):
|
||||
with writer.NamespaceScopeBlock(indented_writer, [""]):
|
||||
with writer.IndentedScopedBlock(
|
||||
indented_writer, f"constexpr std::array {cpp_type}_values{{", "};"
|
||||
):
|
||||
for e in self._enum.values:
|
||||
indented_writer.write_line(f'{cpp_type}::{e.name},')
|
||||
with writer.IndentedScopedBlock(indented_writer,
|
||||
f'constexpr std::array {cpp_type}_names{{', '};'):
|
||||
indented_writer.write_line(f"{cpp_type}::{e.name},")
|
||||
with writer.IndentedScopedBlock(
|
||||
indented_writer, f"constexpr std::array {cpp_type}_names{{", "};"
|
||||
):
|
||||
for e in self._enum.values:
|
||||
indented_writer.write_line(f'"{e.value}"_sd,')
|
||||
indented_writer.write_empty_line()
|
||||
|
||||
with writer.IndentedScopedBlock(
|
||||
indented_writer,
|
||||
f"{cpp_type} {func}(const IDLParserContext& ctxt, StringData value) {{",
|
||||
"}",
|
||||
):
|
||||
indented_writer.write_line(
|
||||
f"static constexpr auto onMatch = [](int i) {{ return {cpp_type}_values[i]; }};"
|
||||
)
|
||||
indented_writer.write_line(
|
||||
f"auto onFail = [&] {{ ctxt.throwBadEnumValue(value); return {cpp_type}{{}}; }};"
|
||||
)
|
||||
writer.gen_string_table_find_function_block(
|
||||
indented_writer,
|
||||
f'{cpp_type} {func}(const IDLParserContext& ctxt, StringData value) {{', '}'):
|
||||
indented_writer.write_line(
|
||||
f'static constexpr auto onMatch = [](int i) {{ return {cpp_type}_values[i]; }};')
|
||||
indented_writer.write_line(
|
||||
f'auto onFail = [&] {{ ctxt.throwBadEnumValue(value); return {cpp_type}{{}}; }};')
|
||||
writer.gen_string_table_find_function_block(indented_writer, 'value', 'onMatch({})',
|
||||
'onFail()',
|
||||
[e.value for e in self._enum.values])
|
||||
"value",
|
||||
"onMatch({})",
|
||||
"onFail()",
|
||||
[e.value for e in self._enum.values],
|
||||
)
|
||||
|
||||
def get_serializer_declaration(self):
|
||||
# type: () -> str
|
||||
"""Get the serializer function declaration minus trailing semicolon."""
|
||||
cpp_type = self.get_cpp_type_name()
|
||||
func = self._get_enum_serializer_name()
|
||||
return f'StringData {func}({cpp_type} value)'
|
||||
return f"StringData {func}({cpp_type} value)"
|
||||
|
||||
def gen_serializer_definition(self, indented_writer):
|
||||
# type: (writer.IndentedTextWriter) -> None
|
||||
"""Generate the serializer function definition."""
|
||||
func = self._get_enum_serializer_name()
|
||||
cpp_type = self.get_cpp_type_name()
|
||||
with writer.IndentedScopedBlock(indented_writer, f'StringData {func}({cpp_type} value) {{',
|
||||
'}'):
|
||||
indented_writer.write_line('auto idx = static_cast<size_t>(value);')
|
||||
indented_writer.write_line(f'invariant(idx < {cpp_type}_names.size());')
|
||||
indented_writer.write_line(f'return {cpp_type}_names[idx];')
|
||||
with writer.IndentedScopedBlock(
|
||||
indented_writer, f"StringData {func}({cpp_type} value) {{", "}"
|
||||
):
|
||||
indented_writer.write_line("auto idx = static_cast<size_t>(value);")
|
||||
indented_writer.write_line(f"invariant(idx < {cpp_type}_names.size());")
|
||||
indented_writer.write_line(f"return {cpp_type}_names[idx];")
|
||||
|
||||
|
||||
def get_type_info(idl_enum):
|
||||
# type: (Union[syntax.Enum,ast.Enum]) -> Optional[EnumTypeInfoBase]
|
||||
"""Get the type information for a given enumeration type, return None if not supported."""
|
||||
if idl_enum.type == 'int':
|
||||
if idl_enum.type == "int":
|
||||
return _EnumTypeInt(idl_enum)
|
||||
elif idl_enum.type == 'string':
|
||||
elif idl_enum.type == "string":
|
||||
return _EnumTypeString(idl_enum)
|
||||
|
||||
return None
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user