diff --git a/.gitattributes b/.gitattributes index 97539eea5a6..c82856f037a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -6,6 +6,7 @@ external rules-lint-ignored=true **/*.tpl.h rules-lint-ignored=true **/*.tpl.cpp rules-lint-ignored=true +rpm/*.spec rules-lint-ignored=true src/mongo/bson/util/bson_column_compressed_data.inl rules-lint-ignored=true src/mongo/bson/util/simple8b.inl rules-lint-ignored=true *.idl linguist-language=yaml diff --git a/bazel/bazelisk.py b/bazel/bazelisk.py index 985004179df..ae52b951087 100755 --- a/bazel/bazelisk.py +++ b/bazel/bazelisk.py @@ -198,7 +198,9 @@ def get_releases_json(bazelisk_directory): pass with open(releases, "wb") as f: - body = read_remote_text_file("https://api.github.com/repos/bazelbuild/bazel/releases") + body = read_remote_text_file( + "https://api.github.com/repos/bazelbuild/bazel/releases" + ) f.write(body.encode("utf-8")) return json.loads(body) @@ -245,7 +247,9 @@ def get_operating_system(): if operating_system not in ("linux", "darwin", "windows"): raise Exception( 'Unsupported operating system "{}". ' - "Bazel currently only supports Linux, macOS and Windows.".format(operating_system) + "Bazel currently only supports Linux, macOS and Windows.".format( + operating_system + ) ) return operating_system @@ -262,7 +266,10 @@ def determine_bazel_filename(version): if machine not in supported_machines: raise Exception( 'Unsupported machine architecture "{}". Bazel {} only supports {} on {}.'.format( - machine, version, ", ".join(supported_machines), operating_system.capitalize() + machine, + version, + ", ".join(supported_machines), + operating_system.capitalize(), ) ) @@ -270,7 +277,9 @@ def determine_bazel_filename(version): bazel_flavor = "bazel" if get_env_or_config("BAZELISK_NOJDK", "0") != "0": bazel_flavor = "bazel_nojdk" - return "{}-{}-{}-{}{}".format(bazel_flavor, version, operating_system, machine, filename_suffix) + return "{}-{}-{}-{}{}".format( + bazel_flavor, version, operating_system, machine, filename_suffix + ) def get_supported_machine_archs(version, operating_system): @@ -498,7 +507,9 @@ def execute_bazel(bazel_path, argv): cmd = make_bazel_cmd(bazel_path, argv) # We cannot use close_fds on Windows, so disable it there. - p = subprocess.Popen([cmd["exec"]] + cmd["args"], close_fds=os.name != "nt", env=cmd["env"]) + p = subprocess.Popen( + [cmd["exec"]] + cmd["args"], close_fds=os.name != "nt", env=cmd["env"] + ) while True: try: return p.wait() diff --git a/bazel/config/generate_config_header.py b/bazel/config/generate_config_header.py index 569b69d59cc..de6b9f15020 100644 --- a/bazel/config/generate_config_header.py +++ b/bazel/config/generate_config_header.py @@ -13,7 +13,9 @@ import textwrap from typing import Dict -def write_config_header(input_path: str, output_path: str, definitions: Dict[str, str]) -> None: +def write_config_header( + input_path: str, output_path: str, definitions: Dict[str, str] +) -> None: with open(input_path) as in_file: content = in_file.read() @@ -27,18 +29,28 @@ def write_config_header(input_path: str, output_path: str, definitions: Dict[str if __name__ == "__main__": parser = argparse.ArgumentParser(description="Generate a config header file") - parser.add_argument("--compiler-path", help="Path to the compiler executable", required=True) - parser.add_argument("--compiler-args", help="Extra compiler arguments", required=True) + parser.add_argument( + "--compiler-path", help="Path to the compiler executable", required=True + ) + parser.add_argument( + "--compiler-args", help="Extra compiler arguments", required=True + ) parser.add_argument("--env-vars", help="Extra environment variables", required=True) parser.add_argument( "--output-path", help="Path to the output config header file", required=True ) parser.add_argument( - "--template-path", help="Path to the config header's template file", required=True + "--template-path", + help="Path to the config header's template file", + required=True, ) parser.add_argument("--extra-definitions", help="Extra header definitions") - parser.add_argument("--check-path", help="Path to the suppored configure checks", required=True) - parser.add_argument("--log-path", help="Path to the suppored configure checks", required=True) + parser.add_argument( + "--check-path", help="Path to the suppored configure checks", required=True + ) + parser.add_argument( + "--log-path", help="Path to the suppored configure checks", required=True + ) parser.add_argument("--additional-input", help="extra files", action="append") args = parser.parse_args() @@ -72,7 +84,11 @@ if __name__ == "__main__": %s """ - % (generate_config_module, called_args, list(generate_config_header_args.keys())) + % ( + generate_config_module, + called_args, + list(generate_config_header_args.keys()), + ) ) ) diff --git a/bazel/coverity/generate_coverity_targets.py b/bazel/coverity/generate_coverity_targets.py index 219308e5931..7c824b220fa 100644 --- a/bazel/coverity/generate_coverity_targets.py +++ b/bazel/coverity/generate_coverity_targets.py @@ -17,11 +17,15 @@ bazel_cache = os.path.expanduser(args.bazel_cache) # the cc_library and cc_binaries in our build. There is not a good way from # within the build to get all those targets, so we will generate the list via query # https://sig-product-docs.synopsys.com/bundle/coverity-docs/page/coverity-analysis/topics/building_with_bazel.html#build_with_bazel -cmd = [ +cmd = ( + [ bazel_executable, bazel_cache, "aquery", - ] + bazel_cmd_args + [args.bazel_query] + ] + + bazel_cmd_args + + [args.bazel_query] +) print(f"Running command: {cmd}") proc = subprocess.run( cmd, @@ -33,9 +37,7 @@ proc = subprocess.run( print(proc.stderr) targets = set() -with open('coverity_targets.list', 'w') as f: +with open("coverity_targets.list", "w") as f: for line in proc.stdout.splitlines(): if line.startswith(" Target: "): f.write(line.split()[-1] + "\n") - - diff --git a/bazel/format/BUILD.bazel b/bazel/format/BUILD.bazel index 2c3e2de4f80..e02d5e654b2 100644 --- a/bazel/format/BUILD.bazel +++ b/bazel/format/BUILD.bazel @@ -39,6 +39,7 @@ format_multirun( graphql = "//:prettier", html = "//:prettier", markdown = "//:prettier", + python = "@aspect_rules_lint//format:ruff", sql = "//:prettier", starlark = "@buildifier_prebuilt//:buildifier", visibility = ["//visibility:public"], diff --git a/bazel/format/rules_lint_format_wrapper.py b/bazel/format/rules_lint_format_wrapper.py index 0a5d6e4e084..5b11a4f117f 100644 --- a/bazel/format/rules_lint_format_wrapper.py +++ b/bazel/format/rules_lint_format_wrapper.py @@ -31,7 +31,9 @@ def _git_diff(args: list) -> str: return result.stdout.strip() + os.linesep -def _get_files_changed_since_fork_point(origin_branch: str = "origin/master") -> List[str]: +def _get_files_changed_since_fork_point( + origin_branch: str = "origin/master", +) -> List[str]: """Query git to get a list of files in the repo from a diff.""" # There are 3 diffs we run: # 1. List of commits between origin/master and HEAD of current branch @@ -68,7 +70,9 @@ def run_rules_lint( print("Running rules_lint formatter") if files_to_format != "all": command += files_to_format - repo_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + repo_path = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) subprocess.run(command, check=True, env=os.environ, cwd=repo_path) except subprocess.CalledProcessError: return False @@ -83,7 +87,9 @@ def run_shellscripts_linters(shellscripts_linters: pathlib.Path, check: bool) -> command.append("fix") else: print("Running shellscripts linter") - repo_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + repo_path = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + ) subprocess.run(command, check=True, env=os.environ, cwd=repo_path) except subprocess.CalledProcessError: return False @@ -145,7 +151,9 @@ def main() -> int: # If we are running in bazel, default the directory to the workspace default_dir = os.environ.get("BUILD_WORKSPACE_DIRECTORY") if not default_dir: - print("This script must be run though bazel. Please run 'bazel run //:format' instead") + print( + "This script must be run though bazel. Please run 'bazel run //:format' instead" + ) print("*** IF BAZEL IS NOT INSTALLED, RUN THE FOLLOWING: ***\n") print("python buildscripts/install_bazel.py") return 1 @@ -154,7 +162,9 @@ def main() -> int: prog="Format", description="This script formats code in mongodb" ) - parser.add_argument("--check", help="Run in check mode", default=False, action="store_true") + parser.add_argument( + "--check", help="Run in check mode", default=False, action="store_true" + ) parser.add_argument( "--prettier", help="Set the path to prettier", required=True, type=pathlib.Path ) @@ -200,7 +210,9 @@ def main() -> int: print( f"The number of commits between current branch and origin branch ({args.origin_branch}) is too large: {distance} commits" ) - print("WARNING!!! Defaulting to formatting all files, this may take a while.") + print( + "WARNING!!! Defaulting to formatting all files, this may take a while." + ) print( "Please update your local branch with the latest changes from origin, or use `bazel run format -- --origin-branch other_branch` to select a different origin branch" ) @@ -217,12 +229,17 @@ def main() -> int: validate_bazel_groups(generate_report=True, fix=not args.check) if files_to_format != "all": - files_to_format = [str(file) for file in files_to_format if os.path.isfile(file)] + files_to_format = [ + str(file) for file in files_to_format if os.path.isfile(file) + ] return ( 0 if run_rules_lint( - args.rules_lint_format, args.rules_lint_format_check, args.check, files_to_format + args.rules_lint_format, + args.rules_lint_format_check, + args.check, + files_to_format, ) and run_shellscripts_linters(shellscripts_linters_path, args.check) and run_prettier(prettier_path, args.check, files_to_format) diff --git a/bazel/install_rules/install_rules.py b/bazel/install_rules/install_rules.py index 63e1446a5b3..7d36d6d9d28 100644 --- a/bazel/install_rules/install_rules.py +++ b/bazel/install_rules/install_rules.py @@ -7,7 +7,9 @@ parser = argparse.ArgumentParser() parser.add_argument("--depfile", action="append") parser.add_argument("--install-dir") -parser.add_argument("--install-mode", choices=["copy", "symlink", "hardlink"], default="hardlink") +parser.add_argument( + "--install-mode", choices=["copy", "symlink", "hardlink"], default="hardlink" +) args = parser.parse_args() if os.path.exists(args.install_dir): @@ -67,12 +69,15 @@ def install(src, install_type): if os.path.isdir(src): for root, _, files in os.walk(src): for name in files: - dest_dir = os.path.dirname(os.path.join(root, name)).replace( - src, dst - ) + dest_dir = os.path.dirname( + os.path.join(root, name) + ).replace(src, dst) if not os.path.exists(dest_dir): os.makedirs(dest_dir) - os.link(os.path.join(root, name), os.path.join(dest_dir, name)) + os.link( + os.path.join(root, name), + os.path.join(dest_dir, name), + ) else: try: os.link(src, dst) diff --git a/bazel/merge_tidy_configs.py b/bazel/merge_tidy_configs.py index 92458554b2c..73e07881947 100644 --- a/bazel/merge_tidy_configs.py +++ b/bazel/merge_tidy_configs.py @@ -71,7 +71,9 @@ def merge_check_options_into_config( override = check_options_list_to_map(incoming_config.get("CheckOptions")) if override: base.update(override) # later wins - target_config["CheckOptions"] = [{"key": k, "value": v} for k, v in sorted(base.items())] + target_config["CheckOptions"] = [ + {"key": k, "value": v} for k, v in sorted(base.items()) + ] def deep_merge_dicts(base: Any, override: Any) -> Any: @@ -176,7 +178,9 @@ def main() -> None: # then generic merge: merged_config = deep_merge_dicts(merged_config, incoming_config) - merged_config["Checks"] = ",".join(split_checks_to_list(merged_config.get("Checks"))) + merged_config["Checks"] = ",".join( + split_checks_to_list(merged_config.get("Checks")) + ) output_path = pathlib.Path(args.out) output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, "w", encoding="utf-8") as f: @@ -184,4 +188,4 @@ def main() -> None: if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/bazel/platforms/remote_execution_containers_generator.py b/bazel/platforms/remote_execution_containers_generator.py index 31d563f2d96..703776c727a 100755 --- a/bazel/platforms/remote_execution_containers_generator.py +++ b/bazel/platforms/remote_execution_containers_generator.py @@ -10,13 +10,19 @@ from datetime import datetime def log_subprocess_run(*args, **kwargs): arg_list_or_string = kwargs["args"] if "args" in kwargs else args[0] - print(" ".join(arg_list_or_string) if type(arg_list_or_string) == list else arg_list_or_string) + print( + " ".join(arg_list_or_string) + if type(arg_list_or_string) == list + else arg_list_or_string + ) return subprocess.run(*args, **kwargs) def main(): parser = argparse.ArgumentParser() - parser.add_argument("--distro", type=str, help="Restrict to only update a single distro.") + parser.add_argument( + "--distro", type=str, help="Restrict to only update a single distro." + ) parser.add_argument( "--skip-cleanup", action="store_true", @@ -42,14 +48,18 @@ Your docker images, volumes and containers will be purged if you continue. Enter code = compile(f.read(), container_file_path, "exec") exec(code, {}, remote_execution_containers) - for distro, re_container in remote_execution_containers["REMOTE_EXECUTION_CONTAINERS"].items(): + for distro, re_container in remote_execution_containers[ + "REMOTE_EXECUTION_CONTAINERS" + ].items(): if args.distro is not None: if distro != args.distro: continue if not args.skip_cleanup: # Clean host system between container builds to avoid running into disk space issues. - print("Cleaning host system's docker images, containers, volumes, and networks...") + print( + "Cleaning host system's docker images, containers, volumes, and networks..." + ) for command in [ "docker stop $(docker ps -a -q)", # Stop all running containers "docker rm $(docker ps -a -q)", # Remove all containers @@ -65,7 +75,9 @@ Your docker images, volumes and containers will be purged if you continue. Enter print(f"Using dockerfile: {dockerfile}") print(f"Using tag: {tag}\n") - log_subprocess_run(["docker", "buildx", "create", "--use", "default"], check=True) + log_subprocess_run( + ["docker", "buildx", "create", "--use", "default"], check=True + ) log_subprocess_run( [ "docker", diff --git a/bazel/resmoke/resmoke_config_generator.py b/bazel/resmoke/resmoke_config_generator.py index 30a7788f06f..a76fdb5ccc3 100644 --- a/bazel/resmoke/resmoke_config_generator.py +++ b/bazel/resmoke/resmoke_config_generator.py @@ -18,7 +18,12 @@ def main( with open(base_config, "rt") as fh: base_config_content = yaml.safe_load(fh) if "selector" in base_config_content: - for x in ["roots", "exclude_files", "exclude_with_any_tags", "include_with_any_tags"]: + for x in [ + "roots", + "exclude_files", + "exclude_with_any_tags", + "include_with_any_tags", + ]: base_config_content["selector"].pop(x, None) content = base_config_content diff --git a/bazel/resmoke/resmoke_shim.py b/bazel/resmoke/resmoke_shim.py index 7adeb80fd6b..ad11d0873ec 100644 --- a/bazel/resmoke/resmoke_shim.py +++ b/bazel/resmoke/resmoke_shim.py @@ -62,9 +62,13 @@ if __name__ == "__main__": if os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR"): undeclared_output_dir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") - resmoke_args.append(f"--dbpathPrefix={os.path.join(undeclared_output_dir,'data')}") + resmoke_args.append( + f"--dbpathPrefix={os.path.join(undeclared_output_dir,'data')}" + ) resmoke_args.append(f"--taskWorkDir={undeclared_output_dir}") - resmoke_args.append(f"--reportFile={os.path.join(undeclared_output_dir,'report.json')}") + resmoke_args.append( + f"--reportFile={os.path.join(undeclared_output_dir,'report.json')}" + ) if os.environ.get("TEST_SRCDIR"): test_srcdir = os.environ.get("TEST_SRCDIR") diff --git a/bazel/toolchains/cc/mongo_linux/mongo_toolchain_version_generator.py b/bazel/toolchains/cc/mongo_linux/mongo_toolchain_version_generator.py index a671a910448..f10fa31ed5e 100644 --- a/bazel/toolchains/cc/mongo_linux/mongo_toolchain_version_generator.py +++ b/bazel/toolchains/cc/mongo_linux/mongo_toolchain_version_generator.py @@ -52,7 +52,9 @@ PLATFORM_NAME_MAP = { REQUESTS_SESSION = requests.Session() REQUESTS_SESSION.mount( "https://", - HTTPAdapter(max_retries=Retry(total=5, backoff_factor=1, status_forcelist=[502, 503, 504])), + HTTPAdapter( + max_retries=Retry(total=5, backoff_factor=1, status_forcelist=[502, 503, 504]) + ), ) @@ -61,7 +63,9 @@ def download_toolchain(toolchain_url: str, local_path: str) -> bool: response = REQUESTS_SESSION.get(toolchain_url) if response.status_code != requests.codes.ok: - print(f"WARNING: HTTP {response.status_code} status downloading {toolchain_url}") + print( + f"WARNING: HTTP {response.status_code} status downloading {toolchain_url}" + ) return False with open(local_path, "wb") as f: @@ -88,10 +92,14 @@ def main(): help="The build id, this should be the toolchain revision (githash), or the evergreen task id (version and date) if it is a --patch toolchain.", ) parser.add_argument( - "toolchain_version", choices=toolchain_versions + ["all"], help="Toolchain version" + "toolchain_version", + choices=toolchain_versions + ["all"], + help="Toolchain version", ) parser.add_argument( - "toolchain_component", choices=toolchain_components + ["all"], help="Toolchain component" + "toolchain_component", + choices=toolchain_components + ["all"], + help="Toolchain component", ) parser.add_argument( "--patch_toolchain", @@ -162,9 +170,13 @@ def main(): ) print(f'TOOLCHAIN_ID = "{args.build_id}"', file=f) print(f"TOOLCHAIN_MAP_{version_str.upper()} = {{", file=f) - for key, value in sorted(mongo_toolchain_version.items(), key=lambda x: x[0]): + for key, value in sorted( + mongo_toolchain_version.items(), key=lambda x: x[0] + ): print(f' "{key}": {{', file=f) - for subkey, subvalue in sorted(value.items(), key=lambda x: x[0]): + for subkey, subvalue in sorted( + value.items(), key=lambda x: x[0] + ): print(f' "{subkey}": "{subvalue}",', file=f) print(" },", file=f) print("}", file=f) diff --git a/bazel/toolchains/mongo_toolchain_version_generator.py b/bazel/toolchains/mongo_toolchain_version_generator.py index 035237a5f17..2bf1f40bd57 100644 --- a/bazel/toolchains/mongo_toolchain_version_generator.py +++ b/bazel/toolchains/mongo_toolchain_version_generator.py @@ -8,6 +8,7 @@ import pathlib import tempfile import urllib.request + def sha256_file(filename: str) -> str: sha256_hash = hashlib.sha256() with open(filename, "rb") as f: @@ -15,23 +16,29 @@ def sha256_file(filename: str) -> str: sha256_hash.update(block) return sha256_hash.hexdigest() + def main(): parser = argparse.ArgumentParser() - parser.add_argument("patch_build_id", - help="Patch build id from toolchain-builder project.") - parser.add_argument("patch_build_date_string", - help="Patch build date string from toolchain-builder project, get this at the task URL, ex the date is 24_01_09_16_10_07 for https://spruce.mongodb.com/task/toolchain_builder_amazon2023_compile_11bae3c145a48dd7be9ee8aa44e5591783f787aa_24_01_09_16_10_07/") + parser.add_argument( + "patch_build_id", help="Patch build id from toolchain-builder project." + ) + parser.add_argument( + "patch_build_date_string", + help="Patch build date string from toolchain-builder project, get this at the task URL, ex the date is 24_01_09_16_10_07 for https://spruce.mongodb.com/task/toolchain_builder_amazon2023_compile_11bae3c145a48dd7be9ee8aa44e5591783f787aa_24_01_09_16_10_07/", + ) args = parser.parse_args() mongo_toolchain_version = {} - version_file_path = os.path.join(pathlib.Path(__file__).parent.resolve(), "mongo_toolchain_version.bzl") + version_file_path = os.path.join( + pathlib.Path(__file__).parent.resolve(), "mongo_toolchain_version.bzl" + ) with open(version_file_path, "r") as f: code = compile(f.read(), version_file_path, "exec") exec(code, {}, mongo_toolchain_version) - + for toolchain_name, toolchain in mongo_toolchain_version["TOOLCHAIN_MAP"].items(): - underscore_platform_name = toolchain['platform_name'].replace("-", "_") + underscore_platform_name = toolchain["platform_name"].replace("-", "_") toolchain_url = mongo_toolchain_version["TOOLCHAIN_URL_FORMAT"].format( platform_name=toolchain["platform_name"], @@ -39,9 +46,12 @@ def main(): patch_build_id=args.patch_build_id, patch_build_date=args.patch_build_date_string, ) - + temp_dir = tempfile.gettempdir() - local_tarball_path = os.path.join(temp_dir, f"bazel_v4_toolchain_builder_{underscore_platform_name}_{args.patch_build_id}.tar.gz") + local_tarball_path = os.path.join( + temp_dir, + f"bazel_v4_toolchain_builder_{underscore_platform_name}_{args.patch_build_id}.tar.gz", + ) print(f"Downloading {toolchain_url}...") @@ -54,15 +64,23 @@ def main(): with open(version_file_path, "w") as f: print(f"Writing toolchain map to {version_file_path}...") - print("# Use mongo/bazel/toolchains/toolchain_generator.py to generate this mapping for a given patch build.\n", file=f) - print(f"TOOLCHAIN_URL_FORMAT = \"{mongo_toolchain_version['TOOLCHAIN_URL_FORMAT']}\"", file=f) - print(f"TOOLCHAIN_PATCH_BUILD_ID = \"{args.patch_build_id}\"", file=f) - print(f"TOOLCHAIN_PATCH_BUILD_DATE = \"{args.patch_build_date_string}\"", file=f) + print( + "# Use mongo/bazel/toolchains/toolchain_generator.py to generate this mapping for a given patch build.\n", + file=f, + ) + print( + f"TOOLCHAIN_URL_FORMAT = \"{mongo_toolchain_version['TOOLCHAIN_URL_FORMAT']}\"", + file=f, + ) + print(f'TOOLCHAIN_PATCH_BUILD_ID = "{args.patch_build_id}"', file=f) + print(f'TOOLCHAIN_PATCH_BUILD_DATE = "{args.patch_build_date_string}"', file=f) print("TOOLCHAIN_MAP = {", file=f) - for key, value in sorted(mongo_toolchain_version["TOOLCHAIN_MAP"].items(), key=lambda x: x[0]): - print(f" \"{key}\": {{", file=f) - for subkey, subvalue in sorted(value.items(), key=lambda x: x[0]): - print(f" \"{subkey}\": \"{subvalue}\",", file=f) + for key, value in sorted( + mongo_toolchain_version["TOOLCHAIN_MAP"].items(), key=lambda x: x[0] + ): + print(f' "{key}": {{', file=f) + for subkey, subvalue in sorted(value.items(), key=lambda x: x[0]): + print(f' "{subkey}": "{subvalue}",', file=f) print(" },", file=f) print("}", file=f) @@ -70,5 +88,6 @@ def main(): print(f"Finished writing to {version_file_path}:") print(f.read()) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/bazel/wrapper_hook/autogenerated_targets.py b/bazel/wrapper_hook/autogenerated_targets.py index 039c4fe2827..e3428f53ab5 100644 --- a/bazel/wrapper_hook/autogenerated_targets.py +++ b/bazel/wrapper_hook/autogenerated_targets.py @@ -27,7 +27,13 @@ bazel_tags_to_autogenerate = [ def autogenerate_targets(args, bazel): bazel_autogenerate_flag = "include_autogenerated_targets" output_file = os.path.join( - "src", "mongo", "db", "modules", "enterprise", "autogenerated_targets", "BUILD.bazel" + "src", + "mongo", + "db", + "modules", + "enterprise", + "autogenerated_targets", + "BUILD.bazel", ) if not any(bazel_autogenerate_flag in arg for arg in args): if os.path.exists(output_file): @@ -40,7 +46,13 @@ def generate_targets( args=[], bazel="bazel", output_file=os.path.join( - "src", "mongo", "db", "modules", "enterprise", "autogenerated_targets", "BUILD.bazel" + "src", + "mongo", + "db", + "modules", + "enterprise", + "autogenerated_targets", + "BUILD.bazel", ), ): targets = [] diff --git a/bazel/wrapper_hook/compiledb.py b/bazel/wrapper_hook/compiledb.py index 92ca5c1b936..51df9b1c0eb 100644 --- a/bazel/wrapper_hook/compiledb.py +++ b/bazel/wrapper_hook/compiledb.py @@ -45,7 +45,6 @@ def run_pty_command(cmd): def generate_compiledb(bazel_bin, persistent_compdb, enterprise): - # compiledb ignores command line args so just make a version rc file in anycase write_mongo_variables_bazelrc([]) if persistent_compdb: @@ -165,7 +164,8 @@ def generate_compiledb(bazel_bin, persistent_compdb, enterprise): if external_link.exists(): os.unlink(external_link) os.symlink( - pathlib.Path(os.readlink(REPO_ROOT / "bazel-out")).parent.parent.parent / "external", + pathlib.Path(os.readlink(REPO_ROOT / "bazel-out")).parent.parent.parent + / "external", external_link, target_is_directory=True, ) @@ -199,7 +199,9 @@ def generate_compiledb(bazel_bin, persistent_compdb, enterprise): os.chmod(config, 0o744) with fileinput.FileInput(config, inplace=True) as file: for line in file: - print(line.replace("bazel-out/", f"{symlink_prefix}out/"), end="") + print( + line.replace("bazel-out/", f"{symlink_prefix}out/"), end="" + ) shutil.copyfile(configs[1], clang_tidy_file) with open(".mongo_checks_module_path", "w") as f: f.write( @@ -216,10 +218,14 @@ def generate_compiledb(bazel_bin, persistent_compdb, enterprise): shutil.copyfile(pathlib.Path("bazel-bin") / ".clang-tidy", clang_tidy_file) if persistent_compdb: shutdown_proc = subprocess.run( - [bazel_bin, f"--output_base={output_base}", "shutdown"], capture_output=True, text=True + [bazel_bin, f"--output_base={output_base}", "shutdown"], + capture_output=True, + text=True, ) if shutdown_proc.returncode != 0: - print(f"Failed to shutdown compiledb output_base: {shutdown_proc.returncode}") + print( + f"Failed to shutdown compiledb output_base: {shutdown_proc.returncode}" + ) print("--- stdout ---:") print(shutdown_proc.stdout) print("--- stderr ---:") diff --git a/bazel/wrapper_hook/engflow_check.py b/bazel/wrapper_hook/engflow_check.py index 12d19f64acb..1f6baf3255d 100644 --- a/bazel/wrapper_hook/engflow_check.py +++ b/bazel/wrapper_hook/engflow_check.py @@ -26,6 +26,9 @@ def engflow_auth(args): and "--config local" not in args_str and "--config public-release" not in args_str ): - if os.environ.get("CI") is None and platform.machine().lower() not in {"ppc64le", "s390x"}: + if os.environ.get("CI") is None and platform.machine().lower() not in { + "ppc64le", + "s390x", + }: setup_auth_wrapper() wrapper_debug(f"engflow auth time: {time.time() - start}") diff --git a/bazel/wrapper_hook/install_modules.py b/bazel/wrapper_hook/install_modules.py index b9f3661b3e2..7d8190e491f 100644 --- a/bazel/wrapper_hook/install_modules.py +++ b/bazel/wrapper_hook/install_modules.py @@ -14,7 +14,9 @@ from bazel.wrapper_hook.wrapper_debug import wrapper_debug def get_deps_dirs(deps): - tmp_dir = pathlib.Path(os.environ["Temp"] if platform.system() == "Windows" else "/tmp") + tmp_dir = pathlib.Path( + os.environ["Temp"] if platform.system() == "Windows" else "/tmp" + ) bazel_bin = REPO_ROOT / "bazel-bin" for dep in deps: try: @@ -41,7 +43,9 @@ def add_module_to_path(poetry_dir, modules_added): def setup_python_path(): - tmp_dir = pathlib.Path(os.environ["Temp"] if platform.system() == "Windows" else "/tmp") + tmp_dir = pathlib.Path( + os.environ["Temp"] if platform.system() == "Windows" else "/tmp" + ) modules_added = set() for out_dir in [ @@ -92,7 +96,9 @@ def search_for_modules(deps, deps_installed, lockfile_changed=False): def install_modules(bazel): need_to_install = False pwd_hash = hashlib.md5(str(REPO_ROOT).encode()).hexdigest() - lockfile_hash_file = pathlib.Path(tempfile.gettempdir()) / f"{pwd_hash}_lockfile_hash" + lockfile_hash_file = ( + pathlib.Path(tempfile.gettempdir()) / f"{pwd_hash}_lockfile_hash" + ) with open(REPO_ROOT / "poetry.lock", "rb") as f: current_hash = hashlib.md5(f.read()).hexdigest() @@ -132,7 +138,9 @@ def install_modules(bazel): ] ) if proc.returncode != 0: - print("Failed to install modules using remote exec/cache, falling back to local...") + print( + "Failed to install modules using remote exec/cache, falling back to local..." + ) proc = subprocess.run( cmd + [ diff --git a/bazel/wrapper_hook/lint.py b/bazel/wrapper_hook/lint.py index caa26287790..babf435d650 100644 --- a/bazel/wrapper_hook/lint.py +++ b/bazel/wrapper_hook/lint.py @@ -83,7 +83,9 @@ def list_files_without_targets( "src/mongo/util/processinfo_solaris.cpp", } - typed_files_in_targets = [line for line in files_with_targets if line.endswith(f".{ext}")] + typed_files_in_targets = [ + line for line in files_with_targets if line.endswith(f".{ext}") + ] print(f"Checking that all {type_name} files have BUILD.bazel targets...") @@ -184,7 +186,9 @@ def run_rules_lint(bazel_bin: str, args: List[str]) -> bool: # so that the naive thing of pasting that flag to lint.sh will do what the user expects. if "--fix" in args: fix = "patch" - args.extend(["--@aspect_rules_lint//lint:fix", "--output_groups=rules_lint_patch"]) + args.extend( + ["--@aspect_rules_lint//lint:fix", "--output_groups=rules_lint_patch"] + ) args.remove("--fix") # the --dry-run flag must immediately follow the --fix flag @@ -208,7 +212,15 @@ def run_rules_lint(bazel_bin: str, args: List[str]) -> bool: # jq on windows outputs CRLF which breaks this script. https://github.com/jqlang/jq/issues/92 valid_reports = ( subprocess.run( - ["jq", "--arg", "ext", ".out", "--raw-output", filter_expr, buildevents_path], + [ + "jq", + "--arg", + "ext", + ".out", + "--raw-output", + filter_expr, + buildevents_path, + ], capture_output=True, text=True, check=True, @@ -220,7 +232,11 @@ def run_rules_lint(bazel_bin: str, args: List[str]) -> bool: failing_reports = 0 for report in valid_reports: # Exclude coverage reports, and check if the output is empty. - if "coverage.dat" in report or not os.path.exists(report) or not os.path.getsize(report): + if ( + "coverage.dat" in report + or not os.path.exists(report) + or not os.path.getsize(report) + ): # Report is empty. No linting errors. continue with open(report, "r", encoding="utf-8") as f: @@ -238,7 +254,15 @@ def run_rules_lint(bazel_bin: str, args: List[str]) -> bool: if fix: valid_patches = ( subprocess.run( - ["jq", "--arg", "ext", ".patch", "--raw-output", filter_expr, buildevents_path], + [ + "jq", + "--arg", + "ext", + ".patch", + "--raw-output", + filter_expr, + buildevents_path, + ], capture_output=True, text=True, check=True, @@ -249,7 +273,11 @@ def run_rules_lint(bazel_bin: str, args: List[str]) -> bool: for patch in valid_patches: # Exclude coverage, and check if the patch is empty. - if "coverage.dat" in patch or not os.path.exists(patch) or not os.path.getsize(patch): + if ( + "coverage.dat" in patch + or not os.path.exists(patch) + or not os.path.getsize(patch) + ): # Patch is empty. No linting errors. continue @@ -260,7 +288,9 @@ def run_rules_lint(bazel_bin: str, args: List[str]) -> bool: print() elif fix == "patch": subprocess.run( - ["patch", "-p1"], check=True, stdin=open(patch, "r", encoding="utf-8") + ["patch", "-p1"], + check=True, + stdin=open(patch, "r", encoding="utf-8"), ) else: print(f"ERROR: unknown fix type {fix}", file=sys.stderr) diff --git a/bazel/wrapper_hook/plus_interface.py b/bazel/wrapper_hook/plus_interface.py index 3aa0e044433..77d3cff5afe 100644 --- a/bazel/wrapper_hook/plus_interface.py +++ b/bazel/wrapper_hook/plus_interface.py @@ -25,7 +25,9 @@ class DuplicateSourceNames(Exception): def get_buildozer_output(autocomplete_query): from buildscripts.install_bazel import install_bazel - buildozer_name = "buildozer" if not platform.system() == "Windows" else "buildozer.exe" + buildozer_name = ( + "buildozer" if not platform.system() == "Windows" else "buildozer.exe" + ) buildozer = shutil.which(buildozer_name) if not buildozer: buildozer = str(pathlib.Path(f"~/.local/bin/{buildozer_name}").expanduser()) @@ -87,7 +89,12 @@ def test_runner_interface( if autocomplete_query: str_args = " ".join(args) - if "'//:*'" in str_args or "':*'" in str_args or "//:all" in str_args or ":all" in str_args: + if ( + "'//:*'" in str_args + or "':*'" in str_args + or "//:all" in str_args + or ":all" in str_args + ): plus_autocomplete_query = True if os.environ.get("CI") is not None: @@ -174,7 +181,9 @@ def test_runner_interface( if not real_target: for bin_target in set(sources_to_bin.values()): if ( - pathlib.Path(bin_target.replace("//", "").replace(":", "/")).name + pathlib.Path( + bin_target.replace("//", "").replace(":", "/") + ).name == test_name ): bin_targets.append(bin_target) diff --git a/bazel/wrapper_hook/set_mongo_variables.py b/bazel/wrapper_hook/set_mongo_variables.py index bc3abf5afe2..b3c22ce663a 100644 --- a/bazel/wrapper_hook/set_mongo_variables.py +++ b/bazel/wrapper_hook/set_mongo_variables.py @@ -13,6 +13,7 @@ ARCH_NORMALIZE_MAP = { "s390x": "s390x", } + def get_mongo_arch(args): arch = platform.machine().lower() if arch in ARCH_NORMALIZE_MAP: @@ -20,14 +21,18 @@ def get_mongo_arch(args): else: return arch + def get_mongo_version(args): - proc = subprocess.run(["git", "describe", "--abbrev=0"], capture_output=True, text=True) + proc = subprocess.run( + ["git", "describe", "--abbrev=0"], capture_output=True, text=True + ) return proc.stdout.strip()[1:] + def write_mongo_variables_bazelrc(args): mongo_version = get_mongo_version(args) mongo_arch = get_mongo_arch(args) - + repo_root = pathlib.Path(os.path.abspath(__file__)).parent.parent.parent version_file = os.path.join(repo_root, ".bazelrc.mongo_variables") existing_hash = "" @@ -42,4 +47,4 @@ common --define=MONGO_VERSION={mongo_version} current_hash = hashlib.md5(bazelrc_contents.encode()).hexdigest() if existing_hash != current_hash: with open(version_file, "w", encoding="utf-8") as f: - f.write(bazelrc_contents) + f.write(bazelrc_contents) diff --git a/bazel/wrapper_hook/wrapper_hook.py b/bazel/wrapper_hook/wrapper_hook.py index d3e66b1eeac..abb911b81ce 100644 --- a/bazel/wrapper_hook/wrapper_hook.py +++ b/bazel/wrapper_hook/wrapper_hook.py @@ -32,7 +32,12 @@ def main(): autogenerate_targets(sys.argv, sys.argv[1]) enterprise = True - if check_bazel_command_type(sys.argv[1:]) not in ["clean", "shutdown", "version", None]: + if check_bazel_command_type(sys.argv[1:]) not in [ + "clean", + "shutdown", + "version", + None, + ]: args = sys.argv enterprise_mod = REPO_ROOT / "src" / "mongo" / "db" / "modules" / "enterprise" if not enterprise_mod.exists(): diff --git a/buildscripts/aggregate_tracefiles.py b/buildscripts/aggregate_tracefiles.py index ebe3e3f16f1..8764a6c08b6 100644 --- a/buildscripts/aggregate_tracefiles.py +++ b/buildscripts/aggregate_tracefiles.py @@ -13,14 +13,14 @@ from optparse import OptionParser def aggregate(inputs, output): """Aggregate the tracefiles given in inputs to a tracefile given by output.""" - args = ['lcov'] + args = ["lcov"] for name in inputs: - args += ['-a', name] + args += ["-a", name] - args += ['-o', output] + args += ["-o", output] - print(' '.join(args)) + print(" ".join(args)) return subprocess.call(args) @@ -46,17 +46,19 @@ def main(): for path in args[:-1]: _, ext = os.path.splitext(path) - if ext == '.info': + if ext == ".info": if getfilesize(path) > 0: inputs.append(path) - elif ext == '.txt': - inputs += [line.strip() for line in open(path) if getfilesize(line.strip()) > 0] + elif ext == ".txt": + inputs += [ + line.strip() for line in open(path) if getfilesize(line.strip()) > 0 + ] else: return "unrecognized file type" return aggregate(inputs, args[-1]) -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/buildscripts/antithesis/topologies/sharded_cluster/scripts/mongos_init.py b/buildscripts/antithesis/topologies/sharded_cluster/scripts/mongos_init.py index 5a687bee0a5..dd5cf6a6091 100644 --- a/buildscripts/antithesis/topologies/sharded_cluster/scripts/mongos_init.py +++ b/buildscripts/antithesis/topologies/sharded_cluster/scripts/mongos_init.py @@ -1,4 +1,5 @@ """Script to configure a sharded cluster in Antithesis from the mongos container.""" + import json import subprocess from time import sleep @@ -59,7 +60,8 @@ retry_until_success(mongo_process_running, {"host": "configsvr1", "port": 27019} retry_until_success(mongo_process_running, {"host": "configsvr2", "port": 27019}) retry_until_success(mongo_process_running, {"host": "configsvr3", "port": 27019}) retry_until_success( - subprocess.run, { + subprocess.run, + { "args": [ "mongo", "--host", @@ -70,14 +72,16 @@ retry_until_success( f"rs.initiate({json.dumps(CONFIGSVR_CONFIG)})", ], "check": True, - }) + }, +) # Create Shard1 once all nodes are up retry_until_success(mongo_process_running, {"host": "database1", "port": 27018}) retry_until_success(mongo_process_running, {"host": "database2", "port": 27018}) retry_until_success(mongo_process_running, {"host": "database3", "port": 27018}) retry_until_success( - subprocess.run, { + subprocess.run, + { "args": [ "mongo", "--host", @@ -88,14 +92,16 @@ retry_until_success( f"rs.initiate({json.dumps(SHARD1_CONFIG)})", ], "check": True, - }) + }, +) # Create Shard2 once all nodes are up retry_until_success(mongo_process_running, {"host": "database4", "port": 27018}) retry_until_success(mongo_process_running, {"host": "database5", "port": 27018}) retry_until_success(mongo_process_running, {"host": "database6", "port": 27018}) retry_until_success( - subprocess.run, { + subprocess.run, + { "args": [ "mongo", "--host", @@ -106,11 +112,13 @@ retry_until_success( f"rs.initiate({json.dumps(SHARD2_CONFIG)})", ], "check": True, - }) + }, +) # Start mongos retry_until_success( - subprocess.run, { + subprocess.run, + { "args": [ "mongos", "--bind_ip", @@ -126,11 +134,13 @@ retry_until_success( "--fork", ], "check": True, - }) + }, +) # Add shards to cluster retry_until_success( - subprocess.run, { + subprocess.run, + { "args": [ "mongo", "--host", @@ -141,9 +151,11 @@ retry_until_success( 'sh.addShard("Shard1/database1:27018,database2:27018,database3:27018")', ], "check": True, - }) + }, +) retry_until_success( - subprocess.run, { + subprocess.run, + { "args": [ "mongo", "--host", @@ -154,7 +166,8 @@ retry_until_success( 'sh.addShard("Shard2/database4:27018,database5:27018,database6:27018")', ], "check": True, - }) + }, +) while True: sleep(10) diff --git a/buildscripts/antithesis/topologies/sharded_cluster/scripts/utils.py b/buildscripts/antithesis/topologies/sharded_cluster/scripts/utils.py index 3338c68e7e0..13f97f03995 100644 --- a/buildscripts/antithesis/topologies/sharded_cluster/scripts/utils.py +++ b/buildscripts/antithesis/topologies/sharded_cluster/scripts/utils.py @@ -1,12 +1,15 @@ """Util functions to assist in setting up a sharded cluster topology in Antithesis.""" + import subprocess import time def mongo_process_running(host, port): """Check to see if the process at the given host & port is running.""" - return subprocess.run(['mongo', '--host', host, '--port', - str(port), '--eval', '"db.stats()"'], check=True) + return subprocess.run( + ["mongo", "--host", host, "--port", str(port), "--eval", '"db.stats()"'], + check=True, + ) def retry_until_success(func, kwargs=None, wait_time=1, timeout_period=30): @@ -16,10 +19,13 @@ def retry_until_success(func, kwargs=None, wait_time=1, timeout_period=30): while True: if time.time() > timeout: raise TimeoutError( - f"{func.__name__} called with {kwargs} timed out after {timeout_period} second(s).") + f"{func.__name__} called with {kwargs} timed out after {timeout_period} second(s)." + ) try: func(**kwargs) break except: # pylint: disable=bare-except - print(f"Retrying {func.__name__} called with {kwargs} after {wait_time} second(s).") + print( + f"Retrying {func.__name__} called with {kwargs} after {wait_time} second(s)." + ) time.sleep(wait_time) diff --git a/buildscripts/antithesis/topologies/sharded_cluster/scripts/workload_init.py b/buildscripts/antithesis/topologies/sharded_cluster/scripts/workload_init.py index 971cedea74a..0caa3764135 100644 --- a/buildscripts/antithesis/topologies/sharded_cluster/scripts/workload_init.py +++ b/buildscripts/antithesis/topologies/sharded_cluster/scripts/workload_init.py @@ -1,4 +1,5 @@ """Script to initialize a workload container in Antithesis.""" + from time import sleep while True: diff --git a/buildscripts/apply_clang_tidy_fixes.py b/buildscripts/apply_clang_tidy_fixes.py index 64170466cc9..f71480cb7a0 100755 --- a/buildscripts/apply_clang_tidy_fixes.py +++ b/buildscripts/apply_clang_tidy_fixes.py @@ -48,13 +48,17 @@ def get_replacements_to_apply(fixes_file) -> dict: fixes_data = json.load(fin) for clang_tidy_check in fixes_data: for main_source_file in fixes_data[clang_tidy_check]: - for violation_instance in fixes_data[clang_tidy_check][main_source_file]: + for violation_instance in fixes_data[clang_tidy_check][ + main_source_file + ]: replacements = fixes_data[clang_tidy_check][main_source_file][ violation_instance ]["replacements"] if can_replacements_be_applied(replacements): for replacement in replacements: - replacements_to_apply[replacement["FilePath"]].append(replacement) + replacements_to_apply[replacement["FilePath"]].append( + replacement + ) else: print( f"""WARNING: not applying replacements for {clang_tidy_check} in {main_source_file} at offset {violation_instance}, at least one file that is part of the automatic replacement has changed since clang-tidy was run, or is not writeable.""" @@ -88,7 +92,9 @@ def _combine_errors(dir: str) -> str: if "Notes" in fix: fix_msg = fix["Notes"][0] if len(fix["Notes"]) > 1: - print(f'Warning: this script may be missing values in [{fix["Notes"]}]') + print( + f'Warning: this script may be missing values in [{fix["Notes"]}]' + ) else: fix_msg = fix["DiagnosticMessage"] fix_data = ( @@ -105,7 +111,9 @@ def _combine_errors(dir: str) -> str: .setdefault( str(fix_msg.get("FileOffset", "FileOffset Not Found")), { - "replacements": fix_msg.get("Replacements", "Replacements not found"), + "replacements": fix_msg.get( + "Replacements", "Replacements not found" + ), "message": fix_msg.get("Message", "Message not found"), "count": 0, "source_files": [], @@ -171,7 +179,9 @@ def main(argv=sys.argv[1:]): ] = replacement["ReplacementText"].encode() if replacement["Length"] != len(replacement["ReplacementText"]): - adjustments += len(replacement["ReplacementText"]) - replacement["Length"] + adjustments += ( + len(replacement["ReplacementText"]) - replacement["Length"] + ) with open(file, "wb") as fout: fout.write(bytes(file_bytes)) diff --git a/buildscripts/bazel_rules_mongo/codeowners/codeowners_generate.py b/buildscripts/bazel_rules_mongo/codeowners/codeowners_generate.py index be0ee5e094b..658e1df44dc 100644 --- a/buildscripts/bazel_rules_mongo/codeowners/codeowners_generate.py +++ b/buildscripts/bazel_rules_mongo/codeowners/codeowners_generate.py @@ -31,13 +31,14 @@ def add_file_to_tree(root_node: FileNode, file_parts: List[str]): for i, dir in enumerate(file_parts[:-1]): node_dirs = current_node.dirs if dir not in node_dirs: - directory = "/".join(file_parts[:i + 1]) + directory = "/".join(file_parts[: i + 1]) node_dirs[dir] = FileNode(f"./{directory}") current_node = node_dirs[dir] - assert (current_node.owners_file is - None), f"{'/'.join(file_parts[:-1])} there are two OWNERS files in this directory" + assert ( + current_node.owners_file is None + ), f"{'/'.join(file_parts[:-1])} there are two OWNERS files in this directory" current_node.owners_file = file_parts[-1] @@ -65,7 +66,9 @@ def process_owners_file(output_lines: list[str], node: FileNode) -> None: with open(owners_file_path, "r") as file: contents = yaml.safe_load(file) assert "version" in contents, f"Version not found in {owners_file_path}" - assert contents["version"] in parsers, f"Unsupported version in {owners_file_path}" + assert ( + contents["version"] in parsers + ), f"Unsupported version in {owners_file_path}" parser = parsers[contents["version"]] owners_lines = parser.parse(directory, owners_file_path, contents) output_lines.extend(owners_lines) @@ -91,7 +94,9 @@ def print_diff_and_instructions(old_codeowners_contents, new_codeowners_contents ) sys.stdout.writelines(diff) - print("If you are seeing this message in CI you likely need to run `bazel run codeowners`") + print( + "If you are seeing this message in CI you likely need to run `bazel run codeowners`" + ) def validate_generated_codeowners(validator_path: str) -> int: @@ -114,7 +119,9 @@ def validate_generated_codeowners(validator_path: str) -> int: @cache -def get_unowned_files(codeowners_binary_path: str, codeowners_file: str = None) -> Set[str]: +def get_unowned_files( + codeowners_binary_path: str, codeowners_file: str = None +) -> Set[str]: temp_output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt") temp_output_file.close() codeowners_file_arg = "" @@ -140,7 +147,9 @@ def get_unowned_files(codeowners_binary_path: str, codeowners_file: str = None) return unowned_files -def check_new_files(codeowners_binary_path: str, expansions_file: str, branch: str) -> int: +def check_new_files( + codeowners_binary_path: str, expansions_file: str, branch: str +) -> int: new_files = evergreen_git.get_new_files(expansions_file, branch) if not new_files: print("No new files were detected.") @@ -168,27 +177,35 @@ def check_new_files(codeowners_binary_path: str, expansions_file: str, branch: s return 0 -def check_orphaned_files(codeowners_binary_path: str, expansions_file: str, branch: str, - codeowners_file: str) -> int: +def check_orphaned_files( + codeowners_binary_path: str, expansions_file: str, branch: str, codeowners_file: str +) -> int: # This compares the new codeowners file with the old codeowners file on the same working tree # This tells us which coverage is lost between codeowners file changes current_unowned_files = get_unowned_files(codeowners_binary_path) base_revision = evergreen_git.get_diff_revision(expansions_file, branch) previous_codeowners_file_contents = evergreen_git.get_file_at_revision( - codeowners_file, base_revision) + codeowners_file, base_revision + ) if previous_codeowners_file_contents is None: return 0 - temp_codeowners_file = tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".txt") + temp_codeowners_file = tempfile.NamedTemporaryFile( + mode="w", delete=False, suffix=".txt" + ) temp_codeowners_file.write(previous_codeowners_file_contents) temp_codeowners_file.close() - old_unowned_files = get_unowned_files(codeowners_binary_path, temp_codeowners_file.name) + old_unowned_files = get_unowned_files( + codeowners_binary_path, temp_codeowners_file.name + ) unowned_files_difference = current_unowned_files - old_unowned_files if not unowned_files_difference: print("No files have lost ownership with these changes.") return 0 - print("The following files lost ownership with CODEOWNERS changes:", file=sys.stderr) + print( + "The following files lost ownership with CODEOWNERS changes:", file=sys.stderr + ) for file in sorted(unowned_files_difference): print(f"- {file}", file=sys.stderr) @@ -196,21 +213,22 @@ def check_orphaned_files(codeowners_binary_path: str, expansions_file: str, bran def post_generation_checks( - validator_path: str, - should_run_validation: bool, - codeowners_binary_path: str, - should_check_new_files: bool, - expansions_file: str, - branch: str, - codeowners_file_path: str, + validator_path: str, + should_run_validation: bool, + codeowners_binary_path: str, + should_check_new_files: bool, + expansions_file: str, + branch: str, + codeowners_file_path: str, ) -> int: status = 0 if should_run_validation: status |= validate_generated_codeowners(validator_path) if should_check_new_files: status |= check_new_files(codeowners_binary_path, expansions_file, branch) - status |= check_orphaned_files(codeowners_binary_path, expansions_file, branch, - codeowners_file_path) + status |= check_orphaned_files( + codeowners_binary_path, expansions_file, branch, codeowners_file_path + ) return status @@ -219,8 +237,12 @@ def main(): # If we are running in bazel, default the directory to the workspace default_dir = os.environ.get("BUILD_WORKSPACE_DIRECTORY") if not default_dir: - process = subprocess.run(["git", "rev-parse", "--show-toplevel"], capture_output=True, - text=True, check=True) + process = subprocess.run( + ["git", "rev-parse", "--show-toplevel"], + capture_output=True, + text=True, + check=True, + ) default_dir = process.stdout.strip() codeowners_validator_path = os.environ.get("CODEOWNERS_VALIDATOR_PATH") @@ -247,12 +269,14 @@ def main(): help="Path of the CODEOWNERS file to be generated.", default=os.path.join(".github", "CODEOWNERS"), ) - parser.add_argument("--repo-dir", help="Root of the repo to scan for OWNER files.", - default=default_dir) + parser.add_argument( + "--repo-dir", + help="Root of the repo to scan for OWNER files.", + default=default_dir, + ) parser.add_argument( "--check", - help= - "When set, program exits 1 when the CODEOWNERS content changes. This will skip generation", + help="When set, program exits 1 when the CODEOWNERS content changes. This will skip generation", default=False, action="store_true", ) @@ -276,8 +300,7 @@ def main(): ) parser.add_argument( "--branch", - help= - "Helps the script understand what branch to compare against to see what new files are added when run locally. Defaults to master or main.", + help="Helps the script understand what branch to compare against to see what new files are added when run locally. Defaults to master or main.", default=None, action="store", ) @@ -301,9 +324,14 @@ def main(): root_node = build_tree(files) process_dir(output_lines, root_node) except Exception as ex: - print("An exception was found while generating the CODEOWNERS file.", file=sys.stderr) - print("Please refer to the docs to see the spec for OWNERS.yml files here :", - file=sys.stderr) + print( + "An exception was found while generating the CODEOWNERS file.", + file=sys.stderr, + ) + print( + "Please refer to the docs to see the spec for OWNERS.yml files here :", + file=sys.stderr, + ) print( "https://github.com/10gen/mongo/blob/master/docs/owners/owners_format.md", file=sys.stderr, @@ -327,7 +355,8 @@ def main(): should_check_new_files = True else: raise RuntimeError( - f"Invalid value for CODEOWNERS_CHECK_NEW_FILES: {should_check_new_files}") + f"Invalid value for CODEOWNERS_CHECK_NEW_FILES: {should_check_new_files}" + ) else: should_check_new_files = args.check_new_files @@ -350,7 +379,9 @@ def main(): with open(output_file, "w") as file: file.write(new_contents) - print(f"Successfully wrote to the CODEOWNERS file at: {os.path.abspath(output_file)}") + print( + f"Successfully wrote to the CODEOWNERS file at: {os.path.abspath(output_file)}" + ) # Add validation after generating CODEOWNERS file return post_generation_checks( diff --git a/buildscripts/bazel_rules_mongo/codeowners/parsers/owners_v1.py b/buildscripts/bazel_rules_mongo/codeowners/parsers/owners_v1.py index d460a70b672..229a4220e67 100644 --- a/buildscripts/bazel_rules_mongo/codeowners/parsers/owners_v1.py +++ b/buildscripts/bazel_rules_mongo/codeowners/parsers/owners_v1.py @@ -9,12 +9,16 @@ import yaml # Parser for OWNERS.yml files version 1.0.0 class OwnersParserV1: - def parse(self, directory: str, owners_file_path: str, contents: Dict[str, any]) -> List[str]: + def parse( + self, directory: str, owners_file_path: str, contents: Dict[str, any] + ) -> List[str]: lines = [] no_parent_owners = False if "options" in contents: options = contents["options"] - no_parent_owners = "no_parent_owners" in options and options["no_parent_owners"] + no_parent_owners = ( + "no_parent_owners" in options and options["no_parent_owners"] + ) if no_parent_owners: # Specfying no owners will ensure that no file in this directory has an owner unless it @@ -28,15 +32,18 @@ class OwnersParserV1: if "filters" in contents: filters = contents["filters"] for _filter in filters: - assert ("approvers" in _filter - ), f"Filter in {owners_file_path} does not have approvers." + assert ( + "approvers" in _filter + ), f"Filter in {owners_file_path} does not have approvers." approvers = _filter["approvers"] del _filter["approvers"] if "metadata" in _filter: del _filter["metadata"] # the last key remaining should be the pattern for the filter - assert len(_filter) == 1, f"Filter in {owners_file_path} has incorrect values." + assert ( + len(_filter) == 1 + ), f"Filter in {owners_file_path} has incorrect values." pattern = next(iter(_filter)) owners: set[str] = set() @@ -44,7 +51,9 @@ class OwnersParserV1: if "@" in owner: # approver is email, just add as is if not owner.endswith("@mongodb.com"): - raise RuntimeError("Any emails specified must be a mongodb.com email.") + raise RuntimeError( + "Any emails specified must be a mongodb.com email." + ) owners.add(owner) else: # approver is github username, need to prefix with @ @@ -52,8 +61,9 @@ class OwnersParserV1: NOOWNERS_NAME = "NOOWNERS-DO-NOT-USE-DEPRECATED-2024-07-01" if NOOWNERS_NAME in approvers: - assert (len(approvers) == 1 - ), f"{NOOWNERS_NAME} must be the only approver when it is used." + assert ( + len(approvers) == 1 + ), f"{NOOWNERS_NAME} must be the only approver when it is used." else: for approver in approvers: if approver in aliases: @@ -72,7 +82,8 @@ class OwnersParserV1: def process_alias_import(self, path: str) -> Dict[str, List[str]]: if not path.startswith("//"): raise RuntimeError( - f"Alias file paths must start with // and be relative to the repo root: {path}") + f"Alias file paths must start with // and be relative to the repo root: {path}" + ) # remove // from beginning of path parsed_path = path[2::] @@ -108,7 +119,9 @@ class OwnersParserV1: parsed_pattern = f"/{directory}/**/{pattern}" if not self.test_pattern(parsed_pattern): - raise (RuntimeError(f"Can not find any files that match pattern: `{pattern}`")) + raise ( + RuntimeError(f"Can not find any files that match pattern: `{pattern}`") + ) return self.get_line(parsed_pattern, owners) diff --git a/buildscripts/bazel_rules_mongo/codeowners/parsers/owners_v2.py b/buildscripts/bazel_rules_mongo/codeowners/parsers/owners_v2.py index 106b28a46dc..964614f5196 100644 --- a/buildscripts/bazel_rules_mongo/codeowners/parsers/owners_v2.py +++ b/buildscripts/bazel_rules_mongo/codeowners/parsers/owners_v2.py @@ -24,6 +24,8 @@ class OwnersParserV2(OwnersParserV1): parsed_pattern = f"/{directory}/{pattern}" if not self.test_pattern(parsed_pattern): - raise (RuntimeError(f"Can not find any files that match pattern: `{pattern}`")) + raise ( + RuntimeError(f"Can not find any files that match pattern: `{pattern}`") + ) return self.get_line(parsed_pattern, owners) diff --git a/buildscripts/bazel_rules_mongo/codeowners/validate_codeowners.py b/buildscripts/bazel_rules_mongo/codeowners/validate_codeowners.py index 37f7cc3c87c..a407cdf40dd 100755 --- a/buildscripts/bazel_rules_mongo/codeowners/validate_codeowners.py +++ b/buildscripts/bazel_rules_mongo/codeowners/validate_codeowners.py @@ -10,11 +10,13 @@ def get_validator_env() -> dict: """Prepare the environment for the codeowners-validator.""" env = os.environ.copy() - env.update({ - "REPOSITORY_PATH": ".", - "CHECKS": "duppatterns,syntax", - "EXPERIMENTAL_CHECKS": "avoid-shadowing", - }) + env.update( + { + "REPOSITORY_PATH": ".", + "CHECKS": "duppatterns,syntax", + "EXPERIMENTAL_CHECKS": "avoid-shadowing", + } + ) return env @@ -28,8 +30,9 @@ def run_validator(validator_path: str) -> int: env = get_validator_env() try: - result = subprocess.run([validator_path], env=env, check=True, capture_output=True, - text=True) + result = subprocess.run( + [validator_path], env=env, check=True, capture_output=True, text=True + ) if result.stdout: print(result.stdout) return 0 @@ -40,7 +43,10 @@ def run_validator(validator_path: str) -> int: print(exc.stderr, file=sys.stderr) return exc.returncode except FileNotFoundError: - print("Error: Failed to run codeowners-validator after installation", file=sys.stderr) + print( + "Error: Failed to run codeowners-validator after installation", + file=sys.stderr, + ) return 1 except Exception: raise diff --git a/buildscripts/bazel_rules_mongo/engflow_auth/engflow_auth.py b/buildscripts/bazel_rules_mongo/engflow_auth/engflow_auth.py index 8c6e62b7303..3dae603f9a3 100644 --- a/buildscripts/bazel_rules_mongo/engflow_auth/engflow_auth.py +++ b/buildscripts/bazel_rules_mongo/engflow_auth/engflow_auth.py @@ -15,21 +15,21 @@ from datetime import datetime from retry import retry -NORMALIZED_ARCH = {"x86_64": "x64", "aarch64": "arm64", "arm64": "arm64", "AMD64": "x64"} +NORMALIZED_ARCH = { + "x86_64": "x64", + "aarch64": "arm64", + "arm64": "arm64", + "AMD64": "x64", +} NORMALIZED_OS = {"Windows": "windows", "Darwin": "macos", "Linux": "linux"} CHECKSUMS = { - "engflow_auth_linux_arm64": - "ad5ffee1e6db926f5066aa40ee35517b1993851d0063ac121dbf5b407c81e2bf", - "engflow_auth_linux_x64": - "b731bae21628b2be321c24b342854c6ed1ed0326010e62a2ecf0b5650a56cf1a", - "engflow_auth_macos_arm64": - "69057929b4515d41b1af861c9bfdbc47427cc5ce5a80c94d4776c8bef672292e", - "engflow_auth_macos_x64": - "a322373e41faa7750c34348f357c5a4a670a66cfd988e80b4343c72822d91292", - "engflow_auth_windows_x64.exe": - "cb9590ffcc6731389ded173250f604b37778417450b1dc92c6bafadeef342826", + "engflow_auth_linux_arm64": "ad5ffee1e6db926f5066aa40ee35517b1993851d0063ac121dbf5b407c81e2bf", + "engflow_auth_linux_x64": "b731bae21628b2be321c24b342854c6ed1ed0326010e62a2ecf0b5650a56cf1a", + "engflow_auth_macos_arm64": "69057929b4515d41b1af861c9bfdbc47427cc5ce5a80c94d4776c8bef672292e", + "engflow_auth_macos_x64": "a322373e41faa7750c34348f357c5a4a670a66cfd988e80b4343c72822d91292", + "engflow_auth_windows_x64.exe": "cb9590ffcc6731389ded173250f604b37778417450b1dc92c6bafadeef342826", } GH_URL_PREFIX = "https://github.com/EngFlow/auth/releases/download/v0.0.13/" CLUSTER = "sodalite.cluster.engflow.com" @@ -73,7 +73,9 @@ def install(verbose: bool) -> str: binary_path += ".exe" if os.path.exists(binary_path): if verbose: - print(f"{binary_filename} already exists at {binary_path}, skipping download") + print( + f"{binary_filename} already exists at {binary_path}, skipping download" + ) else: url = GH_URL_PREFIX + tag print(f"Downloading {url}...") @@ -109,7 +111,9 @@ def update_bazelrc(binary_path: str, verbose: bool): def authenticate(binary_path: str, verbose: bool) -> bool: need_login = False - p = subprocess.run(f"{binary_path} export {CLUSTER}", shell=True, capture_output=True) + p = subprocess.run( + f"{binary_path} export {CLUSTER}", shell=True, capture_output=True + ) if p.returncode != 0: need_login = True else: diff --git a/buildscripts/bazel_rules_mongo/tests/test_changed_files.py b/buildscripts/bazel_rules_mongo/tests/test_changed_files.py index 9dc260d7e91..578f7447984 100644 --- a/buildscripts/bazel_rules_mongo/tests/test_changed_files.py +++ b/buildscripts/bazel_rules_mongo/tests/test_changed_files.py @@ -19,8 +19,10 @@ def write_file(repo: Repo, file_name: str) -> None: file.write("change\n") -@unittest.skipIf(sys.platform == "win32", - reason="This test breaks on windows and only needs to work on linux") +@unittest.skipIf( + sys.platform == "win32", + reason="This test breaks on windows and only needs to work on linux", +) class TestChangedFiles(unittest.TestCase): @classmethod def setUpClass(cls): @@ -37,7 +39,8 @@ class TestChangedFiles(unittest.TestCase): # get tracked files that have been changed that are tracked by git diff_output = root_repo.git.execute( - ["git", "diff", "--name-only", "--diff-filter=d", commit]) + ["git", "diff", "--name-only", "--diff-filter=d", commit] + ) files_to_copy.update(diff_output.split("\n")) # gets all the untracked changes in the current repo @@ -86,7 +89,9 @@ class TestChangedFiles(unittest.TestCase): # make a new file that has not been commited yet write_file(self.repo, new_file_name) - with tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) as tmp: + with tempfile.NamedTemporaryFile( + mode="w", encoding="utf-8", delete=False + ) as tmp: tmp.write("fake_expansion: true\n") self.expansions_file = tmp.name @@ -98,12 +103,18 @@ class TestChangedFiles(unittest.TestCase): def test_local_unchanged_files(self): evergreen_git.get_remote_branch_ref = MagicMock(return_value=self.base_revision) new_files = evergreen_git.get_new_files() - self.assertEqual(new_files, [], - msg="New files list was not empty when no new files were added to git.") + self.assertEqual( + new_files, + [], + msg="New files list was not empty when no new files were added to git.", + ) changed_files = evergreen_git.get_changed_files() - self.assertEqual(changed_files, [changed_file_name], - msg="Changed file list was not as expected.") + self.assertEqual( + changed_files, + [changed_file_name], + msg="Changed file list was not as expected.", + ) self.repo.git.execute(["git", "add", "."]) @@ -131,10 +142,15 @@ class TestChangedFiles(unittest.TestCase): tmp.write(f"revision: {self.base_revision}\n") self.repo.git.execute(["git", "add", "."]) new_files = evergreen_git.get_new_files(expansions_file=self.expansions_file) - self.assertEqual(new_files, [new_file_name], - msg="New file list did not contain the new file.") + self.assertEqual( + new_files, + [new_file_name], + msg="New file list did not contain the new file.", + ) - changed_files = evergreen_git.get_changed_files(expansions_file=self.expansions_file) + changed_files = evergreen_git.get_changed_files( + expansions_file=self.expansions_file + ) self.assertEqual( changed_files, [changed_file_name, new_file_name], @@ -146,10 +162,15 @@ class TestChangedFiles(unittest.TestCase): self.repo.git.execute(["git", "add", "."]) self.repo.git.execute(["git", "commit", "-m", "Fake waterfall changes"]) new_files = evergreen_git.get_new_files(expansions_file=self.expansions_file) - self.assertEqual(new_files, [new_file_name], - msg="New file list did not contain the new file.") + self.assertEqual( + new_files, + [new_file_name], + msg="New file list did not contain the new file.", + ) - changed_files = evergreen_git.get_changed_files(expansions_file=self.expansions_file) + changed_files = evergreen_git.get_changed_files( + expansions_file=self.expansions_file + ) self.assertEqual( changed_files, [changed_file_name, new_file_name], diff --git a/buildscripts/bazel_rules_mongo/utils/evergreen_git.py b/buildscripts/bazel_rules_mongo/utils/evergreen_git.py index e386cf20141..d2efa3941c8 100644 --- a/buildscripts/bazel_rules_mongo/utils/evergreen_git.py +++ b/buildscripts/bazel_rules_mongo/utils/evergreen_git.py @@ -36,7 +36,14 @@ def get_mongodb_remote(repo: Repo) -> Remote: assert len(parts) >= 2, f"Unexpected git remote url: {url}" owner = parts[-2].split(":")[-1] - if owner in ("10gen", "mongodb", "evergreen-ci", "mongodb-ets", "realm", "mongodb-js"): + if owner in ( + "10gen", + "mongodb", + "evergreen-ci", + "mongodb-ets", + "realm", + "mongodb-js", + ): picked_remote = remote print(f"Selected remote: {remote.url}") break @@ -106,13 +113,15 @@ def get_diff_revision(expansions_file: str = None, branch: str = None) -> str: return diff_commit -def get_changed_files(expansions_file: str = None, branch: str = None, - diff_filter: str = "d") -> List[str]: +def get_changed_files( + expansions_file: str = None, branch: str = None, diff_filter: str = "d" +) -> List[str]: diff_commit = get_diff_revision(expansions_file, branch) repo = Repo() output = repo.git.execute( - ["git", "diff", "--name-only", f"--diff-filter={diff_filter}", diff_commit]) + ["git", "diff", "--name-only", f"--diff-filter={diff_filter}", diff_commit] + ) files = output.split("\n") return [file for file in files if file] @@ -135,7 +144,10 @@ def get_files_to_lint() -> List[str]: tracked_files = repo.git.execute(["git", "ls-files"]).split("\n") # all unstaged files from git tracked_files.extend( - repo.git.execute(["git", "ls-files", "--others", "--exclude-standard"]).split("\n")) + repo.git.execute(["git", "ls-files", "--others", "--exclude-standard"]).split( + "\n" + ) + ) # remove any empty entries tracked_files = list(filter(bool, tracked_files)) return tracked_files diff --git a/buildscripts/build_system_options.py b/buildscripts/build_system_options.py index 140d1e4dffd..33b04f7b914 100644 --- a/buildscripts/build_system_options.py +++ b/buildscripts/build_system_options.py @@ -1,4 +1,5 @@ """Options used for building process.""" + import re @@ -18,7 +19,8 @@ class PathOptions: if not self._compiled_shared_library_file_patterns: self._compiled_shared_library_file_patterns = tuple( - map(re.compile, self._shared_library_file_patterns)) + map(re.compile, self._shared_library_file_patterns) + ) return self._compiled_shared_library_file_patterns def is_shared_library_file(self, path: str) -> bool: diff --git a/buildscripts/buildifier.py b/buildscripts/buildifier.py index b2b69f9aafa..b25bc75667e 100644 --- a/buildscripts/buildifier.py +++ b/buildscripts/buildifier.py @@ -19,16 +19,25 @@ def find_all_failed(bin_path: str) -> list[str]: if contents: ignored_paths.append(contents) - process = subprocess.run([bin_path, "--format=json", "--mode=check", "-r", "./"], check=True, - capture_output=True) + process = subprocess.run( + [bin_path, "--format=json", "--mode=check", "-r", "./"], + check=True, + capture_output=True, + ) buildifier_results = json.loads(process.stdout) if buildifier_results["success"]: return [] return [ - result["filename"] for result in buildifier_results["files"] - if (not result["formatted"] and \ - not any(result["filename"].startswith(ignored_path) for ignored_path in ignored_paths)) + result["filename"] + for result in buildifier_results["files"] + if ( + not result["formatted"] + and not any( + result["filename"].startswith(ignored_path) + for ignored_path in ignored_paths + ) + ) ] @@ -44,14 +53,18 @@ def fix_all(bin_path: str): def lint(bin_path: str, files: list[str], generate_report: bool): for file in files: - process = subprocess.run([bin_path, "--format=json", "--mode=check", file], check=True, - capture_output=True) + process = subprocess.run( + [bin_path, "--format=json", "--mode=check", file], + check=True, + capture_output=True, + ) result = json.loads(process.stdout) if result["success"]: continue # This purposefully gives a exit code of 4 when there is a diff - process = subprocess.run([bin_path, "--mode=diff", file], capture_output=True, - encoding='utf-8') + process = subprocess.run( + [bin_path, "--mode=diff", file], capture_output=True, encoding="utf-8" + ) if process.returncode not in (0, 4): raise RuntimeError() diff = process.stdout @@ -62,7 +75,8 @@ def lint(bin_path: str, files: list[str], generate_report: bool): header = ( "There are linting errors in this file, fix them with one of the following commands:\n" "python3 buildscripts/buildifier.py fix-all\n" - f"python3 buildscripts/buildifier.py fix {file}\n\n") + f"python3 buildscripts/buildifier.py fix {file}\n\n" + ) report = make_report(f"{file} warnings", json.dumps(result, indent=2), 1) try_combine_reports(report) put_report(report) @@ -79,15 +93,20 @@ def fix(bin_path: str, files: list[str]): def main(): - parser = argparse.ArgumentParser(description='buildifier wrapper') + parser = argparse.ArgumentParser(description="buildifier wrapper") parser.add_argument( - "--binary-dir", "-b", type=str, + "--binary-dir", + "-b", + type=str, help="Path to the buildifier binary, defaults to looking in the current directory.", - default="") + default="", + ) parser.add_argument( - "--generate-report", action="store_true", + "--generate-report", + action="store_true", help="Whether or not a report of the lint errors should be generated for evergreen.", - default=False) + default=False, + ) parser.set_defaults(subcommand=None) sub = parser.add_subparsers(title="buildifier subcommands", help="sub-command help") @@ -108,7 +127,8 @@ def main(): args = parser.parse_args() assert os.path.abspath(os.curdir) == str( - mongo_dir.absolute()), "buildifier.py must be run from the root of the mongo repo" + mongo_dir.absolute() + ), "buildifier.py must be run from the root of the mongo repo" binary_name = "buildifier.exe" if platform.system() == "Windows" else "buildifier" if args.binary_dir: binary_path = os.path.join(args.binary_dir, binary_name) @@ -127,7 +147,9 @@ def main(): else: # we purposefully do not use sub.choices.keys() so it does not print as a dict_keys object choices = [key for key in sub.choices] - raise RuntimeError(f"One of the following subcommands must be specified: {choices}") + raise RuntimeError( + f"One of the following subcommands must be specified: {choices}" + ) if __name__ == "__main__": diff --git a/buildscripts/burn_in_tests.py b/buildscripts/burn_in_tests.py index b34383e8d12..8f0191d24ee 100755 --- a/buildscripts/burn_in_tests.py +++ b/buildscripts/burn_in_tests.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 """Command line utility for determining what jstests have been added or modified.""" + import collections import copy import json @@ -61,18 +62,27 @@ SUITE_FILES = ["with_server"] BURN_IN_TEST_MEMBERSHIP_FILE = "burn_in_test_membership_map_file_for_ci.json" -SUPPORTED_TEST_KINDS = ("fsm_workload_test", "js_test", "json_schema_test", - "multi_stmt_txn_passthrough", "parallel_fsm_workload_test", - "all_versions_js_test") +SUPPORTED_TEST_KINDS = ( + "fsm_workload_test", + "js_test", + "json_schema_test", + "multi_stmt_txn_passthrough", + "parallel_fsm_workload_test", + "all_versions_js_test", +) RUN_ALL_FEATURE_FLAG_TESTS = "--runAllFeatureFlagTests" class RepeatConfig(object): """Configuration for how tests should be repeated.""" - def __init__(self, repeat_tests_secs: Optional[int] = None, - repeat_tests_min: Optional[int] = None, repeat_tests_max: Optional[int] = None, - repeat_tests_num: Optional[int] = None): + def __init__( + self, + repeat_tests_secs: Optional[int] = None, + repeat_tests_min: Optional[int] = None, + repeat_tests_max: Optional[int] = None, + repeat_tests_num: Optional[int] = None, + ): """ Create a Repeat Config. @@ -97,10 +107,14 @@ class RepeatConfig(object): if self.repeat_tests_max: if not self.repeat_tests_secs: - raise ValueError("Must specify --repeat-tests-secs with --repeat-tests-max") + raise ValueError( + "Must specify --repeat-tests-secs with --repeat-tests-max" + ) if self.repeat_tests_min and self.repeat_tests_min > self.repeat_tests_max: - raise ValueError("--repeat-tests-secs-min is greater than --repeat-tests-max") + raise ValueError( + "--repeat-tests-secs-min is greater than --repeat-tests-max" + ) if self.repeat_tests_min and not self.repeat_tests_secs: raise ValueError("Must specify --repeat-tests-secs with --repeat-tests-min") @@ -120,15 +134,19 @@ class RepeatConfig(object): repeat_options += f" --repeatTestsMax={self.repeat_tests_max} " return repeat_options - repeat_suites = self.repeat_tests_num if self.repeat_tests_num else REPEAT_SUITES + repeat_suites = ( + self.repeat_tests_num if self.repeat_tests_num else REPEAT_SUITES + ) return f" --repeatSuites={repeat_suites} " def __repr__(self): """Build string representation of object for debugging.""" - return "".join([ - f"RepeatConfig[num={self.repeat_tests_num}, secs={self.repeat_tests_secs}, ", - f"min={self.repeat_tests_min}, max={self.repeat_tests_max}]", - ]) + return "".join( + [ + f"RepeatConfig[num={self.repeat_tests_num}, secs={self.repeat_tests_secs}, ", + f"min={self.repeat_tests_min}, max={self.repeat_tests_max}]", + ] + ) def is_file_a_test_file(file_path: str) -> bool: @@ -161,11 +179,15 @@ def find_excludes(selector_file: str) -> Tuple[List, List, List]: try: js_test = yml["selector"]["js_test"] except KeyError: - raise Exception(f"The selector file {selector_file} is missing the 'selector.js_test' key") + raise Exception( + f"The selector file {selector_file} is missing the 'selector.js_test' key" + ) - return (default_if_none(js_test.get("exclude_suites"), []), - default_if_none(js_test.get("exclude_tasks"), []), - default_if_none(js_test.get("exclude_tests"), [])) + return ( + default_if_none(js_test.get("exclude_suites"), []), + default_if_none(js_test.get("exclude_tasks"), []), + default_if_none(js_test.get("exclude_tests"), []), + ) def filter_tests(tests: Set[str], exclude_tests: List[str]) -> Set[str]: @@ -199,7 +221,9 @@ def create_executor_list(suites, exclude_suites): try: with open(BURN_IN_TEST_MEMBERSHIP_FILE) as file: test_membership = collections.defaultdict(list, json.load(file)) - LOGGER.info(f"Using cached test membership file {BURN_IN_TEST_MEMBERSHIP_FILE}.") + LOGGER.info( + f"Using cached test membership file {BURN_IN_TEST_MEMBERSHIP_FILE}." + ) except FileNotFoundError: LOGGER.info("Getting test membership data.") test_membership = create_test_membership_map(test_kind=SUPPORTED_TEST_KINDS) @@ -208,7 +232,9 @@ def create_executor_list(suites, exclude_suites): for suite in suites: LOGGER.debug("Adding tests for suite", suite=suite, tests=suite.tests) for test in suite.tests: - LOGGER.debug("membership for test", test=test, membership=test_membership[test]) + LOGGER.debug( + "membership for test", test=test, membership=test_membership[test] + ) for executor in set(test_membership[test]) - set(exclude_suites): if test not in memberships[executor]: memberships[executor].append(test) @@ -255,9 +281,9 @@ class TaskToBurnInInfo(NamedTuple): @classmethod def from_task( - cls, - task: VariantTask, - tests_by_suite: Dict[str, List[str]], + cls, + task: VariantTask, + tests_by_suite: Dict[str, List[str]], ) -> "TaskToBurnInInfo": """ Gather the information needed to run the given task. @@ -271,7 +297,8 @@ class TaskToBurnInInfo(NamedTuple): name=suite_name, resmoke_args=resmoke_args, tests=tests_by_suite[suite_name], - ) for suite_name, resmoke_args in task.combined_suite_to_resmoke_args_map.items() + ) + for suite_name, resmoke_args in task.combined_suite_to_resmoke_args_map.items() if len(tests_by_suite[suite_name]) > 0 ] return cls( @@ -280,9 +307,12 @@ class TaskToBurnInInfo(NamedTuple): ) -def create_task_list(evergreen_conf: EvergreenProjectConfig, build_variant: str, - tests_by_suite: Dict[str, List[str]], - exclude_tasks: [str]) -> Dict[str, TaskToBurnInInfo]: +def create_task_list( + evergreen_conf: EvergreenProjectConfig, + build_variant: str, + tests_by_suite: Dict[str, List[str]], + exclude_tasks: [str], +) -> Dict[str, TaskToBurnInInfo]: """ Find associated tasks for the specified build_variant and suites. @@ -294,15 +324,20 @@ def create_task_list(evergreen_conf: EvergreenProjectConfig, build_variant: str, """ log = LOGGER.bind(build_variant=build_variant) - log.debug("creating task list for suites", suites=tests_by_suite, exclude_tasks=exclude_tasks) + log.debug( + "creating task list for suites", + suites=tests_by_suite, + exclude_tasks=exclude_tasks, + ) evg_build_variant = _get_evg_build_variant_by_name(evergreen_conf, build_variant) # Find all the build variant tasks. exclude_tasks_set = set(exclude_tasks) all_variant_tasks = { task.name: task - for task in evg_build_variant.tasks if task.name not in exclude_tasks_set and ( - task.is_run_tests_task or task.is_generate_resmoke_task) + for task in evg_build_variant.tasks + if task.name not in exclude_tasks_set + and (task.is_run_tests_task or task.is_generate_resmoke_task) } # Return the list of tasks to run for the specified suite. @@ -327,10 +362,13 @@ def _set_resmoke_cmd(repeat_config: RepeatConfig, resmoke_args: [str]) -> [str]: return new_args -def create_task_list_for_tests(changed_tests: Set[str], build_variant: str, - evg_conf: EvergreenProjectConfig, - exclude_suites: Optional[List] = None, - exclude_tasks: Optional[List] = None) -> Dict[str, TaskToBurnInInfo]: +def create_task_list_for_tests( + changed_tests: Set[str], + build_variant: str, + evg_conf: EvergreenProjectConfig, + exclude_suites: Optional[List] = None, + exclude_tasks: Optional[List] = None, +) -> Dict[str, TaskToBurnInInfo]: """ Create a list of tests by task for the given tests. @@ -355,9 +393,12 @@ def create_task_list_for_tests(changed_tests: Set[str], build_variant: str, return create_task_list(evg_conf, build_variant, tests_by_executor, exclude_tasks) -def create_tests_by_task(build_variant: str, evg_conf: EvergreenProjectConfig, - changed_tests: Set[str], - install_dir: Optional[str]) -> Dict[str, TaskToBurnInInfo]: +def create_tests_by_task( + build_variant: str, + evg_conf: EvergreenProjectConfig, + changed_tests: Set[str], + install_dir: Optional[str], +) -> Dict[str, TaskToBurnInInfo]: """ Create a list of tests by task. @@ -378,8 +419,9 @@ def create_tests_by_task(build_variant: str, evg_conf: EvergreenProjectConfig, buildscripts.resmokelib.parser.set_run_options(run_options) if changed_tests: - return create_task_list_for_tests(changed_tests, build_variant, evg_conf, exclude_suites, - exclude_tasks) + return create_task_list_for_tests( + changed_tests, build_variant, evg_conf, exclude_suites, exclude_tasks + ) LOGGER.info("No new or modified tests found.") return {} @@ -404,7 +446,9 @@ def run_tests(tests_by_task: Dict[str, TaskToBurnInInfo], resmoke_cmd: [str]) -> try: subprocess.check_call(new_resmoke_cmd, shell=False) except subprocess.CalledProcessError as err: - log.warning("Resmoke returned an error with suite", error=err.returncode) + log.warning( + "Resmoke returned an error with suite", error=err.returncode + ) sys.exit(err.returncode) @@ -424,7 +468,9 @@ def _configure_logging(verbose: bool): logging.getLogger(log_name).setLevel(logging.WARNING) -def _get_evg_build_variant_by_name(evergreen_conf: EvergreenProjectConfig, name: str) -> Variant: +def _get_evg_build_variant_by_name( + evergreen_conf: EvergreenProjectConfig, name: str +) -> Variant: """ Get the evergreen build variant by name from the evergreen config file. @@ -467,7 +513,11 @@ class FileChangeDetector(ABC): LOGGER.info("Calculated revision map", revision_map=revision_map) changed_files = find_changed_files_in_repos(repos, revision_map) - return {os.path.normpath(path) for path in changed_files if is_file_a_test_file(path)} + return { + os.path.normpath(path) + for path in changed_files + if is_file_a_test_file(path) + } class LocalFileChangeDetector(FileChangeDetector): @@ -588,23 +638,31 @@ class YamlBurnInExecutor(BurnInExecutor): :param tests_by_task: Dictionary of tasks to run with tests to run in each. """ - discovered_tasks = DiscoveredTaskList(discovered_tasks=[ - DiscoveredTask( - task_name=task_name, - suites=[ - DiscoveredSuite(suite_name=suite.name, test_list=suite.tests) - for suite in task_info.suites - ], - ) for task_name, task_info in tests_by_task.items() - ]) + discovered_tasks = DiscoveredTaskList( + discovered_tasks=[ + DiscoveredTask( + task_name=task_name, + suites=[ + DiscoveredSuite(suite_name=suite.name, test_list=suite.tests) + for suite in task_info.suites + ], + ) + for task_name, task_info in tests_by_task.items() + ] + ) print(yaml.safe_dump(discovered_tasks.dict())) class BurnInOrchestrator: """Orchestrate the execution of burn_in_tests.""" - def __init__(self, change_detector: FileChangeDetector, burn_in_executor: BurnInExecutor, - evg_conf: EvergreenProjectConfig, install_dir: Optional[str]) -> None: + def __init__( + self, + change_detector: FileChangeDetector, + burn_in_executor: BurnInExecutor, + evg_conf: EvergreenProjectConfig, + install_dir: Optional[str], + ) -> None: """ Create a new orchestrator. @@ -627,8 +685,9 @@ class BurnInOrchestrator: changed_tests = self.change_detector.find_changed_tests(repos) LOGGER.info("Found changed tests", files=changed_tests) - tests_by_task = create_tests_by_task(build_variant, self.evg_conf, changed_tests, - self.install_dir) + tests_by_task = create_tests_by_task( + build_variant, self.evg_conf, changed_tests, self.install_dir + ) LOGGER.debug("tests and tasks found", tests_by_task=tests_by_task) self.burn_in_executor.execute(tests_by_task) @@ -640,34 +699,92 @@ def cli(): @cli.command(context_settings=dict(ignore_unknown_options=True)) -@click.option("--no-exec", "no_exec", default=False, is_flag=True, - help="Do not execute the found tests.") -@click.option("--build-variant", "build_variant", default=DEFAULT_VARIANT, metavar='BUILD_VARIANT', - help="Tasks to run will be selected from this build variant.") -@click.option("--repeat-tests", "repeat_tests_num", default=None, type=int, - help="Number of times to repeat tests.") -@click.option("--repeat-tests-min", "repeat_tests_min", default=None, type=int, - help="The minimum number of times to repeat tests if time option is specified.") -@click.option("--repeat-tests-max", "repeat_tests_max", default=None, type=int, - help="The maximum number of times to repeat tests if time option is specified.") -@click.option("--repeat-tests-secs", "repeat_tests_secs", default=None, type=int, metavar="SECONDS", - help="Repeat tests for the given time (in secs).") -@click.option("--yaml", "use_yaml", is_flag=True, default=False, - help="Output discovered tasks in YAML. Tests will not be run.") -@click.option("--verbose", "verbose", default=False, is_flag=True, help="Enable extra logging.") @click.option( - "--origin-rev", "origin_rev", default=None, - help="The revision in the mongo repo that changes will be compared against if specified.") -@click.option("--install-dir", "install_dir", type=str, - help="Path to bin directory of a testable installation") -@click.option("--evg-project-file", "evg_project_file", default=DEFAULT_EVG_PROJECT_FILE, - help="Evergreen project config file") + "--no-exec", + "no_exec", + default=False, + is_flag=True, + help="Do not execute the found tests.", +) +@click.option( + "--build-variant", + "build_variant", + default=DEFAULT_VARIANT, + metavar="BUILD_VARIANT", + help="Tasks to run will be selected from this build variant.", +) +@click.option( + "--repeat-tests", + "repeat_tests_num", + default=None, + type=int, + help="Number of times to repeat tests.", +) +@click.option( + "--repeat-tests-min", + "repeat_tests_min", + default=None, + type=int, + help="The minimum number of times to repeat tests if time option is specified.", +) +@click.option( + "--repeat-tests-max", + "repeat_tests_max", + default=None, + type=int, + help="The maximum number of times to repeat tests if time option is specified.", +) +@click.option( + "--repeat-tests-secs", + "repeat_tests_secs", + default=None, + type=int, + metavar="SECONDS", + help="Repeat tests for the given time (in secs).", +) +@click.option( + "--yaml", + "use_yaml", + is_flag=True, + default=False, + help="Output discovered tasks in YAML. Tests will not be run.", +) +@click.option( + "--verbose", "verbose", default=False, is_flag=True, help="Enable extra logging." +) +@click.option( + "--origin-rev", + "origin_rev", + default=None, + help="The revision in the mongo repo that changes will be compared against if specified.", +) +@click.option( + "--install-dir", + "install_dir", + type=str, + help="Path to bin directory of a testable installation", +) +@click.option( + "--evg-project-file", + "evg_project_file", + default=DEFAULT_EVG_PROJECT_FILE, + help="Evergreen project config file", +) @click.argument("resmoke_args", nargs=-1, type=click.UNPROCESSED) -def run(build_variant: str, no_exec: bool, repeat_tests_num: Optional[int], - repeat_tests_min: Optional[int], repeat_tests_max: Optional[int], - repeat_tests_secs: Optional[int], resmoke_args: str, verbose: bool, - origin_rev: Optional[str], install_dir: Optional[str], use_yaml: bool, - evg_project_file: Optional[str]) -> None: +def run( + build_variant: str, + no_exec: bool, + repeat_tests_num: Optional[int], + repeat_tests_min: Optional[int], + repeat_tests_max: Optional[int], + repeat_tests_secs: Optional[int], + resmoke_args: str, + verbose: bool, + origin_rev: Optional[str], + install_dir: Optional[str], + use_yaml: bool, + evg_project_file: Optional[str], +) -> None: """ Run new or changed tests in repeated mode to validate their stability. @@ -719,7 +836,9 @@ def run(build_variant: str, no_exec: bool, repeat_tests_num: Optional[int], elif no_exec: executor = NopBurnInExecutor() - burn_in_orchestrator = BurnInOrchestrator(change_detector, executor, evg_conf, install_dir) + burn_in_orchestrator = BurnInOrchestrator( + change_detector, executor, evg_conf, install_dir + ) burn_in_orchestrator.burn_in(repos, build_variant) @@ -742,7 +861,8 @@ def generate_test_membership_map_file_for_ci(): with open(BURN_IN_TEST_MEMBERSHIP_FILE, "w") as file: json.dump(test_membership, file) LOGGER.info( - f"Finished writing burn_in test membership mapping to {BURN_IN_TEST_MEMBERSHIP_FILE}") + f"Finished writing burn_in test membership mapping to {BURN_IN_TEST_MEMBERSHIP_FILE}" + ) if __name__ == "__main__": diff --git a/buildscripts/cheetah_source_generator.py b/buildscripts/cheetah_source_generator.py index e15bc57ae4a..e7d6acfe078 100755 --- a/buildscripts/cheetah_source_generator.py +++ b/buildscripts/cheetah_source_generator.py @@ -42,14 +42,23 @@ def main(): """ parser = argparse.ArgumentParser(description=__doc__) - parser.add_argument('-o', nargs='?', type=argparse.FileType('w'), default=sys.stdout, - help='output file (default sys.stdout)') - parser.add_argument('template_file', help='Cheetah template file') - parser.add_argument('template_arg', nargs='*', default=[], help='Cheetah template args') + parser.add_argument( + "-o", + nargs="?", + type=argparse.FileType("w"), + default=sys.stdout, + help="output file (default sys.stdout)", + ) + parser.add_argument("template_file", help="Cheetah template file") + parser.add_argument( + "template_arg", nargs="*", default=[], help="Cheetah template args" + ) opts = parser.parse_args() - opts.o.write(str(Template(file=opts.template_file, namespaces=[{'args': opts.template_arg}]))) + opts.o.write( + str(Template(file=opts.template_file, namespaces=[{"args": opts.template_arg}])) + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/buildscripts/ciconfig/evergreen.py b/buildscripts/ciconfig/evergreen.py index 2e13fd50539..762c253e7ea 100644 --- a/buildscripts/ciconfig/evergreen.py +++ b/buildscripts/ciconfig/evergreen.py @@ -3,6 +3,7 @@ The API also provides methods to access specific fields present in the mongodb/mongo configuration file. """ + from __future__ import annotations import datetime @@ -33,15 +34,19 @@ def parse_evergreen_file(path, evergreen_binary="evergreen"): prev_environ = os.environ.copy() if sys.platform in ("win32", "cygwin"): - LOGGER.info(f"Previous os.environ={os.environ} before updating 'USERPROFILE'") - if 'HOME' in os.environ: - os.environ['USERPROFILE'] = os.environ['HOME'] + LOGGER.info( + f"Previous os.environ={os.environ} before updating 'USERPROFILE'" + ) + if "HOME" in os.environ: + os.environ["USERPROFILE"] = os.environ["HOME"] else: LOGGER.warn( "'HOME' enviorment variable unset. This will likely cause us to be unable to find evergreen binary." ) - default_evergreen_location = os.path.expanduser(os.path.join("~", "evergreen")) + default_evergreen_location = os.path.expanduser( + os.path.join("~", "evergreen") + ) # Restore enviorment if it was modified above on windows os.environ.clear() @@ -80,9 +85,12 @@ class EvergreenProjectConfig(object): self.tasks = [Task(task_dict) for task_dict in self._conf["tasks"]] self._tasks_by_name = {task.name: task for task in self.tasks} self.task_groups = [ - TaskGroup(task_group_dict) for task_group_dict in self._conf.get("task_groups", []) + TaskGroup(task_group_dict) + for task_group_dict in self._conf.get("task_groups", []) ] - self._task_groups_by_name = {task_group.name: task_group for task_group in self.task_groups} + self._task_groups_by_name = { + task_group.name: task_group for task_group in self.task_groups + } self.variants = [ Variant(variant_dict, self._tasks_by_name, self._task_groups_by_name) for variant_dict in self._conf["buildvariants"] @@ -214,12 +222,17 @@ class Task(object): if self.is_run_tests_task: return [command_vars.get("suite", self.name)] - if self.is_generate_resmoke_task and not self.is_initialize_multiversion_tasks_task: + if ( + self.is_generate_resmoke_task + and not self.is_initialize_multiversion_tasks_task + ): return [command_vars.get("suite", self.generated_task_name)] if self.is_initialize_multiversion_tasks_task: return [ suite - for suite in self.initialize_multiversion_tasks_command.get("vars", {}).keys() + for suite in self.initialize_multiversion_tasks_command.get( + "vars", {} + ).keys() ] raise ValueError(f"{self.name} task does not run a resmoke.py test suite") @@ -286,10 +299,18 @@ class Variant(object): # A task in conf_dict may be a task_group, containing a list of tasks. for task_in_group in task_group_map.get(task_name).tasks: self.tasks.append( - VariantTask(task_map.get(task_in_group), task.get("distros", run_on), self)) + VariantTask( + task_map.get(task_in_group), + task.get("distros", run_on), + self, + ) + ) else: self.tasks.append( - VariantTask(task_map.get(task["name"]), task.get("distros", run_on), self)) + VariantTask( + task_map.get(task["name"]), task.get("distros", run_on), self + ) + ) self.distro_names = set(run_on) for task in self.tasks: self.distro_names.update(task.run_on) diff --git a/buildscripts/clang_tidy_vscode.py b/buildscripts/clang_tidy_vscode.py index 20a1190df87..dc3b09f876d 100755 --- a/buildscripts/clang_tidy_vscode.py +++ b/buildscripts/clang_tidy_vscode.py @@ -67,7 +67,9 @@ def get_half_cpu_mask(): def count_running_clang_tidy(cmd_path): try: - output = subprocess.check_output(["ps", "-axww", "-o", "pid=,command="], text=True) + output = subprocess.check_output( + ["ps", "-axww", "-o", "pid=,command="], text=True + ) return sum(1 for line in output.splitlines() if cmd_path in line) except Exception as e: print(f"WARNING: failed to check running clang-tidy processes: {e}") @@ -110,7 +112,9 @@ def main(): if ( (arg.endswith(".cpp") or arg.endswith(".h")) and rel.startswith("src/mongo") - and not rel.startswith("src/mongo/db/modules/enterprise/src/streams/third_party") + and not rel.startswith( + "src/mongo/db/modules/enterprise/src/streams/third_party" + ) ): files_to_check.append(rel) else: @@ -125,7 +129,9 @@ def main(): ) if not os.path.exists("compile_commands.json"): - print("ERROR: failed to find compile_commands.json, run 'bazel build compiledb'") + print( + "ERROR: failed to find compile_commands.json, run 'bazel build compiledb'" + ) sys.exit(1) with open("compile_commands.json") as f: @@ -168,7 +174,9 @@ def main(): # probably a header, skip caching and let clang-tidy do its thing: else: - proc = subprocess.run(clang_tidy_cmd + files_to_check + other_args, capture_output=True) + proc = subprocess.run( + clang_tidy_cmd + files_to_check + other_args, capture_output=True + ) sys.stdout.buffer.write(proc.stdout) sys.stderr.buffer.write(proc.stderr) diff --git a/buildscripts/client/github.py b/buildscripts/client/github.py index 8b602de3fe3..1f4eca9d822 100644 --- a/buildscripts/client/github.py +++ b/buildscripts/client/github.py @@ -1,4 +1,5 @@ """Functions for working with github.""" + import logging import time @@ -38,7 +39,7 @@ class GithubApi(object): links = response.headers["Link"].split(",") for link in links: link_parts = link.split(";") - link_type = link_parts[1].replace("rel=", "").strip(" \"") + link_type = link_parts[1].replace("rel=", "").strip(' "') link_address = link_parts[0].strip("<> ") link_object[link_type] = link_address @@ -46,8 +47,9 @@ class GithubApi(object): def get_commits(self, owner, project, params): """Get the list of commits from a specified repository from github.""" - url = "{api_server}/repos/{owner}/{project}/commits".format(api_server=self.api_server, - owner=owner, project=project) + url = "{api_server}/repos/{owner}/{project}/commits".format( + api_server=self.api_server, owner=owner, project=project + ) LOGGER.debug("get_commits project=%s/%s, params: %s", owner, project, params) response = self._make_request(url, params) @@ -61,7 +63,11 @@ class GithubApi(object): links = self._parse_link(response) - LOGGER.debug("Commits from github (count=%d): [%s - %s]", len(commits), commits[-1]["sha"], - commits[0]["sha"]) + LOGGER.debug( + "Commits from github (count=%d): [%s - %s]", + len(commits), + commits[-1]["sha"], + commits[0]["sha"], + ) return commits diff --git a/buildscripts/client/jiraclient.py b/buildscripts/client/jiraclient.py index 6251d043e25..5b18078d7ee 100644 --- a/buildscripts/client/jiraclient.py +++ b/buildscripts/client/jiraclient.py @@ -1,4 +1,5 @@ """Module to access a JIRA server.""" + from enum import Enum import jira diff --git a/buildscripts/collect_resource_info.py b/buildscripts/collect_resource_info.py index 40e0e038059..6ae14610354 100755 --- a/buildscripts/collect_resource_info.py +++ b/buildscripts/collect_resource_info.py @@ -22,14 +22,23 @@ def main(): usage = "usage: %prog [options]" parser = optparse.OptionParser(description=__doc__, usage=usage) parser.add_option( - "-i", "--interval", dest="interval", default=5, type="int", + "-i", + "--interval", + dest="interval", + default=5, + type="int", help="Collect system resource information every seconds. " - "Default is every 5 seconds.") + "Default is every 5 seconds.", + ) parser.add_option( - "-o", "--output-file", dest="outfile", default="-", + "-o", + "--output-file", + dest="outfile", + default="-", help="If '-', then the file is written to stdout." " Any other value is treated as the output file name. By default," - " output is written to stdout.") + " output is written to stdout.", + ) (options, _) = parser.parse_args() @@ -40,8 +49,11 @@ def main(): response = requests.get("http://localhost:2285/status") if response.status_code != requests.codes.ok: print( - "Received a {} HTTP response: {}".format(response.status_code, - response.text), file=sys.stderr) + "Received a {} HTTP response: {}".format( + response.status_code, response.text + ), + file=sys.stderr, + ) time.sleep(options.interval) continue @@ -49,8 +61,12 @@ def main(): try: res_json = response.json() except ValueError: - print("Invalid JSON object returned with response: {}".format(response.text), - file=sys.stderr) + print( + "Invalid JSON object returned with response: {}".format( + response.text + ), + file=sys.stderr, + ) time.sleep(options.interval) continue diff --git a/buildscripts/combine_reports.py b/buildscripts/combine_reports.py index b8104959c74..96a7fcf1301 100755 --- a/buildscripts/combine_reports.py +++ b/buildscripts/combine_reports.py @@ -53,13 +53,24 @@ def main(): usage = "usage: %prog [options] report1.json report2.json ..." parser = OptionParser(description=__doc__, usage=usage) parser.add_option( - "-o", "--output-file", dest="outfile", default="-", - help=("If '-', then the combined report file is written to stdout." - " Any other value is treated as the output file name. By default," - " output is written to stdout.")) - parser.add_option("-x", "--no-report-exit", dest="report_exit", default=True, - action="store_false", - help="Do not exit with a non-zero code if any test in the report fails.") + "-o", + "--output-file", + dest="outfile", + default="-", + help=( + "If '-', then the combined report file is written to stdout." + " Any other value is treated as the output file name. By default," + " output is written to stdout." + ), + ) + parser.add_option( + "-x", + "--no-report-exit", + dest="report_exit", + default=True, + action="store_false", + help="Do not exit with a non-zero code if any test in the report fails.", + ) (options, args) = parser.parse_args() diff --git a/buildscripts/config_diff.py b/buildscripts/config_diff.py index a476e0bd0c0..cbfd11999f3 100755 --- a/buildscripts/config_diff.py +++ b/buildscripts/config_diff.py @@ -16,13 +16,13 @@ from enum import Enum import yaml -_COMPARE_FIELDS_SERVER_PARAMETERS = ['default', 'set_at', 'validator', 'test_only'] -_COMPARE_FIELDS_CONFIGS = ['arg_vartype', 'requires', 'hidden', 'redact'] +_COMPARE_FIELDS_SERVER_PARAMETERS = ["default", "set_at", "validator", "test_only"] +_COMPARE_FIELDS_CONFIGS = ["arg_vartype", "requires", "hidden", "redact"] class ComparisonType(str, Enum): - CONFIGS = 'configs' - SERVER_PARAMETERS = 'server_parameters' + CONFIGS = "configs" + SERVER_PARAMETERS = "server_parameters" class PropertyDiff: @@ -44,7 +44,8 @@ def build_diff_fn(compare_fields: list) -> callable: for field in compare_fields: if prop_base.get(field) != prop_inc.get(field): change_diffs[field] = PropertyDiff( - str(prop_base.get(field, "")), str(prop_inc.get(field, ""))) + str(prop_base.get(field, "")), str(prop_inc.get(field, "")) + ) return change_diffs return diff_fn @@ -76,13 +77,17 @@ class ComputeDiffsFromIncrementedVersionHandler: requires knowledge of a base dictionary of properties (base_properties) to execute. """ - def __init__(self, handler_type: ComparisonType, base_properties: dict, calc_diff_fn: callable): + def __init__( + self, + handler_type: ComparisonType, + base_properties: dict, + calc_diff_fn: callable, + ): self.calc_diff = calc_diff_fn self.handler_type = handler_type self.properties_diff = PropertiesDiffs(base_properties, {}, {}) def _compare_and_partition(self, yaml_props: dict, yaml_file_name: str) -> None: - for yaml_key, yaml_val in yaml_props.items(): compare_key = (yaml_key, yaml_file_name) @@ -124,17 +129,20 @@ def load_yaml(dirs: list, exclusions: list, idl_yaml_handlers: list) -> None: break for name in filenames: - if not name.endswith('.idl'): + if not name.endswith(".idl"): continue - with io.open(os.path.join(dirpath, name), 'r', encoding='utf-8') as idl_yaml_stream: + with io.open( + os.path.join(dirpath, name), "r", encoding="utf-8" + ) as idl_yaml_stream: idl_yaml = yaml.safe_load(idl_yaml_stream) for handler in idl_yaml_handlers: handler.handle(idl_yaml, name) -def get_properties_diffs(mode: ComparisonType, base_version_dirs: list, inc_version_dirs: list, - exclude: list) -> PropertiesDiffs: +def get_properties_diffs( + mode: ComparisonType, base_version_dirs: list, inc_version_dirs: list, exclude: list +) -> PropertiesDiffs: """Returns a PropertiesDiffs object containing the changes between properties in base_version_dirs and inc_version_dirs.""" compare_fields = [] @@ -143,22 +151,22 @@ def get_properties_diffs(mode: ComparisonType, base_version_dirs: list, inc_vers elif mode == ComparisonType.CONFIGS: compare_fields = _COMPARE_FIELDS_CONFIGS else: - raise Exception(f'Unknown option {mode}') + raise Exception(f"Unknown option {mode}") diff_fn = build_diff_fn(compare_fields) base_handler = BuildBasePropertiesForComparisonHandler(mode) load_yaml(base_version_dirs, exclude, [base_handler]) - increment_handler = ComputeDiffsFromIncrementedVersionHandler(mode, base_handler.properties, - diff_fn) + increment_handler = ComputeDiffsFromIncrementedVersionHandler( + mode, base_handler.properties, diff_fn + ) load_yaml(inc_version_dirs, exclude, [increment_handler]) return increment_handler.properties_diff def output_diffs(mode: ComparisonType, diff: PropertiesDiffs) -> None: - pp = pprint.PrettyPrinter() mode_format = "" @@ -167,56 +175,67 @@ def output_diffs(mode: ComparisonType, diff: PropertiesDiffs) -> None: elif mode == ComparisonType.SERVER_PARAMETERS: mode_format = "server parameter" else: - raise Exception(f'Unknown option {mode}') + raise Exception(f"Unknown option {mode}") for sp, val in diff.added.items(): - if not val.get('test_only'): - print(f'Added {mode_format} {str(sp)}') + if not val.get("test_only"): + print(f"Added {mode_format} {str(sp)}") pp.pprint(val) print() for sp, val in diff.removed.items(): - if not val.get('test_only'): - print(f'Removed {mode_format} {str(sp)}') + if not val.get("test_only"): + print(f"Removed {mode_format} {str(sp)}") pp.pprint(val) print() for sp, val in diff.modified.items(): - if not val.get('test_only'): - print(f'Modified {mode_format} {str(sp)}') + if not val.get("test_only"): + print(f"Modified {mode_format} {str(sp)}") for property_name, delta in val.items(): - print(f'<{property_name}> changed from [{delta.base}] to [{delta.inc}]') + print(f"<{property_name}> changed from [{delta.base}] to [{delta.inc}]") print() def main(): - arg_parser = argparse.ArgumentParser(prog="Core Server IDL Parameter/Config Diff") arg_parser.add_argument( - 'mode', choices=[ComparisonType.SERVER_PARAMETERS.value, ComparisonType.CONFIGS.value]) + "mode", + choices=[ComparisonType.SERVER_PARAMETERS.value, ComparisonType.CONFIGS.value], + ) arg_parser.add_argument( - '-b', '--base_version_dirs', - help='A colon-separated list of paths to the base version for comparison', required=True) + "-b", + "--base_version_dirs", + help="A colon-separated list of paths to the base version for comparison", + required=True, + ) arg_parser.add_argument( - '-i', '--incremented_version_dirs', - help='A colon-separated list of paths to the incremented version for comparison', - required=True) + "-i", + "--incremented_version_dirs", + help="A colon-separated list of paths to the incremented version for comparison", + required=True, + ) arg_parser.add_argument( - '-e', '--exclude_dirs', - help='A colon-separated list of directory path strings to exclude from comparison, ' + - 'e.g. a path /foo/bar/dir will be excluded by an argument of any of foo/bar/dir, bar/dir,' + - 'foo, or bar, or dir ', required=False) + "-e", + "--exclude_dirs", + help="A colon-separated list of directory path strings to exclude from comparison, " + + "e.g. a path /foo/bar/dir will be excluded by an argument of any of foo/bar/dir, bar/dir," + + "foo, or bar, or dir ", + required=False, + ) args = arg_parser.parse_args() - incremented_version_dirs = str.split(args.incremented_version_dirs, ':') - base_version_dirs = str.split(args.base_version_dirs, ':') - exclude = set(args.exclude_dirs.split(':')) if args.exclude_dirs else set() + incremented_version_dirs = str.split(args.incremented_version_dirs, ":") + base_version_dirs = str.split(args.base_version_dirs, ":") + exclude = set(args.exclude_dirs.split(":")) if args.exclude_dirs else set() mode = ComparisonType(args.mode) - diffs = get_properties_diffs(mode, base_version_dirs, incremented_version_dirs, exclude) + diffs = get_properties_diffs( + mode, base_version_dirs, incremented_version_dirs, exclude + ) output_diffs(mode, diffs) @@ -254,19 +273,25 @@ class TestBuildBasePropertiesForComparisonHandler(unittest.TestCase): """ yaml_obj = yaml.load(document, Loader=yaml.FullLoader) - fixture = BuildBasePropertiesForComparisonHandler(ComparisonType.SERVER_PARAMETERS) + fixture = BuildBasePropertiesForComparisonHandler( + ComparisonType.SERVER_PARAMETERS + ) fixture.handle(yaml_obj, filename) - #should filter out configs, but parse server parameters - self.assertIsNone(fixture.properties.get(("net.compression.compressors", filename))) + # should filter out configs, but parse server parameters + self.assertIsNone( + fixture.properties.get(("net.compression.compressors", filename)) + ) self.assertIsNotNone(fixture.properties[("changeStreamOptions", filename)]) fixture = BuildBasePropertiesForComparisonHandler(ComparisonType.CONFIGS) fixture.handle(yaml_obj, filename) - #should filter out server parameters, but parse configs + # should filter out server parameters, but parse configs self.assertIsNone(fixture.properties.get(("changeStreamOptions", filename))) - self.assertIsNotNone(fixture.properties.get(("net.compression.compressors", filename))) + self.assertIsNotNone( + fixture.properties.get(("net.compression.compressors", filename)) + ) def test_empty_yaml_obj_does_nothing(self): filename = "test.yml" @@ -277,7 +302,9 @@ class TestBuildBasePropertiesForComparisonHandler(unittest.TestCase): yaml_obj = yaml.load(document, Loader=yaml.FullLoader) - fixture = BuildBasePropertiesForComparisonHandler(ComparisonType.SERVER_PARAMETERS) + fixture = BuildBasePropertiesForComparisonHandler( + ComparisonType.SERVER_PARAMETERS + ) fixture.handle(yaml_obj, filename) self.assertTrue(len(fixture.properties) == 0) @@ -328,8 +355,9 @@ class TestComputeDiffsFromIncrementedVersionHandler(unittest.TestCase): inc_yaml_obj = yaml.load(document, Loader=yaml.FullLoader) - inc_fixture = ComputeDiffsFromIncrementedVersionHandler(ComparisonType.CONFIGS, {}, - self.config_diff_function) + inc_fixture = ComputeDiffsFromIncrementedVersionHandler( + ComparisonType.CONFIGS, {}, self.config_diff_function + ) inc_fixture.handle(inc_yaml_obj, filename) properties_diffs = inc_fixture.properties_diff @@ -339,8 +367,9 @@ class TestComputeDiffsFromIncrementedVersionHandler(unittest.TestCase): self.assertIsNone(properties_diffs.added.get(("testOptions", filename))) self.assertIsNone(properties_diffs.added.get(("helloMorld", filename))) - inc_fixture = ComputeDiffsFromIncrementedVersionHandler(ComparisonType.SERVER_PARAMETERS, - {}, self.parameter_diff_function) + inc_fixture = ComputeDiffsFromIncrementedVersionHandler( + ComparisonType.SERVER_PARAMETERS, {}, self.parameter_diff_function + ) inc_fixture.handle(inc_yaml_obj, filename) properties_diffs = inc_fixture.properties_diff @@ -375,8 +404,9 @@ class TestComputeDiffsFromIncrementedVersionHandler(unittest.TestCase): inc_yaml_obj = yaml.load(document, Loader=yaml.FullLoader) - inc_fixture = ComputeDiffsFromIncrementedVersionHandler(ComparisonType.CONFIGS, {}, - self.config_diff_function) + inc_fixture = ComputeDiffsFromIncrementedVersionHandler( + ComparisonType.CONFIGS, {}, self.config_diff_function + ) inc_fixture.handle(inc_yaml_obj, filename) properties_diffs = inc_fixture.properties_diff @@ -386,8 +416,9 @@ class TestComputeDiffsFromIncrementedVersionHandler(unittest.TestCase): self.assertTrue(len(properties_diffs.removed) == 0) self.assertTrue(len(properties_diffs.modified) == 0) - inc_fixture = ComputeDiffsFromIncrementedVersionHandler(ComparisonType.SERVER_PARAMETERS, - {}, self.parameter_diff_function) + inc_fixture = ComputeDiffsFromIncrementedVersionHandler( + ComparisonType.SERVER_PARAMETERS, {}, self.parameter_diff_function + ) inc_fixture.handle(inc_yaml_obj, filename) properties_diffs = inc_fixture.properties_diff @@ -404,13 +435,16 @@ class TestComputeDiffsFromIncrementedVersionHandler(unittest.TestCase): """ def get_base_data(): - return {("ok", "test.yaml"): {"yes": "no"}, ("also_ok", "blah.yaml"): {"no": "yes"}} + return { + ("ok", "test.yaml"): {"yes": "no"}, + ("also_ok", "blah.yaml"): {"no": "yes"}, + } inc_yaml_obj = yaml.load(document, Loader=yaml.FullLoader) - inc_fixture = ComputeDiffsFromIncrementedVersionHandler(ComparisonType.CONFIGS, - get_base_data(), - self.config_diff_function) + inc_fixture = ComputeDiffsFromIncrementedVersionHandler( + ComparisonType.CONFIGS, get_base_data(), self.config_diff_function + ) inc_fixture.handle(inc_yaml_obj, filename) properties_diffs = inc_fixture.properties_diff @@ -422,9 +456,11 @@ class TestComputeDiffsFromIncrementedVersionHandler(unittest.TestCase): self.assertTrue(len(properties_diffs.added) == 0) self.assertTrue(len(properties_diffs.modified) == 0) - inc_fixture = ComputeDiffsFromIncrementedVersionHandler(ComparisonType.SERVER_PARAMETERS, - get_base_data(), - self.parameter_diff_function) + inc_fixture = ComputeDiffsFromIncrementedVersionHandler( + ComparisonType.SERVER_PARAMETERS, + get_base_data(), + self.parameter_diff_function, + ) inc_fixture.handle(inc_yaml_obj, filename) properties_diffs = inc_fixture.properties_diff @@ -475,8 +511,9 @@ class TestComputeDiffsFromIncrementedVersionHandler(unittest.TestCase): """ inc_yaml_obj = yaml.load(document, Loader=yaml.FullLoader) - inc_fixture = ComputeDiffsFromIncrementedVersionHandler(ComparisonType.CONFIGS, {}, - build_diff_fn(['default'])) + inc_fixture = ComputeDiffsFromIncrementedVersionHandler( + ComparisonType.CONFIGS, {}, build_diff_fn(["default"]) + ) inc_fixture.handle(inc_yaml_obj, filename) properties_diffs = inc_fixture.properties_diff @@ -485,8 +522,9 @@ class TestComputeDiffsFromIncrementedVersionHandler(unittest.TestCase): self.assertTrue(len(properties_diffs.removed) == 0) self.assertTrue(len(properties_diffs.modified) == 0) - inc_fixture = ComputeDiffsFromIncrementedVersionHandler(ComparisonType.SERVER_PARAMETERS, - {}, build_diff_fn(['set_at'])) + inc_fixture = ComputeDiffsFromIncrementedVersionHandler( + ComparisonType.SERVER_PARAMETERS, {}, build_diff_fn(["set_at"]) + ) inc_fixture.handle(inc_yaml_obj, filename) properties_diffs = inc_fixture.properties_diff @@ -541,11 +579,13 @@ class TestComputeDiffsFromIncrementedVersionHandler(unittest.TestCase): diff_fn = build_diff_fn(_COMPARE_FIELDS_CONFIGS) config_base_properties_handler = BuildBasePropertiesForComparisonHandler( - ComparisonType.CONFIGS) + ComparisonType.CONFIGS + ) config_base_properties_handler.handle(document_yaml, filename) config_inc_properties_handler = ComputeDiffsFromIncrementedVersionHandler( - ComparisonType.CONFIGS, config_base_properties_handler.properties, diff_fn) + ComparisonType.CONFIGS, config_base_properties_handler.properties, diff_fn + ) config_inc_properties_handler.handle(document_inc_yaml, filename) property_diff = config_inc_properties_handler.properties_diff @@ -555,11 +595,15 @@ class TestComputeDiffsFromIncrementedVersionHandler(unittest.TestCase): diff_fn = build_diff_fn(_COMPARE_FIELDS_SERVER_PARAMETERS) sp_base_properties_handler = BuildBasePropertiesForComparisonHandler( - ComparisonType.SERVER_PARAMETERS) + ComparisonType.SERVER_PARAMETERS + ) sp_base_properties_handler.handle(document_yaml, filename) sp_inc_properties_handler = ComputeDiffsFromIncrementedVersionHandler( - ComparisonType.SERVER_PARAMETERS, sp_base_properties_handler.properties, diff_fn) + ComparisonType.SERVER_PARAMETERS, + sp_base_properties_handler.properties, + diff_fn, + ) sp_inc_properties_handler.handle(document_inc_yaml, filename) property_diff = sp_inc_properties_handler.properties_diff @@ -645,37 +689,50 @@ class TestComputeDiffsFromIncrementedVersionHandler(unittest.TestCase): diff_fn = build_diff_fn(_COMPARE_FIELDS_CONFIGS) config_base_properties_handler = BuildBasePropertiesForComparisonHandler( - ComparisonType.CONFIGS) + ComparisonType.CONFIGS + ) config_base_properties_handler.handle(document_yaml, filename) config_inc_properties_handler = ComputeDiffsFromIncrementedVersionHandler( - ComparisonType.CONFIGS, config_base_properties_handler.properties, diff_fn) + ComparisonType.CONFIGS, config_base_properties_handler.properties, diff_fn + ) config_inc_properties_handler.handle(document_inc_yaml, filename) property_diff = config_inc_properties_handler.properties_diff self.assertEqual( - property_diff.modified.get(("asdf", filename)).get("arg_vartype").base, "String") + property_diff.modified.get(("asdf", filename)).get("arg_vartype").base, + "String", + ) self.assertEqual( - property_diff.modified.get(("asdf", filename)).get("arg_vartype").inc, "int") + property_diff.modified.get(("asdf", filename)).get("arg_vartype").inc, "int" + ) self.assertIsNone(property_diff.modified.get(("zxcv", filename))) diff_fn = build_diff_fn(_COMPARE_FIELDS_SERVER_PARAMETERS) sp_base_properties_handler = BuildBasePropertiesForComparisonHandler( - ComparisonType.SERVER_PARAMETERS) + ComparisonType.SERVER_PARAMETERS + ) sp_base_properties_handler.handle(document_yaml, filename) sp_inc_properties_handler = ComputeDiffsFromIncrementedVersionHandler( - ComparisonType.SERVER_PARAMETERS, sp_base_properties_handler.properties, diff_fn) + ComparisonType.SERVER_PARAMETERS, + sp_base_properties_handler.properties, + diff_fn, + ) sp_inc_properties_handler.handle(document_inc_yaml, filename) property_diff = sp_inc_properties_handler.properties_diff self.assertEqual( - property_diff.modified.get(("testOptions", filename)).get("set_at").base, "cluster") + property_diff.modified.get(("testOptions", filename)).get("set_at").base, + "cluster", + ) self.assertEqual( - property_diff.modified.get(("testOptions", filename)).get("set_at").inc, "runtime") + property_diff.modified.get(("testOptions", filename)).get("set_at").inc, + "runtime", + ) self.assertIsNone(property_diff.modified.get(("testParameter", filename))) diff --git a/buildscripts/cost_model/abt_calibrator.py b/buildscripts/cost_model/abt_calibrator.py index 4755024e775..48a8138ffbe 100644 --- a/buildscripts/cost_model/abt_calibrator.py +++ b/buildscripts/cost_model/abt_calibrator.py @@ -37,7 +37,7 @@ from cost_estimator import estimate from database_instance import DatabaseInstance from sklearn.linear_model import LinearRegression -__all__ = ['calibrate'] +__all__ = ["calibrate"] async def calibrate(config: AbtCalibratorConfig, database: DatabaseInstance): @@ -55,18 +55,21 @@ async def calibrate(config: AbtCalibratorConfig, database: DatabaseInstance): return result -def calibrate_node(abt_df: pd.DataFrame, config: AbtCalibratorConfig, - node_config: AbtNodeCalibrationConfig): +def calibrate_node( + abt_df: pd.DataFrame, + config: AbtCalibratorConfig, + node_config: AbtNodeCalibrationConfig, +): abt_node_df = abt_df[abt_df.abt_type == node_config.type] if node_config.filter_function is not None: abt_node_df = node_config.filter_function(abt_node_df) # pylint: disable=invalid-name if node_config.variables_override is None: - variables = ['n_processed'] + variables = ["n_processed"] else: variables = node_config.variables_override - y = abt_node_df['execution_time'] + y = abt_node_df["execution_time"] X = abt_node_df[variables] X = sm.add_constant(X) diff --git a/buildscripts/cost_model/benchmark.py b/buildscripts/cost_model/benchmark.py index 2810398da2a..54e272d4c06 100644 --- a/buildscripts/cost_model/benchmark.py +++ b/buildscripts/cost_model/benchmark.py @@ -87,8 +87,8 @@ class CostModelCoefficients: def to_camel_case(string): """Convert a snake_case string to camelCase one.""" - words = string.split('_') - return words[0] + ''.join(w.capitalize() for w in words[1:]) + words = string.split("_") + return words[0] + "".join(w.capitalize() for w in words[1:]) @dataclass @@ -109,12 +109,12 @@ class BenchmarkTask: def print(self): """Prints the task.""" - print('Cost Model Coefficients Overrides:') - print(f'\tA: {self.cost_model_a.to_dict()}') - print(f'\tB: {self.cost_model_b.to_dict()}') - print(f'threshold: {self.threshold}') - print(f'collection: {self.collection_name}') - print(f'pipeline: {self.pipeline}') + print("Cost Model Coefficients Overrides:") + print(f"\tA: {self.cost_model_a.to_dict()}") + print(f"\tB: {self.cost_model_b.to_dict()}") + print(f"threshold: {self.threshold}") + print(f"collection: {self.collection_name}") + print(f"pipeline: {self.pipeline}") @dataclass @@ -131,20 +131,20 @@ class BenchmarkResult: def print(self): """Print the results.""" - print('## Benchmark Task') + print("## Benchmark Task") self.task.print() - print('\n## Result') - print(f'Means: A: {self.variant_a.mean:,.2f}, B: {self.variant_b.mean:,.2f}.') + print("\n## Result") + print(f"Means: A: {self.variant_a.mean:,.2f}, B: {self.variant_b.mean:,.2f}.") print(f"t-test's p-value: {self.pvalue}.") if self.pvalue < self.task.threshold: - print('The means are significantly different.') + print("The means are significantly different.") else: - print('The means are not significantly different.') + print("The means are not significantly different.") - print('\n### A\n') + print("\n### A\n") self.variant_a.print() - print('\n### B\n') + print("\n### B\n") self.variant_b.print() @@ -162,13 +162,15 @@ class ExperimentResult: if index is None: index = len(self.explain) // 2 - print('ABT Physical Tree') + print("ABT Physical Tree") self.physical_tree[index].print() - print('\nSBE Execution Tree') + print("\nSBE Execution Tree") self.execution_tree[index].print() -async def benchmark(config: BenchmarkConfig, database: DatabaseInstance, task: BenchmarkTask): +async def benchmark( + config: BenchmarkConfig, database: DatabaseInstance, task: BenchmarkTask +): """Run the A/B performance task. It executes the given pipeline for both overrides of Cost Model Coefficients, @@ -177,7 +179,9 @@ async def benchmark(config: BenchmarkConfig, database: DatabaseInstance, task: B value (usually 0.05 or 0.01) we can say that the Null hypothesis is proven and there is no significant difference in the execution times. """ - async with get_database_parameter(database, 'internalCostModelCoefficients') as db_param: + async with get_database_parameter( + database, "internalCostModelCoefficients" + ) as db_param: await db_param.set(json.dumps(task.cost_model_a.to_dict())) result_a = await run(config, database, task.collection_name, task.pipeline) @@ -190,22 +194,34 @@ async def benchmark(config: BenchmarkConfig, database: DatabaseInstance, task: B execution_times_a = [et.total_execution_time for et in variant_a.execution_tree] execution_times_b = [et.total_execution_time for et in variant_b.execution_tree] - ttest_result = stats.ttest_ind(execution_times_a, execution_times_b, equal_var=False) + ttest_result = stats.ttest_ind( + execution_times_a, execution_times_b, equal_var=False + ) - return BenchmarkResult(task=task, variant_a=variant_a, variant_b=variant_b, - pvalue=ttest_result.pvalue) + return BenchmarkResult( + task=task, variant_a=variant_a, variant_b=variant_b, pvalue=ttest_result.pvalue + ) def make_variant(explain: Sequence[dict[str, any]]) -> ExperimentResult: """Make one variant of the A/B test.""" - pt = [physical_tree.build(e['queryPlanner']['winningPlan']['queryPlan']) for e in explain] - et = [execution_tree.build_execution_tree(e['executionStats']) for e in explain] + pt = [ + physical_tree.build(e["queryPlanner"]["winningPlan"]["queryPlan"]) + for e in explain + ] + et = [execution_tree.build_execution_tree(e["executionStats"]) for e in explain] mean = sum(et.total_execution_time for et in et) / len(et) - return ExperimentResult(explain=explain, physical_tree=pt, execution_tree=et, mean=mean) + return ExperimentResult( + explain=explain, physical_tree=pt, execution_tree=et, mean=mean + ) -async def run(config: BenchmarkConfig, database: DatabaseInstance, collection: str, - pipeline: Pipeline): +async def run( + config: BenchmarkConfig, + database: DatabaseInstance, + collection: str, + pipeline: Pipeline, +): """Run one variant of the A/B test.""" # warmup @@ -215,10 +231,10 @@ async def run(config: BenchmarkConfig, database: DatabaseInstance, collection: s result = [] for _ in range(config.runs): explain = await database.explain(collection, pipeline) - if explain['ok'] == 1: + if explain["ok"] == 1: result.append(explain) else: - logging.warn('Query execution failed: %s', explain) + logging.warn("Query execution failed: %s", explain) return result @@ -231,14 +247,18 @@ async def smoke_test(): await database.enable_cascades(True) cost_model_a = CostModelCoefficients(index_scan_incremental_cost=0.0001) cost_model_b = CostModelCoefficients(index_scan_incremental_cost=0.9) - task = BenchmarkTask(collection_name='c_str_05_45000', - pipeline=[{'$match': {'choice1': 'hello', 'choice2': 'gaussian'}}], - cost_model_a=cost_model_a, cost_model_b=cost_model_b, threshold=0.05) + task = BenchmarkTask( + collection_name="c_str_05_45000", + pipeline=[{"$match": {"choice1": "hello", "choice2": "gaussian"}}], + cost_model_a=cost_model_a, + cost_model_b=cost_model_b, + threshold=0.05, + ) res = await benchmark(config, database, task) res.print() -if __name__ == '__main__': +if __name__ == "__main__": loop = asyncio.get_event_loop() loop.run_until_complete(smoke_test()) diff --git a/buildscripts/cost_model/calibration_settings.py b/buildscripts/cost_model/calibration_settings.py index ef907f36173..4d1a8cb424e 100644 --- a/buildscripts/cost_model/calibration_settings.py +++ b/buildscripts/cost_model/calibration_settings.py @@ -37,48 +37,52 @@ from random_generator import ( RangeGenerator, ) -__all__ = ['main_config', 'distributions'] +__all__ = ["main_config", "distributions"] # A string value to fill up collections and not used in queries. -HIDDEN_STRING_VALUE = '__hidden_string_value' +HIDDEN_STRING_VALUE = "__hidden_string_value" # Data distributions settings. distributions = {} string_choice_values = [ - 'h', - 'hi', - 'hi!', - 'hola', - 'hello', - 'square', - 'squared', - 'gaussian', - 'chisquare', - 'chisquared', - 'hello world', - 'distribution', + "h", + "hi", + "hi!", + "hola", + "hello", + "square", + "squared", + "gaussian", + "chisquare", + "chisquared", + "hello world", + "distribution", ] string_choice_weights = [10, 20, 5, 17, 30, 7, 9, 15, 40, 2, 12, 1] -distributions['string_choice'] = RandomDistribution.choice(string_choice_values, - string_choice_weights) +distributions["string_choice"] = RandomDistribution.choice( + string_choice_values, string_choice_weights +) small_query_weights = [i for i in range(10, 201, 10)] small_query_cardinality = sum(small_query_weights) int_choice_values = [i for i in range(1, 1000, 50)] random.shuffle(int_choice_values) -distributions['int_choice'] = RandomDistribution.choice(int_choice_values, small_query_weights) +distributions["int_choice"] = RandomDistribution.choice( + int_choice_values, small_query_weights +) -distributions['random_string'] = ArrayRandomDistribution( +distributions["random_string"] = ArrayRandomDistribution( RandomDistribution.uniform(RangeGenerator(DataType.INTEGER, 5, 10, 2)), - RandomDistribution.uniform(RangeGenerator(DataType.STRING, "a", "z"))) + RandomDistribution.uniform(RangeGenerator(DataType.STRING, "a", "z")), +) def generate_random_str(num: int): - strs = distributions['random_string'].generate(num) + strs = distributions["random_string"].generate(num) str_list = [] for char_array in strs: str_res = "".join(char_array) @@ -90,45 +94,74 @@ def generate_random_str(num: int): def random_strings_distr(size: int, count: int): distr = ArrayRandomDistribution( RandomDistribution.uniform([size]), - RandomDistribution.uniform(RangeGenerator(DataType.STRING, "a", "z"))) + RandomDistribution.uniform(RangeGenerator(DataType.STRING, "a", "z")), + ) - return RandomDistribution.uniform([''.join(s) for s in distr.generate(count)]) + return RandomDistribution.uniform(["".join(s) for s in distr.generate(count)]) small_string_choice = generate_random_str(20) -distributions['string_choice_small'] = RandomDistribution.choice(small_string_choice, - small_query_weights) +distributions["string_choice_small"] = RandomDistribution.choice( + small_string_choice, small_query_weights +) -string_range_4 = RandomDistribution.normal(RangeGenerator(DataType.STRING, "abca", "abc_")) -string_range_5 = RandomDistribution.normal(RangeGenerator(DataType.STRING, "abcda", "abcd_")) -string_range_7 = RandomDistribution.normal(RangeGenerator(DataType.STRING, "hello_a", "hello__")) +string_range_4 = RandomDistribution.normal( + RangeGenerator(DataType.STRING, "abca", "abc_") +) +string_range_5 = RandomDistribution.normal( + RangeGenerator(DataType.STRING, "abcda", "abcd_") +) +string_range_7 = RandomDistribution.normal( + RangeGenerator(DataType.STRING, "hello_a", "hello__") +) string_range_12 = RandomDistribution.normal( - RangeGenerator(DataType.STRING, "helloworldaa", "helloworldd_")) + RangeGenerator(DataType.STRING, "helloworldaa", "helloworldd_") +) -distributions['string_mixed'] = RandomDistribution.mixed( - [string_range_4, string_range_5, string_range_7, string_range_12], [0.1, 0.15, 0.25, 0.5]) +distributions["string_mixed"] = RandomDistribution.mixed( + [string_range_4, string_range_5, string_range_7, string_range_12], + [0.1, 0.15, 0.25, 0.5], +) -distributions['string_uniform'] = RandomDistribution.uniform( - RangeGenerator(DataType.STRING, "helloworldaa", "helloworldd_")) +distributions["string_uniform"] = RandomDistribution.uniform( + RangeGenerator(DataType.STRING, "helloworldaa", "helloworldd_") +) -distributions['int_normal'] = RandomDistribution.normal( - RangeGenerator(DataType.INTEGER, 0, 1000, 2)) +distributions["int_normal"] = RandomDistribution.normal( + RangeGenerator(DataType.INTEGER, 0, 1000, 2) +) lengths_distr = RandomDistribution.uniform(RangeGenerator(DataType.INTEGER, 1, 10)) -distributions['array_small'] = ArrayRandomDistribution(lengths_distr, distributions['int_normal']) +distributions["array_small"] = ArrayRandomDistribution( + lengths_distr, distributions["int_normal"] +) # Database settings -database = config.DatabaseConfig(connection_string='mongodb://localhost', - database_name='abt_calibration', dump_path='~/data/dump', - restore_from_dump=config.RestoreMode.NEVER, dump_on_exit=False) +database = config.DatabaseConfig( + connection_string="mongodb://localhost", + database_name="abt_calibration", + dump_path="~/data/dump", + restore_from_dump=config.RestoreMode.NEVER, + dump_on_exit=False, +) # Collection template settings -def create_index_scan_collection_template(name: str, cardinality: int) -> config.CollectionTemplate: +def create_index_scan_collection_template( + name: str, cardinality: int +) -> config.CollectionTemplate: values = [ - 'iqtbr5b5is', 'vt5s3tf8o6', 'b0rgm58qsn', '9m59if353m', 'biw2l9ok17', 'b9ct0ue14d', - 'oxj0vxjsti', 'f3k8w9vb49', 'ec7v82k6nk', 'f49ufwaqx7' + "iqtbr5b5is", + "vt5s3tf8o6", + "b0rgm58qsn", + "9m59if353m", + "biw2l9ok17", + "b9ct0ue14d", + "oxj0vxjsti", + "f3k8w9vb49", + "ec7v82k6nk", + "f49ufwaqx7", ] start_weight = 10 @@ -143,80 +176,174 @@ def create_index_scan_collection_template(name: str, cardinality: int) -> config distr = RandomDistribution.choice(values, weights) return config.CollectionTemplate( - name=name, fields=[ - config.FieldTemplate(name="choice", data_type=config.DataType.STRING, - distribution=distr, indexed=True), - config.FieldTemplate(name="mixed1", data_type=config.DataType.STRING, - distribution=distributions["string_mixed"], indexed=False), - config.FieldTemplate(name="uniform1", data_type=config.DataType.STRING, - distribution=distributions["string_uniform"], indexed=False), - config.FieldTemplate(name="choice2", data_type=config.DataType.STRING, - distribution=distributions["string_choice"], indexed=False), - config.FieldTemplate(name="mixed2", data_type=config.DataType.STRING, - distribution=distributions["string_mixed"], indexed=False), - ], compound_indexes=[], cardinalities=[cardinality]) + name=name, + fields=[ + config.FieldTemplate( + name="choice", + data_type=config.DataType.STRING, + distribution=distr, + indexed=True, + ), + config.FieldTemplate( + name="mixed1", + data_type=config.DataType.STRING, + distribution=distributions["string_mixed"], + indexed=False, + ), + config.FieldTemplate( + name="uniform1", + data_type=config.DataType.STRING, + distribution=distributions["string_uniform"], + indexed=False, + ), + config.FieldTemplate( + name="choice2", + data_type=config.DataType.STRING, + distribution=distributions["string_choice"], + indexed=False, + ), + config.FieldTemplate( + name="mixed2", + data_type=config.DataType.STRING, + distribution=distributions["string_mixed"], + indexed=False, + ), + ], + compound_indexes=[], + cardinalities=[cardinality], + ) -def create_physical_scan_collection_template(name: str, - payload_size: int = 0) -> config.CollectionTemplate: +def create_physical_scan_collection_template( + name: str, payload_size: int = 0 +) -> config.CollectionTemplate: template = config.CollectionTemplate( - name=name, fields=[ - config.FieldTemplate(name="choice1", data_type=config.DataType.STRING, - distribution=distributions["string_choice"], indexed=False), - config.FieldTemplate(name="mixed1", data_type=config.DataType.STRING, - distribution=distributions["string_mixed"], indexed=False), - config.FieldTemplate(name="uniform1", data_type=config.DataType.STRING, - distribution=distributions["string_uniform"], indexed=False), - config.FieldTemplate(name="choice", data_type=config.DataType.STRING, - distribution=distributions["string_choice"], indexed=False), - config.FieldTemplate(name="mixed2", data_type=config.DataType.STRING, - distribution=distributions["string_mixed"], indexed=False), - ], compound_indexes=[], cardinalities=[1000, 5000, 10000]) + name=name, + fields=[ + config.FieldTemplate( + name="choice1", + data_type=config.DataType.STRING, + distribution=distributions["string_choice"], + indexed=False, + ), + config.FieldTemplate( + name="mixed1", + data_type=config.DataType.STRING, + distribution=distributions["string_mixed"], + indexed=False, + ), + config.FieldTemplate( + name="uniform1", + data_type=config.DataType.STRING, + distribution=distributions["string_uniform"], + indexed=False, + ), + config.FieldTemplate( + name="choice", + data_type=config.DataType.STRING, + distribution=distributions["string_choice"], + indexed=False, + ), + config.FieldTemplate( + name="mixed2", + data_type=config.DataType.STRING, + distribution=distributions["string_mixed"], + indexed=False, + ), + ], + compound_indexes=[], + cardinalities=[1000, 5000, 10000], + ) if payload_size > 0: payload_distr = random_strings_distr(payload_size, 1000) template.fields.append( - config.FieldTemplate(name="payload", data_type=config.DataType.STRING, - distribution=payload_distr, indexed=False)) + config.FieldTemplate( + name="payload", + data_type=config.DataType.STRING, + distribution=payload_distr, + indexed=False, + ) + ) return template collection_caridinalities = list(range(10000, 50001, 10000)) c_int_05 = config.CollectionTemplate( - name="c_int_05", fields=[ - config.FieldTemplate(name="in1", data_type=config.DataType.INTEGER, - distribution=distributions["int_normal"], indexed=True), - config.FieldTemplate(name="mixed1", data_type=config.DataType.STRING, - distribution=distributions["string_mixed"], indexed=False), - config.FieldTemplate(name="uniform1", data_type=config.DataType.STRING, - distribution=distributions["string_uniform"], indexed=False), - config.FieldTemplate(name="in2", data_type=config.DataType.INTEGER, - distribution=distributions["int_normal"], indexed=True), - config.FieldTemplate(name="mixed2", data_type=config.DataType.STRING, - distribution=distributions["string_mixed"], indexed=False), - ], compound_indexes=[], cardinalities=collection_caridinalities) + name="c_int_05", + fields=[ + config.FieldTemplate( + name="in1", + data_type=config.DataType.INTEGER, + distribution=distributions["int_normal"], + indexed=True, + ), + config.FieldTemplate( + name="mixed1", + data_type=config.DataType.STRING, + distribution=distributions["string_mixed"], + indexed=False, + ), + config.FieldTemplate( + name="uniform1", + data_type=config.DataType.STRING, + distribution=distributions["string_uniform"], + indexed=False, + ), + config.FieldTemplate( + name="in2", + data_type=config.DataType.INTEGER, + distribution=distributions["int_normal"], + indexed=True, + ), + config.FieldTemplate( + name="mixed2", + data_type=config.DataType.STRING, + distribution=distributions["string_mixed"], + indexed=False, + ), + ], + compound_indexes=[], + cardinalities=collection_caridinalities, +) c_arr_01 = config.CollectionTemplate( - name="c_arr_01", fields=[ - config.FieldTemplate(name="as", data_type=config.DataType.INTEGER, - distribution=distributions["array_small"], indexed=True) - ], compound_indexes=[], cardinalities=collection_caridinalities) + name="c_arr_01", + fields=[ + config.FieldTemplate( + name="as", + data_type=config.DataType.INTEGER, + distribution=distributions["array_small"], + indexed=True, + ) + ], + compound_indexes=[], + cardinalities=collection_caridinalities, +) -index_scan = create_index_scan_collection_template('index_scan', 1000000) +index_scan = create_index_scan_collection_template("index_scan", 1000000) -physical_scan = create_physical_scan_collection_template('physical_scan', 2000) +physical_scan = create_physical_scan_collection_template("physical_scan", 2000) # Data Generator settings data_generator = config.DataGeneratorConfig( - enabled=True, create_indexes=True, batch_size=10000, + enabled=True, + create_indexes=True, + batch_size=10000, collection_templates=[index_scan, physical_scan, c_int_05, c_arr_01], - write_mode=config.WriteMode.REPLACE, collection_name_with_card=True) + write_mode=config.WriteMode.REPLACE, + collection_name_with_card=True, +) # Workload Execution settings workload_execution = config.WorkloadExecutionConfig( - enabled=True, output_collection_name='calibrationData', write_mode=config.WriteMode.REPLACE, - warmup_runs=3, runs=30) + enabled=True, + output_collection_name="calibrationData", + write_mode=config.WriteMode.REPLACE, + warmup_runs=3, + runs=30, +) def make_filter_by_note(note_value: any): @@ -227,30 +354,45 @@ def make_filter_by_note(note_value: any): abt_nodes = [ - config.AbtNodeCalibrationConfig(type='PhysicalScan', - filter_function=make_filter_by_note('PhysicalScan')), - config.AbtNodeCalibrationConfig(type='IndexScan', - filter_function=make_filter_by_note('IndexScan')), - config.AbtNodeCalibrationConfig(type='Seek', filter_function=make_filter_by_note('IndexScan')), - config.AbtNodeCalibrationConfig(type='Filter', - filter_function=make_filter_by_note('PhysicalScan')), - config.AbtNodeCalibrationConfig(type='Evaluation', - filter_function=make_filter_by_note('Evaluation')), - config.AbtNodeCalibrationConfig(type='NestedLoopJoin'), - config.AbtNodeCalibrationConfig(type='HashJoin'), - config.AbtNodeCalibrationConfig(type='MergeJoin'), - config.AbtNodeCalibrationConfig(type='Union'), - config.AbtNodeCalibrationConfig(type='LimitSkip', - filter_function=make_filter_by_note('LimitSkip')), - config.AbtNodeCalibrationConfig(type='GroupBy'), - config.AbtNodeCalibrationConfig(type='Unwind'), - config.AbtNodeCalibrationConfig(type='Unique'), + config.AbtNodeCalibrationConfig( + type="PhysicalScan", filter_function=make_filter_by_note("PhysicalScan") + ), + config.AbtNodeCalibrationConfig( + type="IndexScan", filter_function=make_filter_by_note("IndexScan") + ), + config.AbtNodeCalibrationConfig( + type="Seek", filter_function=make_filter_by_note("IndexScan") + ), + config.AbtNodeCalibrationConfig( + type="Filter", filter_function=make_filter_by_note("PhysicalScan") + ), + config.AbtNodeCalibrationConfig( + type="Evaluation", filter_function=make_filter_by_note("Evaluation") + ), + config.AbtNodeCalibrationConfig(type="NestedLoopJoin"), + config.AbtNodeCalibrationConfig(type="HashJoin"), + config.AbtNodeCalibrationConfig(type="MergeJoin"), + config.AbtNodeCalibrationConfig(type="Union"), + config.AbtNodeCalibrationConfig( + type="LimitSkip", filter_function=make_filter_by_note("LimitSkip") + ), + config.AbtNodeCalibrationConfig(type="GroupBy"), + config.AbtNodeCalibrationConfig(type="Unwind"), + config.AbtNodeCalibrationConfig(type="Unique"), ] # Calibrator settings abt_calibrator = config.AbtCalibratorConfig( - enabled=True, test_size=0.2, input_collection_name=workload_execution.output_collection_name, - trace=False, nodes=abt_nodes) + enabled=True, + test_size=0.2, + input_collection_name=workload_execution.output_collection_name, + trace=False, + nodes=abt_nodes, +) -main_config = config.Config(database=database, data_generator=data_generator, - abt_calibrator=abt_calibrator, workload_execution=workload_execution) +main_config = config.Config( + database=database, + data_generator=data_generator, + abt_calibrator=abt_calibrator, + workload_execution=workload_execution, +) diff --git a/buildscripts/cost_model/ce_data_settings.py b/buildscripts/cost_model/ce_data_settings.py index 2f34cbfd3a9..a1fcd76e792 100644 --- a/buildscripts/cost_model/ce_data_settings.py +++ b/buildscripts/cost_model/ce_data_settings.py @@ -40,15 +40,18 @@ from random_generator import ( RangeGenerator, ) -__all__ = ['database_config', 'data_generator_config'] +__all__ = ["database_config", "data_generator_config"] ################################################################################ # Data distributions ################################################################################ -def add_distribution(distr_set: Sequence[RandomDistribution], distr_type: DistributionType, - rg: RangeGenerator): +def add_distribution( + distr_set: Sequence[RandomDistribution], + distr_type: DistributionType, + rg: RangeGenerator, +): distr = None if distr_type == DistributionType.UNIFORM: distr = RandomDistribution.uniform(rg) @@ -104,27 +107,48 @@ for range_gen in int_ranges_2: # Mixes of distributions with different NDV and value distances int_distributions.append( RandomDistribution.mixed( - children=[int_distributions[0], int_distributions_offset[0], int_distributions[4]], - weight=[1, 1, 1])) + children=[ + int_distributions[0], + int_distributions_offset[0], + int_distributions[4], + ], + weight=[1, 1, 1], + ) +) int_distributions.append( RandomDistribution.mixed( children=[int_distributions[1], int_distributions[4], int_distributions[7]], - weight=[1, 1, 1])) + weight=[1, 1, 1], + ) +) int_distributions.append( RandomDistribution.mixed( children=[ - int_distributions[1], int_distributions_offset[1], int_distributions[3], - int_distributions[2], int_distributions_offset[2] - ], weight=[1, 1, 1, 1, 1])) + int_distributions[1], + int_distributions_offset[1], + int_distributions[3], + int_distributions[2], + int_distributions_offset[2], + ], + weight=[1, 1, 1, 1, 1], + ) +) int_distributions.append( RandomDistribution.mixed( children=[ - int_distributions[2], int_distributions[3], int_distributions[6], - int_distributions_offset[1], int_distributions_offset[2], int_distributions_offset[5] - ], weight=[1, 1, 1, 1, 1, 1])) + int_distributions[2], + int_distributions[3], + int_distributions[6], + int_distributions_offset[1], + int_distributions_offset[2], + int_distributions_offset[5], + ], + weight=[1, 1, 1, 1, 1, 1], + ) +) ############################# # Double number distributions @@ -137,7 +161,7 @@ dbl_ranges = [ # 10K unique doubles with different distances RangeGenerator(DataType.DOUBLE, 0.0, 1000.0, 0.1), RangeGenerator(DataType.DOUBLE, 0.0, 100000.0, 10), - RangeGenerator(DataType.DOUBLE, 0.0, 10000000.0, 1000) + RangeGenerator(DataType.DOUBLE, 0.0, 10000000.0, 1000), ] dbl_distributions = [] @@ -149,16 +173,25 @@ for range_gen in dbl_ranges: dbl_distributions.append( RandomDistribution.mixed( children=[dbl_distributions[0], dbl_distributions[3], dbl_distributions[10]], - weight=[1, 1, 1])) + weight=[1, 1, 1], + ) +) dbl_distributions.append( RandomDistribution.mixed( children=[ dbl_distributions[0], dbl_distributions[4], - RandomDistribution.normal(RangeGenerator(DataType.DOUBLE, 500.0, 600.0, 0.1)), - RandomDistribution.normal(RangeGenerator(DataType.DOUBLE, 3000200.0, 5000100.0, 3030)), - ], weight=[1, 1, 1, 1])) + RandomDistribution.normal( + RangeGenerator(DataType.DOUBLE, 500.0, 600.0, 0.1) + ), + RandomDistribution.normal( + RangeGenerator(DataType.DOUBLE, 3000200.0, 5000100.0, 3030) + ), + ], + weight=[1, 1, 1, 1], + ) +) ############################# # Date distributions @@ -168,13 +201,27 @@ HOUR = MINUTE * 60 DAY = HOUR * 24 MONTH = DAY * 30 -range_dtt_1y = RangeGenerator(DataType.DATE, datetime(2007, 1, 1), datetime(2008, 1, 1), HOUR) -range_dtt_1m_1 = RangeGenerator(DataType.DATE, datetime(2007, 2, 1), datetime(2008, 3, 1), HOUR) -range_dtt_1m_2 = RangeGenerator(DataType.DATE, datetime(2007, 6, 1), datetime(2008, 7, 1), HOUR) -range_dtt_1m_3 = RangeGenerator(DataType.DATE, datetime(2007, 10, 1), datetime(2008, 11, 1), HOUR) -range_dtt_10y_1 = RangeGenerator(DataType.DATE, datetime(2006, 1, 1), datetime(2016, 1, 1), DAY) -range_dtt_10y_2 = RangeGenerator(DataType.DATE, datetime(1995, 1, 1), datetime(2005, 1, 1), DAY) -range_dtt_20y = RangeGenerator(DataType.DATE, datetime(1997, 10, 1), datetime(2017, 11, 1), MONTH) +range_dtt_1y = RangeGenerator( + DataType.DATE, datetime(2007, 1, 1), datetime(2008, 1, 1), HOUR +) +range_dtt_1m_1 = RangeGenerator( + DataType.DATE, datetime(2007, 2, 1), datetime(2008, 3, 1), HOUR +) +range_dtt_1m_2 = RangeGenerator( + DataType.DATE, datetime(2007, 6, 1), datetime(2008, 7, 1), HOUR +) +range_dtt_1m_3 = RangeGenerator( + DataType.DATE, datetime(2007, 10, 1), datetime(2008, 11, 1), HOUR +) +range_dtt_10y_1 = RangeGenerator( + DataType.DATE, datetime(2006, 1, 1), datetime(2016, 1, 1), DAY +) +range_dtt_10y_2 = RangeGenerator( + DataType.DATE, datetime(1995, 1, 1), datetime(2005, 1, 1), DAY +) +range_dtt_20y = RangeGenerator( + DataType.DATE, datetime(1997, 10, 1), datetime(2017, 11, 1), MONTH +) dt_distributions = [] @@ -182,25 +229,33 @@ add_distribution(dt_distributions, DistributionType.UNIFORM, range_dtt_1y) add_distribution(dt_distributions, DistributionType.NORMAL, range_dtt_10y_1) dt_distributions.append( - RandomDistribution.mixed([ - RandomDistribution.uniform(range_dtt_1y), - RandomDistribution.uniform(range_dtt_1m_1), - RandomDistribution.uniform(range_dtt_1m_2), - RandomDistribution.uniform(range_dtt_1m_3) - ], [1, 1, 1, 1])) + RandomDistribution.mixed( + [ + RandomDistribution.uniform(range_dtt_1y), + RandomDistribution.uniform(range_dtt_1m_1), + RandomDistribution.uniform(range_dtt_1m_2), + RandomDistribution.uniform(range_dtt_1m_3), + ], + [1, 1, 1, 1], + ) +) dt_distributions.append( - RandomDistribution.mixed([ - RandomDistribution.uniform(range_dtt_10y_1), - RandomDistribution.uniform(range_dtt_10y_2), - RandomDistribution.uniform(range_dtt_20y) - ], [1, 1, 1])) + RandomDistribution.mixed( + [ + RandomDistribution.uniform(range_dtt_10y_1), + RandomDistribution.uniform(range_dtt_10y_2), + RandomDistribution.uniform(range_dtt_20y), + ], + [1, 1, 1], + ) +) ####################### # String distributions -PRINTED_CHAR_MIN_CODE = ord('0') -PRINTED_CHAR_MAX_CODE = ord('~') +PRINTED_CHAR_MIN_CODE = ord("0") +PRINTED_CHAR_MAX_CODE = ord("~") ascii_printable_chars = [ chr(code) for code in range(PRINTED_CHAR_MIN_CODE, PRINTED_CHAR_MAX_CODE + 1) @@ -209,19 +264,27 @@ ascii_printable_chars = [ def next_char(char: str, distance: int, min_char_code: int, max_char_code: int): char_code = ord(char) - assert (min_char_code <= char_code <= max_char_code - ), f'char_code "{char_code}" is out of range ({min_char_code}, {max_char_code})' + assert ( + min_char_code <= char_code <= max_char_code + ), f'char_code "{char_code}" is out of range ({min_char_code}, {max_char_code})' number_of_chars = max_char_code - min_char_code + 1 - new_char_code = ((char_code - min_char_code + distance) % number_of_chars) + min_char_code - assert (min_char_code <= new_char_code <= - max_char_code), f'new char code "{new_char_code}" is out of range' + new_char_code = ( + (char_code - min_char_code + distance) % number_of_chars + ) + min_char_code + assert ( + min_char_code <= new_char_code <= max_char_code + ), f'new char code "{new_char_code}" is out of range' return chr(new_char_code) -def generate_str_by_distance(num_strings: int, seed_str: str, distance_distr_0: RandomDistribution, - distance_distr_1: RandomDistribution, - distance_distr_2: RandomDistribution, - distance_distr_3: RandomDistribution): +def generate_str_by_distance( + num_strings: int, + seed_str: str, + distance_distr_0: RandomDistribution, + distance_distr_1: RandomDistribution, + distance_distr_2: RandomDistribution, + distance_distr_3: RandomDistribution, +): """ Generate a set of unique strings with different string distances. @@ -240,14 +303,18 @@ def generate_str_by_distance(num_strings: int, seed_str: str, distance_distr_0: cur_str = seed_str str_set.add(cur_str) for i in range(1, num_strings): - new_str = next_char(cur_str[0], distances_0[i], PRINTED_CHAR_MIN_CODE, - PRINTED_CHAR_MAX_CODE) - new_str += next_char(cur_str[1], distances_1[i], PRINTED_CHAR_MIN_CODE, - PRINTED_CHAR_MAX_CODE) - new_str += next_char(cur_str[2], distances_2[i], PRINTED_CHAR_MIN_CODE, - PRINTED_CHAR_MAX_CODE) - new_str += next_char(cur_str[3], distances_3[i], PRINTED_CHAR_MIN_CODE, - PRINTED_CHAR_MAX_CODE) + new_str = next_char( + cur_str[0], distances_0[i], PRINTED_CHAR_MIN_CODE, PRINTED_CHAR_MAX_CODE + ) + new_str += next_char( + cur_str[1], distances_1[i], PRINTED_CHAR_MIN_CODE, PRINTED_CHAR_MAX_CODE + ) + new_str += next_char( + cur_str[2], distances_2[i], PRINTED_CHAR_MIN_CODE, PRINTED_CHAR_MAX_CODE + ) + new_str += next_char( + cur_str[3], distances_3[i], PRINTED_CHAR_MIN_CODE, PRINTED_CHAR_MAX_CODE + ) str_set.add(new_str) cur_str = new_str return list(str_set) @@ -268,17 +335,17 @@ d4 = RandomDistribution.uniform(range_int_20_30) # Sets of strings where characters at different positions have different distances string_sets = {} # 250 unique strings -string_sets['set_1112_250'] = generate_str_by_distance(250, 'xxxx', d1, d1, d1, d2) -string_sets['set_2221_250'] = generate_str_by_distance(250, 'azay', d2, d2, d3, d1) -string_sets['set_5555_250'] = generate_str_by_distance(250, 'axbz', d4, d4, d4, d4) +string_sets["set_1112_250"] = generate_str_by_distance(250, "xxxx", d1, d1, d1, d2) +string_sets["set_2221_250"] = generate_str_by_distance(250, "azay", d2, d2, d3, d1) +string_sets["set_5555_250"] = generate_str_by_distance(250, "axbz", d4, d4, d4, d4) # 1000 unique strings -string_sets['set_1112_1000'] = generate_str_by_distance(1000, 'xxxx', d1, d1, d1, d2) -string_sets['set_2221_1000'] = generate_str_by_distance(1000, 'azay', d2, d2, d3, d1) -string_sets['set_5555_1000'] = generate_str_by_distance(1000, 'axbz', d4, d4, d4, d4) +string_sets["set_1112_1000"] = generate_str_by_distance(1000, "xxxx", d1, d1, d1, d2) +string_sets["set_2221_1000"] = generate_str_by_distance(1000, "azay", d2, d2, d3, d1) +string_sets["set_5555_1000"] = generate_str_by_distance(1000, "axbz", d4, d4, d4, d4) # 10000 unique strings -string_sets['set_1112_10000'] = generate_str_by_distance(10000, 'xxxx', d1, d1, d1, d2) -string_sets['set_2221_10000'] = generate_str_by_distance(10000, 'azay', d2, d2, d3, d1) -string_sets['set_5555_10000'] = generate_str_by_distance(10000, 'axbz', d4, d4, d4, d4) +string_sets["set_1112_10000"] = generate_str_by_distance(10000, "xxxx", d1, d1, d1, d2) +string_sets["set_2221_10000"] = generate_str_by_distance(10000, "azay", d2, d2, d3, d1) +string_sets["set_5555_10000"] = generate_str_by_distance(10000, "axbz", d4, d4, d4, d4) # Weights with different variance. For instance if the smallest weight is 1, and the biggest weight is 5 # then some values in a choice distribution will be picked with at most 5 times higher probability. @@ -291,19 +358,26 @@ weight_range_s = RangeGenerator(DataType.INTEGER, 95, 101, 1) weight_range_l = RangeGenerator(DataType.INTEGER, 25, 101, 2) weights = {} -weights['weight_unif_s'] = RandomDistribution.uniform(weight_range_s) -weights['weight_unif_l'] = RandomDistribution.uniform(weight_range_l) +weights["weight_unif_s"] = RandomDistribution.uniform(weight_range_s) +weights["weight_unif_l"] = RandomDistribution.uniform(weight_range_l) -#weights['weight_norm_s'] = RandomDistribution.normal(weight_range_s) -#weights['weight_norm_l'] = RandomDistribution.normal(weight_range_l) +# weights['weight_norm_s'] = RandomDistribution.normal(weight_range_s) +# weights['weight_norm_l'] = RandomDistribution.normal(weight_range_l) -#weights['chi2_s'] = RandomDistribution.noncentral_chisquare(weight_range_s) -#weights['chi2_l'] = RandomDistribution.noncentral_chisquare(weight_range_l) +# weights['chi2_s'] = RandomDistribution.noncentral_chisquare(weight_range_s) +# weights['chi2_l'] = RandomDistribution.noncentral_chisquare(weight_range_l) -def add_choice_distr(distr_set: Sequence[RandomDistribution], str_set: Sequence[str], - weight_distr: RandomDistribution, v_name: str, w_name: str): - distr = RandomDistribution.choice(str_set, weight_distr.generate(len(str_set)), v_name, w_name) +def add_choice_distr( + distr_set: Sequence[RandomDistribution], + str_set: Sequence[str], + weight_distr: RandomDistribution, + v_name: str, + w_name: str, +): + distr = RandomDistribution.choice( + str_set, weight_distr.generate(len(str_set)), v_name, w_name + ) distr_set.append(distr) @@ -320,12 +394,19 @@ for set_name, cur_set in string_sets.items(): # array lenght distributions - they are all uniform arr_len_dist_s = RandomDistribution.uniform(RangeGenerator(DataType.INTEGER, 1, 6, 1)) -arr_len_dist_m = RandomDistribution.uniform(RangeGenerator(DataType.INTEGER, 90, 110, 3)) -arr_len_dist_l = RandomDistribution.uniform(RangeGenerator(DataType.INTEGER, 900, 1100, 10)) +arr_len_dist_m = RandomDistribution.uniform( + RangeGenerator(DataType.INTEGER, 90, 110, 3) +) +arr_len_dist_l = RandomDistribution.uniform( + RangeGenerator(DataType.INTEGER, 900, 1100, 10) +) -def add_array_distr(distr_set: Sequence[RandomDistribution], lengths_distr: RandomDistribution, - value_distr: RandomDistribution): +def add_array_distr( + distr_set: Sequence[RandomDistribution], + lengths_distr: RandomDistribution, + value_distr: RandomDistribution, +): distr_set.append(ArrayRandomDistribution(lengths_distr, value_distr)) @@ -349,24 +430,30 @@ add_array_distr(arr_distributions, arr_len_dist_l, str_distributions[-1]) # 30% scalars, 70% arrays arr_distributions.append( - RandomDistribution.mixed([int_distributions[0], arr_distributions[0]], [0.3, 0.7])) + RandomDistribution.mixed([int_distributions[0], arr_distributions[0]], [0.3, 0.7]) +) arr_distributions.append( - RandomDistribution.mixed([int_distributions[-1], arr_distributions[-1]], [0.3, 0.7])) + RandomDistribution.mixed([int_distributions[-1], arr_distributions[-1]], [0.3, 0.7]) +) # 70% scalars, 30% arrays arr_distributions.append( - RandomDistribution.mixed([int_distributions[0], arr_distributions[0]], [0.7, 0.3])) + RandomDistribution.mixed([int_distributions[0], arr_distributions[0]], [0.7, 0.3]) +) arr_distributions.append( - RandomDistribution.mixed([int_distributions[-1], arr_distributions[-1]], [0.7, 0.3])) + RandomDistribution.mixed([int_distributions[-1], arr_distributions[-1]], [0.7, 0.3]) +) arr_zero_size = RandomDistribution.uniform(RangeGenerator(DataType.INTEGER, 0, 1, 1)) arr_empty_distr = ArrayRandomDistribution(arr_zero_size, int_distributions[0]) # 20% empty arrays arr_distributions.append( - RandomDistribution.mixed([arr_empty_distr, arr_distributions[2]], [0.2, 0.8])) + RandomDistribution.mixed([arr_empty_distr, arr_distributions[2]], [0.2, 0.8]) +) # 80% empty arrays arr_distributions.append( - RandomDistribution.mixed([arr_empty_distr, arr_distributions[2]], [0.8, 0.2])) + RandomDistribution.mixed([arr_empty_distr, arr_distributions[2]], [0.8, 0.2]) +) ############################### # Mixed data type distributions @@ -377,40 +464,68 @@ mix_distributions = [] int_str_mix_1 = [int_distributions[0], str_distributions[0]] int_str_mix_2 = [int_distributions_offset[7], str_distributions[-1]] -mix_distributions.append(RandomDistribution.mixed(children=int_str_mix_1, weight=[0.5, 0.5])) -mix_distributions.append(RandomDistribution.mixed(children=int_str_mix_2, weight=[0.5, 0.5])) +mix_distributions.append( + RandomDistribution.mixed(children=int_str_mix_1, weight=[0.5, 0.5]) +) +mix_distributions.append( + RandomDistribution.mixed(children=int_str_mix_2, weight=[0.5, 0.5]) +) -mix_distributions.append(RandomDistribution.mixed(children=int_str_mix_1, weight=[0.1, 0.9])) -mix_distributions.append(RandomDistribution.mixed(children=int_str_mix_1, weight=[0.9, 0.1])) -mix_distributions.append(RandomDistribution.mixed(children=int_str_mix_2, weight=[0.1, 0.9])) -mix_distributions.append(RandomDistribution.mixed(children=int_str_mix_2, weight=[0.9, 0.1])) +mix_distributions.append( + RandomDistribution.mixed(children=int_str_mix_1, weight=[0.1, 0.9]) +) +mix_distributions.append( + RandomDistribution.mixed(children=int_str_mix_1, weight=[0.9, 0.1]) +) +mix_distributions.append( + RandomDistribution.mixed(children=int_str_mix_2, weight=[0.1, 0.9]) +) +mix_distributions.append( + RandomDistribution.mixed(children=int_str_mix_2, weight=[0.9, 0.1]) +) # Doubles and strings -dbl_ascii_range = RangeGenerator(DataType.DOUBLE, float(PRINTED_CHAR_MIN_CODE), - float(PRINTED_CHAR_MAX_CODE), 0.01) +dbl_ascii_range = RangeGenerator( + DataType.DOUBLE, float(PRINTED_CHAR_MIN_CODE), float(PRINTED_CHAR_MAX_CODE), 0.01 +) ascii_double_range_distr = RandomDistribution.normal(dbl_ascii_range) dbl_str_mix_1 = [ascii_double_range_distr, str_distributions[1]] -mix_distributions.append(RandomDistribution.mixed(children=dbl_str_mix_1, weight=[0.5, 0.5])) -mix_distributions.append(RandomDistribution.mixed(children=dbl_str_mix_1, weight=[0.1, 0.9])) -mix_distributions.append(RandomDistribution.mixed(children=dbl_str_mix_1, weight=[0.9, 0.1])) +mix_distributions.append( + RandomDistribution.mixed(children=dbl_str_mix_1, weight=[0.5, 0.5]) +) +mix_distributions.append( + RandomDistribution.mixed(children=dbl_str_mix_1, weight=[0.1, 0.9]) +) +mix_distributions.append( + RandomDistribution.mixed(children=dbl_str_mix_1, weight=[0.9, 0.1]) +) dbl_str_mix_2 = [dbl_distributions[5], str_distributions[0]] -mix_distributions.append(RandomDistribution.mixed(children=dbl_str_mix_2, weight=[0.5, 0.5])) +mix_distributions.append( + RandomDistribution.mixed(children=dbl_str_mix_2, weight=[0.5, 0.5]) +) dbl_str_mix_3 = [dbl_distributions[5], str_distributions[5]] -mix_distributions.append(RandomDistribution.mixed(children=dbl_str_mix_3, weight=[0.5, 0.5])) +mix_distributions.append( + RandomDistribution.mixed(children=dbl_str_mix_3, weight=[0.5, 0.5]) +) # Doubles and/or strings and dates dbl_str_dt_mix_1 = [ascii_double_range_distr, str_distributions[4], dt_distributions[0]] mix_distributions.append( - RandomDistribution.mixed(children=dbl_str_dt_mix_1, weight=[0.5, 0.5, 0.5])) + RandomDistribution.mixed(children=dbl_str_dt_mix_1, weight=[0.5, 0.5, 0.5]) +) str_dt_mix_1 = [str_distributions[0], dt_distributions[-1]] -mix_distributions.append(RandomDistribution.mixed(children=str_dt_mix_1, weight=[0.5, 0.5])) +mix_distributions.append( + RandomDistribution.mixed(children=str_dt_mix_1, weight=[0.5, 0.5]) +) str_dt_mix_2 = [str_distributions[-1], dt_distributions[0]] -mix_distributions.append(RandomDistribution.mixed(children=str_dt_mix_2, weight=[0.5, 0.5])) +mix_distributions.append( + RandomDistribution.mixed(children=str_dt_mix_2, weight=[0.5, 0.5]) +) ################################################################################ # Collection templates @@ -424,46 +539,88 @@ mix_distributions.append(RandomDistribution.mixed(children=str_dt_mix_2, weight= collection_cardinalities = [500] field_templates = [ - config.FieldTemplate(name=f'{str(dist)}', data_type=config.DataType.INTEGER, distribution=dist, - indexed=False) for dist in int_distributions + config.FieldTemplate( + name=f"{str(dist)}", + data_type=config.DataType.INTEGER, + distribution=dist, + indexed=False, + ) + for dist in int_distributions ] field_templates += [ - config.FieldTemplate(name=f'{str(dist)}', data_type=config.DataType.STRING, distribution=dist, - indexed=False) for dist in str_distributions + config.FieldTemplate( + name=f"{str(dist)}", + data_type=config.DataType.STRING, + distribution=dist, + indexed=False, + ) + for dist in str_distributions ] field_templates += [ - config.FieldTemplate(name=f'{str(dist)}', data_type=config.DataType.ARRAY, distribution=dist, - indexed=False) for dist in arr_distributions + config.FieldTemplate( + name=f"{str(dist)}", + data_type=config.DataType.ARRAY, + distribution=dist, + indexed=False, + ) + for dist in arr_distributions ] field_templates += [ - config.FieldTemplate(name=f'{str(dist)}', data_type=config.DataType.DOUBLE, distribution=dist, - indexed=False) for dist in dbl_distributions + config.FieldTemplate( + name=f"{str(dist)}", + data_type=config.DataType.DOUBLE, + distribution=dist, + indexed=False, + ) + for dist in dbl_distributions ] field_templates += [ - config.FieldTemplate(name=f'{str(dist)}', data_type=config.DataType.DATE, distribution=dist, - indexed=False) for dist in dt_distributions + config.FieldTemplate( + name=f"{str(dist)}", + data_type=config.DataType.DATE, + distribution=dist, + indexed=False, + ) + for dist in dt_distributions ] field_templates += [ - config.FieldTemplate(name=f'{str(dist)}', data_type=config.DataType.MIXDATA, distribution=dist, - indexed=False) for dist in mix_distributions + config.FieldTemplate( + name=f"{str(dist)}", + data_type=config.DataType.MIXDATA, + distribution=dist, + indexed=False, + ) + for dist in mix_distributions ] -ce_data = config.CollectionTemplate(name="ce_data", fields=field_templates, compound_indexes=[], - cardinalities=collection_cardinalities) +ce_data = config.CollectionTemplate( + name="ce_data", + fields=field_templates, + compound_indexes=[], + cardinalities=collection_cardinalities, +) ################################################################################ # Database settings ################################################################################ database_config = config.DatabaseConfig( - connection_string='mongodb://localhost', database_name='ce_accuracy_test', dump_path=Path( - '..', '..', 'jstests', 'query_golden', 'libs', 'data'), - restore_from_dump=config.RestoreMode.NEVER, dump_on_exit=False) + connection_string="mongodb://localhost", + database_name="ce_accuracy_test", + dump_path=Path("..", "..", "jstests", "query_golden", "libs", "data"), + restore_from_dump=config.RestoreMode.NEVER, + dump_on_exit=False, +) ################################################################################ # Data Generator settings ################################################################################ data_generator_config = config.DataGeneratorConfig( - enabled=True, create_indexes=False, batch_size=10000, collection_templates=[ce_data], - write_mode=config.WriteMode.REPLACE, collection_name_with_card=True) + enabled=True, + create_indexes=False, + batch_size=10000, + collection_templates=[ce_data], + write_mode=config.WriteMode.REPLACE, + collection_name_with_card=True, +) diff --git a/buildscripts/cost_model/ce_generate_data.py b/buildscripts/cost_model/ce_generate_data.py index 78ea063b901..f589b88cdfa 100644 --- a/buildscripts/cost_model/ce_generate_data.py +++ b/buildscripts/cost_model/ce_generate_data.py @@ -52,10 +52,15 @@ class CollectionTemplateEncoder(json.JSONEncoder): if isinstance(o, CollectionTemplate): collections = [] for card in o.cardinalities: - name = f'{o.name}_{card}' + name = f"{o.name}_{card}" collections.append( - dict(collectionName=name, fields=o.fields, compoundIndexes=o.compound_indexes, - cardinality=card)) + dict( + collectionName=name, + fields=o.fields, + compoundIndexes=o.compound_indexes, + cardinality=card, + ) + ) return collections elif isinstance(o, FieldTemplate): return dict(fieldName=o.name, dataType=o.data_type, indexed=o.indexed) @@ -82,11 +87,11 @@ async def dump_collection(db, dump_path, database_name, coll_name, chunk_size): """Dump a collection into separate files each containing at most chunk_size documents.""" def open_chunk_file(chunk_id): - chunk_name = f'{coll_name}_{chunk_id}' - chunk_file_path = Path(dump_path) / f'{chunk_name}' - print(f'Writing chunk: {chunk_file_path}') - chunk_file = open(chunk_file_path, 'w', encoding="utf-8") - chunk_file.write('// This is a generated file.\n') + chunk_name = f"{coll_name}_{chunk_id}" + chunk_file_path = Path(dump_path) / f"{chunk_name}" + print(f"Writing chunk: {chunk_file_path}") + chunk_file = open(chunk_file_path, "w", encoding="utf-8") + chunk_file.write("// This is a generated file.\n") chunk_file.write(f'{chunk_name} = {{collName: "{coll_name}", collData: [\n') return chunk_file, "'" + chunk_name + "'" @@ -115,7 +120,7 @@ async def dump_collection(db, dump_path, database_name, coll_name, chunk_size): chunk_file.write(json.dumps(doc, cls=OidEncoder)) doc_pos += 1 - chunk_file.write(',') + chunk_file.write(",") chunk_file.write("\n") close_chunk_file(chunk_file) return chunk_names @@ -123,16 +128,17 @@ async def dump_collection(db, dump_path, database_name, coll_name, chunk_size): async def dump_collections_to_json(db, dump_path, database_name, collections): chunk_size = 100 # number of documents per chunk file - print(f'Dumping all collections into chunks of size {chunk_size}.') + print(f"Dumping all collections into chunks of size {chunk_size}.") all_chunk_names = [] for coll_name in collections: - coll_chunk_names = await dump_collection(db, dump_path, database_name, coll_name, - chunk_size) + coll_chunk_names = await dump_collection( + db, dump_path, database_name, coll_name, chunk_size + ) all_chunk_names.extend(coll_chunk_names) # Generate a JS file that loads all chunk files - load_file = open(Path(dump_path) / f'{database_name}.data', 'w') - load_file.write('// This is a generated file.\n') + load_file = open(Path(dump_path) / f"{database_name}.data", "w") + load_file.write("// This is a generated file.\n") # Create an array named 'chunkNames' with all chunk file names to be loaded. load_file.write(f'const chunkNames = [{",".join(all_chunk_names)}];') @@ -146,9 +152,11 @@ async def generate_histograms(coll_template, coll, dump_path): doc_count = await coll.count_documents({}) for field in coll_template.fields: field_data = [] - if re.match('^mixeddata_.*', field.name): + if re.match("^mixeddata_.*", field.name): continue - async for doc in coll.find({field.name: {"$exists": True}}, {"_id": 0, field.name: 1}): + async for doc in coll.find( + {field.name: {"$exists": True}}, {"_id": 0, field.name: 1} + ): field_val = doc[field.name] if isinstance(field_val, str): field_val = re.escape(field_val) @@ -157,10 +165,11 @@ async def generate_histograms(coll_template, coll, dump_path): continue field_data.append(field_val) if len(field_data) > 0: - fig_file_name = f'{dump_path}/{coll.name}_{field.name}.png' - print(f'Generating histogram {fig_file_name}') - hist = sns.displot(data=field_data, kind="hist", - bins=round(math.sqrt(doc_count))).figure + fig_file_name = f"{dump_path}/{coll.name}_{field.name}.png" + print(f"Generating histogram {fig_file_name}") + hist = sns.displot( + data=field_data, kind="hist", bins=round(math.sqrt(doc_count)) + ).figure hist.savefig(fig_file_name) plt.close(hist) @@ -173,7 +182,6 @@ async def main(): # 1. Database Instance provides connectivity to a MongoDB instance, it loads data optionally # from the dump on creating and stores data optionally to the dump on closing. with DatabaseInstance(database_config) as database_instance: - # 2. Generate random data and populate collections with it. old_db_collections = await database_instance.database.list_collection_names() for coll_name in old_db_collections: @@ -188,27 +196,40 @@ async def main(): # TODO: This is an alternative way to export the data. It is better than what is implemented, # but cannot be used until we find a way to call 'mongoimport' from the corresponding JS test. # - #for coll_name in db_collections: + # for coll_name in db_collections: # subprocess.run([ # 'mongoexport', f'--db={database_config.database_name}', f'--collection={coll_name}', # f'--out={coll_name}.dat' # ], cwd=database_config.dump_path, check=True) - await dump_collections_to_json(database_instance.database, database_config.dump_path, - database_config.database_name, db_collections) + await dump_collections_to_json( + database_instance.database, + database_config.dump_path, + database_config.database_name, + db_collections, + ) # 4. Export the collection templates used to create the test collections into JSON file - with open(Path(database_config.dump_path) / f'{database_config.database_name}.schema', - "w") as metadata_file: + with open( + Path(database_config.dump_path) / f"{database_config.database_name}.schema", + "w", + ) as metadata_file: collections = [] for coll_template in data_generator_config.collection_templates: for card in coll_template.cardinalities: - name = f'{coll_template.name}_{card}' + name = f"{coll_template.name}_{card}" collections.append( - dict(collectionName=name, fields=coll_template.fields, - compound_indexes=coll_template.compound_indexes, cardinality=card)) + dict( + collectionName=name, + fields=coll_template.fields, + compound_indexes=coll_template.compound_indexes, + cardinality=card, + ) + ) # Uncomment this to generate histograms in PNG format # await generate_histograms(coll_template, database_instance.database[name], database_config.dump_path) - json_metadata = json.dumps(collections, indent=4, cls=CollectionTemplateEncoder) + json_metadata = json.dumps( + collections, indent=4, cls=CollectionTemplateEncoder + ) metadata_file.write("// This is a generated file.\nconst dbMetadata = ") metadata_file.write(json_metadata) metadata_file.write(";") @@ -216,7 +237,7 @@ async def main(): print("DONE!") -if __name__ == '__main__': +if __name__ == "__main__": loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) asyncio.run(main()) diff --git a/buildscripts/cost_model/common.py b/buildscripts/cost_model/common.py index 340fafc5a59..8820715e7ff 100644 --- a/buildscripts/cost_model/common.py +++ b/buildscripts/cost_model/common.py @@ -39,7 +39,7 @@ def timer_decorator(func): t0 = time.perf_counter() result = func(*args, **kwargs) t1 = time.perf_counter() - print(f'{func.__name__} took {t1-t0}s.') + print(f"{func.__name__} took {t1-t0}s.") return result return wrapper diff --git a/buildscripts/cost_model/cost_estimator.py b/buildscripts/cost_model/cost_estimator.py index 26e96735c11..e087dc32bba 100644 --- a/buildscripts/cost_model/cost_estimator.py +++ b/buildscripts/cost_model/cost_estimator.py @@ -68,8 +68,9 @@ class LinearModel: # pylint: disable=invalid-name -def estimate(fit, X: np.ndarray, y: np.ndarray, test_size: float, - trace: bool = False) -> LinearModel: +def estimate( + fit, X: np.ndarray, y: np.ndarray, test_size: float, trace: bool = False +) -> LinearModel: """Estimate cost model parameters.""" if len(X) == 0: @@ -80,7 +81,7 @@ def estimate(fit, X: np.ndarray, y: np.ndarray, test_size: float, X_training, X_test, y_training, y_test = train_test_split(X, y, test_size=test_size) if trace: - print(f'Training size: {len(X_training)}, test size: {len(X_test)}') + print(f"Training size: {len(X_training)}, test size: {len(X_test)}") print(X_training) print(y_training) @@ -96,5 +97,11 @@ def estimate(fit, X: np.ndarray, y: np.ndarray, test_size: float, evs = explained_variance_score(y_test, y_predict) corrcoef = np.corrcoef(np.transpose(X[:, 1:]), y) - return LinearModel(coef=coef[1:], intercept=coef[0], mse=mse, r2=r2, evs=evs, - corrcoef=corrcoef[0, 1:]) + return LinearModel( + coef=coef[1:], + intercept=coef[0], + mse=mse, + r2=r2, + evs=evs, + corrcoef=corrcoef[0, 1:], + ) diff --git a/buildscripts/cost_model/data_generator.py b/buildscripts/cost_model/data_generator.py index fd04c0e6c9c..ca2167b5fc0 100644 --- a/buildscripts/cost_model/data_generator.py +++ b/buildscripts/cost_model/data_generator.py @@ -41,7 +41,7 @@ from motor.motor_asyncio import AsyncIOMotorCollection from pymongo import IndexModel from random_generator import DataType, RandomDistribution -__all__ = ['DataGenerator'] +__all__ = ["DataGenerator"] @dataclass @@ -96,52 +96,78 @@ class DataGenerator: coll = self.database.database[coll_info.name] if self.config.write_mode == WriteMode.REPLACE: await coll.drop() - tasks.append(asyncio.create_task(self._populate_collection(coll, coll_info))) + tasks.append( + asyncio.create_task(self._populate_collection(coll, coll_info)) + ) if self.config.create_indexes: tasks.append( - asyncio.create_task(create_single_field_indexes(coll, coll_info.fields))) - tasks.append(asyncio.create_task(create_compound_indexes(coll, coll_info))) + asyncio.create_task( + create_single_field_indexes(coll, coll_info.fields) + ) + ) + tasks.append( + asyncio.create_task(create_compound_indexes(coll, coll_info)) + ) for task in tasks: await task t1 = time.time() - print(f'\npopulate Collections took {t1-t0} s.') + print(f"\npopulate Collections took {t1-t0} s.") def _generate_collection_infos(self): for coll_template in self.config.collection_templates: fields = [ - FieldInfo(name=ft.name, type=ft.data_type, distribution=ft.distribution, - indexed=ft.indexed) for ft in coll_template.fields + FieldInfo( + name=ft.name, + type=ft.data_type, + distribution=ft.distribution, + indexed=ft.indexed, + ) + for ft in coll_template.fields ] for doc_count in coll_template.cardinalities: - name = f'{coll_template.name}' + name = f"{coll_template.name}" if self.config.collection_name_with_card is True: - name = f'{coll_template.name}_{doc_count}' - yield CollectionInfo(name=name, fields=fields, documents_count=doc_count, - compound_indexes=coll_template.compound_indexes) + name = f"{coll_template.name}_{doc_count}" + yield CollectionInfo( + name=name, + fields=fields, + documents_count=doc_count, + compound_indexes=coll_template.compound_indexes, + ) - async def _populate_collection(self, coll: AsyncIOMotorCollection, - coll_info: CollectionInfo) -> None: - print(f'\nGenerating ${coll_info.name} ...') + async def _populate_collection( + self, coll: AsyncIOMotorCollection, coll_info: CollectionInfo + ) -> None: + print(f"\nGenerating ${coll_info.name} ...") batch_size = self.config.batch_size tasks = [] for _ in range(coll_info.documents_count // batch_size): - tasks.append(asyncio.create_task(populate_batch(coll, batch_size, coll_info.fields))) + tasks.append( + asyncio.create_task(populate_batch(coll, batch_size, coll_info.fields)) + ) if coll_info.documents_count % batch_size > 0: tasks.append( asyncio.create_task( - populate_batch(coll, coll_info.documents_count % batch_size, coll_info.fields))) + populate_batch( + coll, coll_info.documents_count % batch_size, coll_info.fields + ) + ) + ) for task in tasks: await task -async def populate_batch(coll: AsyncIOMotorCollection, documents_count: int, - fields: Sequence[FieldInfo]) -> None: +async def populate_batch( + coll: AsyncIOMotorCollection, documents_count: int, fields: Sequence[FieldInfo] +) -> None: """Generate collection data and write it to the collection.""" - await coll.insert_many(generate_collection_data(documents_count, fields), ordered=False) + await coll.insert_many( + generate_collection_data(documents_count, fields), ordered=False + ) def generate_collection_data(documents_count: int, fields: Sequence[FieldInfo]): @@ -149,30 +175,43 @@ def generate_collection_data(documents_count: int, fields: Sequence[FieldInfo]): documents = [{} for _ in range(documents_count)] for field in fields: - for field_index, field_data in enumerate(field.distribution.generate(documents_count)): + for field_index, field_data in enumerate( + field.distribution.generate(documents_count) + ): documents[field_index][field.name] = field_data return documents -async def create_single_field_indexes(coll: AsyncIOMotorCollection, - fields: Sequence[FieldInfo]) -> None: +async def create_single_field_indexes( + coll: AsyncIOMotorCollection, fields: Sequence[FieldInfo] +) -> None: """Create single-fields indexes on the given collection.""" - indexes = [IndexModel([(field.name, pymongo.ASCENDING)]) for field in fields if field.indexed] + indexes = [ + IndexModel([(field.name, pymongo.ASCENDING)]) + for field in fields + if field.indexed + ] if len(indexes) > 0: await coll.create_indexes(indexes) - print(f'create_single_field_indexes done. {[index.document for index in indexes]}') + print( + f"create_single_field_indexes done. {[index.document for index in indexes]}" + ) -async def create_compound_indexes(coll: AsyncIOMotorCollection, coll_info: CollectionInfo) -> None: +async def create_compound_indexes( + coll: AsyncIOMotorCollection, coll_info: CollectionInfo +) -> None: """Create a coumpound indexes on the given collection.""" indexes_spec = [] index_specs = [] for compound_index in coll_info.compound_indexes: - index_spec = IndexModel([(field, pymongo.ASCENDING) for field in compound_index]) + index_spec = IndexModel( + [(field, pymongo.ASCENDING) for field in compound_index] + ) indexes_spec.append(index_spec) index_specs.append([(field, pymongo.ASCENDING) for field in compound_index]) if len(indexes_spec) > 0: await coll.create_indexes(indexes_spec) - print(f'create_compound_indexes done. {index_specs}') + print(f"create_compound_indexes done. {index_specs}") diff --git a/buildscripts/cost_model/database_instance.py b/buildscripts/cost_model/database_instance.py index 5f833378602..e49cb79cca8 100644 --- a/buildscripts/cost_model/database_instance.py +++ b/buildscripts/cost_model/database_instance.py @@ -36,9 +36,9 @@ from typing import Any, Mapping, NewType, Sequence from config import DatabaseConfig, RestoreMode from motor.motor_asyncio import AsyncIOMotorClient -__all__ = ['DatabaseInstance', 'Pipeline'] +__all__ = ["DatabaseInstance", "Pipeline"] """MongoDB Aggregate's Pipeline""" -Pipeline = NewType('Pipeline', Sequence[Mapping[str, Any]]) +Pipeline = NewType("Pipeline", Sequence[Mapping[str, Any]]) class DatabaseInstance: @@ -52,8 +52,9 @@ class DatabaseInstance: def __enter__(self): if self.config.restore_from_dump == RestoreMode.ALWAYS or ( - self.config.restore_from_dump == RestoreMode.ONLY_NEW - and self.config.database_name not in self.client.list_database_names()): + self.config.restore_from_dump == RestoreMode.ONLY_NEW + and self.config.database_name not in self.client.list_database_names() + ): self.restore() return self @@ -68,74 +69,94 @@ class DatabaseInstance: def restore(self): """Restore the database from the 'self.dump_directory'.""" - subprocess.run(['mongorestore', '--nsInclude', f'{self.config.database_name}.*', '--drop'], - shell=True, check=True, cwd=self.config.dump_path) + subprocess.run( + ["mongorestore", "--nsInclude", f"{self.config.database_name}.*", "--drop"], + shell=True, + check=True, + cwd=self.config.dump_path, + ) def dump(self): """Dump the database into 'self.dump_directory'.""" - subprocess.run(['mongodump', '--db', self.config.database_name], cwd=self.config.dump_path, - check=True) + subprocess.run( + ["mongodump", "--db", self.config.database_name], + cwd=self.config.dump_path, + check=True, + ) async def set_parameter(self, name: str, value: any) -> None: """Set MongoDB Parameter.""" - await self.client.admin.command({'setParameter': 1, name: value}) + await self.client.admin.command({"setParameter": 1, name: value}) async def get_parameter(self, name: str) -> any: - return (await self.client.admin.command({'getParameter': 1, name: 1}))[name] + return (await self.client.admin.command({"getParameter": 1, name: 1}))[name] async def enable_sbe(self, state: bool) -> None: """Enable new query execution engine. Throw pymongo.errors.OperationFailure in case of failure.""" - await self.set_parameter('internalQueryFrameworkControl', - 'trySbeEngine' if state else 'forceClassicEngine') + await self.set_parameter( + "internalQueryFrameworkControl", + "trySbeEngine" if state else "forceClassicEngine", + ) async def enable_cascades(self, state: bool) -> None: """Enable new query optimizer. Requires featureFlagCommonQueryFramework set to True.""" # Set FeatureCompatibilityVersion compatible with featureFlagCommonQueryFramework. - version = (await self.client.admin.command( - {'getParameter': 1, - 'featureFlagCommonQueryFramework': 1}))['featureFlagCommonQueryFramework']['version'] + version = ( + await self.client.admin.command( + {"getParameter": 1, "featureFlagCommonQueryFramework": 1} + ) + )["featureFlagCommonQueryFramework"]["version"] await self.client.admin.command( - {'setFeatureCompatibilityVersion': version, 'confirm': True}) + {"setFeatureCompatibilityVersion": version, "confirm": True} + ) await self.client.admin.command( - {'configureFailPoint': 'enableExplainInBonsai', 'mode': 'alwaysOn'}) - await self.set_parameter('internalQueryFrameworkControl', - 'forceBonsai' if state else 'trySbeEngine') + {"configureFailPoint": "enableExplainInBonsai", "mode": "alwaysOn"} + ) + await self.set_parameter( + "internalQueryFrameworkControl", "forceBonsai" if state else "trySbeEngine" + ) async def explain(self, collection_name: str, pipeline: Pipeline) -> dict[str, any]: """Return explain for the given pipeline.""" return await self.database.command( - 'explain', {'aggregate': collection_name, 'pipeline': pipeline, 'cursor': {}}, - verbosity='executionStats') + "explain", + {"aggregate": collection_name, "pipeline": pipeline, "cursor": {}}, + verbosity="executionStats", + ) async def hide_index(self, collection_name: str, index_name: str) -> None: """Hide the given index from the query optimizer.""" await self.database.command( - {'collMod': collection_name, 'index': {'name': index_name, 'hidden': True}}) + {"collMod": collection_name, "index": {"name": index_name, "hidden": True}} + ) async def unhide_index(self, collection_name: str, index_name: str) -> None: """Make the given index visible for the query optimizer.""" await self.database.command( - {'collMod': collection_name, 'index': {'name': index_name, 'hidden': False}}) + {"collMod": collection_name, "index": {"name": index_name, "hidden": False}} + ) async def hide_all_indexes(self, collection_name: str) -> None: """Hide all indexes of the given collection from the query optimizer.""" for index in self.database[collection_name].list_indexes(): - if index['name'] != '_id_': - await self.hide_index(collection_name, index['name']) + if index["name"] != "_id_": + await self.hide_index(collection_name, index["name"]) async def unhide_all_indexes(self, collection_name: str) -> None: """Make all indexes of the given collection visible fpr the query optimizer.""" for index in self.database[collection_name].list_indexes(): - if index['name'] != '_id_': - await self.unhide_index(collection_name, index['name']) + if index["name"] != "_id_": + await self.unhide_index(collection_name, index["name"]) async def drop_collection(self, collection_name: str) -> None: """Drop collection.""" await self.database[collection_name].drop() - async def insert_many(self, collection_name: str, docs: Sequence[Mapping[str, any]]) -> None: + async def insert_many( + self, collection_name: str, docs: Sequence[Mapping[str, any]] + ) -> None: """Insert documents into the collection with the given name.""" if len(docs) > 0: await self.database[collection_name].insert_many(docs, ordered=False) @@ -146,12 +167,12 @@ class DatabaseInstance: async def get_stats(self, collection_name: str): """Get collection statistics.""" - return await self.database.command('collstats', collection_name) + return await self.database.command("collstats", collection_name) async def get_average_document_size(self, collection_name: str) -> float: """Get average document size for the given collection.""" stats = await self.get_stats(collection_name) - avg_size = stats.get('avgObjSize') + avg_size = stats.get("avgObjSize") return avg_size if avg_size is not None else 0 @@ -177,7 +198,9 @@ class DatabaseParameter: if self.original_value is not None: await self.set(self.original_value) else: - raise ValueError(f'The parameter "{self.parameter_name}" has not been remembered.') + raise ValueError( + f'The parameter "{self.parameter_name}" has not been remembered.' + ) @asynccontextmanager diff --git a/buildscripts/cost_model/end_to_end.py b/buildscripts/cost_model/end_to_end.py index fc057b3fd3c..83c3b04daa4 100644 --- a/buildscripts/cost_model/end_to_end.py +++ b/buildscripts/cost_model/end_to_end.py @@ -65,20 +65,20 @@ class CostEstimator: self.cost_model = cost_model self.estimators = { - 'PhysicalScan': self.physical_scan, - 'IndexScan': self.index_scan, - 'Seek': self.seek, - 'Filter': self.filter, - 'Evaluation': self.evaluation, - 'GroupBy': self.group_by, - 'Unwind': self.unwind, - 'NestedLoopJoin': self.nested_loop_join, - 'HashJoin': self.hash_join, - 'MergeJoin': self.merge_join, - 'Unique': self.unique, - 'Union': self.union, - 'LimitSkip': self.limit_skip, - 'Root': self.root, + "PhysicalScan": self.physical_scan, + "IndexScan": self.index_scan, + "Seek": self.seek, + "Filter": self.filter, + "Evaluation": self.evaluation, + "GroupBy": self.group_by, + "Unwind": self.unwind, + "NestedLoopJoin": self.nested_loop_join, + "HashJoin": self.hash_join, + "MergeJoin": self.merge_join, + "Unique": self.unique, + "Union": self.union, + "LimitSkip": self.limit_skip, + "Root": self.root, } def estimate(self, abt_node_name: str, cardinality: int) -> float: @@ -88,55 +88,93 @@ class CostEstimator: def physical_scan(self, cardinality: int) -> float: """Estinamate PhysicalScan ABT node.""" - return self.cost_model.scan_startup_cost + cardinality * self.cost_model.scan_incremental_cost + return ( + self.cost_model.scan_startup_cost + + cardinality * self.cost_model.scan_incremental_cost + ) def index_scan(self, cardinality: int) -> float: """Estinamate IndexScan ABT node.""" - return self.cost_model.index_scan_startup_cost + cardinality * self.cost_model.index_scan_incremental_cost + return ( + self.cost_model.index_scan_startup_cost + + cardinality * self.cost_model.index_scan_incremental_cost + ) def seek(self, cardinality: int) -> float: """Estinamate Seek ABT node.""" - return self.cost_model.seek_startup_cost + cardinality * self.cost_model.seek_cost + return ( + self.cost_model.seek_startup_cost + cardinality * self.cost_model.seek_cost + ) def filter(self, cardinality: int) -> float: """Estinamate Filter ABT node.""" - return self.cost_model.filter_startup_cost + cardinality * self.cost_model.filter_incremental_cost + return ( + self.cost_model.filter_startup_cost + + cardinality * self.cost_model.filter_incremental_cost + ) def evaluation(self, cardinality: int) -> float: """Estinamate Evaluation ABT node.""" - return self.cost_model.eval_startup_cost + cardinality * self.cost_model.eval_incremental_cost + return ( + self.cost_model.eval_startup_cost + + cardinality * self.cost_model.eval_incremental_cost + ) def group_by(self, cardinality: int) -> float: """Estinamate GroupBy ABT node.""" - return self.cost_model.group_by_startup_cost + cardinality * self.cost_model.group_by_incremental_cost + return ( + self.cost_model.group_by_startup_cost + + cardinality * self.cost_model.group_by_incremental_cost + ) def unwind(self, cardinality: int) -> float: """Estinamate Unwind ABT node.""" - return self.cost_model.unwind_startup_cost + cardinality * self.cost_model.unwind_incremental_cost + return ( + self.cost_model.unwind_startup_cost + + cardinality * self.cost_model.unwind_incremental_cost + ) def nested_loop_join(self, cardinality: int) -> float: """Estinamate NestedLoopJoin ABT node.""" - return self.cost_model.nested_loop_join_startup_cost + cardinality * self.cost_model.nested_loop_join_incremental_cost + return ( + self.cost_model.nested_loop_join_startup_cost + + cardinality * self.cost_model.nested_loop_join_incremental_cost + ) def hash_join(self, cardinality: int) -> float: """Estinamate HashJoin ABT node.""" - return self.cost_model.hash_join_startup_cost + cardinality * self.cost_model.hash_join_incremental_cost + return ( + self.cost_model.hash_join_startup_cost + + cardinality * self.cost_model.hash_join_incremental_cost + ) def merge_join(self, cardinality: int) -> float: """Estinamate MergeJoin ABT node.""" - return self.cost_model.merge_join_startup_cost + cardinality * self.cost_model.merge_join_incremental_cost + return ( + self.cost_model.merge_join_startup_cost + + cardinality * self.cost_model.merge_join_incremental_cost + ) def unique(self, cardinality: int) -> float: """Estinamate Unique ABT node.""" - return self.cost_model.unique_startup_cost + cardinality * self.cost_model.unique_incremental_cost + return ( + self.cost_model.unique_startup_cost + + cardinality * self.cost_model.unique_incremental_cost + ) def union(self, cardinality: int) -> float: """Estinamate Union ABT node.""" - return self.cost_model.union_startup_cost + cardinality * self.cost_model.union_incremental_cost + return ( + self.cost_model.union_startup_cost + + cardinality * self.cost_model.union_incremental_cost + ) def limit_skip(self, cardinality: int) -> float: """Estinamate LimitSkip ABT node.""" - return self.cost_model.limit_skip_startup_cost + cardinality * self.cost_model.limit_skip_incremental_cost + return ( + self.cost_model.limit_skip_startup_cost + + cardinality * self.cost_model.limit_skip_incremental_cost + ) def root(self, _: int) -> float: """Root ABT node is always 0.""" @@ -153,13 +191,23 @@ class AbtCostEstimator: def __init__(self, estimate_node: Callable[[str, int], float]): self.estimate_node = estimate_node - def estimate(self, abt: pt.Node, sbe: et.Node, - estimations: Sequence[Tuple[str, ExecutionStats, float]], level=0): + def estimate( + self, + abt: pt.Node, + sbe: et.Node, + estimations: Sequence[Tuple[str, ExecutionStats, float]], + level=0, + ): stats = get_excution_stats(sbe, abt.plan_node_id) local_cost = self.estimate_node(abt.node_type, stats.n_processed) estimations.append((abt.node_type, stats, local_cost)) - child_cost = sum((self.estimate(child, sbe, estimations, level + 1) - for child in abt.children), start=0.0) + child_cost = sum( + ( + self.estimate(child, sbe, estimations, level + 1) + for child in abt.children + ), + start=0.0, + ) return local_cost + child_cost @@ -167,16 +215,29 @@ class AbtCostEstimator: class EndToEndStatisticsRow: """Represents a row with descriptive statistics of one query execution.""" - def __init__(self, pipeline: str = None, abt_type: str = None, abt_type_id: int = 0, - execution_time: float = 0.0, estimated_cost: float = 0.0, n_documents: int = 0): - self.pipeline = pipeline if pipeline is not None else '' - self.abt_type = abt_type if abt_type is not None else '' + def __init__( + self, + pipeline: str = None, + abt_type: str = None, + abt_type_id: int = 0, + execution_time: float = 0.0, + estimated_cost: float = 0.0, + n_documents: int = 0, + ): + self.pipeline = pipeline if pipeline is not None else "" + self.abt_type = abt_type if abt_type is not None else "" self.abt_type_id = abt_type_id self.execution_time = execution_time self.estimated_cost = estimated_cost self.estimation_error = execution_time - estimated_cost - self.estimation_error_per_doc = self.estimation_error / n_documents if n_documents != 0 else 0 - self.relative_error = self.estimation_error / self.execution_time if self.execution_time != 0 else 0 + self.estimation_error_per_doc = ( + self.estimation_error / n_documents if n_documents != 0 else 0 + ) + self.relative_error = ( + self.estimation_error / self.execution_time + if self.execution_time != 0 + else 0 + ) pipeline: str abt_type: str @@ -189,11 +250,20 @@ class EndToEndStatisticsRow: def make_config(): - def create_end2end_collection_template(name: str, - cardinality: int) -> config.CollectionTemplate: + def create_end2end_collection_template( + name: str, cardinality: int + ) -> config.CollectionTemplate: values = [ - 'iqtbr5b5is', 'vt5s3tf8o6', 'b0rgm58qsn', '9m59if353m', 'biw2l9ok17', 'b9ct0ue14d', - 'oxj0vxjsti', 'f3k8w9vb49', 'ec7v82k6nk', 'f49ufwaqx7' + "iqtbr5b5is", + "vt5s3tf8o6", + "b0rgm58qsn", + "9m59if353m", + "biw2l9ok17", + "b9ct0ue14d", + "oxj0vxjsti", + "f3k8w9vb49", + "ec7v82k6nk", + "f49ufwaqx7", ] start_weight = 30 @@ -208,108 +278,198 @@ def make_config(): distr = RandomDistribution.choice(values, weights) return config.CollectionTemplate( - name=name, fields=[ - config.FieldTemplate(name="indexed_choice", data_type=config.DataType.STRING, - distribution=distr, indexed=True), - config.FieldTemplate(name="int1", data_type=config.DataType.INTEGER, - distribution=distributions["int_normal"], indexed=True), - config.FieldTemplate(name="non_indexed_choice", data_type=config.DataType.STRING, - distribution=distributions['string_choice'], indexed=False), - config.FieldTemplate(name="uniform1", data_type=config.DataType.STRING, - distribution=distributions["string_uniform"], indexed=False), - config.FieldTemplate(name="int2", data_type=config.DataType.INTEGER, - distribution=distributions["int_normal"], indexed=True), - config.FieldTemplate(name="choice2", data_type=config.DataType.STRING, - distribution=distributions["string_choice"], indexed=False), - config.FieldTemplate(name="mixed2", data_type=config.DataType.STRING, - distribution=distributions["string_mixed"], indexed=False), - ], compound_indexes=[], cardinalities=[cardinality]) + name=name, + fields=[ + config.FieldTemplate( + name="indexed_choice", + data_type=config.DataType.STRING, + distribution=distr, + indexed=True, + ), + config.FieldTemplate( + name="int1", + data_type=config.DataType.INTEGER, + distribution=distributions["int_normal"], + indexed=True, + ), + config.FieldTemplate( + name="non_indexed_choice", + data_type=config.DataType.STRING, + distribution=distributions["string_choice"], + indexed=False, + ), + config.FieldTemplate( + name="uniform1", + data_type=config.DataType.STRING, + distribution=distributions["string_uniform"], + indexed=False, + ), + config.FieldTemplate( + name="int2", + data_type=config.DataType.INTEGER, + distribution=distributions["int_normal"], + indexed=True, + ), + config.FieldTemplate( + name="choice2", + data_type=config.DataType.STRING, + distribution=distributions["string_choice"], + indexed=False, + ), + config.FieldTemplate( + name="mixed2", + data_type=config.DataType.STRING, + distribution=distributions["string_mixed"], + indexed=False, + ), + ], + compound_indexes=[], + cardinalities=[cardinality], + ) - col_end2end = create_end2end_collection_template('end2end', 2000000) + col_end2end = create_end2end_collection_template("end2end", 2000000) data_generator_config = config.DataGeneratorConfig( - enabled=True, create_indexes=True, batch_size=10000, collection_templates=[col_end2end], - write_mode=config.WriteMode.REPLACE, collection_name_with_card=True) + enabled=True, + create_indexes=True, + batch_size=10000, + collection_templates=[col_end2end], + write_mode=config.WriteMode.REPLACE, + collection_name_with_card=True, + ) workload_execution_config = config.WorkloadExecutionConfig( - enabled=True, output_collection_name='end2endData', write_mode=config.WriteMode.APPEND, - warmup_runs=3, runs=30) + enabled=True, + output_collection_name="end2endData", + write_mode=config.WriteMode.APPEND, + warmup_runs=3, + runs=30, + ) # The cost model to test. cost_model = CostModelCoefficients( - scan_incremental_cost=422.31145989, scan_startup_cost=6175.527218993269, - index_scan_incremental_cost=403.68075869, index_scan_startup_cost=14054.983953111061, - seek_cost=1223.35513997, seek_startup_cost=7488.662376624863, - filter_incremental_cost=83.7274685, filter_startup_cost=1461.3148783443378, - eval_incremental_cost=430.6176946, eval_startup_cost=1103.4048573163343, - group_by_incremental_cost=413.07932374, group_by_startup_cost=1199.8878012735659, - unwind_incremental_cost=586.57200195, unwind_startup_cost=1.0, + scan_incremental_cost=422.31145989, + scan_startup_cost=6175.527218993269, + index_scan_incremental_cost=403.68075869, + index_scan_startup_cost=14054.983953111061, + seek_cost=1223.35513997, + seek_startup_cost=7488.662376624863, + filter_incremental_cost=83.7274685, + filter_startup_cost=1461.3148783443378, + eval_incremental_cost=430.6176946, + eval_startup_cost=1103.4048573163343, + group_by_incremental_cost=413.07932374, + group_by_startup_cost=1199.8878012735659, + unwind_incremental_cost=586.57200195, + unwind_startup_cost=1.0, nested_loop_join_incremental_cost=161.62301944, - nested_loop_join_startup_cost=402.8455479458652, hash_join_incremental_cost=250.61365634, - hash_join_startup_cost=1.0, merge_join_incremental_cost=111.23423304, - merge_join_startup_cost=1517.7970800404169, unique_incremental_cost=269.71368614, - unique_startup_cost=1.0, union_incremental_cost=111.94945268, - union_startup_cost=69.88096657391543, limit_skip_incremental_cost=62.42111111, - limit_skip_startup_cost=655.1342592592522) + nested_loop_join_startup_cost=402.8455479458652, + hash_join_incremental_cost=250.61365634, + hash_join_startup_cost=1.0, + merge_join_incremental_cost=111.23423304, + merge_join_startup_cost=1517.7970800404169, + unique_incremental_cost=269.71368614, + unique_startup_cost=1.0, + union_incremental_cost=111.94945268, + union_startup_cost=69.88096657391543, + limit_skip_incremental_cost=62.42111111, + limit_skip_startup_cost=655.1342592592522, + ) cost_estimator = CostEstimator(cost_model) processor_config = config.End2EndProcessorConfig( - enabled=True, estimator=cost_estimator.estimate, - input_collection_name=workload_execution_config.output_collection_name) + enabled=True, + estimator=cost_estimator.estimate, + input_collection_name=workload_execution_config.output_collection_name, + ) return config.EntToEndTestingConfig( - database=main_config.database, data_generator=data_generator_config, - workload_execution=workload_execution_config, processor=processor_config, - result_csv_filepath="end2end.csv") + database=main_config.database, + data_generator=data_generator_config, + workload_execution=workload_execution_config, + processor=processor_config, + result_csv_filepath="end2end.csv", + ) -async def execute_queries(database: DatabaseInstance, we_config: config.WorkloadExecutionConfig, - collections: Sequence[CollectionInfo]): - collection = [ci for ci in collections if ci.name.startswith('end2end')][0] +async def execute_queries( + database: DatabaseInstance, + we_config: config.WorkloadExecutionConfig, + collections: Sequence[CollectionInfo], +): + collection = [ci for ci in collections if ci.name.startswith("end2end")][0] requests = [] limits = [5, 10, 15, 20, 25, 50] skips = [15, 10, 5] - for field in [f for f in collection.fields if f.name == 'indexed_choice']: + for field in [f for f in collection.fields if f.name == "indexed_choice"]: for val in field.distribution.get_values(): - if val.startswith('_'): + if val.startswith("_"): continue limit = limits[len(requests) % len(limits)] skip = skips[len(requests) % len(skips)] requests.append( - Query(pipeline=[{'$match': {field.name: val}}, {"$skip": skip}, {"$limit": limit}, - {"$project": {"int1": 1}}])) + Query( + pipeline=[ + {"$match": {field.name: val}}, + {"$skip": skip}, + {"$limit": limit}, + {"$project": {"int1": 1}}, + ] + ) + ) - for field in [f for f in collection.fields if f.name == 'non_indexed_choice']: - for val in ['chisquare', 'hi']: + for field in [f for f in collection.fields if f.name == "non_indexed_choice"]: + for val in ["chisquare", "hi"]: limit = limits[len(requests) % len(limits)] skip = skips[len(requests) % len(skips)] requests.append( - Query(pipeline=[{'$match': {field.name: val}}, {"$skip": skip}, {"$limit": limit}, - {"$project": {"int1": 1}}])) + Query( + pipeline=[ + {"$match": {field.name: val}}, + {"$skip": skip}, + {"$limit": limit}, + {"$project": {"int1": 1}}, + ] + ) + ) for i in range(100, 1000, 250): limit = limits[len(requests) % len(limits)] skip = skips[len(requests) % len(skips)] requests.append( - Query(pipeline=[{'$match': {'in1': i, 'in2': 1000 - - i}}, {"$skip": skip}, {"$limit": limit}])) + Query( + pipeline=[ + {"$match": {"in1": i, "in2": 1000 - i}}, + {"$skip": skip}, + {"$limit": limit}, + ] + ) + ) requests.append( - Query(pipeline=[{'$match': {'in1': {'$lte': i}, 'in2': 1000 - i}}, {"$skip": skip}, - {"$limit": limit}])) + Query( + pipeline=[ + {"$match": {"in1": {"$lte": i}, "in2": 1000 - i}}, + {"$skip": skip}, + {"$limit": limit}, + ] + ) + ) await workload_execution.execute(database, we_config, [collection], requests) -async def execute_index_intersect_queries(database: DatabaseInstance, - we_config: config.WorkloadExecutionConfig, - collections: Sequence[CollectionInfo]): - collection = [ci for ci in collections if ci.name.startswith('end2end')][0] +async def execute_index_intersect_queries( + database: DatabaseInstance, + we_config: config.WorkloadExecutionConfig, + collections: Sequence[CollectionInfo], +): + collection = [ci for ci in collections if ci.name.startswith("end2end")][0] requests = [] @@ -321,19 +481,36 @@ async def execute_index_intersect_queries(database: DatabaseInstance, skip = skips[len(requests) % len(skips)] requests.append( - Query(pipeline=[{'$match': {'in1': i, 'in2': 1000 - - i}}, {"$skip": skip}, {"$limit": limit}])) + Query( + pipeline=[ + {"$match": {"in1": i, "in2": 1000 - i}}, + {"$skip": skip}, + {"$limit": limit}, + ] + ) + ) requests.append( - Query(pipeline=[{'$match': {'in1': {'$lte': i}, 'in2': 1000 - i}}, {"$skip": skip}, - {"$limit": limit}])) + Query( + pipeline=[ + {"$match": {"in1": {"$lte": i}, "in2": 1000 - i}}, + {"$skip": skip}, + {"$limit": limit}, + ] + ) + ) - async with get_database_parameter( - database, 'internalCostModelCoefficients') as cost_model_param, get_database_parameter( - database, 'internalCascadesOptimizerDisableMergeJoinRIDIntersect' - ) as merge_join_param, get_database_parameter( - database, - 'internalCascadesOptimizerDisableHashJoinRIDIntersect') as hash_join_param: + async with ( + get_database_parameter( + database, "internalCostModelCoefficients" + ) as cost_model_param, + get_database_parameter( + database, "internalCascadesOptimizerDisableMergeJoinRIDIntersect" + ) as merge_join_param, + get_database_parameter( + database, "internalCascadesOptimizerDisableHashJoinRIDIntersect" + ) as hash_join_param, + ): await cost_model_param.set('{"filterIncrementalCost": 10000.0}') await merge_join_param.set(False) await hash_join_param.set(False) @@ -350,7 +527,7 @@ def extract_abt_nodes(df: pd.DataFrame, estimate_cost) -> pd.DataFrame: """Extract ABT Nodes and execution statistics from calibration DataFrame.""" def extract(df_seq): - es_dict = exp.extract_execution_stats(df_seq['sbe'], df_seq['abt'], []) + es_dict = exp.extract_execution_stats(df_seq["sbe"], df_seq["abt"], []) rows = [] for abt_type, es in es_dict.items(): @@ -359,48 +536,70 @@ def extract_abt_nodes(df: pd.DataFrame, estimate_cost) -> pd.DataFrame: continue estimated_cost = estimate_cost(abt_type, stat.n_processed) rows.append( - EndToEndStatisticsRow(abt_type=abt_type, execution_time=stat.execution_time, - estimated_cost=estimated_cost, - n_documents=stat.n_processed)) + EndToEndStatisticsRow( + abt_type=abt_type, + execution_time=stat.execution_time, + estimated_cost=estimated_cost, + n_documents=stat.n_processed, + ) + ) return rows return pd.DataFrame(list(df.apply(extract, axis=1).explode())) -def build_abt_nodes_report(df: pd.DataFrame, processor_config: config.End2EndProcessorConfig): +def build_abt_nodes_report( + df: pd.DataFrame, processor_config: config.End2EndProcessorConfig +): return extract_abt_nodes(df, processor_config.estimator) -def build_queries_report(df: pd.DataFrame, processor_config: config.End2EndProcessorConfig): +def build_queries_report( + df: pd.DataFrame, processor_config: config.End2EndProcessorConfig +): abt_estimator = AbtCostEstimator(processor_config.estimator) def calculate_cost(row): rows = [] estimations = [] - total_estimated_cost = abt_estimator.estimate(row['abt'], row['sbe'], estimations) + total_estimated_cost = abt_estimator.estimate( + row["abt"], row["sbe"], estimations + ) local_id = 0 rows.append( - EndToEndStatisticsRow(pipeline=row['pipeline'], abt_type_id=local_id, - execution_time=row['total_execution_time'], - estimated_cost=total_estimated_cost)) - for (abt_type, stats, local_cost) in estimations: + EndToEndStatisticsRow( + pipeline=row["pipeline"], + abt_type_id=local_id, + execution_time=row["total_execution_time"], + estimated_cost=total_estimated_cost, + ) + ) + for abt_type, stats, local_cost in estimations: local_id += 1 rows.append( - EndToEndStatisticsRow(pipeline=row['pipeline'], abt_type=abt_type, - abt_type_id=local_id, - execution_time=row['total_execution_time'], - estimated_cost=local_cost, n_documents=stats.n_processed)) + EndToEndStatisticsRow( + pipeline=row["pipeline"], + abt_type=abt_type, + abt_type_id=local_id, + execution_time=row["total_execution_time"], + estimated_cost=local_cost, + n_documents=stats.n_processed, + ) + ) return rows return pd.DataFrame(list(df.apply(calculate_cost, axis=1).explode())) -async def conduct_end2end(database: DatabaseInstance, - processor_config: config.End2EndProcessorConfig): +async def conduct_end2end( + database: DatabaseInstance, processor_config: config.End2EndProcessorConfig +): if not processor_config.enabled: return {} - df = await exp.load_calibration_data(database, processor_config.input_collection_name) + df = await exp.load_calibration_data( + database, processor_config.input_collection_name + ) noout_df = exp.remove_outliers(df, 0.0, 0.90) abt_report = build_abt_nodes_report(noout_df, processor_config) @@ -409,21 +608,26 @@ async def conduct_end2end(database: DatabaseInstance, report = pd.concat([abt_report, queries_report], axis=0) - group_columns = ['pipeline', 'abt_type', 'abt_type_id'] + group_columns = ["pipeline", "abt_type", "abt_type_id"] def calc_r2(group): - return r2_score(group['execution_time'], group['estimated_cost']) + return r2_score(group["execution_time"], group["estimated_cost"]) r2_scores = report.groupby(group_columns).apply(calc_r2).reset_index() - r2_scores.columns = [group_columns + ['r2'], [''] * (len(group_columns) + 1)] + r2_scores.columns = [group_columns + ["r2"], [""] * (len(group_columns) + 1)] - agg_stats = report.groupby(group_columns)[[ - 'execution_time', 'estimated_cost', 'estimation_error', 'estimation_error_per_doc', - 'relative_error' - ]].agg([np.mean, np.std, np.min, np.max]) + agg_stats = report.groupby(group_columns)[ + [ + "execution_time", + "estimated_cost", + "estimation_error", + "estimation_error_per_doc", + "relative_error", + ] + ].agg([np.mean, np.std, np.min, np.max]) report = pd.merge(r2_scores, agg_stats, on=group_columns) - del report['abt_type_id'] + del report["abt_type_id"] return report @@ -434,7 +638,6 @@ async def end2end(e2e_config: config.EntToEndTestingConfig): # 1. Database Instance provides connectivity to a MongoDB instance, it loads data optionally # from the dump on creating and stores data optionally to the dump on closing. with DatabaseInstance(e2e_config.database) as database: - # 2. Data generation (optional), generates random data and populates collections with it. generator = DataGenerator(database, e2e_config.data_generator) await generator.populate_collections() @@ -443,10 +646,12 @@ async def end2end(e2e_config: config.EntToEndTestingConfig): # It runs the pipelines and stores explains to the database. execution_query_functions = [execute_queries, execute_index_intersect_queries] for execute_query in execution_query_functions: - await execute_query(database, e2e_config.workload_execution, generator.collection_infos) + await execute_query( + database, e2e_config.workload_execution, generator.collection_infos + ) e2e_config.workload_execution.write_mode = config.WriteMode.APPEND - #4. Process end to end testing. Compare the estimated and actual costs and return results. + # 4. Process end to end testing. Compare the estimated and actual costs and return results. report = await conduct_end2end(database, e2e_config.processor) if e2e_config.result_csv_filepath is not None: report.to_csv(e2e_config.result_csv_filepath, index=False) @@ -457,7 +662,7 @@ async def main(): await end2end(e2e_config) -if __name__ == '__main__': +if __name__ == "__main__": loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: diff --git a/buildscripts/cost_model/execution_tree.py b/buildscripts/cost_model/execution_tree.py index 5ad55130c9d..fa7bc4c06eb 100644 --- a/buildscripts/cost_model/execution_tree.py +++ b/buildscripts/cost_model/execution_tree.py @@ -34,7 +34,7 @@ from typing import Optional import bson.json_util as json -__all__ = ['Node', 'build_execution_tree'] +__all__ = ["Node", "build_execution_tree"] @dataclass @@ -51,7 +51,9 @@ class Node: def get_execution_time(self): """Execution time of the SBE node without execuion time of its children.""" - return self.total_execution_time - sum(n.total_execution_time for n in self.children) + return self.total_execution_time - sum( + n.total_execution_time for n in self.children + ) def print(self, level=0): """Pretty print of the SBE tree.""" @@ -64,126 +66,152 @@ class Node: def build_execution_tree(execution_stats: dict[str, any]) -> Node: """Build SBE executioon tree from 'executionStats' field of query explain.""" - assert execution_stats['executionSuccess'] - return process_stage(execution_stats['executionStages']) + assert execution_stats["executionSuccess"] + return process_stage(execution_stats["executionStages"]) def process_stage(stage: dict[str, any]) -> Node: """Parse the given SBE stage.""" processors = { - 'filter': process_filter, - 'cfilter': process_filter, - 'traverse': process_traverse, - 'project': process_inner_node, - 'limit': process_inner_node, - 'ixscan_generic': process_seek, - 'scan': process_seek, - 'coscan': process_leaf_node, - 'nlj': process_nlj, - 'hj': process_hash_join_node, - 'mj': process_hash_join_node, - 'seek': process_seek, - 'ixseek': process_seek, - 'limitskip': process_inner_node, - 'group': process_inner_node, - 'union': process_union_node, - 'unique': process_unique_node, - 'unwind': process_unwind_node, - 'branch': process_branch_node, + "filter": process_filter, + "cfilter": process_filter, + "traverse": process_traverse, + "project": process_inner_node, + "limit": process_inner_node, + "ixscan_generic": process_seek, + "scan": process_seek, + "coscan": process_leaf_node, + "nlj": process_nlj, + "hj": process_hash_join_node, + "mj": process_hash_join_node, + "seek": process_seek, + "ixseek": process_seek, + "limitskip": process_inner_node, + "group": process_inner_node, + "union": process_union_node, + "unique": process_unique_node, + "unwind": process_unwind_node, + "branch": process_branch_node, } - processor = processors.get(stage['stage']) + processor = processors.get(stage["stage"]) if processor is None: print(json.dumps(stage, indent=4)) - raise ValueError(f'Unknown stage: {stage}') + raise ValueError(f"Unknown stage: {stage}") return processor(stage) def process_filter(stage: dict[str, any]) -> Node: """Process filter stage.""" - input_stage = process_stage(stage['inputStage']) - return Node(**get_common_fields(stage), n_processed=stage['numTested'], children=[input_stage]) + input_stage = process_stage(stage["inputStage"]) + return Node( + **get_common_fields(stage), + n_processed=stage["numTested"], + children=[input_stage], + ) def process_traverse(stage: dict[str, any]) -> Node: """Process traverse, not used by Bonsai.""" - outer_stage = process_stage(stage['outerStage']) - inner_stage = process_stage(stage['innerStage']) - return Node(**get_common_fields(stage), n_processed=stage['nReturned'], - children=[outer_stage, inner_stage]) + outer_stage = process_stage(stage["outerStage"]) + inner_stage = process_stage(stage["innerStage"]) + return Node( + **get_common_fields(stage), + n_processed=stage["nReturned"], + children=[outer_stage, inner_stage], + ) def process_hash_join_node(stage: dict[str, any]) -> Node: """Process hj node.""" - outer_stage = process_stage(stage['outerStage']) - inner_stage = process_stage(stage['innerStage']) + outer_stage = process_stage(stage["outerStage"]) + inner_stage = process_stage(stage["innerStage"]) n_processed = outer_stage.n_returned + inner_stage.n_returned - return Node(**get_common_fields(stage), n_processed=n_processed, - children=[outer_stage, inner_stage]) + return Node( + **get_common_fields(stage), + n_processed=n_processed, + children=[outer_stage, inner_stage], + ) def process_nlj(stage: dict[str, any]) -> Node: """Process nlj stage.""" - outer_stage = process_stage(stage['outerStage']) - inner_stage = process_stage(stage['innerStage']) - n_processed = stage['totalDocsExamined'] - return Node(**get_common_fields(stage), n_processed=n_processed, - children=[outer_stage, inner_stage]) + outer_stage = process_stage(stage["outerStage"]) + inner_stage = process_stage(stage["innerStage"]) + n_processed = stage["totalDocsExamined"] + return Node( + **get_common_fields(stage), + n_processed=n_processed, + children=[outer_stage, inner_stage], + ) def process_inner_node(stage: dict[str, any]) -> Node: """Process SBE stage with one input stage.""" - input_stage = process_stage(stage['inputStage']) - return Node(**get_common_fields(stage), n_processed=input_stage.n_returned, - children=[input_stage]) + input_stage = process_stage(stage["inputStage"]) + return Node( + **get_common_fields(stage), + n_processed=input_stage.n_returned, + children=[input_stage], + ) def process_leaf_node(stage: dict[str, any]) -> Node: """Process SBE stage without input stages.""" - return Node(**get_common_fields(stage), n_processed=stage['nReturned'], children=[]) + return Node(**get_common_fields(stage), n_processed=stage["nReturned"], children=[]) def process_seek(stage: dict[str, any]) -> Node: """Process seek stage.""" - return Node(**get_common_fields(stage), n_processed=stage['numReads'], children=[]) + return Node(**get_common_fields(stage), n_processed=stage["numReads"], children=[]) def process_union_node(stage: dict[str, any]) -> Node: """Process union stage.""" - children = [process_stage(child) for child in stage['inputStages']] - return Node(**get_common_fields(stage), n_processed=stage['nReturned'], children=children) + children = [process_stage(child) for child in stage["inputStages"]] + return Node( + **get_common_fields(stage), n_processed=stage["nReturned"], children=children + ) def process_unwind_node(stage: dict[str, any]) -> Node: """Process unwind stage.""" - input_stage = process_stage(stage['inputStage']) - return Node(**get_common_fields(stage), n_processed=input_stage.n_returned, - children=[input_stage]) + input_stage = process_stage(stage["inputStage"]) + return Node( + **get_common_fields(stage), + n_processed=input_stage.n_returned, + children=[input_stage], + ) def process_unique_node(stage: dict[str, any]) -> Node: """Process unique stage.""" - input_stage = process_stage(stage['inputStage']) - n_processed = stage['dupsTested'] - return Node(**get_common_fields(stage), n_processed=n_processed, children=[input_stage]) + input_stage = process_stage(stage["inputStage"]) + n_processed = stage["dupsTested"] + return Node( + **get_common_fields(stage), n_processed=n_processed, children=[input_stage] + ) def process_branch_node(stage: dict[str, any]) -> Node: """Process unique stage.""" - then_stage = process_stage(stage['thenStage']) - else_stage = process_stage(stage['elseStage']) + then_stage = process_stage(stage["thenStage"]) + else_stage = process_stage(stage["elseStage"]) n_processed = then_stage.n_returned + else_stage.n_returned - return Node(**get_common_fields(stage), n_processed=n_processed, - children=[then_stage, else_stage]) + return Node( + **get_common_fields(stage), + n_processed=n_processed, + children=[then_stage, else_stage], + ) def get_common_fields(json_stage: dict[str, any]) -> dict[str, any]: """Exctract common field from json representation of SBE stage.""" return { - 'stage': json_stage['stage'], - 'plan_node_id': json_stage['planNodeId'], - 'total_execution_time': json_stage['executionTimeNanos'], - 'n_returned': json_stage['nReturned'], - 'seeks': json_stage.get('seeks'), + "stage": json_stage["stage"], + "plan_node_id": json_stage["planNodeId"], + "total_execution_time": json_stage["executionTimeNanos"], + "n_returned": json_stage["nReturned"], + "seeks": json_stage.get("seeks"), } diff --git a/buildscripts/cost_model/experiment.py b/buildscripts/cost_model/experiment.py index f29346a6d62..1a6a6d9dbf7 100644 --- a/buildscripts/cost_model/experiment.py +++ b/buildscripts/cost_model/experiment.py @@ -109,21 +109,26 @@ from sklearn.linear_model import LinearRegression from sklearn.metrics import r2_score -async def load_calibration_data(database: DatabaseInstance, collection_name: str) -> pd.DataFrame: +async def load_calibration_data( + database: DatabaseInstance, collection_name: str +) -> pd.DataFrame: """Load workflow data containing explain output from database and parse it. Retuned calibration DataFrame with parsed SBE and ABT.""" data = await database.get_all_documents(collection_name) df = pd.DataFrame(data) - df['sbe'] = df.explain.apply(lambda e: sbe.build_execution_tree( - json.loads(e)['executionStats'])) - df['abt'] = df.explain.apply(lambda e: abt.build( - json.loads(e)['queryPlanner']['winningPlan']['queryPlan'])) - df['total_execution_time'] = df.sbe.apply(lambda t: t.total_execution_time) + df["sbe"] = df.explain.apply( + lambda e: sbe.build_execution_tree(json.loads(e)["executionStats"]) + ) + df["abt"] = df.explain.apply( + lambda e: abt.build(json.loads(e)["queryPlanner"]["winningPlan"]["queryPlan"]) + ) + df["total_execution_time"] = df.sbe.apply(lambda t: t.total_execution_time) return df -def remove_outliers(df: pd.DataFrame, lower_percentile: float = 0.1, - upper_percentile: float = 0.9) -> pd.DataFrame: +def remove_outliers( + df: pd.DataFrame, lower_percentile: float = 0.1, upper_percentile: float = 0.9 +) -> pd.DataFrame: """Remove the outliers from the parsed calibration DataFrame.""" def is_not_outlier(df_seq): @@ -131,8 +136,11 @@ def remove_outliers(df: pd.DataFrame, lower_percentile: float = 0.1, high = df_seq.quantile(upper_percentile) return (df_seq >= low) & (df_seq <= high) - return df[df.groupby(['run_id', 'collection', - 'pipeline']).total_execution_time.transform(is_not_outlier).eq(1)] + return df[ + df.groupby(["run_id", "collection", "pipeline"]) + .total_execution_time.transform(is_not_outlier) + .eq(1) + ] def extract_sbe_stages(df: pd.DataFrame) -> pd.DataFrame: @@ -140,18 +148,24 @@ def extract_sbe_stages(df: pd.DataFrame) -> pd.DataFrame: def flatten_sbe_stages(explain): def traverse(node, stages): - execution_time = node['executionTimeNanos'] - children_fields = ['innerStage', 'outerStage', 'inputStage', 'thenStage', 'elseStage'] + execution_time = node["executionTimeNanos"] + children_fields = [ + "innerStage", + "outerStage", + "inputStage", + "thenStage", + "elseStage", + ] for field in children_fields: if field in node and node[field]: child = node[field] - execution_time -= child['executionTimeNanos'] + execution_time -= child["executionTimeNanos"] traverse(child, stages) del node[field] - node['executionTime'] = execution_time + node["executionTime"] = execution_time stages.append(node) - sbe_tree = json.loads(explain)['executionStats']['executionStages'] + sbe_tree = json.loads(explain)["executionStats"]["executionStages"] result = [] traverse(sbe_tree, result) return result @@ -168,15 +182,18 @@ def extract_abt_nodes(df: pd.DataFrame) -> pd.DataFrame: """Extract ABT Nodes and execution statistics from calibration DataFrame.""" def extract(df_seq): - es_dict = extract_execution_stats(df_seq['sbe'], df_seq['abt'], []) + es_dict = extract_execution_stats(df_seq["sbe"], df_seq["abt"], []) rows = [] for abt_type, es in es_dict.items(): for stat in es: row = { - 'abt_type': abt_type, **dataclasses.asdict(stat), - **json.loads(df_seq['query_parameters']), 'run_id': df_seq.run_id, - 'pipeline': df_seq.pipeline, 'source': df_seq.name + "abt_type": abt_type, + **dataclasses.asdict(stat), + **json.loads(df_seq["query_parameters"]), + "run_id": df_seq.run_id, + "pipeline": df_seq.pipeline, + "source": df_seq.name, } rows.append(row) return rows @@ -187,13 +204,15 @@ def extract_abt_nodes(df: pd.DataFrame) -> pd.DataFrame: def print_trees(calibration_df: pd.DataFrame, abt_df: pd.DataFrame, row_index: int = 0): """Print SBE and ABT Trees.""" row = calibration_df.loc[abt_df.iloc[row_index].source] - print('SBE') + print("SBE") row.sbe.print() - print('\nABT') + print("\nABT") row.abt.print() -def print_explain(calibration_df: pd.DataFrame, abt_df: pd.DataFrame, row_index: int = 0): +def print_explain( + calibration_df: pd.DataFrame, abt_df: pd.DataFrame, row_index: int = 0 +): """Print explain.""" row = calibration_df.loc[abt_df.iloc[row_index].source] explain = json.loads(row.explain) @@ -205,34 +224,38 @@ def calibrate(abt_node_df: pd.DataFrame, variables: list[str] = None): """Calibrate the ABT node given in abd_node_df with the given model input variables.""" # pylint: disable=invalid-name if variables is None: - variables = ['n_processed'] - y = abt_node_df['execution_time'] + variables = ["n_processed"] + y = abt_node_df["execution_time"] X = abt_node_df[variables] X = sm.add_constant(X) nnls = LinearRegression(positive=True, fit_intercept=False) model = nnls.fit(X, y) y_pred = model.predict(X) - print(f'R2: {r2_score(y, y_pred)}') - print(f'Coefficients: {model.coef_}') + print(f"R2: {r2_score(y, y_pred)}") + print(f"Coefficients: {model.coef_}") - sns.scatterplot(x=abt_node_df['n_processed'], y=abt_node_df['execution_time']) - sns.lineplot(x=abt_node_df['n_processed'], y=y_pred, color='red') + sns.scatterplot(x=abt_node_df["n_processed"], y=abt_node_df["execution_time"]) + sns.lineplot(x=abt_node_df["n_processed"], y=y_pred, color="red") -if __name__ == '__main__': +if __name__ == "__main__": import asyncio from config import DatabaseConfig async def test(): """Smoke tests.""" - database_config = DatabaseConfig(connection_string='mongodb://localhost', - database_name='abt_calibration', dump_path='', - restore_from_dump=False, dump_on_exit=False) + database_config = DatabaseConfig( + connection_string="mongodb://localhost", + database_name="abt_calibration", + dump_path="", + restore_from_dump=False, + dump_on_exit=False, + ) database = DatabaseInstance(database_config) - raw_df = await load_calibration_data(database, 'calibrationData') + raw_df = await load_calibration_data(database, "calibrationData") print(raw_df.head()) cleaned_df = remove_outliers(raw_df, 0.0, 0.9) @@ -241,7 +264,7 @@ if __name__ == '__main__': sbe_stages_df = extract_sbe_stages(cleaned_df) print(sbe_stages_df.head()) - seek_df = get_sbe_stage(sbe_stages_df, 'seek') + seek_df = get_sbe_stage(sbe_stages_df, "seek") print(seek_df.head()) abt_nodes_df = extract_abt_nodes(cleaned_df) diff --git a/buildscripts/cost_model/parameters_extractor.py b/buildscripts/cost_model/parameters_extractor.py index 5b4be46ab15..9beafad4bde 100644 --- a/buildscripts/cost_model/parameters_extractor.py +++ b/buildscripts/cost_model/parameters_extractor.py @@ -40,20 +40,20 @@ from cost_estimator import CostModelParameters, ExecutionStats from database_instance import DatabaseInstance from workload_execution import QueryParameters -__all__ = ['extract_parameters', 'extract_execution_stats'] +__all__ = ["extract_parameters", "extract_execution_stats"] async def extract_parameters( - config: AbtCalibratorConfig, database: DatabaseInstance, - abt_types: Sequence[str]) -> Mapping[str, Sequence[CostModelParameters]]: + config: AbtCalibratorConfig, database: DatabaseInstance, abt_types: Sequence[str] +) -> Mapping[str, Sequence[CostModelParameters]]: """Read measurements from database and extract cost model parameters for the given ABT types.""" stats = defaultdict(list) docs = await database.get_all_documents(config.input_collection_name) for result in docs: - explain = json.loads(result['explain']) - query_parameters = QueryParameters.from_json(result['query_parameters']) + explain = json.loads(result["explain"]) + query_parameters = QueryParameters.from_json(result["query_parameters"]) res = parse_explain(explain, abt_types) for abt_type, es in res.items(): stats[abt_type] += [ @@ -65,10 +65,12 @@ async def extract_parameters( return stats -Node = TypeVar('Node') +Node = TypeVar("Node") -def find_abt_node_by_type(root: physical_tree.Node, abt_type: str) -> Sequence[physical_tree.Node]: +def find_abt_node_by_type( + root: physical_tree.Node, abt_type: str +) -> Sequence[physical_tree.Node]: """Find ABT node by its type.""" return find_nodes(root, lambda node: node.node_type == abt_type) @@ -111,26 +113,28 @@ def get_excution_stats(root: execution_tree.Node, node_id: int) -> ExecutionStat assert n_sbe_nodes <= 1 - return ExecutionStats(execution_time=execution_time, n_returned=n_returned, - n_processed=n_processed) + return ExecutionStats( + execution_time=execution_time, n_returned=n_returned, n_processed=n_processed + ) def parse_explain(explain: Mapping[str, any], abt_types: Sequence[str]): """Extract ExecutionStats from the given explain for the given ABT types.""" try: - et = execution_tree.build_execution_tree(explain['executionStats']) - pt = physical_tree.build(explain['queryPlanner']['winningPlan']['queryPlan']) + et = execution_tree.build_execution_tree(explain["executionStats"]) + pt = physical_tree.build(explain["queryPlanner"]["winningPlan"]["queryPlan"]) except Exception as exception: - print(f'*** Failed to parse explain with the followinf error: {exception}') + print(f"*** Failed to parse explain with the followinf error: {exception}") print(explain) raise exception return extract_execution_stats(et, pt, abt_types) -def extract_execution_stats(et: execution_tree.Node, pt: physical_tree.Node, - abt_types: Sequence[str]) -> Mapping[str, Sequence[ExecutionStats]]: +def extract_execution_stats( + et: execution_tree.Node, pt: physical_tree.Node, abt_types: Sequence[str] +) -> Mapping[str, Sequence[ExecutionStats]]: """Extract ExecutionStats from the given SBE and ABT trees for the given ABT types.""" if len(abt_types) == 0: @@ -144,7 +148,7 @@ def extract_execution_stats(et: execution_tree.Node, pt: physical_tree.Node, result[abt_type].append(execution_stats) return result except AssertionError as ae: - print(f'{pt.node_type} {ae} {pt}') + print(f"{pt.node_type} {ae} {pt}") raise ae diff --git a/buildscripts/cost_model/physical_tree.py b/buildscripts/cost_model/physical_tree.py index e8e31d2ea67..727996d01bc 100644 --- a/buildscripts/cost_model/physical_tree.py +++ b/buildscripts/cost_model/physical_tree.py @@ -31,7 +31,7 @@ from __future__ import annotations from dataclasses import dataclass -__all__ = ['Node', 'build'] +__all__ = ["Node", "build"] @dataclass @@ -62,20 +62,25 @@ def build(optimizer_plan: dict[str, any]) -> Node: def parse_optimizer_node(explain_node: dict[str, any]) -> Node: """Recursively parse ABT node from query explain's node.""" children = get_children(explain_node) - properties = explain_node['properties'] - return Node(node_type=explain_node['nodeType'], plan_node_id=properties['planNodeID'], - cost=properties['cost'], local_cost=properties['localCost'], - adjusted_ce=properties['adjustedCE'], children=children) + properties = explain_node["properties"] + return Node( + node_type=explain_node["nodeType"], + plan_node_id=properties["planNodeID"], + cost=properties["cost"], + local_cost=properties["localCost"], + adjusted_ce=properties["adjustedCE"], + children=children, + ) def get_children(explain_node: dict[str, any]) -> list[Node]: """Get children nodes of the ABT node.""" - if 'children' in explain_node: - children = [parse_optimizer_node(child) for child in explain_node['children']] + if "children" in explain_node: + children = [parse_optimizer_node(child) for child in explain_node["children"]] else: children = [] - children_refs = ['child', 'leftChild', 'rightChild'] + children_refs = ["child", "leftChild", "rightChild"] for ref in children_refs: if ref in explain_node: diff --git a/buildscripts/cost_model/qsn_costing_parameters.py b/buildscripts/cost_model/qsn_costing_parameters.py index 33cf36b882e..8b8c0fcb766 100644 --- a/buildscripts/cost_model/qsn_costing_parameters.py +++ b/buildscripts/cost_model/qsn_costing_parameters.py @@ -34,7 +34,7 @@ import pandas as pd import query_solution_tree as qsn from workload_execution import QueryParameters -Node = TypeVar('Node') +Node = TypeVar("Node") def parse_explain(explain: dict[str, any]) -> (qsn.Node, sbe.Node): @@ -83,41 +83,69 @@ class ParametersBuilder: def buildDataFrame(self) -> pd.DataFrame: return pd.DataFrame( - self.rows, columns=[ - 'stage', 'execution_time', 'n_processed', 'seeks', 'note', 'keys_length_in_bytes', - 'average_document_size_in_bytes', 'number_of_fields' - ]) + self.rows, + columns=[ + "stage", + "execution_time", + "n_processed", + "seeks", + "note", + "keys_length_in_bytes", + "average_document_size_in_bytes", + "number_of_fields", + ], + ) def _process(self, qsn_node: qsn.Node, sbe_tree: sbe.Node, params: QueryParameters): processor = self._get_processor(qsn_node.node_type) - self.rows.append(processor(qsn_node.node_type, qsn_node.plan_node_id, sbe_tree, params)) + self.rows.append( + processor(qsn_node.node_type, qsn_node.plan_node_id, sbe_tree, params) + ) for child in qsn_node.children: self._process(child, sbe_tree, params) def _get_processor(self, stage: str): return self.processors.get(stage, self.default_processor) - def _process_generic(self, stage: str, node_id: int, sbe_tree: sbe.Node, - params: QueryParameters): - nodes: list[sbe.Node] = find_nodes(sbe_tree, lambda node: node.plan_node_id == node_id) + def _process_generic( + self, stage: str, node_id: int, sbe_tree: sbe.Node, params: QueryParameters + ): + nodes: list[sbe.Node] = find_nodes( + sbe_tree, lambda node: node.plan_node_id == node_id + ) if len(nodes) == 0: raise ValueError(f"Cannot find sbe nodes of {stage}") return ParametersBuilder._build_row( - stage, params, execution_time=sum([node.get_execution_time() for node in nodes]), - n_processed=max([node.n_processed for node in nodes]), seeks=sum( - [node.seeks for node in nodes if node.seeks])) + stage, + params, + execution_time=sum([node.get_execution_time() for node in nodes]), + n_processed=max([node.n_processed for node in nodes]), + seeks=sum([node.seeks for node in nodes if node.seeks]), + ) @staticmethod - def _build_row(stage: str, params: QueryParameters, execution_time: int = None, - n_processed: int = None, seeks: int = None): + def _build_row( + stage: str, + params: QueryParameters, + execution_time: int = None, + n_processed: int = None, + seeks: int = None, + ): return [ - stage, execution_time, n_processed, seeks, params.note, params.keys_length_in_bytes, - params.average_document_size_in_bytes, params.number_of_fields + stage, + execution_time, + n_processed, + seeks, + params.note, + params.keys_length_in_bytes, + params.average_document_size_in_bytes, + params.number_of_fields, ] if __name__ == "__main__": import json + explain = """ { "explainVersion": "2", @@ -746,7 +774,7 @@ if __name__ == "__main__": qsn_tree.print() sbe_tree.print() - params = QueryParameters(10, 2000, 'rooted-or') + params = QueryParameters(10, 2000, "rooted-or") builder = ParametersBuilder() builder.process(explainJson, params) df = builder.buildDataFrame() diff --git a/buildscripts/cost_model/query_solution_tree.py b/buildscripts/cost_model/query_solution_tree.py index 7415b5dd4bf..d99d3713a19 100644 --- a/buildscripts/cost_model/query_solution_tree.py +++ b/buildscripts/cost_model/query_solution_tree.py @@ -31,7 +31,7 @@ from __future__ import annotations from dataclasses import dataclass -__all__ = ['Node', 'build'] +__all__ = ["Node", "build"] @dataclass @@ -59,17 +59,22 @@ def parse_optimizer_node(explain_node: dict[str, any]) -> Node: """Recursively parse QSN from query explain's node.""" children = get_children(explain_node) - return Node(node_type=explain_node['stage'], plan_node_id=explain_node['planNodeId'], - children=children) + return Node( + node_type=explain_node["stage"], + plan_node_id=explain_node["planNodeId"], + children=children, + ) def get_children(explain_node: dict[str, any]) -> list[Node]: """Get children nodes of the QSN.""" - if 'inputStage' in explain_node: - children = [parse_optimizer_node(explain_node['inputStage'])] - elif 'inputStages' in explain_node: - children = [parse_optimizer_node(child) for child in explain_node['inputStages']] + if "inputStage" in explain_node: + children = [parse_optimizer_node(explain_node["inputStage"])] + elif "inputStages" in explain_node: + children = [ + parse_optimizer_node(child) for child in explain_node["inputStages"] + ] else: children = [] return children @@ -77,6 +82,7 @@ def get_children(explain_node: dict[str, any]) -> list[Node]: if __name__ == "__main__": import json + explain = """ { "explainVersion": "2", diff --git a/buildscripts/cost_model/random_generator.py b/buildscripts/cost_model/random_generator.py index 4ebaab68d09..07d351c7512 100644 --- a/buildscripts/cost_model/random_generator.py +++ b/buildscripts/cost_model/random_generator.py @@ -38,9 +38,9 @@ from typing import Generic, Sequence, TypeVar import numpy as np -__all__ = ['RangeGenerator', 'DataType', 'RandomDistribution'] +__all__ = ["RangeGenerator", "DataType", "RandomDistribution"] -TVar = TypeVar('TVar', str, int, float, datetime) +TVar = TypeVar("TVar", str, int, float, datetime) class DataType(Enum): @@ -61,18 +61,18 @@ class DataType(Enum): def __str__(self): typenames = { - DataType.DOUBLE: 'dbl', - DataType.STRING: 'str', - DataType.OBJECT: 'obj', - DataType.ARRAY: 'arr', - DataType.OBJECTID: 'oid', - DataType.BOOLEAN: 'bool', - DataType.DATE: 'dt', - DataType.NULL: 'null', - DataType.INTEGER: 'int', - DataType.TIMESTAMP: 'ts', - DataType.DECIMAL128: 'dec', - DataType.MIXDATA: 'mixdata', + DataType.DOUBLE: "dbl", + DataType.STRING: "str", + DataType.OBJECT: "obj", + DataType.ARRAY: "arr", + DataType.OBJECTID: "oid", + DataType.BOOLEAN: "bool", + DataType.DATE: "dt", + DataType.NULL: "null", + DataType.INTEGER: "int", + DataType.TIMESTAMP: "ts", + DataType.DECIMAL128: "dec", + DataType.MIXDATA: "mixdata", } return typenames[self] @@ -89,7 +89,8 @@ class RangeGenerator(Generic[TVar]): def __post_init__(self): assert type(self.interval_begin) == type( - self.interval_end), 'Interval ends must of the same type.' + self.interval_end + ), "Interval ends must of the same type." if type(self.interval_begin) == int or type(self.interval_begin) == float: self.ndv = round((self.interval_end - self.interval_begin) / self.step) elif type(self.interval_begin) == datetime: @@ -101,37 +102,39 @@ class RangeGenerator(Generic[TVar]): """Generate the range.""" gen_range_dict = { - DataType.STRING: - ansi_range, - DataType.INTEGER: - range, + DataType.STRING: ansi_range, + DataType.INTEGER: range, # The arange function produces equi-distant values which is too regular for CE testing. # It is left here as a possible way of generating doubles. # DataType.DOUBLE: np.arange - DataType.DOUBLE: - double_range, - DataType.DATE: - datetime_range, + DataType.DOUBLE: double_range, + DataType.DATE: datetime_range, } gen_range = gen_range_dict.get(self.data_type) if gen_range is None: - raise ValueError(f'Unsupported data type: {self.data_type}') + raise ValueError(f"Unsupported data type: {self.data_type}") return list(gen_range(self.interval_begin, self.interval_end, self.step)) def __str__(self): # TODO: for now skip NDV from the name to make it shorter. - #ndv_str = "_" if self.ndv <= 0 else f'_{self.ndv}_' - begin_str = str(self.interval_begin.date()) if isinstance( - self.interval_begin, datetime) else str(self.interval_begin) - end_str = str(self.interval_end.date()) if isinstance(self.interval_end, datetime) else str( - self.interval_end) + # ndv_str = "_" if self.ndv <= 0 else f'_{self.ndv}_' + begin_str = ( + str(self.interval_begin.date()) + if isinstance(self.interval_begin, datetime) + else str(self.interval_begin) + ) + end_str = ( + str(self.interval_end.date()) + if isinstance(self.interval_end, datetime) + else str(self.interval_end) + ) - str_rep = f'{str(self.data_type)}_{begin_str}-{end_str}-{self.step}' + str_rep = f"{str(self.data_type)}_{begin_str}-{end_str}-{self.step}" # Remove dots and spaces from field names. - str_rep = str_rep.replace('.', ',') - str_rep = str_rep.replace(' ', '_') + str_rep = str_rep.replace(".", ",") + str_rep = str_rep.replace(" ", "_") return str_rep @@ -145,14 +148,14 @@ def ansi_range(begin: str, end: str, step: int = 1): """Produces a sequence of string from begin to end.""" alphabet_size = 28 - non_alpha_char = '_' + non_alpha_char = "_" def ansi_to_int(data: str) -> int: res = 0 for char in data.lower(): res = res * alphabet_size - if 'a' <= char <= 'z': - res += ord(char) - ord('a') + 1 + if "a" <= char <= "z": + res += ord(char) - ord("a") + 1 else: res += alphabet_size - 1 return res @@ -164,10 +167,10 @@ def ansi_range(begin: str, end: str, step: int = 1): if remainder == alphabet_size - 1: char = non_alpha_char else: - char = chr(remainder + ord('a') - 1) + char = chr(remainder + ord("a") - 1) result.append(char) result.reverse() - return ''.join(result) + return "".join(result) def get_common_prefix_len(s1: str, s2: str): index = 0 @@ -188,7 +191,7 @@ def ansi_range(begin: str, end: str, step: int = 1): if prefix_len == 0: yield int_to_ansi(number) else: - yield f'{prefix}{int_to_ansi(number)}' + yield f"{prefix}{int_to_ansi(number)}" def datetime_range(begin: datetime, end: datetime, step: int = 60): @@ -199,8 +202,8 @@ def datetime_range(begin: datetime, end: datetime, step: int = 60): for _ in range(0, num_values): random_ts = np.random.randint(begin_ts, end_ts) yield datetime.fromtimestamp(random_ts) - #random_dates = [datetime.fromtimestamp(random_ts) for random_ts in random.sample(range(int(begin_ts), int(end_ts)), num_values)] - #return random_dates + # random_dates = [datetime.fromtimestamp(random_ts) for random_ts in random.sample(range(int(begin_ts), int(end_ts)), num_values)] + # return random_dates class DistributionType(Enum): @@ -226,8 +229,8 @@ class RandomDistribution: distribution_type: DistributionType values: Union[Sequence[TVar], RangeGenerator] weights: Union[Sequence[float], None] - values_name: str = '' - weights_name: str = '' + values_name: str = "" + weights_name: str = "" def __str__(self): def print_values(vals): @@ -235,56 +238,71 @@ class RandomDistribution: return str(vals) elif isinstance(vals[0], RandomDistribution): # Must be a mixed distribution - res = '' + res = "" for distr in vals: - res += f'{str(distr)}_' + res += f"{str(distr)}_" return res else: # All values are of the same type because of how RangeGenerator works - return f'{type(vals[0]).__name__}_{min(vals)}_{max(vals)}_{len(vals)}' + return f"{type(vals[0]).__name__}_{min(vals)}_{max(vals)}_{len(vals)}" - range_str = '' - if hasattr(self, 'values'): + range_str = "" + if hasattr(self, "values"): range_str = print_values(self.values) - if self.values_name != '': - range_str += f'_{self.values_name}' - if self.weights_name != '': - range_str += f'_{self.weights_name}' + if self.values_name != "": + range_str += f"_{self.values_name}" + if self.weights_name != "": + range_str += f"_{self.weights_name}" - distr_str = f'{str(self.distribution_type)}_{range_str}' + distr_str = f"{str(self.distribution_type)}_{range_str}" return distr_str @staticmethod - def choice(values: Sequence[TVar], weights: Union[Sequence[float], RangeGenerator], - v_name: str = '', w_name: str = ''): + def choice( + values: Sequence[TVar], + weights: Union[Sequence[float], RangeGenerator], + v_name: str = "", + w_name: str = "", + ): """Create choice distribution.""" - return RandomDistribution(distribution_type=DistributionType.CHOICE, values=values, - weights=weights, values_name=v_name, weights_name=w_name) + return RandomDistribution( + distribution_type=DistributionType.CHOICE, + values=values, + weights=weights, + values_name=v_name, + weights_name=w_name, + ) @staticmethod def normal(values: Union[Sequence[TVar], RangeGenerator]): """Create normal distribution.""" - return RandomDistribution(distribution_type=DistributionType.NORMAL, values=values, - weights=None) + return RandomDistribution( + distribution_type=DistributionType.NORMAL, values=values, weights=None + ) @staticmethod def noncentral_chisquare(values: Union[Sequence[TVar], RangeGenerator]): """Create Non Central Chi2 distribution.""" - return RandomDistribution(distribution_type=DistributionType.CHI2, values=values, - weights=None) + return RandomDistribution( + distribution_type=DistributionType.CHI2, values=values, weights=None + ) @staticmethod def uniform(values: Union[Sequence[TVar], RangeGenerator]): """Create uniform distribution.""" - return RandomDistribution(distribution_type=DistributionType.UNIFORM, values=values, - weights=None) + return RandomDistribution( + distribution_type=DistributionType.UNIFORM, values=values, weights=None + ) @staticmethod - def mixed(children: Sequence[RandomDistribution], - weight: Union[Sequence[float], RangeGenerator]): + def mixed( + children: Sequence[RandomDistribution], + weight: Union[Sequence[float], RangeGenerator], + ): """Create mixed distribution.""" - return RandomDistribution(distribution_type=DistributionType.MIXDIST, values=children, - weights=weight) + return RandomDistribution( + distribution_type=DistributionType.MIXDIST, values=children, weights=weight + ) def generate(self, size: int) -> Sequence[TVar]: """Generate random data sequence of the given size.""" @@ -306,7 +324,9 @@ class RandomDistribution: probs = None if probs is not None and len(probs) != len(values): - raise ValueError(f'values and probs must be the same size: {probs} !! {values}') + raise ValueError( + f"values and probs must be the same size: {probs} !! {values}" + ) if len(values) == 0: raise ValueError(f"Values cannot be empty: {self.values}") @@ -382,7 +402,9 @@ class RandomDistribution: index = len(values) - 1 return values[index] - return [get_value(n) for n in _rng.noncentral_chisquare(df=df, nonc=nonc, size=size)] + return [ + get_value(n) for n in _rng.noncentral_chisquare(df=df, nonc=nonc, size=size) + ] @staticmethod def _uniform(size: int, values: Sequence[TVar], _: Sequence[float]): @@ -393,15 +415,19 @@ class RandomDistribution: return [get_value(n) for n in _rng.uniform(low=0, high=len(values), size=size)] @staticmethod - def _mixed(size: int, children: Sequence[RandomDistribution], probs: Sequence[float]): + def _mixed( + size: int, children: Sequence[RandomDistribution], probs: Sequence[float] + ): if probs is None: - raise ValueError(f'probs must be specified for mixed distributions: {str(children)}') + raise ValueError( + f"probs must be specified for mixed distributions: {str(children)}" + ) result = [] for child_distr, prob in zip(children, probs): if not isinstance(child_distr, RandomDistribution): raise ValueError( - f'children must be of type RandomDistribution for mixed distribution, child_distr: {child_distr}' + f"children must be of type RandomDistribution for mixed distribution, child_distr: {child_distr}" ) child_size = int(size * prob) result.append(child_distr.generate(child_size)) @@ -419,14 +445,16 @@ class ArrayRandomDistribution(RandomDistribution): lengths_distr: RandomDistribution = _NO_DEFAULT value_distr: RandomDistribution = _NO_DEFAULT - def __init__(self, lengths_distr: RandomDistribution, value_distr: RandomDistribution): + def __init__( + self, lengths_distr: RandomDistribution, value_distr: RandomDistribution + ): self.lengths_distr = lengths_distr self.value_distr = value_distr self.distribution_type = value_distr.distribution_type def __str__(self): - distr_str = f'{super().__str__()}' - distr_str += f'array_{str(self.value_distr)}_{str(self.lengths_distr)}' + distr_str = f"{super().__str__()}" + distr_str += f"array_{str(self.value_distr)}_{str(self.lengths_distr)}" return distr_str def generate(self, size: int): @@ -450,8 +478,12 @@ class DocumentRandomDistribution(RandomDistribution): fields_distr: RandomDistribution = _NO_DEFAULT field_to_distribution: dict = _NO_DEFAULT - def __init__(self, number_of_fields_distr: RandomDistribution, fields_distr: RandomDistribution, - field_to_distribution: dict): + def __init__( + self, + number_of_fields_distr: RandomDistribution, + fields_distr: RandomDistribution, + field_to_distribution: dict, + ): self.number_of_fields_distr = number_of_fields_distr self.fields_distr = fields_distr self.field_to_distribution = field_to_distribution @@ -462,7 +494,7 @@ class DocumentRandomDistribution(RandomDistribution): raise ValueError("Must provide a RandomDistribution for each field") def __str__(self): - return f'{super().__str__()}' + return f"{super().__str__()}" def generate(self, size: int): """Generate random document sequence of the given size.""" @@ -479,7 +511,9 @@ class DocumentRandomDistribution(RandomDistribution): for idx, num in enumerate(nums): doc = {} if not isinstance(num, int): - raise ValueError("the number of fields must be an int for document generation") + raise ValueError( + "the number of fields must be an int for document generation" + ) field_names = self.fields_distr.generate(num) for field in field_names: @@ -494,12 +528,12 @@ class DocumentRandomDistribution(RandomDistribution): return self.fields_distr.get_values() -if __name__ == '__main__': +if __name__ == "__main__": from collections import Counter def print_distr(title, distr, size=10000): """Print distribution.""" - print(f'\n{title}: {str(distr)}\n') + print(f"\n{title}: {str(distr)}\n") rs = distr.generate(size) has_arrays = any(isinstance(elem, list) for elem in rs) has_dict = any(isinstance(elem, dict) for elem in rs) @@ -516,40 +550,61 @@ if __name__ == '__main__': for elem in rs: print(elem) - choice = RandomDistribution.choice(values=['pooh', 'rabbit', 'piglet', 'Chris'], - weights=[0.5, 0.1, 0.1, 0.3]) + choice = RandomDistribution.choice( + values=["pooh", "rabbit", "piglet", "Chris"], weights=[0.5, 0.1, 0.1, 0.3] + ) print_distr("Choice", choice, 1000) - string_generator = RangeGenerator(data_type=DataType.STRING, interval_begin='hello_a', - interval_end='hello__') + string_generator = RangeGenerator( + data_type=DataType.STRING, interval_begin="hello_a", interval_end="hello__" + ) str_normal = RandomDistribution.normal(string_generator) print_distr("Normal for strings", str_normal) - int_noncentral_chisquare = RandomDistribution.noncentral_chisquare(list(range(1, 30))) + int_noncentral_chisquare = RandomDistribution.noncentral_chisquare( + list(range(1, 30)) + ) print_distr("Noncentral Chisquare for integers", int_noncentral_chisquare) - float_uniform = RandomDistribution.uniform(RangeGenerator(DataType.DOUBLE, 0.1, 10.0, 0.37)) + float_uniform = RandomDistribution.uniform( + RangeGenerator(DataType.DOUBLE, 0.1, 10.0, 0.37) + ) print_distr("Uniform for floats", float_uniform) - float_normal = RandomDistribution.normal(RangeGenerator(DataType.DOUBLE, 0.1, 10.0, 0.37)) + float_normal = RandomDistribution.normal( + RangeGenerator(DataType.DOUBLE, 0.1, 10.0, 0.37) + ) print_distr("Normal for floats", float_normal) FOUR_DAYS_IN_SECONDS = 60 * 20 * 24 * 12 date_uniform = RandomDistribution.uniform( - RangeGenerator(DataType.DATE, datetime(2007, 1, 1), datetime(2008, 1, 1), - FOUR_DAYS_IN_SECONDS)) + RangeGenerator( + DataType.DATE, + datetime(2007, 1, 1), + datetime(2008, 1, 1), + FOUR_DAYS_IN_SECONDS, + ) + ) print_distr("Uniform for dates", date_uniform, size=1000) date_normal = RandomDistribution.normal( - RangeGenerator(DataType.DATE, datetime(2007, 1, 1), datetime(2008, 1, 1), - FOUR_DAYS_IN_SECONDS)) + RangeGenerator( + DataType.DATE, + datetime(2007, 1, 1), + datetime(2008, 1, 1), + FOUR_DAYS_IN_SECONDS, + ) + ) print_distr("Normal for dates", date_normal, size=1000) - str_chisquare2 = RandomDistribution.normal(RangeGenerator(DataType.STRING, "aa", "ba")) + str_chisquare2 = RandomDistribution.normal( + RangeGenerator(DataType.STRING, "aa", "ba") + ) str_normal2 = RandomDistribution.normal(RangeGenerator(DataType.STRING, "ap", "bp")) - mixed = RandomDistribution.mixed(children=[float_uniform, str_chisquare2, str_normal2], - weight=[0.3, 0.5, 0.2]) + mixed = RandomDistribution.mixed( + children=[float_uniform, str_chisquare2, str_normal2], weight=[0.3, 0.5, 0.2] + ) print_distr("Mixed", mixed, 20_000) int_normal = RandomDistribution.normal(RangeGenerator(DataType.INTEGER, 2, 10)) @@ -557,24 +612,33 @@ if __name__ == '__main__': arr_distr = ArrayRandomDistribution(int_normal, mixed) print_distr("Mixed Arrays", arr_distr, 100) - mixed_with_arrays = RandomDistribution.mixed(children=[float_uniform, str_normal2, arr_distr], - weight=[0.3, 0.2, 0.5]) + mixed_with_arrays = RandomDistribution.mixed( + children=[float_uniform, str_normal2, arr_distr], weight=[0.3, 0.2, 0.5] + ) nested_arr_distr = ArrayRandomDistribution(int_normal, mixed_with_arrays) print_distr("Mixed Nested Arrays", nested_arr_distr, 100) simple_doc_distr = DocumentRandomDistribution( RandomDistribution.normal(RangeGenerator(DataType.INTEGER, 1, 2)), - RandomDistribution.uniform(["obj"]), {"obj": int_normal}) + RandomDistribution.uniform(["obj"]), + {"obj": int_normal}, + ) - field_name_choice = RandomDistribution.uniform(['a', 'b', 'c', 'd', 'e', 'f']) + field_name_choice = RandomDistribution.uniform(["a", "b", "c", "d", "e", "f"]) field_to_distr = { - 'a': int_normal, 'b': str_normal, 'c': mixed, 'd': arr_distr, 'e': nested_arr_distr, - 'f': simple_doc_distr + "a": int_normal, + "b": str_normal, + "c": mixed, + "d": arr_distr, + "e": nested_arr_distr, + "f": simple_doc_distr, } nested_doc_distr = DocumentRandomDistribution( - RandomDistribution.normal(RangeGenerator(DataType.INTEGER, 0, 7)), field_name_choice, - field_to_distr) + RandomDistribution.normal(RangeGenerator(DataType.INTEGER, 0, 7)), + field_name_choice, + field_to_distr, + ) print_distr("Nested Document generation", nested_doc_distr, 100) diff --git a/buildscripts/cost_model/start.py b/buildscripts/cost_model/start.py index f15a67a90d7..bdc65126541 100644 --- a/buildscripts/cost_model/start.py +++ b/buildscripts/cost_model/start.py @@ -46,160 +46,228 @@ from workload_execution import Query, QueryParameters __all__ = [] -def save_to_csv(parameters: Mapping[str, Sequence[CostModelParameters]], filepath: str) -> None: +def save_to_csv( + parameters: Mapping[str, Sequence[CostModelParameters]], filepath: str +) -> None: """Save model input parameters to a csv file.""" - abt_type_name = 'abt_type' + abt_type_name = "abt_type" fieldnames = [ - abt_type_name, *[f.name for f in dataclasses.fields(ExecutionStats)], - *[f.name for f in dataclasses.fields(QueryParameters)] + abt_type_name, + *[f.name for f in dataclasses.fields(ExecutionStats)], + *[f.name for f in dataclasses.fields(QueryParameters)], ] - with open(filepath, 'w', newline='') as csvfile: + with open(filepath, "w", newline="") as csvfile: writer = csv.DictWriter(csvfile, fieldnames=fieldnames) writer.writeheader() for abt_type, type_params_list in parameters.items(): for type_params in type_params_list: - fields = dataclasses.asdict(type_params.execution_stats) | dataclasses.asdict( - type_params.query_params) + fields = dataclasses.asdict( + type_params.execution_stats + ) | dataclasses.asdict(type_params.query_params) fields[abt_type_name] = abt_type writer.writerow(fields) -async def execute_index_scan_queries(database: DatabaseInstance, - collections: Sequence[CollectionInfo]): - collection = [ci for ci in collections if ci.name.startswith('index_scan')][0] - fields = [f for f in collection.fields if f.name == 'choice'] +async def execute_index_scan_queries( + database: DatabaseInstance, collections: Sequence[CollectionInfo] +): + collection = [ci for ci in collections if ci.name.startswith("index_scan")][0] + fields = [f for f in collection.fields if f.name == "choice"] requests = [] for field in fields: for val in field.distribution.get_values(): - if val.startswith('_'): + if val.startswith("_"): continue keys_length = len(val) + 2 requests.append( - Query(pipeline=[{'$match': {field.name: val}}], keys_length_in_bytes=keys_length, - note='IndexScan')) + Query( + pipeline=[{"$match": {field.name: val}}], + keys_length_in_bytes=keys_length, + note="IndexScan", + ) + ) - await workload_execution.execute(database, main_config.workload_execution, [collection], - requests) + await workload_execution.execute( + database, main_config.workload_execution, [collection], requests + ) -async def execute_physical_scan_queries(database: DatabaseInstance, - collections: Sequence[CollectionInfo]): - collections = [ci for ci in collections if ci.name.startswith('physical_scan')] - fields = [f for f in collections[0].fields if f.name == 'choice'] +async def execute_physical_scan_queries( + database: DatabaseInstance, collections: Sequence[CollectionInfo] +): + collections = [ci for ci in collections if ci.name.startswith("physical_scan")] + fields = [f for f in collections[0].fields if f.name == "choice"] requests = [] for field in fields: for val in field.distribution.get_values()[::3]: - if val.startswith('_'): + if val.startswith("_"): continue keys_length = len(val) + 2 requests.append( - Query(pipeline=[{'$match': {field.name: val}}, {"$limit": 10}], - keys_length_in_bytes=keys_length, note='PhysicalScan')) + Query( + pipeline=[{"$match": {field.name: val}}, {"$limit": 10}], + keys_length_in_bytes=keys_length, + note="PhysicalScan", + ) + ) - await workload_execution.execute(database, main_config.workload_execution, collections, - requests) + await workload_execution.execute( + database, main_config.workload_execution, collections, requests + ) -async def execute_index_intersections_with_requests(database: DatabaseInstance, - collections: Sequence[CollectionInfo], - requests: Sequence[Query]): +async def execute_index_intersections_with_requests( + database: DatabaseInstance, + collections: Sequence[CollectionInfo], + requests: Sequence[Query], +): try: - await database.set_parameter('internalCostModelCoefficients', - '{"filterIncrementalCost": 10000.0}') - await database.set_parameter('internalCascadesOptimizerDisableMergeJoinRIDIntersect', False) - await database.set_parameter('internalCascadesOptimizerDisableHashJoinRIDIntersect', False) + await database.set_parameter( + "internalCostModelCoefficients", '{"filterIncrementalCost": 10000.0}' + ) + await database.set_parameter( + "internalCascadesOptimizerDisableMergeJoinRIDIntersect", False + ) + await database.set_parameter( + "internalCascadesOptimizerDisableHashJoinRIDIntersect", False + ) - await workload_execution.execute(database, main_config.workload_execution, collections, - requests) + await workload_execution.execute( + database, main_config.workload_execution, collections, requests + ) - await database.set_parameter('internalCascadesOptimizerDisableMergeJoinRIDIntersect', True) - await database.set_parameter('internalCascadesOptimizerDisableHashJoinRIDIntersect', True) + await database.set_parameter( + "internalCascadesOptimizerDisableMergeJoinRIDIntersect", True + ) + await database.set_parameter( + "internalCascadesOptimizerDisableHashJoinRIDIntersect", True + ) main_config.workload_execution.write_mode = WriteMode.APPEND - await workload_execution.execute(database, main_config.workload_execution, collections, - requests[::4]) + await workload_execution.execute( + database, main_config.workload_execution, collections, requests[::4] + ) finally: - await database.set_parameter('internalCascadesOptimizerDisableMergeJoinRIDIntersect', False) - await database.set_parameter('internalCascadesOptimizerDisableHashJoinRIDIntersect', False) - await database.set_parameter('internalCostModelCoefficients', '') + await database.set_parameter( + "internalCascadesOptimizerDisableMergeJoinRIDIntersect", False + ) + await database.set_parameter( + "internalCascadesOptimizerDisableHashJoinRIDIntersect", False + ) + await database.set_parameter("internalCostModelCoefficients", "") -async def execute_index_intersections(database: DatabaseInstance, - collections: Sequence[CollectionInfo]): - collections = [ci for ci in collections if ci.name.startswith('c_int')] +async def execute_index_intersections( + database: DatabaseInstance, collections: Sequence[CollectionInfo] +): + collections = [ci for ci in collections if ci.name.startswith("c_int")] requests = [] for i in range(0, 1000, 100): - requests.append(Query(pipeline=[{'$match': {'in1': i, 'in2': i}}], keys_length_in_bytes=1)) + requests.append( + Query(pipeline=[{"$match": {"in1": i, "in2": i}}], keys_length_in_bytes=1) + ) requests.append( - Query(pipeline=[{'$match': {'in1': i, 'in2': 1000 - i}}], keys_length_in_bytes=1)) + Query( + pipeline=[{"$match": {"in1": i, "in2": 1000 - i}}], + keys_length_in_bytes=1, + ) + ) requests.append( - Query(pipeline=[{'$match': {'in1': {'$lte': i}, 'in2': 1000 - i}}], - keys_length_in_bytes=1)) + Query( + pipeline=[{"$match": {"in1": {"$lte": i}, "in2": 1000 - i}}], + keys_length_in_bytes=1, + ) + ) requests.append( - Query(pipeline=[{'$match': {'in1': i, 'in2': {'$gt': 1000 - i}}}], - keys_length_in_bytes=1)) + Query( + pipeline=[{"$match": {"in1": i, "in2": {"$gt": 1000 - i}}}], + keys_length_in_bytes=1, + ) + ) await execute_index_intersections_with_requests(database, collections, requests) -async def execute_evaluation(database: DatabaseInstance, collections: Sequence[CollectionInfo]): - collections = [ci for ci in collections if ci.name.startswith('c_int_05')] +async def execute_evaluation( + database: DatabaseInstance, collections: Sequence[CollectionInfo] +): + collections = [ci for ci in collections if ci.name.startswith("c_int_05")] requests = [] for i in [100, 500, 1000]: requests.append( - Query(pipeline=[{'$project': {'uniform1': 1, 'mixed2': 1}}, {"$limit": i}], - keys_length_in_bytes=1, number_of_fields=1, note='Evaluation')) + Query( + pipeline=[{"$project": {"uniform1": 1, "mixed2": 1}}, {"$limit": i}], + keys_length_in_bytes=1, + number_of_fields=1, + note="Evaluation", + ) + ) - await workload_execution.execute(database, main_config.workload_execution, collections, - requests) + await workload_execution.execute( + database, main_config.workload_execution, collections, requests + ) -async def execute_unwind(database: DatabaseInstance, collections: Sequence[CollectionInfo]): - collections = [ci for ci in collections if ci.name.startswith('c_arr_01')] +async def execute_unwind( + database: DatabaseInstance, collections: Sequence[CollectionInfo] +): + collections = [ci for ci in collections if ci.name.startswith("c_arr_01")] requests = [] # average size of arrays in the collection average_size_of_arrays = 10 for _ in range(500, 1000, 100): requests.append( - Query(pipeline=[{"$unwind": "$as"}], number_of_fields=average_size_of_arrays)) + Query( + pipeline=[{"$unwind": "$as"}], number_of_fields=average_size_of_arrays + ) + ) - await workload_execution.execute(database, main_config.workload_execution, collections, - requests) + await workload_execution.execute( + database, main_config.workload_execution, collections, requests + ) -async def execute_unique(database: DatabaseInstance, collections: Sequence[CollectionInfo]): - collections = [ci for ci in collections if ci.name.startswith('c_arr_01')] +async def execute_unique( + database: DatabaseInstance, collections: Sequence[CollectionInfo] +): + collections = [ci for ci in collections if ci.name.startswith("c_arr_01")] requests = [] for i in range(500, 1000, 200): requests.append(Query(pipeline=[{"$match": {"as": {"$gt": i}}}])) - await workload_execution.execute(database, main_config.workload_execution, collections, - requests) + await workload_execution.execute( + database, main_config.workload_execution, collections, requests + ) -async def execute_limitskip(database: DatabaseInstance, collections: Sequence[CollectionInfo]): - collection = [ci for ci in collections if ci.name.startswith('index_scan')][0] +async def execute_limitskip( + database: DatabaseInstance, collections: Sequence[CollectionInfo] +): + collection = [ci for ci in collections if ci.name.startswith("index_scan")][0] limits = [5, 10, 15, 20] skips = [5, 10, 15, 20] requests = [] for limit in limits: for skip in skips: - requests.append(Query(pipeline=[{"$skip": skip}, {"$limit": limit}], note="LimitSkip")) + requests.append( + Query(pipeline=[{"$skip": skip}, {"$limit": limit}], note="LimitSkip") + ) - await workload_execution.execute(database, main_config.workload_execution, [collection], - requests) + await workload_execution.execute( + database, main_config.workload_execution, [collection], requests + ) async def main(): @@ -210,7 +278,6 @@ async def main(): # 1. Database Instance provides connectivity to a MongoDB instance, it loads data optionally # from the dump on creating and stores data optionally to the dump on closing. with DatabaseInstance(main_config.database) as database: - # 2. Data generation (optional), generates random data and populates collections with it. generator = DataGenerator(database, main_config.data_generator) await generator.populate_collections() @@ -235,16 +302,17 @@ async def main(): # aparses the explains, nd calibrates the cost model for the ABT nodes. models = await abt_calibrator.calibrate(main_config.abt_calibrator, database) for abt, model in models.items(): - print(f'{abt}\t\t{model}') + print(f"{abt}\t\t{model}") - parameters = await parameters_extractor.extract_parameters(main_config.abt_calibrator, - database, []) - save_to_csv(parameters, 'parameters.csv') + parameters = await parameters_extractor.extract_parameters( + main_config.abt_calibrator, database, [] + ) + save_to_csv(parameters, "parameters.csv") print("DONE!") -if __name__ == '__main__': +if __name__ == "__main__": loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: diff --git a/buildscripts/cost_model/workload_execution.py b/buildscripts/cost_model/workload_execution.py index d8c4cd9138e..325bcffe21b 100644 --- a/buildscripts/cost_model/workload_execution.py +++ b/buildscripts/cost_model/workload_execution.py @@ -38,7 +38,7 @@ from config import WorkloadExecutionConfig, WriteMode from data_generator import CollectionInfo from database_instance import DatabaseInstance, Pipeline -__all__ = ['execute'] +__all__ = ["execute"] @dataclass @@ -70,15 +70,19 @@ class QueryParameters: return QueryParameters(**json.loads(json_str)) -async def execute(database: DatabaseInstance, config: WorkloadExecutionConfig, - collection_infos: Sequence[CollectionInfo], queries: Sequence[Query]): +async def execute( + database: DatabaseInstance, + config: WorkloadExecutionConfig, + collection_infos: Sequence[CollectionInfo], + queries: Sequence[Query], +): """Run the given queries and write the collected explain into collection.""" if not config.enabled: return collector = WorkloadExecution(database, config) await collector.async_init() - print('>>> running queries') + print(">>> running queries") await collector.collect(collection_infos, queries) @@ -97,33 +101,46 @@ class WorkloadExecution: if self.config.write_mode == WriteMode.REPLACE: await self.database.drop_collection(self.config.output_collection_name) - async def collect(self, collection_infos: Sequence[CollectionInfo], queries: Sequence[Query]): + async def collect( + self, collection_infos: Sequence[CollectionInfo], queries: Sequence[Query] + ): """Run the given piplelines on the given collection to generate and collect execution statistics.""" measurements = [] for coll_info in collection_infos: - print(f'\n>>>>> running queries on collection {coll_info.name}') + print(f"\n>>>>> running queries on collection {coll_info.name}") for query in queries: - print(f'>>>>>>> running query {query.pipeline}') + print(f">>>>>>> running query {query.pipeline}") await self._run_query(coll_info, query, measurements) - await self.database.insert_many(self.config.output_collection_name, measurements) + await self.database.insert_many( + self.config.output_collection_name, measurements + ) - async def _run_query(self, coll_info: CollectionInfo, query: Query, result: Sequence): + async def _run_query( + self, coll_info: CollectionInfo, query: Query, result: Sequence + ): # warm up for _ in range(self.config.warmup_runs): await self.database.explain(coll_info.name, query.pipeline) run_id = ObjectId() avg_doc_size = await self.database.get_average_document_size(coll_info.name) - parameters = QueryParameters(keys_length_in_bytes=query.keys_length_in_bytes, - number_of_fields=query.number_of_fields, - average_document_size_in_bytes=avg_doc_size, note=query.note) + parameters = QueryParameters( + keys_length_in_bytes=query.keys_length_in_bytes, + number_of_fields=query.number_of_fields, + average_document_size_in_bytes=avg_doc_size, + note=query.note, + ) for _ in range(self.config.runs): explain = await self.database.explain(coll_info.name, query.pipeline) - if explain['ok'] == 1: - result.append({ - 'run_id': run_id, 'collection': coll_info.name, - 'pipeline': json.dumps(query.pipeline), 'explain': json.dumps(explain), - 'query_parameters': parameters.to_json() - }) + if explain["ok"] == 1: + result.append( + { + "run_id": run_id, + "collection": coll_info.name, + "pipeline": json.dumps(query.pipeline), + "explain": json.dumps(explain), + "query_parameters": parameters.to_json(), + } + ) diff --git a/buildscripts/create_bazel_test_report.py b/buildscripts/create_bazel_test_report.py index 7537b3fc598..7a18a30d330 100644 --- a/buildscripts/create_bazel_test_report.py +++ b/buildscripts/create_bazel_test_report.py @@ -55,13 +55,16 @@ def main(testlog_dir: str, silent_fail: bool = False): end = start + int(duration) report["results"].append( - Result({ - "test_file": test_file, - "status": status, - "start": start, - "end": end, - "log_raw": log_raw, - })) + Result( + { + "test_file": test_file, + "status": status, + "start": start, + "end": end, + "log_raw": log_raw, + } + ) + ) if report["results"]: with open("report.json", "wt") as fh: diff --git a/buildscripts/daily_task_scan.py b/buildscripts/daily_task_scan.py index 455cbbbaec4..4873e8c6028 100644 --- a/buildscripts/daily_task_scan.py +++ b/buildscripts/daily_task_scan.py @@ -24,8 +24,10 @@ def process_version(version: evergreen.Version) -> None: build_variant_name = build_variant.build_variant unsupported_variants = ["tsan", "asan", "aubsan", "macos"] - if any(unsupported_variant in build_variant_name - for unsupported_variant in unsupported_variants): + if any( + unsupported_variant in build_variant_name + for unsupported_variant in unsupported_variants + ): continue for task_id in build_variant.tasks: @@ -53,7 +55,9 @@ def main(expansions_file: str, output_file: str) -> int: is_patch = expansions.get("is_patch", False) today = datetime.datetime.utcnow().date() - start_of_today = datetime.datetime(today.year, today.month, today.day, tzinfo=tz.UTC) + start_of_today = datetime.datetime( + today.year, today.month, today.day, tzinfo=tz.UTC + ) # STM daily cron runs everyday at 4 AM # We scan the day before yesterday so we do not have to worry about in progress tasks assuming @@ -61,7 +65,7 @@ def main(expansions_file: str, output_file: str) -> int: start_of_window = start_of_today - datetime.timedelta(days=2) end_of_window = start_of_today - datetime.timedelta(days=1) - #We only care maintaining alerts as of v7.2 + # We only care maintaining alerts as of v7.2 mongo_projects = ["mongodb-mongo-master", "mongodb-mongo-master-nightly"] for project in evg_api.all_projects(): if not project.enabled: @@ -79,8 +83,9 @@ def main(expansions_file: str, output_file: str) -> int: with concurrent.futures.ThreadPoolExecutor(max_workers=cores) as executor: futures = [] for project in mongo_projects: - patches = evg_api.patches_by_project_time_window(project, end_of_window, - start_of_window) + patches = evg_api.patches_by_project_time_window( + project, end_of_window, start_of_window + ) # This covers all user patches generated with `evergreen patch` for patch in patches: if not patch.activated: @@ -91,20 +96,33 @@ def main(expansions_file: str, output_file: str) -> int: # This covers all automated versions that are run by crons and other stuff for requester in ["adhoc", "gitter_request", "github_pull_request"]: for version in evg_api.versions_by_project_time_window( - project, end_of_window, start_of_window, requester): + project, end_of_window, start_of_window, requester + ): futures.append(executor.submit(process_version, version=version)) concurrent.futures.wait(futures) errors = [] if timeouts_without_dumps: - errors.append("ERROR: The following tasks timed out without core dumps uploaded:") + errors.append( + "ERROR: The following tasks timed out without core dumps uploaded:" + ) errors.extend( - [f"https://spruce.mongodb.com/task/{task_id}" for task_id in timeouts_without_dumps]) + [ + f"https://spruce.mongodb.com/task/{task_id}" + for task_id in timeouts_without_dumps + ] + ) if passed_with_dumps: - errors.append("ERROR: The following tasks had core dumps uploaded while being successful:") + errors.append( + "ERROR: The following tasks had core dumps uploaded while being successful:" + ) errors.extend( - [f"https://spruce.mongodb.com/task/{task_id}" for task_id in passed_with_dumps]) + [ + f"https://spruce.mongodb.com/task/{task_id}" + for task_id in passed_with_dumps + ] + ) if not errors: return 0 @@ -124,11 +142,13 @@ def main(expansions_file: str, output_file: str) -> int: if timeouts_without_dumps: msg.append( - f"- {len(timeouts_without_dumps)} task(s) timed out without core dumps being uploaded") + f"- {len(timeouts_without_dumps)} task(s) timed out without core dumps being uploaded" + ) if passed_with_dumps: msg.append( - f"- {len(passed_with_dumps)} task(s) had core dumps uploaded while being successful") + f"- {len(passed_with_dumps)} task(s) had core dumps uploaded while being successful" + ) msg.append( f"For more details view the `Task Errors` file ." @@ -136,24 +156,33 @@ def main(expansions_file: str, output_file: str) -> int: if not is_patch: evg_api.send_slack_message( - target="#sdp-test-alerts", #TODO SERVER-83205: change to #sdp-triager + target="#sdp-test-alerts", # TODO SERVER-83205: change to #sdp-triager msg="\n".join(msg), ) - #TODO SERVER-83205: change to return 1 + # TODO SERVER-83205: change to return 1 return 0 if __name__ == "__main__": parser = argparse.ArgumentParser( - prog='DailyTaskScanner', - description='Iterates over all of the tasks in mongodb-mongo-master and ' - 'mongodb-mongo-master-nightly over a day and looks for certain errors to send ' - 'alerts of.') + prog="DailyTaskScanner", + description="Iterates over all of the tasks in mongodb-mongo-master and " + "mongodb-mongo-master-nightly over a day and looks for certain errors to send " + "alerts of.", + ) - parser.add_argument("--expansions-file", "-e", help="Expansions file to read task info from.", - default="../expansions.yml") - parser.add_argument("--output-file", "-f", help="File to output errors to.", - default="task_errors.txt") + parser.add_argument( + "--expansions-file", + "-e", + help="Expansions file to read task info from.", + default="../expansions.yml", + ) + parser.add_argument( + "--output-file", + "-f", + help="File to output errors to.", + default="task_errors.txt", + ) args = parser.parse_args() sys.exit(main(args.expansions_file, args.output_file)) diff --git a/buildscripts/debugsymb_mapper.py b/buildscripts/debugsymb_mapper.py index f3b3fddbca2..71137aa3a8e 100644 --- a/buildscripts/debugsymb_mapper.py +++ b/buildscripts/debugsymb_mapper.py @@ -1,4 +1,5 @@ """Script to generate & upload 'buildId -> debug symbols URL' mappings to symbolizer service.""" + import argparse import json import logging @@ -41,8 +42,13 @@ class CmdClient: :return: Command output. """ - out = subprocess.run(args, close_fds=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, - check=False) + out = subprocess.run( + args, + close_fds=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + check=False, + ) return out.stdout.strip().decode() @@ -73,8 +79,11 @@ class BinVersionOutput(NamedTuple): class CmdOutputExtractor: """Data extractor from command output.""" - def __init__(self, cmd_client: Optional[CmdClient] = None, - json_decoder: Optional[JSONDecoder] = None) -> None: + def __init__( + self, + cmd_client: Optional[CmdClient] = None, + json_decoder: Optional[JSONDecoder] = None, + ) -> None: """ Initialize. @@ -117,10 +126,10 @@ class CmdOutputExtractor: build_id = None for line in out.splitlines(): line = line.strip() - if line.startswith('Build ID'): + if line.startswith("Build ID"): if build_id is not None: raise ValueError("Found multiple Build ID values.") - build_id = line.split(': ')[1] + build_id = line.split(": ")[1] return build_id def _get_mongodb_version(self, out: str) -> Optional[str]: @@ -143,8 +152,13 @@ class CmdOutputExtractor: class DownloadOptions(object): """A class to collect download option configurations.""" - def __init__(self, download_binaries=False, download_symbols=False, download_artifacts=False, - download_python_venv=False): + def __init__( + self, + download_binaries=False, + download_symbols=False, + download_artifacts=False, + download_python_venv=False, + ): """Initialize instance.""" self.download_binaries = download_binaries @@ -158,16 +172,26 @@ class Mapper: # This amount of attributes are necessary. - default_web_service_base_url: str = "https://symbolizer-service.server-tig.prod.corp.mongodb.com" - default_cache_dir = os.path.join(os.getcwd(), 'build', 'symbols_cache') - selected_binaries = ('mongos', 'mongod', 'mongo') + default_web_service_base_url: str = ( + "https://symbolizer-service.server-tig.prod.corp.mongodb.com" + ) + default_cache_dir = os.path.join(os.getcwd(), "build", "symbols_cache") + selected_binaries = ("mongos", "mongod", "mongo") default_client_credentials_scope = "servertig-symbolizer-fullaccess" default_client_credentials_user_name = "client-user" - default_creds_file_path = os.path.join(os.getcwd(), '.symbolizer_credentials.json') + default_creds_file_path = os.path.join(os.getcwd(), ".symbolizer_credentials.json") - def __init__(self, evg_version: str, evg_variant: str, is_san_variant: bool, client_id: str, - client_secret: str, cache_dir: str = None, web_service_base_url: str = None, - logger: logging.Logger = None): + def __init__( + self, + evg_version: str, + evg_variant: str, + is_san_variant: bool, + client_id: str, + client_secret: str, + cache_dir: str = None, + web_service_base_url: str = None, + logger: logging.Logger = None, + ): """ Initialize instance. @@ -184,11 +208,13 @@ class Mapper: self.evg_variant = evg_variant self.is_san_variant = is_san_variant self.cache_dir = cache_dir or self.default_cache_dir - self.web_service_base_url = web_service_base_url or self.default_web_service_base_url + self.web_service_base_url = ( + web_service_base_url or self.default_web_service_base_url + ) if not logger: logging.basicConfig() - logger = logging.getLogger('symbolizer') + logger = logging.getLogger("symbolizer") logger.setLevel(logging.INFO) self.logger = logger @@ -196,12 +222,15 @@ class Mapper: self.multiversion_setup = SetupMultiversion( DownloadOptions(download_symbols=True, download_binaries=True), - variant=self.evg_variant, ignore_failed_push=True) + variant=self.evg_variant, + ignore_failed_push=True, + ) self.debug_symbols_url = None self.url = None self.configs = Configs( client_credentials_scope=self.default_client_credentials_scope, - client_credentials_user_name=self.default_client_credentials_user_name) + client_credentials_user_name=self.default_client_credentials_user_name, + ) self.client_id = client_id self.client_secret = client_secret self.path_options = PathOptions() @@ -219,23 +248,34 @@ class Mapper: if os.path.exists(self.default_creds_file_path): with open(self.default_creds_file_path) as cfile: data = json.loads(cfile.read()) - access_token, expire_time = data.get("access_token"), data.get("expire_time") + access_token, expire_time = ( + data.get("access_token"), + data.get("expire_time"), + ) if time.time() < expire_time: # credentials haven't expired yet - self.http_client.headers.update({"Authorization": f"Bearer {access_token}"}) + self.http_client.headers.update( + {"Authorization": f"Bearer {access_token}"} + ) return - credentials = get_client_cred_oauth_credentials(self.client_id, self.client_secret, - configs=self.configs) - self.http_client.headers.update({"Authorization": f"Bearer {credentials.access_token}"}) + credentials = get_client_cred_oauth_credentials( + self.client_id, self.client_secret, configs=self.configs + ) + self.http_client.headers.update( + {"Authorization": f"Bearer {credentials.access_token}"} + ) # write credentials to local file for further usage with open(self.default_creds_file_path, "w") as cfile: cfile.write( - json.dumps({ - "access_token": credentials.access_token, - "expire_time": time.time() + credentials.expires_in - })) + json.dumps( + { + "access_token": credentials.access_token, + "expire_time": time.time() + credentials.expires_in, + } + ) + ) def __enter__(self): """Return instance when used as a context manager.""" @@ -261,7 +301,7 @@ class Mapper: :param url: download URL :return: full name for local file """ - return url.split('/')[-1] + return url.split("/")[-1] def setup_urls(self): """Set up URLs using multiversion.""" @@ -273,12 +313,16 @@ class Mapper: # Sanitizer builds are not stripped and contain debug symbols download_symbols_url = binaries_url else: - download_symbols_url = urlinfo.urls.get("mongo-debugsymbols.tgz") or urlinfo.urls.get( - "mongo-debugsymbols.zip") + download_symbols_url = urlinfo.urls.get( + "mongo-debugsymbols.tgz" + ) or urlinfo.urls.get("mongo-debugsymbols.zip") if not download_symbols_url: - self.logger.error("Couldn't find URL for debug symbols. Version: %s, URLs dict: %s", - self.evg_version, urlinfo.urls) + self.logger.error( + "Couldn't find URL for debug symbols. Version: %s, URLs dict: %s", + self.evg_version, + urlinfo.urls, + ) raise ValueError(f"Debug symbols URL not found. URLs dict: {urlinfo.urls}") self.debug_symbols_url = download_symbols_url @@ -291,7 +335,7 @@ class Mapper: :param path: full path of file :return: full path of directory of unpacked file """ - foldername = path.replace('.tgz', '', 1).split('/')[-1] + foldername = path.replace(".tgz", "", 1).split("/")[-1] out_dir = os.path.join(self.cache_dir, foldername) if not os.path.exists(out_dir): @@ -335,26 +379,36 @@ class Mapper: # shared libraries folder holds shared libraries, tons of them. # some build variants do not contain shared libraries. - binaries_unpacked_path = os.path.join(binaries_unpacked_path, 'dist-test') + binaries_unpacked_path = os.path.join(binaries_unpacked_path, "dist-test") - self.logger.info("INSIDE unpacked binaries/dist-test: %s", - os.listdir(binaries_unpacked_path)) + self.logger.info( + "INSIDE unpacked binaries/dist-test: %s", os.listdir(binaries_unpacked_path) + ) - mongod_bin = os.path.join(binaries_unpacked_path, self.path_options.main_binary_folder_name, - MONGOD) + mongod_bin = os.path.join( + binaries_unpacked_path, self.path_options.main_binary_folder_name, MONGOD + ) bin_version_output = extractor.get_bin_version(mongod_bin) if bin_version_output.mongodb_version is None: - self.logger.error("mongodb version could not be extracted. \n`%s --version` output: %s", - mongod_bin, bin_version_output.cmd_output) + self.logger.error( + "mongodb version could not be extracted. \n`%s --version` output: %s", + mongod_bin, + bin_version_output.cmd_output, + ) return else: - self.logger.info("Extracted mongodb version: %s", bin_version_output.mongodb_version) + self.logger.info( + "Extracted mongodb version: %s", bin_version_output.mongodb_version + ) # start with main binary folder for binary in self.selected_binaries: - full_bin_path = os.path.join(binaries_unpacked_path, - self.path_options.main_binary_folder_name, binary) + full_bin_path = os.path.join( + binaries_unpacked_path, + self.path_options.main_binary_folder_name, + binary, + ) if not os.path.exists(full_bin_path): self.logger.error("Could not find binary at %s", full_bin_path) @@ -363,35 +417,41 @@ class Mapper: build_id_output = extractor.get_build_id(full_bin_path) if not build_id_output.build_id: - self.logger.error("Build ID couldn't be extracted. \nReadELF output %s", - build_id_output.cmd_output) + self.logger.error( + "Build ID couldn't be extracted. \nReadELF output %s", + build_id_output.cmd_output, + ) continue else: self.logger.info("Extracted build ID: %s", build_id_output.build_id) yield { - 'url': self.url, - 'debug_symbols_url': self.debug_symbols_url, - 'build_id': build_id_output.build_id, - 'file_name': binary, - 'version': bin_version_output.mongodb_version, + "url": self.url, + "debug_symbols_url": self.debug_symbols_url, + "build_id": build_id_output.build_id, + "file_name": binary, + "version": bin_version_output.mongodb_version, } # move to shared libraries folder. # it contains all shared library binary files, # we run readelf on each of them. - lib_folder_path = os.path.join(binaries_unpacked_path, - self.path_options.shared_library_folder_name) + lib_folder_path = os.path.join( + binaries_unpacked_path, self.path_options.shared_library_folder_name + ) if not os.path.exists(lib_folder_path): # sometimes we don't get lib folder, which means there is no shared libraries for current build variant. - self.logger.info("'%s' folder does not exist.", - self.path_options.shared_library_folder_name) + self.logger.info( + "'%s' folder does not exist.", + self.path_options.shared_library_folder_name, + ) sofiles = [] else: sofiles = os.listdir(lib_folder_path) - self.logger.info("'%s' folder: %s", self.path_options.shared_library_folder_name, - sofiles) + self.logger.info( + "'%s' folder: %s", self.path_options.shared_library_folder_name, sofiles + ) for sofile in sofiles: sofile_path = os.path.join(lib_folder_path, sofile) @@ -403,18 +463,20 @@ class Mapper: build_id_output = extractor.get_build_id(sofile_path) if not build_id_output.build_id: - self.logger.error("Build ID couldn't be extracted. \nReadELF out %s", - build_id_output.cmd_output) + self.logger.error( + "Build ID couldn't be extracted. \nReadELF out %s", + build_id_output.cmd_output, + ) continue else: self.logger.info("Extracted build ID: %s", build_id_output.build_id) yield { - 'url': self.url, - 'debug_symbols_url': self.debug_symbols_url, - 'build_id': build_id_output.build_id, - 'file_name': sofile, - 'version': bin_version_output.mongodb_version, + "url": self.url, + "debug_symbols_url": self.debug_symbols_url, + "build_id": build_id_output.build_id, + "file_name": sofile, + "version": bin_version_output.mongodb_version, } def run(self): @@ -428,12 +490,17 @@ class Mapper: # mappings is a generator, we iterate over to generate mappings on the go for mapping in mappings: self.logger.info("Creating mapping %s", mapping) - response = self.http_client.post('/'.join((self.web_service_base_url, 'add')), - json=mapping) + response = self.http_client.post( + "/".join((self.web_service_base_url, "add")), json=mapping + ) if response.status_code != 200: self.logger.error( "Could not store mapping, web service returned status code %s from URL %s. " - "Response: %s", response.status_code, response.url, response.text) + "Response: %s", + response.status_code, + response.url, + response.text, + ) def make_argument_parser(parser=None, **kwargs): @@ -454,10 +521,14 @@ def make_argument_parser(parser=None, **kwargs): def main(options): """Execute mapper here. Main entry point.""" - mapper = Mapper(evg_version=options.version, evg_variant=options.variant, - is_san_variant=options.is_san_variant, client_id=options.client_id, - client_secret=options.client_secret, - web_service_base_url=options.web_service_base_url) + mapper = Mapper( + evg_version=options.version, + evg_variant=options.variant, + is_san_variant=options.is_san_variant, + client_id=options.client_id, + client_secret=options.client_secret, + web_service_base_url=options.web_service_base_url, + ) # when used as a context manager, mapper instance automatically cleans files/folders after finishing its job. # in other cases, mapper.cleanup() method should be called manually. @@ -465,6 +536,6 @@ def main(options): mapper.run() -if __name__ == '__main__': +if __name__ == "__main__": mapper_options = make_argument_parser(description=__doc__).parse_args() main(mapper_options) diff --git a/buildscripts/deflakinator.py b/buildscripts/deflakinator.py index ad16cb94acd..09b8aef2474 100755 --- a/buildscripts/deflakinator.py +++ b/buildscripts/deflakinator.py @@ -9,8 +9,14 @@ _MAX_RUNS = 1000 def print_pb_version_ids(run_id: str): - result = subprocess.run(f"evergreen list-patches -j -n {_MAX_RUNS * 2}", shell=True, check=True, - stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + result = subprocess.run( + f"evergreen list-patches -j -n {_MAX_RUNS * 2}", + shell=True, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) data = json.loads(result.stdout) filtered_objects = [ @@ -26,7 +32,7 @@ def print_pb_version_ids(run_id: str): "For analysis, visit", f"https://ui.honeycomb.io/mongodb-4b/environments/production?query=%7B%22time_range%22%3A86400%2C%22granularity%22%3A0%2C%22breakdowns%22%3A%5B%22evergreen.task.name%22%5D%2C%22calculations%22%3A%5B%7B%22op%22%3A%22COUNT%22%7D%5D%2C%22filters%22%3A%5B%7B%22column%22%3A%22evergreen.task.name%22%2C%22op%22%3A%22!%3D%22%2C%22value%22%3A%22lint_repo%22%7D%2C%7B%22column%22%3A%22evergreen.project.id%22%2C%22op%22%3A%22%3D%22%2C%22value%22%3A%22mongodb-mongo-master%22%7D%2C%7B%22column%22%3A%22evergreen.task.status%22%2C%22op%22%3A%22%3D%22%2C%22value%22%3A%22failed%22%7D%2C%7B%22column%22%3A%22evergreen.version.id%22%2C%22op%22%3A%22in%22%2C%22value%22%3A%5B%22\ {'%22%2C%22'.join(versions)}\ -%22%5D%7D%5D%2C%22filter_combination%22%3A%22AND%22%2C%22orders%22%3A%5B%7B%22op%22%3A%22COUNT%22%2C%22order%22%3A%22descending%22%7D%5D%2C%22havings%22%3A%5B%5D%2C%22limit%22%3A1000%7D" +%22%5D%7D%5D%2C%22filter_combination%22%3A%22AND%22%2C%22orders%22%3A%5B%7B%22op%22%3A%22COUNT%22%2C%22order%22%3A%22descending%22%7D%5D%2C%22havings%22%3A%5B%5D%2C%22limit%22%3A1000%7D", ) print("---") @@ -38,11 +44,17 @@ def deflakinator(runs: int, evergreen_args: str): print("Kicking off evergreen patch builds:") print("---") - command = f"evergreen patch {evergreen_args} --yes --finalize --description \"flakinator run id: {run_id}\"" + command = f'evergreen patch {evergreen_args} --yes --finalize --description "flakinator run id: {run_id}"' for _ in range(runs): print(command) - subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, text=True) + subprocess.run( + command, + shell=True, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) print_pb_version_ids(run_id) @@ -50,9 +62,16 @@ def deflakinator(runs: int, evergreen_args: str): def main(): """Run the main function.""" parser = argparse.ArgumentParser() - parser.add_argument("--runs", type=int, choices=range(1, _MAX_RUNS), metavar=f"[1-{_MAX_RUNS}]", - help="Number of times to run each patch build") - parser.add_argument("--evergreen-args", type=str, help="Arguments to pass to evergreen patch") + parser.add_argument( + "--runs", + type=int, + choices=range(1, _MAX_RUNS), + metavar=f"[1-{_MAX_RUNS}]", + help="Number of times to run each patch build", + ) + parser.add_argument( + "--evergreen-args", type=str, help="Arguments to pass to evergreen patch" + ) args = parser.parse_args() deflakinator(args.runs, args.evergreen_args) diff --git a/buildscripts/download_buildifier.py b/buildscripts/download_buildifier.py index 780de33cc73..49c022a60df 100644 --- a/buildscripts/download_buildifier.py +++ b/buildscripts/download_buildifier.py @@ -39,17 +39,23 @@ def main(): operating_system = determine_platform() architechture = determine_architecture() if operating_system == "windows" and architechture == "arm64": - raise RuntimeError("There are no published arm windows releases for buildifier.") + raise RuntimeError( + "There are no published arm windows releases for buildifier." + ) parser = argparse.ArgumentParser( - prog='DownloadBuildifier', - description='This downloads buildifier, it is intended for use in evergreen.' - 'This is our temperary solution to get bazel linting while we determine a ' - 'long-term solution for getting buildifier on evergreen/development hosts.') + prog="DownloadBuildifier", + description="This downloads buildifier, it is intended for use in evergreen." + "This is our temperary solution to get bazel linting while we determine a " + "long-term solution for getting buildifier on evergreen/development hosts.", + ) - parser.add_argument("--download-location", "-f", - help="Name of directory to download the buildifier binary to.", - default="./") + parser.add_argument( + "--download-location", + "-f", + help="Name of directory to download the buildifier binary to.", + default="./", + ) args = parser.parse_args() diff --git a/buildscripts/download_sys_perf_binaries.py b/buildscripts/download_sys_perf_binaries.py index c4ce78a63f2..2bce343d243 100644 --- a/buildscripts/download_sys_perf_binaries.py +++ b/buildscripts/download_sys_perf_binaries.py @@ -5,63 +5,67 @@ import argparse import requests -BASE_URI = 'https://evergreen.mongodb.com/rest/v2/' +BASE_URI = "https://evergreen.mongodb.com/rest/v2/" def _get_auth_headers(evergreen_api_user, evergreen_api_key): return { - 'Api-User': evergreen_api_user, - 'Api-Key': evergreen_api_key, + "Api-User": evergreen_api_user, + "Api-Key": evergreen_api_key, } def _get_build_id(build_variant_name, version_id, auth_headers): - url = BASE_URI + 'versions/' + version_id + url = BASE_URI + "versions/" + version_id response = requests.get(url, headers=auth_headers) if response.status_code != 200: - raise ValueError('Invalid version_id:', version_id) + raise ValueError("Invalid version_id:", version_id) version_json = response.json() - build_variants = version_json['build_variants_status'] + build_variants = version_json["build_variants_status"] for build_variant in build_variants: - if build_variant['build_variant'] == build_variant_name: - return build_variant['build_id'] + if build_variant["build_variant"] == build_variant_name: + return build_variant["build_id"] - raise RuntimeError('The compile-variant ' + build_variant_name + - ' does not exist for the build with version_id ' + version_id) + raise RuntimeError( + "The compile-variant " + + build_variant_name + + " does not exist for the build with version_id " + + version_id + ) # All of our sys-perf compile variants have exactly one task, # so we can safely select the first one from the build variants tasks. def _get_task_id(build_id, auth_headers): - url = BASE_URI + 'builds/' + build_id + url = BASE_URI + "builds/" + build_id response = requests.get(url, headers=auth_headers) if response.status_code != 200: - raise RuntimeError('Unexpected error when trying to reach' + url) + raise RuntimeError("Unexpected error when trying to reach" + url) build_json = response.json() - task_list = build_json['tasks'] + task_list = build_json["tasks"] if len(task_list) != 1: - raise RuntimeError('Recieved unexpected tasklist:', task_list) + raise RuntimeError("Recieved unexpected tasklist:", task_list) return task_list[0] # The API used here always grabs the latest execution def _get_binary_details(task_id, auth_headers): - url = BASE_URI + 'tasks/' + task_id + url = BASE_URI + "tasks/" + task_id response = requests.get(url, headers=auth_headers) if response.status_code != 200: - raise RuntimeError('Unexpected error when trying to reach' + url) + raise RuntimeError("Unexpected error when trying to reach" + url) task_json = response.json() - if task_json['status'] != 'success': - raise RuntimeError('The task ' + task_id + ' did not run sucessfully') + if task_json["status"] != "success": + raise RuntimeError("The task " + task_id + " did not run sucessfully") # The binary will always be the first artifact, unless we make large changes to system_perf.yml - artifacts = task_json['artifacts'] - if len(artifacts) > 0 and artifacts[0]['name'].startswith('mongo'): + artifacts = task_json["artifacts"] + if len(artifacts) > 0 and artifacts[0]["name"].startswith("mongo"): return artifacts[0] - raise RuntimeError('Unexpected list of artifacts:' + artifacts) + raise RuntimeError("Unexpected list of artifacts:" + artifacts) def _get_binary_url(version_id, build_variant, evergreen_api_user, evergreen_api_key): @@ -69,36 +73,55 @@ def _get_binary_url(version_id, build_variant, evergreen_api_user, evergreen_api build_id = _get_build_id(build_variant, version_id, auth_headers) task_id = _get_task_id(build_id, auth_headers) binary_json = _get_binary_details(task_id, auth_headers) - return binary_json['url'] + return binary_json["url"] def _download_binary_file(url, save_path): response = requests.get(url, stream=True) if response.status_code == 200: - with open(save_path, 'wb') as file: + with open(save_path, "wb") as file: for chunk in response.iter_content(chunk_size=1024): file.write(chunk) else: - raise RuntimeError('Failed to download the file ' + url) + raise RuntimeError("Failed to download the file " + url) -def _download_sys_perf_binaries(version_id, build_variant, evergreen_api_user, evergreen_api_key): - url = _get_binary_url(version_id, build_variant, evergreen_api_user, evergreen_api_key) - _download_binary_file(url, 'binary.tar.gz') +def _download_sys_perf_binaries( + version_id, build_variant, evergreen_api_user, evergreen_api_key +): + url = _get_binary_url( + version_id, build_variant, evergreen_api_user, evergreen_api_key + ) + _download_binary_file(url, "binary.tar.gz") -if __name__ == '__main__': +if __name__ == "__main__": argParser = argparse.ArgumentParser() - argParser.add_argument("-v", "--version_id", - help="Evergreen version_id from which binaries will be downloaded") - argParser.add_argument("-b", "--build_variant", - help="Build variant for which binaries will be downloaded") argParser.add_argument( - "-u", "--evergreen_api_user", - help="Evergreen API user, see https://spruce.mongodb.com/preferences/cli") - argParser.add_argument("-k", "--evergreen_api_key", - help="Evergreen API key, see https://spruce.mongodb.com/preferences/cli") + "-v", + "--version_id", + help="Evergreen version_id from which binaries will be downloaded", + ) + argParser.add_argument( + "-b", + "--build_variant", + help="Build variant for which binaries will be downloaded", + ) + argParser.add_argument( + "-u", + "--evergreen_api_user", + help="Evergreen API user, see https://spruce.mongodb.com/preferences/cli", + ) + argParser.add_argument( + "-k", + "--evergreen_api_key", + help="Evergreen API key, see https://spruce.mongodb.com/preferences/cli", + ) args = argParser.parse_args() - _download_sys_perf_binaries(args.version_id, args.build_variant, args.evergreen_api_user, - args.evergreen_api_key) + _download_sys_perf_binaries( + args.version_id, + args.build_variant, + args.evergreen_api_user, + args.evergreen_api_key, + ) diff --git a/buildscripts/errorcodes.py b/buildscripts/errorcodes.py index 77db926726c..954f38a0590 100755 --- a/buildscripts/errorcodes.py +++ b/buildscripts/errorcodes.py @@ -28,12 +28,15 @@ MAXIMUM_CODE = 99999999 # JIRA Ticket + XX codes = [] # type: ignore # Each AssertLocation identifies the C++ source location of an assertion -AssertLocation = namedtuple("AssertLocation", ['sourceFile', 'byteOffset', 'lines', 'code']) +AssertLocation = namedtuple( + "AssertLocation", ["sourceFile", "byteOffset", "lines", "code"] +) list_files = False # pylint: disable=invalid-name _CODE_PATTERNS = [ - re.compile(p + r'\s*(?P\d+)', re.MULTILINE) for p in [ + re.compile(p + r"\s*(?P\d+)", re.MULTILINE) + for p in [ # All the asserts and their optional variant suffixes r"(?:f|i|m|msg|t|u)(?:assert)" r"(?:ed)?" @@ -56,20 +59,22 @@ _CODE_PATTERNS = [ ] ] -_DIR_EXCLUDE_RE = re.compile(r'(\..*' - r'|pcre2.*' - r'|32bit.*' - r'|mongodb-.*' - r'|debian.*' - r'|mongo-cxx-driver.*' - r'|.*gotools.*' - r'|.*mozjs.*' - r')') +_DIR_EXCLUDE_RE = re.compile( + r"(\..*" + r"|pcre2.*" + r"|32bit.*" + r"|mongodb-.*" + r"|debian.*" + r"|mongo-cxx-driver.*" + r"|.*gotools.*" + r"|.*mozjs.*" + r")" +) -_FILE_INCLUDE_RE = re.compile(r'.*\.(cpp|c|h|py|idl)') +_FILE_INCLUDE_RE = re.compile(r".*\.(cpp|c|h|py|idl)") -def get_all_source_files(prefix='.'): +def get_all_source_files(prefix="."): """Return source files.""" def walk(path): @@ -92,8 +97,8 @@ def foreach_source_file(callback, src_root): """Invoke a callback on the text of each source file.""" for source_file in get_all_source_files(prefix=src_root): if list_files: - print('scanning file: ' + source_file) - with open(source_file, 'r', encoding='utf-8') as fh: + print("scanning file: " + source_file) + with open(source_file, "r", encoding="utf-8") as fh: callback(source_file, fh.read()) @@ -106,8 +111,12 @@ def parse_source_files(callback, src_root): # Note that this will include the text of the full match but will report the # position of the beginning of the code portion rather than the beginning of the # match. This is to position editors on the spot that needs to change. - loc = AssertLocation(source_file, match.start('code'), match.group(0), - match.group('code')) + loc = AssertLocation( + source_file, + match.start("code"), + match.group(0), + match.group("code"), + ) callback(loc) foreach_source_file(scan_for_codes, src_root) @@ -134,7 +143,7 @@ def get_line_and_column_for_position(loc, _file_cache=None): def is_terminated(lines): """Determine if assert is terminated, from .cpp/.h source lines as text.""" code_block = " ".join(lines) - return ';' in code_block or code_block.count('(') - code_block.count(')') <= 0 + return ";" in code_block or code_block.count("(") - code_block.count(")") <= 0 def get_next_code(seen, server_ticket=0): @@ -167,7 +176,7 @@ def check_error_codes(): return len(errors) == 0 -def read_error_codes(src_root='src/mongo'): +def read_error_codes(src_root="src/mongo"): """Define callback, call parse_source_files() with callback, save matches to global codes list.""" seen = {} errors = [] @@ -247,7 +256,9 @@ def replace_bad_codes(errors, next_code_generator): for loc in skip_errors: line, col = get_line_and_column_for_position(loc) - print("SKIPPING NONZERO code=%s: %s:%d:%d" % (loc.code, loc.sourceFile, line, col)) + print( + "SKIPPING NONZERO code=%s: %s:%d:%d" % (loc.code, loc.sourceFile, line, col) + ) # Dedupe, sort, and reverse so we don't have to update offsets as we go. for assert_loc in reversed(sorted(set(zero_errors))): @@ -257,16 +268,16 @@ def replace_bad_codes(errors, next_code_generator): ln = line_num - 1 - with open(source_file, 'r+') as fh: + with open(source_file, "r+") as fh: print("LINE_%d_BEFORE:%s" % (line_num, fh.readlines()[ln].rstrip())) fh.seek(0) text = fh.read() - assert text[byte_offset] == '0' + assert text[byte_offset] == "0" fh.seek(0) fh.write(text[:byte_offset]) fh.write(str(next(next_code_generator))) - fh.write(text[byte_offset + 1:]) + fh.write(text[byte_offset + 1 :]) fh.seek(0) print("LINE_%d_AFTER :%s" % (line_num, fh.readlines()[ln].rstrip())) @@ -281,7 +292,7 @@ def coerce_to_number(ticket_value): if isinstance(ticket_value, int): return ticket_value - ticket_re = re.compile(r'(?:SERVER-)?(\d+)', re.IGNORECASE) + ticket_re = re.compile(r"(?:SERVER-)?(\d+)", re.IGNORECASE) matches = ticket_re.fullmatch(ticket_value) if not matches: print("Unknown ticket number. Input: " + ticket_value) @@ -293,16 +304,37 @@ def coerce_to_number(ticket_value): def main(): """Validate error codes.""" parser = OptionParser(description=__doc__.strip()) - parser.add_option("--fix", dest="replace", action="store_true", default=False, - help="Fix zero codes in source files [default: %default]") - parser.add_option("-q", "--quiet", dest="quiet", action="store_true", default=False, - help="Suppress output on success [default: %default]") - parser.add_option("--list-files", dest="list_files", action="store_true", default=False, - help="Print the name of each file as it is scanned [default: %default]") parser.add_option( - "--ticket", dest="ticket", type="str", action="store", default=None, + "--fix", + dest="replace", + action="store_true", + default=False, + help="Fix zero codes in source files [default: %default]", + ) + parser.add_option( + "-q", + "--quiet", + dest="quiet", + action="store_true", + default=False, + help="Suppress output on success [default: %default]", + ) + parser.add_option( + "--list-files", + dest="list_files", + action="store_true", + default=False, + help="Print the name of each file as it is scanned [default: %default]", + ) + parser.add_option( + "--ticket", + dest="ticket", + type="str", + action="store", + default=None, help="Generate error codes for a given SERVER ticket number. Inputs can be of" - " the form: `--ticket=12345` or `--ticket=SERVER-12345`.") + " the form: `--ticket=12345` or `--ticket=SERVER-12345`.", + ) options, extra = parser.parse_args() if extra: parser.error(f"Unrecognized arguments: {' '.join(extra)}") diff --git a/buildscripts/evergreen_activate_gen_tasks.py b/buildscripts/evergreen_activate_gen_tasks.py index f8611d3f3ed..e492bbb9229 100755 --- a/buildscripts/evergreen_activate_gen_tasks.py +++ b/buildscripts/evergreen_activate_gen_tasks.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 """Activate an evergreen task in the existing build.""" + import os import sys @@ -63,7 +64,8 @@ def activate_task(expansions: EvgExpansions, evg_api: EvergreenApi) -> None: if expansions.task == BURN_IN_TAGS: version = evg_api.version_by_id(expansions.version_id) burn_in_build_variants = [ - variant for variant in version.build_variants_map.keys() + variant + for variant in version.build_variants_map.keys() if variant.endswith(BURN_IN_VARIANT_SUFFIX) ] for build_variant in burn_in_build_variants: @@ -72,29 +74,42 @@ def activate_task(expansions: EvgExpansions, evg_api: EvergreenApi) -> None: for task in task_list: if task.display_name == BURN_IN_TESTS: - LOGGER.info("Activating task", task_id=task.task_id, - task_name=task.display_name) + LOGGER.info( + "Activating task", + task_id=task.task_id, + task_name=task.display_name, + ) try: evg_api.configure_task(task.task_id, activated=True) except Exception: - LOGGER.warning("Could not activate task", task_id=task.task_id, - task_name=task.display_name) + LOGGER.warning( + "Could not activate task", + task_id=task.task_id, + task_name=task.display_name, + ) tasks_not_activated.append(task.task_id) else: task_list = evg_api.tasks_by_build(expansions.build_id) for task in task_list: if task.display_name == expansions.task: - LOGGER.info("Activating task", task_id=task.task_id, task_name=task.display_name) + LOGGER.info( + "Activating task", task_id=task.task_id, task_name=task.display_name + ) try: evg_api.configure_task(task.task_id, activated=True) except Exception: - LOGGER.warning("Could not activate task", task_id=task.task_id, - task_name=task.display_name) + LOGGER.warning( + "Could not activate task", + task_id=task.task_id, + task_name=task.display_name, + ) tasks_not_activated.append(task.task_id) if len(tasks_not_activated) > 0: - LOGGER.error("Some tasks were unable to be activated", - unactivated_tasks=len(tasks_not_activated)) + LOGGER.error( + "Some tasks were unable to be activated", + unactivated_tasks=len(tasks_not_activated), + ) raise ValueError( "Some tasks were unable to be activated, failing the task to let the author know. " "This should not be a blocking issue but may mean that some tasks are missing from your patch." @@ -102,10 +117,18 @@ def activate_task(expansions: EvgExpansions, evg_api: EvergreenApi) -> None: @click.command() -@click.option("--expansion-file", type=str, required=True, - help="Location of expansions file generated by evergreen.") -@click.option("--evergreen-config", type=str, default=EVG_CONFIG_FILE, - help="Location of evergreen configuration file.") +@click.option( + "--expansion-file", + type=str, + required=True, + help="Location of expansions file generated by evergreen.", +) +@click.option( + "--evergreen-config", + type=str, + default=EVG_CONFIG_FILE, + help="Location of evergreen configuration file.", +) @click.option("--verbose", is_flag=True, default=False, help="Enable verbose logging.") def main(expansion_file: str, evergreen_config: str, verbose: bool) -> None: """ diff --git a/buildscripts/evergreen_expansions2bash.py b/buildscripts/evergreen_expansions2bash.py index 858b6d648e0..8fe9a1163c1 100644 --- a/buildscripts/evergreen_expansions2bash.py +++ b/buildscripts/evergreen_expansions2bash.py @@ -1,4 +1,5 @@ """Convert Evergreen's expansions.yml to an eval-able shell script.""" + import sys from shlex import quote @@ -12,16 +13,20 @@ try: import click import yaml except ModuleNotFoundError: - _error("ERROR: Failed to import a dependency. This is almost certainly because " - "the task did not initialize the venv immediately after cloning the repository.") + _error( + "ERROR: Failed to import a dependency. This is almost certainly because " + "the task did not initialize the venv immediately after cloning the repository." + ) def _load_defaults(defaults_file: str) -> dict: with open(defaults_file) as fh: defaults = yaml.safe_load(fh) if not isinstance(defaults, dict): - _error("ERROR: expected to read a dictionary. expansions.defaults.yml" - "must be a dictionary. Check the indentation.") + _error( + "ERROR: expected to read a dictionary. expansions.defaults.yml" + "must be a dictionary. Check the indentation." + ) # expansions MUST be strings. Reject any that are not bad_expansions = set() @@ -30,11 +35,13 @@ def _load_defaults(defaults_file: str) -> dict: bad_expansions.add(key) if bad_expansions: - _error("ERROR: all default expansions must be strings. You can " - " fix this error by quoting the values in expansions.defaults.yml. " - "Integers, floating points, 'true', 'false', and 'null' " - "must be quoted. The following keys were interpreted as " - f"other types: {bad_expansions}") + _error( + "ERROR: all default expansions must be strings. You can " + " fix this error by quoting the values in expansions.defaults.yml. " + "Integers, floating points, 'true', 'false', and 'null' " + "must be quoted. The following keys were interpreted as " + f"other types: {bad_expansions}" + ) # These values show up if 1. Python's str is used to naively convert # a boolean to str, 2. A human manually entered one of those strings. @@ -48,11 +55,13 @@ def _load_defaults(defaults_file: str) -> dict: risky_boolean_keys.add(key) if risky_boolean_keys: - _error("ERROR: Found keys which had 'True' or 'False' as values. " - "Shell scripts assume that booleans are represented as 'true'" - " or 'false' (leading lowercase). If you added a new boolean, " - "ensure that it's represented in lowercase. If not, please report this in " - f"#server-testing. Keys with bad values: {risky_boolean_keys}") + _error( + "ERROR: Found keys which had 'True' or 'False' as values. " + "Shell scripts assume that booleans are represented as 'true'" + " or 'false' (leading lowercase). If you added a new boolean, " + "ensure that it's represented in lowercase. If not, please report this in " + f"#server-testing. Keys with bad values: {risky_boolean_keys}" + ) return defaults @@ -62,8 +71,10 @@ def _load_expansions(expansions_file) -> dict: expansions = yaml.safe_load(fh) if not isinstance(expansions, dict): - _error("ERROR: expected to read a dictionary. Has the output format " - "of expansions.write changed?") + _error( + "ERROR: expected to read a dictionary. Has the output format " + "of expansions.write changed?" + ) if not expansions: _error("ERROR: found 0 expansions. This is almost certainly wrong.") diff --git a/buildscripts/evergreen_gen_powercycle_tasks.py b/buildscripts/evergreen_gen_powercycle_tasks.py index 0d329b19ba5..26d053ef631 100755 --- a/buildscripts/evergreen_gen_powercycle_tasks.py +++ b/buildscripts/evergreen_gen_powercycle_tasks.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 """Generate multiple powercycle tasks to run in evergreen.""" + from collections import namedtuple from typing import Any, List, Set, Tuple @@ -18,17 +19,20 @@ from buildscripts.util.fileops import write_file from buildscripts.util.read_config import read_config_file from buildscripts.util.taskname import name_generated_task -Config = namedtuple("config", [ - "current_task_name", - "task_names", - "num_tasks", - "timeout_params", - "remote_credentials_vars", - "set_up_ec2_instance_vars", - "run_powercycle_vars", - "build_variant", - "distro", -]) +Config = namedtuple( + "config", + [ + "current_task_name", + "task_names", + "num_tasks", + "timeout_params", + "remote_credentials_vars", + "set_up_ec2_instance_vars", + "run_powercycle_vars", + "build_variant", + "distro", + ], +) def make_config(expansions_file: Any) -> Config: @@ -58,8 +62,17 @@ def make_config(expansions_file: Any) -> Config: build_variant = expansions.get("build_variant") distro = expansions.get("distro_id") - return Config(current_task_name, task_names, num_tasks, timeout_params, remote_credentials_vars, - set_up_ec2_instance_vars, run_powercycle_vars, build_variant, distro) + return Config( + current_task_name, + task_names, + num_tasks, + timeout_params, + remote_credentials_vars, + set_up_ec2_instance_vars, + run_powercycle_vars, + build_variant, + distro, + ) def get_setup_commands() -> Tuple[List[FunctionCall], Set[TaskDependency]]: @@ -87,8 +100,9 @@ def get_skip_compile_setup_commands() -> Tuple[List[FunctionCall], set]: @click.command() @click.argument("expansions_file", type=str, default="expansions.yml") @click.argument("output_file", type=str, default="powercycle_tasks.json") -def main(expansions_file: str = "expansions.yml", - output_file: str = "powercycle_tasks.json") -> None: +def main( + expansions_file: str = "expansions.yml", output_file: str = "powercycle_tasks.json" +) -> None: """Generate multiple powercycle tasks to run in evergreen.""" config = make_config(expansions_file) @@ -101,28 +115,41 @@ def main(expansions_file: str = "expansions.yml", else: commands, task_dependency = get_setup_commands() - commands.extend([ - FunctionCall("set up remote credentials", config.remote_credentials_vars), - BuiltInCommand("timeout.update", config.timeout_params), - FunctionCall("set up EC2 instance", config.set_up_ec2_instance_vars), - FunctionCall("run powercycle test", config.run_powercycle_vars), - ]) + commands.extend( + [ + FunctionCall( + "set up remote credentials", config.remote_credentials_vars + ), + BuiltInCommand("timeout.update", config.timeout_params), + FunctionCall("set up EC2 instance", config.set_up_ec2_instance_vars), + FunctionCall("run powercycle test", config.run_powercycle_vars), + ] + ) - sub_tasks.update({ - Task( - name_generated_task(task_name, index, config.num_tasks, config.build_variant), - commands, task_dependency) - for index in range(config.num_tasks) - }) + sub_tasks.update( + { + Task( + name_generated_task( + task_name, index, config.num_tasks, config.build_variant + ), + commands, + task_dependency, + ) + for index in range(config.num_tasks) + } + ) build_variant.display_task( - config.current_task_name.replace("_gen", ""), sub_tasks, distros=[config.distro], - execution_existing_tasks={ExistingTask(config.current_task_name)}) + config.current_task_name.replace("_gen", ""), + sub_tasks, + distros=[config.distro], + execution_existing_tasks={ExistingTask(config.current_task_name)}, + ) shrub_project = ShrubProject.empty() shrub_project.add_build_variant(build_variant) write_file(output_file, shrub_project.json()) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/buildscripts/evergreen_resmoke_job_count.py b/buildscripts/evergreen_resmoke_job_count.py index f4da198a52d..75ba78e7176 100644 --- a/buildscripts/evergreen_resmoke_job_count.py +++ b/buildscripts/evergreen_resmoke_job_count.py @@ -30,40 +30,54 @@ SYS_PLATFORM = sys.platform # Apply factor for a task based on the build variant it is running on. VARIANT_TASK_FACTOR_OVERRIDES = { - "enterprise-rhel-8-64-bit": [{"task": r"logical_session_cache_replication.*", "factor": 0.75}], + "enterprise-rhel-8-64-bit": [ + {"task": r"logical_session_cache_replication.*", "factor": 0.75} + ], "enterprise-rhel-8-64-bit-inmem": [ {"task": "secondary_reads_passthrough", "factor": 0.3}, {"task": "multi_stmt_txn_jscore_passthrough_with_migration", "factor": 0.3}, ], - "enterprise-rhel8-debug-tsan": [{ - "task": r"shard.*uninitialized_fcv_jscore_passthrough.*", "factor": 0.125 - }], - "rhel8-debug-aubsan-classic-engine": [{ - "task": r"shard.*uninitialized_fcv_jscore_passthrough.*", "factor": 0.25 - }], - "rhel8-debug-aubsan-all-feature-flags": [{ - "task": r"shard.*uninitialized_fcv_jscore_passthrough.*", "factor": 0.25 - }], + "enterprise-rhel8-debug-tsan": [ + {"task": r"shard.*uninitialized_fcv_jscore_passthrough.*", "factor": 0.125} + ], + "rhel8-debug-aubsan-classic-engine": [ + {"task": r"shard.*uninitialized_fcv_jscore_passthrough.*", "factor": 0.25} + ], + "rhel8-debug-aubsan-all-feature-flags": [ + {"task": r"shard.*uninitialized_fcv_jscore_passthrough.*", "factor": 0.25} + ], # TODO(SERVER-91466): figure out why noPassthrough tests are taking up more memory after switching # from Windows Server 2019 to Windows Server 2022 - "enterprise-windows-all-feature-flags-required": [{"task": "noPassthrough", "factor": 0.5}], + "enterprise-windows-all-feature-flags-required": [ + {"task": "noPassthrough", "factor": 0.5} + ], "enterprise-windows": [{"task": "noPassthrough", "factor": 0.5}], "windows-debug-suggested": [{"task": "noPassthrough", "factor": 0.5}], "windows": [{"task": "noPassthrough", "factor": 0.5}], } -TASKS_FACTORS = [{"task": r"replica_sets.*", "factor": 0.5}, {"task": r"sharding.*", "factor": 0.5}] +TASKS_FACTORS = [ + {"task": r"replica_sets.*", "factor": 0.5}, + {"task": r"sharding.*", "factor": 0.5}, +] DISTRO_MULTIPLIERS = {"rhel8.8-large": 1.618} # Apply factor for a task based on the machine type it is running on. MACHINE_TASK_FACTOR_OVERRIDES = { - "aarch64": - TASKS_FACTORS, + "aarch64": TASKS_FACTORS, "ppc64le": [ - dict(task=r"causally_consistent_hedged_reads_jscore_passthrough.*", factor=0.125), - dict(task=r"causally_consistent_read_concern_snapshot_passthrough.*", factor=0.125), - dict(task=r"sharded_causally_consistent_read_concern_snapshot_passthrough.*", factor=0.125), + dict( + task=r"causally_consistent_hedged_reads_jscore_passthrough.*", factor=0.125 + ), + dict( + task=r"causally_consistent_read_concern_snapshot_passthrough.*", + factor=0.125, + ), + dict( + task=r"sharded_causally_consistent_read_concern_snapshot_passthrough.*", + factor=0.125, + ), ], } @@ -117,8 +131,12 @@ def determine_final_multiplier(distro): def determine_factor(task_name, variant, distro, factor): """Determine the job factor.""" factors = [ - get_task_factor(task_name, MACHINE_TASK_FACTOR_OVERRIDES, PLATFORM_MACHINE, factor), - get_task_factor(task_name, PLATFORM_TASK_FACTOR_OVERRIDES, SYS_PLATFORM, factor), + get_task_factor( + task_name, MACHINE_TASK_FACTOR_OVERRIDES, PLATFORM_MACHINE, factor + ), + get_task_factor( + task_name, PLATFORM_TASK_FACTOR_OVERRIDES, SYS_PLATFORM, factor + ), get_task_factor(task_name, VARIANT_TASK_FACTOR_OVERRIDES, variant, factor), global_task_factor(task_name, GLOBAL_TASK_FACTOR_OVERRIDES, factor), ] @@ -153,34 +171,71 @@ def main(): """Determine the resmoke jobs value a task should use in Evergreen.""" parser = argparse.ArgumentParser(description=main.__doc__) - parser.add_argument("--taskName", dest="task", required=True, help="Task being executed.") - parser.add_argument("--buildVariant", dest="variant", required=True, - help="Build variant task is being executed on.") - parser.add_argument("--distro", dest="distro", required=True, - help="Distro task is being executed on.") parser.add_argument( - "--jobFactor", dest="jobs_factor", type=float, default=1.0, - help=("Job factor to use as a mulitplier with the number of CPUs. Defaults" - " to %(default)s.")) + "--taskName", dest="task", required=True, help="Task being executed." + ) parser.add_argument( - "--jobsMax", dest="jobs_max", type=int, default=0, - help=("Maximum number of jobs to use. Specify 0 to indicate the number of" - " jobs is determined by --jobFactor and the number of CPUs. Defaults" - " to %(default)s.")) + "--buildVariant", + dest="variant", + required=True, + help="Build variant task is being executed on.", + ) parser.add_argument( - "--outFile", dest="outfile", help=("File to write configuration to. If" - " unspecified no file is generated.")) + "--distro", + dest="distro", + required=True, + help="Distro task is being executed on.", + ) + parser.add_argument( + "--jobFactor", + dest="jobs_factor", + type=float, + default=1.0, + help=( + "Job factor to use as a mulitplier with the number of CPUs. Defaults" + " to %(default)s." + ), + ) + parser.add_argument( + "--jobsMax", + dest="jobs_max", + type=int, + default=0, + help=( + "Maximum number of jobs to use. Specify 0 to indicate the number of" + " jobs is determined by --jobFactor and the number of CPUs. Defaults" + " to %(default)s." + ), + ) + parser.add_argument( + "--outFile", + dest="outfile", + help=( + "File to write configuration to. If" " unspecified no file is generated." + ), + ) options = parser.parse_args() logging.basicConfig(stream=sys.stdout, level=logging.INFO) structlog.configure(logger_factory=structlog.stdlib.LoggerFactory()) - LOGGER.info("Finding job count", task=options.task, variant=options.variant, - platform=PLATFORM_MACHINE, sys=SYS_PLATFORM, cpu_count=CPU_COUNT) + LOGGER.info( + "Finding job count", + task=options.task, + variant=options.variant, + platform=PLATFORM_MACHINE, + sys=SYS_PLATFORM, + cpu_count=CPU_COUNT, + ) - jobs = determine_jobs(options.task, options.variant, options.distro, options.jobs_max, - options.jobs_factor) + jobs = determine_jobs( + options.task, + options.variant, + options.distro, + options.jobs_max, + options.jobs_factor, + ) if jobs < CPU_COUNT: print("Reducing number of jobs to run from {} to {}".format(CPU_COUNT, jobs)) output_jobs(jobs, options.outfile) diff --git a/buildscripts/evergreen_task_tags.py b/buildscripts/evergreen_task_tags.py index 5964bc3bd40..d621c315fb8 100755 --- a/buildscripts/evergreen_task_tags.py +++ b/buildscripts/evergreen_task_tags.py @@ -20,17 +20,42 @@ def parse_command_line(): """Parse command line options.""" parser = argparse.ArgumentParser(description=main.__doc__) - parser.add_argument("--list-tags", action="store_true", default=False, - help="List all tags used by tasks in evergreen yml.") - parser.add_argument("--list-tasks", type=str, help="List all tasks for the given buildvariant.") - parser.add_argument("--list-variants-and-tasks", action="store_true", - help="List all tasks for every buildvariant.") - parser.add_argument("-t", "--tasks-for-tag", type=str, default=None, action="append", - help="List all tasks that use the given tag.") - parser.add_argument("-x", "--remove-tasks-for-tag-filter", type=str, default=None, - action="append", help="Remove tasks tagged with given tag.") - parser.add_argument("--evergreen-file", type=str, default=DEFAULT_EVERGREEN_FILE, - help="Location of evergreen file.") + parser.add_argument( + "--list-tags", + action="store_true", + default=False, + help="List all tags used by tasks in evergreen yml.", + ) + parser.add_argument( + "--list-tasks", type=str, help="List all tasks for the given buildvariant." + ) + parser.add_argument( + "--list-variants-and-tasks", + action="store_true", + help="List all tasks for every buildvariant.", + ) + parser.add_argument( + "-t", + "--tasks-for-tag", + type=str, + default=None, + action="append", + help="List all tasks that use the given tag.", + ) + parser.add_argument( + "-x", + "--remove-tasks-for-tag-filter", + type=str, + default=None, + action="append", + help="Remove tasks tagged with given tag.", + ) + parser.add_argument( + "--evergreen-file", + type=str, + default=DEFAULT_EVERGREEN_FILE, + help="Location of evergreen file.", + ) options = parser.parse_args() @@ -116,7 +141,9 @@ def get_tasks_with_tag(evg_config, tags, filters): :param filters: lst of tags to filter out. :return: list of tasks marked with the given tag. """ - return sorted([task.name for task in evg_config.tasks if is_task_tagged(task, tags, filters)]) + return sorted( + [task.name for task in evg_config.tasks if is_task_tagged(task, tags, filters)] + ) def list_tasks_with_tag(evg_config, tags, filters): @@ -148,7 +175,9 @@ def main(): list_all_tasks(evg_config, options.list_tasks) if options.tasks_for_tag: - list_tasks_with_tag(evg_config, options.tasks_for_tag, options.remove_tasks_for_tag_filter) + list_tasks_with_tag( + evg_config, options.tasks_for_tag, options.remove_tasks_for_tag_filter + ) if __name__ == "__main__": diff --git a/buildscripts/evergreen_task_timeout.py b/buildscripts/evergreen_task_timeout.py index f33079f6308..d8d6aacdb92 100755 --- a/buildscripts/evergreen_task_timeout.py +++ b/buildscripts/evergreen_task_timeout.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 """Determine the timeout value a task should use in evergreen.""" + from __future__ import annotations import argparse @@ -59,8 +60,12 @@ class TimeoutOverride(BaseModel): idle_timeout: Optional[int] = None @classmethod - def from_seconds(cls, task: str, exec_timeout_secs: Optional[float], - idle_timeout_secs: Optional[float]) -> TimeoutOverride: + def from_seconds( + cls, + task: str, + exec_timeout_secs: Optional[float], + idle_timeout_secs: Optional[float], + ) -> TimeoutOverride: """Create an instance of an override from seconds.""" exec_timeout = exec_timeout_secs / 60 if exec_timeout_secs else None idle_timeout = idle_timeout_secs / 60 if idle_timeout_secs else None @@ -94,7 +99,9 @@ class TimeoutOverrides(BaseModel): with open(file_path) as file_handler: return cls(**yaml.safe_load(file_handler)) - def _lookup_override(self, build_variant: str, task_name: str) -> Optional[TimeoutOverride]: + def _lookup_override( + self, build_variant: str, task_name: str + ) -> Optional[TimeoutOverride]: """ Check if the given task on the given build variant has an override defined. @@ -105,19 +112,27 @@ class TimeoutOverrides(BaseModel): :return: Timeout override if found. """ overrides = [ - override for override in self.overrides.get(build_variant, []) + override + for override in self.overrides.get(build_variant, []) if override.task == task_name ] if overrides: if len(overrides) > 1: - LOGGER.error("Found multiple overrides for the same task", - build_variant=build_variant, task=task_name, - overrides=[override.dict() for override in overrides]) - raise ValueError(f"Found multiple overrides for '{task_name}' on '{build_variant}'") + LOGGER.error( + "Found multiple overrides for the same task", + build_variant=build_variant, + task=task_name, + overrides=[override.dict() for override in overrides], + ) + raise ValueError( + f"Found multiple overrides for '{task_name}' on '{build_variant}'" + ) return overrides[0] return None - def lookup_exec_override(self, build_variant: str, task_name: str) -> Optional[timedelta]: + def lookup_exec_override( + self, build_variant: str, task_name: str + ) -> Optional[timedelta]: """ Look up the exec timeout override of the given build variant/task. @@ -130,7 +145,9 @@ class TimeoutOverrides(BaseModel): return override.get_exec_timeout() return None - def lookup_idle_override(self, build_variant: str, task_name: str) -> Optional[timedelta]: + def lookup_idle_override( + self, build_variant: str, task_name: str + ) -> Optional[timedelta]: """ Look up the idle timeout override of the given build variant/task. @@ -144,8 +161,11 @@ class TimeoutOverrides(BaseModel): return None -def output_timeout(exec_timeout: timedelta, idle_timeout: Optional[timedelta], - output_file: Optional[str]) -> None: +def output_timeout( + exec_timeout: timedelta, + idle_timeout: Optional[timedelta], + output_file: Optional[str], +) -> None: """ Output timeout configuration to the specified location. @@ -170,8 +190,12 @@ class TaskTimeoutOrchestrator: """An orchestrator for determining task timeouts.""" @inject.autoparams() - def __init__(self, timeout_service: TimeoutService, timeout_overrides: TimeoutOverrides, - evg_project_config: EvergreenProjectConfig) -> None: + def __init__( + self, + timeout_service: TimeoutService, + timeout_overrides: TimeoutOverrides, + evg_project_config: EvergreenProjectConfig, + ) -> None: """ Initialize the orchestrator. @@ -183,10 +207,15 @@ class TaskTimeoutOrchestrator: self.timeout_overrides = timeout_overrides self.evg_project_config = evg_project_config - def determine_exec_timeout(self, task_name: str, variant: str, - idle_timeout: Optional[timedelta] = None, - exec_timeout: Optional[timedelta] = None, evg_alias: str = "", - historic_timeout: Optional[timedelta] = None) -> timedelta: + def determine_exec_timeout( + self, + task_name: str, + variant: str, + idle_timeout: Optional[timedelta] = None, + exec_timeout: Optional[timedelta] = None, + evg_alias: str = "", + historic_timeout: Optional[timedelta] = None, + ) -> timedelta: """ Determine what exec timeout should be used. @@ -205,36 +234,56 @@ class TaskTimeoutOrchestrator: override = self.timeout_overrides.lookup_exec_override(variant, task_name) if exec_timeout and exec_timeout.total_seconds() != 0: - LOGGER.info("Using timeout from cmd line", - exec_timeout_secs=exec_timeout.total_seconds()) + LOGGER.info( + "Using timeout from cmd line", + exec_timeout_secs=exec_timeout.total_seconds(), + ) determined_timeout = exec_timeout elif override is not None: - LOGGER.info("Overriding configured timeout", exec_timeout_secs=override.total_seconds()) + LOGGER.info( + "Overriding configured timeout", + exec_timeout_secs=override.total_seconds(), + ) determined_timeout = override - elif self._is_required_build_variant( - variant) and determined_timeout > DEFAULT_REQUIRED_BUILD_TIMEOUT: - LOGGER.info("Overriding required-builder timeout", - exec_timeout_secs=DEFAULT_REQUIRED_BUILD_TIMEOUT.total_seconds()) + elif ( + self._is_required_build_variant(variant) + and determined_timeout > DEFAULT_REQUIRED_BUILD_TIMEOUT + ): + LOGGER.info( + "Overriding required-builder timeout", + exec_timeout_secs=DEFAULT_REQUIRED_BUILD_TIMEOUT.total_seconds(), + ) determined_timeout = DEFAULT_REQUIRED_BUILD_TIMEOUT elif evg_alias == COMMIT_QUEUE_ALIAS: - LOGGER.info("Overriding commit-queue timeout", - exec_timeout_secs=COMMIT_QUEUE_TIMEOUT.total_seconds()) + LOGGER.info( + "Overriding commit-queue timeout", + exec_timeout_secs=COMMIT_QUEUE_TIMEOUT.total_seconds(), + ) determined_timeout = COMMIT_QUEUE_TIMEOUT # The timeout needs to be at least as large as the idle timeout. - if idle_timeout and determined_timeout.total_seconds() < idle_timeout.total_seconds(): - LOGGER.info("Making exec timeout as large as idle timeout", - exec_timeout_secs=idle_timeout.total_seconds()) + if ( + idle_timeout + and determined_timeout.total_seconds() < idle_timeout.total_seconds() + ): + LOGGER.info( + "Making exec timeout as large as idle timeout", + exec_timeout_secs=idle_timeout.total_seconds(), + ) return idle_timeout return determined_timeout - def determine_idle_timeout(self, task_name: str, variant: str, - idle_timeout: Optional[timedelta] = None, - historic_timeout: Optional[timedelta] = None) -> Optional[timedelta]: + def determine_idle_timeout( + self, + task_name: str, + variant: str, + idle_timeout: Optional[timedelta] = None, + historic_timeout: Optional[timedelta] = None, + ) -> Optional[timedelta]: """ Determine what idle timeout should be used. @@ -249,18 +298,29 @@ class TaskTimeoutOrchestrator: override = self.timeout_overrides.lookup_idle_override(variant, task_name) if idle_timeout and idle_timeout.total_seconds() != 0: - LOGGER.info("Using timeout from cmd line", - idle_timeout_secs=idle_timeout.total_seconds()) + LOGGER.info( + "Using timeout from cmd line", + idle_timeout_secs=idle_timeout.total_seconds(), + ) determined_timeout = idle_timeout elif override is not None: - LOGGER.info("Overriding configured timeout", idle_timeout_secs=override.total_seconds()) + LOGGER.info( + "Overriding configured timeout", + idle_timeout_secs=override.total_seconds(), + ) determined_timeout = override return determined_timeout - def determine_historic_timeout(self, project: str, task: str, variant: str, suite_name: str, - exec_timeout_factor: Optional[float]) -> TimeoutOverride: + def determine_historic_timeout( + self, + project: str, + task: str, + variant: str, + suite_name: str, + exec_timeout_factor: Optional[float], + ) -> TimeoutOverride: """ Calculate the timeout based on historic test results. @@ -283,11 +343,15 @@ class TaskTimeoutOrchestrator: timeout_estimate = self.timeout_service.get_timeout_estimate(timeout_params) if timeout_estimate and timeout_estimate.is_specified(): exec_timeout = timeout_estimate.calculate_task_timeout( - repeat_factor=1, scaling_factor=exec_timeout_factor) + repeat_factor=1, scaling_factor=exec_timeout_factor + ) idle_timeout = timeout_estimate.calculate_test_timeout(repeat_factor=1) if exec_timeout is not None or idle_timeout is not None: - LOGGER.info("Getting historic based timeout", exec_timeout_secs=exec_timeout, - idle_timeout_secs=idle_timeout) + LOGGER.info( + "Getting historic based timeout", + exec_timeout_secs=exec_timeout, + idle_timeout_secs=idle_timeout, + ) return TimeoutOverride.from_seconds(task, exec_timeout, idle_timeout) return TimeoutOverride(task=task, exec_timeout=None, idle_timeout=None) @@ -312,10 +376,18 @@ class TaskTimeoutOrchestrator: bv = self.evg_project_config.get_variant(build_variant) return "!" in bv.display_name - def determine_timeouts(self, cli_idle_timeout: Optional[timedelta], - cli_exec_timeout: Optional[timedelta], outfile: Optional[str], - project: str, task: str, variant: str, evg_alias: str, suite_name: str, - exec_timeout_factor: Optional[float]) -> None: + def determine_timeouts( + self, + cli_idle_timeout: Optional[timedelta], + cli_exec_timeout: Optional[timedelta], + outfile: Optional[str], + project: str, + task: str, + variant: str, + evg_alias: str, + suite_name: str, + exec_timeout_factor: Optional[float], + ) -> None: """ Determine the timeouts to use for the given task and write timeouts to expansion file. @@ -329,13 +401,21 @@ class TaskTimeoutOrchestrator: :param suite_name: Name of evergreen suite being run. :param exec_timeout_factor: Scaling factor to use when determining timeout. """ - historic_timeout = self.determine_historic_timeout(project, task, variant, suite_name, - exec_timeout_factor) + historic_timeout = self.determine_historic_timeout( + project, task, variant, suite_name, exec_timeout_factor + ) - idle_timeout = self.determine_idle_timeout(task, variant, cli_idle_timeout, - historic_timeout.get_idle_timeout()) - exec_timeout = self.determine_exec_timeout(task, variant, idle_timeout, cli_exec_timeout, - evg_alias, historic_timeout.get_exec_timeout()) + idle_timeout = self.determine_idle_timeout( + task, variant, cli_idle_timeout, historic_timeout.get_idle_timeout() + ) + exec_timeout = self.determine_exec_timeout( + task, + variant, + idle_timeout, + cli_exec_timeout, + evg_alias, + historic_timeout.get_exec_timeout(), + ) output_timeout(exec_timeout, idle_timeout, outfile) @@ -344,42 +424,92 @@ def main(): """Determine the timeout value a task should use in evergreen.""" parser = argparse.ArgumentParser(description=main.__doc__) - parser.add_argument("--install-dir", dest="install_dir", required=True, - help="Path to bin directory of testable installation") - parser.add_argument("--task-name", dest="task", required=True, help="Task being executed.") - parser.add_argument("--suite-name", dest="suite_name", required=True, - help="Resmoke suite being run against.") - parser.add_argument("--build-variant", dest="variant", required=True, - help="Build variant task is being executed on.") - parser.add_argument("--project", dest="project", required=True, - help="Evergreen project task is being executed on.") - parser.add_argument("--evg-alias", dest="evg_alias", required=True, - help="Evergreen alias used to trigger build.") - parser.add_argument("--test-flags", dest="test_flags", - help="Test flags that are used for `resmoke.py run` command call.") - parser.add_argument("--timeout", dest="timeout", type=int, help="Timeout to use (in sec).") - parser.add_argument("--exec-timeout", dest="exec_timeout", type=int, - help="Exec timeout to use (in sec).") - parser.add_argument("--exec-timeout-factor", dest="exec_timeout_factor", type=float, - help="Exec timeout factor to use (in sec).") - parser.add_argument("--out-file", dest="outfile", help="File to write configuration to.") - parser.add_argument("--timeout-overrides", dest="timeout_overrides_file", - default=DEFAULT_TIMEOUT_OVERRIDES, - help="File containing timeout overrides to use.") - parser.add_argument("--evg-api-config", dest="evg_api_config", - default=DEFAULT_EVERGREEN_AUTH_CONFIG, help="Evergreen API config file.") - parser.add_argument("--evg-project-config", dest="evg_project_config", - default=DEFAULT_EVERGREEN_CONFIG, help="Evergreen project config file.") + parser.add_argument( + "--install-dir", + dest="install_dir", + required=True, + help="Path to bin directory of testable installation", + ) + parser.add_argument( + "--task-name", dest="task", required=True, help="Task being executed." + ) + parser.add_argument( + "--suite-name", + dest="suite_name", + required=True, + help="Resmoke suite being run against.", + ) + parser.add_argument( + "--build-variant", + dest="variant", + required=True, + help="Build variant task is being executed on.", + ) + parser.add_argument( + "--project", + dest="project", + required=True, + help="Evergreen project task is being executed on.", + ) + parser.add_argument( + "--evg-alias", + dest="evg_alias", + required=True, + help="Evergreen alias used to trigger build.", + ) + parser.add_argument( + "--test-flags", + dest="test_flags", + help="Test flags that are used for `resmoke.py run` command call.", + ) + parser.add_argument( + "--timeout", dest="timeout", type=int, help="Timeout to use (in sec)." + ) + parser.add_argument( + "--exec-timeout", + dest="exec_timeout", + type=int, + help="Exec timeout to use (in sec).", + ) + parser.add_argument( + "--exec-timeout-factor", + dest="exec_timeout_factor", + type=float, + help="Exec timeout factor to use (in sec).", + ) + parser.add_argument( + "--out-file", dest="outfile", help="File to write configuration to." + ) + parser.add_argument( + "--timeout-overrides", + dest="timeout_overrides_file", + default=DEFAULT_TIMEOUT_OVERRIDES, + help="File containing timeout overrides to use.", + ) + parser.add_argument( + "--evg-api-config", + dest="evg_api_config", + default=DEFAULT_EVERGREEN_AUTH_CONFIG, + help="Evergreen API config file.", + ) + parser.add_argument( + "--evg-project-config", + dest="evg_project_config", + default=DEFAULT_EVERGREEN_CONFIG, + help="Evergreen project config file.", + ) options = parser.parse_args() timeout_override = timedelta(seconds=options.timeout) if options.timeout else None - exec_timeout_override = timedelta( - seconds=options.exec_timeout) if options.exec_timeout else None + exec_timeout_override = ( + timedelta(seconds=options.exec_timeout) if options.exec_timeout else None + ) task_name = determine_task_base_name(options.task, options.variant) timeout_overrides = TimeoutOverrides.from_yaml_file( - os.path.expanduser(options.timeout_overrides_file)) + os.path.expanduser(options.timeout_overrides_file) + ) enable_logging(verbose=False) LOGGER.info("Determining timeouts", cli_args=options) @@ -387,22 +517,36 @@ def main(): def dependencies(binder: inject.Binder) -> None: binder.bind( EvergreenApi, - RetryingEvergreenApi.get_api(config_file=os.path.expanduser(options.evg_api_config))) + RetryingEvergreenApi.get_api( + config_file=os.path.expanduser(options.evg_api_config) + ), + ) binder.bind(TimeoutOverrides, timeout_overrides) - binder.bind(EvergreenProjectConfig, - parse_evergreen_file(os.path.expanduser(options.evg_project_config))) + binder.bind( + EvergreenProjectConfig, + parse_evergreen_file(os.path.expanduser(options.evg_project_config)), + ) binder.bind( ResmokeProxyService, ResmokeProxyService( - run_options=f"--installDir={shlex.quote(options.install_dir)} {options.test_flags}") + run_options=f"--installDir={shlex.quote(options.install_dir)} {options.test_flags}" + ), ) inject.configure(dependencies) task_timeout_orchestrator = inject.instance(TaskTimeoutOrchestrator) task_timeout_orchestrator.determine_timeouts( - timeout_override, exec_timeout_override, options.outfile, options.project, task_name, - options.variant, options.evg_alias, options.suite_name, options.exec_timeout_factor) + timeout_override, + exec_timeout_override, + options.outfile, + options.project, + task_name, + options.variant, + options.evg_alias, + options.suite_name, + options.exec_timeout_factor, + ) if __name__ == "__main__": diff --git a/buildscripts/fast_archive.py b/buildscripts/fast_archive.py index 841d65c8e99..c01b6ccdc3b 100644 --- a/buildscripts/fast_archive.py +++ b/buildscripts/fast_archive.py @@ -15,9 +15,19 @@ import requests from buildscripts.util.read_config import read_config_file -def process_file(file: str, aws_secret: str, aws_key: str, project: str, variant: str, - version_id: str, revision: int, task_name: str, file_number: int, upload_name: str, - start_time: int) -> Optional[Dict[str, str]]: +def process_file( + file: str, + aws_secret: str, + aws_key: str, + project: str, + variant: str, + version_id: str, + revision: int, + task_name: str, + file_number: int, + upload_name: str, + start_time: int, +) -> Optional[Dict[str, str]]: print(f"{file} started compressing at {time.time() - start_time}") compressed_file = f"{file}.gz" with open(file, "rb") as f_in: @@ -26,12 +36,16 @@ def process_file(file: str, aws_secret: str, aws_key: str, project: str, variant print(f"{file} finished compressing at {time.time() - start_time}") - s3_client = boto3.client('s3', aws_access_key_id=aws_key, aws_secret_access_key=aws_secret) + s3_client = boto3.client( + "s3", aws_access_key_id=aws_key, aws_secret_access_key=aws_secret + ) basename = os.path.basename(compressed_file) object_path = f"{project}/{variant}/{version_id}/{task_name}-{revision}-{file_number}/{basename}" extra_args = {"ContentType": "application/gzip", "ACL": "public-read"} try: - s3_client.upload_file(compressed_file, "mciuploads", object_path, ExtraArgs=extra_args) + s3_client.upload_file( + compressed_file, "mciuploads", object_path, ExtraArgs=extra_args + ) except Exception as ex: print(f"ERROR: failed to upload file to s3 {file}", file=sys.stderr) print(ex, file=sys.stderr) @@ -43,8 +57,10 @@ def process_file(file: str, aws_secret: str, aws_key: str, project: str, variant # Sanity check to ensure the url exists r = requests.head(url) if r.status_code != 200: - print(f"ERROR: Could not verify that {compressed_file} was uploaded to {url}", - file=sys.stderr) + print( + f"ERROR: Could not verify that {compressed_file} was uploaded to {url}", + file=sys.stderr, + ) return None print(f"{compressed_file} uploaded at {time.time() - start_time} to {url}") @@ -59,8 +75,9 @@ def process_file(file: str, aws_secret: str, aws_key: str, project: str, variant return task_artifact -def main(output_file: str, patterns: List[str], display_name: str, expansions_file: str) -> int: - +def main( + output_file: str, patterns: List[str], display_name: str, expansions_file: str +) -> int: if not output_file.endswith(".json"): print("WARN: filename input should end with `.json`", file=sys.stderr) @@ -98,11 +115,21 @@ def main(output_file: str, patterns: List[str], display_name: str, expansions_fi for i, path in enumerate(files): file_number = i + 1 futures.append( - executor.submit(process_file, file=path, aws_secret=aws_secret_key, - aws_key=aws_access_key, project=project, variant=build_variant, - version_id=version_id, revision=revision, task_name=task_name, - file_number=file_number, upload_name=display_name, - start_time=start_time)) + executor.submit( + process_file, + file=path, + aws_secret=aws_secret_key, + aws_key=aws_access_key, + project=project, + variant=build_variant, + version_id=version_id, + revision=revision, + task_name=task_name, + file_number=file_number, + upload_name=display_name, + start_time=start_time, + ) + ) for future in concurrent.futures.as_completed(futures): result = future.result() @@ -118,21 +145,40 @@ def main(output_file: str, patterns: List[str], display_name: str, expansions_fi if __name__ == "__main__": parser = argparse.ArgumentParser( - prog='FastArchiver', - description='This improves archiving times of a large amount of big files in evergreen ' - 'by compressing and uploading them asynchronously. ' - 'This also uses pigz, which is a multithreaded implementation of gzip, ' - 'to improve gzipping times.') + prog="FastArchiver", + description="This improves archiving times of a large amount of big files in evergreen " + "by compressing and uploading them asynchronously. " + "This also uses pigz, which is a multithreaded implementation of gzip, " + "to improve gzipping times.", + ) - parser.add_argument("--output-file", "-f", help="Name of output attach.artifacts file.", - required=True) - parser.add_argument("--pattern", "-p", help="glob patterns of files to be archived.", - dest="patterns", action="append", default=[], required=True) - parser.add_argument("--display-name", "-n", help="The display name of the file in evergreen", - required=True) - parser.add_argument("--expansions-file", "-e", - help="Expansions file to read task info and aws credentials from.", - default="../expansions.yml") + parser.add_argument( + "--output-file", + "-f", + help="Name of output attach.artifacts file.", + required=True, + ) + parser.add_argument( + "--pattern", + "-p", + help="glob patterns of files to be archived.", + dest="patterns", + action="append", + default=[], + required=True, + ) + parser.add_argument( + "--display-name", + "-n", + help="The display name of the file in evergreen", + required=True, + ) + parser.add_argument( + "--expansions-file", + "-e", + help="Expansions file to read task info and aws credentials from.", + default="../expansions.yml", + ) args = parser.parse_args() exit(main(args.output_file, args.patterns, args.display_name, args.expansions_file)) diff --git a/buildscripts/feature_flag_tags_check.py b/buildscripts/feature_flag_tags_check.py index fe3bed82cb1..46d510a0b24 100755 --- a/buildscripts/feature_flag_tags_check.py +++ b/buildscripts/feature_flag_tags_check.py @@ -61,7 +61,9 @@ def main(diff_file, ent_path): base_feature_flags = fh.read().split() with open("patch_all_feature_flags.txt", "r") as fh: patch_feature_flags = fh.read().split() - enabled_feature_flags = [flag for flag in base_feature_flags if flag not in patch_feature_flags] + enabled_feature_flags = [ + flag for flag in base_feature_flags if flag not in patch_feature_flags + ] if not enabled_feature_flags: print( @@ -69,23 +71,30 @@ def main(diff_file, ent_path): ) sys.exit(0) - tests_with_feature_flag_tag = get_tests_with_feature_flag_tags(enabled_feature_flags, ent_path) + tests_with_feature_flag_tag = get_tests_with_feature_flag_tags( + enabled_feature_flags, ent_path + ) _run_git_cmd(["apply", diff_file]) _run_git_cmd(["apply", diff_file], cwd=ent_path) tests_missing_fcv_tag = get_tests_missing_fcv_tag(tests_with_feature_flag_tag) if tests_missing_fcv_tag: - print(f"Found tests missing `{REQUIRES_FCV_TAG_LATEST}` tag:\n" + - "\n".join(tests_missing_fcv_tag)) + print( + f"Found tests missing `{REQUIRES_FCV_TAG_LATEST}` tag:\n" + + "\n".join(tests_missing_fcv_tag) + ) sys.exit(1) sys.exit(0) if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--diff-file-name", type=str, - help="Name of the file containing the git diff") - parser.add_argument("--enterprise-path", type=str, help="Path to the enterprise module") + parser.add_argument( + "--diff-file-name", type=str, help="Name of the file containing the git diff" + ) + parser.add_argument( + "--enterprise-path", type=str, help="Path to the enterprise module" + ) args = parser.parse_args() main(args.diff_file_name, args.enterprise_path) diff --git a/buildscripts/gather_failed_unittests.py b/buildscripts/gather_failed_unittests.py index 963289b11de..9a45fc7a996 100644 --- a/buildscripts/gather_failed_unittests.py +++ b/buildscripts/gather_failed_unittests.py @@ -29,7 +29,9 @@ def _relink_binaries_with_symbols(failed_tests: List[str]): bazel_build_flags.replace("--config=strip-debug-during-link", "") relink_command = [ - arg for arg in ["bazel", "build", *bazel_build_flags.split(" "), *failed_tests] if arg + arg + for arg in ["bazel", "build", *bazel_build_flags.split(" "), *failed_tests] + if arg ] # Enable this when/if we want to strip debug symbols during linking unit tests @@ -39,13 +41,17 @@ def _relink_binaries_with_symbols(failed_tests: List[str]): # check=True, # ) - repro_test_command = " ".join(["test" if arg == "build" else arg for arg in relink_command]) + repro_test_command = " ".join( + ["test" if arg == "build" else arg for arg in relink_command] + ) with open(".failed_unittest_repro.txt", "w", encoding="utf-8") as f: f.write(repro_test_command) print(f"Repro command written to .failed_unittest_repro.txt: {repro_test_command}") -def _copy_bins_to_upload(failed_tests: List[str], upload_bin_dir: str, upload_lib_dir: str) -> bool: +def _copy_bins_to_upload( + failed_tests: List[str], upload_bin_dir: str, upload_lib_dir: str +) -> bool: success = True bazel_bin_dir = Path("./bazel-bin") for failed_test in failed_tests: @@ -80,7 +86,9 @@ def _copy_bins_to_upload(failed_tests: List[str], upload_bin_dir: str, upload_li dsym_dir = full_binary_path.with_suffix(".dSYM") if dsym_dir.is_dir(): print(f"Copying dsym {dsym_dir} to {upload_bin_dir}") - shutil.copytree(dsym_dir, upload_bin_dir / dsym_dir.name, dirs_exist_ok=True) + shutil.copytree( + dsym_dir, upload_bin_dir / dsym_dir.name, dirs_exist_ok=True + ) # Copy debug symbols for dynamic builds lib_dir = Path("bazel-bin/install/lib") @@ -105,7 +113,9 @@ def main(testlog_dir: str = "bazel-testlogs"): print("No failed tests found.") exit(0) - print(f"Found {len(failed_tests)} failed tests. Gathering binaries and debug symbols.") + print( + f"Found {len(failed_tests)} failed tests. Gathering binaries and debug symbols." + ) _relink_binaries_with_symbols(failed_tests) print("Copying binaries and debug symbols to upload directories.") diff --git a/buildscripts/gdb/mongo.py b/buildscripts/gdb/mongo.py index b9c5220691b..b4ddd06a29f 100644 --- a/buildscripts/gdb/mongo.py +++ b/buildscripts/gdb/mongo.py @@ -22,25 +22,24 @@ if not gdb: def detect_toolchain(progspace): - - readelf_bin = '/opt/mongodbtoolchain/v4/bin/readelf' + readelf_bin = "/opt/mongodbtoolchain/v4/bin/readelf" if not os.path.exists(readelf_bin): - readelf_bin = 'readelf' + readelf_bin = "readelf" - gcc_version_regex = re.compile(r'.*\]\s*GCC: \(GNU\) (\d+\.\d+\.\d+)\s*$') - clang_version_regex = re.compile(r'.*\]\s*MongoDB clang version (\d+\.\d+\.\d+).*') + gcc_version_regex = re.compile(r".*\]\s*GCC: \(GNU\) (\d+\.\d+\.\d+)\s*$") + clang_version_regex = re.compile(r".*\]\s*MongoDB clang version (\d+\.\d+\.\d+).*") - readelf_cmd = [readelf_bin, '-p', '.comment', progspace.filename] + readelf_cmd = [readelf_bin, "-p", ".comment", progspace.filename] # take an educated guess as to where we could find the c++ printers, better than hardcoding result = subprocess.run(readelf_cmd, capture_output=True, text=True) gcc_version = None - for line in result.stdout.split('\n'): + for line in result.stdout.split("\n"): if match := re.search(gcc_version_regex, line): gcc_version = match.group(1) break clang_version = None - for line in result.stdout.split('\n'): + for line in result.stdout.split("\n"): if match := re.search(clang_version_regex, line): clang_version = match.group(1) break @@ -55,13 +54,13 @@ def detect_toolchain(progspace): # default is v4 if we can't find a version toolchain_ver = None if gcc_version: - if int(gcc_version.split('.')[0]) == 8: - toolchain_ver = 'v3' - elif int(gcc_version.split('.')[0]) == 11: - toolchain_ver = 'v4' + if int(gcc_version.split(".")[0]) == 8: + toolchain_ver = "v3" + elif int(gcc_version.split(".")[0]) == 11: + toolchain_ver = "v4" if not toolchain_ver: - toolchain_ver = 'v4' + toolchain_ver = "v4" print(f""" WARNING: could not detect a MongoDB toolchain to load matching stdcxx printers: ----------------- @@ -75,7 +74,8 @@ STDERR: Assuming {toolchain_ver} as a default, this could cause issues with the printers.""") pp = glob.glob( - f"/opt/mongodbtoolchain/{toolchain_ver}/share/gcc-*/python/libstdcxx/v6/printers.py") + f"/opt/mongodbtoolchain/{toolchain_ver}/share/gcc-*/python/libstdcxx/v6/printers.py" + ) printers = pp[0] return os.path.dirname(os.path.dirname(os.path.dirname(printers))) @@ -91,6 +91,7 @@ def load_libstdcxx_printers(progspace): global stdlib_printers # pylint: disable=invalid-name,global-variable-not-assigned from libstdcxx.v6 import printers as stdlib_printers from libstdcxx.v6 import register_libstdcxx_printers + register_libstdcxx_printers(progspace) print( f"Loaded libstdc++ pretty printers from '{stdcxx_printer_toolchain_paths[progspace]}'" @@ -112,7 +113,8 @@ def on_new_object_file(objfile) -> None: warnings.warn( "Unable to locate the libstdc++ GDB pretty printers without an executable" " file. Try running the `file` command with the path to the executable file" - " and reloading the core dump with the `core-file` command") + " and reloading the core dump with the `core-file` command" + ) return load_libstdcxx_printers(objfile.new_objfile.progspace) @@ -132,7 +134,8 @@ except ImportError: if sys.version_info[0] < 3: raise gdb.GdbError( - "MongoDB gdb extensions only support Python 3. Your GDB was compiled against Python 2") + "MongoDB gdb extensions only support Python 3. Your GDB was compiled against Python 2" + ) def get_process_name(): @@ -190,12 +193,14 @@ def lookup_type(gdb_type_str: str) -> gdb.Type: except Exception as exc: exceptions.append(exc) - raise gdb.error("Failed to get type, tried:\n%s" % '\n'.join([str(exc) for exc in exceptions])) + raise gdb.error( + "Failed to get type, tried:\n%s" % "\n".join([str(exc) for exc in exceptions]) + ) def get_current_thread_name(): """Return the name of the current GDB thread.""" - fallback_name = '"%s"' % (gdb.selected_thread().name or '') + fallback_name = '"%s"' % (gdb.selected_thread().name or "") try: # This goes through the pretty printer for StringData which adds "" around the name. name = str(gdb.parse_and_eval("mongo::getThreadName()")) @@ -208,7 +213,9 @@ def get_current_thread_name(): def get_global_service_context(): """Return the global ServiceContext object.""" - return gdb.parse_and_eval("'mongo::(anonymous namespace)::globalServiceContext'").dereference() + return gdb.parse_and_eval( + "'mongo::(anonymous namespace)::globalServiceContext'" + ).dereference() def get_session_catalog(): @@ -217,7 +224,9 @@ def get_session_catalog(): Returns None if no SessionCatalog could be found. """ # The SessionCatalog is a decoration on the ServiceContext. - session_catalog_dec = get_decoration(get_global_service_context(), "mongo::SessionCatalog") + session_catalog_dec = get_decoration( + get_global_service_context(), "mongo::SessionCatalog" + ) if session_catalog_dec is None: return None return session_catalog_dec[1] @@ -251,7 +260,8 @@ def get_wt_session(recovery_unit, recovery_unit_impl_type): if not wt_session_handle.dereference().address: return None wt_session = wt_session_handle.dereference().cast( - lookup_type("mongo::WiredTigerSession"))["_session"] + lookup_type("mongo::WiredTigerSession") + )["_session"] return wt_session @@ -269,28 +279,32 @@ def get_decorations(obj): try: yield (deco_type_name, obj) except Exception as err: - print("Failed to look up decoration type: " + deco_type_name + ": " + str(err)) + print( + "Failed to look up decoration type: " + deco_type_name + ": " + str(err) + ) def get_object_decoration(decorable, start, index): decoration_data = get_unique_ptr_bytes(decorable["_decorations"]["_data"]) entry = start[index] deco_type_info = str(entry["typeInfo"]) - deco_type_name = re.sub(r'.* ', r'\1', deco_type_info) + deco_type_name = re.sub(r".* ", r"\1", deco_type_info) offset = int(entry["offset"]) obj = decoration_data[offset] - obj_addr = re.sub(r'^(.*) .*', r'\1', str(obj.address)) + obj_addr = re.sub(r"^(.*) .*", r"\1", str(obj.address)) obj = _cast_decoration_value(deco_type_name, int(obj.address)) return (deco_type_name, obj, obj_addr) def get_decorable_info(decorable): decorable_t = decorable.type.template_argument(0) - reg_sym, _ = gdb.lookup_symbol("mongo::decorable_detail::gdbRegistry<{}>".format(decorable_t)) + reg_sym, _ = gdb.lookup_symbol( + "mongo::decorable_detail::gdbRegistry<{}>".format(decorable_t) + ) decl_vector = reg_sym.value()["_entries"] start = decl_vector["_M_impl"]["_M_start"] finish = decl_vector["_M_impl"]["_M_finish"] - decinfo_t = lookup_type('mongo::decorable_detail::Registry::Entry') + decinfo_t = lookup_type("mongo::decorable_detail::Registry::Entry") count = int((int(finish) - int(start)) / decinfo_t.sizeof) return start, count @@ -327,10 +341,10 @@ def get_boost_optional(optional): TODO: Import the boost pretty printers instead of using this custom function. """ - if not optional['m_initialized']: + if not optional["m_initialized"]: return None value_ref_type = optional.type.template_argument(0).pointer() - storage = optional['m_storage']['dummy_']['data'] + storage = optional["m_storage"]["dummy_"]["data"] return storage.cast(value_ref_type).dereference() @@ -409,7 +423,11 @@ class GetMongoDecoration(gdb.Command): (type_name, obj) = dec print(type_name, obj) else: - print("No decoration found whose type name contains '" + type_name_substr + "'.") + print( + "No decoration found whose type name contains '" + + type_name_substr + + "'." + ) # Register command @@ -448,21 +466,25 @@ class DumpMongoDSessionCatalog(gdb.Command): ) return session_kv_pairs = get_session_kv_pairs() - print("Dumping %d Session objects from the SessionCatalog" % len(session_kv_pairs)) + print( + "Dumping %d Session objects from the SessionCatalog" % len(session_kv_pairs) + ) # Optionally search for a specified session, based on its id. if lsid_to_find: - print("Only printing information for session " + lsid_to_find + ", if found.") + print( + "Only printing information for session " + lsid_to_find + ", if found." + ) lsids_to_print = [lsid_to_find] else: - lsids_to_print = [str(s['first']['_id']) for s in session_kv_pairs] + lsids_to_print = [str(s["first"]["_id"]) for s in session_kv_pairs] for session_kv in session_kv_pairs: # The Session objects are stored inside the SessionRuntimeInfo object. - session_runtime_info = get_unique_ptr(session_kv['second']).dereference() - parent_session = session_runtime_info['parentSession'] - child_sessions = absl_get_nodes(session_runtime_info['childSessions']) - lsid = str(parent_session['_sessionId']['_id']) + session_runtime_info = get_unique_ptr(session_kv["second"]).dereference() + parent_session = session_runtime_info["parentSession"] + child_sessions = absl_get_nodes(session_runtime_info["childSessions"]) + lsid = str(parent_session["_sessionId"]["_id"]) # If we are only interested in a specific session, then we print out the entire Session # objects, to aid more detailed debugging. @@ -470,7 +492,7 @@ class DumpMongoDSessionCatalog(gdb.Command): self.dump_session_runtime_info(session_runtime_info) print(parent_session) for child_session_kv in child_sessions: - child_session = child_session_kv['second'] + child_session = child_session_kv["second"] print(child_session) # Terminate if this is the only session we care about. break @@ -484,24 +506,27 @@ class DumpMongoDSessionCatalog(gdb.Command): self.dump_session_runtime_info(session_runtime_info) self.dump_session(parent_session) for child_session_kv in child_sessions: - child_session = child_session_kv['second'] + child_session = child_session_kv["second"] self.dump_session(child_session) @staticmethod def dump_session_runtime_info(session_runtime_info): """Dump the session runtime info.""" - parent_session = session_runtime_info['parentSession'] + parent_session = session_runtime_info["parentSession"] # TODO: Add a custom pretty printer for LogicalSessionId. - lsid = str(parent_session['_sessionId']['_id'])[1:-1] + lsid = str(parent_session["_sessionId"]["_id"])[1:-1] print("SessionId =", lsid) - fields_to_print = ['checkoutOpCtx', 'killsRequested'] + fields_to_print = ["checkoutOpCtx", "killsRequested"] for field in fields_to_print: # Skip fields that aren't found on the object. if field in get_field_names(session_runtime_info): print(field, "=", session_runtime_info[field]) else: - print("Could not find field '%s' on the SessionRuntimeInfo object." % field) + print( + "Could not find field '%s' on the SessionRuntimeInfo object." + % field + ) print("") @staticmethod @@ -509,13 +534,16 @@ class DumpMongoDSessionCatalog(gdb.Command): """Dump the session.""" print("Session (" + str(session.address) + "):") - fields_to_print = ['_sessionId', '_numWaitingToCheckOut'] + fields_to_print = ["_sessionId", "_numWaitingToCheckOut"] for field in fields_to_print: # Skip fields that aren't found on the object. if field in get_field_names(session): print(field, "=", session[field]) else: - print("Could not find field '%s' on the SessionRuntimeInfo object." % field) + print( + "Could not find field '%s' on the SessionRuntimeInfo object." + % field + ) # Print the information from a TransactionParticipant if a session has one. txn_part_dec = get_decoration(session, "TransactionParticipant") @@ -528,28 +556,33 @@ class DumpMongoDSessionCatalog(gdb.Command): # from that object. If, in the future, we want to print fields from the # 'PrivateState' object, we can inspect the TransactionParticipant's '_p' field. txn_part = txn_part_dec[1] - txn_part_observable_state = txn_part['_o'] - fields_to_print = ['txnState', 'activeTxnNumberAndRetryCounter'] + txn_part_observable_state = txn_part["_o"] + fields_to_print = ["txnState", "activeTxnNumberAndRetryCounter"] print("TransactionParticipant (" + str(txn_part.address) + "):") for field in fields_to_print: # Skip fields that aren't found on the object. if field in get_field_names(txn_part_observable_state): print(field, "=", txn_part_observable_state[field]) else: - print("Could not find field '%s' on the TransactionParticipant" % field) + print( + "Could not find field '%s' on the TransactionParticipant" + % field + ) # The 'txnResourceStash' field is a boost::optional so we unpack it manually if it # is non-empty. We are only interested in its Locker object for now. TODO: Load the # boost pretty printers so the object will be printed clearly by default, without # the need for special unpacking. - val = get_boost_optional(txn_part_observable_state['txnResourceStash']) + val = get_boost_optional(txn_part_observable_state["txnResourceStash"]) if val: locker_addr = get_unique_ptr(val["_locker"]) - locker_obj = locker_addr.dereference().cast(lookup_type("mongo::Locker")) - print('txnResourceStash._locker', "@", locker_addr) + locker_obj = locker_addr.dereference().cast( + lookup_type("mongo::Locker") + ) + print("txnResourceStash._locker", "@", locker_addr) print("txnResourceStash._locker._id", "=", locker_obj["_id"]) else: - print('txnResourceStash', "=", None) + print("txnResourceStash", "=", None) print("") @@ -581,7 +614,8 @@ class DumpMongoDBMutexes(gdb.Command): # Use the STL pretty-printer to iterate over the list printer = stdlib_printers.StdForwardListPrinter( # pylint: disable=undefined-variable - str(diagnostic_info_list.type), diagnostic_info_list) + str(diagnostic_info_list.type), diagnostic_info_list + ) # Prepare structured output doc client_name = str(client["_desc"]) @@ -595,7 +629,9 @@ class DumpMongoDBMutexes(gdb.Command): output_doc["mutex"] = str(diagnostic_info["_captureName"])[1:-1] millis = int(diagnostic_info["_timestamp"]["millis"]) - dt = datetime.datetime.fromtimestamp(millis / 1000, tz=datetime.timezone.utc) + dt = datetime.datetime.fromtimestamp( + millis / 1000, tz=datetime.timezone.utc + ) output_doc["since"] = dt.isoformat() print(json.dumps(output_doc)) @@ -616,7 +652,7 @@ class MongoDBDumpLocks(gdb.Command): print("Running Hang Analyzer Supplement - MongoDBDumpLocks") main_binary_name = get_process_name() - if main_binary_name == 'mongod': + if main_binary_name == "mongod": self.dump_mongod_locks() else: print("Not invoking mongod lock dump for: %s" % (main_binary_name)) @@ -628,7 +664,9 @@ class MongoDBDumpLocks(gdb.Command): try: # Call into mongod, and dump the state of lock manager # Note that output will go to mongod's standard output, not the debugger output window - gdb.execute("call mongo::dumpLockManager()", from_tty=False, to_string=False) + gdb.execute( + "call mongo::dumpLockManager()", from_tty=False, to_string=False + ) except gdb.error as gdberr: print("Ignoring error '%s' in dump_mongod_locks" % str(gdberr)) @@ -642,7 +680,9 @@ class MongoDBDumpRecoveryUnits(gdb.Command): def __init__(self): """Initialize MongoDBDumpRecoveryUnits.""" - RegisterMongoCommand.register(self, "mongodb-dump-recovery-units", gdb.COMMAND_DATA) + RegisterMongoCommand.register( + self, "mongodb-dump-recovery-units", gdb.COMMAND_DATA + ) def invoke(self, arg, _from_tty): """Invoke MongoDBDumpRecoveryUnits.""" @@ -684,12 +724,17 @@ class MongoDBDumpRecoveryUnits(gdb.Command): recovery_unit = None if operation_context_handle: operation_context = operation_context_handle.dereference() - recovery_unit_handle = get_unique_ptr(operation_context["_recoveryUnit"]) + recovery_unit_handle = get_unique_ptr( + operation_context["_recoveryUnit"] + ) # By default, cast the recovery unit as "mongo::WiredTigerRecoveryUnit" recovery_unit = recovery_unit_handle.dereference().cast( - lookup_type(recovery_unit_impl_type)) + lookup_type(recovery_unit_impl_type) + ) - output_doc["recoveryUnit"] = hex(recovery_unit_handle) if recovery_unit else "0x0" + output_doc["recoveryUnit"] = ( + hex(recovery_unit_handle) if recovery_unit else "0x0" + ) wt_session = get_wt_session(recovery_unit, recovery_unit_impl_type) if wt_session: output_doc["WT_SESSION"] = hex(wt_session) @@ -700,14 +745,18 @@ class MongoDBDumpRecoveryUnits(gdb.Command): # Dump stashed recovery unit info for each session in a mongod process for session_kv in get_session_kv_pairs(): # The Session objects are stored inside the SessionRuntimeInfo object. - session_runtime_info = get_unique_ptr(session_kv['second']).dereference() - parent_session = session_runtime_info['parentSession'] - child_sessions = absl_get_nodes(session_runtime_info['childSessions']) + session_runtime_info = get_unique_ptr(session_kv["second"]).dereference() + parent_session = session_runtime_info["parentSession"] + child_sessions = absl_get_nodes(session_runtime_info["childSessions"]) - MongoDBDumpRecoveryUnits.dump_session(parent_session, recovery_unit_impl_type) + MongoDBDumpRecoveryUnits.dump_session( + parent_session, recovery_unit_impl_type + ) for child_session_kv in child_sessions: - child_session = child_session_kv['second'] - MongoDBDumpRecoveryUnits.dump_session(child_session, recovery_unit_impl_type) + child_session = child_session_kv["second"] + MongoDBDumpRecoveryUnits.dump_session( + child_session, recovery_unit_impl_type + ) if enabled_at_start: gdb.execute("set print static-members on") @@ -726,15 +775,21 @@ class MongoDBDumpRecoveryUnits(gdb.Command): if txn_participant_dec: txn_participant_observable_state = txn_participant_dec[1]["_o"] txn_resource_stash = get_boost_optional( - txn_participant_observable_state["txnResourceStash"]) + txn_participant_observable_state["txnResourceStash"] + ) if txn_resource_stash: output_doc["txnResourceStash"] = str(txn_resource_stash.address) - recovery_unit_handle = get_unique_ptr(txn_resource_stash["_recoveryUnit"]) + recovery_unit_handle = get_unique_ptr( + txn_resource_stash["_recoveryUnit"] + ) # By default, cast the recovery unit as "mongo::WiredTigerRecoveryUnit" recovery_unit = recovery_unit_handle.dereference().cast( - lookup_type(recovery_unit_impl_type)) + lookup_type(recovery_unit_impl_type) + ) - output_doc["recoveryUnit"] = hex(recovery_unit_handle) if recovery_unit else "0x0" + output_doc["recoveryUnit"] = ( + hex(recovery_unit_handle) if recovery_unit else "0x0" + ) wt_session = get_wt_session(recovery_unit, recovery_unit_impl_type) if wt_session: output_doc["WT_SESSION"] = hex(wt_session) @@ -754,17 +809,22 @@ class MongoDBDumpStorageEngineInfo(gdb.Command): def __init__(self): """Initialize MongoDBDumpStorageEngineInfo.""" - RegisterMongoCommand.register(self, "mongodb-dump-storage-engine-info", gdb.COMMAND_DATA) + RegisterMongoCommand.register( + self, "mongodb-dump-storage-engine-info", gdb.COMMAND_DATA + ) def invoke(self, arg, _from_tty): # pylint: disable=unused-argument """Invoke MongoDBDumpStorageEngineInfo.""" print("Running Hang Analyzer Supplement - MongoDBDumpStorageEngineInfo") main_binary_name = get_process_name() - if main_binary_name == 'mongod': + if main_binary_name == "mongod": self.dump_mongod_storage_engine_info() else: - print("Not invoking mongod storage engine info dump for: %s" % (main_binary_name)) + print( + "Not invoking mongod storage engine info dump for: %s" + % (main_binary_name) + ) @staticmethod def dump_mongod_storage_engine_info(): @@ -775,9 +835,13 @@ class MongoDBDumpStorageEngineInfo(gdb.Command): # Note that output will go to mongod's standard output, not the debugger output window gdb.execute( "call mongo::getGlobalServiceContext()->_storageEngine._ptr._value._M_b._M_p->dump()", - from_tty=False, to_string=False) + from_tty=False, + to_string=False, + ) except gdb.error as gdberr: - print("Ignoring error '%s' in dump_mongod_storage_engine_info" % str(gdberr)) + print( + "Ignoring error '%s' in dump_mongod_storage_engine_info" % str(gdberr) + ) # Register command @@ -794,7 +858,9 @@ class BtIfActive(gdb.Command): def invoke(self, arg, _from_tty): # pylint: disable=unused-argument """Invoke GDB to print stack trace.""" try: - idle_location = gdb.parse_and_eval("mongo::for_debuggers::idleThreadLocation") + idle_location = gdb.parse_and_eval( + "mongo::for_debuggers::idleThreadLocation" + ) except gdb.error: idle_location = None # If unsure, print a stack trace. @@ -821,7 +887,7 @@ class MongoDBUniqueStack(gdb.Command): """Invoke GDB to dump stacks.""" stacks = {} if not arg: - arg = 'bt' # default to 'bt' + arg = "bt" # default to 'bt' current_thread = gdb.selected_thread() try: @@ -839,24 +905,28 @@ class MongoDBUniqueStack(gdb.Command): def _process_thread_stack(arg, stacks, thread): """Process the thread stack.""" thread_info = {} # thread dict to hold per thread data - thread_info['pthread'] = get_thread_id() - thread_info['gdb_thread_num'] = thread.num - thread_info['lwpid'] = thread.ptid[1] - thread_info['name'] = get_current_thread_name() + thread_info["pthread"] = get_thread_id() + thread_info["gdb_thread_num"] = thread.num + thread_info["lwpid"] = thread.ptid[1] + thread_info["name"] = get_current_thread_name() if sys.platform.startswith("linux"): - header_format = "Thread {gdb_thread_num}: {name} (Thread 0x{pthread:x} (LWP {lwpid}))" + header_format = ( + "Thread {gdb_thread_num}: {name} (Thread 0x{pthread:x} (LWP {lwpid}))" + ) elif sys.platform.startswith("sunos"): (_, _, thread_tid) = thread.ptid - if thread_tid != 0 and thread_info['lwpid'] != 0: - header_format = "Thread {gdb_thread_num}: {name} (Thread {pthread} (LWP {lwpid}))" - elif thread_info['lwpid'] != 0: + if thread_tid != 0 and thread_info["lwpid"] != 0: + header_format = ( + "Thread {gdb_thread_num}: {name} (Thread {pthread} (LWP {lwpid}))" + ) + elif thread_info["lwpid"] != 0: header_format = "Thread {gdb_thread_num}: {name} (LWP {lwpid})" else: header_format = "Thread {gdb_thread_num}: {name} (Thread {pthread})" else: raise ValueError("Unsupported platform: {}".format(sys.platform)) - thread_info['header'] = header_format.format(**thread_info) + thread_info["header"] = header_format.format(**thread_info) addrs = [] # list of return addresses from frames frame = gdb.newest_frame() @@ -865,17 +935,17 @@ class MongoDBUniqueStack(gdb.Command): try: frame = frame.older() except gdb.error as err: - print("{} {}".format(thread_info['header'], err)) + print("{} {}".format(thread_info["header"], err)) break addrs_tuple = tuple(addrs) # tuples are hashable, lists aren't. - unique = stacks.setdefault(addrs_tuple, {'threads': []}) - unique['threads'].append(thread_info) - if 'output' not in unique: + unique = stacks.setdefault(addrs_tuple, {"threads": []}) + unique["threads"].append(thread_info) + if "output" not in unique: try: - unique['output'] = gdb.execute(arg, to_string=True).rstrip() + unique["output"] = gdb.execute(arg, to_string=True).rstrip() except gdb.error as err: - print("{} {}".format(thread_info['header'], err)) + print("{} {}".format(thread_info["header"], err)) @staticmethod def _dump_unique_stacks(stacks): @@ -883,13 +953,13 @@ class MongoDBUniqueStack(gdb.Command): def first_tid(stack): """Return the first tid.""" - return stack['threads'][0]['gdb_thread_num'] + return stack["threads"][0]["gdb_thread_num"] for stack in sorted(list(stacks.values()), key=first_tid, reverse=True): - for i, thread in enumerate(stack['threads']): - prefix = '' if i == 0 else 'Duplicate ' - print(prefix + thread['header']) - print(stack['output']) + for i, thread in enumerate(stack["threads"]): + prefix = "" if i == 0 else "Duplicate " + print(prefix + thread["header"]) + print(stack["output"]) print() # leave extra blank line after each thread stack @@ -910,14 +980,16 @@ class MongoDBJavaScriptStack(gdb.Command): def __init__(self): """Initialize MongoDBJavaScriptStack.""" - RegisterMongoCommand.register(self, "mongodb-javascript-stack", gdb.COMMAND_STATUS) + RegisterMongoCommand.register( + self, "mongodb-javascript-stack", gdb.COMMAND_STATUS + ) def invoke(self, arg, _from_tty): # pylint: disable=unused-argument """Invoke GDB to dump JS stacks.""" print("Running Print JavaScript Stack Supplement") main_binary_name = get_process_name() - if main_binary_name.endswith('mongod') or main_binary_name.endswith('mongo'): + if main_binary_name.endswith("mongod") or main_binary_name.endswith("mongo"): self.javascript_stack() else: print("No JavaScript stack print done for: %s" % (main_binary_name)) @@ -930,13 +1002,14 @@ class MongoDBJavaScriptStack(gdb.Command): # `'_M_b' in atomic_scope`, so exceptions for flow control it is. :| try: # reach into std::atomic and grab the pointer. This is for libc++ - return atomic_scope['_M_b']['_M_p'] + return atomic_scope["_M_b"]["_M_p"] except gdb.error: # Worst case scenario: try and use .load(), but it's probably # inlined. parse_and_eval required because you can't call methods # in gdb on the Python API return gdb.parse_and_eval( - f"((std::atomic *)({atomic_scope.address}))->load()") + f"((std::atomic *)({atomic_scope.address}))->load()" + ) return None @@ -973,16 +1046,18 @@ class MongoDBJavaScriptStack(gdb.Command): continue scope = ptr.dereference() - if scope['_inOp'] == 0: + if scope["_inOp"] == 0: continue - gdb.execute('thread', from_tty=False, to_string=False) + gdb.execute("thread", from_tty=False, to_string=False) # gdb continues to not support calling methods through Python, # so work around it by casting the raw ptr back to its type, # and calling the method through execute darkness gdb.execute( f'printf "%s\\n", ((mongo::mozjs::MozJSImplScope*)({ptr}))->buildStackString().c_str()', - from_tty=False, to_string=False) + from_tty=False, + to_string=False, + ) except gdb.error as err: print("Ignoring GDB error '%s' in javascript_stack" % str(err)) @@ -1002,7 +1077,7 @@ class MongoDBPPrintBsonAtPointer(gdb.Command): def invoke(self, args, _from_tty): """Invoke.""" - args = args.split(' ') + args = args.split(" ") if len(args) == 0 or (len(args) == 1 and len(args[0]) == 0): print("Usage: mongodb-pprint-bson ") return @@ -1017,6 +1092,7 @@ class MongoDBPPrintBsonAtPointer(gdb.Command): bsonobj = next(bson.decode_iter(memory)) from pprint import pprint + pprint(bsonobj) diff --git a/buildscripts/gdb/mongo_lock.py b/buildscripts/gdb/mongo_lock.py index 8e165461f3d..7c4c35cd177 100644 --- a/buildscripts/gdb/mongo_lock.py +++ b/buildscripts/gdb/mongo_lock.py @@ -19,7 +19,8 @@ if not gdb: if sys.version_info[0] < 3: raise gdb.GdbError( - "MongoDB gdb extensions only support Python 3. Your GDB was compiled against Python 2") + "MongoDB gdb extensions only support Python 3. Your GDB was compiled against Python 2" + ) class NonExecutingThread(object): @@ -69,7 +70,9 @@ class Thread(object): return not self == other def __str__(self): - return "{} (Thread 0x{:012x} (LWP {}))".format(self.name, self.thread_id, self.lwpid) + return "{} (Thread 0x{:012x} (LWP {}))".format( + self.name, self.thread_id, self.lwpid + ) def key(self): """Return thread key.""" @@ -125,7 +128,7 @@ class Graph(object): def add_node(self, node): """Add node to graph.""" if not self.find_node(node): - self.nodes[node.key()] = {'node': node, 'next_nodes': []} + self.nodes[node.key()] = {"node": node, "next_nodes": []} def find_node(self, node): """Find node in graph.""" @@ -137,8 +140,8 @@ class Graph(object): """Find from node.""" for node_key in self.nodes: node = self.nodes[node_key] - for next_node in node['next_nodes']: - if next_node == from_node['node'].key(): + for next_node in node["next_nodes"]: + if next_node == from_node["node"].key(): return node return None @@ -148,7 +151,7 @@ class Graph(object): temp_nodes = {} for node_key in self.nodes: node = self.nodes[node_key] - if node['next_nodes'] or self.find_from_node(node) is not None: + if node["next_nodes"] or self.find_from_node(node) is not None: temp_nodes[node_key] = self.nodes[node_key] self.nodes = temp_nodes @@ -164,50 +167,62 @@ class Graph(object): self.add_node(to_node) t_node = self.nodes[to_node.key()] - for n_node in f_node['next_nodes']: + for n_node in f_node["next_nodes"]: if n_node == to_node.key(): return - self.nodes[from_node.key()]['next_nodes'].append(to_node.key()) + self.nodes[from_node.key()]["next_nodes"].append(to_node.key()) def print(self): """Print graph.""" for node_key in self.nodes: - print("Node", self.nodes[node_key]['node']) - for to_node in self.nodes[node_key]['next_nodes']: - print(" ->", self.nodes[to_node]['node']) + print("Node", self.nodes[node_key]["node"]) + for to_node in self.nodes[node_key]["next_nodes"]: + print(" ->", self.nodes[to_node]["node"]) def _get_node_escaped(self, node_key): """Return the name of the node with any double quotes escaped. The DOT language requires that literal double quotes be escaped using a backslash character. """ - return str(self.nodes[node_key]['node']).replace('"', '\\"') + return str(self.nodes[node_key]["node"]).replace('"', '\\"') def to_graph(self, nodes=None, message=None): """Return the 'to_graph'.""" sb = [] - sb.append('# Legend:') - sb.append('# Thread 1 -> Lock C (MODE_IX) indicates Thread 1 is waiting on Lock C and' - ' Lock C is currently held in MODE_IX') - sb.append('# Lock C (MODE_IX) -> Thread 2 indicates Lock C is held by Thread 2 in' - ' MODE_IX') + sb.append("# Legend:") + sb.append( + "# Thread 1 -> Lock C (MODE_IX) indicates Thread 1 is waiting on Lock C and" + " Lock C is currently held in MODE_IX" + ) + sb.append( + "# Lock C (MODE_IX) -> Thread 2 indicates Lock C is held by Thread 2 in" + " MODE_IX" + ) if message is not None: sb.append(message) sb.append('digraph "mongod+lock-status" {') # Draw the graph from left to right. There can be hundreds of threads blocked by the same # resource, but only a few resources involved in a deadlock, so we prefer a long graph # than a super wide one. Long resource / thread names would make a wide graph even wider. - sb.append(' rankdir=LR;') + sb.append(" rankdir=LR;") for node_key in self.nodes: - for next_node_key in self.nodes[node_key]['next_nodes']: - sb.append(' "{}" -> "{}";'.format( - self._get_node_escaped(node_key), self._get_node_escaped(next_node_key))) + for next_node_key in self.nodes[node_key]["next_nodes"]: + sb.append( + ' "{}" -> "{}";'.format( + self._get_node_escaped(node_key), + self._get_node_escaped(next_node_key), + ) + ) for node_key in self.nodes: color = "" if nodes and node_key in nodes: color = "color = red" - sb.append(' "{0}" [label="{0}" {1}]'.format(self._get_node_escaped(node_key), color)) + sb.append( + ' "{0}" [label="{0}" {1}]'.format( + self._get_node_escaped(node_key), color + ) + ) sb.append("}") return "\n".join(sb) @@ -221,10 +236,10 @@ class Graph(object): nodes_in_cycle = [] nodes_visited.add(node_key) nodes_in_cycle.append(node_key) - for node in self.nodes[node_key]['next_nodes']: + for node in self.nodes[node_key]["next_nodes"]: if node in nodes_in_cycle: # The graph cycle starts at the index of node in nodes_in_cycle. - return nodes_in_cycle[nodes_in_cycle.index(node):] + return nodes_in_cycle[nodes_in_cycle.index(node) :] if node not in nodes_visited: dfs_nodes = self.depth_first_search(node, nodes_visited, nodes_in_cycle) if dfs_nodes: @@ -241,13 +256,15 @@ class Graph(object): if node not in nodes_visited: cycle_path = self.depth_first_search(node, nodes_visited) if cycle_path: - return [str(self.nodes[node_key]['node']) for node_key in cycle_path] + return [ + str(self.nodes[node_key]["node"]) for node_key in cycle_path + ] return None def find_thread(thread_dict, search_thread_id): """Find thread.""" - for (_, thread) in list(thread_dict.items()): + for _, thread in list(thread_dict.items()): if thread.thread_id == search_thread_id: return thread return None @@ -286,7 +303,7 @@ def find_frame(function_name_pattern): def find_mutex_holder(graph, thread_dict, show): """Find mutex holder.""" - frame = find_frame(r'std::mutex::lock\(\)') + frame = find_frame(r"std::mutex::lock\(\)") if frame is None: return @@ -301,9 +318,12 @@ def find_mutex_holder(graph, thread_dict, show): # At time thread_dict was initialized, the mutex holder may not have been found. # Use the thread LWP as a substitute for showing output or generating the graph. if mutex_holder_lwpid not in thread_dict: - print("Warning: Mutex at {} held by thread with LWP {}" - " not found in thread_dict. Using LWP to track thread.".format( - mutex_value, mutex_holder_lwpid)) + print( + "Warning: Mutex at {} held by thread with LWP {}" + " not found in thread_dict. Using LWP to track thread.".format( + mutex_value, mutex_holder_lwpid + ) + ) mutex_holder = Thread(mutex_holder_lwpid, mutex_holder_lwpid, '"[unknown]"') else: mutex_holder = thread_dict[mutex_holder_lwpid] @@ -311,8 +331,11 @@ def find_mutex_holder(graph, thread_dict, show): (_, mutex_waiter_lwpid, _) = gdb.selected_thread().ptid mutex_waiter = thread_dict[mutex_waiter_lwpid] if show: - print("Mutex at {} held by {} waited on by {}".format(mutex_value, mutex_holder, - mutex_waiter)) + print( + "Mutex at {} held by {} waited on by {}".format( + mutex_value, mutex_holder, mutex_waiter + ) + ) if graph: graph.add_edge(mutex_waiter, Lock(int(mutex_value), "Mutex")) graph.add_edge(Lock(int(mutex_value), "Mutex"), mutex_holder) @@ -320,7 +343,7 @@ def find_mutex_holder(graph, thread_dict, show): def find_lock_manager_holders(graph, thread_dict, show): """Find lock manager holders.""" - frame = find_frame(r'mongo::Locker::') + frame = find_frame(r"mongo::Locker::") if not frame: return @@ -349,8 +372,11 @@ def find_lock_manager_holders(graph, thread_dict, show): else: lock_holder = find_thread(thread_dict, lock_holder_id) if show: - print("MongoDB Lock at {} held by {} ({}) waited on by {}".format( - lock_head, lock_holder, lock_request["mode"], lock_waiter)) + print( + "MongoDB Lock at {} held by {} ({}) waited on by {}".format( + lock_head, lock_holder, lock_request["mode"], lock_waiter + ) + ) if graph: graph.add_edge(lock_waiter, Lock(int(lock_head), lock_request["mode"])) graph.add_edge(Lock(int(lock_head), lock_request["mode"]), lock_holder) @@ -446,7 +472,7 @@ class MongoDBWaitsForGraph(gdb.Command): cycle_message = "# Cycle detected in the graph nodes %s" % cycle_nodes if graph_file: print("Saving digraph to %s" % graph_file) - with open(graph_file, 'w') as fh: + with open(graph_file, "w") as fh: fh.write(graph.to_graph(nodes=cycle_nodes, message=cycle_message)) print(cycle_message.split("# ")[1]) else: diff --git a/buildscripts/gdb/mongo_printers.py b/buildscripts/gdb/mongo_printers.py index 34d47f645a5..18d5d5133c1 100644 --- a/buildscripts/gdb/mongo_printers.py +++ b/buildscripts/gdb/mongo_printers.py @@ -39,7 +39,8 @@ except ImportError: if sys.version_info[0] < 3: raise gdb.GdbError( - "MongoDB gdb extensions only support Python 3. Your GDB was compiled against Python 2") + "MongoDB gdb extensions only support Python 3. Your GDB was compiled against Python 2" + ) def get_unique_ptr_bytes(obj): @@ -50,7 +51,9 @@ def get_unique_ptr_bytes(obj): mongo::Decorable<> types which store the decorations as a slab of memory with std::unique_ptr. In all other cases get_unique_ptr() can be preferred. """ - return obj.cast(gdb.lookup_type('std::_Head_base<0, unsigned char*, false>'))['_M_head_impl'] + return obj.cast(gdb.lookup_type("std::_Head_base<0, unsigned char*, false>"))[ + "_M_head_impl" + ] def get_unique_ptr(obj): @@ -71,20 +74,20 @@ class StatusPrinter(object): @staticmethod def extract_error(val): """Extract the error object (if any) from a Status/StatusWith.""" - error = val['_error'] - if 'px' in error.type.iterkeys(): - return error['px'] + error = val["_error"] + if "px" in error.type.iterkeys(): + return error["px"] return error @staticmethod def generate_error_details(error): """Generate a (code,reason) tuple from a Status/StatusWith error object.""" info = error.dereference() - code = info['code'] + code = info["code"] # Remove the mongo::ErrorCodes:: prefix. Does nothing if not a real ErrorCode. - code = str(code).split('::')[-1] + code = str(code).split("::")[-1] - return (code, info['reason']) + return (code, info["reason"]) def __init__(self, val): """Initialize StatusPrinter.""" @@ -94,8 +97,8 @@ class StatusPrinter(object): """Return status for printing.""" error = StatusPrinter.extract_error(self.val) if not error: - return 'Status::OK()' - return 'Status(%s, %s)' % StatusPrinter.generate_error_details(error) + return "Status::OK()" + return "Status(%s, %s)" % StatusPrinter.generate_error_details(error) class StatusWithPrinter(object): @@ -107,10 +110,10 @@ class StatusWithPrinter(object): def to_string(self): """Return status for printing.""" - error = StatusPrinter.extract_error(self.val['_status']) + error = StatusPrinter.extract_error(self.val["_status"]) if not error: - return 'StatusWith(OK, %s)' % (self.val['_t']) - return 'StatusWith(%s, %s)' % StatusPrinter.generate_error_details(error) + return "StatusWith(OK, %s)" % (self.val["_t"]) + return "StatusWith(%s, %s)" % StatusPrinter.generate_error_details(error) class StringDataPrinter(object): @@ -123,20 +126,20 @@ class StringDataPrinter(object): @staticmethod def display_hint(): """Display hint.""" - return 'string' + return "string" def to_string(self): """Return data for printing.""" # As of SERVER-82604, StringData is based on std::string_view, so try with that first - sv = self.val['_sv'] + sv = self.val["_sv"] if sv is not None: return sv # ... back-off to the legacy format otherwise - size = self.val['_size'] + size = self.val["_size"] if size == -1: - return self.val['_data'].lazy_string() - return self.val['_data'].lazy_string(length=size) + return self.val["_data"].lazy_string() + return self.val["_data"].lazy_string(length=size) class BoostOptionalPrinter(object): @@ -157,7 +160,7 @@ class BSONObjPrinter(object): def __init__(self, val): """Initialize BSONObjPrinter.""" self.val = val - self.ptr = self.val['_objdata'].cast(lookup_type('void').pointer()) + self.ptr = self.val["_objdata"].cast(lookup_type("void").pointer()) self.is_valid = False # Handle the endianness of the BSON object size, which is represented as a 32-bit integer @@ -168,29 +171,36 @@ class BSONObjPrinter(object): self.size = -1 self.raw_memory = None else: - self.size = struct.unpack(' 17 * 1024 * 1024: + if ( + not bson + or not self.is_valid + or self.size < 5 + or self.size > 17 * 1024 * 1024 + ): return options = CodecOptions(document_class=collections.OrderedDict) bsondoc = bson.decode(self.raw_memory, codec_options=options) for key, val in list(bsondoc.items()): - yield 'key', key - yield 'value', bson.json_util.dumps(val) + yield "key", key + yield "value", bson.json_util.dumps(val) def to_string(self): """Return BSONObj for printing.""" @@ -198,7 +208,11 @@ class BSONObjPrinter(object): if self.size == -1: return "BSONObj @ %s - optimized out" % (self.ptr) - ownership = "owned" if self.val['_ownedBuffer']['_buffer']['_holder']['px'] else "unowned" + ownership = ( + "owned" + if self.val["_ownedBuffer"]["_buffer"]["_holder"]["px"] + else "unowned" + ) size = self.size # Print an invalid BSONObj size in hex. @@ -232,12 +246,17 @@ class OplogEntryPrinter(object): def to_string(self): """Return OplogEntry for printing.""" - optime = self.val['_entry']['_opTimeBase'] - optime_str = "ts(%s, %s)" % (optime['_timestamp']['secs'], optime['_timestamp']['i']) + optime = self.val["_entry"]["_opTimeBase"] + optime_str = "ts(%s, %s)" % ( + optime["_timestamp"]["secs"], + optime["_timestamp"]["i"], + ) return "OplogEntry(%s, %s, %s, %s)" % ( - str(self.val['_entry']['_durableReplOperation']['_opType']).split('::')[-1], - str(self.val['_entry']['_commandType']).split('::')[-1], - self.val['_entry']['_durableReplOperation']['_nss'], optime_str) + str(self.val["_entry"]["_durableReplOperation"]["_opType"]).split("::")[-1], + str(self.val["_entry"]["_commandType"]).split("::")[-1], + self.val["_entry"]["_durableReplOperation"]["_nss"], + optime_str, + ) class UUIDPrinter(object): @@ -250,11 +269,11 @@ class UUIDPrinter(object): @staticmethod def display_hint(): """Display hint.""" - return 'string' + return "string" def to_string(self): """Return UUID for printing.""" - raw_bytes = [self.val['_uuid']['_M_elems'][i] for i in range(16)] + raw_bytes = [self.val["_uuid"]["_M_elems"][i] for i in range(16)] uuid_hex_bytes = [hex(int(b))[2:].zfill(2) for b in raw_bytes] return str(uuid.UUID("".join(uuid_hex_bytes))) @@ -269,11 +288,11 @@ class OIDPrinter(object): @staticmethod def display_hint(): """Display hint.""" - return 'string' + return "string" def to_string(self): """Return OID for printing.""" - raw_bytes = [int(self.val['_data'][i]) for i in range(OBJECT_ID_WIDTH)] + raw_bytes = [int(self.val["_data"][i]) for i in range(OBJECT_ID_WIDTH)] oid_hex_bytes = [hex(b & 0xFF)[2:].zfill(2) for b in raw_bytes] return "ObjectID('%s')" % "".join(oid_hex_bytes) @@ -288,12 +307,12 @@ class RecordIdPrinter(object): @staticmethod def display_hint(): """Display hint.""" - return 'string' + return "string" ## Get the address at given offset of data as the selected pointer type def __get_data_address(self, ptr, offset): ptr_type = gdb.lookup_type(ptr).pointer() - return self.val['_data']['_M_elems'][offset].address.cast(ptr_type) + return self.val["_data"]["_M_elems"][offset].address.cast(ptr_type) def to_string(self): """Return RecordId for printing.""" @@ -301,27 +320,39 @@ class RecordIdPrinter(object): if rid_format == 0: return "null RecordId" elif rid_format == 1: - koffset = 8 - 1 ## std::alignment_of_v - sizeof(Format); (see record_id.h) - rid_address = self.__get_data_address('int64_t', koffset) + koffset = ( + 8 - 1 + ) ## std::alignment_of_v - sizeof(Format); (see record_id.h) + rid_address = self.__get_data_address("int64_t", koffset) return "RecordId long: %d" % int(rid_address.dereference()) elif rid_format == 2: - str_len = self.__get_data_address('int8_t', 0).dereference() - array_address = self.__get_data_address('int8_t', 1) + str_len = self.__get_data_address("int8_t", 0).dereference() + array_address = self.__get_data_address("int8_t", 1) raw_bytes = [array_address[i] for i in range(0, str_len)] hex_bytes = [hex(b & 0xFF)[2:].zfill(2) for b in raw_bytes] - return "RecordId small string %d hex bytes: %s" % (str_len, str("".join(hex_bytes))) + return "RecordId small string %d hex bytes: %s" % ( + str_len, + str("".join(hex_bytes)), + ) elif rid_format == 3: - koffset = 8 - 1 ## std::alignment_of_v - sizeof(Format); (see record_id.h) - buffer = self.__get_data_address('mongo::ConstSharedBuffer', koffset).dereference() - holder_ptr = holder = buffer['_buffer']["_holder"]["px"] + koffset = ( + 8 - 1 + ) ## std::alignment_of_v - sizeof(Format); (see record_id.h) + buffer = self.__get_data_address( + "mongo::ConstSharedBuffer", koffset + ).dereference() + holder_ptr = holder = buffer["_buffer"]["_holder"]["px"] holder = holder.dereference() str_len = int(holder["_capacity"]) # Start of data is immediately after pointer for holder start_ptr = (holder_ptr + 1).dereference().cast(lookup_type("char")).address raw_bytes = [start_ptr[i] for i in range(0, str_len)] hex_bytes = [hex(b & 0xFF)[2:].zfill(2) for b in raw_bytes] - return "RecordId big string %d hex bytes @ %s: %s" % (str_len, holder_ptr + 1, - str("".join(hex_bytes))) + return "RecordId big string %d hex bytes @ %s: %s" % ( + str_len, + holder_ptr + 1, + str("".join(hex_bytes)), + ) else: return "unknown RecordId format: %d" % rid_format @@ -353,22 +384,22 @@ class DatabaseNamePrinter(object): @staticmethod def display_hint(): """Display hint.""" - return 'string' + return "string" def _get_storage_data(self): """Return the data pointer from the _data Storage class.""" - data = self.val['_data'] - footer = data['_footer'] + data = self.val["_data"] + footer = data["_footer"] f_size = footer.type.sizeof # The last byte of _footer contain the flags (and the size when using small string). - flags = footer.cast(gdb.lookup_type('char').array(f_size))[f_size - 1] + flags = footer.cast(gdb.lookup_type("char").array(f_size))[f_size - 1] - data_ptr = data['_data'] + data_ptr = data["_data"] if is_small_string(flags): return data_ptr.address, small_string_size(flags) else: - return data_ptr, data['_length'] + return data_ptr, data["_length"] def _get_string(self, address, size): data = gdb.selected_inferior().read_memory(address, size).tobytes() @@ -396,21 +427,30 @@ class DecorablePrinter(object): @staticmethod def display_hint(): """Display hint.""" - return 'map' + return "map" def to_string(self): """Return Decorable for printing.""" - return "Decorable<{}> with {} elems ".format(self.val.type.template_argument(0), self.count) + return "Decorable<{}> with {} elems ".format( + self.val.type.template_argument(0), self.count + ) def children(self): """Children.""" for index in range(self.count): try: - deco_type_name, obj, obj_addr = get_object_decoration(self.val, self.start, index) - yield ('key', "{}:{}:{}".format(index, obj_addr, deco_type_name)) - yield ('value', obj) + deco_type_name, obj, obj_addr = get_object_decoration( + self.val, self.start, index + ) + yield ("key", "{}:{}:{}".format(index, obj_addr, deco_type_name)) + yield ("value", obj) except Exception as err: - print("Failed to look up decoration type: " + deco_type_name + ": " + str(err)) + print( + "Failed to look up decoration type: " + + deco_type_name + + ": " + + str(err) + ) def _get_flags(flag_val, flags): @@ -443,7 +483,9 @@ class WtCursorPrinter(object): """ try: - with open("./src/third_party/wiredtiger/src/include/wiredtiger.in") as wiredtiger_header: + with open( + "./src/third_party/wiredtiger/src/include/wiredtiger.in" + ) as wiredtiger_header: file_contents = wiredtiger_header.read() cursor_flags_re = re.compile(r"#define\s+WT_CURSTD_(\w+)\s+0x(\d+)u") cursor_flags = cursor_flags_re.findall(file_contents)[::-1] @@ -463,8 +505,12 @@ class WtCursorPrinter(object): for field in self.val.type.fields(): field_val = self.val[field.name] if field.name == "flags": - yield ("flags", "{} ({})".format(field_val, - str(_get_flags(field_val, self.cursor_flags)))) + yield ( + "flags", + "{} ({})".format( + field_val, str(_get_flags(field_val, self.cursor_flags)) + ), + ) else: yield (field.name, field_val) @@ -477,7 +523,9 @@ class WtSessionImplPrinter(object): """ try: - with open("./src/third_party/wiredtiger/src/include/session.h") as session_header: + with open( + "./src/third_party/wiredtiger/src/include/session.h" + ) as session_header: file_contents = session_header.read() session_flags_re = re.compile(r"#define\s+WT_SESSION_(\w+)\s+0x(\d+)u") session_flags = session_flags_re.findall(file_contents)[::-1] @@ -497,8 +545,12 @@ class WtSessionImplPrinter(object): for field in self.val.type.fields(): field_val = self.val[field.name] if field.name == "flags": - yield ("flags", "{} ({})".format(field_val, - str(_get_flags(field_val, self.session_flags)))) + yield ( + "flags", + "{} ({})".format( + field_val, str(_get_flags(field_val, self.session_flags)) + ), + ) else: yield (field.name, field_val) @@ -531,8 +583,12 @@ class WtTxnPrinter(object): for field in self.val.type.fields(): field_val = self.val[field.name] if field.name == "flags": - yield ("flags", "{} ({})".format(field_val, - str(_get_flags(field_val, self.txn_flags)))) + yield ( + "flags", + "{} ({})".format( + field_val, str(_get_flags(field_val, self.txn_flags)) + ), + ) else: yield (field.name, field_val) @@ -551,7 +607,11 @@ def absl_insert_version_after_absl(cpp_name): absl_ns_end = absl_ns_start + len(absl_ns_str) return ( - cpp_name[:absl_ns_end] + ABSL_OPTION_INLINE_NAMESPACE_NAME + "::" + cpp_name[absl_ns_end:]) + cpp_name[:absl_ns_end] + + ABSL_OPTION_INLINE_NAMESPACE_NAME + + "::" + + cpp_name[absl_ns_end:] + ) def absl_get_settings(val): @@ -559,8 +619,12 @@ def absl_get_settings(val): try: common_fields_storage_type = gdb.lookup_type( absl_insert_version_after_absl( - "absl::container_internal::internal_compressed_tuple::Storage") + - absl_insert_version_after_absl("")) + "absl::container_internal::internal_compressed_tuple::Storage" + ) + + absl_insert_version_after_absl( + "" + ) + ) except gdb.error as err: if not err.args[0].startswith("No type named "): raise @@ -571,7 +635,9 @@ def absl_get_settings(val): common_fields_storage_type = gdb.lookup_type( absl_insert_version_after_absl( "absl::container_internal::internal_compressed_tuple::Storage" - "")) + "" + ) + ) # The Hash, Eq, or Alloc functors may not be zero-sized objects. # mongo::LogicalSessionIdHash is one such example. An explicit cast is needed to @@ -595,7 +661,9 @@ def absl_get_nodes(val): ctrl = settings["control_"] # Derive the underlying type stored in the container. - slot_type = lookup_type(str(val.type.strip_typedefs()) + "::slot_type").strip_typedefs() + slot_type = lookup_type( + str(val.type.strip_typedefs()) + "::slot_type" + ).strip_typedefs() # Using the array of ctrl bytes, search for in-use slots and return them # https://github.com/abseil/abseil-cpp/blob/8a3caf7dea955b513a6c1b572a2423c6b4213402/absl/container/internal/raw_hash_set.h#L2108-L2113 @@ -616,7 +684,7 @@ class AbslHashSetPrinterBase(object): @staticmethod def display_hint(): """Display hint.""" - return 'array' + return "array" def to_string(self): """Return absl::[node/flat]_hash_set for printing.""" @@ -668,7 +736,7 @@ class AbslHashMapPrinterBase(object): @staticmethod def display_hint(): """Display hint.""" - return 'map' + return "map" def to_string(self): """Return absl::[node/flat]_hash_map for printing.""" @@ -690,8 +758,8 @@ class AbslNodeHashMapPrinter(AbslHashMapPrinterBase): def children(self): """Children.""" for kvp in absl_get_nodes(self.val): - yield ('key', kvp['first']) - yield ('value', kvp['second']) + yield ("key", kvp["first"]) + yield ("value", kvp["second"]) class AbslFlatHashMapPrinter(AbslHashMapPrinterBase): @@ -704,8 +772,8 @@ class AbslFlatHashMapPrinter(AbslHashMapPrinterBase): def children(self): """Children.""" for kvp in absl_get_nodes(self.val): - yield ('key', kvp['key']) - yield ('value', kvp['value']["second"]) + yield ("key", kvp["key"]) + yield ("value", kvp["value"]["second"]) class ImmutableMapIter(ImmerListIter): @@ -717,7 +785,7 @@ class ImmutableMapIter(ImmerListIter): def __next__(self): if self.pair: - result = ('value', self.pair['second']) + result = ("value", self.pair["second"]) self.pair = None self.i += 1 return result @@ -726,8 +794,9 @@ class ImmutableMapIter(ImmerListIter): if self.i < self.curr[1] or self.i >= self.curr[2]: self.curr = self.region() self.pair = self.curr[0][self.i - self.curr[1]].cast( - gdb.lookup_type(self.v.type.template_argument(0).name)) - result = ('key', self.pair['first']) + gdb.lookup_type(self.v.type.template_argument(0).name) + ) + result = ("key", self.pair["first"]) return result @@ -738,13 +807,16 @@ class ImmutableMapPrinter: self.val = val def to_string(self): - return '%s of size %d' % (self.val.type, int(self.val['_storage']['impl_']['size'])) + return "%s of size %d" % ( + self.val.type, + int(self.val["_storage"]["impl_"]["size"]), + ) def children(self): - return ImmutableMapIter(self.val['_storage']) + return ImmutableMapIter(self.val["_storage"]) def display_hint(self): - return 'map' + return "map" class ImmutableSetPrinter: @@ -754,16 +826,19 @@ class ImmutableSetPrinter: self.val = val def to_string(self): - return '%s of size %d' % (self.val.type, int(self.val['_storage']['impl_']['size'])) + return "%s of size %d" % ( + self.val.type, + int(self.val["_storage"]["impl_"]["size"]), + ) def children(self): - return ImmerListIter(self.val['_storage']) + return ImmerListIter(self.val["_storage"]) def display_hint(self): - return 'array' + return "array" -def find_match_brackets(search, opening='<', closing='>'): +def find_match_brackets(search, opening="<", closing=">"): """Return the index of the closing bracket that matches the first opening bracket. Return -1 if no last matching bracket is found, i.e. not a template. @@ -817,7 +892,9 @@ class MongoPrettyPrinterCollection(gdb.printing.PrettyPrinter): def add(self, name, prefix, is_template, printer): """Add a subprinter.""" - self.subprinters.append(MongoSubPrettyPrinter(name, prefix, is_template, printer)) + self.subprinters.append( + MongoSubPrettyPrinter(name, prefix, is_template, printer) + ) def __call__(self, val): """Return matched printer type.""" @@ -837,7 +914,10 @@ class MongoPrettyPrinterCollection(gdb.printing.PrettyPrinter): # Ignore subtypes of templated classes. # We do not want HashTable::iterator as an example, just HashTable if printer.is_template: - if (index + 1 == len(lookup_tag) and lookup_tag.find(printer.prefix) == 0): + if ( + index + 1 == len(lookup_tag) + and lookup_tag.find(printer.prefix) == 0 + ): return printer.printer(val) elif lookup_tag == printer.prefix: return printer.printer(val) @@ -851,13 +931,13 @@ class WtUpdateToBsonPrinter(object): def __init__(self, val): """Initializer.""" self.val = val - self.size = self.val['size'] - self.ptr = self.val['data'] + self.size = self.val["size"] + self.ptr = self.val["data"] @staticmethod def display_hint(): """DisplayHint.""" - return 'map' + return "map" def to_string(self): """ToString.""" @@ -866,11 +946,11 @@ class WtUpdateToBsonPrinter(object): fld = self.val.type.fields()[idx] val = self.val[fld.name] elems.append(str((fld.name, str(val)))) - return "WT_UPDATE: \n %s" % ('\n '.join(elems)) + return "WT_UPDATE: \n %s" % ("\n ".join(elems)) def children(self): """children.""" - if self.val['type'] != 3: + if self.val["type"] != 3: # Type 3 is a "normal" update. Notably type 4 is a deletion and type 1 represents a # delta relative to the previous committed version in the update chain. Only attempt # to parse type 3 as bson. @@ -884,8 +964,8 @@ class WtUpdateToBsonPrinter(object): return for key, value in list(bsonobj.items()): - yield 'key', key - yield 'value', bson.json_util.dumps(value) + yield "key", key + yield "value", bson.json_util.dumps(value) def make_inverse_enum_dict(enum_type_name): @@ -898,7 +978,7 @@ def make_inverse_enum_dict(enum_type_name): enum_dict = gdb.types.make_enum_dict(lookup_type(enum_type_name)) enum_inverse_dic = dict() for key, value in enum_dict.items(): - enum_inverse_dic[int(value)] = key.split('::')[-1] # take last element + enum_inverse_dic[int(value)] = key.split("::")[-1] # take last element return enum_inverse_dic @@ -941,19 +1021,22 @@ class SbeCodeFragmentPrinter(object): # The instructions stream is stored using 'absl::InlinedVector' type, which can # either use an inline buffer or an allocated one. The choice of storage is decoded in the # last bit of the 'metadata_' field. - storage = self.val['_instrs']['storage_'] - meta = storage['metadata_'].cast(lookup_type('size_t')) - self.is_inlined = (meta % 2 == 0) - self.size = (meta >> 1) - self.pdata = \ - storage['data_']['inlined']['inlined_data'].cast(lookup_type('uint8_t').pointer()) \ - if self.is_inlined \ - else storage['data_']['allocated']['allocated_data'] + storage = self.val["_instrs"]["storage_"] + meta = storage["metadata_"].cast(lookup_type("size_t")) + self.is_inlined = meta % 2 == 0 + self.size = meta >> 1 + self.pdata = ( + storage["data_"]["inlined"]["inlined_data"].cast( + lookup_type("uint8_t").pointer() + ) + if self.is_inlined + else storage["data_"]["allocated"]["allocated_data"] + ) # Precompute lookup tables for Instructions and Builtins. - self.optags_lookup = make_inverse_enum_dict('mongo::sbe::vm::Instruction::Tags') - self.builtins_lookup = make_inverse_enum_dict('mongo::sbe::vm::Builtin') - self.valuetags_lookup = make_inverse_enum_dict('mongo::sbe::value::TypeTags') + self.optags_lookup = make_inverse_enum_dict("mongo::sbe::vm::Instruction::Tags") + self.builtins_lookup = make_inverse_enum_dict("mongo::sbe::vm::Builtin") + self.valuetags_lookup = make_inverse_enum_dict("mongo::sbe::value::TypeTags") def to_string(self): """Return sbe::vm::CodeFragment for printing.""" @@ -961,26 +1044,29 @@ class SbeCodeFragmentPrinter(object): def children(self): """children.""" - yield '_instrs', '{... (to see raw output, run "disable pretty-printer")}' - yield '_fixUps', self.val['_fixUps'] - yield '_stackSize', self.val['_stackSize'] + yield "_instrs", '{... (to see raw output, run "disable pretty-printer")}' + yield "_fixUps", self.val["_fixUps"] + yield "_stackSize", self.val["_stackSize"] - yield 'inlined', self.is_inlined - yield 'instrs data at', '[{} - {}]'.format(hex(self.pdata), hex(self.pdata + self.size)) - yield 'instrs total size', self.size + yield "inlined", self.is_inlined + yield ( + "instrs data at", + "[{} - {}]".format(hex(self.pdata), hex(self.pdata + self.size)), + ) + yield "instrs total size", self.size # Sizes for types we'll use when parsing the insructions stream. - int_size = lookup_type('int').sizeof - ptr_size = lookup_type('void').pointer().sizeof - tag_size = lookup_type('mongo::sbe::value::TypeTags').sizeof - value_size = lookup_type('mongo::sbe::value::Value').sizeof - uint8_size = lookup_type('uint8_t').sizeof - uint32_size = lookup_type('uint32_t').sizeof - uint64_size = lookup_type('uint64_t').sizeof - builtin_size = lookup_type('mongo::sbe::vm::Builtin').sizeof - time_unit_size = lookup_type('mongo::TimeUnit').sizeof - timezone_size = lookup_type('mongo::TimeZone').sizeof - day_of_week_size = lookup_type('mongo::DayOfWeek').sizeof + int_size = lookup_type("int").sizeof + ptr_size = lookup_type("void").pointer().sizeof + tag_size = lookup_type("mongo::sbe::value::TypeTags").sizeof + value_size = lookup_type("mongo::sbe::value::Value").sizeof + uint8_size = lookup_type("uint8_t").sizeof + uint32_size = lookup_type("uint32_t").sizeof + uint64_size = lookup_type("uint64_t").sizeof + builtin_size = lookup_type("mongo::sbe::vm::Builtin").sizeof + time_unit_size = lookup_type("mongo::TimeUnit").sizeof + timezone_size = lookup_type("mongo::TimeZone").sizeof + day_of_week_size = lookup_type("mongo::DayOfWeek").sizeof cur_op = self.pdata end_op = self.pdata + self.size @@ -991,7 +1077,7 @@ class SbeCodeFragmentPrinter(object): op_tag = read_as_integer(op_addr, 1) if op_tag not in self.optags_lookup: - yield hex(op_addr), 'unknown op tag: {}'.format(op_tag) + yield hex(op_addr), "unknown op tag: {}".format(op_tag) error = True break op_name = self.optags_lookup[op_tag] @@ -1000,70 +1086,93 @@ class SbeCodeFragmentPrinter(object): instr_count += 1 # Some instructions have extra arguments, embedded into the ops stream. - args = '' - if op_name in ['pushLocalVal', 'pushMoveLocalVal', 'pushLocalLambda']: - args = 'arg: ' + str(read_as_integer(cur_op, int_size)) + args = "" + if op_name in ["pushLocalVal", "pushMoveLocalVal", "pushLocalLambda"]: + args = "arg: " + str(read_as_integer(cur_op, int_size)) cur_op += int_size - elif op_name in ['jmp', 'jmpTrue', 'jmpFalse', 'jmpNothing', 'jmpNotNothing']: + elif op_name in [ + "jmp", + "jmpTrue", + "jmpFalse", + "jmpNothing", + "jmpNotNothing", + ]: offset = read_as_integer_signed(cur_op, int_size) cur_op += int_size - args = 'offset: ' + str(offset) + ', target: ' + hex(cur_op + offset) - elif op_name in ['pushConstVal', 'getFieldImm']: + args = "offset: " + str(offset) + ", target: " + hex(cur_op + offset) + elif op_name in ["pushConstVal", "getFieldImm"]: tag = read_as_integer(cur_op, tag_size) - args = 'tag: ' + self.valuetags_lookup.get(tag, "unknown") + \ - ', value: ' + hex(read_as_integer(cur_op + tag_size, value_size)) - cur_op += (tag_size + value_size) - elif op_name in ['pushAccessVal', 'pushMoveVal']: - args = 'accessor: ' + hex(read_as_integer(cur_op, ptr_size)) + args = ( + "tag: " + + self.valuetags_lookup.get(tag, "unknown") + + ", value: " + + hex(read_as_integer(cur_op + tag_size, value_size)) + ) + cur_op += tag_size + value_size + elif op_name in ["pushAccessVal", "pushMoveVal"]: + args = "accessor: " + hex(read_as_integer(cur_op, ptr_size)) cur_op += ptr_size - elif op_name in ['numConvert']: - args = 'convert to: ' + \ - self.valuetags_lookup.get(read_as_integer(cur_op, tag_size), "unknown") + elif op_name in ["numConvert"]: + args = "convert to: " + self.valuetags_lookup.get( + read_as_integer(cur_op, tag_size), "unknown" + ) cur_op += tag_size - elif op_name in ['typeMatchImm']: - args = 'mask: ' + hex(read_as_integer(cur_op, uint32_size)) + elif op_name in ["typeMatchImm"]: + args = "mask: " + hex(read_as_integer(cur_op, uint32_size)) cur_op += uint32_size - elif op_name in ['function', 'functionSmall']: - arity_size = \ - lookup_type('mongo::sbe::vm::ArityType').sizeof \ - if op_name == 'function' \ - else lookup_type('mongo::sbe::vm::SmallArityType').sizeof + elif op_name in ["function", "functionSmall"]: + arity_size = ( + lookup_type("mongo::sbe::vm::ArityType").sizeof + if op_name == "function" + else lookup_type("mongo::sbe::vm::SmallArityType").sizeof + ) builtin_id = read_as_integer(cur_op, builtin_size) - args = 'builtin: ' + self.builtins_lookup.get(builtin_id, "unknown") - args += ' arity: ' + str(read_as_integer(cur_op + builtin_size, arity_size)) - cur_op += (builtin_size + arity_size) - elif op_name in ['fillEmptyImm']: - args = 'Instruction::Constants: ' + str(read_as_integer(cur_op, uint8_size)) + args = "builtin: " + self.builtins_lookup.get(builtin_id, "unknown") + args += " arity: " + str( + read_as_integer(cur_op + builtin_size, arity_size) + ) + cur_op += builtin_size + arity_size + elif op_name in ["fillEmptyImm"]: + args = "Instruction::Constants: " + str( + read_as_integer(cur_op, uint8_size) + ) cur_op += uint8_size - elif op_name in ['traverseFImm', 'traversePImm']: + elif op_name in ["traverseFImm", "traversePImm"]: const_enum = read_as_integer(cur_op, uint8_size) cur_op += uint8_size - args = \ - 'Instruction::Constants: ' + str(const_enum) + \ - ", offset: " + str(read_as_integer_signed(cur_op, int_size)) + args = ( + "Instruction::Constants: " + + str(const_enum) + + ", offset: " + + str(read_as_integer_signed(cur_op, int_size)) + ) cur_op += int_size - elif op_name in ['dateTruncImm']: + elif op_name in ["dateTruncImm"]: unit = read_as_integer(cur_op, time_unit_size) cur_op += time_unit_size - args = 'unit: ' + str(unit) + args = "unit: " + str(unit) bin_size = read_as_integer(cur_op, uint64_size) cur_op += uint64_size - args += ', binSize: ' + str(bin_size) + args += ", binSize: " + str(bin_size) timezone = read_as_integer(cur_op, timezone_size) cur_op += timezone_size - args += ', timezone: ' + hex(timezone) + args += ", timezone: " + hex(timezone) day_of_week = read_as_integer(cur_op, day_of_week_size) cur_op += day_of_week_size - args += ', dayOfWeek: ' + str(day_of_week) - elif op_name in ['traverseCsiCellValues', 'traverseCsiCellTypes']: + args += ", dayOfWeek: " + str(day_of_week) + elif op_name in ["traverseCsiCellValues", "traverseCsiCellTypes"]: offset = read_as_integer_signed(cur_op, int_size) cur_op += int_size - args = 'lambda at: ' + hex(cur_op + offset) + args = "lambda at: " + hex(cur_op + offset) - yield hex(op_addr), '{} ({})'.format(op_name, args) + yield hex(op_addr), "{} ({})".format(op_name, args) - yield 'instructions count', \ - instr_count if not error else '? (successfully parsed {})'.format(instr_count) + yield ( + "instructions count", + instr_count + if not error + else "? (successfully parsed {})".format(instr_count), + ) def build_pretty_printer(): @@ -1108,7 +1217,9 @@ def build_pretty_printer(): pp.add("__wt_session_impl", "__wt_session_impl", False, WtSessionImplPrinter) pp.add("__wt_txn", "__wt_txn", False, WtTxnPrinter) pp.add("__wt_update", "__wt_update", False, WtUpdateToBsonPrinter) - pp.add("CodeFragment", "mongo::sbe::vm::CodeFragment", False, SbeCodeFragmentPrinter) + pp.add( + "CodeFragment", "mongo::sbe::vm::CodeFragment", False, SbeCodeFragmentPrinter + ) pp.add("boost::optional", "boost::optional", True, BoostOptionalPrinter) pp.add("immutable::map", "mongo::immutable::map", True, ImmutableMapPrinter) pp.add("immutable::set", "mongo::immutable::set", True, ImmutableSetPrinter) @@ -1126,6 +1237,8 @@ def build_pretty_printer(): ################################################################################################### # Register pretty-printers, replace existing mongo printers -gdb.printing.register_pretty_printer(gdb.current_objfile(), build_pretty_printer(), True) +gdb.printing.register_pretty_printer( + gdb.current_objfile(), build_pretty_printer(), True +) print("MongoDB GDB pretty-printers loaded") diff --git a/buildscripts/gdb/optimizer_printers.py b/buildscripts/gdb/optimizer_printers.py index 4a6e38bdbc8..eb5878aabe2 100644 --- a/buildscripts/gdb/optimizer_printers.py +++ b/buildscripts/gdb/optimizer_printers.py @@ -28,7 +28,7 @@ def eval_print_fn(val, print_fn): # replace them with a single EOL character so that GDB prints multi-line # explains nicely. pp_result = print_fn(val) - pp_str = str(pp_result).replace("\"", "").replace("\\n", "\n") + pp_str = str(pp_result).replace('"', "").replace("\\n", "\n") return pp_str @@ -145,7 +145,9 @@ class FixedArityNodePrinter(object): global operator_indent_level prior_indent = operator_indent_level - current_indent = operator_indent_level + self.arity + len(self.custom_children) - 1 + current_indent = ( + operator_indent_level + self.arity + len(self.custom_children) - 1 + ) for child in self.custom_children: lhs = "\n" for _ in range(current_indent): @@ -178,8 +180,8 @@ class FixedArityNodePrinter(object): class Vector(object): def __init__(self, vec): self.vec = vec - self.start = vec['_M_impl']['_M_start'] - self.finish = vec['_M_impl']['_M_finish'] + self.start = vec["_M_impl"]["_M_start"] + self.finish = vec["_M_impl"]["_M_finish"] def __iter__(self): item = self.start @@ -194,8 +196,11 @@ class Vector(object): def get(self, index): if index > self.count() - 1: - raise gdb.GdbError("Invalid Vector access at index {} with size {}".format( - index, self.count())) + raise gdb.GdbError( + "Invalid Vector access at index {} with size {}".format( + index, self.count() + ) + ) item = self.start + index return item.dereference() @@ -284,7 +289,9 @@ class ScanNodePrinter(object): return str(bound_projections.get(0)) def to_string(self): - return "Scan[{}, {}]".format(self.val["_scanDefName"], self.get_bound_projection()) + return "Scan[{}, {}]".format( + self.val["_scanDefName"], self.get_bound_projection() + ) class FilterNodePrinter(FixedArityNodePrinter): @@ -320,13 +327,16 @@ class ConstantPrinter(object): value_print_fn = "mongo::sbe::value::print" (print_fn_symbol, _) = gdb.lookup_symbol(value_print_fn) if print_fn_symbol is None: - raise gdb.GdbError("Could not find pretty print function: " + value_print_fn) + raise gdb.GdbError( + "Could not find pretty print function: " + value_print_fn + ) print_fn = print_fn_symbol.value() return print_fn(tag, value) def to_string(self): return "Constant[{}]".format( - ConstantPrinter.print_sbe_value(self.val["_tag"], self.val["_val"])) + ConstantPrinter.print_sbe_value(self.val["_tag"], self.val["_val"]) + ) class VariablePrinter(object): @@ -609,8 +619,12 @@ class FieldProjectionMapPrinter(object): res += ": " + str(root_proj) + ", " # Python reformats the string with embedded "=" characters, avoid that by replacing here. - res += str(self.val["_fieldProjections"]).replace("=", ":").replace("{", "(").replace( - "}", ")") + res += ( + str(self.val["_fieldProjections"]) + .replace("=", ":") + .replace("{", "(") + .replace("}", ")") + ) res += "}" return res @@ -624,7 +638,8 @@ class PhysicalScanNodePrinter(FixedArityNodePrinter): def to_string(self): return "PhysicalScan[{}, {}]".format( - str(self.val["_fieldProjectionMap"]), str(self.val["_scanDefName"])) + str(self.val["_fieldProjectionMap"]), str(self.val["_scanDefName"]) + ) class ValueScanNodePrinter(FixedArityNodePrinter): @@ -635,8 +650,9 @@ class ValueScanNodePrinter(FixedArityNodePrinter): super().__init__(val, 1, "ValueScan") def to_string(self): - return "ValueScan[hasRID={},arraySize={}]".format(self.val["_hasRID"], - self.val["_arraySize"]) + return "ValueScan[hasRID={},arraySize={}]".format( + self.val["_hasRID"], self.val["_arraySize"] + ) class CoScanNodePrinter(FixedArityNodePrinter): @@ -656,8 +672,11 @@ class IndexScanNodePrinter(FixedArityNodePrinter): def to_string(self): return "IndexScan[{{{}}}, scanDef={}, indexDef={}, interval={}]".format( - self.val["_fieldProjectionMap"], self.val["_scanDefName"], self.val["_indexDefName"], - self.val["_indexInterval"]).replace("\n", "") + self.val["_fieldProjectionMap"], + self.val["_scanDefName"], + self.val["_indexDefName"], + self.val["_indexInterval"], + ).replace("\n", "") class SeekNodePrinter(FixedArityNodePrinter): @@ -668,9 +687,11 @@ class SeekNodePrinter(FixedArityNodePrinter): super().__init__(val, 2, "Seek") def to_string(self): - return "Seek[rid_projection: {}, {}, scanDef: {}]".format(self.val["_ridProjectionName"], - self.val["_fieldProjectionMap"], - self.val["_scanDefName"]) + return "Seek[rid_projection: {}, {}, scanDef: {}]".format( + self.val["_ridProjectionName"], + self.val["_fieldProjectionMap"], + self.val["_scanDefName"], + ) class MemoLogicalDelegatorNodePrinter(FixedArityNodePrinter): @@ -692,8 +713,9 @@ class MemoPhysicalDelegatorNodePrinter(FixedArityNodePrinter): super().__init__(val, 0, "MemoPhysicalDelegator") def to_string(self): - return "MemoPhysicalDelegator[group: {}, index: {}]".format(self.val["_nodeId"]["_groupId"], - self.val["_nodeId"]["_index"]) + return "MemoPhysicalDelegator[group: {}, index: {}]".format( + self.val["_nodeId"]["_groupId"], self.val["_nodeId"]["_index"] + ) class ResidualRequirementPrinter(object): @@ -710,10 +732,18 @@ class ResidualRequirementPrinter(object): if get_boost_optional(key["_projectionName"]) is not None: res += "refProj: " + str(get_boost_optional(key["_projectionName"])) + ", " - res += "path: '" + str(key["_path"]).replace("| ", "").replace("\n", " -> ") + "'" + res += ( + "path: '" + + str(key["_path"]).replace("| ", "").replace("\n", " -> ") + + "'" + ) if get_boost_optional(req["_boundProjectionName"]) is not None: - res += "boundProj: " + str(get_boost_optional(req["_boundProjectionName"])) + ", " + res += ( + "boundProj: " + + str(get_boost_optional(req["_boundProjectionName"])) + + ", " + ) res += ">" return res @@ -805,9 +835,16 @@ class BinaryJoinNodePrinter(FixedArityNodePrinter): super().__init__(val, 3, "BinaryJoin") def to_string(self): - correlated = print_correlated_projections(self.val["_correlatedProjectionNames"]) - return "BinaryJoin[type=" + str(strip_namespace( - self.val["_joinType"])) + ", " + correlated + "]" + correlated = print_correlated_projections( + self.val["_correlatedProjectionNames"] + ) + return ( + "BinaryJoin[type=" + + str(strip_namespace(self.val["_joinType"])) + + ", " + + correlated + + "]" + ) def print_eq_join_condition(leftKeys, rightKeys): @@ -843,8 +880,9 @@ class MergeJoinNodePrinter(FixedArityNodePrinter): # Manually add the collation ops. collationOps = Vector(self.val["_collation"]) - collationChild = "Collation[" + ", ".join(str(collation) - for collation in collationOps) + "]" + collationChild = ( + "Collation[" + ", ".join(str(collation) for collation in collationOps) + "]" + ) self.add_child(collationChild) # Manually add the child which prints the sets of keys. @@ -859,7 +897,8 @@ class MergeJoinNodePrinter(FixedArityNodePrinter): def print_collation_req(req): spec = Vector(req["_spec"]) return ", ".join( - str(entry["first"]) + ": " + strip_namespace(entry["second"]) for entry in spec) + str(entry["first"]) + ": " + strip_namespace(entry["second"]) for entry in spec + ) class SortedMergeNodePrinter(DynamicArityNodePrinter): @@ -869,7 +908,9 @@ class SortedMergeNodePrinter(DynamicArityNodePrinter): """Initialize SortedMergeNodePrinter.""" super().__init__(val, 2, "MergeJoin") - self.add_child("collation[" + print_collation_req(self.val["_collationReq"]) + "]") + self.add_child( + "collation[" + print_collation_req(self.val["_collationReq"]) + "]" + ) def to_string(self): return "SortedMerge" @@ -883,9 +924,16 @@ class NestedLoopJoinNodePrinter(FixedArityNodePrinter): super().__init__(val, 3, "NestedLoopJoin") def to_string(self): - correlated = print_correlated_projections(self.val["_correlatedProjectionNames"]) - return "NestedLoopJoin[type=" + strip_namespace( - self.val["_joinType"]) + ", " + correlated + "]" + correlated = print_correlated_projections( + self.val["_correlatedProjectionNames"] + ) + return ( + "NestedLoopJoin[type=" + + strip_namespace(self.val["_joinType"]) + + ", " + + correlated + + "]" + ) class UnwindNodePrinter(FixedArityNodePrinter): @@ -912,8 +960,13 @@ class SpoolProducerNodePrinter(FixedArityNodePrinter): super().__init__(val, 4, "SpoolProducer") def to_string(self): - return "SpoolProducer[" + strip_namespace(self.val["_type"]) + ", id:" + str( - self.val["_spoolId"]) + "]" + return ( + "SpoolProducer[" + + strip_namespace(self.val["_type"]) + + ", id:" + + str(self.val["_spoolId"]) + + "]" + ) class SpoolConsumerNodePrinter(FixedArityNodePrinter): @@ -924,8 +977,13 @@ class SpoolConsumerNodePrinter(FixedArityNodePrinter): super().__init__(val, 1, "SpoolConsumer") def to_string(self): - return "SpoolConsumer[" + strip_namespace(self.val["_type"]) + ", id:" + str( - self.val["_spoolId"]) + "]" + return ( + "SpoolConsumer[" + + strip_namespace(self.val["_type"]) + + ", id:" + + str(self.val["_spoolId"]) + + "]" + ) class CollationNodePrinter(FixedArityNodePrinter): @@ -949,8 +1007,13 @@ class LimitSkipNodePrinter(FixedArityNodePrinter): super().__init__(val, 1, "LimitSkip") def to_string(self): - return "LimitSkip[limit: " + str(self.val["_property"]["_limit"]) + ", skip: " + str( - self.val["_property"]["_skip"]) + "]" + return ( + "LimitSkip[limit: " + + str(self.val["_property"]["_limit"]) + + ", skip: " + + str(self.val["_property"]["_skip"]) + + "]" + ) class ExchangeNodePrinter(FixedArityNodePrinter): @@ -961,10 +1024,17 @@ class ExchangeNodePrinter(FixedArityNodePrinter): super().__init__(val, 2, "Exchange") def to_string(self): - return "Exchange[type: " + str( - self.val["_distribution"]["_distributionAndProjections"] - ["_type"]) + ", projections: " + str( - self.val["_distribution"]["_distributionAndProjections"]["_projectionNames"]) + "]" + return ( + "Exchange[type: " + + str(self.val["_distribution"]["_distributionAndProjections"]["_type"]) + + ", projections: " + + str( + self.val["_distribution"]["_distributionAndProjections"][ + "_projectionNames" + ] + ) + + "]" + ) class ReferencesPrinter(DynamicArityNodePrinter): @@ -991,12 +1061,17 @@ class PolyValuePrinter(object): self.type_set = str(self.poly_type).split("<", 1)[1] if self.tag < 0: - raise gdb.GdbError("Invalid PolyValue tag: {}, must be at least 0".format(self.tag)) + raise gdb.GdbError( + "Invalid PolyValue tag: {}, must be at least 0".format(self.tag) + ) # Check if the tag is out of range for the set of types that we know about. if self.tag > len(self.type_set.split(",")): - raise gdb.GdbError("Unknown PolyValue tag: {} (max: {}), did you add a new one?".format( - self.tag, str(self.type_set))) + raise gdb.GdbError( + "Unknown PolyValue tag: {} (max: {}), did you add a new one?".format( + self.tag, str(self.type_set) + ) + ) @staticmethod def display_hint(): @@ -1004,8 +1079,11 @@ class PolyValuePrinter(object): return None def cast_control_block(self, target_type): - return self.control_block.dereference().address.cast( - target_type.pointer()).dereference()["_t"] + return ( + self.control_block.dereference() + .address.cast(target_type.pointer()) + .dereference()["_t"] + ) def get_dynamic_type(self): # Build up the dynamic type for the particular variant of this PolyValue instance. This is @@ -1026,9 +1104,12 @@ class PolyValuePrinter(object): return "Unknown PolyValue tag: {}, did you add a new one?".format(self.tag) # GDB automatically formats types with children, remove the extra characters to get the # output that we want. - return str(self.cast_control_block(dynamic_type)).replace(" = ", "").replace("{", - "").replace( - "}", "") + return ( + str(self.cast_control_block(dynamic_type)) + .replace(" = ", "") + .replace("{", "") + .replace("}", "") + ) class AtomPrinter(object): @@ -1076,9 +1157,12 @@ class DisjunctionPrinter(ConjunctionPrinter): def bool_expr_type(T): - return (f"{OPTIMIZER_NS}::algebra::PolyValue<" + f"{OPTIMIZER_NS}::BoolExpr<{T}>::Atom, " + - f"{OPTIMIZER_NS}::BoolExpr<{T}>::Conjunction, " + - f"{OPTIMIZER_NS}::BoolExpr<{T}>::Disjunction>") + return ( + f"{OPTIMIZER_NS}::algebra::PolyValue<" + + f"{OPTIMIZER_NS}::BoolExpr<{T}>::Atom, " + + f"{OPTIMIZER_NS}::BoolExpr<{T}>::Conjunction, " + + f"{OPTIMIZER_NS}::BoolExpr<{T}>::Disjunction>" + ) def register_optimizer_printers(pp): @@ -1086,8 +1170,12 @@ def register_optimizer_printers(pp): # IntervalRequirement printer. pp.add("Interval", f"{OPTIMIZER_NS}::IntervalRequirement", False, IntervalPrinter) - pp.add("CompoundInterval", f"{OPTIMIZER_NS}::CompoundIntervalRequirement", False, - CompoundIntervalPrinter) + pp.add( + "CompoundInterval", + f"{OPTIMIZER_NS}::CompoundIntervalRequirement", + False, + CompoundIntervalPrinter, + ) # IntervalReqExpr::Node printer. pp.add( @@ -1111,12 +1199,20 @@ def register_optimizer_printers(pp): pp.add("Memo", f"{OPTIMIZER_NS}::cascades::Memo", False, MemoPrinter) # ResidualRequirement printer. - pp.add("ResidualRequirement", f"{OPTIMIZER_NS}::ResidualRequirement", False, - ResidualRequirementPrinter) + pp.add( + "ResidualRequirement", + f"{OPTIMIZER_NS}::ResidualRequirement", + False, + ResidualRequirementPrinter, + ) # CandidateIndexEntry printer. - pp.add("CandidateIndexEntry", f"{OPTIMIZER_NS}::CandidateIndexEntry", False, - CandidateIndexEntryPrinter) + pp.add( + "CandidateIndexEntry", + f"{OPTIMIZER_NS}::CandidateIndexEntry", + False, + CandidateIndexEntryPrinter, + ) # BoolExpr is handled by the PolyValue printer, but still need to add # printers for each of the possible bool expr types. @@ -1124,14 +1220,26 @@ def register_optimizer_printers(pp): bool_exprs = ["ResidualRequirement"] for bool_type in bool_expr_types: for expr in bool_exprs: - pp.add(bool_type, f"{OPTIMIZER_NS}::BoolExpr<{OPTIMIZER_NS}::{expr}>::{bool_type}", - False, getattr(sys.modules[__name__], bool_type + "Printer")) + pp.add( + bool_type, + f"{OPTIMIZER_NS}::BoolExpr<{OPTIMIZER_NS}::{expr}>::{bool_type}", + False, + getattr(sys.modules[__name__], bool_type + "Printer"), + ) # Utility types within the optimizer. - pp.add("StrongStringAlias", f"{OPTIMIZER_NS}::StrongStringAlias", True, - StrongStringAliasPrinter) - pp.add("FieldProjectionMap", f"{OPTIMIZER_NS}::FieldProjectionMap", False, - FieldProjectionMapPrinter) + pp.add( + "StrongStringAlias", + f"{OPTIMIZER_NS}::StrongStringAlias", + True, + StrongStringAliasPrinter, + ) + pp.add( + "FieldProjectionMap", + f"{OPTIMIZER_NS}::FieldProjectionMap", + False, + FieldProjectionMapPrinter, + ) pp.add("ScanParams", f"{OPTIMIZER_NS}::ScanParams", False, ScanParamsPrinter) @@ -1202,9 +1310,13 @@ def register_optimizer_printers(pp): "ExpressionBinder", ] for abt_type in abt_type_set: - pp.add(abt_type, f"{OPTIMIZER_NS}::{abt_type}", False, - getattr(sys.modules[__name__], abt_type + "Printer")) + pp.add( + abt_type, + f"{OPTIMIZER_NS}::{abt_type}", + False, + getattr(sys.modules[__name__], abt_type + "Printer"), + ) # Add the generic PolyValue printer which determines the exact type at runtime and attempts to # invoke the printer for that type. - pp.add('PolyValue', OPTIMIZER_NS + "::algebra::PolyValue", True, PolyValuePrinter) + pp.add("PolyValue", OPTIMIZER_NS + "::algebra::PolyValue", True, PolyValuePrinter) diff --git a/buildscripts/gdb/udb.py b/buildscripts/gdb/udb.py index b51c517fd40..b72483b971e 100644 --- a/buildscripts/gdb/udb.py +++ b/buildscripts/gdb/udb.py @@ -7,12 +7,13 @@ import gdb # Pattern to match output of 'info files' PATTERN_ELF_SECTIONS = re.compile( - r'(?P[0x0-9a-fA-F]+)\s-\s(?P[0x0-9a-fA-F]+)\s\bis\b\s(?P
\.[a-z]+$)') + r"(?P[0x0-9a-fA-F]+)\s-\s(?P[0x0-9a-fA-F]+)\s\bis\b\s(?P
\.[a-z]+$)" +) def parse_sections(): """Find addresses for .text, .data, and .bss sections.""" - file_info = gdb.execute('info files', to_string=True) + file_info = gdb.execute("info files", to_string=True) section_map = {} for line in file_info.splitlines(): line = line.strip() @@ -20,10 +21,10 @@ def parse_sections(): if match is None: continue - section = match.group('section') - if section not in ('.text', '.data', '.bss'): + section = match.group("section") + if section not in (".text", ".data", ".bss"): continue - begin = match.group('begin') + begin = match.group("begin") section_map[section] = begin return section_map @@ -31,8 +32,9 @@ def parse_sections(): def load_sym_file_at_addrs(dbg_file, smap): """Invoke add-symbol-file with addresses.""" - cmd = 'add-symbol-file {} {} -s .data {} -s .bss {}'.format(dbg_file, smap['.text'], - smap['.data'], smap['.bss']) + cmd = "add-symbol-file {} {} -s .data {} -s .bss {}".format( + dbg_file, smap[".text"], smap[".data"], smap[".bss"] + ) gdb.execute(cmd, to_string=True) @@ -41,18 +43,20 @@ class LoadDebugFile(gdb.Command): def __init__(self): """GDB Command API init.""" - super(LoadDebugFile, self).__init__('load-debug-symbols', gdb.COMPLETE_EXPRESSION) + super(LoadDebugFile, self).__init__( + "load-debug-symbols", gdb.COMPLETE_EXPRESSION + ) def invoke(self, args, from_tty): """GDB Command API invoke.""" arglist = args.split() if len(arglist) != 1: - print('Usage: load-debug-symbols ') + print("Usage: load-debug-symbols ") return dbg_file = arglist[0] if not os.path.exists(dbg_file): - print('{} is not a valid file path'.format(dbg_file)) + print("{} is not a valid file path".format(dbg_file)) return try: @@ -65,13 +69,13 @@ class LoadDebugFile(gdb.Command): LoadDebugFile() PATTERN_ELF_SOLIB_SECTIONS = re.compile( - r'(?P[0x0-9a-fA-F]+)\s-\s(?P[0x0-9a-fA-F]+)\s\bis\b\s(?P
\.[a-z]+)\s\bin\b\s(?P.*$)' + r"(?P[0x0-9a-fA-F]+)\s-\s(?P[0x0-9a-fA-F]+)\s\bis\b\s(?P
\.[a-z]+)\s\bin\b\s(?P.*$)" ) def parse_solib_sections(): """Find addresses for .text, .data, and .bss sections.""" - file_info = gdb.execute('info files', to_string=True) + file_info = gdb.execute("info files", to_string=True) section_map = {} for line in file_info.splitlines(): line = line.strip() @@ -79,15 +83,18 @@ def parse_solib_sections(): if match is None: continue - section = match.group('section') - if section not in ('.text', '.data', '.bss'): + section = match.group("section") + if section not in (".text", ".data", ".bss"): continue - begin = match.group('begin') + begin = match.group("begin") # TODO duplicate fnames? - fname = os.path.basename(match.group('file')) + fname = os.path.basename(match.group("file")) - if fname.startswith("system-supplied DSO") or match.group('file').startswith( - "/lib") or match.group('file').startswith("/usr/lib"): + if ( + fname.startswith("system-supplied DSO") + or match.group("file").startswith("/lib") + or match.group("file").startswith("/usr/lib") + ): continue fname = f"{fname}.debug" section_map.setdefault(fname, {}) @@ -111,7 +118,9 @@ def find_dwarf_files(path): return out -SOLIB_SEARCH_PATH_PREFIX = "The search path for loading non-absolute shared library symbol files is " +SOLIB_SEARCH_PATH_PREFIX = ( + "The search path for loading non-absolute shared library symbol files is " +) def extend_solib_search_path(new_path: str): @@ -119,7 +128,7 @@ def extend_solib_search_path(new_path: str): solib_search_path = gdb.execute("show solib-search-path", to_string=True) # remove the prefix and suffix (which is a period and space) from the # search path - solib_search_path = solib_search_path[len(SOLIB_SEARCH_PATH_PREFIX):-2] + solib_search_path = solib_search_path[len(SOLIB_SEARCH_PATH_PREFIX) : -2] solib_search_path = f"{new_path}:{solib_search_path}" if solib_search_path.endswith(":"): solib_search_path = solib_search_path[:-1] @@ -127,7 +136,9 @@ def extend_solib_search_path(new_path: str): gdb.execute(f"set solib-search-path {solib_search_path}", to_string=True) -DEBUG_FILE_DIRECTORY_PREFIX = 'The directory where separate debug symbols are searched for is "' +DEBUG_FILE_DIRECTORY_PREFIX = ( + 'The directory where separate debug symbols are searched for is "' +) def extend_debug_file_directory(new_path: str): @@ -135,7 +146,7 @@ def extend_debug_file_directory(new_path: str): debug_file_directory = gdb.execute("show debug-file-directory", to_string=True) # remove the prefix and suffix (which is a period and space) from the # search path - debug_file_directory = debug_file_directory[len(DEBUG_FILE_DIRECTORY_PREFIX):-3] + debug_file_directory = debug_file_directory[len(DEBUG_FILE_DIRECTORY_PREFIX) : -3] debug_file_directory = f"{new_path}:{debug_file_directory}" if debug_file_directory.endswith(":"): debug_file_directory = debug_file_directory[:-1] @@ -152,7 +163,7 @@ class LoadDistTest(gdb.Command): def __init__(self): """GDB Command API init.""" - super(LoadDistTest, self).__init__('load-dist-test', gdb.COMPLETE_EXPRESSION) + super(LoadDistTest, self).__init__("load-dist-test", gdb.COMPLETE_EXPRESSION) try: # test if we're running udb @@ -166,11 +177,11 @@ class LoadDistTest(gdb.Command): """Fetch the name of the binary gdb is attached to.""" main_binary_name = gdb.objfiles()[0].filename main_binary_name = os.path.splitext(os.path.basename(main_binary_name))[0] - if main_binary_name.endswith('mongod'): + if main_binary_name.endswith("mongod"): return "mongod" - if main_binary_name.endswith('mongo'): + if main_binary_name.endswith("mongo"): return "mongo" - if main_binary_name.endswith('mongos'): + if main_binary_name.endswith("mongos"): return "mongos" return None @@ -183,7 +194,7 @@ class LoadDistTest(gdb.Command): print(f"No path provided, assuming '{arglist[0]}'") if len(arglist) != 1: - print('Usage: load-dist-test ') + print("Usage: load-dist-test ") return dist_test = arglist[0] diff --git a/buildscripts/gdb/wt_dump_table.py b/buildscripts/gdb/wt_dump_table.py index b3badab4a7c..21bd57ad35c 100644 --- a/buildscripts/gdb/wt_dump_table.py +++ b/buildscripts/gdb/wt_dump_table.py @@ -11,7 +11,7 @@ if not gdb: from buildscripts.gdb.mongo import lookup_type DEBUGGING = False -''' +""" Public API to be called by users. The input `ident` is a string of the form: 'collection-2--4547167393143767234'. From within gdb type: @@ -25,18 +25,20 @@ Some behaviors/limitations: more raw output. * Any `file:*.wt` can be output, e.g: `_mdb_catalog` or `WiredTiger`. Though the output may be less supported/of lower quality. -''' +""" def dump_pages_for_table(ident): conn_impl_type = lookup_type("WT_CONNECTION_IMPL") if not conn_impl_type: - print('WT_CONNECTION_IMPL type not found. Try invoking this function from a different \ -thread and frame.') + print( + "WT_CONNECTION_IMPL type not found. Try invoking this function from a different \ +thread and frame." + ) return conn_impl_ptr_type = conn_impl_type.pointer() - dbg('impl', conn_impl_ptr_type) + dbg("impl", conn_impl_ptr_type) conn_ptr = None try: @@ -46,18 +48,19 @@ thread and frame.') if not conn_ptr or not conn_ptr.address: print( - 'Failed to find a suitable `WT_SESSION session` object to extract a connection object \ + "Failed to find a suitable `WT_SESSION session` object to extract a connection object \ from. Try finding an eviction thread and frame, e.g: `__wt_evict_thread_run`. If the session is \ optimized out, try going up stack frames until the variable is in a local scope rather than a \ -function input.') +function input." + ) return conn = conn_ptr.reinterpret_cast(conn_impl_ptr_type).dereference() - dbg('conn', conn) - data_handle, all_dhs = get_data_handle(conn, 'file:{}.wt'.format(ident)) + dbg("conn", conn) + data_handle, all_dhs = get_data_handle(conn, "file:{}.wt".format(ident)) if not data_handle: - print('Data handle not found for ident. Ident: `{}`'.format(ident)) - print('All known data handles:') + print("Data handle not found for ident. Ident: `{}`".format(ident)) + print("All known data handles:") pprint(all_dhs) return @@ -69,153 +72,161 @@ def dbg(ident, var): if not DEBUGGING: return - print('----------') + print("----------") if type(var) == gdb.Value: - print('{}: ({}*){}'.format(ident, var.type, var.address)) + print("{}: ({}*){}".format(ident, var.type, var.address)) else: print(ident) - print(' ' + str(type(var))) + print(" " + str(type(var))) methods = dir(var) out = [name for name in methods if not name.startswith("__")] for item in out: - print(' ' + item) + print(" " + item) if type(var) == gdb.Value: - print('\n Fields:') - print('\t' + '\n\t'.join(str(var).split('\n'))) + print("\n Fields:") + print("\t" + "\n\t".join(str(var).split("\n"))) def walk_wt_list(lst): ret = [] - node = lst['tqh_first'] - dbg('node', node) + node = lst["tqh_first"] + dbg("node", node) while True: if not node: break ret.append(node.dereference()) - node = node['q']['tqe_next'] + node = node["q"]["tqe_next"] return ret def get_data_handle(conn, handle_name): - dbg('datahandles', conn['dhqh']) + dbg("datahandles", conn["dhqh"]) ret = None all_file_dhs = [] - for handle in walk_wt_list(conn['dhqh']): - if handle['name'].string().startswith('file:'): - all_file_dhs.append(handle['name'].string()[5:-3]) - if handle['name'].string() == handle_name: + for handle in walk_wt_list(conn["dhqh"]): + if handle["name"].string().startswith("file:"): + all_file_dhs.append(handle["name"].string()[5:-3]) + if handle["name"].string() == handle_name: ret = handle return ret, all_file_dhs def get_btree_handle(dhandle): - btree = lookup_type('WT_BTREE').pointer() - return dhandle['handle'].reinterpret_cast(btree).dereference() + btree = lookup_type("WT_BTREE").pointer() + return dhandle["handle"].reinterpret_cast(btree).dereference() def dump_update_chain(update_chain): while True: if not update_chain: - print(' λ (End of update chain)') + print(" λ (End of update chain)") break - dbg('update', update_chain) + dbg("update", update_chain) wt_val = update_chain.dereference() obj = None - dbg('wt_val', wt_val) - val_bytes = gdb.selected_inferior().read_memory(wt_val['data'], wt_val['size']) - can_bson = wt_val['type'] == 3 + dbg("wt_val", wt_val) + val_bytes = gdb.selected_inferior().read_memory(wt_val["data"], wt_val["size"]) + can_bson = wt_val["type"] == 3 if can_bson: try: obj = bson.decode_all(val_bytes)[0] except: pass - print(' ' + '\n '.join(str(wt_val).split('\n')) + " " + str(obj) + " =>") + print(" " + "\n ".join(str(wt_val).split("\n")) + " " + str(obj) + " =>") - update_chain = update_chain['next'] + update_chain = update_chain["next"] def dump_insert_list(wt_insert): - key_struct = wt_insert['u']['key'] - key = gdb.selected_inferior().read_memory( - int(wt_insert.address) + key_struct['offset'], key_struct['size']).tobytes() - print('Key: ' + str(key)) - print('Value:') - update_chain = wt_insert['upd'] + key_struct = wt_insert["u"]["key"] + key = ( + gdb.selected_inferior() + .read_memory(int(wt_insert.address) + key_struct["offset"], key_struct["size"]) + .tobytes() + ) + print("Key: " + str(key)) + print("Value:") + update_chain = wt_insert["upd"] dump_update_chain(update_chain) def dump_skip_list(wt_insert_head): - if not wt_insert_head['head'].address: + if not wt_insert_head["head"].address: return - wt_insert = wt_insert_head['head'][0] + wt_insert = wt_insert_head["head"][0] idx = 0 while True: if not wt_insert: break dump_insert_list(wt_insert.dereference()) - dbg('insert' + str(idx), wt_insert.dereference()) + dbg("insert" + str(idx), wt_insert.dereference()) idx += 1 - wt_insert = wt_insert['next'][0] + wt_insert = wt_insert["next"][0] def dump_modified(leaf_page): print("Modify:") - if not leaf_page['modify']: + if not leaf_page["modify"]: print("No modifies") return - leaf_modify = leaf_page['modify'].dereference() - dbg('modify', leaf_modify) - row_leaf_insert = leaf_modify['u2']['row_leaf']['insert'] - dbg('row store', row_leaf_insert) + leaf_modify = leaf_page["modify"].dereference() + dbg("modify", leaf_modify) + row_leaf_insert = leaf_modify["u2"]["row_leaf"]["insert"] + dbg("row store", row_leaf_insert) if not row_leaf_insert: print("No insert list") else: print("Insert list:") dump_skip_list(row_leaf_insert.dereference().dereference()) - row_leaf_update = leaf_modify['u2']['row_leaf']['update'] + row_leaf_update = leaf_modify["u2"]["row_leaf"]["update"] if not row_leaf_update: print("No update list") else: print("Update list:") - leaf_num_entries = int(leaf_page['entries']) + leaf_num_entries = int(leaf_page["entries"]) for i in range(0, leaf_num_entries): dump_update_chain(row_leaf_update[i]) def dump_disk(leaf_page): - dbg('in-memory page:', leaf_page) - dsk = leaf_page['dsk'].dereference() + dbg("in-memory page:", leaf_page) + dsk = leaf_page["dsk"].dereference() if int(dsk.address) == 0: print("No page loaded from disk.") return - dbg('on-disk page:', dsk) + dbg("on-disk page:", dsk) wt_page_header_size = 28 wt_block_header_size = 12 - page_bytes = gdb.selected_inferior().read_memory( - int(dsk.address) + wt_page_header_size + wt_block_header_size, - int(dsk['mem_size'])).tobytes() + page_bytes = ( + gdb.selected_inferior() + .read_memory( + int(dsk.address) + wt_page_header_size + wt_block_header_size, + int(dsk["mem_size"]), + ) + .tobytes() + ) print("Dsk:\n" + str(page_bytes)) def dump_handle(dhandle): - print("Dumping: " + dhandle['name'].string()) + print("Dumping: " + dhandle["name"].string()) btree = get_btree_handle(dhandle) - root = btree['root'] - root_page = root['page'].dereference() - dbg('btree', btree) - dbg('root', btree['root']) - dbg('root page', root_page) - rpindex = root_page['u']['intl']['__index'].dereference() - leaf_num_entries = int(rpindex['entries']) + root = btree["root"] + root_page = root["page"].dereference() + dbg("btree", btree) + dbg("root", btree["root"]) + dbg("root page", root_page) + rpindex = root_page["u"]["intl"]["__index"].dereference() + leaf_num_entries = int(rpindex["entries"]) for idx in range(0, leaf_num_entries): - dbg('rpindex', rpindex) - dbg('rp-pre-index', rpindex['index'].dereference().dereference()) - leaf_page = rpindex['index'][idx].dereference()['page'].dereference() - dbg('leaf', leaf_page) + dbg("rpindex", rpindex) + dbg("rp-pre-index", rpindex["index"].dereference().dereference()) + leaf_page = rpindex["index"][idx].dereference()["page"].dereference() + dbg("leaf", leaf_page) dump_disk(leaf_page) dump_modified(leaf_page) diff --git a/buildscripts/generate_version_expansions.py b/buildscripts/generate_version_expansions.py index 65d1c103802..90721329705 100755 --- a/buildscripts/generate_version_expansions.py +++ b/buildscripts/generate_version_expansions.py @@ -46,7 +46,7 @@ def generate_version_expansions(): with open(VERSION_JSON, "r") as fh: data = fh.read() version_data = json.loads(data) - version_line = version_data['version'] + version_line = version_data["version"] version_parts = match_verstr(version_line) if not version_parts: raise ValueError("Unable to parse version.json") @@ -58,7 +58,9 @@ def generate_version_expansions(): version_line = version_line.lstrip("r") version_parts = match_verstr(version_line) if not version_parts: - raise ValueError("Unable to parse version from stdin and no version.json provided") + raise ValueError( + "Unable to parse version from stdin and no version.json provided" + ) if version_parts[0]: expansions["suffix"] = "v8.0-latest" @@ -83,7 +85,7 @@ def match_verstr(verstr): r2.3.4-git234, r2.3.4-rc0-234-githash If the version is invalid (i.e. doesn't start with "2.3.4" or "2.3.4-rc0", this will return False. """ - res = re.match(r'^r?(?:\d+\.\d+\.\d+(?:-rc\d+|-alpha\d+)?)(-.*)?', verstr) + res = re.match(r"^r?(?:\d+\.\d+\.\d+(?:-rc\d+|-alpha\d+)?)(-.*)?", verstr) if not res: return False return res.groups() diff --git a/buildscripts/golden_test.py b/buildscripts/golden_test.py index 3f4d17632a7..2a074e25179 100755 --- a/buildscripts/golden_test.py +++ b/buildscripts/golden_test.py @@ -71,7 +71,7 @@ def replace_variables(pattern: str, variables: dict) -> str: def get_path_name_regex(pattern: str) -> str: """Return the regex pattern for output names.""" - return '[0-9a-f]'.join([re.escape(part) for part in pattern.split('%')]) + return "[0-9a-f]".join([re.escape(part) for part in pattern.split("%")]) # For compatibility with version<3.8 that does not support shutil.copytree with dirs_exist_ok=True @@ -81,7 +81,9 @@ def copytree_dirs_exist_ok_compatibility(src, dest): os.makedirs(dest) files = os.listdir(src) for file in files: - copytree_dirs_exist_ok_compatibility(os.path.join(src, file), os.path.join(dest, file)) + copytree_dirs_exist_ok_compatibility( + os.path.join(src, file), os.path.join(dest, file) + ) else: shutil.copyfile(src, dest) @@ -94,10 +96,13 @@ def copytree_dirs_exist_ok(src, dest): @click.group() -@click.option('-n', '--dry-run', is_flag=True) -@click.option('-v', '--verbose', is_flag=True) -@click.option('--config', envvar='GOLDEN_TEST_CONFIG_PATH', - help='Config file path. Also GOLDEN_TEST_CONFIG_PATH environment variable.') +@click.option("-n", "--dry-run", is_flag=True) +@click.option("-v", "--verbose", is_flag=True) +@click.option( + "--config", + envvar="GOLDEN_TEST_CONFIG_PATH", + help="Config file path. Also GOLDEN_TEST_CONFIG_PATH environment variable.", +) @click.pass_context def cli(ctx, dry_run, verbose, config): """Manage test results from golden data test framework. @@ -150,31 +155,40 @@ class GoldenTestApp(object): def get_git_root(self): """Return the root for git repo.""" self.vprint("Querying git repo root") - repo_root = check_output("git rev-parse --show-toplevel", shell=True, text=True).strip() + repo_root = check_output( + "git rev-parse --show-toplevel", shell=True, text=True + ).strip() self.vprint(f"Found git repo root: '{repo_root}'") return repo_root def load_config(self, config_path): """Load configuration file.""" if config_path is None: - raise AppError(( - "Can't load config. GOLDEN_TEST_CONFIG_PATH envrionment variable is not set. Golden test CLI must be configured before use.\n" - "To configure it, follow the instructions in https://github.com/mongodb/mongo/blob/master/docs/golden_data_test_framework.md#how-to-diff-and-accept-new-test-outputs-on-a-workstation\n" - "Note: After setup you may need to rerun the tests for this utility to find them.")) + raise AppError( + ( + "Can't load config. GOLDEN_TEST_CONFIG_PATH envrionment variable is not set. Golden test CLI must be configured before use.\n" + "To configure it, follow the instructions in https://github.com/mongodb/mongo/blob/master/docs/golden_data_test_framework.md#how-to-diff-and-accept-new-test-outputs-on-a-workstation\n" + "Note: After setup you may need to rerun the tests for this utility to find them." + ) + ) self.vprint(f"Loading config from path: '{config_path}'") config = GoldenTestConfig.from_yaml_file(config_path) if config.outputRootPattern is None: - raise AppError("Invalid config. outputRootPattern config parameter is not set") + raise AppError( + "Invalid config. outputRootPattern config parameter is not set" + ) return config def get_output_path(self, output_name): """Return the path for given output name.""" if not re.match(self.output_name_regex, output_name): - raise AppError(f"Invalid name: '{output_name}'. " + - f"Does not match configured pattern: {self.output_name_pattern}") + raise AppError( + f"Invalid name: '{output_name}'. " + + f"Does not match configured pattern: {self.output_name_pattern}" + ) output_path = os.path.join(self.output_parent_path, output_name) if not os.path.isdir(output_path): raise AppError(f"No such directory: '{output_path}'") @@ -182,13 +196,17 @@ class GoldenTestApp(object): def list_outputs(self): """Return names of all available outputs.""" - self.vprint(f"Listing outputs in path: '{self.output_parent_path}' " + - f"matching '{self.output_name_pattern}'") + self.vprint( + f"Listing outputs in path: '{self.output_parent_path}' " + + f"matching '{self.output_name_pattern}'" + ) if not os.path.isdir(self.output_parent_path): return [] return [ - o for o in os.listdir(self.output_parent_path) if re.match(self.output_name_regex, o) + o + for o in os.listdir(self.output_parent_path) + if re.match(self.output_name_regex, o) and os.path.isdir(os.path.join(self.output_parent_path, o)) ] @@ -206,8 +224,10 @@ class GoldenTestApp(object): if latest_name is None: raise AppError("No outputs found") - self.vprint(f"Found output with latest creation time: {latest_name} " + - f"created at {latest_ctime}") + self.vprint( + f"Found output with latest creation time: {latest_name} " + + f"created at {latest_ctime}" + ) return latest_name @@ -215,20 +235,22 @@ class GoldenTestApp(object): """Return actual and expected paths for given output name.""" output_path = self.get_output_path(output_name) return OutputPaths( - actual=os.path.join(output_path, "actual"), expected=os.path.join( - output_path, "expected")) + actual=os.path.join(output_path, "actual"), + expected=os.path.join(output_path, "expected"), + ) def setup_linux(self): # Create config file - config_path = os.path.join(os.path.expanduser('~'), ".golden_test_config.yml") + config_path = os.path.join(os.path.expanduser("~"), ".golden_test_config.yml") if not os.path.isfile(config_path): print(f"Creating {config_path}") config_contents = ( r"""outputRootPattern: '/var/tmp/test_output/out-%%%%-%%%%-%%%%-%%%%'""" "\n" r"""diffCmd: 'git diff --no-index "{{expected}}" "{{actual}}"'""" - "\n") - with open(config_path, 'w') as file: + "\n" + ) + with open(config_path, "w") as file: file.write(config_contents) else: print(f"Skipping creating {config_path}, file exists.") @@ -237,14 +259,21 @@ class GoldenTestApp(object): etc_environment_path = "/etc/environment" env_var_defined = False if os.path.isfile(etc_environment_path): - with open(etc_environment_path, 'r') as file: + with open(etc_environment_path, "r") as file: for line in file.readlines(): if line.startswith("GOLDEN_TEST_CONFIG_PATH="): env_var_defined = True if not env_var_defined: print(f"Adding GOLDEN_TEST_CONFIG_PATH to {etc_environment_path}") - env_var_contents = (f"GOLDEN_TEST_CONFIG_PATH=\"{config_path}\"") - call(["sudo", "/bin/sh", "-c", f"echo '{env_var_contents}' >> {etc_environment_path}"]) + env_var_contents = f'GOLDEN_TEST_CONFIG_PATH="{config_path}"' + call( + [ + "sudo", + "/bin/sh", + "-c", + f"echo '{env_var_contents}' >> {etc_environment_path}", + ] + ) else: print( f"Skipping adding GOLDEN_TEST_CONFIG_PATH to {etc_environment_path}, variable already defined." @@ -252,15 +281,18 @@ class GoldenTestApp(object): def setup_windows(self): # Create config file - config_path = os.path.join(os.path.expandvars('%LocalAppData%'), ".golden_test_config.yml") + config_path = os.path.join( + os.path.expandvars("%LocalAppData%"), ".golden_test_config.yml" + ) if not os.path.isfile(config_path): print(f"Creating {config_path}") config_contents = ( r"outputRootPattern: 'C:\Users\Administrator\AppData\Local\Temp\test_output\out-%%%%-%%%%-%%%%-%%%%'" "\n" r"""diffCmd: 'git diff --no-index "{{expected}}" "{{actual}}"'""" - "\n") - with open(config_path, 'w') as file: + "\n" + ) + with open(config_path, "w") as file: file.write(config_contents) else: print(f"Skipping creating {config_path}, file exists.") @@ -268,17 +300,21 @@ class GoldenTestApp(object): # Add global GOLDEN_TEST_CONFIG_PATH environment variable if os.environ.get("GOLDEN_TEST_CONFIG_PATH") is None: print("Setting GOLDEN_TEST_CONFIG_PATH global variable") - call([ - "runas", "/profile", "/user:administrator", - f"setx GOLDEN_TEST_CONFIG_PATH \"{config_path}\"" - ]) + call( + [ + "runas", + "/profile", + "/user:administrator", + f'setx GOLDEN_TEST_CONFIG_PATH "{config_path}"', + ] + ) else: print( "Skipping setting GOLDEN_TEST_CONFIG_PATH global variable, variable already defined." ) - @cli.command('diff', help='Diff the expected and actual folders of the test output') - @click.argument('output_name', required=False) + @cli.command("diff", help="Diff the expected and actual folders of the test output") + @click.argument("output_name", required=False) @click.pass_obj def command_diff(self, output_name): """Diff the expected and actual folders of the test output.""" @@ -290,13 +326,14 @@ class GoldenTestApp(object): self.vprint(f"Diffing results from output '{output_name}'") paths = self.get_paths(output_name) - diff_cmd = replace_variables(self.config.diffCmd, - {'actual': paths.actual, 'expected': paths.expected}) + diff_cmd = replace_variables( + self.config.diffCmd, {"actual": paths.actual, "expected": paths.expected} + ) self.vprint(f"Running command: '{diff_cmd}'") self.call_shell(diff_cmd) - @cli.command('get-path', help='Get the root folder path of the test output.') - @click.argument('output_name', required=False) + @cli.command("get-path", help="Get the root folder path of the test output.") + @click.argument("output_name", required=False) @click.pass_obj def command_get_path(self, output_name): """Get the root folder path of the test output.""" @@ -308,9 +345,10 @@ class GoldenTestApp(object): print(self.get_output_path(output_name)) @cli.command( - 'accept', - help='Accept the actual test output and copy it as new golden data to the source repo.') - @click.argument('output_name', required=False) + "accept", + help="Accept the actual test output and copy it as new golden data to the source repo.", + ) + @click.argument("output_name", required=False) @click.pass_obj def command_accept(self, output_name): """Accept the actual test output and copy it as new golden data to the source repo.""" @@ -328,7 +366,7 @@ class GoldenTestApp(object): if not self.dry_run: copytree_dirs_exist_ok(paths.actual, repo_root) - @cli.command('clean', help='Remove all test outputs') + @cli.command("clean", help="Remove all test outputs") @click.pass_obj def command_clean(self): """Remove all test outputs.""" @@ -342,7 +380,7 @@ class GoldenTestApp(object): if not self.dry_run: shutil.rmtree(output_path) - @cli.command('latest', help='Get the name of the most recent test output') + @cli.command("latest", help="Get the name of the most recent test output") @click.pass_obj def command_latest(self): """Get the name of the most recent test output.""" @@ -351,7 +389,7 @@ class GoldenTestApp(object): output_name = self.get_latest_output() print(output_name) - @cli.command('list', help='List all names of the available test outputs') + @cli.command("list", help="List all names of the available test outputs") @click.pass_obj def command_list(self): """List all names of the available test outputs.""" @@ -360,7 +398,7 @@ class GoldenTestApp(object): for output_name in self.list_outputs(): print(output_name) - @cli.command('setup', help='Performs default setup based on current platform') + @cli.command("setup", help="Performs default setup based on current platform") @click.pass_obj def command_setup(self): """Performs default setup based on current platform.""" @@ -369,7 +407,9 @@ class GoldenTestApp(object): elif platform.platform().startswith("Windows"): self.setup_windows() else: - raise AppError(f"Platform not supported by this setup utility: {platform.platform()}") + raise AppError( + f"Platform not supported by this setup utility: {platform.platform()}" + ) def main(): diff --git a/buildscripts/hang_analyzer.py b/buildscripts/hang_analyzer.py index 30eba2206e4..9fdfe7b186a 100755 --- a/buildscripts/hang_analyzer.py +++ b/buildscripts/hang_analyzer.py @@ -2,6 +2,8 @@ """Stub file pointing users to the new invocation.""" if __name__ == "__main__": - print("Hello! It seems you've executed 'buildscripts/hang_analyzer.py'. We have recently\n" - "repackaged the hang analyzer as a subcommand of resmoke. It can now be invoked as\n" - "'./buildscripts/resmoke.py hang-analyzer' with all of the same arguments as before.") + print( + "Hello! It seems you've executed 'buildscripts/hang_analyzer.py'. We have recently\n" + "repackaged the hang analyzer as a subcommand of resmoke. It can now be invoked as\n" + "'./buildscripts/resmoke.py hang-analyzer' with all of the same arguments as before." + ) diff --git a/buildscripts/idl/check_stable_api_commands_have_idl_definitions.py b/buildscripts/idl/check_stable_api_commands_have_idl_definitions.py index 275a879dab4..5b26f76cbfd 100644 --- a/buildscripts/idl/check_stable_api_commands_have_idl_definitions.py +++ b/buildscripts/idl/check_stable_api_commands_have_idl_definitions.py @@ -40,7 +40,7 @@ from typing import Dict, List, Set from pymongo import MongoClient # Permit imports from "buildscripts". -sys.path.append(os.path.normpath(os.path.join(os.path.abspath(__file__), '../../..'))) +sys.path.append(os.path.normpath(os.path.join(os.path.abspath(__file__), "../../.."))) # pylint: disable=wrong-import-position from buildscripts.idl.lib import list_idls, parse_idl @@ -52,7 +52,7 @@ from idl import syntax # pylint: enable=wrong-import-position -LOGGER_NAME = 'check-idl-definitions' +LOGGER_NAME = "check-idl-definitions" LOGGER = logging.getLogger(LOGGER_NAME) @@ -67,8 +67,9 @@ def is_test_or_third_party_idl(idl_path: str) -> bool: return False -def get_command_definitions(api_version: str, directory: str, - import_directories: List[str]) -> Dict[str, syntax.Command]: +def get_command_definitions( + api_version: str, directory: str, import_directories: List[str] +) -> Dict[str, syntax.Command]: """Get parsed IDL definitions of commands in a given API version.""" LOGGER.info("Searching for command definitions in %s", directory) @@ -76,16 +77,22 @@ def get_command_definitions(api_version: str, directory: str, def gen(): for idl_path in sorted(list_idls(directory)): if not is_test_or_third_party_idl(idl_path): - for command in parse_idl(idl_path, import_directories).spec.symbols.commands: + for command in parse_idl( + idl_path, import_directories + ).spec.symbols.commands: if command.api_version == api_version: yield command.command_name, command idl_commands = dict(gen()) - LOGGER.debug("Found %s IDL commands in API Version %s", len(idl_commands), api_version) + LOGGER.debug( + "Found %s IDL commands in API Version %s", len(idl_commands), api_version + ) return idl_commands -def list_commands_for_api(api_version: str, mongod_or_mongos: str, install_dir: str) -> Set[str]: +def list_commands_for_api( + api_version: str, mongod_or_mongos: str, install_dir: str +) -> Set[str]: """Get a list of commands in a given API version by calling listCommands.""" assert mongod_or_mongos in ("mongod", "mongos") logging.info("Calling listCommands on %s", mongod_or_mongos) @@ -97,28 +104,43 @@ def list_commands_for_api(api_version: str, mongod_or_mongos: str, install_dir: logger = loggers.new_fixture_logger("MongoDFixture", 0) logger.parent = LOGGER fixture: interface.Fixture = fixturelib.make_fixture( - "MongoDFixture", logger, 0, dbpath_prefix=dbpath.name, - mongod_executable=mongod_executable, mongod_options={"set_parameters": {}}) + "MongoDFixture", + logger, + 0, + dbpath_prefix=dbpath.name, + mongod_executable=mongod_executable, + mongod_options={"set_parameters": {}}, + ) else: logger = loggers.new_fixture_logger("ShardedClusterFixture", 0) logger.parent = LOGGER fixture = fixturelib.make_fixture( - "ShardedClusterFixture", logger, 0, dbpath_prefix=dbpath.name, - mongos_executable=mongos_executable, mongod_executable=mongod_executable, - mongod_options={"set_parameters": {}}) + "ShardedClusterFixture", + logger, + 0, + dbpath_prefix=dbpath.name, + mongos_executable=mongos_executable, + mongod_executable=mongod_executable, + mongod_options={"set_parameters": {}}, + ) fixture.setup() fixture.await_ready() try: client = MongoClient(fixture.get_driver_connection_url()) # type: MongoClient - reply = client.admin.command('listCommands') # type: Mapping[str, Any] + reply = client.admin.command("listCommands") # type: Mapping[str, Any] commands = { name - for name, info in reply['commands'].items() if api_version in info['apiVersions'] + for name, info in reply["commands"].items() + if api_version in info["apiVersions"] } - logging.info("Found %s commands in API Version %s on %s", len(commands), api_version, - mongod_or_mongos) + logging.info( + "Found %s commands in API Version %s on %s", + len(commands), + api_version, + mongod_or_mongos, + ) return commands finally: fixture.teardown() @@ -128,7 +150,9 @@ def assert_command_sets_equal(api_version: str, command_sets: Dict[str, Set[str] """Check that all sources have the same set of commands for a given API version.""" LOGGER.info("Comparing %s command sets", len(command_sets)) for name, commands in command_sets.items(): - LOGGER.info("--------- %s API Version %s commands --------------", name, api_version) + LOGGER.info( + "--------- %s API Version %s commands --------------", name, api_version + ) for command in sorted(commands): LOGGER.info("%s", command) @@ -138,13 +162,22 @@ def assert_command_sets_equal(api_version: str, command_sets: Dict[str, Set[str] for other_name, other_commands in it: if commands != other_commands: if commands - other_commands: - LOGGER.error("%s has commands not in %s: %s", name, other_name, - commands - other_commands) + LOGGER.error( + "%s has commands not in %s: %s", + name, + other_name, + commands - other_commands, + ) if other_commands - commands: - LOGGER.error("%s has commands not in %s: %s", other_name, name, - other_commands - commands) + LOGGER.error( + "%s has commands not in %s: %s", + other_name, + name, + other_commands - commands, + ) raise AssertionError( - f"{name} and {other_name} have different commands in API Version {api_version}") + f"{name} and {other_name} have different commands in API Version {api_version}" + ) def remove_skipped_commands(command_sets: Dict[str, Set[str]]): @@ -167,12 +200,24 @@ def remove_skipped_commands(command_sets: Dict[str, Set[str]]): def main(): """Run the script.""" arg_parser = argparse.ArgumentParser(description=__doc__) - arg_parser.add_argument("--include", type=str, action="append", - help="Directory to search for IDL import files") - arg_parser.add_argument("--install-dir", dest="install_dir", required=True, - help="Directory to search for MongoDB binaries") - arg_parser.add_argument("-v", "--verbose", action="count", help="Enable verbose logging") - arg_parser.add_argument("api_version", metavar="API_VERSION", help="API Version to check") + arg_parser.add_argument( + "--include", + type=str, + action="append", + help="Directory to search for IDL import files", + ) + arg_parser.add_argument( + "--install-dir", + dest="install_dir", + required=True, + help="Directory to search for MongoDB binaries", + ) + arg_parser.add_argument( + "-v", "--verbose", action="count", help="Enable verbose logging" + ) + arg_parser.add_argument( + "api_version", metavar="API_VERSION", help="API Version to check" + ) args = arg_parser.parse_args() class FakeArgs: @@ -190,12 +235,20 @@ def main(): loggers.configure_loggers() loggers.new_job_logger(sys.argv[0], 0) logging.basicConfig(level=logging.WARNING) - logging.getLogger(LOGGER_NAME).setLevel(logging.DEBUG if args.verbose else logging.INFO) + logging.getLogger(LOGGER_NAME).setLevel( + logging.DEBUG if args.verbose else logging.INFO + ) command_sets = {} - command_sets["mongod"] = list_commands_for_api(args.api_version, "mongod", args.install_dir) - command_sets["mongos"] = list_commands_for_api(args.api_version, "mongos", args.install_dir) - command_sets["idl"] = set(get_command_definitions(args.api_version, os.getcwd(), args.include)) + command_sets["mongod"] = list_commands_for_api( + args.api_version, "mongod", args.install_dir + ) + command_sets["mongos"] = list_commands_for_api( + args.api_version, "mongos", args.install_dir + ) + command_sets["idl"] = set( + get_command_definitions(args.api_version, os.getcwd(), args.include) + ) remove_skipped_commands(command_sets) assert_command_sets_equal(args.api_version, command_sets) diff --git a/buildscripts/idl/checkout_idl_files_from_past_releases.py b/buildscripts/idl/checkout_idl_files_from_past_releases.py index 91622d29606..abf357ea785 100755 --- a/buildscripts/idl/checkout_idl_files_from_past_releases.py +++ b/buildscripts/idl/checkout_idl_files_from_past_releases.py @@ -38,7 +38,9 @@ from packaging.version import Version # Get relative imports to work when the package is not installed on the PYTHONPATH. if __name__ == "__main__" and __package__ is None: - sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + sys.path.append( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + ) # pylint: disable=wrong-import-position from buildscripts.resmokelib.multiversionconstants import ( @@ -49,7 +51,7 @@ from buildscripts.resmokelib.multiversionconstants import ( # pylint: enable=wrong-import-position -LOGGER_NAME = 'checkout-idl' +LOGGER_NAME = "checkout-idl" LOGGER = logging.getLogger(LOGGER_NAME) @@ -57,9 +59,9 @@ def get_tags() -> List[str]: """Get a list of git tags that the IDL compatibility script should check against.""" def gen_versions_and_tags(): - for tag in check_output(['git', 'tag']).decode().split(): + for tag in check_output(["git", "tag"]).decode().split(): # Releases are like "r5.6.7". Older ones aren't r-prefixed but we don't care about them. - if not tag.startswith('r'): + if not tag.startswith("r"): continue try: @@ -90,7 +92,9 @@ def get_tags() -> List[str]: for version, tag in sorted(gen_versions_and_tags(), reverse=True): major_minor_version = Version(f"{version.major}.{version.minor}") if major_minor_version in results: - candidate_tag, candidate_is_prerelease_version = results[major_minor_version] + candidate_tag, candidate_is_prerelease_version = results[ + major_minor_version + ] if candidate_tag is None: # This is the first tag we have seen for this version. Set our first # candidate tag and if this tag is a prerelease version. @@ -115,27 +119,36 @@ def make_idl_directories(tags: List[str], destination: str) -> None: for tag in tags: LOGGER.info("Checking out IDL files in %s", tag) directory = os.path.join(destination, tag) - for path in check_output(['git', 'ls-tree', '--name-only', '-r', tag]).decode().split(): - if not path.endswith('.idl'): + for path in ( + check_output(["git", "ls-tree", "--name-only", "-r", tag]).decode().split() + ): + if not path.endswith(".idl"): continue - contents = check_output(['git', 'show', f'{tag}:{path}']).decode() + contents = check_output(["git", "show", f"{tag}:{path}"]).decode() output_path = os.path.join(directory, path) os.makedirs(os.path.dirname(output_path), exist_ok=True) - with open(output_path, 'w+') as fd: + with open(output_path, "w+") as fd: fd.write(contents) def main(): """Run the script.""" arg_parser = argparse.ArgumentParser(description=__doc__) - arg_parser.add_argument("-v", "--verbose", action="count", help="Enable verbose logging") - arg_parser.add_argument("destination", metavar="DESTINATION", - help="Directory to check out past IDL file versions") + arg_parser.add_argument( + "-v", "--verbose", action="count", help="Enable verbose logging" + ) + arg_parser.add_argument( + "destination", + metavar="DESTINATION", + help="Directory to check out past IDL file versions", + ) args = arg_parser.parse_args() logging.basicConfig(level=logging.WARNING) - logging.getLogger(LOGGER_NAME).setLevel(logging.DEBUG if args.verbose else logging.INFO) + logging.getLogger(LOGGER_NAME).setLevel( + logging.DEBUG if args.verbose else logging.INFO + ) tags = get_tags() LOGGER.info("Fetching IDL files for past tags: %s", tags) diff --git a/buildscripts/idl/gen_all_feature_flag_list.py b/buildscripts/idl/gen_all_feature_flag_list.py index 1862ccbe9a5..c114d090db6 100644 --- a/buildscripts/idl/gen_all_feature_flag_list.py +++ b/buildscripts/idl/gen_all_feature_flag_list.py @@ -37,7 +37,7 @@ from typing import List import yaml # Permit imports from "buildscripts". -sys.path.append(os.path.normpath(os.path.join(os.path.abspath(__file__), '../../..'))) +sys.path.append(os.path.normpath(os.path.join(os.path.abspath(__file__), "../../.."))) # pylint: disable=wrong-import-position from buildscripts.idl import lib @@ -59,7 +59,7 @@ def get_all_feature_flags(idl_dirs: List[str] = None): # Most IDL files do not contain feature flags. # We can discard these quickly without expensive YAML parsing. with open(idl_path) as idl_file: - if 'feature_flags' not in idl_file.read(): + if "feature_flags" not in idl_file.read(): continue with open(idl_path) as idl_file: doc = parser.parse_file(idl_file, idl_path) @@ -81,7 +81,9 @@ def get_all_feature_flags_turned_off_by_default(idl_dirs: List[str] = None): all_flags = get_all_feature_flags(idl_dirs) all_default_false_flags = [flag for flag in all_flags if all_flags[flag] != "true"] - with open("buildscripts/resmokeconfig/fully_disabled_feature_flags.yml") as fully_disabled_ffs: + with open( + "buildscripts/resmokeconfig/fully_disabled_feature_flags.yml" + ) as fully_disabled_ffs: force_disabled_flags = yaml.safe_load(fully_disabled_ffs) return list(set(all_default_false_flags) - set(force_disabled_flags)) @@ -99,5 +101,5 @@ def main(): gen_all_feature_flags_file() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/buildscripts/idl/gen_all_server_params_list.py b/buildscripts/idl/gen_all_server_params_list.py index 558dafcae55..7332f041cf4 100644 --- a/buildscripts/idl/gen_all_server_params_list.py +++ b/buildscripts/idl/gen_all_server_params_list.py @@ -35,7 +35,7 @@ import sys from typing import List # Permit imports from "buildscripts". -sys.path.append(os.path.normpath(os.path.join(os.path.abspath(__file__), '../../..'))) +sys.path.append(os.path.normpath(os.path.join(os.path.abspath(__file__), "../../.."))) # pylint: disable=wrong-import-position from buildscripts.idl import lib @@ -57,7 +57,7 @@ def gen_all_server_params(idl_dirs: List[str] = None): # Most IDL files do not contain server parameters. # We can discard these quickly without expensive YAML parsing. with open(idl_path) as idl_file: - if 'server_parameters' not in idl_file.read(): + if "server_parameters" not in idl_file.read(): continue with open(idl_path) as idl_file: doc = parser.parse_file(idl_file, idl_path) @@ -79,5 +79,5 @@ def main(): gen_all_server_params_file() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/buildscripts/idl/idl/ast.py b/buildscripts/idl/idl/ast.py index b8f00ba6ae3..b26f6cb6b63 100644 --- a/buildscripts/idl/idl/ast.py +++ b/buildscripts/idl/idl/ast.py @@ -33,6 +33,7 @@ Represents the derived IDL specification after type resolution in the binding pa This is a lossy translation from the IDL Syntax tree as the IDL AST only contains information about the enums and structs that need code generated for them, and just enough information to do that. """ + import enum from . import common @@ -44,8 +45,9 @@ class IDLBoundSpec(object): def __init__(self, spec, error_collection): # type: (IDLAST, errors.ParserErrorCollection) -> None """Must specify either an IDL document or errors, not both.""" - assert (spec is None and error_collection is not None) or (spec is not None - and error_collection is None) + assert (spec is None and error_collection is not None) or ( + spec is not None and error_collection is None + ) self.spec = spec self.errors = error_collection @@ -290,12 +292,16 @@ class Field(common.SourceLocation): # type: () -> bool """Returns true if the IDL compiler should add a call to serialization options for this field.""" return self.query_shape is not None and self.query_shape in [ - QueryShapeFieldType.LITERAL, QueryShapeFieldType.ANONYMIZE + QueryShapeFieldType.LITERAL, + QueryShapeFieldType.ANONYMIZE, ] @property def should_shapify(self): - return self.query_shape is not None and self.query_shape != QueryShapeFieldType.PARAMETER + return ( + self.query_shape is not None + and self.query_shape != QueryShapeFieldType.PARAMETER + ) class Privilege(common.SourceLocation): diff --git a/buildscripts/idl/idl/binder.py b/buildscripts/idl/idl/binder.py index 0d228833aa5..ca08a94e1da 100644 --- a/buildscripts/idl/idl/binder.py +++ b/buildscripts/idl/idl/binder.py @@ -56,10 +56,13 @@ def _validate_single_bson_type(ctxt, idl_type, syntax_type): subtype = "" if not bson.is_valid_bindata_subtype(subtype): - ctxt.add_bad_bson_bindata_subtype_value_error(idl_type, syntax_type, idl_type.name, - subtype) + ctxt.add_bad_bson_bindata_subtype_value_error( + idl_type, syntax_type, idl_type.name, subtype + ) elif idl_type.bindata_subtype is not None: - ctxt.add_bad_bson_bindata_subtype_error(idl_type, syntax_type, idl_type.name, bson_type) + ctxt.add_bad_bson_bindata_subtype_error( + idl_type, syntax_type, idl_type.name, bson_type + ) return True @@ -74,21 +77,29 @@ def _validate_bson_types_list(ctxt, idl_type, syntax_type): for bson_type in bson_types: if bson_type in ["any", "chain"]: - ctxt.add_bad_any_type_use_error(idl_type, bson_type, syntax_type, idl_type.name) + ctxt.add_bad_any_type_use_error( + idl_type, bson_type, syntax_type, idl_type.name + ) return False if not bson.is_valid_bson_type(bson_type): - ctxt.add_bad_bson_type_error(idl_type, syntax_type, idl_type.name, bson_type) + ctxt.add_bad_bson_type_error( + idl_type, syntax_type, idl_type.name, bson_type + ) return False if not isinstance(idl_type, syntax.VariantType): if bson_type == "bindata": - ctxt.add_bad_bson_type_error(idl_type, syntax_type, idl_type.name, bson_type) + ctxt.add_bad_bson_type_error( + idl_type, syntax_type, idl_type.name, bson_type + ) return False # Cannot mix non-scalar types into the list of types if not bson.is_scalar_bson_type(bson_type): - ctxt.add_bad_bson_scalar_type_error(idl_type, syntax_type, idl_type.name, bson_type) + ctxt.add_bad_bson_scalar_type_error( + idl_type, syntax_type, idl_type.name, bson_type + ) return False return True @@ -102,7 +113,7 @@ def _validate_type(ctxt, idl_type): if idl_type.name.startswith("array<"): ctxt.add_array_not_valid_error(idl_type, "type", idl_type.name) - _validate_type_properties(ctxt, idl_type, 'type') + _validate_type_properties(ctxt, idl_type, "type") def _validate_cpp_type(ctxt, idl_type, syntax_type): @@ -115,20 +126,27 @@ def _validate_cpp_type(ctxt, idl_type, syntax_type): ctxt.add_no_string_data_error(idl_type, syntax_type, idl_type.name) # We do not support C++ char and float types for style reasons - if idl_type.cpp_type in ['char', 'wchar_t', 'char16_t', 'char32_t', 'float']: - ctxt.add_bad_cpp_numeric_type_use_error(idl_type, syntax_type, idl_type.name, - idl_type.cpp_type) + if idl_type.cpp_type in ["char", "wchar_t", "char16_t", "char32_t", "float"]: + ctxt.add_bad_cpp_numeric_type_use_error( + idl_type, syntax_type, idl_type.name, idl_type.cpp_type + ) # We do not support C++ builtin integer for style reasons - for numeric_word in ['signed', "unsigned", "int", "long", "short"]: - if re.search(r'\b%s\b' % (numeric_word), idl_type.cpp_type): - ctxt.add_bad_cpp_numeric_type_use_error(idl_type, syntax_type, idl_type.name, - idl_type.cpp_type) + for numeric_word in ["signed", "unsigned", "int", "long", "short"]: + if re.search(r"\b%s\b" % (numeric_word), idl_type.cpp_type): + ctxt.add_bad_cpp_numeric_type_use_error( + idl_type, syntax_type, idl_type.name, idl_type.cpp_type + ) # Return early so we only throw one error for types like "signed short int" return # Check for std fixed integer types which are allowed - if idl_type.cpp_type in ["std::int32_t", "std::int64_t", "std::uint32_t", "std::uint64_t"]: + if idl_type.cpp_type in [ + "std::int32_t", + "std::int64_t", + "std::uint32_t", + "std::uint64_t", + ]: return # Only allow 16-byte arrays since they are for MD5 and UUID @@ -146,27 +164,39 @@ def _validate_cpp_type(ctxt, idl_type, syntax_type): # Check for std fixed integer types which are not allowed. These are not allowed even if they # have the "std::" prefix. for std_numeric_type in [ - "int8_t", "int16_t", "int32_t", "int64_t", "uint8_t", "uint16_t", "uint32_t", "uint64_t" + "int8_t", + "int16_t", + "int32_t", + "int64_t", + "uint8_t", + "uint16_t", + "uint32_t", + "uint64_t", ]: if std_numeric_type in idl_type.cpp_type: - ctxt.add_bad_cpp_numeric_type_use_error(idl_type, syntax_type, idl_type.name, - idl_type.cpp_type) + ctxt.add_bad_cpp_numeric_type_use_error( + idl_type, syntax_type, idl_type.name, idl_type.cpp_type + ) return def _validate_chain_type_properties(ctxt, idl_type, syntax_type): # type: (errors.ParserContext, Union[syntax.Type, ast.Type], str) -> None """Validate a chained type has both a deserializer and serializer.""" - assert len( - idl_type.bson_serialization_type) == 1 and idl_type.bson_serialization_type[0] == 'chain' + assert ( + len(idl_type.bson_serialization_type) == 1 + and idl_type.bson_serialization_type[0] == "chain" + ) if idl_type.deserializer is None: - ctxt.add_missing_ast_required_field_error(idl_type, syntax_type, idl_type.name, - "deserializer") + ctxt.add_missing_ast_required_field_error( + idl_type, syntax_type, idl_type.name, "deserializer" + ) if idl_type.serializer is None: - ctxt.add_missing_ast_required_field_error(idl_type, syntax_type, idl_type.name, - "serializer") + ctxt.add_missing_ast_required_field_error( + idl_type, syntax_type, idl_type.name, "serializer" + ) def _validate_type_properties(ctxt, idl_type, syntax_type): @@ -184,38 +214,51 @@ def _validate_type_properties(ctxt, idl_type, syntax_type): # serialization for their C++ type. An internal_only type is not associated with BSON # and thus should not have a deserializer defined. if idl_type.deserializer is None and not idl_type.internal_only: - ctxt.add_missing_ast_required_field_error(idl_type, syntax_type, idl_type.name, - "deserializer") + ctxt.add_missing_ast_required_field_error( + idl_type, syntax_type, idl_type.name, "deserializer" + ) elif bson_type == "chain": _validate_chain_type_properties(ctxt, idl_type, syntax_type) elif bson_type == "string": # Strings support custom serialization unlike other non-object scalar types if idl_type.deserializer is None: - ctxt.add_missing_ast_required_field_error(idl_type, syntax_type, idl_type.name, - "deserializer") + ctxt.add_missing_ast_required_field_error( + idl_type, syntax_type, idl_type.name, "deserializer" + ) elif bson_type not in ["array", "object", "bindata"]: if idl_type.deserializer is None: - ctxt.add_missing_ast_required_field_error(idl_type, syntax_type, idl_type.name, - "deserializer") + ctxt.add_missing_ast_required_field_error( + idl_type, syntax_type, idl_type.name, "deserializer" + ) - if idl_type.deserializer is not None and "BSONElement" not in idl_type.deserializer: + if ( + idl_type.deserializer is not None + and "BSONElement" not in idl_type.deserializer + ): ctxt.add_not_custom_scalar_serialization_not_supported_error( - idl_type, syntax_type, idl_type.name, bson_type) + idl_type, syntax_type, idl_type.name, bson_type + ) if idl_type.serializer is not None: ctxt.add_not_custom_scalar_serialization_not_supported_error( - idl_type, syntax_type, idl_type.name, bson_type) + idl_type, syntax_type, idl_type.name, bson_type + ) - if bson_type == "bindata" and isinstance(idl_type, syntax.Type) and idl_type.default: + if ( + bson_type == "bindata" + and isinstance(idl_type, syntax.Type) + and idl_type.default + ): ctxt.add_bindata_no_default(idl_type, syntax_type, idl_type.name) else: # Now, this is a list of scalar types if idl_type.deserializer is None: - ctxt.add_missing_ast_required_field_error(idl_type, syntax_type, idl_type.name, - "deserializer") + ctxt.add_missing_ast_required_field_error( + idl_type, syntax_type, idl_type.name, "deserializer" + ) _validate_cpp_type(ctxt, idl_type, syntax_type) @@ -238,7 +281,9 @@ def _is_duplicate_field(ctxt, field_container, fields, ast_field): if field.name == ast_field.name: duplicate_field = field - ctxt.add_duplicate_field_error(ast_field, field_container, ast_field.name, duplicate_field) + ctxt.add_duplicate_field_error( + ast_field, field_container, ast_field.name, duplicate_field + ) return True return False @@ -246,26 +291,27 @@ def _is_duplicate_field(ctxt, field_container, fields, ast_field): def _get_struct_qualified_cpp_name(struct): # type: (syntax.Struct) -> str - return common.qualify_cpp_name(struct.cpp_namespace, - common.title_case(struct.cpp_name or struct.name)) + return common.qualify_cpp_name( + struct.cpp_namespace, common.title_case(struct.cpp_name or struct.name) + ) def _compute_field_is_view(resolved_field, ctxt, symbols): # type: (Union[syntax.Type, syntax.Enum, syntax.Struct], errors.ParserContext, syntax.SymbolTable) -> bool """Compute is_view for a symbol referenced by a field.""" # Resolved field is an array. - if (isinstance(resolved_field, syntax.ArrayType)): + if isinstance(resolved_field, syntax.ArrayType): # Inner type needs to be resolved. return _compute_field_is_view(resolved_field.element_type, ctxt, symbols) # Resolved field is a variant. - elif (isinstance(resolved_field, syntax.VariantType)): + elif isinstance(resolved_field, syntax.VariantType): for variant_type in resolved_field.variant_types: # Inner type needs to be resolved. - if (_compute_field_is_view(variant_type, ctxt, symbols)): + if _compute_field_is_view(variant_type, ctxt, symbols): return True for variant_struct_type in resolved_field.variant_struct_types: - if (_compute_is_view(variant_struct_type, ctxt, symbols)): + if _compute_is_view(variant_struct_type, ctxt, symbols): return True return False @@ -277,12 +323,13 @@ def _compute_field_is_view(resolved_field, ctxt, symbols): def _compute_chained_item_is_view(struct, ctxt, symbols, chained_item): # type: (syntax.Struct, errors.ParserContext, syntax.SymbolTable, Union[syntax.ChainedType, syntax.ChainedStruct]) -> bool """Helper to compute is_view of chained types or structs.""" - resolved_chained_item = symbols.resolve_type_from_name(ctxt, struct, chained_item.name, - chained_item.name) + resolved_chained_item = symbols.resolve_type_from_name( + ctxt, struct, chained_item.name, chained_item.name + ) # If symbols.resolve_field_type returns None, we can assume an error occured during the function. # We can rely on symbols.resolve_field_type to add errors. - if (resolved_chained_item is None): - assert (ctxt.errors.has_errors()) + if resolved_chained_item is None: + assert ctxt.errors.has_errors() return True return _compute_is_view(resolved_chained_item, ctxt, symbols) @@ -291,25 +338,29 @@ def _compute_command_type_is_view(struct, ctxt, symbols, field_type): # type: (syntax.Struct, errors.ParserContext, syntax.SymbolTable, syntax.FieldType) -> bool """ Compute is_view for the command parameter type. - + This function is similar to _compute_field_is_view, but because command parameter types are syntax.FieldType instead of syntax.Type, separate logic must exist to resolve the command parameter types. """ - if (isinstance(field_type, syntax.FieldTypeVariant)): + if isinstance(field_type, syntax.FieldTypeVariant): for variant_type in field_type.variant: - if (_compute_command_type_is_view(struct, ctxt, symbols, variant_type)): + if _compute_command_type_is_view(struct, ctxt, symbols, variant_type): return True - elif (isinstance(field_type, syntax.FieldTypeArray)): - return _compute_command_type_is_view(struct, ctxt, symbols, field_type.element_type) - elif (isinstance(field_type, syntax.FieldTypeSingle)): - resolved_type = symbols.resolve_field_type(ctxt, struct, field_type.type_name, field_type) + elif isinstance(field_type, syntax.FieldTypeArray): + return _compute_command_type_is_view( + struct, ctxt, symbols, field_type.element_type + ) + elif isinstance(field_type, syntax.FieldTypeSingle): + resolved_type = symbols.resolve_field_type( + ctxt, struct, field_type.type_name, field_type + ) # If symbols.resolve_field_type returns None, we can assume an error occured during the function. # We can rely on symbols.resolve_field_type to add errors. - if (resolved_type is None): - assert (ctxt.errors.has_errors()) + if resolved_type is None: + assert ctxt.errors.has_errors() return True - if (_compute_field_is_view(resolved_type, ctxt, symbols)): + if _compute_field_is_view(resolved_type, ctxt, symbols): return True else: ctxt.add_unknown_command_type_error(struct, struct.name) @@ -320,37 +371,38 @@ def _compute_struct_is_view(struct, ctxt, symbols): # type: (syntax.Struct, errors.ParserContext, syntax.SymbolTable) -> bool """Compute is_view for structs. A struct is a view if any of its fields are views.""" # Empty structs are non view types. - if (not struct.fields): + if not struct.fields: return False for field in struct.fields: - if (field.ignore): + if field.ignore: continue # Get the resolved field from the global symbol table. resolved_field = symbols.resolve_field_type(ctxt, field, field.name, field.type) # If symbols.resolve_field_type returns None, we can assume an error occured during the function. # We can rely on symbols.resolve_field_type to add errors. - if (resolved_field is None): - assert (ctxt.errors.has_errors()) + if resolved_field is None: + assert ctxt.errors.has_errors() return True # If any field is a view type, then the struct is also a view type. - if (_compute_field_is_view(resolved_field, ctxt, symbols)): + if _compute_field_is_view(resolved_field, ctxt, symbols): return True - if (struct.chained_types): + if struct.chained_types: for chained_type in struct.chained_types: - if (_compute_chained_item_is_view(struct, ctxt, symbols, chained_type)): + if _compute_chained_item_is_view(struct, ctxt, symbols, chained_type): return True - if (struct.chained_structs): + if struct.chained_structs: for chained_struct in struct.chained_structs: - if (_compute_chained_item_is_view(struct, ctxt, symbols, chained_struct)): + if _compute_chained_item_is_view(struct, ctxt, symbols, chained_struct): return True # Check command parameter. - if (isinstance(struct, syntax.Command)): - if (struct.type is not None - and _compute_command_type_is_view(struct, ctxt, symbols, struct.type)): + if isinstance(struct, syntax.Command): + if struct.type is not None and _compute_command_type_is_view( + struct, ctxt, symbols, struct.type + ): return True return False @@ -359,11 +411,11 @@ def _compute_struct_is_view(struct, ctxt, symbols): def _compute_is_view(symbol, ctxt, symbols): # type: (Union[syntax.Type, syntax.Enum, syntax.Struct], errors.ParserContext, syntax.SymbolTable) -> bool """Compute is_view for any symbol.""" - if (isinstance(symbol, syntax.Type)): + if isinstance(symbol, syntax.Type): return symbol.is_view - elif (isinstance(symbol, syntax.Enum)): + elif isinstance(symbol, syntax.Enum): return False - elif (isinstance(symbol, syntax.Struct)): + elif isinstance(symbol, syntax.Struct): return _compute_struct_is_view(symbol, ctxt, symbols) else: ctxt.add_unknown_symbol_error(symbol, symbol.name) @@ -389,11 +441,16 @@ def _bind_struct_common(ctxt, parsed_spec, struct, ast_struct): ast_struct.is_command_reply = struct.is_command_reply ast_struct.is_catalog_ctxt = struct.is_catalog_ctxt ast_struct.query_shape_component = struct.query_shape_component - ast_struct.unsafe_dangerous_disable_extra_field_duplicate_checks = struct.unsafe_dangerous_disable_extra_field_duplicate_checks + ast_struct.unsafe_dangerous_disable_extra_field_duplicate_checks = ( + struct.unsafe_dangerous_disable_extra_field_duplicate_checks + ) ast_struct.is_view = _compute_is_view(struct, ctxt, parsed_spec.symbols) # Check that unsafe_dangerous_disable_extra_field_duplicate_checks is used correctly - if ast_struct.unsafe_dangerous_disable_extra_field_duplicate_checks and ast_struct.strict is True: + if ( + ast_struct.unsafe_dangerous_disable_extra_field_duplicate_checks + and ast_struct.strict is True + ): ctxt.add_strict_and_disable_check_not_allowed(ast_struct) if struct.is_generic_cmd_list: @@ -414,8 +471,9 @@ def _bind_struct_common(ctxt, parsed_spec, struct, ast_struct): for chained_type in struct.chained_types: ast_field = _bind_chained_type(ctxt, parsed_spec, ast_struct, chained_type) - if ast_field and not _is_duplicate_field(ctxt, chained_type.name, ast_struct.fields, - ast_field): + if ast_field and not _is_duplicate_field( + ctxt, chained_type.name, ast_struct.fields, ast_field + ): ast_struct.fields.append(ast_field) # Merge chained structs as a chained struct and ignored fields @@ -427,7 +485,9 @@ def _bind_struct_common(ctxt, parsed_spec, struct, ast_struct): ast_field = _bind_field(ctxt, parsed_spec, field) if ast_field: if ast_struct.generic_list_type: - gen_field_info = ast.GenericFieldInfo(struct.file_name, struct.line, struct.column) + gen_field_info = ast.GenericFieldInfo( + struct.file_name, struct.line, struct.column + ) if ast_struct.generic_list_type == ast.GenericListType.ARG: gen_field_info.forward_to_shards = field.forward_to_shards elif ast_struct.generic_list_type == ast.GenericListType.REPLY: @@ -436,31 +496,47 @@ def _bind_struct_common(ctxt, parsed_spec, struct, ast_struct): else: assert False ast_field.generic_field_info = gen_field_info - if ast_field.supports_doc_sequence and not isinstance(ast_struct, ast.Command): + if ast_field.supports_doc_sequence and not isinstance( + ast_struct, ast.Command + ): # Doc sequences are only supported in commands at the moment - ctxt.add_bad_struct_field_as_doc_sequence_error(ast_struct, ast_struct.name, - ast_field.name) + ctxt.add_bad_struct_field_as_doc_sequence_error( + ast_struct, ast_struct.name, ast_field.name + ) if ast_field.non_const_getter and struct.immutable: ctxt.add_bad_field_non_const_getter_in_immutable_struct_error( - ast_struct, ast_struct.name, ast_field.name) + ast_struct, ast_struct.name, ast_field.name + ) - if not _is_duplicate_field(ctxt, ast_struct.name, ast_struct.fields, ast_field): + if not _is_duplicate_field( + ctxt, ast_struct.name, ast_struct.fields, ast_field + ): ast_struct.fields.append(ast_field) # Verify that each field on the struct defines a query shape type on the field if and only if # query_shape_component is defined on the struct. - if not field.hidden and struct.query_shape_component and ast_field.query_shape is None: - ctxt.add_must_declare_shape_type(ast_field, ast_struct.name, ast_field.name) + if ( + not field.hidden + and struct.query_shape_component + and ast_field.query_shape is None + ): + ctxt.add_must_declare_shape_type( + ast_field, ast_struct.name, ast_field.name + ) if not struct.query_shape_component and ast_field.query_shape is not None: - ctxt.add_must_be_query_shape_component(ast_field, ast_struct.name, ast_field.name) + ctxt.add_must_be_query_shape_component( + ast_field, ast_struct.name, ast_field.name + ) if ast_field.query_shape == ast.QueryShapeFieldType.ANONYMIZE and not ( - ast_field.type.cpp_type in ["std::string", "std::vector"] - or 'string' in ast_field.type.bson_serialization_type): - ctxt.add_query_shape_anonymize_must_be_string(ast_field, ast_field.name, - ast_field.type.cpp_type) + ast_field.type.cpp_type in ["std::string", "std::vector"] + or "string" in ast_field.type.bson_serialization_type + ): + ctxt.add_query_shape_anonymize_must_be_string( + ast_field, ast_field.name, ast_field.type.cpp_type + ) # Fill out the field comparison_order property as needed if ast_struct.generate_comparison_operators and ast_struct.fields: @@ -473,8 +549,9 @@ def _bind_struct_common(ctxt, parsed_spec, struct, ast_struct): if not ast_field.comparison_order == -1: use_default_order = False if ast_field.comparison_order in comparison_orders: - ctxt.add_duplicate_comparison_order_field_error(ast_struct, ast_struct.name, - ast_field.comparison_order) + ctxt.add_duplicate_comparison_order_field_error( + ast_struct, ast_struct.name, ast_field.comparison_order + ) comparison_orders.add(ast_field.comparison_order) @@ -493,10 +570,15 @@ def _inject_hidden_fields(struct): if struct.fields is None: struct.fields = [] - serialization_context_field = syntax.Field(struct.file_name, struct.line, struct.column) - serialization_context_field.name = "serialization_context" # This comes from basic_types.idl - serialization_context_field.type = syntax.FieldTypeSingle(struct.file_name, struct.line, - struct.column) + serialization_context_field = syntax.Field( + struct.file_name, struct.line, struct.column + ) + serialization_context_field.name = ( + "serialization_context" # This comes from basic_types.idl + ) + serialization_context_field.type = syntax.FieldTypeSingle( + struct.file_name, struct.line, struct.column + ) serialization_context_field.type.type_name = "serialization_context" serialization_context_field.cpp_name = "serializationContext" serialization_context_field.optional = False @@ -529,7 +611,9 @@ def _inject_hidden_command_fields(command): # Inject a "$db" which we can decode during command parsing db_field = syntax.Field(command.file_name, command.line, command.column) db_field.name = "$db" - db_field.type = syntax.FieldTypeSingle(command.file_name, command.line, command.column) + db_field.type = syntax.FieldTypeSingle( + command.file_name, command.line, command.column + ) db_field.type.type_name = "database_name" # This comes from basic_types.idl db_field.cpp_name = "dbName" db_field.serialize_op_msg_request_only = True @@ -601,7 +685,7 @@ def _bind_variant_field(ctxt, ast_field, idl_type): def gen_cpp_types(): for alternative in ast_field.type.variant_types: if alternative.is_array: - yield f'std::vector<{alternative.cpp_type}>' + yield f"std::vector<{alternative.cpp_type}>" else: yield alternative.cpp_type @@ -633,8 +717,9 @@ def _bind_command_type(ctxt, parsed_spec, command): ctxt.add_array_not_valid_error(ast_field, "field", ast_field.name) # Resolve the command type as a field - syntax_symbol = parsed_spec.symbols.resolve_field_type(ctxt, command, command.name, - command.type) + syntax_symbol = parsed_spec.symbols.resolve_field_type( + ctxt, command, command.name, command.type + ) if syntax_symbol is None: return None @@ -644,8 +729,11 @@ def _bind_command_type(ctxt, parsed_spec, command): assert not isinstance(syntax_symbol, syntax.Enum) - base_type = (syntax_symbol.element_type - if isinstance(syntax_symbol, syntax.ArrayType) else syntax_symbol) + base_type = ( + syntax_symbol.element_type + if isinstance(syntax_symbol, syntax.ArrayType) + else syntax_symbol + ) # Copy over only the needed information if this is a struct or a type. if isinstance(base_type, syntax.Struct): @@ -678,8 +766,9 @@ def _bind_command_reply_type(ctxt, parsed_spec, command): ast_field.description = f"{command.name} reply type" # Resolve the command type as a field - syntax_symbol = parsed_spec.symbols.resolve_type_from_name(ctxt, command, command.name, - command.reply_type) + syntax_symbol = parsed_spec.symbols.resolve_type_from_name( + ctxt, command, command.name, command.reply_type + ) if syntax_symbol is None: # Resolution failed, we've recorded an error. @@ -709,8 +798,9 @@ def _bind_enum_value(ctxt, parsed_spec, location, enum_name, enum_value): # type: (errors.ParserContext, syntax.IDLSpec, common.SourceLocation, str, str) -> str # Look up the enum for "enum_name" in the symbol table - access_check_enum = parsed_spec.symbols.resolve_type_from_name(ctxt, location, "access_check", - enum_name) + access_check_enum = parsed_spec.symbols.resolve_type_from_name( + ctxt, location, "access_check", enum_name + ) if access_check_enum is None: # Resolution failed, we've recorded an error. @@ -720,8 +810,9 @@ def _bind_enum_value(ctxt, parsed_spec, location, enum_name, enum_value): ctxt.add_unknown_type_error(location, enum_name, "enum") return None - syntax_enum = resolve_enum_value(ctxt, location, cast(syntax.Enum, access_check_enum), - enum_value) + syntax_enum = resolve_enum_value( + ctxt, location, cast(syntax.Enum, access_check_enum), enum_value + ) if not syntax_enum: return None @@ -732,22 +823,27 @@ def _bind_single_check(ctxt, parsed_spec, access_check): # type: (errors.ParserContext, syntax.IDLSpec, syntax.AccessCheck) -> ast.AccessCheck """Bind a single access_check.""" - ast_access_check = ast.AccessCheck(access_check.file_name, access_check.line, - access_check.column) + ast_access_check = ast.AccessCheck( + access_check.file_name, access_check.line, access_check.column + ) assert bool(access_check.check) != bool(access_check.privilege) if access_check.check: - ast_access_check.check = _bind_enum_value(ctxt, parsed_spec, access_check, "AccessCheck", - access_check.check) + ast_access_check.check = _bind_enum_value( + ctxt, parsed_spec, access_check, "AccessCheck", access_check.check + ) if not ast_access_check.check: return None else: privilege = access_check.privilege - ast_privilege = ast.Privilege(privilege.file_name, privilege.line, privilege.column) + ast_privilege = ast.Privilege( + privilege.file_name, privilege.line, privilege.column + ) - ast_privilege.resource_pattern = _bind_enum_value(ctxt, parsed_spec, privilege, "MatchType", - privilege.resource_pattern) + ast_privilege.resource_pattern = _bind_enum_value( + ctxt, parsed_spec, privilege, "MatchType", privilege.resource_pattern + ) if not ast_privilege.resource_pattern: return None @@ -930,7 +1026,9 @@ def _validate_variant_type(ctxt, syntax_symbol, field): for type_name, count in array_type_count.items(): if count > 1: - ctxt.add_variant_duplicate_types_error(syntax_symbol, field.name, f'array<{type_name}>') + ctxt.add_variant_duplicate_types_error( + syntax_symbol, field.name, f"array<{type_name}>" + ) types = len(syntax_symbol.variant_types) + len(syntax_symbol.variant_struct_types) if types < 2: @@ -953,14 +1051,14 @@ def _validate_field_properties(ctxt, ast_field): if ast_field.optional: ctxt.add_bad_field_default_and_optional(ast_field, ast_field.name) - if ast_field.type.bson_serialization_type == ['bindata']: + if ast_field.type.bson_serialization_type == ["bindata"]: ctxt.add_bindata_no_default(ast_field, ast_field.type.name, ast_field.name) if ast_field.always_serialize and not ast_field.optional: ctxt.add_bad_field_always_serialize_not_optional(ast_field, ast_field.name) # A "chain" type should never appear as a field. - if ast_field.type.bson_serialization_type == ['chain']: + if ast_field.type.bson_serialization_type == ["chain"]: ctxt.add_bad_array_of_chain(ast_field, ast_field.name) @@ -974,7 +1072,7 @@ def _validate_doc_sequence_field(ctxt, ast_field): # The only allowed BSON type for a doc_sequence field is "object" for serialization_type in ast_field.type.bson_serialization_type: - if serialization_type != 'object': + if serialization_type != "object": ctxt.add_bad_non_object_as_doc_sequence_error(ast_field, ast_field.name) @@ -986,7 +1084,7 @@ def _normalize_method_name(cpp_type_name, cpp_method_name): return cpp_method_name # Global function - if cpp_method_name.startswith('::'): + if cpp_method_name.startswith("::"): return cpp_method_name # Method is full qualified already @@ -998,7 +1096,7 @@ def _normalize_method_name(cpp_type_name, cpp_method_name): # Method is prefixed with just the type name if cpp_method_name.startswith(type_name): - return '::'.join(cpp_type_name.split('::')[0:-1]) + "::" + cpp_method_name + return "::".join(cpp_type_name.split("::")[0:-1]) + "::" + cpp_method_name return cpp_method_name @@ -1039,9 +1137,9 @@ def _bind_expression(expr, allow_literal_string=True): # std::string if allow_literal_string: strval = expr.literal - for i in ['\\', '"', "'"]: + for i in ["\\", '"', "'"]: if i in strval: - strval = strval.replace(i, '\\' + i) + strval = strval.replace(i, "\\" + i) node.expr = '"' + strval + '"' return node @@ -1086,11 +1184,11 @@ def _bind_condition(condition, condition_for): ast_condition.preprocessor = condition.preprocessor if condition.feature_flag: - assert condition_for == 'server_parameter' + assert condition_for == "server_parameter" ast_condition.feature_flag = condition.feature_flag if condition.min_fcv: - assert condition_for == 'server_parameter' + assert condition_for == "server_parameter" ast_condition.min_fcv = condition.min_fcv return ast_condition @@ -1115,7 +1213,9 @@ def _bind_type(idltype): ast_type.bson_serialization_type = idltype.bson_serialization_type ast_type.bindata_subtype = idltype.bindata_subtype ast_type.serializer = _normalize_method_name(idltype.cpp_type, idltype.serializer) - ast_type.deserializer = _normalize_method_name(idltype.cpp_type, idltype.deserializer) + ast_type.deserializer = _normalize_method_name( + idltype.cpp_type, idltype.deserializer + ) ast_type.deserialize_with_tenant = idltype.deserialize_with_tenant ast_type.internal_only = idltype.internal_only ast_type.is_query_shape_component = True @@ -1162,7 +1262,9 @@ def _bind_field(ctxt, parsed_spec, field): _validate_ignored_field(ctxt, field) return ast_field - syntax_symbol = parsed_spec.symbols.resolve_field_type(ctxt, field, field.name, field.type) + syntax_symbol = parsed_spec.symbols.resolve_field_type( + ctxt, field, field.name, field.type + ) if syntax_symbol is None: return None @@ -1179,12 +1281,16 @@ def _bind_field(ctxt, parsed_spec, field): _validate_array_type(ctxt, cast(syntax.ArrayType, syntax_symbol), field) elif field.supports_doc_sequence: # Doc sequences are only supported for arrays - ctxt.add_bad_non_array_as_doc_sequence_error(syntax_symbol, syntax_symbol.name, - ast_field.name) + ctxt.add_bad_non_array_as_doc_sequence_error( + syntax_symbol, syntax_symbol.name, ast_field.name + ) return None - base_type = (syntax_symbol.element_type - if isinstance(syntax_symbol, syntax.ArrayType) else syntax_symbol) + base_type = ( + syntax_symbol.element_type + if isinstance(syntax_symbol, syntax.ArrayType) + else syntax_symbol + ) # Copy over only the needed information if this is a struct or a type. @@ -1230,15 +1336,18 @@ def _bind_field(ctxt, parsed_spec, field): return None if ast_field.should_shapify and not ast_field.type.is_query_shape_component: - ctxt.add_must_be_query_shape_component(ast_field, ast_field.type.name, ast_field.name) + ctxt.add_must_be_query_shape_component( + ast_field, ast_field.type.name, ast_field.name + ) return ast_field def _bind_chained_type(ctxt, parsed_spec, location, chained_type): # type: (errors.ParserContext, syntax.IDLSpec, common.SourceLocation, syntax.ChainedType) -> ast.Field """Bind the specified chained type.""" - syntax_symbol = parsed_spec.symbols.resolve_type_from_name(ctxt, location, chained_type.name, - chained_type.name) + syntax_symbol = parsed_spec.symbols.resolve_type_from_name( + ctxt, location, chained_type.name, chained_type.name + ) if not syntax_symbol: return None @@ -1248,9 +1357,13 @@ def _bind_chained_type(ctxt, parsed_spec, location, chained_type): idltype = cast(syntax.Type, syntax_symbol) - if len(idltype.bson_serialization_type) != 1 or idltype.bson_serialization_type[0] != 'chain': - ctxt.add_chained_type_wrong_type_error(location, chained_type.name, - idltype.bson_serialization_type[0]) + if ( + len(idltype.bson_serialization_type) != 1 + or idltype.bson_serialization_type[0] != "chain" + ): + ctxt.add_chained_type_wrong_type_error( + location, chained_type.name, idltype.bson_serialization_type[0] + ) return None ast_field = ast.Field(location.file_name, location.line, location.column) @@ -1267,12 +1380,15 @@ def _bind_chained_struct(ctxt, parsed_spec, ast_struct, chained_struct): # type: (errors.ParserContext, syntax.IDLSpec, ast.Struct, syntax.ChainedStruct) -> None """Bind the specified chained struct.""" syntax_symbol = parsed_spec.symbols.resolve_type_from_name( - ctxt, ast_struct, chained_struct.name, chained_struct.name) + ctxt, ast_struct, chained_struct.name, chained_struct.name + ) if not syntax_symbol: return - if not isinstance(syntax_symbol, syntax.Struct) or isinstance(syntax_symbol, syntax.Command): + if not isinstance(syntax_symbol, syntax.Struct) or isinstance( + syntax_symbol, syntax.Command + ): ctxt.add_chained_struct_not_found_error(ast_struct, chained_struct.name) return @@ -1280,22 +1396,28 @@ def _bind_chained_struct(ctxt, parsed_spec, ast_struct, chained_struct): # chained struct cannot be strict unless it is inlined if struct.strict and not ast_struct.inline_chained_structs: - ctxt.add_chained_nested_struct_no_strict_error(ast_struct, ast_struct.name, - chained_struct.name) + ctxt.add_chained_nested_struct_no_strict_error( + ast_struct, ast_struct.name, chained_struct.name + ) if struct.chained_types or struct.chained_structs: - ctxt.add_chained_nested_struct_no_nested_error(ast_struct, ast_struct.name, - chained_struct.name) + ctxt.add_chained_nested_struct_no_nested_error( + ast_struct, ast_struct.name, chained_struct.name + ) # Configure a field for the chained struct. - ast_chained_field = ast.Field(ast_struct.file_name, ast_struct.line, ast_struct.column) + ast_chained_field = ast.Field( + ast_struct.file_name, ast_struct.line, ast_struct.column + ) ast_chained_field.name = struct.name ast_chained_field.type = _bind_struct_type(struct) ast_chained_field.cpp_name = chained_struct.cpp_name ast_chained_field.description = struct.description ast_chained_field.chained = True - if not _is_duplicate_field(ctxt, chained_struct.name, ast_struct.fields, ast_chained_field): + if not _is_duplicate_field( + ctxt, chained_struct.name, ast_struct.fields, ast_chained_field + ): ast_struct.fields.append(ast_chained_field) else: return @@ -1306,9 +1428,9 @@ def _bind_chained_struct(ctxt, parsed_spec, ast_struct, chained_struct): # Don't use internal fields in chained types, stick to local access only if ast_field.type.internal_only: continue - if ast_field and not _is_duplicate_field(ctxt, chained_struct.name, ast_struct.fields, - ast_field): - + if ast_field and not _is_duplicate_field( + ctxt, chained_struct.name, ast_struct.fields, ast_field + ): if ast_struct.inline_chained_structs: ast_field.chained_struct_field = ast_chained_field else: @@ -1322,8 +1444,11 @@ def _bind_globals(ctxt, parsed_spec): # type: (errors.ParserContext, syntax.IDLSpec) -> ast.Global """Bind the globals object from the idl.syntax tree into the idl.ast tree by doing a deep copy.""" if parsed_spec.globals: - ast_global = ast.Global(parsed_spec.globals.file_name, parsed_spec.globals.line, - parsed_spec.globals.column) + ast_global = ast.Global( + parsed_spec.globals.file_name, + parsed_spec.globals.line, + parsed_spec.globals.column, + ) ast_global.cpp_namespace = parsed_spec.globals.cpp_namespace ast_global.cpp_includes = parsed_spec.globals.cpp_includes @@ -1332,13 +1457,16 @@ def _bind_globals(ctxt, parsed_spec): configs = parsed_spec.globals.configs if configs: - ast_global.configs = ast.ConfigGlobal(configs.file_name, configs.line, configs.column) + ast_global.configs = ast.ConfigGlobal( + configs.file_name, configs.line, configs.column + ) if configs.initializer: init = configs.initializer ast_global.configs.initializer = ast.GlobalInitializer( - init.file_name, init.line, init.column) + init.file_name, init.line, init.column + ) # Parser rule makes it impossible to have both name and register/store. ast_global.configs.initializer.name = init.name ast_global.configs.initializer.register = init.register @@ -1364,8 +1492,9 @@ def _validate_enum_int(ctxt, idl_enum): try: int_values_set.add(int(enum_value.value)) except ValueError as value_error: - ctxt.add_enum_value_not_int_error(idl_enum, idl_enum.name, enum_value.value, - str(value_error)) + ctxt.add_enum_value_not_int_error( + idl_enum, idl_enum.name, enum_value.value, str(value_error) + ) return @@ -1390,7 +1519,9 @@ def _bind_enum(ctxt, idl_enum): return None for enum_value in idl_enum.values: - ast_enum_value = ast.EnumValue(enum_value.file_name, enum_value.line, enum_value.column) + ast_enum_value = ast.EnumValue( + enum_value.file_name, enum_value.line, enum_value.column + ) ast_enum_value.name = enum_value.name ast_enum_value.description = enum_value.description ast_enum_value.value = enum_value.value @@ -1405,7 +1536,7 @@ def _bind_enum(ctxt, idl_enum): if len(idl_enum.values) != len(values_set): ctxt.add_enum_value_not_unique_error(idl_enum, idl_enum.name) - if ast_enum.type == 'int': + if ast_enum.type == "int": _validate_enum_int(ctxt, idl_enum) return ast_enum @@ -1416,9 +1547,9 @@ def _bind_server_parameter_class(ctxt, ast_param, param): """Bind and validate ServerParameter attributes specific to specialized ServerParameters.""" # Fields specific to bound and unbound standard params. - for field in ['cpp_vartype', 'cpp_varname', 'on_update', 'validator']: + for field in ["cpp_vartype", "cpp_varname", "on_update", "validator"]: if getattr(param, field) is not None: - ctxt.add_server_parameter_invalid_attr(param, field, 'specialized') + ctxt.add_server_parameter_invalid_attr(param, field, "specialized") return None # Fields specific to specialized stroage. @@ -1426,8 +1557,9 @@ def _bind_server_parameter_class(ctxt, ast_param, param): if param.default is not None: if not param.default.is_constexpr: - ctxt.add_server_parameter_invalid_attr(param, 'default.is_constexpr=false', - 'specialized') + ctxt.add_server_parameter_invalid_attr( + param, "default.is_constexpr=false", "specialized" + ) return None ast_param.default = _bind_expression(param.default) @@ -1441,7 +1573,9 @@ def _bind_server_parameter_class(ctxt, ast_param, param): ast_param.cpp_class.override_validate = cls.override_validate # If set_at is cluster, then set must be overridden. Otherwise, use the parsed value. - ast_param.cpp_class.override_set = True if param.set_at == ['cluster'] else cls.override_set + ast_param.cpp_class.override_set = ( + True if param.set_at == ["cluster"] else cls.override_set + ) return ast_param @@ -1451,13 +1585,13 @@ def _bind_server_parameter_with_storage(ctxt, ast_param, param): """Bind and validate ServerParameter attributes specific to bound ServerParameters.""" # Fields specific to specialized and unbound standard params. - for field in ['cpp_class']: + for field in ["cpp_class"]: if getattr(param, field) is not None: - ctxt.add_server_parameter_invalid_attr(param, field, 'bound') + ctxt.add_server_parameter_invalid_attr(param, field, "bound") return None - if param.set_at == ['cluster']: - ast_param.cpp_vartype = f'TenantIdMap<{param.cpp_vartype}>' + if param.set_at == ["cluster"]: + ast_param.cpp_vartype = f"TenantIdMap<{param.cpp_vartype}>" else: ast_param.cpp_vartype = param.cpp_vartype ast_param.cpp_varname = param.cpp_varname @@ -1480,20 +1614,20 @@ def _bind_server_parameter_set_at(ctxt, param): # type: (errors.ParserContext, syntax.ServerParameter) -> str """Translate set_at options to C++ enum value.""" - if param.set_at == ['readonly']: + if param.set_at == ["readonly"]: # Readonly may not be mixed with startup or runtime return "ServerParameterType::kReadOnly" - if param.set_at == ['cluster']: + if param.set_at == ["cluster"]: # Cluster-wide parameters may not be mixed with startup or runtime. # They are implicitly runtime-only. return "ServerParameterType::kClusterWide" set_at = 0 for psa in param.set_at: - if psa.lower() == 'startup': + if psa.lower() == "startup": set_at |= 1 - elif psa.lower() == 'runtime': + elif psa.lower() == "runtime": set_at |= 2 else: ctxt.add_bad_setat_specifier(param, psa) @@ -1509,7 +1643,7 @@ def _bind_server_parameter_set_at(ctxt, param): return mask_to_text[set_at] # Can't happen based on above logic. - ctxt.add_bad_setat_specifier(param, ','.join(param.set_at)) + ctxt.add_bad_setat_specifier(param, ",".join(param.set_at)) return None @@ -1519,19 +1653,23 @@ def _bind_server_parameter(ctxt, param): ast_param = ast.ServerParameter(param.file_name, param.line, param.column) ast_param.name = param.name ast_param.description = param.description - ast_param.condition = _bind_condition(param.condition, condition_for='server_parameter') + ast_param.condition = _bind_condition( + param.condition, condition_for="server_parameter" + ) ast_param.redact = param.redact ast_param.test_only = param.test_only ast_param.deprecated_name = param.deprecated_name # The omit_in_ftdc flag can only be enabled for cluster parameters. - if param.omit_in_ftdc is not None and param.set_at != ['cluster']: - ctxt.add_server_parameter_invalid_attr(param, 'omit_in_ftdc=True', ''.join(param.set_at)) + if param.omit_in_ftdc is not None and param.set_at != ["cluster"]: + ctxt.add_server_parameter_invalid_attr( + param, "omit_in_ftdc=True", "".join(param.set_at) + ) return None # If omit_in_ftdc is None (it has not been set) for a cluster parameter, then emit an error. - if param.omit_in_ftdc is None and param.set_at == ['cluster']: - ctxt.add_server_parameter_required_attr(param, 'omit_in_ftdc', 'cluster') + if param.omit_in_ftdc is None and param.set_at == ["cluster"]: + ctxt.add_server_parameter_required_attr(param, "omit_in_ftdc", "cluster") ast_param.omit_in_ftdc = param.omit_in_ftdc @@ -1544,7 +1682,9 @@ def _bind_server_parameter(ctxt, param): elif param.cpp_varname: return _bind_server_parameter_with_storage(ctxt, ast_param, param) else: - ctxt.add_server_parameter_required_attr(param, 'cpp_varname', 'server_parameter') + ctxt.add_server_parameter_required_attr( + param, "cpp_varname", "server_parameter" + ) return None @@ -1565,7 +1705,11 @@ def _bind_feature_flags(ctxt, param): return None # Feature flags that default to true and should be FCV gated are required to have a version - if param.default.literal == "true" and param.shouldBeFCVGated.literal == "true" and not param.version: + if ( + param.default.literal == "true" + and param.shouldBeFCVGated.literal == "true" + and not param.version + ): ctxt.add_feature_flag_default_true_missing_version(param) return None @@ -1574,10 +1718,16 @@ def _bind_feature_flags(ctxt, param): ctxt.add_feature_flag_fcv_gated_false_has_version(param) return None - expr = syntax.Expression(param.default.file_name, param.default.line, param.default.column) - expr.expr = '%s, "%s"_sd, %s' % (param.default.literal, param.version if - (param.shouldBeFCVGated.literal == "true" - and param.version) else '', param.shouldBeFCVGated.literal) + expr = syntax.Expression( + param.default.file_name, param.default.line, param.default.column + ) + expr.expr = '%s, "%s"_sd, %s' % ( + param.default.literal, + param.version + if (param.shouldBeFCVGated.literal == "true" and param.version) + else "", + param.shouldBeFCVGated.literal, + ) ast_param.default = _bind_expression(expr) ast_param.default.export = False @@ -1590,7 +1740,7 @@ def _bind_feature_flags(ctxt, param): def _is_invalid_config_short_name(name): # type: (str) -> bool """Check if a given name is valid as a short name.""" - return ('.' in name) or (',' in name) + return ("." in name) or ("," in name) def _parse_config_option_sources(source_list): @@ -1628,7 +1778,7 @@ def _bind_config_option(ctxt, globals_spec, option): node = ast.ConfigOption(option.file_name, option.line, option.column) - if _is_invalid_config_short_name(option.short_name or ''): + if _is_invalid_config_short_name(option.short_name or ""): ctxt.add_invalid_short_name(option, option.short_name) return None @@ -1657,13 +1807,13 @@ def _bind_config_option(ctxt, globals_spec, option): ctxt.add_missing_short_name_with_single_name(option, option.single_name) return None - node.short_name = node.short_name + ',' + option.single_name + node.short_name = node.short_name + "," + option.single_name node.description = _bind_expression(option.description) node.arg_vartype = option.arg_vartype node.cpp_vartype = option.cpp_vartype node.cpp_varname = option.cpp_varname - node.condition = _bind_condition(option.condition, condition_for='config') + node.condition = _bind_condition(option.condition, condition_for="config") node.requires = option.requires node.conflicts = option.conflicts @@ -1687,7 +1837,7 @@ def _bind_config_option(ctxt, globals_spec, option): node.source = _parse_config_option_sources(source_list) if node.source is None: - ctxt.add_bad_source_specifier(option, ', '.join(source_list)) + ctxt.add_bad_source_specifier(option, ", ".join(source_list)) return None if option.duplicate_behavior: @@ -1703,17 +1853,17 @@ def _bind_config_option(ctxt, globals_spec, option): return None # Parse single digit, closed range, or open range of digits. - spread = option.positional.split('-') + spread = option.positional.split("-") if len(spread) == 1: # Make a single number behave like a range of that number, (e.g. "2" -> "2-2"). spread.append(spread[0]) if (len(spread) != 2) or ((spread[0] == "") and (spread[1] == "")): - ctxt.add_bad_numeric_range(option, 'positional', option.positional) + ctxt.add_bad_numeric_range(option, "positional", option.positional) try: node.positional_start = int(spread[0] or "-1") node.positional_end = int(spread[1] or "-1") except ValueError: - ctxt.add_bad_numeric_range(option, 'positional', option.positional) + ctxt.add_bad_numeric_range(option, "positional", option.positional) return None if option.validator is not None: @@ -1753,10 +1903,14 @@ def bind(parsed_spec): bound_spec.server_parameters.append(_bind_feature_flags(ctxt, feature_flag)) for server_parameter in parsed_spec.server_parameters: - bound_spec.server_parameters.append(_bind_server_parameter(ctxt, server_parameter)) + bound_spec.server_parameters.append( + _bind_server_parameter(ctxt, server_parameter) + ) for option in parsed_spec.configs: - bound_spec.configs.append(_bind_config_option(ctxt, parsed_spec.globals, option)) + bound_spec.configs.append( + _bind_config_option(ctxt, parsed_spec.globals, option) + ) if ctxt.errors.has_errors(): return ast.IDLBoundSpec(None, ctxt.errors) diff --git a/buildscripts/idl/idl/bson.py b/buildscripts/idl/idl/bson.py index 3400b27248c..d10f735a988 100644 --- a/buildscripts/idl/idl/bson.py +++ b/buildscripts/idl/idl/bson.py @@ -35,34 +35,34 @@ Utilities for validating bson types, etc. # scalar: True if the type is not an array or object # bson_type_enum: The BSONType enum value for the given type _BSON_TYPE_INFORMATION = { - "double": {'scalar': True, 'bson_type_enum': 'NumberDouble'}, - "string": {'scalar': True, 'bson_type_enum': 'String'}, - "object": {'scalar': False, 'bson_type_enum': 'Object'}, - "array": {'scalar': False, 'bson_type_enum': 'Array'}, - "bindata": {'scalar': True, 'bson_type_enum': 'BinData'}, - "undefined": {'scalar': True, 'bson_type_enum': 'Undefined'}, - "objectid": {'scalar': True, 'bson_type_enum': 'jstOID'}, - "bool": {'scalar': True, 'bson_type_enum': 'Bool'}, - "date": {'scalar': True, 'bson_type_enum': 'Date'}, - "null": {'scalar': True, 'bson_type_enum': 'jstNULL'}, - "regex": {'scalar': True, 'bson_type_enum': 'RegEx'}, - "int": {'scalar': True, 'bson_type_enum': 'NumberInt'}, - "timestamp": {'scalar': True, 'bson_type_enum': 'bsonTimestamp'}, - "long": {'scalar': True, 'bson_type_enum': 'NumberLong'}, - "decimal": {'scalar': True, 'bson_type_enum': 'NumberDecimal'}, + "double": {"scalar": True, "bson_type_enum": "NumberDouble"}, + "string": {"scalar": True, "bson_type_enum": "String"}, + "object": {"scalar": False, "bson_type_enum": "Object"}, + "array": {"scalar": False, "bson_type_enum": "Array"}, + "bindata": {"scalar": True, "bson_type_enum": "BinData"}, + "undefined": {"scalar": True, "bson_type_enum": "Undefined"}, + "objectid": {"scalar": True, "bson_type_enum": "jstOID"}, + "bool": {"scalar": True, "bson_type_enum": "Bool"}, + "date": {"scalar": True, "bson_type_enum": "Date"}, + "null": {"scalar": True, "bson_type_enum": "jstNULL"}, + "regex": {"scalar": True, "bson_type_enum": "RegEx"}, + "int": {"scalar": True, "bson_type_enum": "NumberInt"}, + "timestamp": {"scalar": True, "bson_type_enum": "bsonTimestamp"}, + "long": {"scalar": True, "bson_type_enum": "NumberLong"}, + "decimal": {"scalar": True, "bson_type_enum": "NumberDecimal"}, } # Dictionary of BinData subtype type Information # scalar: True if the type is not an array or object # bindata_enum: The BinDataType enum value for the given type _BINDATA_SUBTYPE = { - "generic": {'scalar': True, 'bindata_enum': 'BinDataGeneral'}, - "function": {'scalar': True, 'bindata_enum': 'Function'}, + "generic": {"scalar": True, "bindata_enum": "BinDataGeneral"}, + "function": {"scalar": True, "bindata_enum": "Function"}, # Also simply known as type 2, deprecated, and requires special handling - #"binary": { + # "binary": { # 'scalar': False, # 'bindata_enum': 'ByteArrayDeprecated' - #}, + # }, # Deprecated # "uuid_old": { # 'scalar': False, @@ -86,14 +86,14 @@ def is_scalar_bson_type(name): # type: (str) -> bool """Return True if this bson type is a scalar.""" assert is_valid_bson_type(name) - return _BSON_TYPE_INFORMATION[name]['scalar'] # type: ignore + return _BSON_TYPE_INFORMATION[name]["scalar"] # type: ignore def cpp_bson_type_name(name): # type: (str) -> str """Return the C++ type name for a bson type.""" assert is_valid_bson_type(name) - return _BSON_TYPE_INFORMATION[name]['bson_type_enum'] # type: ignore + return _BSON_TYPE_INFORMATION[name]["bson_type_enum"] # type: ignore def list_valid_types(): @@ -112,4 +112,4 @@ def cpp_bindata_subtype_type_name(name): # type: (str) -> str """Return the C++ type name for a bindata subtype.""" assert is_valid_bindata_subtype(name) - return _BINDATA_SUBTYPE[name]['bindata_enum'] # type: ignore + return _BINDATA_SUBTYPE[name]["bindata_enum"] # type: ignore diff --git a/buildscripts/idl/idl/common.py b/buildscripts/idl/idl/common.py index 97d9e787f0c..24d8f00ec6a 100644 --- a/buildscripts/idl/idl/common.py +++ b/buildscripts/idl/idl/common.py @@ -47,7 +47,7 @@ def title_case(name): # Only capitalize the last part of a fully-qualified name pos = name.rfind("::") if pos > -1: - return name[:pos + 2] + name[pos + 2:pos + 3].upper() + name[pos + 3:] + return name[: pos + 2] + name[pos + 2 : pos + 3].upper() + name[pos + 3 :] return name[0:1].upper() + name[1:] @@ -71,9 +71,9 @@ def _escape_template_string(template): # type: (str) -> str """Escape the '$' in template strings unless followed by '{'.""" # See https://docs.python.org/2/library/string.html#template-strings - template = template.replace('${', '#{') - template = template.replace('$', '$$') - return template.replace('#{', '${') + template = template.replace("${", "#{") + template = template.replace("$", "$$") + return template.replace("#{", "${") def template_format(template, template_params=None): @@ -114,5 +114,9 @@ class SourceLocation(object): Example location message: test.idl: (17, 4) """ - msg = "%s: (%d, %d)" % (os.path.basename(self.file_name), self.line, self.column) + msg = "%s: (%d, %d)" % ( + os.path.basename(self.file_name), + self.line, + self.column, + ) return msg # type: ignore diff --git a/buildscripts/idl/idl/compiler.py b/buildscripts/idl/idl/compiler.py index b16b31c9824..ef1c16d52da 100644 --- a/buildscripts/idl/idl/compiler.py +++ b/buildscripts/idl/idl/compiler.py @@ -72,29 +72,52 @@ class CompilerImportResolver(parser.ImportResolverBase): # type: (str, str) -> str """Return the complete path to an imported file name.""" - logging.debug("Resolving imported file '%s' for file '%s'", imported_file_name, base_file) + logging.debug( + "Resolving imported file '%s' for file '%s'", imported_file_name, base_file + ) # Check for fully-qualified paths - logging.debug("Checking for imported file '%s' for file '%s' at '%s'", imported_file_name, - base_file, imported_file_name) + logging.debug( + "Checking for imported file '%s' for file '%s' at '%s'", + imported_file_name, + base_file, + imported_file_name, + ) if os.path.isabs(imported_file_name) and os.path.exists(imported_file_name): - logging.debug("Found imported file '%s' for file '%s' at '%s'", imported_file_name, - base_file, imported_file_name) + logging.debug( + "Found imported file '%s' for file '%s' at '%s'", + imported_file_name, + base_file, + imported_file_name, + ) return imported_file_name for candidate_dir in self._import_directories or []: base_dir = os.path.abspath(candidate_dir) - resolved_file_name = os.path.normpath(os.path.join(base_dir, imported_file_name)) + resolved_file_name = os.path.normpath( + os.path.join(base_dir, imported_file_name) + ) - logging.debug("Checking for imported file '%s' for file '%s' at '%s'", - imported_file_name, base_file, resolved_file_name) + logging.debug( + "Checking for imported file '%s' for file '%s' at '%s'", + imported_file_name, + base_file, + resolved_file_name, + ) if os.path.exists(resolved_file_name): - logging.debug("Found imported file '%s' for file '%s' at '%s'", imported_file_name, - base_file, resolved_file_name) + logging.debug( + "Found imported file '%s' for file '%s' at '%s'", + imported_file_name, + base_file, + resolved_file_name, + ) return resolved_file_name - msg = ("Cannot find imported file '%s' for file '%s'" % (imported_file_name, base_file)) + msg = "Cannot find imported file '%s' for file '%s'" % ( + imported_file_name, + base_file, + ) logging.error(msg) raise errors.IDLError(msg) @@ -102,7 +125,7 @@ class CompilerImportResolver(parser.ImportResolverBase): def open(self, resolved_file_name): # type: (str) -> Any """Return an io.Stream for the requested file.""" - return io.open(resolved_file_name, encoding='utf-8') + return io.open(resolved_file_name, encoding="utf-8") def _write_dependencies(spec, write_dependencies_inline): @@ -133,7 +156,8 @@ def _update_import_includes(args, spec, header_file_name): if args.output_base_dir: base_include_h_file_name = os.path.relpath( - os.path.normpath(header_file_name), os.path.normpath(args.output_base_dir)) + os.path.normpath(header_file_name), os.path.normpath(args.output_base_dir) + ) else: base_include_h_file_name = os.path.abspath(header_file_name) @@ -144,21 +168,29 @@ def _update_import_includes(args, spec, header_file_name): if not spec.globals: spec.globals = syntax.Global(args.input_file, -1, -1) - first_dir = base_include_h_file_name.split('/')[0] + first_dir = base_include_h_file_name.split("/")[0] for resolved_file_name in spec.imports.resolved_imports: # Guess: the file naming rules are consistent across IDL invocations - include_h_file_name = os.path.splitext(resolved_file_name)[0] + args.output_suffix + ".h" + include_h_file_name = ( + os.path.splitext(resolved_file_name)[0] + args.output_suffix + ".h" + ) if args.output_base_dir: base_dir = os.path.normpath(args.output_base_dir) - include_h_file_name = os.path.relpath(os.path.normpath(include_h_file_name), base_dir) + include_h_file_name = os.path.relpath( + os.path.normpath(include_h_file_name), base_dir + ) if os.path.isabs(base_dir): include_h_file_name = os.path.join( - base_dir, include_h_file_name[include_h_file_name.rfind(first_dir):]) + base_dir, + include_h_file_name[include_h_file_name.rfind(first_dir) :], + ) else: - include_h_file_name = include_h_file_name[include_h_file_name.find(first_dir):] + include_h_file_name = include_h_file_name[ + include_h_file_name.find(first_dir) : + ] else: include_h_file_name = os.path.abspath(include_h_file_name) @@ -176,9 +208,12 @@ def compile_idl(args): logging.error("File '%s' not found", args.input_file) if args.output_source is None: - if '.' not in args.input_file: - logging.error("File name '%s' must be end with a filename extension, such as '%s.idl'", - args.input_file, args.input_file) + if "." not in args.input_file: + logging.error( + "File name '%s' must be end with a filename extension, such as '%s.idl'", + args.input_file, + args.input_file, + ) return False file_name_prefix = os.path.splitext(args.input_file)[0] @@ -194,9 +229,12 @@ def compile_idl(args): args.target_arch = platform.machine() # Compile the IDL through the 3 passes - with io.open(args.input_file, encoding='utf-8') as file_stream: - parsed_doc = parser.parse(file_stream, args.input_file, - CompilerImportResolver(args.import_directories)) + with io.open(args.input_file, encoding="utf-8") as file_stream: + parsed_doc = parser.parse( + file_stream, + args.input_file, + CompilerImportResolver(args.import_directories), + ) if not parsed_doc.errors: if args.write_dependencies or args.write_dependencies_inline: @@ -210,8 +248,13 @@ def compile_idl(args): bound_doc = binder.bind(parsed_doc.spec) if not bound_doc.errors: - generator.generate_code(bound_doc.spec, args.target_arch, args.output_base_dir, - header_file_name, source_file_name) + generator.generate_code( + bound_doc.spec, + args.target_arch, + args.output_base_dir, + header_file_name, + source_file_name, + ) return True else: diff --git a/buildscripts/idl/idl/cpp_types.py b/buildscripts/idl/idl/cpp_types.py index b7e3caeebae..b0ca14b4aba 100644 --- a/buildscripts/idl/idl/cpp_types.py +++ b/buildscripts/idl/idl/cpp_types.py @@ -32,7 +32,7 @@ from abc import ABCMeta, abstractmethod from . import bson, common, writer -_STD_ARRAY_UINT8_16 = 'std::array' +_STD_ARRAY_UINT8_16 = "std::array" def is_primitive_scalar_type(cpp_type): @@ -42,26 +42,31 @@ def is_primitive_scalar_type(cpp_type): Primitive scalar types need to have a default value to prevent warnings from Coverity. """ - cpp_type = cpp_type.replace(' ', '') + cpp_type = cpp_type.replace(" ", "") # TODO (SERVER-50101): Remove 'multiversion::FeatureCompatibilityVersion' once IDL supports # a commmand cpp_type of C++ enum. return cpp_type in [ - 'bool', 'double', 'std::int32_t', 'std::uint32_t', 'std::uint64_t', 'std::int64_t', - 'multiversion::FeatureCompatibilityVersion' + "bool", + "double", + "std::int32_t", + "std::uint32_t", + "std::uint64_t", + "std::int64_t", + "multiversion::FeatureCompatibilityVersion", ] def is_primitive_type(cpp_type): # type: (str) -> bool """Return True if a cpp_type is a primitive type and should not be returned as reference.""" - cpp_type = cpp_type.replace(' ', '') + cpp_type = cpp_type.replace(" ", "") return is_primitive_scalar_type(cpp_type) or cpp_type == _STD_ARRAY_UINT8_16 def _qualify_optional_type(cpp_type): # type: (str) -> str """Qualify the type as optional.""" - return 'boost::optional<%s>' % (cpp_type) + return "boost::optional<%s>" % (cpp_type) def _qualify_array_type(cpp_type): @@ -74,7 +79,7 @@ def _optionally_make_call(method_name, param): # type: (str, str) -> str """Return a call to method_name if it is not None, otherwise return an empty string.""" if not method_name: - return '' + return "" return "%s(%s);" % (method_name, param) @@ -142,7 +147,9 @@ class CppTypeBase(metaclass=ABCMeta): return common.template_args( "${optionally_call_validator} ${member_name} = std::move(value);", member_name=member_name, - optionally_call_validator=_optionally_make_call(validator_method_name, "value"), + optionally_call_validator=_optionally_make_call( + validator_method_name, "value" + ), ) @abstractmethod @@ -175,7 +182,9 @@ class _CppTypeBasic(CppTypeBase): def return_by_reference(self): # type: () -> bool - return not is_primitive_type(self.get_type_name()) and not self._field.type.is_enum + return ( + not is_primitive_type(self.get_type_name()) and not self._field.type.is_enum + ) def is_view_type(self): # type: () -> bool @@ -187,14 +196,17 @@ class _CppTypeBasic(CppTypeBase): def get_getter_body(self, member_name): # type: (str) -> str - return common.template_args('return ${member_name};', member_name=member_name) + return common.template_args("return ${member_name};", member_name=member_name) def get_setter_body(self, member_name, validator_method_name): # type: (str, str) -> str return common.template_args( - '${optionally_call_validator} ${member_name} = std::move(value);', - optionally_call_validator=_optionally_make_call(validator_method_name, - 'value'), member_name=member_name) + "${optionally_call_validator} ${member_name} = std::move(value);", + optionally_call_validator=_optionally_make_call( + validator_method_name, "value" + ), + member_name=member_name, + ) def get_transform_to_getter_type(self, expression): # type: (str) -> Optional[str] @@ -240,15 +252,18 @@ class _CppTypeView(CppTypeBase): def get_getter_body(self, member_name): # type: (str) -> str - return common.template_args('return ${member_name};', member_name=member_name) + return common.template_args("return ${member_name};", member_name=member_name) def get_setter_body(self, member_name, validator_method_name): # type: (str, str) -> str return common.template_args( - 'auto _tmpValue = ${value}; ${optionally_call_validator} ${member_name} = std::move(_tmpValue);', - member_name=member_name, optionally_call_validator=_optionally_make_call( - validator_method_name, - '_tmpValue'), value=self.get_transform_to_storage_type("value")) + "auto _tmpValue = ${value}; ${optionally_call_validator} ${member_name} = std::move(_tmpValue);", + member_name=member_name, + optionally_call_validator=_optionally_make_call( + validator_method_name, "_tmpValue" + ), + value=self.get_transform_to_storage_type("value"), + ) def get_transform_to_getter_type(self, expression): # type: (str) -> Optional[str] @@ -257,7 +272,7 @@ class _CppTypeView(CppTypeBase): def get_transform_to_storage_type(self, expression): # type: (str) -> Optional[str] return common.template_args( - '${expression}.toString()', + "${expression}.toString()", expression=expression, ) @@ -267,7 +282,7 @@ class _CppTypeVector(CppTypeBase): def __init__(self, field): # type: (ast.Field) -> None - super(_CppTypeVector, self).__init__(field, 'std::vector') + super(_CppTypeVector, self).__init__(field, "std::vector") def get_type_name(self): # type: () -> str @@ -279,7 +294,7 @@ class _CppTypeVector(CppTypeBase): def get_getter_setter_type(self): # type: () -> str - return 'ConstDataRange' + return "ConstDataRange" def return_by_reference(self): # type: () -> bool @@ -295,27 +310,34 @@ class _CppTypeVector(CppTypeBase): def get_getter_body(self, member_name): # type: (str) -> str - return common.template_args('return ConstDataRange(${member_name});', - member_name=member_name) + return common.template_args( + "return ConstDataRange(${member_name});", member_name=member_name + ) def get_setter_body(self, member_name, validator_method_name): # type: (str, str) -> str return common.template_args( - 'auto _tmpValue = ${value}; ${optionally_call_validator} ${member_name} = std::move(_tmpValue);', - member_name=member_name, optionally_call_validator=_optionally_make_call( - validator_method_name, - '_tmpValue'), value=self.get_transform_to_storage_type("value")) + "auto _tmpValue = ${value}; ${optionally_call_validator} ${member_name} = std::move(_tmpValue);", + member_name=member_name, + optionally_call_validator=_optionally_make_call( + validator_method_name, "_tmpValue" + ), + value=self.get_transform_to_storage_type("value"), + ) def get_transform_to_getter_type(self, expression): # type: (str) -> Optional[str] - return common.template_args('ConstDataRange(${expression});', expression=expression) + return common.template_args( + "ConstDataRange(${expression});", expression=expression + ) def get_transform_to_storage_type(self, expression): # type: (str) -> Optional[str] return common.template_args( - 'std::vector(reinterpret_cast(${expression}.data()), ' + - 'reinterpret_cast(${expression}.data()) + ${expression}.length())', - expression=expression) + "std::vector(reinterpret_cast(${expression}.data()), " + + "reinterpret_cast(${expression}.data()) + ${expression}.length())", + expression=expression, + ) class _CppTypeDelegating(CppTypeBase): @@ -360,7 +382,9 @@ class _CppTypeDelegating(CppTypeBase): def get_storage_type_setter_body(self, member_name, validator_method_name): # type: (str, str) -> Optional[str] - return self._base.get_storage_type_setter_body(member_name, validator_method_name) + return self._base.get_storage_type_setter_body( + member_name, validator_method_name + ) def get_transform_to_getter_type(self, expression): # type: (str) -> Optional[str] @@ -392,7 +416,7 @@ class _CppTypeArray(_CppTypeDelegating): # type: (str) -> str convert = self.get_transform_to_getter_type(member_name) if convert: - return common.template_args('return ${convert};', convert=convert) + return common.template_args("return ${convert};", convert=convert) return self._base.get_getter_body(member_name) def get_setter_body(self, member_name, validator_method_name): @@ -400,16 +424,20 @@ class _CppTypeArray(_CppTypeDelegating): convert = self.get_transform_to_storage_type("value") if convert: return common.template_args( - 'auto _tmpValue = ${convert}; ${optionally_call_validator} ${member_name} = std::move(_tmpValue);', - member_name=member_name, optionally_call_validator=_optionally_make_call( - validator_method_name, '_tmpValue'), convert=convert) + "auto _tmpValue = ${convert}; ${optionally_call_validator} ${member_name} = std::move(_tmpValue);", + member_name=member_name, + optionally_call_validator=_optionally_make_call( + validator_method_name, "_tmpValue" + ), + convert=convert, + ) return self._base.get_setter_body(member_name, validator_method_name) def get_transform_to_getter_type(self, expression): # type: (str) -> Optional[str] if self._base.get_storage_type() != self._base.get_getter_setter_type(): return common.template_args( - 'transformVector(${expression})', + "transformVector(${expression})", expression=expression, ) return None @@ -418,7 +446,7 @@ class _CppTypeArray(_CppTypeDelegating): # type: (str) -> Optional[str] if self._base.get_storage_type() != self._base.get_getter_setter_type(): return common.template_args( - 'transformVector(${expression})', + "transformVector(${expression})", expression=expression, ) return None @@ -443,7 +471,9 @@ class _CppTypeOptional(_CppTypeDelegating): def get_getter_body(self, member_name): # type: (str) -> str - base_expression = common.template_args("*${member_name}", member_name=member_name) + base_expression = common.template_args( + "*${member_name}", member_name=member_name + ) convert = self._base.get_transform_to_getter_type(base_expression) if convert: @@ -457,13 +487,18 @@ class _CppTypeOptional(_CppTypeDelegating): } else { return boost::none; } - """), member_name=member_name, convert=convert) + """), + member_name=member_name, + convert=convert, + ) elif self.is_view_type(): # For optionals around view types, do an explicit construction - return common.template_args('return ${param_type}{${member_name}};', - param_type=self.get_getter_setter_type(), - member_name=member_name) - return common.template_args('return ${member_name};', member_name=member_name) + return common.template_args( + "return ${param_type}{${member_name}};", + param_type=self.get_getter_setter_type(), + member_name=member_name, + ) + return common.template_args("return ${member_name};", member_name=member_name) def _get_setter_body(self, member_name, validator_method_name, convert): # type: (str, str, str) -> str @@ -479,8 +514,13 @@ class _CppTypeOptional(_CppTypeDelegating): } else { ${member_name} = boost::none; } - """), member_name=member_name, convert=convert, - optionally_call_validator=_optionally_make_call(validator_method_name, '_tmpValue')) + """), + member_name=member_name, + convert=convert, + optionally_call_validator=_optionally_make_call( + validator_method_name, "_tmpValue" + ), + ) return self._base.get_setter_body(member_name, validator_method_name) def get_setter_body(self, member_name, validator_method_name): @@ -498,9 +538,9 @@ def get_cpp_type_from_cpp_type_name(field, cpp_type_name, array): # type: (ast.Field, str, bool) -> CppTypeBase """Get the C++ Type information for the given C++ type name, e.g. std::string.""" cpp_type_info: CppTypeBase - if cpp_type_name == 'std::string': - cpp_type_info = _CppTypeView(field, 'std::string', 'std::string', 'StringData') - elif cpp_type_name == 'std::vector': + if cpp_type_name == "std::string": + cpp_type_info = _CppTypeView(field, "std::string", "std::string", "StringData") + elif cpp_type_name == "std::vector": cpp_type_info = _CppTypeVector(field) else: cpp_type_info = _CppTypeBasic(field, cpp_type_name) @@ -514,7 +554,9 @@ def get_cpp_type_from_cpp_type_name(field, cpp_type_name, array): def get_cpp_type_without_optional(field): # type: (ast.Field) -> CppTypeBase """Get the C++ Type information for the given field but ignore optional.""" - return get_cpp_type_from_cpp_type_name(field, field.type.cpp_type, field.type.is_array) + return get_cpp_type_from_cpp_type_name( + field, field.type.cpp_type, field.type.is_array + ) def get_cpp_type(field): @@ -550,15 +592,17 @@ class BsonCppTypeBase(object, metaclass=ABCMeta): pass @abstractmethod - def gen_serializer_expression(self, indented_writer, expression, should_shapify=False, - is_catalog_ctxt=False): + def gen_serializer_expression( + self, indented_writer, expression, should_shapify=False, is_catalog_ctxt=False + ): # type: (writer.IndentedTextWriter, str, bool, bool) -> str """Generate code with the text writer and return an expression to serialize the type.""" pass -def _call_method_or_global_function(expression, ast_type, should_shapify=False, - is_catalog_ctxt=False): +def _call_method_or_global_function( + expression, ast_type, should_shapify=False, is_catalog_ctxt=False +): # type: (str, ast.Type, bool, bool) -> str """ Given a fully-qualified method name, call it correctly. @@ -568,25 +612,27 @@ def _call_method_or_global_function(expression, ast_type, should_shapify=False, enum deserializers/serializers which are not methods. """ method_name = ast_type.serializer - serialization_context = 'getSerializationContext()' if ast_type.deserialize_with_tenant else '' - shape_options = '' + serialization_context = ( + "getSerializationContext()" if ast_type.deserialize_with_tenant else "" + ) + shape_options = "" if should_shapify: - shape_options = 'options' + shape_options = "options" short_method_name = writer.get_method_name(method_name) if writer.is_function(method_name): if ast_type.deserialize_with_tenant: if is_catalog_ctxt: # serializeForCatalog doesn't need a serializationContext - serialization_context = '' + serialization_context = "" method_name = method_name.replace("serialize", "serializeForCatalog") else: - serialization_context = ', ' + serialization_context + serialization_context = ", " + serialization_context if should_shapify: - shape_options = ', ' + shape_options + shape_options = ", " + shape_options return common.template_args( - '${method_name}(${expression}${shape_options}${serialization_context})', + "${method_name}(${expression}${shape_options}${serialization_context})", expression=expression, method_name=method_name, shape_options=shape_options, @@ -594,7 +640,7 @@ def _call_method_or_global_function(expression, ast_type, should_shapify=False, ) return common.template_args( - '${expression}.${method_name}(${shape_options}${serialization_context})', + "${expression}.${method_name}(${shape_options}${serialization_context})", expression=expression, method_name=short_method_name, shape_options=shape_options, @@ -612,19 +658,23 @@ class _CommonBsonCppTypeBase(BsonCppTypeBase): def gen_deserializer_expression(self, indented_writer, object_instance): # type: (writer.IndentedTextWriter, str) -> str - return common.template_args('${object_instance}.${method_name}()', - object_instance=object_instance, - method_name=self._deserialize_method_name) + return common.template_args( + "${object_instance}.${method_name}()", + object_instance=object_instance, + method_name=self._deserialize_method_name, + ) def has_serializer(self): # type: () -> bool return self._ast_type.serializer is not None - def gen_serializer_expression(self, indented_writer, expression, should_shapify=False, - is_catalog_ctxt=False): + def gen_serializer_expression( + self, indented_writer, expression, should_shapify=False, is_catalog_ctxt=False + ): # type: (writer.IndentedTextWriter, str, bool, bool) -> str - return _call_method_or_global_function(expression, self._ast_type, should_shapify, - is_catalog_ctxt) + return _call_method_or_global_function( + expression, self._ast_type, should_shapify, is_catalog_ctxt + ) class _ObjectBsonCppTypeBase(BsonCppTypeBase): @@ -635,34 +685,43 @@ class _ObjectBsonCppTypeBase(BsonCppTypeBase): if self._ast_type.deserializer: # Call a method like: Class::method(const BSONObj& value) indented_writer.write_line( - common.template_args('const BSONObj localObject = ${object_instance}.Obj();', - object_instance=object_instance)) + common.template_args( + "const BSONObj localObject = ${object_instance}.Obj();", + object_instance=object_instance, + ) + ) return "localObject" # Just pass the BSONObj through without trying to parse it. - return common.template_args('${object_instance}.Obj()', object_instance=object_instance) + return common.template_args( + "${object_instance}.Obj()", object_instance=object_instance + ) def has_serializer(self): # type: () -> bool return self._ast_type.serializer is not None - def gen_serializer_expression(self, indented_writer, expression, should_shapify=False, - is_catalog_ctxt=False): + def gen_serializer_expression( + self, indented_writer, expression, should_shapify=False, is_catalog_ctxt=False + ): # type: (writer.IndentedTextWriter, str, bool, bool) -> str method_name = writer.get_method_name(self._ast_type.serializer) function_arguments = [] # SerializationContext is tied to tenant deserialization if self._ast_type.deserialize_with_tenant: - function_arguments.append('getSerializationContext()') + function_arguments.append("getSerializationContext()") # Provide options if custom shapification required. if should_shapify: - function_arguments.append('options') + function_arguments.append("options") indented_writer.write_line( common.template_args( - 'const BSONObj localObject = ${expression}.${method_name}(${function_arguments});', - expression=expression, method_name=method_name, - function_arguments=', '.join(function_arguments))) + "const BSONObj localObject = ${expression}.${method_name}(${function_arguments});", + expression=expression, + method_name=method_name, + function_arguments=", ".join(function_arguments), + ) + ) return "localObject" @@ -673,25 +732,34 @@ class _ArrayBsonCppTypeBase(BsonCppTypeBase): # type: (writer.IndentedTextWriter, str) -> str if self._ast_type.deserializer: indented_writer.write_line( - common.template_args('BSONArray localArray(${object_instance}.Obj());', - object_instance=object_instance)) + common.template_args( + "BSONArray localArray(${object_instance}.Obj());", + object_instance=object_instance, + ) + ) return "localArray" # Just pass the BSONObj through without trying to parse it. - return common.template_args('BSONArray(${object_instance}.Obj())', - object_instance=object_instance) + return common.template_args( + "BSONArray(${object_instance}.Obj())", object_instance=object_instance + ) def has_serializer(self): # type: () -> bool return self._ast_type.serializer is not None - def gen_serializer_expression(self, indented_writer, expression, should_shapify=False, - is_catalog_ctxt=False): + def gen_serializer_expression( + self, indented_writer, expression, should_shapify=False, is_catalog_ctxt=False + ): # type: (writer.IndentedTextWriter, str, bool, bool) -> str method_name = writer.get_method_name(self._ast_type.serializer) indented_writer.write_line( - common.template_args('BSONArray localArray(${expression}.${method_name}());', - expression=expression, method_name=method_name)) + common.template_args( + "BSONArray localArray(${expression}.${method_name}());", + expression=expression, + method_name=method_name, + ) + ) return "localArray" @@ -700,32 +768,45 @@ class _BinDataBsonCppTypeBase(BsonCppTypeBase): def gen_deserializer_expression(self, indented_writer, object_instance): # type: (writer.IndentedTextWriter, str) -> str - if self._ast_type.bindata_subtype == 'uuid': - return common.template_args('uassertStatusOK(UUID::parse(${object_instance}))', - object_instance=object_instance) - return common.template_args('${object_instance}._binDataVector()', - object_instance=object_instance) + if self._ast_type.bindata_subtype == "uuid": + return common.template_args( + "uassertStatusOK(UUID::parse(${object_instance}))", + object_instance=object_instance, + ) + return common.template_args( + "${object_instance}._binDataVector()", object_instance=object_instance + ) def has_serializer(self): # type: () -> bool return True - def gen_serializer_expression(self, indented_writer, expression, should_shapify=False, - is_catalog_ctxt=False): + def gen_serializer_expression( + self, indented_writer, expression, should_shapify=False, is_catalog_ctxt=False + ): # type: (writer.IndentedTextWriter, str, bool, bool) -> str if self._ast_type.serializer: method_name = writer.get_method_name(self._ast_type.serializer) indented_writer.write_line( - common.template_args('ConstDataRange tempCDR = ${expression}.${method_name}();', - expression=expression, method_name=method_name)) + common.template_args( + "ConstDataRange tempCDR = ${expression}.${method_name}();", + expression=expression, + method_name=method_name, + ) + ) else: indented_writer.write_line( - common.template_args('ConstDataRange tempCDR(${expression});', - expression=expression)) + common.template_args( + "ConstDataRange tempCDR(${expression});", expression=expression + ) + ) return common.template_args( - 'BSONBinData(tempCDR.data(), tempCDR.length(), ${bindata_subtype})', - bindata_subtype=bson.cpp_bindata_subtype_type_name(self._ast_type.bindata_subtype)) + "BSONBinData(tempCDR.data(), tempCDR.length(), ${bindata_subtype})", + bindata_subtype=bson.cpp_bindata_subtype_type_name( + self._ast_type.bindata_subtype + ), + ) # For some types, we want to support custom serialization but defer most of the serialization to @@ -739,19 +820,19 @@ def get_bson_cpp_type(ast_type): if len(ast_type.bson_serialization_type) > 1: return None - if ast_type.bson_serialization_type[0] == 'string': + if ast_type.bson_serialization_type[0] == "string": return _CommonBsonCppTypeBase(ast_type, "valueStringData") - if ast_type.bson_serialization_type[0] == 'object': + if ast_type.bson_serialization_type[0] == "object": return _ObjectBsonCppTypeBase(ast_type) - if ast_type.bson_serialization_type[0] == 'array': + if ast_type.bson_serialization_type[0] == "array": return _ArrayBsonCppTypeBase(ast_type) - if ast_type.bson_serialization_type[0] == 'bindata': + if ast_type.bson_serialization_type[0] == "bindata": return _BinDataBsonCppTypeBase(ast_type) - if ast_type.bson_serialization_type[0] == 'int': + if ast_type.bson_serialization_type[0] == "int": return _CommonBsonCppTypeBase(ast_type, "_numberInt") # Unsupported type diff --git a/buildscripts/idl/idl/enum_types.py b/buildscripts/idl/idl/enum_types.py index dbf676d0a50..ce0d7304705 100644 --- a/buildscripts/idl/idl/enum_types.py +++ b/buildscripts/idl/idl/enum_types.py @@ -52,7 +52,9 @@ class EnumTypeInfoBase(object, metaclass=ABCMeta): def get_qualified_cpp_type_name(self): # type: () -> str """Get the fully qualified C++ type name for an enum.""" - return common.qualify_cpp_name(self._enum.cpp_namespace, self.get_cpp_type_name()) + return common.qualify_cpp_name( + self._enum.cpp_namespace, self.get_cpp_type_name() + ) @abstractmethod def get_cpp_type_name(self): @@ -69,32 +71,37 @@ class EnumTypeInfoBase(object, metaclass=ABCMeta): def _get_enum_deserializer_name(self): # type: () -> str """Return the name of deserializer function without prefix.""" - return common.template_args("${enum_name}_parse", - enum_name=common.title_case(self._enum.name)) + return common.template_args( + "${enum_name}_parse", enum_name=common.title_case(self._enum.name) + ) def get_enum_deserializer_name(self): # type: () -> str """Return the name of deserializer function with non-method prefix.""" - return "::" + common.qualify_cpp_name(self._enum.cpp_namespace, - self._get_enum_deserializer_name()) + return "::" + common.qualify_cpp_name( + self._enum.cpp_namespace, self._get_enum_deserializer_name() + ) def _get_enum_serializer_name(self): # type: () -> str """Return the name of serializer function without prefix.""" - return common.template_args("${enum_name}_serializer", - enum_name=common.title_case(self._enum.name)) + return common.template_args( + "${enum_name}_serializer", enum_name=common.title_case(self._enum.name) + ) def get_enum_serializer_name(self): # type: () -> str """Return the name of serializer function with non-method prefix.""" - return "::" + common.qualify_cpp_name(self._enum.cpp_namespace, - self._get_enum_serializer_name()) + return "::" + common.qualify_cpp_name( + self._enum.cpp_namespace, self._get_enum_serializer_name() + ) def _get_enum_extra_data_name(self): # type: () -> str """Return the name of the get_extra_data function without prefix.""" - return common.template_args("${enum_name}_get_extra_data", - enum_name=common.title_case(self._enum.name)) + return common.template_args( + "${enum_name}_get_extra_data", enum_name=common.title_case(self._enum.name) + ) @abstractmethod def get_cpp_value_assignment(self, enum_value): @@ -137,9 +144,11 @@ class EnumTypeInfoBase(object, metaclass=ABCMeta): if len(self._get_populated_extra_values()) == 0: return None - return common.template_args("BSONObj ${function_name}(${enum_name} value)", - enum_name=self.get_cpp_type_name(), - function_name=self._get_enum_extra_data_name()) + return common.template_args( + "BSONObj ${function_name}(${enum_name} value)", + enum_name=self.get_cpp_type_name(), + function_name=self._get_enum_extra_data_name(), + ) def gen_extra_data_definition(self, indented_writer): # type: (writer.IndentedTextWriter) -> None @@ -152,43 +161,56 @@ class EnumTypeInfoBase(object, metaclass=ABCMeta): # Generate an anonymous namespace full of BSON constants. # - with writer.NamespaceScopeBlock(indented_writer, ['']): + with writer.NamespaceScopeBlock(indented_writer, [""]): for enum_value in extra_values: indented_writer.write_line( - common.template_args('// %s' % json.dumps(enum_value.extra_data))) + common.template_args("// %s" % json.dumps(enum_value.extra_data)) + ) - bson_value = ''.join( - [('\\x%02x' % (b)) for b in bson.BSON.encode(enum_value.extra_data)]) + bson_value = "".join( + [("\\x%02x" % (b)) for b in bson.BSON.encode(enum_value.extra_data)] + ) indented_writer.write_line( common.template_args( 'const BSONObj ${const_name}("${bson_value}");', - const_name=_get_constant_enum_extra_data_name(self._enum, enum_value), - bson_value=bson_value)) + const_name=_get_constant_enum_extra_data_name( + self._enum, enum_value + ), + bson_value=bson_value, + ) + ) indented_writer.write_empty_line() # Generate implementation of get_extra_data function. # template_params = { - 'enum_name': self.get_cpp_type_name(), - 'function_name': self.get_extra_data_declaration(), + "enum_name": self.get_cpp_type_name(), + "function_name": self.get_extra_data_declaration(), } with writer.TemplateContext(indented_writer, template_params): with writer.IndentedScopedBlock(indented_writer, "${function_name} {", "}"): - with writer.IndentedScopedBlock(indented_writer, "switch (value) {", "}"): + with writer.IndentedScopedBlock( + indented_writer, "switch (value) {", "}" + ): for enum_value in extra_values: indented_writer.write_template( - 'case ${enum_name}::%s: return %s;' % - (enum_value.name, - _get_constant_enum_extra_data_name(self._enum, enum_value))) + "case ${enum_name}::%s: return %s;" + % ( + enum_value.name, + _get_constant_enum_extra_data_name( + self._enum, enum_value + ), + ) + ) if len(extra_values) != len(self._enum.values): # One or more enums does not have associated extra data. - indented_writer.write_line('default: return BSONObj();') + indented_writer.write_line("default: return BSONObj();") if len(extra_values) == len(self._enum.values): # All enum cases handled, the compiler should know this. - indented_writer.write_line('MONGO_UNREACHABLE;') + indented_writer.write_line("MONGO_UNREACHABLE;") class _EnumTypeInt(EnumTypeInfoBase, metaclass=ABCMeta): @@ -210,17 +232,21 @@ class _EnumTypeInt(EnumTypeInfoBase, metaclass=ABCMeta): # type: () -> str return common.template_args( "${enum_name} ${function_name}(const IDLParserContext& ctxt, std::int32_t value)", - enum_name=self.get_cpp_type_name(), function_name=self._get_enum_deserializer_name()) + enum_name=self.get_cpp_type_name(), + function_name=self._get_enum_deserializer_name(), + ) def gen_deserializer_definition(self, indented_writer): # type: (writer.IndentedTextWriter) -> None - enum_values = sorted(cast(ast.Enum, self._enum).values, key=lambda ev: int(ev.value)) + enum_values = sorted( + cast(ast.Enum, self._enum).values, key=lambda ev: int(ev.value) + ) template_params = { - 'enum_name': self.get_cpp_type_name(), - 'function_name': self.get_deserializer_declaration(), - 'min_value': enum_values[0].name, - 'max_value': enum_values[-1].name, + "enum_name": self.get_cpp_type_name(), + "function_name": self.get_deserializer_declaration(), + "min_value": enum_values[0].name, + "max_value": enum_values[-1].name, } with writer.TemplateContext(indented_writer, template_params): @@ -235,33 +261,41 @@ class _EnumTypeInt(EnumTypeInfoBase, metaclass=ABCMeta): } else { return static_cast<${enum_name}>(value); } - """)) + """) + ) def get_serializer_declaration(self): # type: () -> str """Get the serializer function declaration minus trailing semicolon.""" - return common.template_args("std::int32_t ${function_name}(${enum_name} value)", - enum_name=self.get_cpp_type_name(), - function_name=self._get_enum_serializer_name()) + return common.template_args( + "std::int32_t ${function_name}(${enum_name} value)", + enum_name=self.get_cpp_type_name(), + function_name=self._get_enum_serializer_name(), + ) def gen_serializer_definition(self, indented_writer): # type: (writer.IndentedTextWriter) -> None """Generate the serializer function definition.""" template_params = { - 'enum_name': self.get_cpp_type_name(), - 'function_name': self.get_serializer_declaration(), + "enum_name": self.get_cpp_type_name(), + "function_name": self.get_serializer_declaration(), } with writer.TemplateContext(indented_writer, template_params): with writer.IndentedScopedBlock(indented_writer, "${function_name} {", "}"): - indented_writer.write_template('return static_cast(value);') + indented_writer.write_template( + "return static_cast(value);" + ) def _get_constant_enum_extra_data_name(idl_enum, enum_value): # type: (Union[syntax.Enum,ast.Enum], Union[syntax.EnumValue,ast.EnumValue]) -> str """Return the C++ name for a string constant of enum extra data value.""" - return common.template_args('k${enum_name}_${name}_extra_data', - enum_name=common.title_case(idl_enum.name), name=enum_value.name) + return common.template_args( + "k${enum_name}_${name}_extra_data", + enum_name=common.title_case(idl_enum.name), + name=enum_value.name, + ) class _EnumTypeString(EnumTypeInfoBase, metaclass=ABCMeta): @@ -269,8 +303,9 @@ class _EnumTypeString(EnumTypeInfoBase, metaclass=ABCMeta): def get_cpp_type_name(self): # type: () -> str - return common.template_args("${enum_name}Enum", - enum_name=common.title_case(self._enum.name)) + return common.template_args( + "${enum_name}Enum", enum_name=common.title_case(self._enum.name) + ) def get_bson_types(self): # type: () -> List[str] @@ -278,65 +313,76 @@ class _EnumTypeString(EnumTypeInfoBase, metaclass=ABCMeta): def get_cpp_value_assignment(self, enum_value): # type: (ast.EnumValue) -> str - return '' + return "" def get_deserializer_declaration(self): # type: () -> str cpp_type = self.get_cpp_type_name() func = self._get_enum_deserializer_name() - return f'{cpp_type} {func}(const IDLParserContext& ctxt, StringData value)' + return f"{cpp_type} {func}(const IDLParserContext& ctxt, StringData value)" def gen_deserializer_definition(self, indented_writer): # type: (writer.IndentedTextWriter) -> None cpp_type = self.get_cpp_type_name() func = self._get_enum_deserializer_name() - with writer.NamespaceScopeBlock(indented_writer, ['']): - with writer.IndentedScopedBlock(indented_writer, - f'constexpr std::array {cpp_type}_values{{', '};'): + with writer.NamespaceScopeBlock(indented_writer, [""]): + with writer.IndentedScopedBlock( + indented_writer, f"constexpr std::array {cpp_type}_values{{", "};" + ): for e in self._enum.values: - indented_writer.write_line(f'{cpp_type}::{e.name},') - with writer.IndentedScopedBlock(indented_writer, - f'constexpr std::array {cpp_type}_names{{', '};'): + indented_writer.write_line(f"{cpp_type}::{e.name},") + with writer.IndentedScopedBlock( + indented_writer, f"constexpr std::array {cpp_type}_names{{", "};" + ): for e in self._enum.values: indented_writer.write_line(f'"{e.value}"_sd,') indented_writer.write_empty_line() with writer.IndentedScopedBlock( + indented_writer, + f"{cpp_type} {func}(const IDLParserContext& ctxt, StringData value) {{", + "}", + ): + indented_writer.write_line( + f"static constexpr auto onMatch = [](int i) {{ return {cpp_type}_values[i]; }};" + ) + indented_writer.write_line( + f"auto onFail = [&] {{ ctxt.throwBadEnumValue(value); return {cpp_type}{{}}; }};" + ) + writer.gen_string_table_find_function_block( indented_writer, - f'{cpp_type} {func}(const IDLParserContext& ctxt, StringData value) {{', '}'): - indented_writer.write_line( - f'static constexpr auto onMatch = [](int i) {{ return {cpp_type}_values[i]; }};') - indented_writer.write_line( - f'auto onFail = [&] {{ ctxt.throwBadEnumValue(value); return {cpp_type}{{}}; }};') - writer.gen_string_table_find_function_block(indented_writer, 'value', 'onMatch({})', - 'onFail()', - [e.value for e in self._enum.values]) + "value", + "onMatch({})", + "onFail()", + [e.value for e in self._enum.values], + ) def get_serializer_declaration(self): # type: () -> str """Get the serializer function declaration minus trailing semicolon.""" cpp_type = self.get_cpp_type_name() func = self._get_enum_serializer_name() - return f'StringData {func}({cpp_type} value)' + return f"StringData {func}({cpp_type} value)" def gen_serializer_definition(self, indented_writer): # type: (writer.IndentedTextWriter) -> None """Generate the serializer function definition.""" func = self._get_enum_serializer_name() cpp_type = self.get_cpp_type_name() - with writer.IndentedScopedBlock(indented_writer, f'StringData {func}({cpp_type} value) {{', - '}'): - indented_writer.write_line('auto idx = static_cast(value);') - indented_writer.write_line(f'invariant(idx < {cpp_type}_names.size());') - indented_writer.write_line(f'return {cpp_type}_names[idx];') + with writer.IndentedScopedBlock( + indented_writer, f"StringData {func}({cpp_type} value) {{", "}" + ): + indented_writer.write_line("auto idx = static_cast(value);") + indented_writer.write_line(f"invariant(idx < {cpp_type}_names.size());") + indented_writer.write_line(f"return {cpp_type}_names[idx];") def get_type_info(idl_enum): # type: (Union[syntax.Enum,ast.Enum]) -> Optional[EnumTypeInfoBase] """Get the type information for a given enumeration type, return None if not supported.""" - if idl_enum.type == 'int': + if idl_enum.type == "int": return _EnumTypeInt(idl_enum) - elif idl_enum.type == 'string': + elif idl_enum.type == "string": return _EnumTypeString(idl_enum) return None diff --git a/buildscripts/idl/idl/errors.py b/buildscripts/idl/idl/errors.py index 9b2fa218538..318a8ec084c 100644 --- a/buildscripts/idl/idl/errors.py +++ b/buildscripts/idl/idl/errors.py @@ -173,8 +173,13 @@ class ParserError(common.SourceLocation): Example error message: test.idl: (17, 4): ID0008: Unknown IDL node 'cpp_namespac' for YAML entity 'global'. """ - msg = "%s: (%d, %d): %s: %s" % (os.path.basename(self.file_name), self.line, self.column, - self.error_id, self.msg) + msg = "%s: (%d, %d): %s: %s" % ( + os.path.basename(self.file_name), + self.line, + self.column, + self.error_id, + self.msg, + ) return msg # type: ignore @@ -190,7 +195,10 @@ class ParserErrorCollection(object): # type: (common.SourceLocation, str, str) -> None """Add an error message with file (line, column) information.""" self._errors.append( - ParserError(error_id, msg, location.file_name, location.line, location.column)) + ParserError( + error_id, msg, location.file_name, location.line, location.column + ) + ) def has_errors(self): # type: () -> bool @@ -223,7 +231,7 @@ class ParserErrorCollection(object): def __str__(self): # type: () -> str """Return a list of errors.""" - return ', '.join(self.to_list()) # type: ignore + return ", ".join(self.to_list()) # type: ignore class ParserContext(object): @@ -254,56 +262,82 @@ class ParserContext(object): # type: (yaml.nodes.Node, str, str) -> None """Add an error with source location information based on a YAML node.""" self.errors.add( - common.SourceLocation(self.file_name, node.start_mark.line, node.start_mark.column), - error_id, msg) + common.SourceLocation( + self.file_name, node.start_mark.line, node.start_mark.column + ), + error_id, + msg, + ) def add_unknown_root_node_error(self, node): # type: (yaml.nodes.Node) -> None """Add an error about an unknown YAML root node.""" self._add_node_error( - node, ERROR_ID_UNKNOWN_ROOT, - ("Unrecognized IDL specification root level node '%s', only " + - " (global, import, types, commands, and structs) are accepted") % (node.value)) + node, + ERROR_ID_UNKNOWN_ROOT, + ( + "Unrecognized IDL specification root level node '%s', only " + + " (global, import, types, commands, and structs) are accepted" + ) + % (node.value), + ) def add_unknown_node_error(self, node, name): # type: (yaml.nodes.Node, str) -> None """Add an error about an unknown node.""" - self._add_node_error(node, ERROR_ID_UNKNOWN_NODE, - "Unknown IDL node '%s' for YAML entity '%s'" % (node.value, name)) + self._add_node_error( + node, + ERROR_ID_UNKNOWN_NODE, + "Unknown IDL node '%s' for YAML entity '%s'" % (node.value, name), + ) - def add_duplicate_symbol_error(self, location, name, duplicate_class_name, original_class_name): + def add_duplicate_symbol_error( + self, location, name, duplicate_class_name, original_class_name + ): # type: (common.SourceLocation, str, str, str) -> None """Add an error about a duplicate symbol.""" self._add_error( - location, ERROR_ID_DUPLICATE_SYMBOL, "%s '%s' is a duplicate symbol of an existing %s" % - (duplicate_class_name, name, original_class_name)) + location, + ERROR_ID_DUPLICATE_SYMBOL, + "%s '%s' is a duplicate symbol of an existing %s" + % (duplicate_class_name, name, original_class_name), + ) def add_unknown_type_error(self, location, field_name, type_name): # type: (common.SourceLocation, str, str) -> None """Add an error about an unknown type.""" - self._add_error(location, ERROR_ID_UNKNOWN_TYPE, - "'%s' is an unknown type for field '%s'" % (type_name, field_name)) + self._add_error( + location, + ERROR_ID_UNKNOWN_TYPE, + "'%s' is an unknown type for field '%s'" % (type_name, field_name), + ) def add_unknown_command_type_error(self, location, struct_name): # type: (common.SourceLocation, str) -> None """Add an error about an unknown command type.""" - self._add_error(location, ERROR_ID_UNKNOWN_TYPE, - "'%s' is an unknown command type." % (struct_name)) + self._add_error( + location, + ERROR_ID_UNKNOWN_TYPE, + "'%s' is an unknown command type." % (struct_name), + ) def add_unknown_symbol_error(self, location, symbol_name): # type: (common.SourceLocation, str) -> None """Add an error about an unknown symbol.""" - self._add_error(location, ERROR_ID_UNKNOWN_TYPE, - "'%s' is an unknown symbol" % (symbol_name)) + self._add_error( + location, ERROR_ID_UNKNOWN_TYPE, "'%s' is an unknown symbol" % (symbol_name) + ) def _is_node_type(self, node, node_name, expected_node_type): # type: (Union[yaml.nodes.MappingNode, yaml.nodes.ScalarNode, yaml.nodes.SequenceNode], str, str) -> bool """Return True if the yaml node type is expected, otherwise returns False and logs an error.""" if not node.id == expected_node_type: self._add_node_error( - node, ERROR_ID_IS_NODE_TYPE, - "Illegal YAML node type '%s' for '%s', expected YAML node type '%s'" % - (node.id, node_name, expected_node_type)) + node, + ERROR_ID_IS_NODE_TYPE, + "Illegal YAML node type '%s' for '%s', expected YAML node type '%s'" + % (node.id, node_name, expected_node_type), + ) return False return True @@ -342,9 +376,11 @@ class ParserContext(object): """Return True if the YAML node is a Scalar or Sequence.""" if not node.id == "scalar" and not node.id == "sequence": self._add_node_error( - node, ERROR_ID_IS_NODE_TYPE_SCALAR_OR_SEQUENCE, + node, + ERROR_ID_IS_NODE_TYPE_SCALAR_OR_SEQUENCE, "Illegal node type '%s' for '%s', expected either node type 'scalar' or 'sequence'" - % (node.id, node_name)) + % (node.id, node_name), + ) return False if node.id == "sequence": @@ -357,9 +393,11 @@ class ParserContext(object): """Return True if the YAML node is a Scalar or Mapping.""" if not node.id == "scalar" and not node.id == "mapping": self._add_node_error( - node, ERROR_ID_IS_NODE_TYPE_SCALAR_OR_MAPPING, - "Illegal node type '%s' for '%s', expected either node type 'scalar' or 'mapping'" % - (node.id, node_name)) + node, + ERROR_ID_IS_NODE_TYPE_SCALAR_OR_MAPPING, + "Illegal node type '%s' for '%s', expected either node type 'scalar' or 'mapping'" + % (node.id, node_name), + ) return False return True @@ -372,8 +410,11 @@ class ParserContext(object): if node.value not in ["true", "false"]: self._add_node_error( - node, ERROR_ID_IS_NODE_VALID_BOOL, - "Illegal bool value for '%s', expected either 'true' or 'false'." % node_name) + node, + ERROR_ID_IS_NODE_VALID_BOOL, + "Illegal bool value for '%s', expected either 'true' or 'false'." + % node_name, + ) return False return True @@ -391,8 +432,11 @@ class ParserContext(object): # type: (Union[yaml.nodes.MappingNode, yaml.nodes.ScalarNode, yaml.nodes.SequenceNode]) -> bool boolean_value = yaml.safe_load(node.value) if not isinstance(boolean_value, bool): - self._add_node_error(node, ERROR_ID_IS_NODE_VALID_BOOL, - "Illegal bool value, expected either 'true' or 'false'.") + self._add_node_error( + node, + ERROR_ID_IS_NODE_VALID_BOOL, + "Illegal bool value, expected either 'true' or 'false'.", + ) return boolean_value def get_list(self, node): @@ -407,330 +451,532 @@ class ParserContext(object): def add_duplicate_error(self, node, node_name): # type: (yaml.nodes.Node, str) -> None """Add an error about a duplicate node.""" - self._add_node_error(node, ERROR_ID_DUPLICATE_NODE, - "Duplicate node found for '%s'" % (node_name)) + self._add_node_error( + node, ERROR_ID_DUPLICATE_NODE, "Duplicate node found for '%s'" % (node_name) + ) def add_missing_required_field_error(self, node, node_parent, node_name): # type: (yaml.nodes.Node, str, str) -> None """Add an error about a YAML node missing a required child.""" self._add_node_error( - node, ERROR_ID_MISSING_REQUIRED_FIELD, - "IDL node '%s' is missing required scalar '%s'" % (node_parent, node_name)) + node, + ERROR_ID_MISSING_REQUIRED_FIELD, + "IDL node '%s' is missing required scalar '%s'" % (node_parent, node_name), + ) - def add_missing_ast_required_field_error(self, location, ast_type, ast_parent, ast_name): + def add_missing_ast_required_field_error( + self, location, ast_type, ast_parent, ast_name + ): # type: (common.SourceLocation, str, str, str) -> None """Add an error about a AST node missing a required child.""" self._add_error( - location, ERROR_ID_MISSING_AST_REQUIRED_FIELD, - "%s '%s' is missing required scalar '%s'" % (ast_type, ast_parent, ast_name)) + location, + ERROR_ID_MISSING_AST_REQUIRED_FIELD, + "%s '%s' is missing required scalar '%s'" + % (ast_type, ast_parent, ast_name), + ) def add_array_not_valid_error(self, location, ast_type, name): # type: (common.SourceLocation, str, str) -> None """Add an error about a 'array' not being a valid type name.""" - self._add_error(location, ERROR_ID_ARRAY_NOT_VALID_TYPE, - "The %s '%s' cannot be named 'array'" % (ast_type, name)) + self._add_error( + location, + ERROR_ID_ARRAY_NOT_VALID_TYPE, + "The %s '%s' cannot be named 'array'" % (ast_type, name), + ) def add_bad_bson_type_error(self, location, ast_type, ast_parent, bson_type_name): # type: (common.SourceLocation, str, str, str) -> None """Add an error about a bad bson type.""" self._add_error( - location, ERROR_ID_BAD_BSON_TYPE, "BSON Type '%s' is not recognized for %s '%s'." % - (bson_type_name, ast_type, ast_parent)) + location, + ERROR_ID_BAD_BSON_TYPE, + "BSON Type '%s' is not recognized for %s '%s'." + % (bson_type_name, ast_type, ast_parent), + ) - def add_bad_bson_scalar_type_error(self, location, ast_type, ast_parent, bson_type_name): + def add_bad_bson_scalar_type_error( + self, location, ast_type, ast_parent, bson_type_name + ): # type: (common.SourceLocation, str, str, str) -> None """Add an error about a bad list of bson types.""" - self._add_error(location, ERROR_ID_BAD_BSON_TYPE_LIST, - ("BSON Type '%s' is not a scalar bson type for %s '%s'" + - " and cannot be used in a list of bson serialization types.") % - (bson_type_name, ast_type, ast_parent)) + self._add_error( + location, + ERROR_ID_BAD_BSON_TYPE_LIST, + ( + "BSON Type '%s' is not a scalar bson type for %s '%s'" + + " and cannot be used in a list of bson serialization types." + ) + % (bson_type_name, ast_type, ast_parent), + ) - def add_bad_bson_bindata_subtype_error(self, location, ast_type, ast_parent, bson_type_name): + def add_bad_bson_bindata_subtype_error( + self, location, ast_type, ast_parent, bson_type_name + ): # type: (common.SourceLocation, str, str, str) -> None """Add an error about a bindata_subtype associated with a type that is not bindata.""" - self._add_error(location, ERROR_ID_BAD_BSON_BINDATA_SUBTYPE_TYPE, - ("The bindata_subtype field for %s '%s' is not valid for bson type '%s'") % - (ast_type, ast_parent, bson_type_name)) + self._add_error( + location, + ERROR_ID_BAD_BSON_BINDATA_SUBTYPE_TYPE, + ("The bindata_subtype field for %s '%s' is not valid for bson type '%s'") + % (ast_type, ast_parent, bson_type_name), + ) - def add_bad_bson_bindata_subtype_value_error(self, location, ast_type, ast_parent, value): + def add_bad_bson_bindata_subtype_value_error( + self, location, ast_type, ast_parent, value + ): # type: (common.SourceLocation, str, str, str) -> None """Add an error about a bad value for bindata_subtype.""" - self._add_error(location, ERROR_ID_BAD_BSON_BINDATA_SUBTYPE_VALUE, - ("The bindata_subtype field's value '%s' for %s '%s' is not valid") % - (value, ast_type, ast_parent)) + self._add_error( + location, + ERROR_ID_BAD_BSON_BINDATA_SUBTYPE_VALUE, + ("The bindata_subtype field's value '%s' for %s '%s' is not valid") + % (value, ast_type, ast_parent), + ) def add_bad_setat_specifier(self, location, specifier): # type: (common.SourceLocation, str) -> None """Add an error about a bad set_at specifier.""" self._add_error( - location, ERROR_ID_BAD_SETAT_SPECIFIER, - ("Unexpected set_at specifier: '%s', expected 'startup' or 'runtime'") % (specifier)) + location, + ERROR_ID_BAD_SETAT_SPECIFIER, + ("Unexpected set_at specifier: '%s', expected 'startup' or 'runtime'") + % (specifier), + ) def add_no_string_data_error(self, location, ast_type, ast_parent): # type: (common.SourceLocation, str, str) -> None """Add an error about using StringData for cpp_type.""" - self._add_error(location, ERROR_ID_NO_STRINGDATA, - ("Do not use mongo::StringData for %s '%s', use std::string instead") % - (ast_type, ast_parent)) + self._add_error( + location, + ERROR_ID_NO_STRINGDATA, + ("Do not use mongo::StringData for %s '%s', use std::string instead") + % (ast_type, ast_parent), + ) def add_ignored_field_must_be_empty_error(self, location, name, field_name): # type: (common.SourceLocation, str, str) -> None """Add an error about field must be empty for ignored fields.""" self._add_error( - location, ERROR_ID_FIELD_MUST_BE_EMPTY_FOR_IGNORED, - ("Field '%s' cannot contain a value for property '%s' when a field is marked as ignored" - ) % (name, field_name)) + location, + ERROR_ID_FIELD_MUST_BE_EMPTY_FOR_IGNORED, + ( + "Field '%s' cannot contain a value for property '%s' when a field is marked as ignored" + ) + % (name, field_name), + ) def add_struct_default_must_be_true_or_empty_error(self, location, name): # type: (common.SourceLocation, str) -> None """Add an error about default must be True or empty for fields of type struct.""" - self._add_error(location, ERROR_ID_DEFAULT_MUST_BE_TRUE_OR_EMPTY_FOR_STRUCT, ( - "Field '%s' can only contain value 'true' for property 'default' when a field's type is a struct" - ) % (name)) + self._add_error( + location, + ERROR_ID_DEFAULT_MUST_BE_TRUE_OR_EMPTY_FOR_STRUCT, + ( + "Field '%s' can only contain value 'true' for property 'default' when a field's type is a struct" + ) + % (name), + ) def add_not_custom_scalar_serialization_not_supported_error( # pylint: disable=invalid-name - self, location, ast_type, ast_parent, bson_type_name): + self, location, ast_type, ast_parent, bson_type_name + ): # type: (common.SourceLocation, str, str, str) -> None """Add an error about field must be empty for fields of type struct.""" self._add_error( - location, ERROR_ID_CUSTOM_SCALAR_SERIALIZATION_NOT_SUPPORTED, - ("Custom serialization for a scalar is only supported for 'string'. The %s '%s' cannot" - + " use bson type '%s', use a bson_serialization_type of 'any' instead.") % - (ast_type, ast_parent, bson_type_name)) + location, + ERROR_ID_CUSTOM_SCALAR_SERIALIZATION_NOT_SUPPORTED, + ( + "Custom serialization for a scalar is only supported for 'string'. The %s '%s' cannot" + + " use bson type '%s', use a bson_serialization_type of 'any' instead." + ) + % (ast_type, ast_parent, bson_type_name), + ) def add_bad_any_type_use_error(self, location, bson_type, ast_type, ast_parent): # type: (common.SourceLocation, str, str, str) -> None """Add an error about any being used in a list of bson types.""" self._add_error( - location, ERROR_ID_BAD_ANY_TYPE_USE, - ("The BSON Type '%s' is not allowed in a list of bson serialization types for" + - "%s '%s'. It must be only a single bson type.") % (bson_type, ast_type, ast_parent)) + location, + ERROR_ID_BAD_ANY_TYPE_USE, + ( + "The BSON Type '%s' is not allowed in a list of bson serialization types for" + + "%s '%s'. It must be only a single bson type." + ) + % (bson_type, ast_type, ast_parent), + ) - def add_bad_cpp_numeric_type_use_error(self, location, ast_type, ast_parent, cpp_type): + def add_bad_cpp_numeric_type_use_error( + self, location, ast_type, ast_parent, cpp_type + ): # type: (common.SourceLocation, str, str, str) -> None """Add an error about any being used in a list of bson types.""" self._add_error( - location, ERROR_ID_BAD_NUMERIC_CPP_TYPE, - ("The C++ numeric type '%s' is not allowed for %s '%s'. Only 'std::int32_t'," + - " 'std::uint32_t', 'std::uint64_t', and 'std::int64_t' are supported.") % - (cpp_type, ast_type, ast_parent)) + location, + ERROR_ID_BAD_NUMERIC_CPP_TYPE, + ( + "The C++ numeric type '%s' is not allowed for %s '%s'. Only 'std::int32_t'," + + " 'std::uint32_t', 'std::uint64_t', and 'std::int64_t' are supported." + ) + % (cpp_type, ast_type, ast_parent), + ) def add_bad_array_type_name_error(self, location, field_name, type_name): # type: (common.SourceLocation, str, str) -> None """Add an error about a field type having a malformed type name.""" - self._add_error(location, ERROR_ID_BAD_ARRAY_TYPE_NAME, - ("'%s' is not a valid array type for field '%s'. A valid array type" + - " is in the form 'array'.") % (type_name, field_name)) + self._add_error( + location, + ERROR_ID_BAD_ARRAY_TYPE_NAME, + ( + "'%s' is not a valid array type for field '%s'. A valid array type" + + " is in the form 'array'." + ) + % (type_name, field_name), + ) def add_array_no_default_error(self, location, field_name): # type: (common.SourceLocation, str) -> None """Add an error about an array having a type with a default value.""" self._add_error( - location, ERROR_ID_ARRAY_NO_DEFAULT, - "Field '%s' is not allowed to have both a default value and be an array type" % - (field_name)) + location, + ERROR_ID_ARRAY_NO_DEFAULT, + "Field '%s' is not allowed to have both a default value and be an array type" + % (field_name), + ) def add_cannot_find_import(self, location, imported_file_name): # type: (common.SourceLocation, str) -> None """Add an error about not being able to find an import.""" - self._add_error(location, ERROR_ID_BAD_IMPORT, - "Could not resolve import '%s', file not found" % (imported_file_name)) + self._add_error( + location, + ERROR_ID_BAD_IMPORT, + "Could not resolve import '%s', file not found" % (imported_file_name), + ) def add_bindata_no_default(self, location, ast_type, ast_parent): # type: (common.SourceLocation, str, str) -> None """Add an error about a bindata type with a default value.""" - self._add_error(location, ERROR_ID_BAD_BINDATA_DEFAULT, - ("Default values are not allowed for %s '%s'") % (ast_type, ast_parent)) + self._add_error( + location, + ERROR_ID_BAD_BINDATA_DEFAULT, + ("Default values are not allowed for %s '%s'") % (ast_type, ast_parent), + ) def add_chained_type_not_found_error(self, location, type_name): # type: (common.SourceLocation, str) -> None """Add an error about a chained_type not found.""" - self._add_error(location, ERROR_ID_CHAINED_TYPE_NOT_FOUND, - ("Type '%s' is not a valid chained type") % (type_name)) + self._add_error( + location, + ERROR_ID_CHAINED_TYPE_NOT_FOUND, + ("Type '%s' is not a valid chained type") % (type_name), + ) def add_chained_type_wrong_type_error(self, location, type_name, bson_type_name): # type: (common.SourceLocation, str, str) -> None """Add an error about a chained_type being the wrong type.""" - self._add_error(location, ERROR_ID_CHAINED_TYPE_WRONG_BSON_TYPE, - ("Chained Type '%s' has the wrong bson serialization type '%s', only" + - "'chain' is supported for chained types.") % (type_name, bson_type_name)) + self._add_error( + location, + ERROR_ID_CHAINED_TYPE_WRONG_BSON_TYPE, + ( + "Chained Type '%s' has the wrong bson serialization type '%s', only" + + "'chain' is supported for chained types." + ) + % (type_name, bson_type_name), + ) - def add_duplicate_field_error(self, location, field_container, field_name, duplicate_location): + def add_duplicate_field_error( + self, location, field_container, field_name, duplicate_location + ): # type: (common.SourceLocation, str, str, common.SourceLocation) -> None """Add an error about duplicate fields as a result of chained structs/types.""" self._add_error( - location, ERROR_ID_CHAINED_DUPLICATE_FIELD, - ("Chained Struct or Type '%s' duplicates an existing field '%s' at location" + "'%s'.") - % (field_container, field_name, duplicate_location)) + location, + ERROR_ID_CHAINED_DUPLICATE_FIELD, + ( + "Chained Struct or Type '%s' duplicates an existing field '%s' at location" + + "'%s'." + ) + % (field_container, field_name, duplicate_location), + ) def add_chained_type_no_strict_error(self, location, struct_name): # type: (common.SourceLocation, str) -> None """Add an error about strict parser validate and chained types.""" - self._add_error(location, ERROR_ID_CHAINED_NO_TYPE_STRICT, - ("Strict IDL parser validation is not supported with chained types for " + - "struct '%s'. Specify 'strict: false' for this struct.") % (struct_name)) + self._add_error( + location, + ERROR_ID_CHAINED_NO_TYPE_STRICT, + ( + "Strict IDL parser validation is not supported with chained types for " + + "struct '%s'. Specify 'strict: false' for this struct." + ) + % (struct_name), + ) def add_chained_struct_not_found_error(self, location, struct_name): # type: (common.SourceLocation, str) -> None """Add an error about a chained_struct not found.""" - self._add_error(location, ERROR_ID_CHAINED_STRUCT_NOT_FOUND, - ("Type '%s' is not a valid chained struct") % (struct_name)) + self._add_error( + location, + ERROR_ID_CHAINED_STRUCT_NOT_FOUND, + ("Type '%s' is not a valid chained struct") % (struct_name), + ) - def add_chained_nested_struct_no_strict_error(self, location, struct_name, nested_struct_name): + def add_chained_nested_struct_no_strict_error( + self, location, struct_name, nested_struct_name + ): # type: (common.SourceLocation, str, str) -> None """Add an error about strict parser validate and chained types.""" - self._add_error(location, ERROR_ID_CHAINED_NO_NESTED_STRUCT_STRICT, - ("Strict IDL parser validation is not supported for a chained struct '%s'" + - " contained by struct '%s'. Specify 'strict: false' for this struct.") % - (nested_struct_name, struct_name)) + self._add_error( + location, + ERROR_ID_CHAINED_NO_NESTED_STRUCT_STRICT, + ( + "Strict IDL parser validation is not supported for a chained struct '%s'" + + " contained by struct '%s'. Specify 'strict: false' for this struct." + ) + % (nested_struct_name, struct_name), + ) - def add_chained_nested_struct_no_nested_error(self, location, struct_name, chained_name): + def add_chained_nested_struct_no_nested_error( + self, location, struct_name, chained_name + ): # type: (common.SourceLocation, str, str) -> None """Add an error about struct's chaining being a struct with nested chaining.""" - self._add_error(location, ERROR_ID_CHAINED_NO_NESTED_CHAINED, - ("Struct '%s' is not allowed to nest struct '%s' since it has chained" + - " structs and/or types.") % (struct_name, chained_name)) + self._add_error( + location, + ERROR_ID_CHAINED_NO_NESTED_CHAINED, + ( + "Struct '%s' is not allowed to nest struct '%s' since it has chained" + + " structs and/or types." + ) + % (struct_name, chained_name), + ) def add_empty_enum_error(self, node, name): # type: (yaml.nodes.Node, str) -> None """Add an error about an enum without values.""" self._add_node_error( - node, ERROR_ID_BAD_EMPTY_ENUM, - "Enum '%s' must have values specified but no values were found" % (name)) + node, + ERROR_ID_BAD_EMPTY_ENUM, + "Enum '%s' must have values specified but no values were found" % (name), + ) def add_array_enum_error(self, location, field_name): # type: (common.SourceLocation, str) -> None """Add an error for a field being an array of enums.""" - self._add_error(location, ERROR_ID_NO_ARRAY_ENUM, - "Field '%s' cannot be an array of enums" % (field_name)) + self._add_error( + location, + ERROR_ID_NO_ARRAY_ENUM, + "Field '%s' cannot be an array of enums" % (field_name), + ) def add_enum_bad_type_error(self, location, enum_name, enum_type): # type: (common.SourceLocation, str, str) -> None """Add an error for an enum having the wrong type.""" - self._add_error(location, ERROR_ID_ENUM_BAD_TYPE, - "Enum '%s' type '%s' is not a supported enum type" % (enum_name, enum_type)) + self._add_error( + location, + ERROR_ID_ENUM_BAD_TYPE, + "Enum '%s' type '%s' is not a supported enum type" % (enum_name, enum_type), + ) def add_enum_value_not_int_error(self, location, enum_name, enum_value, err_msg): # type: (common.SourceLocation, str, str, str) -> None """Add an error for an enum value not being an integer.""" self._add_error( - location, ERROR_ID_ENUM_BAD_INT_VAUE, - "Enum '%s' value '%s' is not an integer, exception '%s'" % (enum_name, enum_value, - err_msg)) + location, + ERROR_ID_ENUM_BAD_INT_VAUE, + "Enum '%s' value '%s' is not an integer, exception '%s'" + % (enum_name, enum_value, err_msg), + ) def add_enum_value_not_unique_error(self, location, enum_name): # type: (common.SourceLocation, str) -> None """Add an error for an enum having duplicate values.""" - self._add_error(location, ERROR_ID_ENUM_NON_UNIQUE_VALUES, - "Enum '%s' has duplicate values, all values must be unique" % (enum_name)) + self._add_error( + location, + ERROR_ID_ENUM_NON_UNIQUE_VALUES, + "Enum '%s' has duplicate values, all values must be unique" % (enum_name), + ) - def add_bad_command_namespace_error(self, location, command_name, command_namespace, - valid_commands): + def add_bad_command_namespace_error( + self, location, command_name, command_namespace, valid_commands + ): # type: (common.SourceLocation, str, str, List[str]) -> None """Add an error about the namespace value not being a valid choice.""" self._add_error( - location, ERROR_ID_BAD_COMMAND_NAMESPACE, + location, + ERROR_ID_BAD_COMMAND_NAMESPACE, "Command namespace '%s' for command '%s' is not a valid choice. Valid options are '%s'." - % (command_namespace, command_name, valid_commands)) + % (command_namespace, command_name, valid_commands), + ) def add_bad_command_as_field_error(self, location, command_name): # type: (common.SourceLocation, str) -> None """Add an error about using a command for a field.""" - self._add_error(location, ERROR_ID_FIELD_NO_COMMAND, - ("Command '%s' cannot be used as a field type'. Commands must be top-level" - + " types due to their serialization rules.") % (command_name)) + self._add_error( + location, + ERROR_ID_FIELD_NO_COMMAND, + ( + "Command '%s' cannot be used as a field type'. Commands must be top-level" + + " types due to their serialization rules." + ) + % (command_name), + ) def add_bad_array_of_chain(self, location, field_name): # type: (common.SourceLocation, str) -> None """Add an error about a field being an array of chain_types.""" - self._add_error(location, ERROR_ID_NO_ARRAY_OF_CHAIN, - "Field '%s' cannot be an array of chained types" % (field_name)) + self._add_error( + location, + ERROR_ID_NO_ARRAY_OF_CHAIN, + "Field '%s' cannot be an array of chained types" % (field_name), + ) def add_bad_field_default_and_optional(self, location, field_name): # type: (common.SourceLocation, str) -> None """Add an error about a field being optional and having a default value.""" self._add_error( - location, ERROR_ID_ILLEGAL_FIELD_DEFAULT_AND_OPTIONAL, - ("Field '%s' can only be marked as optional or have a default value," + " not both.") % - (field_name)) + location, + ERROR_ID_ILLEGAL_FIELD_DEFAULT_AND_OPTIONAL, + ( + "Field '%s' can only be marked as optional or have a default value," + + " not both." + ) + % (field_name), + ) - def add_bad_struct_field_as_doc_sequence_error(self, location, struct_name, field_name): + def add_bad_struct_field_as_doc_sequence_error( + self, location, struct_name, field_name + ): # type: (common.SourceLocation, str, str) -> None """Add an error about using a field in a struct being marked with supports_doc_sequence.""" - self._add_error(location, ERROR_ID_STRUCT_NO_DOC_SEQUENCE, - ("Field '%s' in struct '%s' cannot be used as a Command Document Sequence" - " type. They are only supported in commands.") % (field_name, struct_name)) + self._add_error( + location, + ERROR_ID_STRUCT_NO_DOC_SEQUENCE, + ( + "Field '%s' in struct '%s' cannot be used as a Command Document Sequence" + " type. They are only supported in commands." + ) + % (field_name, struct_name), + ) - def add_bad_non_array_as_doc_sequence_error(self, location, struct_name, field_name): + def add_bad_non_array_as_doc_sequence_error( + self, location, struct_name, field_name + ): # type: (common.SourceLocation, str, str) -> None """Add an error about using a non-array type field being marked with supports_doc_sequence.""" - self._add_error(location, ERROR_ID_NO_DOC_SEQUENCE_FOR_NON_ARRAY, - ("Field '%s' in command '%s' cannot be used as a Command Document Sequence" - " type since it is not an array.") % (field_name, struct_name)) + self._add_error( + location, + ERROR_ID_NO_DOC_SEQUENCE_FOR_NON_ARRAY, + ( + "Field '%s' in command '%s' cannot be used as a Command Document Sequence" + " type since it is not an array." + ) + % (field_name, struct_name), + ) def add_bad_non_object_as_doc_sequence_error(self, location, field_name): # type: (common.SourceLocation, str) -> None """Add an error about using a non-struct or BSON object for a doc sequence.""" - self._add_error(location, ERROR_ID_NO_DOC_SEQUENCE_FOR_NON_OBJECT, - ("Field '%s' cannot be used as a Command Document Sequence" - " type since it is not a BSON object or struct.") % (field_name)) + self._add_error( + location, + ERROR_ID_NO_DOC_SEQUENCE_FOR_NON_OBJECT, + ( + "Field '%s' cannot be used as a Command Document Sequence" + " type since it is not a BSON object or struct." + ) + % (field_name), + ) def add_bad_command_name_duplicates_field(self, location, command_name): # type: (common.SourceLocation, str) -> None """Add an error about a command and field having the same name.""" - self._add_error(location, ERROR_ID_COMMAND_DUPLICATES_FIELD, - ("Command '%s' cannot have the same name as a field.") % (command_name)) + self._add_error( + location, + ERROR_ID_COMMAND_DUPLICATES_FIELD, + ("Command '%s' cannot have the same name as a field.") % (command_name), + ) def add_bad_field_non_const_getter_in_immutable_struct_error( # pylint: disable=invalid-name - self, location, struct_name, field_name): + self, location, struct_name, field_name + ): # type: (common.SourceLocation, str, str) -> None """Add an error about marking a field with non_const_getter in an immutable struct.""" self._add_error( - location, ERROR_ID_NON_CONST_GETTER_IN_IMMUTABLE_STRUCT, - ("Cannot generate a non-const getter for field '%s' in struct '%s' since" - " struct '%s' is marked as immutable.") % (field_name, struct_name, struct_name)) + location, + ERROR_ID_NON_CONST_GETTER_IN_IMMUTABLE_STRUCT, + ( + "Cannot generate a non-const getter for field '%s' in struct '%s' since" + " struct '%s' is marked as immutable." + ) + % (field_name, struct_name, struct_name), + ) def add_useless_variant_error(self, location): # type: (common.SourceLocation) -> None """Add an error about a variant with 0 or 1 variant types.""" - self._add_error(location, ERROR_ID_USELESS_VARIANT, - ("Cannot declare a variant with only 0 or 1 variant types")) + self._add_error( + location, + ERROR_ID_USELESS_VARIANT, + ("Cannot declare a variant with only 0 or 1 variant types"), + ) def add_variant_comparison_error(self, location): # type: (common.SourceLocation) -> None """Add an error about a struct with generate_comparison_operators and a variant field.""" - self._add_error(location, ERROR_ID_VARIANT_COMPARISON, - ("generate_comparison_operators is not supported with variant types")) + self._add_error( + location, + ERROR_ID_VARIANT_COMPARISON, + ("generate_comparison_operators is not supported with variant types"), + ) def add_variant_duplicate_types_error(self, location, field_name, type_name): # type: (common.SourceLocation, str, str) -> None """Add an error about a variant having more than one alternative of the same BSON type.""" self._add_error( - location, ERROR_ID_VARIANT_DUPLICATE_TYPES, - ("Variant field '%s' has multiple alternatives with BSON type '%s', this is prohibited" - " to avoid ambiguity while parsing BSON.") % (field_name, type_name)) + location, + ERROR_ID_VARIANT_DUPLICATE_TYPES, + ( + "Variant field '%s' has multiple alternatives with BSON type '%s', this is prohibited" + " to avoid ambiguity while parsing BSON." + ) + % (field_name, type_name), + ) def add_variant_structs_error(self, location, field_name): # type: (common.SourceLocation, str) -> None """Add an error about a variant having more than one struct alternative.""" - self._add_error(location, ERROR_ID_VARIANT_STRUCTS, - ("Variant field '%s' has multiple struct alternatives, this is prohibited" - " to avoid ambiguity while parsing BSON subdocuments.") % (field_name, )) + self._add_error( + location, + ERROR_ID_VARIANT_STRUCTS, + ( + "Variant field '%s' has multiple struct alternatives, this is prohibited" + " to avoid ambiguity while parsing BSON subdocuments." + ) + % (field_name,), + ) def add_variant_enum_error(self, location, field_name, type_name): # type: (common.SourceLocation, str, str) -> None """Add an error for a variant that can be an enum.""" self._add_error( - location, ERROR_ID_NO_VARIANT_ENUM, - "Field '%s' cannot be a variant with an enum alternative type '%s'" % (field_name, - type_name)) + location, + ERROR_ID_NO_VARIANT_ENUM, + "Field '%s' cannot be a variant with an enum alternative type '%s'" + % (field_name, type_name), + ) def add_bad_array_variant_types_error(self, location, type_name): # type: (common.SourceLocation, str) -> None """Add an error about a field type having a malformed type name.""" - self._add_error(location, ERROR_ID_INVALID_ARRAY_VARIANT, - ("'%s' is not a valid array variant type. A valid array variant type" + - " is in the form 'array>'.") % (type_name)) + self._add_error( + location, + ERROR_ID_INVALID_ARRAY_VARIANT, + ( + "'%s' is not a valid array variant type. A valid array variant type" + + " is in the form 'array>'." + ) + % (type_name), + ) def is_scalar_non_negative_int_node(self, node, node_name): # type: (Union[yaml.nodes.MappingNode, yaml.nodes.ScalarNode, yaml.nodes.SequenceNode], str) -> bool @@ -742,15 +988,20 @@ class ParserContext(object): value = int(node.value) if value < 0: self._add_node_error( - node, ERROR_ID_IS_NODE_VALID_NON_NEGATIVE_INT, - "Illegal negative integer value for '%s', expected 0 or positive integer." % - (node_name)) + node, + ERROR_ID_IS_NODE_VALID_NON_NEGATIVE_INT, + "Illegal negative integer value for '%s', expected 0 or positive integer." + % (node_name), + ) return False except ValueError as value_error: self._add_node_error( - node, ERROR_ID_IS_NODE_VALID_INT, - "Illegal integer value for '%s', message '%s'." % (node_name, value_error)) + node, + ERROR_ID_IS_NODE_VALID_INT, + "Illegal integer value for '%s', message '%s'." + % (node_name, value_error), + ) return False return True @@ -762,251 +1013,363 @@ class ParserContext(object): return int(node.value) - def add_duplicate_comparison_order_field_error(self, location, struct_name, comparison_order): + def add_duplicate_comparison_order_field_error( + self, location, struct_name, comparison_order + ): # type: (common.SourceLocation, str, int) -> None """Add an error about fields having duplicate comparison_orders.""" self._add_error( - location, ERROR_ID_IS_DUPLICATE_COMPARISON_ORDER, - ("Struct '%s' cannot have two fields with the same comparison_order value '%d'.") % - (struct_name, comparison_order)) + location, + ERROR_ID_IS_DUPLICATE_COMPARISON_ORDER, + ( + "Struct '%s' cannot have two fields with the same comparison_order value '%d'." + ) + % (struct_name, comparison_order), + ) def add_extranous_command_type(self, location, command_name): # type: (common.SourceLocation, str) -> None """Add an error about commands having type when not needed.""" self._add_error( - location, ERROR_ID_IS_COMMAND_TYPE_EXTRANEOUS, - ("Command '%s' cannot have a 'type' property unless namespace equals 'type'.") % - (command_name)) + location, + ERROR_ID_IS_COMMAND_TYPE_EXTRANEOUS, + ( + "Command '%s' cannot have a 'type' property unless namespace equals 'type'." + ) + % (command_name), + ) def add_value_not_numeric_error(self, location, attrname, value): # type: (common.SourceLocation, str, str) -> None """Add an error about non-numeric value where number expected.""" self._add_error( - location, ERROR_ID_VALUE_NOT_NUMERIC, - ("'%s' requires a numeric value, but %s can not be cast") % (attrname, value)) + location, + ERROR_ID_VALUE_NOT_NUMERIC, + ("'%s' requires a numeric value, but %s can not be cast") + % (attrname, value), + ) def add_server_parameter_invalid_attr(self, location, attrname, conflicts): # type: (common.SourceLocation, str, str) -> None """Add an error about invalid fields in a server parameter definition.""" self._add_error( - location, ERROR_ID_SERVER_PARAMETER_INVALID_ATTR, - ("'%s' attribute not permitted with '%s' server parameter") % (attrname, conflicts)) + location, + ERROR_ID_SERVER_PARAMETER_INVALID_ATTR, + ("'%s' attribute not permitted with '%s' server parameter") + % (attrname, conflicts), + ) - def add_server_parameter_required_attr(self, location, attrname, required, dependant=None): + def add_server_parameter_required_attr( + self, location, attrname, required, dependant=None + ): # type: (common.SourceLocation, str, str, str) -> None """Add an error about missing fields in a server parameter definition.""" - qualifier = '' if dependant is None else (" when using '%s' attribute" % (dependant)) - self._add_error(location, ERROR_ID_SERVER_PARAMETER_REQUIRED_ATTR, - ("'%s' attribute required%s with '%s' server parameter") % - (attrname, qualifier, required)) + qualifier = ( + "" if dependant is None else (" when using '%s' attribute" % (dependant)) + ) + self._add_error( + location, + ERROR_ID_SERVER_PARAMETER_REQUIRED_ATTR, + ("'%s' attribute required%s with '%s' server parameter") + % (attrname, qualifier, required), + ) def add_server_parameter_invalid_method_override(self, location, method): # type: (common.SourceLocation, str) -> None """Add an error about invalid method override in SCP method override.""" - self._add_error(location, ERROR_ID_SERVER_PARAMETER_INVALID_METHOD_OVERRIDE, - ("No such method to override in server parameter class: '%s'") % (method)) + self._add_error( + location, + ERROR_ID_SERVER_PARAMETER_INVALID_METHOD_OVERRIDE, + ("No such method to override in server parameter class: '%s'") % (method), + ) def add_bad_source_specifier(self, location, value): # type: (common.SourceLocation, str) -> None """Add an error about invalid source specifier.""" - self._add_error(location, ERROR_ID_BAD_SOURCE_SPECIFIER, - ("'%s' is not a valid source specifier") % (value)) + self._add_error( + location, + ERROR_ID_BAD_SOURCE_SPECIFIER, + ("'%s' is not a valid source specifier") % (value), + ) def add_bad_duplicate_behavior(self, location, value): # type: (common.SourceLocation, str) -> None """Add an error about invalid duplicate behavior specifier.""" - self._add_error(location, ERROR_ID_BAD_DUPLICATE_BEHAVIOR_SPECIFIER, - ("'%s' is not a valid duplicate behavior specifier") % (value)) + self._add_error( + location, + ERROR_ID_BAD_DUPLICATE_BEHAVIOR_SPECIFIER, + ("'%s' is not a valid duplicate behavior specifier") % (value), + ) def add_bad_numeric_range(self, location, attrname, value): # type: (common.SourceLocation, str, str) -> None """Add an error about invalid range specifier.""" - self._add_error(location, ERROR_ID_BAD_NUMERIC_RANGE, - ("'%s' is not a valid numeric range for '%s'") % (value, attrname)) + self._add_error( + location, + ERROR_ID_BAD_NUMERIC_RANGE, + ("'%s' is not a valid numeric range for '%s'") % (value, attrname), + ) def add_missing_shortname_for_positional_arg(self, location): # type: (common.SourceLocation) -> None """Add an error about required short_name for positional args.""" - self._add_error(location, ERROR_ID_MISSING_SHORTNAME_FOR_POSITIONAL, - "Missing 'short_name' for positional arg") + self._add_error( + location, + ERROR_ID_MISSING_SHORTNAME_FOR_POSITIONAL, + "Missing 'short_name' for positional arg", + ) def add_invalid_short_name(self, location, name): # type: (common.SourceLocation, str) -> None """Add an error about invalid short names.""" - self._add_error(location, ERROR_ID_INVALID_SHORT_NAME, - ("Invalid 'short_name' value '%s'") % (name)) + self._add_error( + location, + ERROR_ID_INVALID_SHORT_NAME, + ("Invalid 'short_name' value '%s'") % (name), + ) def add_invalid_single_name(self, location, name): # type: (common.SourceLocation, str) -> None """Add an error about invalid single names.""" - self._add_error(location, ERROR_ID_INVALID_SINGLE_NAME, - ("Invalid 'single_name' value '%s'") % (name)) + self._add_error( + location, + ERROR_ID_INVALID_SINGLE_NAME, + ("Invalid 'single_name' value '%s'") % (name), + ) def add_missing_short_name_with_single_name(self, location, name): # type: (common.SourceLocation, str) -> None """Add an error about missing required short name when using single name.""" - self._add_error(location, ERROR_ID_MISSING_SHORT_NAME_WITH_SINGLE_NAME, - ("Missing 'short_name' required with 'single_name' value '%s'") % (name)) + self._add_error( + location, + ERROR_ID_MISSING_SHORT_NAME_WITH_SINGLE_NAME, + ("Missing 'short_name' required with 'single_name' value '%s'") % (name), + ) def add_feature_flag_default_true_missing_version(self, location): # type: (common.SourceLocation) -> None """Add an error about a default flag with a default value of true and should be FCV gated but no version.""" - self._add_error(location, ERROR_ID_FEATURE_FLAG_DEFAULT_TRUE_MISSING_VERSION, ( - "Missing 'version' required for feature flag that defaults to true and should be FCV gated" - )) + self._add_error( + location, + ERROR_ID_FEATURE_FLAG_DEFAULT_TRUE_MISSING_VERSION, + ( + "Missing 'version' required for feature flag that defaults to true and should be FCV gated" + ), + ) def add_feature_flag_default_false_has_version(self, location): # type: (common.SourceLocation) -> None """Add an error about a default flag with a default value of false but has a version.""" self._add_error( - location, ERROR_ID_FEATURE_FLAG_DEFAULT_FALSE_HAS_VERSION, - ("The 'version' attribute is not allowed for feature flag that defaults to false")) + location, + ERROR_ID_FEATURE_FLAG_DEFAULT_FALSE_HAS_VERSION, + ( + "The 'version' attribute is not allowed for feature flag that defaults to false" + ), + ) def add_feature_flag_fcv_gated_false_has_version(self, location): # type: (common.SourceLocation) -> None """Add an error about a feature flag that should not be FCV gated but has a version.""" - self._add_error(location, ERROR_ID_FEATURE_FLAG_SHOULD_BE_FCV_GATED_FALSE_HAS_VERSION, ( - "The 'version' attribute is not allowed for feature flag that should not be FCV gated")) + self._add_error( + location, + ERROR_ID_FEATURE_FLAG_SHOULD_BE_FCV_GATED_FALSE_HAS_VERSION, + ( + "The 'version' attribute is not allowed for feature flag that should not be FCV gated" + ), + ) def add_reply_type_invalid_type(self, location, command_name, reply_type_name): # type: (common.SourceLocation, str, str) -> None """Add an error about a command whose reply_type refers to an unknown type.""" self._add_error( - location, ERROR_ID_INVALID_REPLY_TYPE, - ("Command '%s' has invalid reply_type '%s'" % (command_name, reply_type_name))) + location, + ERROR_ID_INVALID_REPLY_TYPE, + ( + "Command '%s' has invalid reply_type '%s'" + % (command_name, reply_type_name) + ), + ) def add_stability_no_api_version(self, location, command_name): # type: (common.SourceLocation, str) -> None """Add an error about a command with 'stability' but no 'api_version'.""" self._add_error( - location, ERROR_ID_STABILITY_NO_API_VERSION, - ("Command '%s' specifies 'stability' but has no 'api_version'" % (command_name, ))) + location, + ERROR_ID_STABILITY_NO_API_VERSION, + ( + "Command '%s' specifies 'stability' but has no 'api_version'" + % (command_name,) + ), + ) def add_missing_reply_type(self, location, command_name): # type: (common.SourceLocation, str) -> None """Add an error about a command with 'api_version' but no 'reply_type'.""" self._add_error( - location, ERROR_ID_MISSING_REPLY_TYPE, - ("Command '%s' has an 'api_version' but no 'reply_type'" % (command_name, ))) + location, + ERROR_ID_MISSING_REPLY_TYPE, + ("Command '%s' has an 'api_version' but no 'reply_type'" % (command_name,)), + ) def add_bad_field_always_serialize_not_optional(self, location, field_name): # type: (common.SourceLocation, str) -> None """Add an error about a field with 'always_serialize' but 'optional' isn't set to true.""" self._add_error( - location, ERROR_ID_ILLEGAL_FIELD_ALWAYS_SERIALIZE_NOT_OPTIONAL, - ("Field '%s' specifies 'always_serialize' but 'optional' isn't true.") % (field_name)) + location, + ERROR_ID_ILLEGAL_FIELD_ALWAYS_SERIALIZE_NOT_OPTIONAL, + ("Field '%s' specifies 'always_serialize' but 'optional' isn't true.") + % (field_name), + ) def add_duplicate_command_name_and_alias(self, node): # type: (yaml.nodes.Node) -> None """Add an error about a command name and command alias having the same name.""" - self._add_node_error(node, ERROR_ID_COMMAND_DUPLICATES_NAME_AND_ALIAS, - "Duplicate command_name and command_alias found.") + self._add_node_error( + node, + ERROR_ID_COMMAND_DUPLICATES_NAME_AND_ALIAS, + "Duplicate command_name and command_alias found.", + ) def add_unknown_enum_value(self, location, enum_name, enum_value): # type: (common.SourceLocation, str, str) -> None """Add an error about an unknown enum value.""" - self._add_error(location, ERROR_ID_UNKOWN_ENUM_VALUE, - "Cannot find enum value '%s' in enum '%s'." % (enum_value, enum_name)) + self._add_error( + location, + ERROR_ID_UNKOWN_ENUM_VALUE, + "Cannot find enum value '%s' in enum '%s'." % (enum_value, enum_name), + ) def add_either_check_or_privilege(self, location): # type: (common.SourceLocation) -> None """Add an error about specifing both a check and a privilege or neither.""" self._add_error( - location, ERROR_ID_EITHER_CHECK_OR_PRIVILEGE, - "Must specify either a 'check' and a 'privilege' in an access_check, not both.") + location, + ERROR_ID_EITHER_CHECK_OR_PRIVILEGE, + "Must specify either a 'check' and a 'privilege' in an access_check, not both.", + ) def add_duplicate_action_types(self, location, name): # type: (common.SourceLocation, str) -> None """Add an error about specifying an action type twice in the same list.""" - self._add_error(location, ERROR_ID_DUPLICATE_ACTION_TYPE, - "Cannot specify an action_type '%s' more then once" % (name)) + self._add_error( + location, + ERROR_ID_DUPLICATE_ACTION_TYPE, + "Cannot specify an action_type '%s' more then once" % (name), + ) def add_duplicate_access_check(self, location, name): # type: (common.SourceLocation, str) -> None """Add an error about specifying an access check twice in the same list.""" - self._add_error(location, ERROR_ID_DUPLICATE_ACCESS_CHECK, - "Cannot specify an access_check '%s' more then once" % (name)) + self._add_error( + location, + ERROR_ID_DUPLICATE_ACCESS_CHECK, + "Cannot specify an access_check '%s' more then once" % (name), + ) def add_duplicate_privilege(self, location, resource_pattern, action_type): # type: (common.SourceLocation, str, str) -> None """Add an error about specifying a privilege twice in the same list.""" self._add_error( - location, ERROR_ID_DUPLICATE_PRIVILEGE, - "Cannot specify the pair of resource_pattern '%s' and action_type '%s' more then once" % - (resource_pattern, action_type)) + location, + ERROR_ID_DUPLICATE_PRIVILEGE, + "Cannot specify the pair of resource_pattern '%s' and action_type '%s' more then once" + % (resource_pattern, action_type), + ) def add_empty_access_check(self, location): # type: (common.SourceLocation) -> None """Add an error about specifying one of ignore, none, simple or complex in an access check.""" self._add_error( - location, ERROR_ID_EMPTY_ACCESS_CHECK, - "Must one and only one of either a 'ignore', 'none', 'simple', or 'complex' in an access_check." + location, + ERROR_ID_EMPTY_ACCESS_CHECK, + "Must one and only one of either a 'ignore', 'none', 'simple', or 'complex' in an access_check.", ) def add_missing_access_check(self, location, name): # type: (common.SourceLocation, str) -> None """Add an error about a missing access_check when api_version != "".""" - self._add_error(location, ERROR_ID_MISSING_ACCESS_CHECK, - 'Command "%s" has api_version != "" but is missing access_check.' % (name)) + self._add_error( + location, + ERROR_ID_MISSING_ACCESS_CHECK, + 'Command "%s" has api_version != "" but is missing access_check.' % (name), + ) def add_stability_unknown_value(self, location): # type: (common.SourceLocation) -> None """Add an error about a field with unknown value set to 'stability' option.""" self._add_error( - location, ERROR_ID_STABILITY_UNKNOWN_VALUE, - "Field option 'stability' has unknown value, should be one of 'stable', 'unstable' or 'internal.'" + location, + ERROR_ID_STABILITY_UNKNOWN_VALUE, + "Field option 'stability' has unknown value, should be one of 'stable', 'unstable' or 'internal.'", ) def add_duplicate_unstable_stability(self, location): # type: (common.SourceLocation) -> None """Add an error about a field specifying both 'unstable' and 'stability'.""" - self._add_error(location, ERROR_ID_DUPLICATE_UNSTABLE_STABILITY, ( - "Field specifies both 'unstable' and 'stability' options, should use 'stability: [stable|unstable|internal]' instead and remove the deprecated 'unstable' option." - )) + self._add_error( + location, + ERROR_ID_DUPLICATE_UNSTABLE_STABILITY, + ( + "Field specifies both 'unstable' and 'stability' options, should use 'stability: [stable|unstable|internal]' instead and remove the deprecated 'unstable' option." + ), + ) def add_must_declare_shape_type(self, location, struct_name, field_name): # type: (common.SourceLocation, str, str) -> None """Add an error about a field not specifying either query_shape_literal or query_shape_anonymize if the struct is query_shape_component.""" self._add_error( - location, ERROR_ID_FIELD_MUST_DECLARE_SHAPE_LITERAL, - f"Field '{field_name}' must specify either 'query_shape_literal' or 'query_shape_anonymize' since struct '{struct_name}' is a query shape component." + location, + ERROR_ID_FIELD_MUST_DECLARE_SHAPE_LITERAL, + f"Field '{field_name}' must specify either 'query_shape_literal' or 'query_shape_anonymize' since struct '{struct_name}' is a query shape component.", ) def add_must_be_query_shape_component(self, location, struct_name, field_name): # type: (common.SourceLocation, str, str) -> None self._add_error( - location, ERROR_ID_CANNOT_DECLARE_SHAPE_LITERAL, - f"Field '{field_name}' cannot specify 'query_shape_literal' property since struct '{struct_name}' is not a query shape component." + location, + ERROR_ID_CANNOT_DECLARE_SHAPE_LITERAL, + f"Field '{field_name}' cannot specify 'query_shape_literal' property since struct '{struct_name}' is not a query shape component.", ) - def add_query_shape_anonymize_must_be_string(self, location, field_name, field_type): + def add_query_shape_anonymize_must_be_string( + self, location, field_name, field_type + ): self._add_error( - location, ERROR_ID_INVALID_TYPE_FOR_SHAPIFY, - f"In order for {field_name} to be marked as a query shape fieldpath, it must have a string type, not {field_type}." + location, + ERROR_ID_INVALID_TYPE_FOR_SHAPIFY, + f"In order for {field_name} to be marked as a query shape fieldpath, it must have a string type, not {field_type}.", ) def add_invalid_query_shape_value(self, location, query_shape_value): - self._add_error(location, ERROR_ID_QUERY_SHAPE_INVALID_VALUE, - f"'{query_shape_value}' is not a valid value for 'query_shape'.") + self._add_error( + location, + ERROR_ID_QUERY_SHAPE_INVALID_VALUE, + f"'{query_shape_value}' is not a valid value for 'query_shape'.", + ) def add_strict_and_disable_check_not_allowed(self, location): self._add_error( - location, ERROR_ID_STRICT_AND_DISABLE_CHECK_NOT_ALLOWED, - "Cannot set strict = true and unsafe_dangerous_disable_extra_field_duplicate_checks = true on a struct. unsafe_dangerous_disable_extra_field_duplicate_checks is only permitted on strict = false" + location, + ERROR_ID_STRICT_AND_DISABLE_CHECK_NOT_ALLOWED, + "Cannot set strict = true and unsafe_dangerous_disable_extra_field_duplicate_checks = true on a struct. unsafe_dangerous_disable_extra_field_duplicate_checks is only permitted on strict = false", ) def add_inheritance_and_disable_check_not_allowed(self, location): self._add_error( - location, ERROR_ID_INHERITANCE_AND_DISABLE_CHECK_NOT_ALLOWED, - "Fields cannot have unsafe_dangerous_disable_extra_field_duplicate_checks = true. unsafe_dangerous_disable_extra_field_duplicate_checks on non field structs" + location, + ERROR_ID_INHERITANCE_AND_DISABLE_CHECK_NOT_ALLOWED, + "Fields cannot have unsafe_dangerous_disable_extra_field_duplicate_checks = true. unsafe_dangerous_disable_extra_field_duplicate_checks on non field structs", ) def add_bad_cpp_namespace(self, location, namespace): # type: (common.SourceLocation, str) -> None self._add_error( - location, ERROR_ID_BAD_CPP_NAMESPACE, + location, + ERROR_ID_BAD_CPP_NAMESPACE, "cpp_namespace must start with 'mongo::' or be just 'mongo', namespace '%s' is not supported" - % (namespace)) + % (namespace), + ) def _assert_unique_error_messages(): diff --git a/buildscripts/idl/idl/generator.py b/buildscripts/idl/idl/generator.py index 46867f4f1b5..723e0c00fe7 100644 --- a/buildscripts/idl/idl/generator.py +++ b/buildscripts/idl/idl/generator.py @@ -61,7 +61,7 @@ class _StructDataOwnership(Enum): def _get_field_member_name(field): # type: (ast.Field) -> str """Get the C++ class member name for a field.""" - return '_%s' % (common.camel_case(field.cpp_name)) + return "_%s" % (common.camel_case(field.cpp_name)) def _get_field_member_setter_name(field): @@ -79,7 +79,7 @@ def _get_field_member_getter_name(field): def _get_has_field_member_name(field): # type: (ast.Field) -> str """Get the C++ class member name for bool 'has' member field.""" - return '_has%s' % (common.title_case(field.cpp_name)) + return "_has%s" % (common.title_case(field.cpp_name)) def _is_required_serializer_field(field): @@ -90,7 +90,13 @@ def _is_required_serializer_field(field): Fields that must be set before serialization are fields without default values, that are not optional, and are not chained. """ - return not field.ignore and not field.optional and not field.default and not field.chained and not field.chained_struct_field + return ( + not field.ignore + and not field.optional + and not field.default + and not field.chained + and not field.chained_struct_field + ) def _get_field_kname(*args, **kwargs): @@ -99,25 +105,26 @@ def _get_field_kname(*args, **kwargs): field_name = args[0].cpp_name else: field_name = kwargs["name"] - return f'k{common.title_case(field_name)}' + return f"k{common.title_case(field_name)}" def _get_field_enum(field): # type: (ast.Field) -> str - return f'Field::{_get_field_kname(field)}' + return f"Field::{_get_field_kname(field)}" def _get_field_constant_name(field): # type: (ast.Field) -> str """Get the C++ string constant name for a field.""" - return common.template_args('k${constant_name}FieldName', - constant_name=common.title_case(field.cpp_name)) + return common.template_args( + "k${constant_name}FieldName", constant_name=common.title_case(field.cpp_name) + ) def _get_field_member_validator_name(field): # type: (ast.Field) -> str """Get the name of the validator method for this field.""" - return 'validate%s' % common.title_case(field.cpp_name) + return "validate%s" % common.title_case(field.cpp_name) def _access_member(field): @@ -125,15 +132,15 @@ def _access_member(field): """Get the declaration to access a member for a field.""" member_name = _get_field_member_name(field) if field.optional: - member_name = '(*%s)' % (member_name) + member_name = "(*%s)" % (member_name) return member_name def _std_array_expr(value_type, elems): # type: (str, List[str]) -> str """Return a std::array{elems} expression.""" - elem_str = ', '.join(elems) - return f'std::array<{value_type}, {len(elems)}>{{{elem_str}}}' + elem_str = ", ".join(elems) + return f"std::array<{value_type}, {len(elems)}>{{{elem_str}}}" def _get_bson_type_check(bson_element, ctxt_name, ast_type): @@ -142,22 +149,29 @@ def _get_bson_type_check(bson_element, ctxt_name, ast_type): # Deduplicate the types in the array. bson_types = list(set(ast_type.bson_serialization_type)) if len(bson_types) == 1: - if bson_types[0] in ['any', 'chain']: + if bson_types[0] in ["any", "chain"]: # Skip BSON validation for 'any' types since they are required to validate the # BSONElement. # Skip BSON validation for 'chain' types since they process the raw BSONObject the # encapsulating IDL struct parser is passed. return None - if not bson_types[0] == 'bindata': - return 'MONGO_likely(%s.checkAndAssertType(%s, %s))' % ( - ctxt_name, bson_element, bson.cpp_bson_type_name(bson_types[0])) - return 'MONGO_likely(%s.checkAndAssertBinDataType(%s, %s))' % ( - ctxt_name, bson_element, bson.cpp_bindata_subtype_type_name(ast_type.bindata_subtype)) + if not bson_types[0] == "bindata": + return "MONGO_likely(%s.checkAndAssertType(%s, %s))" % ( + ctxt_name, + bson_element, + bson.cpp_bson_type_name(bson_types[0]), + ) + return "MONGO_likely(%s.checkAndAssertBinDataType(%s, %s))" % ( + ctxt_name, + bson_element, + bson.cpp_bindata_subtype_type_name(ast_type.bindata_subtype), + ) else: return ( f'MONGO_likely({ctxt_name}.checkAndAssertTypes({bson_element}, ' - f'{_std_array_expr("BSONType", [bson.cpp_bson_type_name(b) for b in bson_types])}))') + f'{_std_array_expr("BSONType", [bson.cpp_bson_type_name(b) for b in bson_types])}))' + ) def _get_required_fields(struct): @@ -219,7 +233,9 @@ def _gen_field_element_name(field): def _gen_mark_present(field_name): # type: (str) -> str - return f'_hasMembers.markPresent(static_cast(RequiredFields::{field_name}));' + return ( + f"_hasMembers.markPresent(static_cast(RequiredFields::{field_name}));" + ) def _is_parse(field): @@ -244,7 +260,7 @@ def _is_forwarding_disabled(field): def _get_constant(name): # type: (str) -> str """Transform an arbitrary label to a constant name.""" - return 'k' + re.sub(r'([^a-zA-Z0-9_]+)', '_', common.title_case(name)) + return "k" + re.sub(r"([^a-zA-Z0-9_]+)", "_", common.title_case(name)) class _FastFieldUsageChecker(_FieldUsageCheckerBase): @@ -261,7 +277,8 @@ class _FastFieldUsageChecker(_FieldUsageCheckerBase): super(_FastFieldUsageChecker, self).__init__(indented_writer) num_internal_only = len( - [field.name for field in fields if field.type and field.type.internal_only]) + [field.name for field in fields if field.type and field.type.internal_only] + ) self.field_count = len(fields) - num_internal_only bit_id = 0 @@ -270,11 +287,12 @@ class _FastFieldUsageChecker(_FieldUsageCheckerBase): continue self._writer.write_line( - 'const size_t %s = %d;' % (_gen_field_usage_constant(field), bit_id)) + "const size_t %s = %d;" % (_gen_field_usage_constant(field), bit_id) + ) bit_id += 1 if bit_id != 0: - self._writer.write_line('std::bitset<%d> usedFields;' % (self.field_count)) + self._writer.write_line("std::bitset<%d> usedFields;" % (self.field_count)) def add_store(self, field_name): # type: (str) -> None @@ -288,24 +306,34 @@ class _FastFieldUsageChecker(_FieldUsageCheckerBase): self._fields.append(field) with writer.IndentedScopedBlock( - self._writer, - 'if (MONGO_unlikely(usedFields[%s])) {' % (_gen_field_usage_constant(field)), '}'): - self._writer.write_line('ctxt.throwDuplicateField(%s);' % (bson_element_variable)) - self._writer.write_empty_line() - - self._writer.write_line('usedFields.set(%s);' % (_gen_field_usage_constant(field))) - self._writer.write_empty_line() - - if field.stability == 'unstable': + self._writer, + "if (MONGO_unlikely(usedFields[%s])) {" + % (_gen_field_usage_constant(field)), + "}", + ): self._writer.write_line( - 'ctxt.checkAndthrowAPIStrictErrorIfApplicable(%s);' % (bson_element_variable)) + "ctxt.throwDuplicateField(%s);" % (bson_element_variable) + ) + self._writer.write_empty_line() + + self._writer.write_line( + "usedFields.set(%s);" % (_gen_field_usage_constant(field)) + ) + self._writer.write_empty_line() + + if field.stability == "unstable": + self._writer.write_line( + "ctxt.checkAndthrowAPIStrictErrorIfApplicable(%s);" + % (bson_element_variable) + ) self._writer.write_empty_line() def add_final_checks(self): # type: () -> None """Output the code to check for missing fields.""" required_fields = [ - field for field in self._fields + field + for field in self._fields if (not field.optional) and (not field.ignore) and (not field.default) ] @@ -318,27 +346,38 @@ class _FastFieldUsageChecker(_FieldUsageCheckerBase): required_fields = sorted(required_fields, key=lambda f: f.cpp_name) - bitmask = ' | '.join( - ['(1ULL << %s)' % (_gen_field_usage_constant(rf)) for rf in required_fields]) - - self._writer.write_line(f'constexpr std::uint64_t requiredFieldBitMask = {bitmask};') + bitmask = " | ".join( + ["(1ULL << %s)" % (_gen_field_usage_constant(rf)) for rf in required_fields] + ) self._writer.write_line( - 'std::bitset<%d> requiredFields(requiredFieldBitMask);' % (self.field_count)) + f"constexpr std::uint64_t requiredFieldBitMask = {bitmask};" + ) self._writer.write_line( - 'bool hasMissingRequiredFields = (requiredFields & usedFields) != requiredFields;') + "std::bitset<%d> requiredFields(requiredFieldBitMask);" % (self.field_count) + ) - with writer.IndentedScopedBlock(self._writer, 'if (hasMissingRequiredFields) {', '}'): + self._writer.write_line( + "bool hasMissingRequiredFields = (requiredFields & usedFields) != requiredFields;" + ) + + with writer.IndentedScopedBlock( + self._writer, "if (hasMissingRequiredFields) {", "}" + ): for field in required_fields: # If 'field.default' is true, the fields(members) gets initialized with the default # value in the class definition. So, it's ok to skip setting the field to # default value here. with writer.IndentedScopedBlock( - self._writer, 'if (!usedFields[%s]) {' % (_gen_field_usage_constant(field)), - '}'): + self._writer, + "if (!usedFields[%s]) {" % (_gen_field_usage_constant(field)), + "}", + ): self._writer.write_line( - 'ctxt.throwMissingField(%s);' % (_get_field_constant_name(field))) + "ctxt.throwMissingField(%s);" + % (_get_field_constant_name(field)) + ) class _SlowFieldUsageChecker(_FastFieldUsageChecker): @@ -357,7 +396,7 @@ class _SlowFieldUsageChecker(_FastFieldUsageChecker): # type: (writer.IndentedTextWriter, List[ast.Field]) -> None super(_SlowFieldUsageChecker, self).__init__(indented_writer, fields) - self._writer.write_line('std::set usedFieldSet;') + self._writer.write_line("std::set usedFieldSet;") def _get_field_usage_checker(indented_writer, struct): @@ -388,9 +427,9 @@ def _encaps(val): def _encaps_list(vals): # type: (List[str]) -> str if vals is None: - return '{}' + return "{}" - return '{' + ', '.join([_encaps(v) for v in vals]) + '}' + return "{" + ", ".join([_encaps(v) for v in vals]) + "}" # Translate an ast.Expression into C++ code. @@ -401,7 +440,7 @@ def _get_expression(expr): # Wrap in a lambda to let the compiler enforce constexprness for us. # The optimization pass should end up inlining it. - return '([]{ constexpr auto value = %s; return value; })()' % expr.expr + return "([]{ constexpr auto value = %s; return value; })()" % expr.expr class _CppFileWriterBase(object): @@ -431,18 +470,22 @@ class _CppFileWriterBase(object): # type: () -> None """Generate a file header saying the file is generated.""" self._writer.write_unindented_line( - textwrap.dedent("""\ + textwrap.dedent( + """\ /** * WARNING: This is a generated file. Do not modify. * * Source: %s */ - """ % (" ".join(sys.argv)))) + """ + % (" ".join(sys.argv)) + ) + ) def gen_system_include(self, include): # type: (str) -> None """Generate a system C++ include line.""" - self._writer.write_unindented_line('#include <%s>' % (include)) + self._writer.write_unindented_line("#include <%s>" % (include)) def gen_include(self, include): # type: (str) -> None @@ -459,11 +502,11 @@ class _CppFileWriterBase(object): def get_initializer_lambda(self, decl, unused=False, return_type=None): # type: (str, bool, str) -> writer.IndentedScopedBlock """Generate an indented block lambda initializing an outer scope variable.""" - prefix = '[[maybe_unused]] ' if unused else '' - prefix = prefix + decl + ' = ([]' + prefix = "[[maybe_unused]] " if unused else "" + prefix = prefix + decl + " = ([]" if return_type: - prefix = prefix + '() -> ' + return_type - return writer.IndentedScopedBlock(self._writer, prefix + ' {', '})();') + prefix = prefix + "() -> " + return_type + return writer.IndentedScopedBlock(self._writer, prefix + " {", "})();") def gen_description_comment(self, description): # type: (str) -> None @@ -496,14 +539,16 @@ class _CppFileWriterBase(object): if not check_str: return writer.EmptyBlock() - conditional = 'if' + conditional = "if" if use_else_if: - conditional = 'else if' + conditional = "else if" if constexpr: - conditional = conditional + ' constexpr' + conditional = conditional + " constexpr" - return writer.IndentedScopedBlock(self._writer, '%s (%s) {' % (conditional, check_str), '}') + return writer.IndentedScopedBlock( + self._writer, "%s (%s) {" % (conditional, check_str), "}" + ) def _else(self, check_bool): # type: (bool) -> Union[writer.IndentedScopedBlock,writer.EmptyBlock] @@ -511,7 +556,7 @@ class _CppFileWriterBase(object): if not check_bool: return writer.EmptyBlock() - return writer.IndentedScopedBlock(self._writer, 'else {', '}') + return writer.IndentedScopedBlock(self._writer, "else {", "}") def _condition(self, condition, preprocessor_only=False): # type: (ast.Condition, bool) -> writer.WriterBlock @@ -523,7 +568,10 @@ class _CppFileWriterBase(object): blocks = [] # type: List[writer.WriterBlock] if condition.preprocessor: blocks.append( - writer.UnindentedBlock(self._writer, '#if ' + condition.preprocessor, '#endif')) + writer.UnindentedBlock( + self._writer, "#if " + condition.preprocessor, "#endif" + ) + ) if not preprocessor_only: if condition.constexpr: @@ -546,8 +594,9 @@ class _CppHeaderFileWriter(_CppFileWriterBase): def gen_class_declaration_block(self, class_name): # type: (str) -> writer.IndentedScopedBlock """Generate a class declaration block.""" - return writer.IndentedScopedBlock(self._writer, - 'class %s {' % common.title_case(class_name), '};') + return writer.IndentedScopedBlock( + self._writer, "class %s {" % common.title_case(class_name), "};" + ) def gen_class_constructors(self, struct): # type: (ast.Struct) -> None @@ -557,7 +606,9 @@ class _CppHeaderFileWriter(_CppFileWriterBase): constructor = struct_type_info.get_constructor_method(gen_header=True) self._writer.write_line(constructor.get_declaration()) - required_constructor = struct_type_info.get_required_constructor_method(gen_header=True) + required_constructor = struct_type_info.get_required_constructor_method( + gen_header=True + ) if len(required_constructor.args) != len(constructor.args): self._writer.write_line(required_constructor.get_declaration()) @@ -572,9 +623,13 @@ class _CppHeaderFileWriter(_CppFileWriterBase): # type: (ast.Struct) -> None """Generate serializer method declarations.""" struct_type_info = struct_types.get_struct_info(struct) - self._writer.write_line(struct_type_info.get_serializer_method().get_declaration()) + self._writer.write_line( + struct_type_info.get_serializer_method().get_declaration() + ) - maybe_op_msg_serializer = struct_type_info.get_op_msg_request_serializer_method() + maybe_op_msg_serializer = ( + struct_type_info.get_op_msg_request_serializer_method() + ) if maybe_op_msg_serializer: self._writer.write_line(maybe_op_msg_serializer.get_declaration()) @@ -590,7 +645,7 @@ class _CppHeaderFileWriter(_CppFileWriterBase): struct_type_info.get_deserializer_static_method(), struct_type_info.get_owned_deserializer_static_method(), struct_type_info.get_sharing_deserializer_static_method(), - struct_type_info.get_op_msg_request_deserializer_static_method() + struct_type_info.get_op_msg_request_deserializer_static_method(), ] for maybe_parse_method in possible_deserializer_methods: if maybe_parse_method: @@ -615,13 +670,17 @@ class _CppHeaderFileWriter(_CppFileWriterBase): generate a BSON representation of an IDL struct, use its `serialize` member functions. Participating in ownership of the underlying data merely allows the struct to ensure that struct members that are pointers-into-BSON (i.e. BSONElement and BSONObject) are valid for - the lifetime of the struct itself.""")) - self._writer.write_line("bool isOwned() const { return _anchorObj.isOwned(); }") + the lifetime of the struct itself.""") + ) + self._writer.write_line( + "bool isOwned() const { return _anchorObj.isOwned(); }" + ) else: self.gen_description_comment( textwrap.dedent("""\ This function will return true every time because there is no underlying BSONObj anchor. - The object owns the data of all of its members.""")) + The object owns the data of all of its members.""") + ) self._writer.write_line("bool isOwner() const { return true; }") def gen_protected_ownership_setters(self, struct): @@ -630,12 +689,12 @@ class _CppHeaderFileWriter(_CppFileWriterBase): if struct.is_view: # If the struct is not a view type, then a BSONObj anchor is not needed because we # know the struct owns all of its data. - with self._block('void setAnchor(const BSONObj& obj) {', '}'): + with self._block("void setAnchor(const BSONObj& obj) {", "}"): self._writer.write_line("invariant(obj.isOwned());") self._writer.write_line("_anchorObj = obj;") self._writer.write_empty_line() - with self._block('void setAnchor(BSONObj&& obj) {', '}'): + with self._block("void setAnchor(BSONObj&& obj) {", "}"): self._writer.write_line("invariant(obj.isOwned());") self._writer.write_line("_anchorObj = std::move(obj);") self._writer.write_empty_line() @@ -671,35 +730,45 @@ class _CppHeaderFileWriter(_CppFileWriterBase): param_type += "&" template_params = { - 'method_name': _get_field_member_getter_name(field), - 'param_type': param_type, - 'body': cpp_type_info.get_getter_body(member_name), - 'const_type': 'const ' if cpp_type_info.return_by_reference() else '', + "method_name": _get_field_member_getter_name(field), + "param_type": param_type, + "body": cpp_type_info.get_getter_body(member_name), + "const_type": "const " if cpp_type_info.return_by_reference() else "", } # Generate a getter that disables xvalue for view types (i.e. StringData), constructed # optional types, and non-primitive types. with self._with_template(template_params): - if field.chained_struct_field: self._writer.write_template( - '${const_type} ${param_type} ${method_name}() const { return %s.%s(); }' % ( - (_get_field_member_name(field.chained_struct_field), - _get_field_member_getter_name(field)))) + "${const_type} ${param_type} ${method_name}() const { return %s.%s(); }" + % ( + ( + _get_field_member_name(field.chained_struct_field), + _get_field_member_getter_name(field), + ) + ) + ) elif field.type.is_struct: # Support mutable accessors self._writer.write_template( - 'const ${param_type} ${method_name}() const { ${body} }') + "const ${param_type} ${method_name}() const { ${body} }" + ) if not struct.immutable: - self._writer.write_template('${param_type} ${method_name}() { ${body} }') + self._writer.write_template( + "${param_type} ${method_name}() { ${body} }" + ) else: self._writer.write_template( - '${const_type}${param_type} ${method_name}() const { ${body} }') + "${const_type}${param_type} ${method_name}() const { ${body} }" + ) if field.non_const_getter: - self._writer.write_template('${param_type} ${method_name}() { ${body} }') + self._writer.write_template( + "${param_type} ${method_name}() { ${body} }" + ) def gen_validators(self, field): # type: (ast.Field) -> None @@ -710,18 +779,19 @@ class _CppHeaderFileWriter(_CppFileWriterBase): param_type = cpp_type_info.get_storage_type() if not cpp_types.is_primitive_type(param_type): - param_type = 'const ' + param_type + '&' + param_type = "const " + param_type + "&" template_params = { - 'method_name': _get_field_member_validator_name(field), - 'param_type': param_type, + "method_name": _get_field_member_validator_name(field), + "param_type": param_type, } with self._with_template(template_params): # Declare method implemented in C++ file. - self._writer.write_template('void ${method_name}(${param_type} value);') + self._writer.write_template("void ${method_name}(${param_type} value);") self._writer.write_template( - 'void ${method_name}(IDLParserContext& ctxt, ${param_type} value);') + "void ${method_name}(IDLParserContext& ctxt, ${param_type} value);" + ) self._writer.write_empty_line() @@ -733,15 +803,22 @@ class _CppHeaderFileWriter(_CppFileWriterBase): storage_type = cpp_type_info.get_storage_type() is_serial = _is_required_serializer_field(field) memfn = _get_field_member_setter_name(field) - validator = _get_field_member_validator_name(field) if field.validator is not None else "" + validator = ( + _get_field_member_validator_name(field) + if field.validator is not None + else "" + ) # Generate the setter for instances of the "getter/setter type", which may not be the same # as the storage type. if field.chained_struct_field: body = "{}.{}(std::move(value));".format( - _get_field_member_name(field.chained_struct_field), memfn) + _get_field_member_name(field.chained_struct_field), memfn + ) else: - body = cpp_type_info.get_setter_body(_get_field_member_name(field), validator) + body = cpp_type_info.get_setter_body( + _get_field_member_name(field), validator + ) set_has = _gen_mark_present(field.cpp_name) if is_serial else "" with self._block(f"void {memfn}({setter_type} value) {{", "}"): @@ -755,10 +832,12 @@ class _CppHeaderFileWriter(_CppFileWriterBase): if storage_type != setter_type and cpp_type_info.has_storage_type_setter(): if field.chained_struct_field: storage_setter_body = "{}.{}(std::move(value));".format( - _get_field_member_name(field.chained_struct_field), memfn) + _get_field_member_name(field.chained_struct_field), memfn + ) else: storage_setter_body = cpp_type_info.get_storage_type_setter_body( - _get_field_member_name(field), validator) + _get_field_member_name(field), validator + ) with self._block(f"void {memfn}({storage_type} value) {{", "}"): self._writer.write_line(storage_setter_body) @@ -773,7 +852,8 @@ class _CppHeaderFileWriter(_CppFileWriterBase): # storage type setter. with self._block("template"): self._writer.write_line( - f"std::enable_if_t, int> = 0") + f"std::enable_if_t, int> = 0" + ) with self._block(f"void {memfn}(const T& value) {{", "}"): self._writer.write_line(f"{memfn}({storage_type}{{value}});") @@ -781,7 +861,8 @@ class _CppHeaderFileWriter(_CppFileWriterBase): # type: () -> None """Generate the getters for constexpr data.""" self._writer.write_line( - 'constexpr bool getIsCommandReply() const { return _isCommandReply; }') + "constexpr bool getIsCommandReply() const { return _isCommandReply; }" + ) def gen_member(self, field): # type: (ast.Field) -> None @@ -794,20 +875,26 @@ class _CppHeaderFileWriter(_CppFileWriterBase): # constructed. if field.default and not field.constructed: if field.type.is_enum: - self._writer.write_line('%s %s{%s::%s};' % (member_type, member_name, - field.type.cpp_type, field.default)) + self._writer.write_line( + "%s %s{%s::%s};" + % (member_type, member_name, field.type.cpp_type, field.default) + ) elif field.type.is_struct: - self._writer.write_line('%s %s;' % (member_type, member_name)) + self._writer.write_line("%s %s;" % (member_type, member_name)) else: - self._writer.write_line('%s %s{%s};' % (member_type, member_name, field.default)) + self._writer.write_line( + "%s %s{%s};" % (member_type, member_name, field.default) + ) else: - self._writer.write_line('%s %s;' % (member_type, member_name)) + self._writer.write_line("%s %s;" % (member_type, member_name)) def gen_constexpr_members(self, struct): # type: (ast.Struct) -> None """Generate the C++ class member definition for constexpr data.""" cpp_string_val = "true" if struct.is_command_reply else "false" - self._writer.write_line(f'static constexpr bool _isCommandReply{{{cpp_string_val}}};') + self._writer.write_line( + f"static constexpr bool _isCommandReply{{{cpp_string_val}}};" + ) def gen_serializer_member(self, field): # type: (ast.Field) -> None @@ -815,53 +902,66 @@ class _CppHeaderFileWriter(_CppFileWriterBase): has_member_name = _get_has_field_member_name(field) # Use a bitfield to save space - self._writer.write_line('bool %s : 1;' % (has_member_name)) + self._writer.write_line("bool %s : 1;" % (has_member_name)) def gen_string_constants_declarations(self, struct): # type: (ast.Struct) -> None """Generate a StringData constant for field name.""" fields = [ - field for field in _get_all_fields(struct) + field + for field in _get_all_fields(struct) if not field.type or (field.type and not field.type.internal_only) ] for field in fields: self._writer.write_line( - common.template_args('static constexpr auto ${constant_name} = "${field_name}"_sd;', - constant_name=_get_field_constant_name(field), - field_name=field.name)) + common.template_args( + 'static constexpr auto ${constant_name} = "${field_name}"_sd;', + constant_name=_get_field_constant_name(field), + field_name=field.name, + ) + ) if isinstance(struct, ast.Command): self._writer.write_line( common.template_args( - 'static constexpr auto kCommandDescription = ${description}_sd;', - description=_encaps(struct.description))) + "static constexpr auto kCommandDescription = ${description}_sd;", + description=_encaps(struct.description), + ) + ) self._writer.write_line( - common.template_args('static constexpr auto kCommandName = "${command_name}"_sd;', - command_name=struct.command_name)) + common.template_args( + 'static constexpr auto kCommandName = "${command_name}"_sd;', + command_name=struct.command_name, + ) + ) # Initialize constexpr for command alias if specified in the IDL spec. if struct.command_alias: self._writer.write_line( common.template_args( 'static constexpr auto kCommandAlias = "${command_alias}"_sd;', - command_alias=struct.command_alias)) + command_alias=struct.command_alias, + ) + ) def gen_field_enum(self, struct): # type: (ast.Struct) -> None """Declare the public enum and string constants for struct fields.""" - with self._block('enum class Field {', '};'): + with self._block("enum class Field {", "};"): for f in struct.fields: - self._writer.write_line(f'{_get_field_kname(f)},') - with self._block('static constexpr std::array fieldNames{', '};'): + self._writer.write_line(f"{_get_field_kname(f)},") + with self._block("static constexpr std::array fieldNames{", "};"): for f in struct.fields: self._writer.write_line(f'"{f.name}"_sd,') self._writer.write_empty_line() def gen_required_field_enum(self, struct): - self._writer.write_line('enum class RequiredFields : size_t { %s };' % ', '.join( - [f.cpp_name for f in _get_required_fields(struct)])) + self._writer.write_line( + "enum class RequiredFields : size_t { %s };" + % ", ".join([f.cpp_name for f in _get_required_fields(struct)]) + ) def gen_authorization_contract_declaration(self, struct): # type: (ast.Struct) -> None @@ -873,7 +973,7 @@ class _CppHeaderFileWriter(_CppFileWriterBase): if struct.access_checks is None: return - self._writer.write_line('static AuthorizationContract kAuthorizationContract;') + self._writer.write_line("static AuthorizationContract kAuthorizationContract;") self.write_empty_line() def gen_enum_functions(self, idl_enum): @@ -894,14 +994,20 @@ class _CppHeaderFileWriter(_CppFileWriterBase): """Generate the declaration for an enum.""" enum_type_info = enum_types.get_type_info(idl_enum) - with self._block('enum class %s : std::int32_t {' % (enum_type_info.get_cpp_type_name()), - '};'): + with self._block( + "enum class %s : std::int32_t {" % (enum_type_info.get_cpp_type_name()), + "};", + ): for enum_value in idl_enum.values: if enum_value.description is not None: self.gen_description_comment(enum_value.description) self._writer.write_line( - common.template_args('${name}${value},', name=enum_value.name, - value=enum_type_info.get_cpp_value_assignment(enum_value))) + common.template_args( + "${name}${value},", + name=enum_value.name, + value=enum_type_info.get_cpp_value_assignment(enum_value), + ) + ) def gen_op_msg_request_methods(self, command): # type: (ast.Command) -> None @@ -930,16 +1036,25 @@ class _CppHeaderFileWriter(_CppFileWriterBase): """Generate the field list entries map for a generic argument or reply field list.""" field_list_info = generic_field_list_types.get_field_list_info(struct) self._writer.write_line( - common.template_args('// Map: fieldName -> ${should_forward_name}', - should_forward_name=field_list_info.get_should_forward_name())) - self._writer.write_line("static const StaticImmortal> _genericFields;") + common.template_args( + "// Map: fieldName -> ${should_forward_name}", + should_forward_name=field_list_info.get_should_forward_name(), + ) + ) + self._writer.write_line( + "static const StaticImmortal> _genericFields;" + ) self.write_empty_line() def gen_known_fields_declaration(self): # type: () -> None """Generate all the known fields vectors for a command.""" - self._writer.write_line("static const std::vector _knownBSONFields;") - self._writer.write_line("static const std::vector _knownOP_MSGFields;") + self._writer.write_line( + "static const std::vector _knownBSONFields;" + ) + self._writer.write_line( + "static const std::vector _knownOP_MSGFields;" + ) self.write_empty_line() def gen_comparison_operators_declarations(self, struct): @@ -947,22 +1062,45 @@ class _CppHeaderFileWriter(_CppFileWriterBase): """Generate comparison operators declarations for the type.""" with self._block("auto _relopTuple() const {", "}"): - sorted_fields = sorted([ - field for field in struct.fields - if (not field.ignore and not (field.type and field.type.internal_only)) - and field.comparison_order != -1 - ], key=lambda f: f.comparison_order) - self._writer.write_line("return std::tuple({});".format(", ".join( - map(lambda f: "idl::relop::Ordering{{{}}}".format(_get_field_member_name(f)), - sorted_fields)))) + sorted_fields = sorted( + [ + field + for field in struct.fields + if ( + not field.ignore + and not (field.type and field.type.internal_only) + ) + and field.comparison_order != -1 + ], + key=lambda f: f.comparison_order, + ) + self._writer.write_line( + "return std::tuple({});".format( + ", ".join( + map( + lambda f: "idl::relop::Ordering{{{}}}".format( + _get_field_member_name(f) + ), + sorted_fields, + ) + ) + ) + ) - for op in ['==', '!=', '<', '>', '<=', '>=']: + for op in ["==", "!=", "<", ">", "<=", ">="]: with self._block( - common.template_args( - "friend bool operator${op}(const ${cls}& a, const ${cls}& b) {", op=op, - cls=common.title_case(struct.name)), "}"): + common.template_args( + "friend bool operator${op}(const ${cls}& a, const ${cls}& b) {", + op=op, + cls=common.title_case(struct.name), + ), + "}", + ): self._writer.write_line( - common.template_args('return a._relopTuple() ${op} b._relopTuple();', op=op)) + common.template_args( + "return a._relopTuple() ${op} b._relopTuple();", op=op + ) + ) self.write_empty_line() @@ -974,7 +1112,8 @@ class _CppHeaderFileWriter(_CppFileWriterBase): with self._condition(condition, preprocessor_only=True): self._writer.write_line( - 'constexpr auto %s%s = %s;' % (_get_constant(name), suffix, expr.expr)) + "constexpr auto %s%s = %s;" % (_get_constant(name), suffix, expr.expr) + ) self.write_empty_line() @@ -985,15 +1124,15 @@ class _CppHeaderFileWriter(_CppFileWriterBase): return with self._condition(condition, preprocessor_only=True): - idents = varname.split('::') + idents = varname.split("::") decl = idents.pop() for ns in idents: - self._writer.write_line('namespace %s {' % (ns)) + self._writer.write_line("namespace %s {" % (ns)) - self._writer.write_line('extern %s %s;' % (vartype, decl)) + self._writer.write_line("extern %s %s;" % (vartype, decl)) for ns in reversed(idents): - self._writer.write_line('} // namespace ' + ns) + self._writer.write_line("} // namespace " + ns) if idents: self.write_empty_line() @@ -1008,10 +1147,13 @@ class _CppHeaderFileWriter(_CppFileWriterBase): if initializer.register: self._writer.write_line( - 'Status %s(optionenvironment::OptionSection*);' % (initializer.register)) + "Status %s(optionenvironment::OptionSection*);" % (initializer.register) + ) if initializer.store: self._writer.write_line( - 'Status %s(const optionenvironment::Environment&);' % (initializer.store)) + "Status %s(const optionenvironment::Environment&);" + % (initializer.store) + ) if initializer.register or initializer.store: self.write_empty_line() @@ -1024,124 +1166,146 @@ class _CppHeaderFileWriter(_CppFileWriterBase): cls = scp.cpp_class - with self._block('class %s : public ServerParameter {' % (cls.name), '};'): - self._writer.write_unindented_line('public:') + with self._block("class %s : public ServerParameter {" % (cls.name), "};"): + self._writer.write_unindented_line("public:") if scp.default is not None: self._writer.write_line( - 'static constexpr auto kDataDefault = %s;' % (scp.default.expr)) + "static constexpr auto kDataDefault = %s;" % (scp.default.expr) + ) if cls.override_ctor: # Explicit custom constructor. - self._writer.write_line(cls.name + '(StringData name, ServerParameterType spt);') + self._writer.write_line( + cls.name + "(StringData name, ServerParameterType spt);" + ) else: # Inherit base constructor. - self._writer.write_line('using ServerParameter::ServerParameter;') + self._writer.write_line("using ServerParameter::ServerParameter;") self.write_empty_line() self._writer.write_line( - 'void append(OperationContext*, BSONObjBuilder*, StringData, const boost::optional&) final;' + "void append(OperationContext*, BSONObjBuilder*, StringData, const boost::optional&) final;" ) if cls.override_set: self._writer.write_line( - 'Status set(const BSONElement&, const boost::optional&) final;') + "Status set(const BSONElement&, const boost::optional&) final;" + ) self._writer.write_line( - 'Status setFromString(StringData, const boost::optional&) final;') + "Status setFromString(StringData, const boost::optional&) final;" + ) # If override_validate is set, provide an override definition. Otherwise, it will inherit # from the base ServerParameter implementation. if cls.override_validate: self._writer.write_line( - 'Status validate(const BSONElement&, const boost::optional& tenantId) const final;' + "Status validate(const BSONElement&, const boost::optional& tenantId) const final;" ) # The reset() and getClusterParameterTime() methods must be custom implemented for # specialized cluster server parameters. Provide the declarations here. - if scp.set_at == 'ServerParameterType::kClusterWide': - self._writer.write_line('Status reset(const boost::optional&) final;') + if scp.set_at == "ServerParameterType::kClusterWide": self._writer.write_line( - 'LogicalTime getClusterParameterTime(const boost::optional&) const final;' + "Status reset(const boost::optional&) final;" + ) + self._writer.write_line( + "LogicalTime getClusterParameterTime(const boost::optional&) const final;" ) if cls.data is not None: self.write_empty_line() if scp.default is not None: - self._writer.write_line('%s _data{kDataDefault};' % (cls.data)) + self._writer.write_line("%s _data{kDataDefault};" % (cls.data)) else: - self._writer.write_line('%s _data;' % (cls.data)) + self._writer.write_line("%s _data;" % (cls.data)) self.write_empty_line() def gen_template_declaration(self): # type: () -> None """Generate a template declaration for a command's base class.""" - self._writer.write_line('template ') + self._writer.write_line("template ") def gen_derived_class_declaration_block(self, class_name): # type: (str) -> writer.IndentedScopedBlock """Generate a command's base class declaration block.""" return writer.IndentedScopedBlock( - self._writer, 'class %s : public TypedCommand {' % class_name, '};') + self._writer, "class %s : public TypedCommand {" % class_name, "};" + ) def gen_type_alias_declaration(self, new_type_name, old_type_name): # type: (str, str) -> None """Generate a type alias declaration.""" self._writer.write_line( - 'using %s = %s;' % (new_type_name, common.title_case(old_type_name))) + "using %s = %s;" % (new_type_name, common.title_case(old_type_name)) + ) - def gen_derived_class_constructor(self, command_name, api_version, base_class, - *base_class_args): + def gen_derived_class_constructor( + self, command_name, api_version, base_class, *base_class_args + ): # type: (str, str, str, *str) -> None """Generate a derived class constructor.""" - class_name = common.title_case(command_name) + "CmdVersion" + api_version + "Gen" + class_name = ( + common.title_case(command_name) + "CmdVersion" + api_version + "Gen" + ) args = ", ".join(base_class_args) - self._writer.write_line('%s(): %s(%s) {}' % (class_name, base_class, args)) + self._writer.write_line("%s(): %s(%s) {}" % (class_name, base_class, args)) def gen_api_version_fn(self, is_api_versions, api_version): # type: (bool, Union[str, bool]) -> None """Generate an apiVersions or deprecatedApiVersions function for a command's base class.""" fn_name = "apiVersions" if is_api_versions else "deprecatedApiVersions" - fn_def = 'const std::set& %s() const final' % fn_name + fn_def = "const std::set& %s() const final" % fn_name value = "kApiVersions1" if api_version else "kNoApiVersions" - with self._block('%s {' % (fn_def), '}'): - self._writer.write_line('return %s;' % value) + with self._block("%s {" % (fn_def), "}"): + self._writer.write_line("return %s;" % value) def gen_invocation_base_class_declaration(self, command): # type: (ast.Command) -> None """Generate the InvocationBaseGen class for a command's base class.""" - class_declaration = 'class InvocationBaseGen : public _TypedCommandInvocationBase {' - with writer.IndentedScopedBlock(self._writer, class_declaration, '};'): + class_declaration = ( + "class InvocationBaseGen : public _TypedCommandInvocationBase {" + ) + with writer.IndentedScopedBlock(self._writer, class_declaration, "};"): # public requires special indentation that aligns with the class definition. self._writer.unindent() - self._writer.write_line('public:') + self._writer.write_line("public:") self._writer.indent() # Inherit base constructor. self._writer.write_line( - 'using _TypedCommandInvocationBase::_TypedCommandInvocationBase;') + "using _TypedCommandInvocationBase::_TypedCommandInvocationBase;" + ) - self._writer.write_line('virtual Reply typedRun(OperationContext* opCtx) = 0;') + self._writer.write_line( + "virtual Reply typedRun(OperationContext* opCtx) = 0;" + ) if command.access_checks == []: self._writer.write_line( - 'void doCheckAuthorization(OperationContext* opCtx) const final {}') + "void doCheckAuthorization(OperationContext* opCtx) const final {}" + ) def generate_versioned_command_base_class(self, command): # type: (ast.Command) -> None """Generate a command's C++ base class to a stream.""" - class_name = "%sCmdVersion%sGen" % (common.title_case(command.command_name), - command.api_version) + class_name = "%sCmdVersion%sGen" % ( + common.title_case(command.command_name), + command.api_version, + ) self.write_empty_line() self.gen_template_declaration() with self.gen_derived_class_declaration_block(class_name): # Write type alias for InvocationBase. - self.gen_type_alias_declaration('_TypedCommandInvocationBase', - 'typename TypedCommand::InvocationBase') + self.gen_type_alias_declaration( + "_TypedCommandInvocationBase", + "typename TypedCommand::InvocationBase", + ) self.write_empty_line() - self.write_unindented_line('public:') + self.write_unindented_line("public:") # Write type aliases for Request and Reply. self.gen_type_alias_declaration("Request", command.cpp_name) @@ -1150,9 +1314,13 @@ class _CppHeaderFileWriter(_CppFileWriterBase): # Generate a constructor for generated derived class if command alias is specified. if command.command_alias: self.write_empty_line() - self.gen_derived_class_constructor(command.command_name, command.api_version, - 'TypedCommand', 'Request::kCommandName', - 'Request::kCommandAlias') + self.gen_derived_class_constructor( + command.command_name, + command.api_version, + "TypedCommand", + "Request::kCommandName", + "Request::kCommandAlias", + ) self.write_empty_line() @@ -1163,7 +1331,7 @@ class _CppHeaderFileWriter(_CppFileWriterBase): # Wrte authorization contract code if command.access_checks is not None: self._writer.write_line( - 'const AuthorizationContract* getAuthorizationContract() const final { return &Request::kAuthorizationContract; } ' + "const AuthorizationContract* getAuthorizationContract() const final { return &Request::kAuthorizationContract; } " ) self.write_empty_line() @@ -1175,17 +1343,17 @@ class _CppHeaderFileWriter(_CppFileWriterBase): """Generate the C++ header to a stream.""" self.gen_file_header() - self._writer.write_unindented_line('#pragma once') + self._writer.write_unindented_line("#pragma once") self.write_empty_line() # Generate system includes first header_list = [ - 'algorithm', - 'boost/optional.hpp', - 'cstdint', - 'string', - 'tuple', - 'vector', + "algorithm", + "boost/optional.hpp", + "cstdint", + "string", + "tuple", + "vector", ] header_list.sort() @@ -1197,47 +1365,50 @@ class _CppHeaderFileWriter(_CppFileWriterBase): # Generate user includes second header_list = [ - 'mongo/base/string_data.h', - 'mongo/base/data_range.h', - 'mongo/bson/bsonobj.h', - 'mongo/bson/bsonobjbuilder.h', - 'mongo/bson/simple_bsonobj_comparator.h', - 'mongo/idl/idl_parser.h', - 'mongo/rpc/op_msg.h', - 'mongo/stdx/unordered_map.h', - 'mongo/util/decimal_counter.h', - 'mongo/util/serialization_context.h', + "mongo/base/string_data.h", + "mongo/base/data_range.h", + "mongo/bson/bsonobj.h", + "mongo/bson/bsonobjbuilder.h", + "mongo/bson/simple_bsonobj_comparator.h", + "mongo/idl/idl_parser.h", + "mongo/rpc/op_msg.h", + "mongo/stdx/unordered_map.h", + "mongo/util/decimal_counter.h", + "mongo/util/serialization_context.h", ] + spec.globals.cpp_includes if spec.configs: - header_list.append('mongo/util/options_parser/option_description.h') + header_list.append("mongo/util/options_parser/option_description.h") config_init = spec.globals.configs and spec.globals.configs.initializer if config_init and (config_init.register or config_init.store): - header_list.append('mongo/util/options_parser/option_section.h') - header_list.append('mongo/util/options_parser/environment.h') + header_list.append("mongo/util/options_parser/option_section.h") + header_list.append("mongo/util/options_parser/environment.h") if spec.server_parameters: if [ - param for param in spec.server_parameters - if param.feature_flag or (param.condition and param.condition.feature_flag) + param + for param in spec.server_parameters + if param.feature_flag + or (param.condition and param.condition.feature_flag) ]: - header_list.append('mongo/db/feature_flag.h') + header_list.append("mongo/db/feature_flag.h") if [ - param for param in spec.server_parameters - if param.condition and param.condition.min_fcv + param + for param in spec.server_parameters + if param.condition and param.condition.min_fcv ]: - header_list.append('mongo/db/feature_compatibility_version_parser.h') - header_list.append('mongo/db/server_parameter.h') - header_list.append('mongo/db/server_parameter_with_storage.h') + header_list.append("mongo/db/feature_compatibility_version_parser.h") + header_list.append("mongo/db/server_parameter.h") + header_list.append("mongo/db/server_parameter_with_storage.h") # Include this for TypedCommand only if a base class will be generated for a command in this # file. if any(command.api_version for command in spec.commands): - header_list.append('mongo/db/commands.h') + header_list.append("mongo/db/commands.h") # Include serialization options only if there is a struct which is part of a query shape. if any(struct.query_shape_component for struct in spec.structs): - header_list.append('mongo/db/query/query_shape/serialization_options.h') + header_list.append("mongo/db/query/query_shape/serialization_options.h") header_list.sort() @@ -1266,16 +1437,17 @@ class _CppHeaderFileWriter(_CppFileWriterBase): for struct in all_structs: self.gen_description_comment(struct.description) with self.gen_class_declaration_block(struct.cpp_name): - self.write_unindented_line('public:') + self.write_unindented_line("public:") self.gen_field_enum(struct) if isinstance(struct, ast.Command): if struct.reply_type: # Alias the reply type as a named type for commands - self.gen_type_alias_declaration("Reply", - struct.reply_type.type.cpp_type) + self.gen_type_alias_declaration( + "Reply", struct.reply_type.type.cpp_type + ) else: - self._writer.write_line('using Reply = void;') + self._writer.write_line("using Reply = void;") # Generate a sorted list of string constants self.gen_string_constants_declarations(struct) @@ -1304,7 +1476,9 @@ class _CppHeaderFileWriter(_CppFileWriterBase): if field.description: self.gen_description_comment(field.description) self.gen_getter(struct, field) - if not struct.immutable or (field.type and field.type.internal_only): + if not struct.immutable or ( + field.type and field.type.internal_only + ): self.gen_setters(field) # Generate getters for any constexpr/compile-time struct data @@ -1314,8 +1488,8 @@ class _CppHeaderFileWriter(_CppFileWriterBase): if struct.generic_list_type: self.gen_field_list_entry_lookup_methods_struct(struct) - self.write_unindented_line('protected:') - if (struct.is_view): + self.write_unindented_line("protected:") + if struct.is_view: # If the struct is not a view type, then a BSONObj anchor is not needed because we # know the struct owns all of its data. self.gen_protected_ownership_setters(struct) @@ -1325,13 +1499,17 @@ class _CppHeaderFileWriter(_CppFileWriterBase): # Write private validators if [field for field in struct.fields if field.validator]: - self.write_unindented_line('private:') + self.write_unindented_line("private:") for field in struct.fields: - if not field.ignore and not struct.immutable and field.validator: + if ( + not field.ignore + and not struct.immutable + and field.validator + ): self.gen_validators(field) - self.write_unindented_line('private:') - self._writer.write_line('struct FieldInfo;') + self.write_unindented_line("private:") + self._writer.write_line("struct FieldInfo;") self.gen_required_field_enum(struct) self.write_empty_line() @@ -1361,8 +1539,9 @@ class _CppHeaderFileWriter(_CppFileWriterBase): # non-debug builds, and is marked MONGO_COMPILER_NO_UNIQUE_ADDRESS so that the # compiler knows that it should be optimized to take up no space when possible. self._writer.write_line( - 'MONGO_COMPILER_NO_UNIQUE_ADDRESS mongo::idl::HasMembers<%s> _hasMembers;' % - len(_get_required_fields(struct))) + "MONGO_COMPILER_NO_UNIQUE_ADDRESS mongo::idl::HasMembers<%s> _hasMembers;" + % len(_get_required_fields(struct)) + ) # Write constexpr struct data self.gen_constexpr_members(struct) @@ -1370,14 +1549,22 @@ class _CppHeaderFileWriter(_CppFileWriterBase): for scp in spec.server_parameters: if scp.cpp_class is None: - self._gen_exported_constexpr(scp.name, 'Default', scp.default, scp.condition) - self._gen_extern_declaration(scp.cpp_vartype, scp.cpp_varname, scp.condition) + self._gen_exported_constexpr( + scp.name, "Default", scp.default, scp.condition + ) + self._gen_extern_declaration( + scp.cpp_vartype, scp.cpp_varname, scp.condition + ) self.gen_server_parameter_class(scp) if spec.configs: for opt in spec.configs: - self._gen_exported_constexpr(opt.name, 'Default', opt.default, opt.condition) - self._gen_extern_declaration(opt.cpp_vartype, opt.cpp_varname, opt.condition) + self._gen_exported_constexpr( + opt.name, "Default", opt.default, opt.condition + ) + self._gen_extern_declaration( + opt.cpp_vartype, opt.cpp_varname, opt.condition + ) self._gen_config_function_declaration(spec) # Write a base class for each command in API Version 1. @@ -1392,8 +1579,10 @@ class _CppHeaderFileWriter(_CppFileWriterBase): cpp_namespace = idl_enum.cpp_namespace cpp_name = enum_type_info.get_cpp_type_name() full_cpp_name = "::{}::{}".format(cpp_namespace, cpp_name) - self._writer.write_line("template<> constexpr inline size_t idlEnumCount<%s> = %d;" - % (full_cpp_name, len(idl_enum.values))) + self._writer.write_line( + "template<> constexpr inline size_t idlEnumCount<%s> = %d;" + % (full_cpp_name, len(idl_enum.values)) + ) class _CppSourceFileWriter(_CppFileWriterBase): @@ -1414,33 +1603,43 @@ class _CppSourceFileWriter(_CppFileWriterBase): It contains static metadata of struct's members. """ cls = common.title_case(struct.cpp_name) - with self._block(f'struct {cls}::FieldInfo {{', '};'): + with self._block(f"struct {cls}::FieldInfo {{", "};"): for func, pred in ( - ('findStructField', lambda f: True), - ('findRequiredField', _is_required_serializer_field), - ('findForwardingDisabledField', _is_forwarding_disabled), - ('findParsedField', _is_parse), + ("findStructField", lambda f: True), + ("findRequiredField", _is_required_serializer_field), + ("findForwardingDisabledField", _is_forwarding_disabled), + ("findParsedField", _is_parse), ): selected_fields = [x for x in filter(pred, struct.fields)] - self._writer.write_line('template ') + self._writer.write_line("template ") with self._block( - f'static auto {func}(StringData s, const OnMatch& onMatch, const OnFail& onFail) {{', - '};'): + f"static auto {func}(StringData s, const OnMatch& onMatch, const OnFail& onFail) {{", + "};", + ): if len(selected_fields) == 0: - self._writer.write_line('return onFail();') + self._writer.write_line("return onFail();") else: - with self._block('static constexpr auto adaptMatch = [](int i) {', '};'): - with self._block('static constexpr auto arr = std::to_array({', - '});'): + with self._block( + "static constexpr auto adaptMatch = [](int i) {", "};" + ): + with self._block( + "static constexpr auto arr = std::to_array({", + "});", + ): for f in selected_fields: - self._writer.write_line(f'{_get_field_enum(f)},') - self._writer.write_line('return arr[i];') + self._writer.write_line(f"{_get_field_enum(f)},") + self._writer.write_line("return arr[i];") writer.gen_string_table_find_function_block( - self._writer, 's', 'onMatch(adaptMatch({}))', 'onFail()', - [f.name for f in selected_fields]) + self._writer, + "s", + "onMatch(adaptMatch({}))", + "onFail()", + [f.name for f in selected_fields], + ) - def _gen_field_deserializer_expression(self, element_name, field, ast_type, tenant, - is_catalog_ctxt): + def _gen_field_deserializer_expression( + self, element_name, field, ast_type, tenant, is_catalog_ctxt + ): # type: (str, ast.Field, ast.Type, str, bool) -> str """ Generate the C++ deserializer piece for a field. @@ -1450,17 +1649,25 @@ class _CppSourceFileWriter(_CppFileWriterBase): """ serialization_context = "getSerializationContext()" if ast_type.is_struct: - validated_tenancy_scope = 'ctxt.getValidatedTenancyScope()' - if 'request' in tenant: - validated_tenancy_scope = 'request.validatedTenancyScope' - self._writer.write_line('IDLParserContext tempContext(%s, &ctxt, %s, %s, %s);' % - (_get_field_constant_name(field), validated_tenancy_scope, - serialization_context, tenant)) - self._writer.write_line('const auto localObject = %s.Obj();' % (element_name)) - return '%s::parse(tempContext, localObject)' % (ast_type.cpp_type, ) - elif ast_type.deserializer and 'BSONElement::' in ast_type.deserializer: + validated_tenancy_scope = "ctxt.getValidatedTenancyScope()" + if "request" in tenant: + validated_tenancy_scope = "request.validatedTenancyScope" + self._writer.write_line( + "IDLParserContext tempContext(%s, &ctxt, %s, %s, %s);" + % ( + _get_field_constant_name(field), + validated_tenancy_scope, + serialization_context, + tenant, + ) + ) + self._writer.write_line( + "const auto localObject = %s.Obj();" % (element_name) + ) + return "%s::parse(tempContext, localObject)" % (ast_type.cpp_type,) + elif ast_type.deserializer and "BSONElement::" in ast_type.deserializer: method_name = writer.get_method_name(ast_type.deserializer) - return '%s.%s()' % (element_name, method_name) + return "%s.%s()" % (element_name, method_name) assert not ast_type.is_variant @@ -1472,166 +1679,242 @@ class _CppSourceFileWriter(_CppFileWriterBase): # Class Class::method(StringData value) # or # Class::method(const BSONObj& value) - expression = bson_cpp_type.gen_deserializer_expression(self._writer, element_name) + expression = bson_cpp_type.gen_deserializer_expression( + self._writer, element_name + ) if ast_type.deserializer: method_name = writer.get_method_name_from_qualified_method_name( - ast_type.deserializer) + ast_type.deserializer + ) # For fields which are enums, pass a IDLParserContext if ast_type.is_enum: - validated_tenancy_scope = 'ctxt.getValidatedTenancyScope()' - if 'request' in tenant: - validated_tenancy_scope = 'request.validatedTenancyScope' + validated_tenancy_scope = "ctxt.getValidatedTenancyScope()" + if "request" in tenant: + validated_tenancy_scope = "request.validatedTenancyScope" self._writer.write_line( - 'IDLParserContext tempContext(%s, &ctxt, %s, %s, %s);' % - (_get_field_constant_name(field), validated_tenancy_scope, - serialization_context, tenant)) - return common.template_args("${method_name}(tempContext, ${expression})", - method_name=method_name, expression=expression) + "IDLParserContext tempContext(%s, &ctxt, %s, %s, %s);" + % ( + _get_field_constant_name(field), + validated_tenancy_scope, + serialization_context, + tenant, + ) + ) + return common.template_args( + "${method_name}(tempContext, ${expression})", + method_name=method_name, + expression=expression, + ) if ast_type.deserialize_with_tenant: arguments = "${method_name}(${tenant}, ${expression}, ${context})" - if is_catalog_ctxt: # serializeForCatalog doesn't need a serializationContext + if ( + is_catalog_ctxt + ): # serializeForCatalog doesn't need a serializationContext arguments = "${method_name}(${tenant}, ${expression})" - method_name = method_name.replace("serialize", "serializeForCatalog") - return common.template_args(arguments, method_name=method_name, tenant=tenant, - expression=expression, - context=serialization_context) + method_name = method_name.replace( + "serialize", "serializeForCatalog" + ) + return common.template_args( + arguments, + method_name=method_name, + tenant=tenant, + expression=expression, + context=serialization_context, + ) else: - return common.template_args("${method_name}(${expression})", - method_name=method_name, expression=expression) + return common.template_args( + "${method_name}(${expression})", + method_name=method_name, + expression=expression, + ) # BSONObjects are allowed to be pass through without deserialization - assert ast_type.bson_serialization_type in [['object'], ['array']] + assert ast_type.bson_serialization_type in [["object"], ["array"]] return expression # Call a static class method with the signature: # Class Class::method(const BSONElement& value) - method_name = writer.get_method_name_from_qualified_method_name(ast_type.deserializer) + method_name = writer.get_method_name_from_qualified_method_name( + ast_type.deserializer + ) if ast_type.deserialize_with_tenant: - return '%s(%s, %s, %s)' % (method_name, tenant, element_name, serialization_context) + return "%s(%s, %s, %s)" % ( + method_name, + tenant, + element_name, + serialization_context, + ) else: - return '%s(%s)' % (method_name, element_name) + return "%s(%s)" % (method_name, element_name) - def _gen_array_deserializer(self, field, bson_element, ast_type, tenant, is_catalog_ctxt): + def _gen_array_deserializer( + self, field, bson_element, ast_type, tenant, is_catalog_ctxt + ): # type: (ast.Field, str, ast.Type, str, bool) -> None """Generate the C++ deserializer piece for an array field.""" assert ast_type.is_array - cpp_type_info = cpp_types.get_cpp_type_from_cpp_type_name(field, ast_type.cpp_type, True) + cpp_type_info = cpp_types.get_cpp_type_from_cpp_type_name( + field, ast_type.cpp_type, True + ) cpp_type = cpp_type_info.get_type_name() - self._writer.write_line('DecimalCounter expectedFieldNumber{0};') - validated_tenancy_scope = 'ctxt.getValidatedTenancyScope()' - if 'request' in tenant: - validated_tenancy_scope = 'request.validatedTenancyScope' - self._writer.write_line('const IDLParserContext arrayCtxt(%s, &ctxt, %s, %s, %s);' % - (_get_field_constant_name(field), validated_tenancy_scope, - "getSerializationContext()", tenant)) - self._writer.write_line('std::vector<%s> values;' % (cpp_type)) + self._writer.write_line("DecimalCounter expectedFieldNumber{0};") + validated_tenancy_scope = "ctxt.getValidatedTenancyScope()" + if "request" in tenant: + validated_tenancy_scope = "request.validatedTenancyScope" + self._writer.write_line( + "const IDLParserContext arrayCtxt(%s, &ctxt, %s, %s, %s);" + % ( + _get_field_constant_name(field), + validated_tenancy_scope, + "getSerializationContext()", + tenant, + ) + ) + self._writer.write_line("std::vector<%s> values;" % (cpp_type)) self._writer.write_empty_line() - self._writer.write_line('const BSONObj arrayObject = %s.Obj();' % (bson_element)) - - with self._block('for (const auto& arrayElement : arrayObject) {', '}'): + self._writer.write_line( + "const BSONObj arrayObject = %s.Obj();" % (bson_element) + ) + with self._block("for (const auto& arrayElement : arrayObject) {", "}"): self._writer.write_line( - 'const auto arrayFieldName = arrayElement.fieldNameStringData();') + "const auto arrayFieldName = arrayElement.fieldNameStringData();" + ) - with self._predicate('MONGO_likely(arrayFieldName == expectedFieldNumber)'): - check = _get_bson_type_check('arrayElement', 'arrayCtxt', ast_type) + with self._predicate("MONGO_likely(arrayFieldName == expectedFieldNumber)"): + check = _get_bson_type_check("arrayElement", "arrayCtxt", ast_type) with self._predicate(check): if ast_type.is_variant: # _gen_variant_deserializer generates code to parse the variant into the variable "_" + field.cpp_name, # so we create a local variable '_tmp'. self._writer.write_line("%s _tmp;" % ast_type.cpp_type) - self._gen_variant_deserializer(field, "_tmp", "arrayElement", tenant, - is_catalog_ctxt) + self._gen_variant_deserializer( + field, "_tmp", "arrayElement", tenant, is_catalog_ctxt + ) self._writer.write_line("values.push_back(std::move(_tmp));") else: array_value = self._gen_field_deserializer_expression( - 'arrayElement', field, ast_type, tenant, is_catalog_ctxt) - self._writer.write_line('values.push_back(%s);' % (array_value)) + "arrayElement", field, ast_type, tenant, is_catalog_ctxt + ) + self._writer.write_line("values.push_back(%s);" % (array_value)) - with self._block('else {', '}'): + with self._block("else {", "}"): self._writer.write_line( - 'arrayCtxt.throwBadArrayFieldNumberSequence(arrayFieldName, expectedFieldNumber);' + "arrayCtxt.throwBadArrayFieldNumberSequence(arrayFieldName, expectedFieldNumber);" ) - self._writer.write_line('++expectedFieldNumber;') + self._writer.write_line("++expectedFieldNumber;") if field.validator: - self._writer.write_line('%s(values);' % (_get_field_member_validator_name(field))) + self._writer.write_line( + "%s(values);" % (_get_field_member_validator_name(field)) + ) if field.chained_struct_field: if field.type.is_variant: - self._writer.write_line('%s.%s(%s(std::move(values)));' % - (_get_field_member_name(field.chained_struct_field), - _get_field_member_setter_name(field), field.type.cpp_type)) + self._writer.write_line( + "%s.%s(%s(std::move(values)));" + % ( + _get_field_member_name(field.chained_struct_field), + _get_field_member_setter_name(field), + field.type.cpp_type, + ) + ) else: - self._writer.write_line('%s.%s(std::move(values));' % (_get_field_member_name( - field.chained_struct_field), _get_field_member_setter_name(field))) + self._writer.write_line( + "%s.%s(std::move(values));" + % ( + _get_field_member_name(field.chained_struct_field), + _get_field_member_setter_name(field), + ) + ) else: - self._writer.write_line('%s = std::move(values);' % (_get_field_member_name(field))) + self._writer.write_line( + "%s = std::move(values);" % (_get_field_member_name(field)) + ) - def _gen_variant_deserializer(self, field, field_name, bson_element, tenant, is_catalog_ctxt): + def _gen_variant_deserializer( + self, field, field_name, bson_element, tenant, is_catalog_ctxt + ): # type: (ast.Field, str, str, str, bool) -> None """Generate the C++ deserializer piece for a variant field.""" self._writer.write_empty_line() - self._writer.write_line('const BSONType variantType = %s.type();' % (bson_element, )) + self._writer.write_line( + "const BSONType variantType = %s.type();" % (bson_element,) + ) array_types = [v for v in field.type.variant_types if v.is_array] scalar_types = [v for v in field.type.variant_types if not v.is_array] - self._writer.write_line('switch (variantType) {') + self._writer.write_line("switch (variantType) {") if array_types: - self._writer.write_line('case Array:') - with self._block('{', '}'): + self._writer.write_line("case Array:") + with self._block("{", "}"): # If the array is empty, we can't infer its element type. Use the first # array type as a fallback to cover that case. fallback_type = array_types[0].bson_serialization_type[0] fallback_bson_cpp_type = bson.cpp_bson_type_name(fallback_type) - condition = '%s.Obj().isEmpty()' % (bson_element, ) + condition = "%s.Obj().isEmpty()" % (bson_element,) self._writer.write_line( - 'const BSONType elemType = %s ? %s : %s.Obj().firstElement().type();' % - (condition, fallback_bson_cpp_type, bson_element)) + "const BSONType elemType = %s ? %s : %s.Obj().firstElement().type();" + % (condition, fallback_bson_cpp_type, bson_element) + ) # Start inner switch statement, for each type the first element could be. - self._writer.write_line('switch (elemType) {') + self._writer.write_line("switch (elemType) {") for array_type in array_types: for bson_type in array_type.bson_serialization_type: - self._writer.write_line('case %s:' % (bson.cpp_bson_type_name(bson_type), )) + self._writer.write_line( + "case %s:" % (bson.cpp_bson_type_name(bson_type),) + ) # Each copy of the array deserialization code gets an anonymous block. - with self._block('{', '}'): - self._gen_array_deserializer(field, bson_element, array_type, tenant, - is_catalog_ctxt) - self._writer.write_line('break;') + with self._block("{", "}"): + self._gen_array_deserializer( + field, bson_element, array_type, tenant, is_catalog_ctxt + ) + self._writer.write_line("break;") - self._writer.write_line('default:') + self._writer.write_line("default:") self._writer.indent() expected_types = [ - bson.cpp_bson_type_name(t.bson_serialization_type[0]) for t in array_types + bson.cpp_bson_type_name(t.bson_serialization_type[0]) + for t in array_types ] self._writer.write_line( f'ctxt.throwBadType({bson_element}, {_std_array_expr("BSONType", expected_types)});' ) - self._writer.write_line('break;') + self._writer.write_line("break;") self._writer.unindent() # End of inner switch. - self._writer.write_line('}') + self._writer.write_line("}") # End of "case Array:". - self._writer.write_line('break;') + self._writer.write_line("break;") for scalar_type in scalar_types: for bson_type in scalar_type.bson_serialization_type: - self._writer.write_line('case %s:' % (bson.cpp_bson_type_name(bson_type), )) - with self._block('{', '}'): - self.gen_field_deserializer(field, scalar_type, "bsonObject", bson_element, - None, tenant, False, check_type=False, - is_catalog_ctxt=is_catalog_ctxt) - self._writer.write_line('break;') + self._writer.write_line( + "case %s:" % (bson.cpp_bson_type_name(bson_type),) + ) + with self._block("{", "}"): + self.gen_field_deserializer( + field, + scalar_type, + "bsonObject", + bson_element, + None, + tenant, + False, + check_type=False, + is_catalog_ctxt=is_catalog_ctxt, + ) + self._writer.write_line("break;") if field.type.variant_struct_types: with self._block("case Object: {", "} break;"): @@ -1643,46 +1926,56 @@ class _CppSourceFileWriter(_CppFileWriterBase): from_doc_seq=False, ) - self._writer.write_line('default:') + self._writer.write_line("default:") self._writer.indent() expected_types = [ bson.cpp_bson_type_name(t.bson_serialization_type[0]) for t in array_types ] - self._writer.write_line(f'ctxt.throwBadType({bson_element}, ' - f'{_std_array_expr("BSONType", expected_types)});') - self._writer.write_line('break;') + self._writer.write_line( + f'ctxt.throwBadType({bson_element}, ' + f'{_std_array_expr("BSONType", expected_types)});' + ) + self._writer.write_line("break;") self._writer.unindent() # End of outer switch statement. - self._writer.write_line('}') + self._writer.write_line("}") - def _gen_variant_deserializer_from_obj(self, field, field_name, bson_element, tenant, - from_doc_seq): + def _gen_variant_deserializer_from_obj( + self, field, field_name, bson_element, tenant, from_doc_seq + ): def on_variant_alternative_match(variant_type): assert variant_type.is_struct validated_tenancy_scope = "ctxt.getValidatedTenancyScope()" if "request" in tenant: validated_tenancy_scope = "request.validatedTenancyScope" - self._writer.write_line("IDLParserContext tempContext(%s, &ctxt, %s, %s, %s);" % ( - _get_field_constant_name(field), - validated_tenancy_scope, - "getSerializationContext()", - tenant, - )) + self._writer.write_line( + "IDLParserContext tempContext(%s, &ctxt, %s, %s, %s);" + % ( + _get_field_constant_name(field), + validated_tenancy_scope, + "getSerializationContext()", + tenant, + ) + ) if from_doc_seq: - value_expr = f"{variant_type.cpp_type}::parse(tempContext, {bson_element})" + value_expr = ( + f"{variant_type.cpp_type}::parse(tempContext, {bson_element})" + ) else: - self._writer.write_line("const auto localObject = %s.Obj();" % (bson_element)) + self._writer.write_line( + "const auto localObject = %s.Obj();" % (bson_element) + ) value_expr = f"{variant_type.cpp_type}::parse(tempContext, localObject)" if field.optional: cpp_type_info = cpp_types.get_cpp_type(field) - value_expr = f'{cpp_type_info.get_getter_setter_type()}({value_expr})' + value_expr = f"{cpp_type_info.get_getter_setter_type()}({value_expr})" if field.chained_struct_field: chain_source = _get_field_member_name(field.chained_struct_field) setter = _get_field_member_setter_name(field) - self._writer.write_line(f'{chain_source}.{setter}({value_expr});') + self._writer.write_line(f"{chain_source}.{setter}({value_expr});") else: - self._writer.write_line(f'{field_name} = {value_expr};') + self._writer.write_line(f"{field_name} = {value_expr};") struct_types_list = field.type.variant_struct_types if len(struct_types_list) == 1: @@ -1696,13 +1989,17 @@ class _CppSourceFileWriter(_CppFileWriterBase): with self._block("auto onMatch = [&](int found) {", "};"): with self._block("switch (found) {", "}"): for idx, variant_type in enumerate(struct_types_list): - with self._block(f'case {idx}: {{', '} break;'): + with self._block(f"case {idx}: {{", "} break;"): on_variant_alternative_match(variant_type) - with self._block('auto onFail = [&] {', '};'): - self._writer.write_line('ctxt.throwUnknownField(s);') + with self._block("auto onFail = [&] {", "};"): + self._writer.write_line("ctxt.throwUnknownField(s);") writer.gen_string_table_find_function_block( - self._writer, 's', 'onMatch({})', 'onFail()', - [f.first_element_field_name for f in struct_types_list]) + self._writer, + "s", + "onMatch({})", + "onFail()", + [f.first_element_field_name for f in struct_types_list], + ) def _gen_usage_check(self, field, bson_element, field_usage_check): # type: (ast.Field, str, _FieldUsageCheckerBase) -> None @@ -1713,9 +2010,18 @@ class _CppSourceFileWriter(_CppFileWriterBase): if _is_required_serializer_field(field): self._writer.write_line(_gen_mark_present(field.cpp_name)) - def gen_field_deserializer(self, field, field_type, bson_object, bson_element, - field_usage_check, tenant, is_command_field=False, check_type=True, - is_catalog_ctxt=False): + def gen_field_deserializer( + self, + field, + field_type, + bson_object, + bson_element, + field_usage_check, + tenant, + is_command_field=False, + check_type=True, + is_catalog_ctxt=False, + ): # type: (ast.Field, ast.Type, str, str, _FieldUsageCheckerBase, str, bool, bool, bool) -> None """Generate the C++ deserializer piece for a field. @@ -1728,17 +2034,25 @@ class _CppSourceFileWriter(_CppFileWriterBase): return if field_type.is_array: - predicate = "MONGO_likely(ctxt.checkAndAssertType(%s, Array))" % (bson_element) + predicate = "MONGO_likely(ctxt.checkAndAssertType(%s, Array))" % ( + bson_element + ) with self._predicate(predicate): self._gen_usage_check(field, bson_element, field_usage_check) - self._gen_array_deserializer(field, bson_element, field_type, tenant, - is_catalog_ctxt) + self._gen_array_deserializer( + field, bson_element, field_type, tenant, is_catalog_ctxt + ) return elif field_type.is_variant: self._gen_usage_check(field, bson_element, field_usage_check) - self._gen_variant_deserializer(field, _get_field_member_name(field), bson_element, - tenant, is_catalog_ctxt) + self._gen_variant_deserializer( + field, + _get_field_member_name(field), + bson_element, + tenant, + is_catalog_ctxt, + ) return def validate_and_assign_or_uassert(field, expression): @@ -1746,13 +2060,15 @@ class _CppSourceFileWriter(_CppFileWriterBase): """Perform field value validation post-assignment.""" field_name = _get_field_member_name(field) if field.validator is None: - self._writer.write_line('%s = %s;' % (field_name, expression)) + self._writer.write_line("%s = %s;" % (field_name, expression)) return - with self._block('{', '}'): - self._writer.write_line('auto value = %s;' % (expression)) - self._writer.write_line('%s(value);' % (_get_field_member_validator_name(field))) - self._writer.write_line('%s = std::move(value);' % (field_name)) + with self._block("{", "}"): + self._writer.write_line("auto value = %s;" % (expression)) + self._writer.write_line( + "%s(value);" % (_get_field_member_validator_name(field)) + ) + self._writer.write_line("%s = std::move(value);" % (field_name)) if field.chained: # Do not generate a predicate check since we always call these deserializers. @@ -1760,10 +2076,11 @@ class _CppSourceFileWriter(_CppFileWriterBase): if field_type.is_struct: # Do not generate a new parser context, reuse the current one since we are not # entering a nested document. - expression = '%s::parse(ctxt, %s)' % (field_type.cpp_type, bson_object) + expression = "%s::parse(ctxt, %s)" % (field_type.cpp_type, bson_object) else: method_name = writer.get_method_name_from_qualified_method_name( - field_type.deserializer) + field_type.deserializer + ) expression = "%s(%s)" % (method_name, bson_object) self._gen_usage_check(field, bson_element, field_usage_check) @@ -1772,33 +2089,42 @@ class _CppSourceFileWriter(_CppFileWriterBase): else: predicate = None if check_type: - predicate = _get_bson_type_check(bson_element, 'ctxt', field_type) + predicate = _get_bson_type_check(bson_element, "ctxt", field_type) with self._predicate(predicate): - self._gen_usage_check(field, bson_element, field_usage_check) object_value = self._gen_field_deserializer_expression( - bson_element, field, field_type, tenant, is_catalog_ctxt) + bson_element, field, field_type, tenant, is_catalog_ctxt + ) if field.chained_struct_field: if field.optional: # We must invoke the boost::optional constructor when setting optional view # types cpp_type_info = cpp_types.get_cpp_type(field) - object_value = '%s(%s)' % (cpp_type_info.get_getter_setter_type(), - object_value) + object_value = "%s(%s)" % ( + cpp_type_info.get_getter_setter_type(), + object_value, + ) # No need for explicit validation as setter will throw for us. self._writer.write_line( - '%s.%s(%s);' % (_get_field_member_name(field.chained_struct_field), - _get_field_member_setter_name(field), object_value)) + "%s.%s(%s);" + % ( + _get_field_member_name(field.chained_struct_field), + _get_field_member_setter_name(field), + object_value, + ) + ) else: validate_and_assign_or_uassert(field, object_value) if is_command_field and predicate: - with self._block('else {', '}'): + with self._block("else {", "}"): self._writer.write_line( - 'ctxt.throwMissingField(%s);' % (_get_field_constant_name(field))) + "ctxt.throwMissingField(%s);" + % (_get_field_constant_name(field)) + ) def gen_doc_sequence_deserializer(self, field, tenant): # type: (ast.Field, str) -> None @@ -1808,28 +2134,37 @@ class _CppSourceFileWriter(_CppFileWriterBase): # If field (cpp_type) is the same type as sequence.objs, just copy and skip loop if cpp_type == "mongo::BSONObj" and not field.type.deserializer: - self._writer.write_line('%s = sequence.objs;' % (_get_field_member_name(field))) + self._writer.write_line( + "%s = sequence.objs;" % (_get_field_member_name(field)) + ) return - self._writer.write_line('std::vector<%s> values;' % (cpp_type)) - self._writer.write_line('values.reserve(sequence.objs.size());') + self._writer.write_line("std::vector<%s> values;" % (cpp_type)) + self._writer.write_line("values.reserve(sequence.objs.size());") self._writer.write_empty_line() # TODO: add support for sequence length checks, today we allow an empty document sequence # because we do not give a way for IDL specifications to specify if they allow empty # sequences or require non-empty sequences. - with self._block('for (const BSONObj& sequenceObject : sequence.objs) {', '}'): - + with self._block("for (const BSONObj& sequenceObject : sequence.objs) {", "}"): # Either we are deserializing BSON Objects or IDL structs if field.type.is_struct: - validated_tenancy_scope = 'ctxt.getValidatedTenancyScope()' - if 'request' in tenant: - validated_tenancy_scope = 'request.validatedTenancyScope' - self._writer.write_line('IDLParserContext tempContext(%s, &ctxt, %s, %s, %s);' % - (_get_field_constant_name(field), validated_tenancy_scope, - "getSerializationContext()", tenant)) - array_value = '%s::parse(tempContext, sequenceObject)' % (field.type.cpp_type, ) + validated_tenancy_scope = "ctxt.getValidatedTenancyScope()" + if "request" in tenant: + validated_tenancy_scope = "request.validatedTenancyScope" + self._writer.write_line( + "IDLParserContext tempContext(%s, &ctxt, %s, %s, %s);" + % ( + _get_field_constant_name(field), + validated_tenancy_scope, + "getSerializationContext()", + tenant, + ) + ) + array_value = "%s::parse(tempContext, sequenceObject)" % ( + field.type.cpp_type, + ) elif field.type.is_variant: self._writer.write_line("%s _tmp;" % field.type.cpp_type) self._gen_variant_deserializer_from_obj( @@ -1842,15 +2177,17 @@ class _CppSourceFileWriter(_CppFileWriterBase): array_value = "_tmp" else: for serialization_type in field.type.bson_serialization_type: - assert serialization_type == 'object' + assert serialization_type == "object" if field.type.deserializer: - array_value = '%s(sequenceObject)' % (field.type.deserializer) + array_value = "%s(sequenceObject)" % (field.type.deserializer) else: array_value = "sequenceObject" - self._writer.write_line('values.emplace_back(%s);' % (array_value)) + self._writer.write_line("values.emplace_back(%s);" % (array_value)) - self._writer.write_line('%s = std::move(values);' % (_get_field_member_name(field))) + self._writer.write_line( + "%s = std::move(values);" % (_get_field_member_name(field)) + ) def gen_op_msg_request_namespace_check(self, struct): # type: (ast.Struct) -> None @@ -1862,10 +2199,10 @@ class _CppSourceFileWriter(_CppFileWriterBase): # Get the Command element if we need it for later in the deserializer to get the # namespace if struct.namespace != common.COMMAND_NAMESPACE_IGNORED: - self._writer.write_line('commandElement = element;') + self._writer.write_line("commandElement = element;") - self._writer.write_line('firstFieldFound = true;') - self._writer.write_line('continue;') + self._writer.write_line("firstFieldFound = true;") + self._writer.write_line("continue;") def _gen_initializer_vars(self, constructor, is_command, is_catalog_ctxt): # type: (struct_types.MethodInfo, bool, bool) -> List[str] @@ -1895,20 +2232,20 @@ class _CppSourceFileWriter(_CppFileWriterBase): # init list is subject to the order in the hardcoded list passed to the constructor. # For now, _serializationContext is the only internal_only type, so additional work # around this will be deferred. - if arg.name == 'serializationContext': + if arg.name == "serializationContext": if is_command: - sc_conditional = 'SerializationContext::stateCommandRequest()' + sc_conditional = "SerializationContext::stateCommandRequest()" else: - sc_conditional = '_isCommandReply ? SerializationContext::stateCommandReply() : SerializationContext::stateDefault()' + sc_conditional = "_isCommandReply ? SerializationContext::stateCommandReply() : SerializationContext::stateDefault()" # this obj is passed in as a boost::optional, so we set the default if no value - initializer_var = arg.name + '.value_or(%s)' % sc_conditional + initializer_var = arg.name + ".value_or(%s)" % sc_conditional # local _serializationContext obj used to init other structs, so it needs to be # initialized first; don't move in the event a boost::none is supplied - initializer_vars.insert(0, '_%s(%s)' % (arg.name, initializer_var)) + initializer_vars.insert(0, "_%s(%s)" % (arg.name, initializer_var)) else: - initializer_vars.append('_%s(std::move(%s))' % (arg.name, arg.name)) + initializer_vars.append("_%s(std::move(%s))" % (arg.name, arg.name)) return initializer_vars @@ -1916,16 +2253,21 @@ class _CppSourceFileWriter(_CppFileWriterBase): # type: (ast.Struct, struct_types.MethodInfo, bool) -> None """Generate the C++ constructor definition.""" - initializers = self._gen_initializer_vars(constructor, isinstance(struct, ast.Command), - struct.is_catalog_ctxt) + initializers = self._gen_initializer_vars( + constructor, isinstance(struct, ast.Command), struct.is_catalog_ctxt + ) # Serialize non-has fields first # Initialize int and other primitive fields to -1 to prevent Coverity warnings. if default_init: for field in struct.fields: - needs_init = (field.type and field.type.cpp_type and not field.type.is_array - and _is_required_serializer_field(field) - and field.cpp_name != 'dbName') + needs_init = ( + field.type + and field.type.cpp_type + and not field.type.is_array + and _is_required_serializer_field(field) + and field.cpp_name != "dbName" + ) if needs_init: # As per _gen_initializer_vars(), we initialize _serializationContext first, # before anything else in the initializer list because it can be consumed by @@ -1933,36 +2275,54 @@ class _CppSourceFileWriter(_CppFileWriterBase): # If the current field is a nested struct, we need to pass the initialized # _serializationContext into the nested struct. - serialization_ctx_arg = '_serializationContext' if field.type and field.type.is_struct else '' + serialization_ctx_arg = ( + "_serializationContext" + if field.type and field.type.is_struct + else "" + ) - initializers.append('%s(mongo::idl::preparsedValue(%s))' - % (_get_field_member_name(field), - _get_field_member_name(field), serialization_ctx_arg)) + initializers.append( + "%s(mongo::idl::preparsedValue(%s))" + % ( + _get_field_member_name(field), + _get_field_member_name(field), + serialization_ctx_arg, + ) + ) # Serialize the _dbName field second. # Use the class member to perform the initialization instead of the constructor # argument, since we're guaranteed to have already written the initializer for the member earlier # in the list via _gen_initializer_vars above. initializes_db_name = False - if [arg for arg in constructor.args if arg.name == 'nss']: - if [field for field in struct.fields if field.serialize_op_msg_request_only]: - initializers.append('_dbName(_nss.dbName())') + if [arg for arg in constructor.args if arg.name == "nss"]: + if [ + field for field in struct.fields if field.serialize_op_msg_request_only + ]: + initializers.append("_dbName(_nss.dbName())") initializes_db_name = True - elif [arg for arg in constructor.args if arg.name == 'nssOrUUID']: - if [field for field in struct.fields if field.serialize_op_msg_request_only]: - initializers.append('_dbName(_nssOrUUID.dbName())') + elif [arg for arg in constructor.args if arg.name == "nssOrUUID"]: + if [ + field for field in struct.fields if field.serialize_op_msg_request_only + ]: + initializers.append("_dbName(_nssOrUUID.dbName())") initializes_db_name = True - initializers_str = '' + initializers_str = "" if initializers: - initializers_str = ': ' + ', '.join(initializers) + initializers_str = ": " + ", ".join(initializers) - with self._block('%s %s {' % (constructor.get_definition(), initializers_str), '}'): + with self._block( + "%s %s {" % (constructor.get_definition(), initializers_str), "}" + ): for field in _get_required_fields(struct): - if not (field.name == "$db" and initializes_db_name) and not default_init: + if ( + not (field.name == "$db" and initializes_db_name) + and not default_init + ): self._writer.write_line(_gen_mark_present(field.cpp_name)) if initializes_db_name: - self._writer.write_line(_gen_mark_present('dbName')) + self._writer.write_line(_gen_mark_present("dbName")) self._writer.write_empty_line() def gen_constructors(self, struct): @@ -1984,17 +2344,17 @@ class _CppSourceFileWriter(_CppFileWriterBase): """Generate the definitions for generic argument or reply field lookup methods.""" field_list_info = generic_field_list_types.get_field_list_info(struct) has_field_defn = field_list_info.get_has_field_method().get_definition() - with self._block(f'{has_field_defn} {{', '}'): + with self._block(f"{has_field_defn} {{", "}"): self._writer.write_line( - 'return FieldInfo::findStructField(fieldName, [](auto) { return true; }, []{ return false; });' + "return FieldInfo::findStructField(fieldName, [](auto) { return true; }, []{ return false; });" ) self._writer.write_empty_line() should_fwd_defn = field_list_info.get_should_forward_method().get_definition() - with self._block(f'{should_fwd_defn} {{', '}'): + with self._block(f"{should_fwd_defn} {{", "}"): self._writer.write_line( - 'return FieldInfo::findForwardingDisabledField(fieldName, [](auto) { return false; }, []{ return true; });' + "return FieldInfo::findForwardingDisabledField(fieldName, [](auto) { return false; }, []{ return true; });" ) self._writer.write_empty_line() @@ -2004,16 +2364,25 @@ class _CppSourceFileWriter(_CppFileWriterBase): """Generate the command field deserializer.""" if isinstance(struct, ast.Command) and struct.command_field: - with self._block('{', '}'): - self.gen_field_deserializer(struct.command_field, struct.command_field.type, - bson_object, "commandElement", None, tenant, - is_command_field=True, check_type=True, - is_catalog_ctxt=struct.is_catalog_ctxt) + with self._block("{", "}"): + self.gen_field_deserializer( + struct.command_field, + struct.command_field.type, + bson_object, + "commandElement", + None, + tenant, + is_command_field=True, + check_type=True, + is_catalog_ctxt=struct.is_catalog_ctxt, + ) else: struct_type_info = struct_types.get_struct_info(struct) # Generate namespace check now that "$db" has been read or defaulted - struct_type_info.gen_namespace_check(self._writer, "_dbName", "commandElement") + struct_type_info.gen_namespace_check( + self._writer, "_dbName", "commandElement" + ) def _gen_fields_deserializer_common(self, struct, bson_object, tenant): # type: (ast.Struct, str, str) -> _FieldUsageCheckerBase @@ -2022,35 +2391,42 @@ class _CppSourceFileWriter(_CppFileWriterBase): field_usage_check = _get_field_usage_checker(self._writer, struct) if isinstance(struct, ast.Command): if struct.namespace != common.COMMAND_NAMESPACE_IGNORED: - self._writer.write_line('BSONElement commandElement;') - self._writer.write_line('bool firstFieldFound = false;') + self._writer.write_line("BSONElement commandElement;") + self._writer.write_line("bool firstFieldFound = false;") self._writer.write_empty_line() # Update the serialization context whether or not we received a tenantId object - if tenant == 'request.getValidatedTenantId()': + if tenant == "request.getValidatedTenantId()": # inject a context into the IDLParserContext that tags the class as a command request self._writer.write_line( - 'setSerializationContext(SerializationContext::stateCommandRequest());') + "setSerializationContext(SerializationContext::stateCommandRequest());" + ) with self._block( - 'if (request.validatedTenancyScope != boost::none && request.validatedTenancyScope->isFromAtlasProxy()) {', - '}'): - self._writer.write_line('_serializationContext.setPrefixState(true);') + "if (request.validatedTenancyScope != boost::none && request.validatedTenancyScope->isFromAtlasProxy()) {", + "}", + ): + self._writer.write_line( + "_serializationContext.setPrefixState(true);" + ) else: # if a non-default serialization context was passed in via the IDLParserContext, # use that to set the local serialization context, otherwise set it to a command # request with self._block( - 'if (ctxt.getSerializationContext() != SerializationContext::stateDefault()) {', - '}'): + "if (ctxt.getSerializationContext() != SerializationContext::stateDefault()) {", + "}", + ): self._writer.write_line( - 'setSerializationContext(ctxt.getSerializationContext());') - with self._block('else {', '}'): + "setSerializationContext(ctxt.getSerializationContext());" + ) + with self._block("else {", "}"): self._writer.write_line( - 'setSerializationContext(SerializationContext::stateCommandRequest());') + "setSerializationContext(SerializationContext::stateCommandRequest());" + ) else: self._writer.write_line( - 'setSerializationContext(_isCommandReply ? SerializationContext::stateCommandReply() : ctxt.getSerializationContext());' + "setSerializationContext(_isCommandReply ? SerializationContext::stateCommandReply() : ctxt.getSerializationContext());" ) self._writer.write_empty_line() @@ -2062,16 +2438,25 @@ class _CppSourceFileWriter(_CppFileWriterBase): if field.ignore: field_usage_check.add(field, "element") - self._writer.write_line('// ignore field') + self._writer.write_line("// ignore field") else: - self.gen_field_deserializer(field, field.type, bson_object, "element", - field_usage_check, tenant, False, True, - struct.is_catalog_ctxt) - self._writer.write_line('return true;') + self.gen_field_deserializer( + field, + field.type, + bson_object, + "element", + field_usage_check, + tenant, + False, + True, + struct.is_catalog_ctxt, + ) + self._writer.write_line("return true;") - with self._block('for (const auto& element :%s) {' % (bson_object), '}'): - - self._writer.write_line('const auto fieldName = element.fieldNameStringData();') + with self._block("for (const auto& element :%s) {" % (bson_object), "}"): + self._writer.write_line( + "const auto fieldName = element.fieldNameStringData();" + ) self._writer.write_empty_line() if isinstance(struct, ast.Command): @@ -2079,24 +2464,27 @@ class _CppSourceFileWriter(_CppFileWriterBase): # Get the Command element if we need it for later in the deserializer to get the # namespace if struct.namespace != common.COMMAND_NAMESPACE_IGNORED: - self._writer.write_line('commandElement = element;') + self._writer.write_line("commandElement = element;") - self._writer.write_line('firstFieldFound = true;') - self._writer.write_line('continue;') + self._writer.write_line("firstFieldFound = true;") + self._writer.write_line("continue;") parsed_fields = [f for f in struct.fields if _is_parse(f)] if parsed_fields: - with self._block('auto onMatch = [&](Field f) {', '};'): - with self._block('switch (f) {', '}'): + with self._block("auto onMatch = [&](Field f) {", "};"): + with self._block("switch (f) {", "}"): for f in parsed_fields: - with self._block(f'case {_get_field_enum(f)}: {{', '} break;'): + with self._block( + f"case {_get_field_enum(f)}: {{", "} break;" + ): map_field(f.name) - self._writer.write_line('default: return false;') - self._writer.write_line('return false;') - with self._block('auto onFail = [] {', '};'): - self._writer.write_line('return false;') + self._writer.write_line("default: return false;") + self._writer.write_line("return false;") + with self._block("auto onFail = [] {", "};"): + self._writer.write_line("return false;") self._writer.write_line( - 'if (FieldInfo::findParsedField(fieldName, onMatch, onFail)) continue;') + "if (FieldInfo::findParsedField(fieldName, onMatch, onFail)) continue;" + ) # End of for fields # Generate strict check for extranous fields @@ -2112,73 +2500,104 @@ class _CppSourceFileWriter(_CppFileWriterBase): command_predicate = "!mongo::isGenericReply(fieldName)" with self._predicate(command_predicate): - self._writer.write_line('ctxt.throwUnknownField(fieldName);') + self._writer.write_line("ctxt.throwUnknownField(fieldName);") elif not struct.unsafe_dangerous_disable_extra_field_duplicate_checks: - self._writer.write_line('auto push_result = usedFieldSet.insert(fieldName);') + self._writer.write_line( + "auto push_result = usedFieldSet.insert(fieldName);" + ) with writer.IndentedScopedBlock( - self._writer, 'if (MONGO_unlikely(push_result.second == false)) {', '}'): - self._writer.write_line('ctxt.throwDuplicateField(fieldName);') + self._writer, + "if (MONGO_unlikely(push_result.second == false)) {", + "}", + ): + self._writer.write_line("ctxt.throwDuplicateField(fieldName);") # Parse chained structs if not inlined # Parse chained types always here for field in struct.fields: - if not field.chained or \ - (field.chained and field.type.is_struct and struct.inline_chained_structs): + if not field.chained or ( + field.chained and field.type.is_struct and struct.inline_chained_structs + ): continue # Simply generate deserializers since these are all 'any' types - self.gen_field_deserializer(field, field.type, bson_object, "element", None, tenant, - False, True, struct.is_catalog_ctxt) + self.gen_field_deserializer( + field, + field.type, + bson_object, + "element", + None, + tenant, + False, + True, + struct.is_catalog_ctxt, + ) self._writer.write_empty_line() self._writer.write_empty_line() return field_usage_check - def get_bson_deserializer_static_common(self, struct, static_method_info, method_info, - ownership): + def get_bson_deserializer_static_common( + self, struct, static_method_info, method_info, ownership + ): # type: (ast.Struct, struct_types.MethodInfo, struct_types.MethodInfo, _StructDataOwnership) -> None """Generate the C++ deserializer static method.""" func_def = static_method_info.get_definition() - with self._block('%s {' % (func_def), '}'): - if isinstance(struct, - ast.Command) and struct.namespace != common.COMMAND_NAMESPACE_IGNORED: + with self._block("%s {" % (func_def), "}"): + if ( + isinstance(struct, ast.Command) + and struct.namespace != common.COMMAND_NAMESPACE_IGNORED + ): if struct.namespace == common.COMMAND_NAMESPACE_TYPE: cpp_type_info = cpp_types.get_cpp_type(struct.command_field) - if struct.command_field.type.cpp_type and cpp_types.is_primitive_scalar_type( - struct.command_field.type.cpp_type): + if ( + struct.command_field.type.cpp_type + and cpp_types.is_primitive_scalar_type( + struct.command_field.type.cpp_type + ) + ): self._writer.write_line( - 'auto localCmdType = mongo::idl::preparsedValue<%s>();' % - (cpp_type_info.get_storage_type())) + "auto localCmdType = mongo::idl::preparsedValue<%s>();" + % (cpp_type_info.get_storage_type()) + ) else: self._writer.write_line( - 'auto localCmdType = mongo::idl::preparsedValue<%s>();' % - (cpp_type_info.get_storage_type())) + "auto localCmdType = mongo::idl::preparsedValue<%s>();" + % (cpp_type_info.get_storage_type()) + ) self._writer.write_line( - '%s object(localCmdType);' % (common.title_case(struct.cpp_name))) - elif struct.namespace in (common.COMMAND_NAMESPACE_CONCATENATE_WITH_DB, - common.COMMAND_NAMESPACE_CONCATENATE_WITH_DB_OR_UUID): - self._writer.write_line('NamespaceString localNS;') + "%s object(localCmdType);" + % (common.title_case(struct.cpp_name)) + ) + elif struct.namespace in ( + common.COMMAND_NAMESPACE_CONCATENATE_WITH_DB, + common.COMMAND_NAMESPACE_CONCATENATE_WITH_DB_OR_UUID, + ): + self._writer.write_line("NamespaceString localNS;") self._writer.write_line( - '%s object(localNS);' % (common.title_case(struct.cpp_name))) + "%s object(localNS);" % (common.title_case(struct.cpp_name)) + ) else: assert False, "Missing case" else: - self._writer.write_line('auto object = mongo::idl::preparsedValue<%s>();' % - common.title_case(struct.cpp_name)) + self._writer.write_line( + "auto object = mongo::idl::preparsedValue<%s>();" + % common.title_case(struct.cpp_name) + ) - self._writer.write_line(method_info.get_call('object')) + self._writer.write_line(method_info.get_call("object")) if struct.is_view: if ownership == _StructDataOwnership.OWNER: - self._writer.write_line('object.setAnchor(std::move(bsonObject));') + self._writer.write_line("object.setAnchor(std::move(bsonObject));") elif ownership == _StructDataOwnership.SHARED: - self._writer.write_line('object.setAnchor(bsonObject);') + self._writer.write_line("object.setAnchor(bsonObject);") - self._writer.write_line('return object;') + self._writer.write_line("return object;") self.write_empty_line() @@ -2189,10 +2608,14 @@ class _CppSourceFileWriter(_CppFileWriterBase): param_type = cpp_type_info.get_storage_type() with self._block("{", "}"): - self._writer.write_line(f"static const {param_type} rhs{{{_get_expression(limit)}}};") + self._writer.write_line( + f"static const {param_type} rhs{{{_get_expression(limit)}}};" + ) with self._block("if (!(value %s rhs)) {" % (op), "}"): - self._writer.write_line('throwComparisonError<%s>(%s"%s", "%s"_sd, value, rhs);' % - (field.type.cpp_type, optional_param, field.name, op)) + self._writer.write_line( + 'throwComparisonError<%s>(%s"%s", "%s"_sd, value, rhs);' + % (field.type.cpp_type, optional_param, field.name, op) + ) def _gen_field_validator(self, struct, field, optional_params): # type: (ast.Struct, ast.Field, Tuple[str, str]) -> None @@ -2203,30 +2626,41 @@ class _CppSourceFileWriter(_CppFileWriterBase): param_type = cpp_type_info.get_storage_type() if not cpp_types.is_primitive_type(param_type): - param_type += '&' + param_type += "&" method_template = { - 'class_name': common.title_case(struct.cpp_name), - 'method_name': _get_field_member_validator_name(field), - 'param_type': param_type, - 'optional_param': optional_params[0], + "class_name": common.title_case(struct.cpp_name), + "method_name": _get_field_member_validator_name(field), + "param_type": param_type, + "optional_param": optional_params[0], } with self._with_template(method_template): self._writer.write_template( - 'void ${class_name}::${method_name}(${optional_param}const ${param_type} value)') - with self._block('{', '}'): + "void ${class_name}::${method_name}(${optional_param}const ${param_type} value)" + ) + with self._block("{", "}"): if validator.gt is not None: - self._compare_and_return_status('>', validator.gt, field, optional_params[1]) + self._compare_and_return_status( + ">", validator.gt, field, optional_params[1] + ) if validator.gte is not None: - self._compare_and_return_status('>=', validator.gte, field, optional_params[1]) + self._compare_and_return_status( + ">=", validator.gte, field, optional_params[1] + ) if validator.lt is not None: - self._compare_and_return_status('<', validator.lt, field, optional_params[1]) + self._compare_and_return_status( + "<", validator.lt, field, optional_params[1] + ) if validator.lte is not None: - self._compare_and_return_status('<=', validator.lte, field, optional_params[1]) + self._compare_and_return_status( + "<=", validator.lte, field, optional_params[1] + ) if validator.callback is not None: - self._writer.write_line('uassertStatusOK(%s(value));' % (validator.callback)) + self._writer.write_line( + "uassertStatusOK(%s(value));" % (validator.callback) + ) self._writer.write_empty_line() @@ -2238,7 +2672,7 @@ class _CppSourceFileWriter(_CppFileWriterBase): # Fields without validators are implemented in the header. continue - for optional_params in [('IDLParserContext& ctxt, ', 'ctxt, '), ('', '')]: + for optional_params in [("IDLParserContext& ctxt, ", "ctxt, "), ("", "")]: self._gen_field_validator(struct, field, optional_params) def gen_bson_deserializer_methods(self, struct): @@ -2247,32 +2681,44 @@ class _CppSourceFileWriter(_CppFileWriterBase): struct_type_info = struct_types.get_struct_info(struct) method = struct_type_info.get_deserializer_method() - self.get_bson_deserializer_static_common(struct, - struct_type_info.get_deserializer_static_method(), - method, _StructDataOwnership.VIEW) self.get_bson_deserializer_static_common( - struct, struct_type_info.get_sharing_deserializer_static_method(), method, - _StructDataOwnership.SHARED) + struct, + struct_type_info.get_deserializer_static_method(), + method, + _StructDataOwnership.VIEW, + ) self.get_bson_deserializer_static_common( - struct, struct_type_info.get_owned_deserializer_static_method(), method, - _StructDataOwnership.OWNER) + struct, + struct_type_info.get_sharing_deserializer_static_method(), + method, + _StructDataOwnership.SHARED, + ) + self.get_bson_deserializer_static_common( + struct, + struct_type_info.get_owned_deserializer_static_method(), + method, + _StructDataOwnership.OWNER, + ) func_def = method.get_definition() # Name of the variable that we are deserialzing from variable_name = "bsonObject" - with self._block('%s {' % (func_def), '}'): + with self._block("%s {" % (func_def), "}"): # If the struct contains no fields, there's nothing to deserialize, so we write an empty function stub. if not struct.fields: return # if the only field is an internal only field, there's also nothing to deserialize - if len(struct.fields) == 1 and any(field.type.internal_only for field in struct.fields): + if len(struct.fields) == 1 and any( + field.type.internal_only for field in struct.fields + ): return # Deserialize all the fields - field_usage_check = self._gen_fields_deserializer_common(struct, variable_name, - "ctxt.getTenantId()") + field_usage_check = self._gen_fields_deserializer_common( + struct, variable_name, "ctxt.getTenantId()" + ) # Check for required fields field_usage_check.add_final_checks() @@ -2294,21 +2740,27 @@ class _CppSourceFileWriter(_CppFileWriterBase): struct_type_info = struct_types.get_struct_info(struct) self.get_bson_deserializer_static_common( - struct, struct_type_info.get_op_msg_request_deserializer_static_method(), - struct_type_info.get_op_msg_request_deserializer_method(), _StructDataOwnership.VIEW) - - func_def = struct_type_info.get_op_msg_request_deserializer_method().get_definition() - with self._block('%s {' % (func_def), '}'): + struct, + struct_type_info.get_op_msg_request_deserializer_static_method(), + struct_type_info.get_op_msg_request_deserializer_method(), + _StructDataOwnership.VIEW, + ) + func_def = ( + struct_type_info.get_op_msg_request_deserializer_method().get_definition() + ) + with self._block("%s {" % (func_def), "}"): # Deserialize all the fields field_usage_check = self._gen_fields_deserializer_common( - struct, "request.body", "request.getValidatedTenantId()") + struct, "request.body", "request.getValidatedTenantId()" + ) # Iterate through the document sequences if we have any has_doc_sequence = len( - [field for field in struct.fields if field.supports_doc_sequence]) + [field for field in struct.fields if field.supports_doc_sequence] + ) if has_doc_sequence: - with self._block('for (auto&& sequence : request.sequences) {', '}'): + with self._block("for (auto&& sequence : request.sequences) {", "}"): field_usage_check.add_store("sequence.name") self._writer.write_empty_line() @@ -2318,32 +2770,44 @@ class _CppSourceFileWriter(_CppFileWriterBase): if not field.supports_doc_sequence: continue - field_predicate = 'sequence.name == %s' % (_get_field_constant_name(field)) + field_predicate = "sequence.name == %s" % ( + _get_field_constant_name(field) + ) with self._predicate(field_predicate, not first_field): field_usage_check.add(field, "sequence.name") if _is_required_serializer_field(field): - self._writer.write_line(_gen_mark_present(field.cpp_name)) + self._writer.write_line( + _gen_mark_present(field.cpp_name) + ) - self.gen_doc_sequence_deserializer(field, - "request.getValidatedTenantId()") + self.gen_doc_sequence_deserializer( + field, "request.getValidatedTenantId()" + ) if first_field: first_field = False # End of for fields # Generate strict check for extranous fields - with self._block('else {', '}'): + with self._block("else {", "}"): if struct.strict: - self._writer.write_line('ctxt.throwUnknownField(sequence.name);') + self._writer.write_line( + "ctxt.throwUnknownField(sequence.name);" + ) else: self._writer.write_line( - 'auto push_result = usedFieldSet.insert(sequence.name);') + "auto push_result = usedFieldSet.insert(sequence.name);" + ) with writer.IndentedScopedBlock( - self._writer, - 'if (MONGO_unlikely(push_result.second == false)) {', '}'): - self._writer.write_line('ctxt.throwDuplicateField(sequence.name);') + self._writer, + "if (MONGO_unlikely(push_result.second == false)) {", + "}", + ): + self._writer.write_line( + "ctxt.throwDuplicateField(sequence.name);" + ) self._writer.write_empty_line() @@ -2351,7 +2815,9 @@ class _CppSourceFileWriter(_CppFileWriterBase): field_usage_check.add_final_checks() self._writer.write_empty_line() - self._gen_command_deserializer(struct, "request.body", "request.getValidatedTenantId()") + self._gen_command_deserializer( + struct, "request.body", "request.getValidatedTenantId()" + ) def _gen_serializer_method_custom(self, field, is_catalog_ctxt): # type: (ast.Field, bool) -> None @@ -2359,10 +2825,10 @@ class _CppSourceFileWriter(_CppFileWriterBase): # Generate custom serialization template_params = { - 'field_name': _get_field_constant_name(field), - 'access_member': _access_member(field), - 'serialization_context': '_serializationContext', - 'serialization_options': 'options', + "field_name": _get_field_constant_name(field), + "access_member": _access_member(field), + "serialization_context": "_serializationContext", + "serialization_options": "options", } with self._with_template(template_params): @@ -2373,44 +2839,56 @@ class _CppSourceFileWriter(_CppFileWriterBase): if bson_cpp_type and bson_cpp_type.has_serializer(): if field.type.is_array: self._writer.write_template( - 'BSONArrayBuilder arrayBuilder(builder->subarrayStart(${field_name}));') - with self._block('for (const auto& item : ${access_member}) {', '}'): + "BSONArrayBuilder arrayBuilder(builder->subarrayStart(${field_name}));" + ) + with self._block( + "for (const auto& item : ${access_member}) {", "}" + ): expression = bson_cpp_type.gen_serializer_expression( - self._writer, 'item', - field.query_shape == ast.QueryShapeFieldType.CUSTOM, is_catalog_ctxt) - template_params['expression'] = expression - self._writer.write_template('arrayBuilder.append(${expression});') + self._writer, + "item", + field.query_shape == ast.QueryShapeFieldType.CUSTOM, + is_catalog_ctxt, + ) + template_params["expression"] = expression + self._writer.write_template( + "arrayBuilder.append(${expression});" + ) else: expression = bson_cpp_type.gen_serializer_expression( - self._writer, _access_member(field), - field.query_shape == ast.QueryShapeFieldType.CUSTOM, is_catalog_ctxt) - template_params['expression'] = expression + self._writer, + _access_member(field), + field.query_shape == ast.QueryShapeFieldType.CUSTOM, + is_catalog_ctxt, + ) + template_params["expression"] = expression if not field.should_serialize_with_options: self._writer.write_template( - 'builder->append(${field_name}, ${expression});') + "builder->append(${field_name}, ${expression});" + ) elif field.query_shape == ast.QueryShapeFieldType.LITERAL: self._writer.write_template( - 'options.serializeLiteral(${expression}).serializeForIDL(${field_name}, builder);' + "options.serializeLiteral(${expression}).serializeForIDL(${field_name}, builder);" ) elif field.query_shape == ast.QueryShapeFieldType.ANONYMIZE: self._writer.write_template( - 'builder->append(${field_name}, options.serializeFieldPathFromString(${expression}));' + "builder->append(${field_name}, options.serializeFieldPathFromString(${expression}));" ) else: assert False - elif field.type.bson_serialization_type[0] == 'any': + elif field.type.bson_serialization_type[0] == "any": def maybe_add_serialization_context(*args): """Append 'SerializationContext' arg if needed.""" if field.type.deserialize_with_tenant: - args += ('${serialization_context}', ) + args += ("${serialization_context}",) return args def maybe_add_serialization_options(*args): """Append 'SerializationOptions' arg if needed.""" if field.query_shape == ast.QueryShapeFieldType.CUSTOM: - args += ('${serialization_options}', ) + args += ("${serialization_options}",) return args def generate_args_template(*args): @@ -2423,7 +2901,9 @@ class _CppSourceFileWriter(_CppFileWriterBase): args = maybe_add_serialization_options(*args) if writer.is_function(serializer): # It should be invoked as 'function(subject, ...args);' - return f"{serializer}({generate_args_template(subject, *args)});" + return ( + f"{serializer}({generate_args_template(subject, *args)});" + ) else: # It should be invoked as 'subject.method(...args);' truncated_serializer = writer.get_method_name(serializer) @@ -2432,150 +2912,195 @@ class _CppSourceFileWriter(_CppFileWriterBase): if field.type.is_array: # Array variants - we pass an array builder self._writer.write_template( - 'BSONArrayBuilder arrayBuilder(builder->subarrayStart(${field_name}));') - with self._block('for (const auto& item : ${access_member}) {', '}'): + "BSONArrayBuilder arrayBuilder(builder->subarrayStart(${field_name}));" + ) + with self._block( + "for (const auto& item : ${access_member}) {", "}" + ): template = generate_call_site( field.type.serializer, - 'item', - '&arrayBuilder', + "item", + "&arrayBuilder", ) self._writer.write_template(template) else: # Non-array variants - we pass the field name they should use, and a BSONObjBuilder. - template = generate_call_site(field.type.serializer, '${access_member}', - '${field_name}', 'builder') + template = generate_call_site( + field.type.serializer, + "${access_member}", + "${field_name}", + "builder", + ) self._writer.write_template(template) else: method_name = writer.get_method_name(field.type.serializer) - template_params['method_name'] = method_name + template_params["method_name"] = method_name if field.chained: # Just directly call the serializer for chained structs without opening up a # nested document. - self._writer.write_template('${access_member}.${method_name}(builder);') + self._writer.write_template( + "${access_member}.${method_name}(builder);" + ) elif field.type.is_array: self._writer.write_template( - 'BSONArrayBuilder arrayBuilder(builder->subarrayStart(${field_name}));') - with self._block('for (const auto& item : ${access_member}) {', '}'): + "BSONArrayBuilder arrayBuilder(builder->subarrayStart(${field_name}));" + ) + with self._block( + "for (const auto& item : ${access_member}) {", "}" + ): self._writer.write_line( - 'BSONObjBuilder subObjBuilder(arrayBuilder.subobjStart());') - self._writer.write_template('item.${method_name}(&subObjBuilder);') + "BSONObjBuilder subObjBuilder(arrayBuilder.subobjStart());" + ) + self._writer.write_template( + "item.${method_name}(&subObjBuilder);" + ) else: self._writer.write_template( - '${access_member}.${method_name}(${field_name}, builder);') + "${access_member}.${method_name}(${field_name}, builder);" + ) def _gen_serializer_method_struct(self, field): # type: (ast.Field) -> None """Generate the serialize method definition for a struct type.""" template_params = { - 'field_name': _get_field_constant_name(field), - 'access_member': _access_member(field), + "field_name": _get_field_constant_name(field), + "access_member": _access_member(field), } with self._with_template(template_params): - if field.chained: # Just directly call the serializer for chained structs without opening up a nested # document. if not field.should_serialize_with_options: - self._writer.write_template('${access_member}.serialize(builder);') + self._writer.write_template("${access_member}.serialize(builder);") else: - self._writer.write_template('${access_member}.serialize(builder, options);') + self._writer.write_template( + "${access_member}.serialize(builder, options);" + ) elif field.type.is_array: self._writer.write_template( - 'BSONArrayBuilder arrayBuilder(builder->subarrayStart(${field_name}));') - with self._block('for (const auto& item : ${access_member}) {', '}'): + "BSONArrayBuilder arrayBuilder(builder->subarrayStart(${field_name}));" + ) + with self._block("for (const auto& item : ${access_member}) {", "}"): self._writer.write_line( - 'BSONObjBuilder subObjBuilder(arrayBuilder.subobjStart());') + "BSONObjBuilder subObjBuilder(arrayBuilder.subobjStart());" + ) if not field.should_serialize_with_options: - self._writer.write_line('item.serialize(&subObjBuilder);') + self._writer.write_line("item.serialize(&subObjBuilder);") else: - self._writer.write_line('item.serialize(&subObjBuilder, options);') + self._writer.write_line( + "item.serialize(&subObjBuilder, options);" + ) else: self._writer.write_template( - 'BSONObjBuilder subObjBuilder(builder->subobjStart(${field_name}));') + "BSONObjBuilder subObjBuilder(builder->subobjStart(${field_name}));" + ) if not field.should_serialize_with_options: - self._writer.write_template('${access_member}.serialize(&subObjBuilder);') + self._writer.write_template( + "${access_member}.serialize(&subObjBuilder);" + ) else: self._writer.write_template( - '${access_member}.serialize(&subObjBuilder, options);') + "${access_member}.serialize(&subObjBuilder, options);" + ) def _gen_serializer_method_array_variant(self, field): template_params = { - 'field_name': _get_field_constant_name(field), - 'access_member': 'item', + "field_name": _get_field_constant_name(field), + "access_member": "item", } with self._with_template(template_params): self._writer.write_template( - 'BSONArrayBuilder arrayBuilder(builder->subarrayStart(${field_name}));') - with self._block('for (const auto& item : %s) {' % _access_member(field), '}'): - self._writer.write_line('BSONObjBuilder subObjBuilder(arrayBuilder.subobjStart());') - self._gen_serializer_method_variant_helper(field, template_params, - builder='&subObjBuilder') + "BSONArrayBuilder arrayBuilder(builder->subarrayStart(${field_name}));" + ) + with self._block( + "for (const auto& item : %s) {" % _access_member(field), "}" + ): + self._writer.write_line( + "BSONObjBuilder subObjBuilder(arrayBuilder.subobjStart());" + ) + self._gen_serializer_method_variant_helper( + field, template_params, builder="&subObjBuilder" + ) def _gen_serializer_method_variant(self, field): # type: (ast.Field) -> None """Generate the serialize method definition for a variant type.""" template_params = { - 'field_name': _get_field_constant_name(field), - 'access_member': _access_member(field), + "field_name": _get_field_constant_name(field), + "access_member": _access_member(field), } with self._with_template(template_params): self._gen_serializer_method_variant_helper(field, template_params) - def _gen_serializer_method_variant_helper(self, field, template_params, builder='builder'): + def _gen_serializer_method_variant_helper( + self, field, template_params, builder="builder" + ): # type: (ast.Field, Dict[str, str], str) -> None - with self._block('visit(OverloadedVisitor{', '}, ${access_member});'): + with self._block("visit(OverloadedVisitor{", "}, ${access_member});"): for variant_type in itertools.chain( - field.type.variant_types, - field.type.variant_struct_types if field.type.variant_struct_types else []): + field.type.variant_types, + field.type.variant_struct_types + if field.type.variant_struct_types + else [], + ): + template_params["cpp_type"] = ( + "std::vector<" + variant_type.cpp_type + ">" + if variant_type.is_array + else variant_type.cpp_type + ) - template_params[ - 'cpp_type'] = 'std::vector<' + variant_type.cpp_type + '>' if variant_type.is_array else variant_type.cpp_type - - template_params['param_opt'] = "" + template_params["param_opt"] = "" if field.should_serialize_with_options: - template_params['param_opt'] = ', options' - with self._block('[%s${param_opt}](const ${cpp_type}& value) {' % builder, '},'): + template_params["param_opt"] = ", options" + with self._block( + "[%s${param_opt}](const ${cpp_type}& value) {" % builder, "}," + ): bson_cpp_type = cpp_types.get_bson_cpp_type(variant_type) if field.type.is_variant and field.type.is_array: - self._writer.write_template('value.serialize(%s);' % builder) + self._writer.write_template("value.serialize(%s);" % builder) elif bson_cpp_type and bson_cpp_type.has_serializer(): assert not field.type.is_array expression = bson_cpp_type.gen_serializer_expression( - self._writer, 'value', - field.query_shape == ast.QueryShapeFieldType.CUSTOM, False) - template_params['expression'] = expression + self._writer, + "value", + field.query_shape == ast.QueryShapeFieldType.CUSTOM, + False, + ) + template_params["expression"] = expression if not field.should_serialize_with_options: self._writer.write_template( - 'builder->append(${field_name}, ${expression});') + "builder->append(${field_name}, ${expression});" + ) elif field.query_shape == ast.QueryShapeFieldType.LITERAL: self._writer.write_template( - 'options.serializeLiteral(${expression}).serializeForIDL(${field_name}, builder);' + "options.serializeLiteral(${expression}).serializeForIDL(${field_name}, builder);" ) elif field.query_shape == ast.QueryShapeFieldType.ANONYMIZE: self._writer.write_template( - 'builder->append(${field_name}, options.serializeFieldPathFromString(${expression}));' + "builder->append(${field_name}, options.serializeFieldPathFromString(${expression}));" ) else: assert False else: if not field.should_serialize_with_options: self._writer.write_template( - 'idl::idlSerialize(builder, ${field_name}, value);') + "idl::idlSerialize(builder, ${field_name}, value);" + ) elif field.query_shape == ast.QueryShapeFieldType.LITERAL: self._writer.write_template( - 'options.serializeLiteral(value).serializeForIDL(${field_name}, builder);' + "options.serializeLiteral(value).serializeForIDL(${field_name}, builder);" ) elif field.query_shape == ast.QueryShapeFieldType.ANONYMIZE: self._writer.write_template( - 'idl::idlSerialize(builder, ${field_name}, options.serializeFieldPathFromString(value));' + "idl::idlSerialize(builder, ${field_name}, options.serializeFieldPathFromString(value));" ) else: assert False @@ -2589,17 +3114,18 @@ class _CppSourceFileWriter(_CppFileWriterBase): # Is this a scalar bson C++ type? bson_cpp_type = cpp_types.get_bson_cpp_type(field.type) - needs_custom_serializer = field.type.serializer or (bson_cpp_type - and bson_cpp_type.has_serializer()) + needs_custom_serializer = field.type.serializer or ( + bson_cpp_type and bson_cpp_type.has_serializer() + ) optional_block_start = None if field.optional: - optional_block_start = 'if (%s) {' % (member_name) + optional_block_start = "if (%s) {" % (member_name) elif field.type.is_struct or needs_custom_serializer or field.type.is_array: # Introduce a new scope for required nested object serialization. - optional_block_start = '{' + optional_block_start = "{" - with self._block(optional_block_start, '}'): + with self._block(optional_block_start, "}"): if not field.type.is_struct: if needs_custom_serializer: self._gen_serializer_method_custom(field, is_catalog_ctxt) @@ -2614,21 +3140,27 @@ class _CppSourceFileWriter(_CppFileWriterBase): # Note: BSONObjBuilder::append, which all three branches use, has overrides for std::vector also if not field.should_serialize_with_options: self._writer.write_line( - 'builder->append(%s, %s);' % (_get_field_constant_name(field), - _access_member(field))) + "builder->append(%s, %s);" + % (_get_field_constant_name(field), _access_member(field)) + ) elif field.query_shape == ast.QueryShapeFieldType.LITERAL: # serializeLiteral expects an ImplicitValue, which can't be constructed with an int64_t expression_cast = "" if field.type.cpp_type == "std::int64_t": expression_cast = "(long long)" self._writer.write_line( - 'options.serializeLiteral(%s%s).serializeForIDL(%s, builder);' - % (expression_cast, _access_member(field), - _get_field_constant_name(field))) + "options.serializeLiteral(%s%s).serializeForIDL(%s, builder);" + % ( + expression_cast, + _access_member(field), + _get_field_constant_name(field), + ) + ) elif field.query_shape == ast.QueryShapeFieldType.ANONYMIZE: self._writer.write_line( - 'builder->append(%s, options.serializeFieldPathFromString(%s));' % - (_get_field_constant_name(field), _access_member(field))) + "builder->append(%s, options.serializeFieldPathFromString(%s));" + % (_get_field_constant_name(field), _access_member(field)) + ) else: assert False else: @@ -2637,9 +3169,10 @@ class _CppSourceFileWriter(_CppFileWriterBase): if field.always_serialize: # If using field.always_serialize, field.optional must also be true. Add an else block # that appends null when the optional field is not initialized. - with self._block('else {', '}'): + with self._block("else {", "}"): self._writer.write_line( - 'builder->appendNull(%s);' % (_get_field_constant_name(field))) + "builder->appendNull(%s);" % (_get_field_constant_name(field)) + ) def _gen_serializer_methods_common(self, struct, is_op_msg_request): # type: (ast.Struct, bool) -> None @@ -2649,20 +3182,26 @@ class _CppSourceFileWriter(_CppFileWriterBase): # Check all required fields have been specified required_fields = [ - _get_has_field_member_name(field) for field in struct.fields + _get_has_field_member_name(field) + for field in struct.fields if _is_required_serializer_field(field) ] if required_fields: - self._writer.write_line('_hasMembers.required();') + self._writer.write_line("_hasMembers.required();") self._writer.write_empty_line() # Serialize the namespace as the first field if isinstance(struct, ast.Command): if struct.command_field: # Internal-only types aren't serialized or deserialized. - if not (struct.command_field.type and struct.command_field.type.internal_only): - self._gen_serializer_method_common(struct.command_field, struct.is_catalog_ctxt) + if not ( + struct.command_field.type + and struct.command_field.type.internal_only + ): + self._gen_serializer_method_common( + struct.command_field, struct.is_catalog_ctxt + ) else: struct_type_info = struct_types.get_struct_info(struct) struct_type_info.gen_serializer(self._writer) @@ -2697,10 +3236,13 @@ class _CppSourceFileWriter(_CppFileWriterBase): # Append passthrough elements if isinstance(struct, ast.Command): - known_name = "_knownOP_MSGFields" if is_op_msg_request else "_knownBSONFields" + known_name = ( + "_knownOP_MSGFields" if is_op_msg_request else "_knownBSONFields" + ) self._writer.write_line( - "::mongo::appendGenericCommandArguments(commandPassthroughFields, %s, builder);" % - (known_name)) + "::mongo::appendGenericCommandArguments(commandPassthroughFields, %s, builder);" + % (known_name) + ) self._writer.write_empty_line() def gen_bson_serializer_method(self, struct): @@ -2709,7 +3251,9 @@ class _CppSourceFileWriter(_CppFileWriterBase): struct_type_info = struct_types.get_struct_info(struct) - with self._block('%s {' % (struct_type_info.get_serializer_method().get_definition()), '}'): + with self._block( + "%s {" % (struct_type_info.get_serializer_method().get_definition()), "}" + ): self._gen_serializer_methods_common(struct, False) def gen_to_bson_serializer_method(self, struct): @@ -2717,11 +3261,16 @@ class _CppSourceFileWriter(_CppFileWriterBase): """Generate the toBSON method definition.""" struct_type_info = struct_types.get_struct_info(struct) - with self._block('%s {' % (struct_type_info.get_to_bson_method().get_definition()), '}'): - self._writer.write_line('BSONObjBuilder builder;') - self._writer.write_line(struct_type_info.get_serializer_method().get_call(None).replace( - "builder", "&builder")) - self._writer.write_line('return builder.obj();') + with self._block( + "%s {" % (struct_type_info.get_to_bson_method().get_definition()), "}" + ): + self._writer.write_line("BSONObjBuilder builder;") + self._writer.write_line( + struct_type_info.get_serializer_method() + .get_call(None) + .replace("builder", "&builder") + ) + self._writer.write_line("return builder.obj();") def _gen_doc_sequence_serializer(self, struct): # type: (ast.Struct) -> None @@ -2733,43 +3282,60 @@ class _CppSourceFileWriter(_CppFileWriterBase): member_name = _get_field_member_name(field) - optional_block_start = '{' + optional_block_start = "{" if field.optional: - optional_block_start = 'if (%s) {' % (member_name) + optional_block_start = "if (%s) {" % (member_name) - with self._block(optional_block_start, '}'): - self._writer.write_line('OpMsg::DocumentSequence documentSequence;') + with self._block(optional_block_start, "}"): + self._writer.write_line("OpMsg::DocumentSequence documentSequence;") self._writer.write_template( - 'documentSequence.name = %s.toString();' % (_get_field_constant_name(field))) + "documentSequence.name = %s.toString();" + % (_get_field_constant_name(field)) + ) - with self._block('for (const auto& item : %s) {' % (_access_member(field)), '}'): + with self._block( + "for (const auto& item : %s) {" % (_access_member(field)), "}" + ): if not field.type.is_struct: if field.type.is_variant: # _gen_serializer_method_variant expects builder to be a pointer. - self._writer.write_line('BSONObjBuilder objBuilder;') - self._writer.write_line('BSONObjBuilder* builder = &objBuilder;') + self._writer.write_line("BSONObjBuilder objBuilder;") + self._writer.write_line( + "BSONObjBuilder* builder = &objBuilder;" + ) template_params = { - 'field_name': _get_field_constant_name(field), - 'access_member': 'item', + "field_name": _get_field_constant_name(field), + "access_member": "item", } with self._with_template(template_params): - self._gen_serializer_method_variant_helper(field, template_params) + self._gen_serializer_method_variant_helper( + field, template_params + ) self._writer.write_line( - 'documentSequence.objs.push_back(builder->obj());') + "documentSequence.objs.push_back(builder->obj());" + ) elif field.type.serializer: - self._writer.write_line('documentSequence.objs.push_back(item.%s());' % - (writer.get_method_name(field.type.serializer))) + self._writer.write_line( + "documentSequence.objs.push_back(item.%s());" + % (writer.get_method_name(field.type.serializer)) + ) else: - self._writer.write_line('documentSequence.objs.push_back(item);') + self._writer.write_line( + "documentSequence.objs.push_back(item);" + ) else: - self._writer.write_line('BSONObjBuilder builder;') - self._writer.write_line('item.serialize(&builder);') - self._writer.write_line('documentSequence.objs.push_back(builder.obj());') + self._writer.write_line("BSONObjBuilder builder;") + self._writer.write_line("item.serialize(&builder);") + self._writer.write_line( + "documentSequence.objs.push_back(builder.obj());" + ) - self._writer.write_template('request.sequences.emplace_back(documentSequence);') + self._writer.write_template( + "request.sequences.emplace_back(documentSequence);" + ) # Add a blank line after each block self._writer.write_empty_line() @@ -2783,21 +3349,25 @@ class _CppSourceFileWriter(_CppFileWriterBase): struct_type_info = struct_types.get_struct_info(struct) with self._block( - '%s {' % (struct_type_info.get_op_msg_request_serializer_method().get_definition()), - '}'): - self._writer.write_line('BSONObjBuilder localBuilder;') + "%s {" + % ( + struct_type_info.get_op_msg_request_serializer_method().get_definition() + ), + "}", + ): + self._writer.write_line("BSONObjBuilder localBuilder;") - with self._block('{', '}'): - self._writer.write_line('BSONObjBuilder* builder = &localBuilder;') + with self._block("{", "}"): + self._writer.write_line("BSONObjBuilder* builder = &localBuilder;") self._gen_serializer_methods_common(struct, True) - self._writer.write_line('OpMsgRequest request;') - self._writer.write_line('request.body = localBuilder.obj();') + self._writer.write_line("OpMsgRequest request;") + self._writer.write_line("request.body = localBuilder.obj();") self._gen_doc_sequence_serializer(struct) - self._writer.write_line('return request;') + self._writer.write_line("return request;") def gen_authorization_contract_definition(self, struct): # type: (ast.Struct) -> None @@ -2814,15 +3384,21 @@ class _CppSourceFileWriter(_CppFileWriterBase): checks = ",".join([("AccessCheckEnum::" + c) for c in checks_list]) privilege_list = [ac.privilege for ac in struct.access_checks if ac.privilege] - privileges = ",".join([ - "Privilege(ResourcePattern::forAuthorizationContract(MatchTypeEnum::%s), ActionSet{%s})" - % (p.resource_pattern, ",".join(["ActionType::" + at for at in p.action_type])) - for p in privilege_list - ]) + privileges = ",".join( + [ + "Privilege(ResourcePattern::forAuthorizationContract(MatchTypeEnum::%s), ActionSet{%s})" + % ( + p.resource_pattern, + ",".join(["ActionType::" + at for at in p.action_type]), + ) + for p in privilege_list + ] + ) self._writer.write_line( - 'mongo::AuthorizationContract %s::kAuthorizationContract = AuthorizationContract(std::initializer_list{%s}, std::initializer_list{%s});' - % (common.title_case(struct.cpp_name), checks, privileges)) + "mongo::AuthorizationContract %s::kAuthorizationContract = AuthorizationContract(std::initializer_list{%s}, std::initializer_list{%s});" + % (common.title_case(struct.cpp_name), checks, privileges) + ) self._writer.write_empty_line() @@ -2843,13 +3419,22 @@ class _CppSourceFileWriter(_CppFileWriterBase): # type: (ast.Struct, str, bool) -> None """Generate the known fields declaration with specified name.""" block_name = common.template_args( - 'const std::vector ${class_name}::_${name}Fields {', name=name, - class_name=common.title_case(struct.cpp_name)) + "const std::vector ${class_name}::_${name}Fields {", + name=name, + class_name=common.title_case(struct.cpp_name), + ) with self._block(block_name, "};"): - sorted_fields = sorted([ - field for field in struct.fields - if (not field.serialize_op_msg_request_only or include_op_msg_implicit) - ], key=lambda f: f.cpp_name) + sorted_fields = sorted( + [ + field + for field in struct.fields + if ( + not field.serialize_op_msg_request_only + or include_op_msg_implicit + ) + ], + key=lambda f: f.cpp_name, + ) for field in sorted_fields: # Internal only fields are not parsed from BSON objects @@ -2857,13 +3442,19 @@ class _CppSourceFileWriter(_CppFileWriterBase): continue self._writer.write_line( - common.template_args('${class_name}::${constant_name},', - class_name=common.title_case(struct.cpp_name), - constant_name=_get_field_constant_name(field))) + common.template_args( + "${class_name}::${constant_name},", + class_name=common.title_case(struct.cpp_name), + constant_name=_get_field_constant_name(field), + ) + ) self._writer.write_line( - common.template_args('${class_name}::kCommandName,', - class_name=common.title_case(struct.cpp_name))) + common.template_args( + "${class_name}::kCommandName,", + class_name=common.title_case(struct.cpp_name), + ) + ) def gen_known_fields_declaration(self, struct): # type: (ast.Struct) -> None @@ -2879,58 +3470,76 @@ class _CppSourceFileWriter(_CppFileWriterBase): field_list_info = generic_field_list_types.get_field_list_info(struct) klass = common.title_case(struct.cpp_name) self._writer.write_line( - common.template_args('// Map: fieldName -> ${should_forward_name}', - should_forward_name=field_list_info.get_should_forward_name())) + common.template_args( + "// Map: fieldName -> ${should_forward_name}", + should_forward_name=field_list_info.get_should_forward_name(), + ) + ) block_name = common.template_args( - 'const StaticImmortal> ${klass}::_genericFields {{', klass=klass) + "const StaticImmortal> ${klass}::_genericFields {{", + klass=klass, + ) with self._block(block_name, "}};"): sorted_entries = sorted(struct.fields, key=lambda f: f.name) for entry in sorted_entries: self._writer.write_line( common.template_args( - '{"${name}", ${should_forward}},', klass=klass, name=entry.name, - should_forward='true' - if entry.generic_field_info.get_should_forward() else 'false')) + '{"${name}", ${should_forward}},', + klass=klass, + name=entry.name, + should_forward="true" + if entry.generic_field_info.get_should_forward() + else "false", + ) + ) def _gen_server_parameter_specialized(self, param): # type: (ast.ServerParameter) -> None """Generate a specialized ServerParameter.""" - self._writer.write_line('auto sp = makeServerParameter<%s>(%s, %s);' % - (param.cpp_class.name, _encaps(param.name), param.set_at)) + self._writer.write_line( + "auto sp = makeServerParameter<%s>(%s, %s);" + % (param.cpp_class.name, _encaps(param.name), param.set_at) + ) if param.redact: - self._writer.write_line('sp->setRedact();') + self._writer.write_line("sp->setRedact();") if param.omit_in_ftdc: - self._writer.write_line('sp->setOmitInFTDC();') + self._writer.write_line("sp->setOmitInFTDC();") - self._writer.write_line('return sp;') + self._writer.write_line("return sp;") def _gen_server_parameter_class_definitions(self, param): # type: (ast.ServerParameter) -> None """Generate storage for default and/or append method for a specialized ServerParameter.""" cls = param.cpp_class - is_cluster_param = (param.set_at == 'ServerParameterType::kClusterWide') + is_cluster_param = param.set_at == "ServerParameterType::kClusterWide" if param.default or param.redact or is_cluster_param: self.gen_description_comment("%s: %s" % (param.name, param.description)) if param.default: self._writer.write_line( - 'constexpr decltype(%s::kDataDefault) %s::kDataDefault;' % (cls.name, cls.name)) + "constexpr decltype(%s::kDataDefault) %s::kDataDefault;" + % (cls.name, cls.name) + ) self.write_empty_line() if param.redact: with self._block( - 'void %s::append(OperationContext*, BSONObjBuilder* b, StringData name, const boost::optional& tenantId) {' - % (cls.name), '}'): + "void %s::append(OperationContext*, BSONObjBuilder* b, StringData name, const boost::optional& tenantId) {" + % (cls.name), + "}", + ): self._writer.write_line('*b << name << "###";') self.write_empty_line() # Specialized cluster parameters should also provide the implementation of setFromString(). if is_cluster_param: with self._block( - 'Status %s::setFromString(StringData str, const boost::optional& tenantId) {' - % (cls.name), '}'): + "Status %s::setFromString(StringData str, const boost::optional& tenantId) {" + % (cls.name), + "}", + ): self._writer.write_line( 'return {ErrorCodes::BadValue, "setFromString should never be used with cluster server parameters"};' ) @@ -2942,38 +3551,51 @@ class _CppSourceFileWriter(_CppFileWriterBase): if param.feature_flag: self._writer.write_line( common.template_args( - 'auto* ret = makeFeatureFlagServerParameter(${name}, ${storage});', - storage=param.cpp_varname, name=_encaps(param.name))) + "auto* ret = makeFeatureFlagServerParameter(${name}, ${storage});", + storage=param.cpp_varname, + name=_encaps(param.name), + ) + ) else: self._writer.write_line( common.template_args( - 'auto* ret = makeIDLServerParameterWithStorage<${spt}>(${name}, ${storage});', - storage=param.cpp_varname, spt=param.set_at, name=_encaps(param.name))) + "auto* ret = makeIDLServerParameterWithStorage<${spt}>(${name}, ${storage});", + storage=param.cpp_varname, + spt=param.set_at, + name=_encaps(param.name), + ) + ) if param.on_update is not None: - self._writer.write_line('ret->setOnUpdate(%s);' % (param.on_update)) + self._writer.write_line("ret->setOnUpdate(%s);" % (param.on_update)) if param.validator is not None: if param.validator.callback is not None: - self._writer.write_line('ret->addValidator(%s);' % (param.validator.callback)) + self._writer.write_line( + "ret->addValidator(%s);" % (param.validator.callback) + ) - for pred in ['lt', 'gt', 'lte', 'gte']: + for pred in ["lt", "gt", "lte", "gte"]: bound = getattr(param.validator, pred) if bound is not None: - self._writer.write_line('ret->addBound(%s);' % - (pred.upper(), _get_expression(bound))) + self._writer.write_line( + "ret->addBound(%s);" + % (pred.upper(), _get_expression(bound)) + ) if param.redact: - self._writer.write_line('ret->setRedact();') + self._writer.write_line("ret->setRedact();") if param.omit_in_ftdc: - self._writer.write_line('ret->setOmitInFTDC();') + self._writer.write_line("ret->setOmitInFTDC();") if param.default and not (param.cpp_vartype and param.cpp_varname): # Only need to call setDefault() if we haven't in-place initialized the declared var. self._writer.write_line( - 'uassertStatusOK(ret->setDefault(%s));' % (_get_expression(param.default))) + "uassertStatusOK(ret->setDefault(%s));" + % (_get_expression(param.default)) + ) - self._writer.write_line('return ret;') + self._writer.write_line("return ret;") def _gen_server_parameter(self, param): # type: (ast.ServerParameter) -> None @@ -2990,9 +3612,13 @@ class _CppSourceFileWriter(_CppFileWriterBase): for alias_no, alias in enumerate(param.deprecated_name): self._writer.write_line( common.template_args( - '${unused} auto* ${alias_var} = makeIDLServerParameterDeprecatedAlias(${name}, ${param_var});', - unused='[[maybe_unused]]', alias_var='scp_%d_%d' % (param_no, alias_no), - name=_encaps(alias), param_var='scp_%d' % (param_no))) + "${unused} auto* ${alias_var} = makeIDLServerParameterDeprecatedAlias(${name}, ${param_var});", + unused="[[maybe_unused]]", + alias_var="scp_%d_%d" % (param_no, alias_no), + name=_encaps(alias), + param_var="scp_%d" % (param_no), + ) + ) def gen_server_parameters(self, params, header_file_name): # type: (List[ast.ServerParameter], str) -> None @@ -3006,34 +3632,42 @@ class _CppSourceFileWriter(_CppFileWriterBase): # Optional storage declarations. elif (param.cpp_vartype is not None) and (param.cpp_varname is not None): with self._condition(param.condition, preprocessor_only=True): - init = ('{%s}' % (param.default.expr)) if param.default else '' + init = ("{%s}" % (param.default.expr)) if param.default else "" self._writer.write_line( - '%s %s%s;' % (param.cpp_vartype, param.cpp_varname, init)) + "%s %s%s;" % (param.cpp_vartype, param.cpp_varname, init) + ) - blockname = 'idl_' + \ - hashlib.sha1(header_file_name.encode()).hexdigest() - with self._block('MONGO_SERVER_PARAMETER_REGISTER(%s)(InitializerContext*) {' % (blockname), - '}'): + blockname = "idl_" + hashlib.sha1(header_file_name.encode()).hexdigest() + with self._block( + "MONGO_SERVER_PARAMETER_REGISTER(%s)(InitializerContext*) {" % (blockname), + "}", + ): # ServerParameter instances. for param_no, param in enumerate(params): self.gen_description_comment(param.description) with self._condition(param.condition): unused = not (param.test_only or param.deprecated_name) - with self.get_initializer_lambda('auto* scp_%d' % (param_no), unused=unused, - return_type='ServerParameter*'): + with self.get_initializer_lambda( + "auto* scp_%d" % (param_no), + unused=unused, + return_type="ServerParameter*", + ): self._gen_server_parameter(param) if param.test_only: - self._writer.write_line('scp_%d->setTestOnly();' % (param_no)) + self._writer.write_line("scp_%d->setTestOnly();" % (param_no)) if param.condition: if param.condition.feature_flag: - self._writer.write_line('scp_%d->setFeatureFlag(&%s);' % - (param_no, param.condition.feature_flag)) + self._writer.write_line( + "scp_%d->setFeatureFlag(&%s);" + % (param_no, param.condition.feature_flag) + ) if param.condition.min_fcv: self._writer.write_line( 'scp_%d->setMinFCV(FeatureCompatibilityVersionParser::parseVersionForFeatureFlags("%s"));' - % (param_no, param.condition.min_fcv)) + % (param_no, param.condition.min_fcv) + ) self._gen_server_parameter_deprecated_aliases(param_no, param) self.write_empty_line() @@ -3043,118 +3677,154 @@ class _CppSourceFileWriter(_CppFileWriterBase): """Generate Config Option instance.""" # Derive cpp_vartype from arg_vartype if needed. - vartype = ("moe::OptionTypeMap::type" % - (opt.arg_vartype)) if opt.cpp_vartype is None else opt.cpp_vartype + vartype = ( + ("moe::OptionTypeMap::type" % (opt.arg_vartype)) + if opt.cpp_vartype is None + else opt.cpp_vartype + ) # Mark option as coming from IDL autogenerated code. - usage = 'moe::OptionSection::OptionParserUsageType::IDLAutoGeneratedCode' + usage = "moe::OptionSection::OptionParserUsageType::IDLAutoGeneratedCode" with self._condition(opt.condition): - with self._block(section, ';'): + with self._block(section, ";"): self._writer.write_line( common.template_format( - '.addOptionChaining(${name}, ${short}, moe::${argtype}, ${desc}, ${deprname}, ${deprshortname}, ${usage})', + ".addOptionChaining(${name}, ${short}, moe::${argtype}, ${desc}, ${deprname}, ${deprshortname}, ${usage})", { - 'name': _encaps(opt.name), - 'short': _encaps(opt.short_name), - 'argtype': opt.arg_vartype, - 'desc': _get_expression(opt.description), - 'deprname': _encaps_list(opt.deprecated_name), - 'deprshortname': _encaps_list(opt.deprecated_short_name), - 'usage': usage, - })) - self._writer.write_line('.setSources(moe::%s)' % (opt.source)) + "name": _encaps(opt.name), + "short": _encaps(opt.short_name), + "argtype": opt.arg_vartype, + "desc": _get_expression(opt.description), + "deprname": _encaps_list(opt.deprecated_name), + "deprshortname": _encaps_list(opt.deprecated_short_name), + "usage": usage, + }, + ) + ) + self._writer.write_line(".setSources(moe::%s)" % (opt.source)) if opt.hidden: - self._writer.write_line('.hidden()') + self._writer.write_line(".hidden()") if opt.redact: - self._writer.write_line('.redact()') + self._writer.write_line(".redact()") for requires in opt.requires: - self._writer.write_line('.requiresOption(%s)' % (_encaps(requires))) + self._writer.write_line(".requiresOption(%s)" % (_encaps(requires))) for conflicts in opt.conflicts: - self._writer.write_line('.incompatibleWith(%s)' % (_encaps(conflicts))) + self._writer.write_line( + ".incompatibleWith(%s)" % (_encaps(conflicts)) + ) if opt.default: self._writer.write_line( - '.setDefault(moe::Value(%s))' % (_get_expression(opt.default))) + ".setDefault(moe::Value(%s))" % (_get_expression(opt.default)) + ) if opt.implicit: self._writer.write_line( - '.setImplicit(moe::Value(%s))' % (_get_expression(opt.implicit))) + ".setImplicit(moe::Value(%s))" % (_get_expression(opt.implicit)) + ) if opt.duplicates_append: - self._writer.write_line('.composing()') - if (opt.positional_start is not None) and (opt.positional_end is not None): + self._writer.write_line(".composing()") + if (opt.positional_start is not None) and ( + opt.positional_end is not None + ): self._writer.write_line( - '.positional(%d, %d)' % (opt.positional_start, opt.positional_end)) + ".positional(%d, %d)" + % (opt.positional_start, opt.positional_end) + ) if opt.canonicalize: - self._writer.write_line('.canonicalize(%s)' % opt.canonicalize) + self._writer.write_line(".canonicalize(%s)" % opt.canonicalize) if opt.validator: if opt.validator.callback: self._writer.write_line( common.template_args( - '.addConstraint(new moe::CallbackKeyConstraint<${argtype}>(${key}, ${callback}))', - argtype=vartype, key=_encaps(opt.name), - callback=opt.validator.callback)) + ".addConstraint(new moe::CallbackKeyConstraint<${argtype}>(${key}, ${callback}))", + argtype=vartype, + key=_encaps(opt.name), + callback=opt.validator.callback, + ) + ) - if (opt.validator.gt is not None) or (opt.validator.lt is not None) or ( - opt.validator.gte is not None) or (opt.validator.lte is not None): + if ( + (opt.validator.gt is not None) + or (opt.validator.lt is not None) + or (opt.validator.gte is not None) + or (opt.validator.lte is not None) + ): self._writer.write_line( common.template_args( - '.addConstraint(new moe::BoundaryKeyConstraint<${argtype}>(${key}, ${gt}, ${lt}, ${gte}, ${lte}))', - argtype=vartype, key=_encaps(opt.name), gt='boost::none' - if opt.validator.gt is None else _get_expression(opt.validator.gt), - lt='boost::none' - if opt.validator.lt is None else _get_expression(opt.validator.lt), - gte='boost::none' if opt.validator.gte is None else _get_expression( - opt.validator.gte), lte='boost::none' if - opt.validator.lte is None else _get_expression(opt.validator.lte))) + ".addConstraint(new moe::BoundaryKeyConstraint<${argtype}>(${key}, ${gt}, ${lt}, ${gte}, ${lte}))", + argtype=vartype, + key=_encaps(opt.name), + gt="boost::none" + if opt.validator.gt is None + else _get_expression(opt.validator.gt), + lt="boost::none" + if opt.validator.lt is None + else _get_expression(opt.validator.lt), + gte="boost::none" + if opt.validator.gte is None + else _get_expression(opt.validator.gte), + lte="boost::none" + if opt.validator.lte is None + else _get_expression(opt.validator.lte), + ) + ) self.write_empty_line() def _gen_config_options_register(self, root_opts, sections, returns_status): - self._writer.write_line('namespace moe = ::mongo::optionenvironment;') + self._writer.write_line("namespace moe = ::mongo::optionenvironment;") self.write_empty_line() for opt in root_opts: - self.gen_config_option(opt, 'options') + self.gen_config_option(opt, "options") for section_name, section_opts in sections.items(): - with self._block('{', '}'): - self._writer.write_line('moe::OptionSection section(%s);' % (_encaps(section_name))) + with self._block("{", "}"): + self._writer.write_line( + "moe::OptionSection section(%s);" % (_encaps(section_name)) + ) self.write_empty_line() for opt in section_opts: - self.gen_config_option(opt, 'section') + self.gen_config_option(opt, "section") - self._writer.write_line('auto status = options.addSection(section);') + self._writer.write_line("auto status = options.addSection(section);") if returns_status: - with self._block('if (!status.isOK()) {', '}'): - self._writer.write_line('return status;') + with self._block("if (!status.isOK()) {", "}"): + self._writer.write_line("return status;") else: - self._writer.write_line('uassertStatusOK(status);') + self._writer.write_line("uassertStatusOK(status);") self.write_empty_line() if returns_status: - self._writer.write_line('return Status::OK();') + self._writer.write_line("return Status::OK();") def _gen_config_options_store(self, configs, return_status): # Setup initializer for storing configured options in their variables. - self._writer.write_line('namespace moe = ::mongo::optionenvironment;') + self._writer.write_line("namespace moe = ::mongo::optionenvironment;") self.write_empty_line() for opt in configs: if opt.cpp_varname is None: continue - vartype = ("moe::OptionTypeMap::type" % - (opt.arg_vartype)) if opt.cpp_vartype is None else opt.cpp_vartype + vartype = ( + ("moe::OptionTypeMap::type" % (opt.arg_vartype)) + if opt.cpp_vartype is None + else opt.cpp_vartype + ) with self._condition(opt.condition): - with self._block('if (params.count(%s)) {' % (_encaps(opt.name)), '}'): + with self._block("if (params.count(%s)) {" % (_encaps(opt.name)), "}"): self._writer.write_line( - '%s = params[%s].as<%s>();' % (opt.cpp_varname, _encaps(opt.name), vartype)) + "%s = params[%s].as<%s>();" + % (opt.cpp_varname, _encaps(opt.name), vartype) + ) self.write_empty_line() if return_status: - self._writer.write_line('return Status::OK();') + self._writer.write_line("return Status::OK();") def gen_config_options(self, spec, header_file_name): # type: (ast.IDLAST, str) -> None @@ -3166,9 +3836,10 @@ class _CppSourceFileWriter(_CppFileWriterBase): has_storage_targets = True if opt.cpp_vartype is not None: with self._condition(opt.condition, preprocessor_only=True): - init = ('{%s}' % (opt.default.expr)) if opt.default else '' + init = ("{%s}" % (opt.default.expr)) if opt.default else "" self._writer.write_line( - '%s %s%s;' % (opt.cpp_vartype, opt.cpp_varname, init)) + "%s %s%s;" % (opt.cpp_vartype, opt.cpp_varname, init) + ) self.write_empty_line() @@ -3186,20 +3857,27 @@ class _CppSourceFileWriter(_CppFileWriterBase): initializer = spec.globals.configs and spec.globals.configs.initializer blockname = (initializer and initializer.name) or ( - 'idl_' + hashlib.sha1(header_file_name.encode()).hexdigest()) + "idl_" + hashlib.sha1(header_file_name.encode()).hexdigest() + ) if initializer and initializer.register: with self._block( - 'Status %s(optionenvironment::OptionSection* options_ptr) {' % - initializer.register, '}'): - self._writer.write_line('auto& options = *options_ptr;') + "Status %s(optionenvironment::OptionSection* options_ptr) {" + % initializer.register, + "}", + ): + self._writer.write_line("auto& options = *options_ptr;") self._gen_config_options_register(root_opts, sections, True) else: - with self.gen_namespace_block(''): + with self.gen_namespace_block(""): with self._block( - 'MONGO_MODULE_STARTUP_OPTIONS_REGISTER(%s)(InitializerContext*) {' % - (blockname), '}'): - self._writer.write_line('auto& options = optionenvironment::startupOptions;') + "MONGO_MODULE_STARTUP_OPTIONS_REGISTER(%s)(InitializerContext*) {" + % (blockname), + "}", + ): + self._writer.write_line( + "auto& options = optionenvironment::startupOptions;" + ) self._gen_config_options_register(root_opts, sections, False) self.write_empty_line() @@ -3207,17 +3885,21 @@ class _CppSourceFileWriter(_CppFileWriterBase): if has_storage_targets: if initializer and initializer.store: with self._block( - 'Status %s(const optionenvironment::Environment& params) {' % - initializer.store, '}'): + "Status %s(const optionenvironment::Environment& params) {" + % initializer.store, + "}", + ): self._gen_config_options_store(spec.configs, True) else: - with self.gen_namespace_block(''): + with self.gen_namespace_block(""): with self._block( - 'MONGO_STARTUP_OPTIONS_STORE(%s)(InitializerContext*) {' % (blockname), - '}'): + "MONGO_STARTUP_OPTIONS_STORE(%s)(InitializerContext*) {" + % (blockname), + "}", + ): # If all options are guarded by non-passing #ifdefs, then params will be unused. self._writer.write_line( - '[[maybe_unused]] const auto& params = optionenvironment::startupOptionsParsed;' + "[[maybe_unused]] const auto& params = optionenvironment::startupOptionsParsed;" ) self._gen_config_options_store(spec.configs, False) @@ -3239,8 +3921,8 @@ class _CppSourceFileWriter(_CppFileWriterBase): # Generate system includes second header_list = [ - 'bitset', - 'set', + "bitset", + "set", ] for include in header_list: @@ -3250,24 +3932,24 @@ class _CppSourceFileWriter(_CppFileWriterBase): # Generate mongo includes third header_list = [ - 'mongo/util/overloaded_visitor.h', - 'mongo/util/string_map.h', + "mongo/util/overloaded_visitor.h", + "mongo/util/string_map.h", ] if spec.commands: - header_list.append('mongo/db/auth/authorization_contract.h') - header_list.append('mongo/idl/command_generic_argument.h') + header_list.append("mongo/db/auth/authorization_contract.h") + header_list.append("mongo/idl/command_generic_argument.h") elif len([s for s in spec.structs if s.is_command_reply]) > 0: - header_list.append('mongo/idl/command_generic_argument.h') + header_list.append("mongo/idl/command_generic_argument.h") if spec.server_parameters: - header_list.append('mongo/db/server_parameter.h') - header_list.append('mongo/db/server_parameter_with_storage.h') + header_list.append("mongo/db/server_parameter.h") + header_list.append("mongo/db/server_parameter_with_storage.h") if spec.configs: - header_list.append('mongo/util/options_parser/option_section.h') - header_list.append('mongo/util/options_parser/startup_option_init.h') - header_list.append('mongo/util/options_parser/startup_options.h') + header_list.append("mongo/util/options_parser/option_section.h") + header_list.append("mongo/util/options_parser/startup_option_init.h") + header_list.append("mongo/util/options_parser/startup_options.h") header_list.sort() @@ -3353,7 +4035,7 @@ def _generate_header(spec, file_name): str_value = generate_header_str(spec) # Generate structs - with io.open(file_name, mode='wb') as file_handle: + with io.open(file_name, mode="wb") as file_handle: file_handle.write(str_value.encode()) @@ -3376,11 +4058,13 @@ def _generate_source(spec, target_arch, file_name, header_file_name): str_value = generate_source_str(spec, target_arch, header_file_name) # Generate structs - with io.open(file_name, mode='wb') as file_handle: + with io.open(file_name, mode="wb") as file_handle: file_handle.write(str_value.encode()) -def generate_code(spec, target_arch, output_base_dir, header_file_name, source_file_name): +def generate_code( + spec, target_arch, output_base_dir, header_file_name, source_file_name +): # type: (ast.IDLAST, str, str, str, str) -> None """Generate a C++ header and source file from an idl.ast tree.""" @@ -3388,7 +4072,8 @@ def generate_code(spec, target_arch, output_base_dir, header_file_name, source_f if output_base_dir: include_h_file_name = os.path.relpath( - os.path.normpath(header_file_name), os.path.normpath(output_base_dir)) + os.path.normpath(header_file_name), os.path.normpath(output_base_dir) + ) else: include_h_file_name = os.path.abspath(os.path.normpath(header_file_name)) diff --git a/buildscripts/idl/idl/generic_field_list_types.py b/buildscripts/idl/idl/generic_field_list_types.py index 13ef7fd395f..b11ece3b46b 100644 --- a/buildscripts/idl/idl/generic_field_list_types.py +++ b/buildscripts/idl/idl/generic_field_list_types.py @@ -43,7 +43,9 @@ class FieldListInfo: # type: () -> MethodInfo """Get the hasField method for a generic argument or generic reply field list.""" class_name = common.title_case(self.struct.cpp_name) - return MethodInfo(class_name, 'hasField', ['StringData fieldName'], 'bool', static=True) + return MethodInfo( + class_name, "hasField", ["StringData fieldName"], "bool", static=True + ) def get_should_forward_name(self): """Get the name of the shard-forwarding rule for a generic argument or reply field.""" @@ -62,8 +64,13 @@ class FieldListInfo: # type: () -> MethodInfo """Get the method for checking the shard-forwarding rule of an argument or reply field.""" class_name = common.title_case(self.struct.cpp_name) - return MethodInfo(class_name, self.get_should_forward_name(), ['StringData fieldName'], - 'bool', static=True) + return MethodInfo( + class_name, + self.get_should_forward_name(), + ["StringData fieldName"], + "bool", + static=True, + ) def get_field_list_info(struct): diff --git a/buildscripts/idl/idl/parser.py b/buildscripts/idl/idl/parser.py index 0e2d6ff759c..a760899318f 100644 --- a/buildscripts/idl/idl/parser.py +++ b/buildscripts/idl/idl/parser.py @@ -58,8 +58,13 @@ class _RuleDesc(object): REQUIRED = 1 OPTIONAL = 2 - def __init__(self, node_type, required=OPTIONAL, mapping_parser_func=None, - sequence_parser_func=None): + def __init__( + self, + node_type, + required=OPTIONAL, + mapping_parser_func=None, + sequence_parser_func=None, + ): # type: (str, int, Callable[[errors.ParserContext,yaml.nodes.MappingNode], Any], Callable[[errors.ParserContext,yaml.nodes.SequenceNode], Any]) -> None """Construct a parser rule description.""" assert required in (_RuleDesc.REQUIRED, _RuleDesc.OPTIONAL) @@ -72,23 +77,22 @@ class _RuleDesc(object): def _has_field( - node, # type: Union[yaml.nodes.MappingNode, yaml.nodes.ScalarNode, yaml.nodes.SequenceNode] - field_name, # type: str + node, # type: Union[yaml.nodes.MappingNode, yaml.nodes.ScalarNode, yaml.nodes.SequenceNode] + field_name, # type: str ): # type: (...) -> bool return any(kv[0].value == field_name for kv in node.value) def _generic_parser( - ctxt, # type: errors.ParserContext - node, # type: Union[yaml.nodes.MappingNode, yaml.nodes.ScalarNode, yaml.nodes.SequenceNode] - syntax_node_name, # type: str - syntax_node, # type: Any - mapping_rules # type: Dict[str, _RuleDesc] + ctxt, # type: errors.ParserContext + node, # type: Union[yaml.nodes.MappingNode, yaml.nodes.ScalarNode, yaml.nodes.SequenceNode] + syntax_node_name, # type: str + syntax_node, # type: Any + mapping_rules, # type: Dict[str, _RuleDesc] ): # type: (...) -> None field_name_set = set() # type: Set[str] for [first_node, second_node] in node.value: - first_name = first_node.value if first_name in field_name_set: @@ -106,32 +110,40 @@ def _generic_parser( syntax_node.__dict__[first_name] = ctxt.get_bool(second_node) elif rule_desc.node_type == "int_scalar": if ctxt.is_scalar_non_negative_int_node(second_node, first_name): - syntax_node.__dict__[first_name] = ctxt.get_non_negative_int(second_node) + syntax_node.__dict__[first_name] = ctxt.get_non_negative_int( + second_node + ) elif rule_desc.node_type == "scalar_or_sequence": if ctxt.is_scalar_sequence_or_scalar_node(second_node, first_name): syntax_node.__dict__[first_name] = rule_desc.sequence_parser_func( - ctxt, second_node) + ctxt, second_node + ) elif rule_desc.node_type == "sequence": if ctxt.is_scalar_sequence(second_node, first_name): syntax_node.__dict__[first_name] = rule_desc.sequence_parser_func( - ctxt, second_node) + ctxt, second_node + ) elif rule_desc.node_type == "sequence_mapping": if ctxt.is_sequence_mapping(second_node, first_name): syntax_node.__dict__[first_name] = rule_desc.sequence_parser_func( - ctxt, second_node) + ctxt, second_node + ) elif rule_desc.node_type == "scalar_or_mapping": if ctxt.is_scalar_or_mapping_node(second_node, first_name): syntax_node.__dict__[first_name] = rule_desc.mapping_parser_func( - ctxt, second_node) + ctxt, second_node + ) elif rule_desc.node_type == "mapping": if ctxt.is_mapping_node(second_node, first_name): syntax_node.__dict__[first_name] = rule_desc.mapping_parser_func( - ctxt, second_node) + ctxt, second_node + ) elif rule_desc.node_type == "required_bool_scalar": syntax_node.__dict__[first_name] = ctxt.get_required_bool(second_node) else: raise errors.IDLError( - "Unknown node_type '%s' for parser rule" % (rule_desc.node_type)) + "Unknown node_type '%s' for parser rule" % (rule_desc.node_type) + ) else: ctxt.add_unknown_node_error(first_node, syntax_node_name) @@ -145,27 +157,28 @@ def _generic_parser( # A bool is never "None" like other types, it simply defaults to "false". # It means "if bool is None" will always return false and there is no support for required # 'bool' at this time. Use the node type 'required_bool_scalar' if this behavior is not desired. - if not rule_desc.node_type == 'bool_scalar': + if not rule_desc.node_type == "bool_scalar": if syntax_node.__dict__[name] is None: ctxt.add_missing_required_field_error(node, syntax_node_name, name) else: raise errors.IDLError( - "Unknown node_type '%s' for parser required rule" % (rule_desc.node_type)) + "Unknown node_type '%s' for parser required rule" + % (rule_desc.node_type) + ) def _parse_mapping( - ctxt, # type: errors.ParserContext - spec, # type: syntax.IDLSpec - node, # type: Union[yaml.nodes.MappingNode, yaml.nodes.ScalarNode, yaml.nodes.SequenceNode] - syntax_node_name, # type: str - func # type: Callable[[errors.ParserContext,syntax.IDLSpec,str,Union[yaml.nodes.MappingNode, yaml.nodes.ScalarNode, yaml.nodes.SequenceNode]], None] + ctxt, # type: errors.ParserContext + spec, # type: syntax.IDLSpec + node, # type: Union[yaml.nodes.MappingNode, yaml.nodes.ScalarNode, yaml.nodes.SequenceNode] + syntax_node_name, # type: str + func, # type: Callable[[errors.ParserContext,syntax.IDLSpec,str,Union[yaml.nodes.MappingNode, yaml.nodes.ScalarNode, yaml.nodes.SequenceNode]], None] ): # type: (...) -> None """Parse a top-level mapping section in the IDL file.""" if not ctxt.is_mapping_node(node, syntax_node_name): return for [first_node, second_node] in node.value: - first_name = first_node.value func(ctxt, spec, first_name, second_node) @@ -174,16 +187,24 @@ def _parse_mapping( def _parse_initializer(ctxt, node): # type: (errors.ParserContext, Union[yaml.nodes.ScalarNode, yaml.nodes.MappingNode]) -> syntax.GlobalInitializer """Parse a global initializer.""" - init = syntax.GlobalInitializer(ctxt.file_name, node.start_mark.line, node.start_mark.column) + init = syntax.GlobalInitializer( + ctxt.file_name, node.start_mark.line, node.start_mark.column + ) - if node.id == 'scalar': + if node.id == "scalar": init.name = node.value return init - _generic_parser(ctxt, node, "initializer", init, { - "register": _RuleDesc('scalar', _RuleDesc.REQUIRED), - "store": _RuleDesc('scalar'), - }) + _generic_parser( + ctxt, + node, + "initializer", + init, + { + "register": _RuleDesc("scalar", _RuleDesc.REQUIRED), + "store": _RuleDesc("scalar"), + }, + ) return init @@ -191,14 +212,23 @@ def _parse_initializer(ctxt, node): def _parse_config_global(ctxt, node): # type: (errors.ParserContext, yaml.nodes.MappingNode) -> syntax.ConfigGlobal """Parse global settings for config options.""" - config = syntax.ConfigGlobal(ctxt.file_name, node.start_mark.line, node.start_mark.column) + config = syntax.ConfigGlobal( + ctxt.file_name, node.start_mark.line, node.start_mark.column + ) _generic_parser( - ctxt, node, "configs", config, { + ctxt, + node, + "configs", + config, + { "section": _RuleDesc("scalar"), "source": _RuleDesc("scalar_or_sequence"), - "initializer": _RuleDesc("scalar_or_mapping", mapping_parser_func=_parse_initializer), - }) + "initializer": _RuleDesc( + "scalar_or_mapping", mapping_parser_func=_parse_initializer + ), + }, + ) return config @@ -209,13 +239,21 @@ def _parse_global(ctxt, spec, node): if not ctxt.is_mapping_node(node, "global"): return - idlglobal = syntax.Global(ctxt.file_name, node.start_mark.line, node.start_mark.column) + idlglobal = syntax.Global( + ctxt.file_name, node.start_mark.line, node.start_mark.column + ) _generic_parser( - ctxt, node, "global", idlglobal, { - "cpp_namespace": _RuleDesc("scalar"), "cpp_includes": _RuleDesc("scalar_or_sequence"), - "configs": _RuleDesc("mapping", mapping_parser_func=_parse_config_global) - }) + ctxt, + node, + "global", + idlglobal, + { + "cpp_namespace": _RuleDesc("scalar"), + "cpp_includes": _RuleDesc("scalar_or_sequence"), + "configs": _RuleDesc("mapping", mapping_parser_func=_parse_config_global), + }, + ) spec.globals = idlglobal @@ -226,7 +264,9 @@ def _parse_imports(ctxt, spec, node): if not ctxt.is_scalar_sequence(node, "imports"): return - imports = syntax.Import(ctxt.file_name, node.start_mark.line, node.start_mark.column) + imports = syntax.Import( + ctxt.file_name, node.start_mark.line, node.start_mark.column + ) imports.imports = ctxt.get_list(node) spec.imports = imports @@ -241,18 +281,25 @@ def _parse_type(ctxt, spec, name, node): idltype.name = name _generic_parser( - ctxt, node, "type", idltype, { - "description": _RuleDesc('scalar', _RuleDesc.REQUIRED), - "cpp_type": _RuleDesc('scalar', _RuleDesc.REQUIRED), - "bson_serialization_type": _RuleDesc('scalar_or_sequence', _RuleDesc.REQUIRED), - "is_view": _RuleDesc('bool_scalar'), - "bindata_subtype": _RuleDesc('scalar'), - "serializer": _RuleDesc('scalar'), - "deserializer": _RuleDesc('scalar'), - "deserialize_with_tenant": _RuleDesc('bool_scalar'), - "internal_only": _RuleDesc('bool_scalar'), - "default": _RuleDesc('scalar'), - }) + ctxt, + node, + "type", + idltype, + { + "description": _RuleDesc("scalar", _RuleDesc.REQUIRED), + "cpp_type": _RuleDesc("scalar", _RuleDesc.REQUIRED), + "bson_serialization_type": _RuleDesc( + "scalar_or_sequence", _RuleDesc.REQUIRED + ), + "is_view": _RuleDesc("bool_scalar"), + "bindata_subtype": _RuleDesc("scalar"), + "serializer": _RuleDesc("scalar"), + "deserializer": _RuleDesc("scalar"), + "deserialize_with_tenant": _RuleDesc("bool_scalar"), + "internal_only": _RuleDesc("bool_scalar"), + "default": _RuleDesc("scalar"), + }, + ) spec.symbols.add_type(ctxt, idltype) @@ -260,16 +307,24 @@ def _parse_type(ctxt, spec, name, node): def _parse_expression(ctxt, node): # type: (errors.ParserContext, Union[yaml.nodes.ScalarNode,yaml.nodes.MappingNode]) -> syntax.Expression """Parse an expression as either a scalar or a mapping.""" - expr = syntax.Expression(ctxt.file_name, node.start_mark.line, node.start_mark.column) + expr = syntax.Expression( + ctxt.file_name, node.start_mark.line, node.start_mark.column + ) - if node.id == 'scalar': + if node.id == "scalar": expr.literal = node.value return expr - _generic_parser(ctxt, node, "expr", expr, { - "expr": _RuleDesc('scalar', _RuleDesc.REQUIRED), - "is_constexpr": _RuleDesc('bool_scalar'), - }) + _generic_parser( + ctxt, + node, + "expr", + expr, + { + "expr": _RuleDesc("scalar", _RuleDesc.REQUIRED), + "is_constexpr": _RuleDesc("bool_scalar"), + }, + ) return expr @@ -277,16 +332,27 @@ def _parse_expression(ctxt, node): def _parse_validator(ctxt, node): # type: (errors.ParserContext, yaml.nodes.MappingNode) -> syntax.Validator """Parse a validator for a field.""" - validator = syntax.Validator(ctxt.file_name, node.start_mark.line, node.start_mark.column) + validator = syntax.Validator( + ctxt.file_name, node.start_mark.line, node.start_mark.column + ) _generic_parser( - ctxt, node, "validator", validator, { + ctxt, + node, + "validator", + validator, + { "gt": _RuleDesc("scalar_or_mapping", mapping_parser_func=_parse_expression), "lt": _RuleDesc("scalar_or_mapping", mapping_parser_func=_parse_expression), - "gte": _RuleDesc("scalar_or_mapping", mapping_parser_func=_parse_expression), - "lte": _RuleDesc("scalar_or_mapping", mapping_parser_func=_parse_expression), + "gte": _RuleDesc( + "scalar_or_mapping", mapping_parser_func=_parse_expression + ), + "lte": _RuleDesc( + "scalar_or_mapping", mapping_parser_func=_parse_expression + ), "callback": _RuleDesc("scalar"), - }) + }, + ) return validator @@ -294,16 +360,23 @@ def _parse_validator(ctxt, node): def _parse_condition(ctxt, node): # type: (errors.ParserContext, yaml.nodes.MappingNode) -> syntax.Condition """Parse a condition.""" - condition = syntax.Condition(ctxt.file_name, node.start_mark.line, node.start_mark.column) + condition = syntax.Condition( + ctxt.file_name, node.start_mark.line, node.start_mark.column + ) _generic_parser( - ctxt, node, "condition", condition, { + ctxt, + node, + "condition", + condition, + { "preprocessor": _RuleDesc("scalar"), "constexpr": _RuleDesc("scalar"), "expr": _RuleDesc("scalar"), "feature_flag": _RuleDesc("scalar"), "min_fcv": _RuleDesc("scalar"), - }) + }, + ) return condition @@ -322,35 +395,48 @@ def _parse_field_type(ctxt, node): """ if node.id == "mapping": # For now, FieldTypeVariant is the only non-scalar node. - variant = syntax.FieldTypeVariant(ctxt.file_name, node.start_mark.line, - node.start_mark.column) + variant = syntax.FieldTypeVariant( + ctxt.file_name, node.start_mark.line, node.start_mark.column + ) _generic_parser( - ctxt, node, "type", variant, - {"variant": _RuleDesc("sequence", sequence_parser_func=_parse_variant_alternatives)}) + ctxt, + node, + "type", + variant, + { + "variant": _RuleDesc( + "sequence", sequence_parser_func=_parse_variant_alternatives + ) + }, + ) return variant else: assert node.id == "scalar" - if node.value.startswith('array syntax.ChainedType """Parse a chained type in a struct in the IDL file.""" - chain = syntax.ChainedType(ctxt.file_name, node.start_mark.line, node.start_mark.column) + chain = syntax.ChainedType( + ctxt.file_name, node.start_mark.line, node.start_mark.column + ) chain.name = name - _generic_parser(ctxt, node, "chain", chain, { - "cpp_name": _RuleDesc('scalar'), - }) + _generic_parser( + ctxt, + node, + "chain", + chain, + { + "cpp_name": _RuleDesc("scalar"), + }, + ) return chain @@ -495,7 +582,9 @@ def _parse_chained_types(ctxt, node): # Simple Scalar if second_node.id == "scalar": - chain = syntax.ChainedType(ctxt.file_name, node.start_mark.line, node.start_mark.column) + chain = syntax.ChainedType( + ctxt.file_name, node.start_mark.line, node.start_mark.column + ) chain.name = first_name chain.cpp_name = second_node.value chained_items.append(chain) @@ -511,12 +600,20 @@ def _parse_chained_types(ctxt, node): def _parse_chained_struct(ctxt, name, node): # type: (errors.ParserContext, str, yaml.nodes.MappingNode) -> syntax.ChainedStruct """Parse a chained struct in a struct in the IDL file.""" - chain = syntax.ChainedStruct(ctxt.file_name, node.start_mark.line, node.start_mark.column) + chain = syntax.ChainedStruct( + ctxt.file_name, node.start_mark.line, node.start_mark.column + ) chain.name = name - _generic_parser(ctxt, node, "chain", chain, { - "cpp_name": _RuleDesc('scalar'), - }) + _generic_parser( + ctxt, + node, + "chain", + chain, + { + "cpp_name": _RuleDesc("scalar"), + }, + ) return chain @@ -529,7 +626,6 @@ def _parse_chained_structs(ctxt, node): field_name_set = set() # type: Set[str] for [first_node, second_node] in node.value: - first_name = first_node.value if first_name in field_name_set: @@ -538,8 +634,9 @@ def _parse_chained_structs(ctxt, node): # Simple Scalar if second_node.id == "scalar": - chain = syntax.ChainedStruct(ctxt.file_name, node.start_mark.line, - node.start_mark.column) + chain = syntax.ChainedStruct( + ctxt.file_name, node.start_mark.line, node.start_mark.column + ) chain.name = first_name chain.cpp_name = second_node.value chained_items.append(chain) @@ -562,28 +659,42 @@ def _parse_struct(ctxt, spec, name, node): struct.name = name _generic_parser( - ctxt, node, "struct", struct, { - "description": _RuleDesc('scalar', _RuleDesc.REQUIRED), - "fields": _RuleDesc('mapping', mapping_parser_func=_parse_fields), - "chained_types": _RuleDesc('mapping', mapping_parser_func=_parse_chained_types), - "chained_structs": _RuleDesc('mapping', mapping_parser_func=_parse_chained_structs), + ctxt, + node, + "struct", + struct, + { + "description": _RuleDesc("scalar", _RuleDesc.REQUIRED), + "fields": _RuleDesc("mapping", mapping_parser_func=_parse_fields), + "chained_types": _RuleDesc( + "mapping", mapping_parser_func=_parse_chained_types + ), + "chained_structs": _RuleDesc( + "mapping", mapping_parser_func=_parse_chained_structs + ), "strict": _RuleDesc("bool_scalar"), "inline_chained_structs": _RuleDesc("bool_scalar"), - "immutable": _RuleDesc('bool_scalar'), + "immutable": _RuleDesc("bool_scalar"), "generate_comparison_operators": _RuleDesc("bool_scalar"), - "non_const_getter": _RuleDesc('bool_scalar'), - "cpp_validator_func": _RuleDesc('scalar'), - "is_command_reply": _RuleDesc('bool_scalar'), - "is_catalog_ctxt": _RuleDesc('bool_scalar'), - "is_generic_cmd_list": _RuleDesc('scalar'), - "query_shape_component": _RuleDesc('bool_scalar'), - "unsafe_dangerous_disable_extra_field_duplicate_checks": _RuleDesc("bool_scalar"), - }) + "non_const_getter": _RuleDesc("bool_scalar"), + "cpp_validator_func": _RuleDesc("scalar"), + "is_command_reply": _RuleDesc("bool_scalar"), + "is_catalog_ctxt": _RuleDesc("bool_scalar"), + "is_generic_cmd_list": _RuleDesc("scalar"), + "query_shape_component": _RuleDesc("bool_scalar"), + "unsafe_dangerous_disable_extra_field_duplicate_checks": _RuleDesc( + "bool_scalar" + ), + }, + ) # PyLint has difficulty with some iterables: https://github.com/PyCQA/pylint/issues/3105 # pylint: disable=not-an-iterable - if struct.generate_comparison_operators and struct.fields and any( - isinstance(f.type, syntax.FieldTypeVariant) for f in struct.fields): + if ( + struct.generate_comparison_operators + and struct.fields + and any(isinstance(f.type, syntax.FieldTypeVariant) for f in struct.fields) + ): ctxt.add_variant_comparison_error(struct) return @@ -599,11 +710,11 @@ def _parse_arbitrary_value(ctxt, node): # type: (errors.ParserContext, Union[yaml.nodes.MappingNode, yaml.nodes.ScalarNode, yaml.nodes.SequenceNode]) -> Any """Parse a generic YAML type to a Python type.""" - if node.id == 'mapping': + if node.id == "mapping": return {k.value: _parse_arbitrary_value(ctxt, v) for (k, v) in node.value} - elif node.id == 'sequence': + elif node.id == "sequence": return [_parse_arbitrary_value(ctxt, node) for node in node.value] - elif ctxt.is_scalar_node(node, 'node'): + elif ctxt.is_scalar_node(node, "node"): return node.value else: # Error added to context by is_scalar_node case above @@ -619,23 +730,31 @@ def _parse_enum_values(ctxt, node): field_name_set = set() # type: Set[str] for [first_node, second_node] in node.value: - first_name = first_node.value if first_name in field_name_set: ctxt.add_duplicate_error(first_node, first_name) continue - enum_value = syntax.EnumValue(ctxt.file_name, node.start_mark.line, node.start_mark.column) + enum_value = syntax.EnumValue( + ctxt.file_name, node.start_mark.line, node.start_mark.column + ) enum_value.name = first_name - if second_node.id == 'mapping': + if second_node.id == "mapping": _generic_parser( - ctxt, second_node, first_name, enum_value, { - "description": _RuleDesc('scalar', _RuleDesc.REQUIRED), - "value": _RuleDesc('scalar', _RuleDesc.REQUIRED), - "extra_data": _RuleDesc('mapping', mapping_parser_func=_parse_arbitrary_value), - }) + ctxt, + second_node, + first_name, + enum_value, + { + "description": _RuleDesc("scalar", _RuleDesc.REQUIRED), + "value": _RuleDesc("scalar", _RuleDesc.REQUIRED), + "extra_data": _RuleDesc( + "mapping", mapping_parser_func=_parse_arbitrary_value + ), + }, + ) elif ctxt.is_scalar_node(second_node, first_name): enum_value.value = second_node.value @@ -656,11 +775,16 @@ def _parse_enum(ctxt, spec, name, node): idl_enum.name = name _generic_parser( - ctxt, node, "enum", idl_enum, { - "description": _RuleDesc('scalar', _RuleDesc.REQUIRED), - "type": _RuleDesc('scalar', _RuleDesc.REQUIRED), - "values": _RuleDesc('mapping', mapping_parser_func=_parse_enum_values), - }) + ctxt, + node, + "enum", + idl_enum, + { + "description": _RuleDesc("scalar", _RuleDesc.REQUIRED), + "type": _RuleDesc("scalar", _RuleDesc.REQUIRED), + "values": _RuleDesc("mapping", mapping_parser_func=_parse_enum_values), + }, + ) if idl_enum.values is None: ctxt.add_empty_enum_error(node, idl_enum.name) @@ -675,14 +799,21 @@ def _parse_privilege(ctxt, node): if not ctxt.is_mapping_node(node, "privilege"): return None - privilege = syntax.Privilege(ctxt.file_name, node.start_mark.line, node.start_mark.column) + privilege = syntax.Privilege( + ctxt.file_name, node.start_mark.line, node.start_mark.column + ) _generic_parser( - ctxt, node, "privilege", privilege, { - "resource_pattern": _RuleDesc('scalar', _RuleDesc.REQUIRED), - "action_type": _RuleDesc('scalar_or_sequence', _RuleDesc.REQUIRED), - "agg_stage": _RuleDesc('scalar', _RuleDesc.OPTIONAL), - }) + ctxt, + node, + "privilege", + privilege, + { + "resource_pattern": _RuleDesc("scalar", _RuleDesc.REQUIRED), + "action_type": _RuleDesc("scalar_or_sequence", _RuleDesc.REQUIRED), + "agg_stage": _RuleDesc("scalar", _RuleDesc.OPTIONAL), + }, + ) return privilege @@ -691,17 +822,24 @@ def _parse_privilege_or_check(ctxt, node): # type: (errors.ParserContext, yaml.nodes.MappingNode) -> syntax.AccessCheck """Parse a privilege section in an access_check in the IDL file.""" - access_check = syntax.AccessCheck(ctxt.file_name, node.start_mark.line, node.start_mark.column) + access_check = syntax.AccessCheck( + ctxt.file_name, node.start_mark.line, node.start_mark.column + ) _generic_parser( - ctxt, node, "privilege_or_check", access_check, { - "check": _RuleDesc('scalar'), - "privilege": _RuleDesc('mapping', mapping_parser_func=_parse_privilege), - }) + ctxt, + node, + "privilege_or_check", + access_check, + { + "check": _RuleDesc("scalar"), + "privilege": _RuleDesc("mapping", mapping_parser_func=_parse_privilege), + }, + ) - if (access_check.check is None - and access_check.privilege is None) or (access_check.check is not None - and access_check.privilege is not None): + if (access_check.check is None and access_check.privilege is None) or ( + access_check.check is not None and access_check.privilege is not None + ): ctxt.add_either_check_or_privilege(access_check) return None @@ -718,22 +856,36 @@ def _parse_access_checks(ctxt, node): # type: (errors.ParserContext, yaml.nodes.MappingNode) -> syntax.AccessChecks """Parse an access check section in a struct in the IDL file.""" - access_checks = syntax.AccessChecks(ctxt.file_name, node.start_mark.line, - node.start_mark.column) + access_checks = syntax.AccessChecks( + ctxt.file_name, node.start_mark.line, node.start_mark.column + ) _generic_parser( - ctxt, node, "access_check", access_checks, { - "ignore": _RuleDesc('bool_scalar'), - "none": _RuleDesc('bool_scalar'), - "simple": _RuleDesc('mapping', mapping_parser_func=_parse_privilege_or_check), - "complex": _RuleDesc('sequence_mapping', sequence_parser_func=_parse_complex_sequence), - }) + ctxt, + node, + "access_check", + access_checks, + { + "ignore": _RuleDesc("bool_scalar"), + "none": _RuleDesc("bool_scalar"), + "simple": _RuleDesc( + "mapping", mapping_parser_func=_parse_privilege_or_check + ), + "complex": _RuleDesc( + "sequence_mapping", sequence_parser_func=_parse_complex_sequence + ), + }, + ) if ctxt.errors.has_errors(): return None - if (bool(access_checks.ignore) + bool(access_checks.none) + bool(access_checks.simple) + bool( - access_checks.complex)) != 1: + if ( + bool(access_checks.ignore) + + bool(access_checks.none) + + bool(access_checks.simple) + + bool(access_checks.complex) + ) != 1: ctxt.add_empty_access_check(access_checks) return None @@ -747,35 +899,52 @@ def _parse_command(ctxt, spec, name, node): if not ctxt.is_mapping_node(node, "command"): return - command = syntax.Command(ctxt.file_name, node.start_mark.line, node.start_mark.column) + command = syntax.Command( + ctxt.file_name, node.start_mark.line, node.start_mark.column + ) command.name = name _generic_parser( - ctxt, node, "command", command, { - "description": _RuleDesc('scalar', _RuleDesc.REQUIRED), - "chained_types": _RuleDesc('mapping', mapping_parser_func=_parse_chained_types), - "chained_structs": _RuleDesc('mapping', mapping_parser_func=_parse_chained_structs), - "fields": _RuleDesc('mapping', mapping_parser_func=_parse_fields), - "namespace": _RuleDesc('scalar', _RuleDesc.REQUIRED), - "cpp_name": _RuleDesc('scalar'), - "type": _RuleDesc('scalar_or_mapping', mapping_parser_func=_parse_field_type), - "command_name": _RuleDesc('scalar'), - "command_alias": _RuleDesc('scalar'), - "reply_type": _RuleDesc('scalar'), - "api_version": _RuleDesc('scalar'), - "is_deprecated": _RuleDesc('bool_scalar'), + ctxt, + node, + "command", + command, + { + "description": _RuleDesc("scalar", _RuleDesc.REQUIRED), + "chained_types": _RuleDesc( + "mapping", mapping_parser_func=_parse_chained_types + ), + "chained_structs": _RuleDesc( + "mapping", mapping_parser_func=_parse_chained_structs + ), + "fields": _RuleDesc("mapping", mapping_parser_func=_parse_fields), + "namespace": _RuleDesc("scalar", _RuleDesc.REQUIRED), + "cpp_name": _RuleDesc("scalar"), + "type": _RuleDesc( + "scalar_or_mapping", mapping_parser_func=_parse_field_type + ), + "command_name": _RuleDesc("scalar"), + "command_alias": _RuleDesc("scalar"), + "reply_type": _RuleDesc("scalar"), + "api_version": _RuleDesc("scalar"), + "is_deprecated": _RuleDesc("bool_scalar"), "strict": _RuleDesc("bool_scalar"), "inline_chained_structs": _RuleDesc("bool_scalar"), - "immutable": _RuleDesc('bool_scalar'), + "immutable": _RuleDesc("bool_scalar"), "generate_comparison_operators": _RuleDesc("bool_scalar"), - "allow_global_collection_name": _RuleDesc('bool_scalar'), - "non_const_getter": _RuleDesc('bool_scalar'), - "access_check": _RuleDesc('mapping', mapping_parser_func=_parse_access_checks), - }) + "allow_global_collection_name": _RuleDesc("bool_scalar"), + "non_const_getter": _RuleDesc("bool_scalar"), + "access_check": _RuleDesc( + "mapping", mapping_parser_func=_parse_access_checks + ), + }, + ) valid_commands = [ - common.COMMAND_NAMESPACE_CONCATENATE_WITH_DB, common.COMMAND_NAMESPACE_IGNORED, - common.COMMAND_NAMESPACE_TYPE, common.COMMAND_NAMESPACE_CONCATENATE_WITH_DB_OR_UUID + common.COMMAND_NAMESPACE_CONCATENATE_WITH_DB, + common.COMMAND_NAMESPACE_IGNORED, + common.COMMAND_NAMESPACE_TYPE, + common.COMMAND_NAMESPACE_CONCATENATE_WITH_DB_OR_UUID, ] if not command.command_name: @@ -789,8 +958,9 @@ def _parse_command(ctxt, spec, name, node): if command.namespace: if command.namespace not in valid_commands: - ctxt.add_bad_command_namespace_error(command, command.name, command.namespace, - valid_commands) + ctxt.add_bad_command_namespace_error( + command, command.name, command.namespace, valid_commands + ) # type property must be specified for a namespace = type if command.namespace == common.COMMAND_NAMESPACE_TYPE and not command.type: @@ -808,7 +978,7 @@ def _parse_command(ctxt, spec, name, node): if not command.api_version: for field in command.fields: - if field.stability is not None and field.stability != 'stable': + if field.stability is not None and field.stability != "stable": ctxt.add_stability_no_api_version(field, command.name) spec.symbols.add_command(ctxt, command) @@ -817,20 +987,27 @@ def _parse_command(ctxt, spec, name, node): def _parse_server_parameter_class(ctxt, node): # type: (errors.ParserContext, Union[yaml.nodes.ScalarNode,yaml.nodes.MappingNode]) -> syntax.ServerParameterClass """Parse a server_parameter.cpp_class as either a scalar or a mapping.""" - spc = syntax.ServerParameterClass(ctxt.file_name, node.start_mark.line, node.start_mark.column) + spc = syntax.ServerParameterClass( + ctxt.file_name, node.start_mark.line, node.start_mark.column + ) - if node.id == 'scalar': + if node.id == "scalar": spc.name = node.value return spc _generic_parser( - ctxt, node, "cpp_class", spc, { - "name": _RuleDesc('scalar', _RuleDesc.REQUIRED), - "data": _RuleDesc('scalar'), - "override_ctor": _RuleDesc('bool_scalar'), - "override_set": _RuleDesc('bool_scalar'), - "override_validate": _RuleDesc('bool_scalar'), - }) + ctxt, + node, + "cpp_class", + spc, + { + "name": _RuleDesc("scalar", _RuleDesc.REQUIRED), + "data": _RuleDesc("scalar"), + "override_ctor": _RuleDesc("bool_scalar"), + "override_set": _RuleDesc("bool_scalar"), + "override_validate": _RuleDesc("bool_scalar"), + }, + ) return spc @@ -841,28 +1018,37 @@ def _parse_server_parameter(ctxt, spec, name, node): if not ctxt.is_mapping_node(node, "server_parameters"): return - param = syntax.ServerParameter(ctxt.file_name, node.start_mark.line, node.start_mark.column) + param = syntax.ServerParameter( + ctxt.file_name, node.start_mark.line, node.start_mark.column + ) param.name = name # Declare as local to avoid ugly formatting with long line. map_class = _parse_server_parameter_class _generic_parser( - ctxt, node, "server_parameters", param, { - "set_at": _RuleDesc('scalar_or_sequence', _RuleDesc.REQUIRED), - "description": _RuleDesc('scalar', _RuleDesc.REQUIRED), - "cpp_vartype": _RuleDesc('scalar'), - "cpp_varname": _RuleDesc('scalar'), - "condition": _RuleDesc('mapping', mapping_parser_func=_parse_condition), - "redact": _RuleDesc('required_bool_scalar', _RuleDesc.REQUIRED), - "default": _RuleDesc('scalar_or_mapping', mapping_parser_func=_parse_expression), - "test_only": _RuleDesc('bool_scalar'), - "deprecated_name": _RuleDesc('scalar_or_sequence'), - "validator": _RuleDesc('mapping', mapping_parser_func=_parse_validator), + ctxt, + node, + "server_parameters", + param, + { + "set_at": _RuleDesc("scalar_or_sequence", _RuleDesc.REQUIRED), + "description": _RuleDesc("scalar", _RuleDesc.REQUIRED), + "cpp_vartype": _RuleDesc("scalar"), + "cpp_varname": _RuleDesc("scalar"), + "condition": _RuleDesc("mapping", mapping_parser_func=_parse_condition), + "redact": _RuleDesc("required_bool_scalar", _RuleDesc.REQUIRED), + "default": _RuleDesc( + "scalar_or_mapping", mapping_parser_func=_parse_expression + ), + "test_only": _RuleDesc("bool_scalar"), + "deprecated_name": _RuleDesc("scalar_or_sequence"), + "validator": _RuleDesc("mapping", mapping_parser_func=_parse_validator), "on_update": _RuleDesc("scalar"), - "omit_in_ftdc": _RuleDesc('bool_scalar'), - "cpp_class": _RuleDesc('scalar_or_mapping', mapping_parser_func=map_class), - }) + "omit_in_ftdc": _RuleDesc("bool_scalar"), + "cpp_class": _RuleDesc("scalar_or_mapping", mapping_parser_func=map_class), + }, + ) spec.server_parameters.append(param) @@ -873,24 +1059,32 @@ def _parse_feature_flag(ctxt, spec, name, node): if not ctxt.is_mapping_node(node, "feature_flags"): return - param = syntax.FeatureFlag(ctxt.file_name, node.start_mark.line, node.start_mark.column) + param = syntax.FeatureFlag( + ctxt.file_name, node.start_mark.line, node.start_mark.column + ) param.name = name _generic_parser( - ctxt, node, "feature_flags", param, { - "description": - _RuleDesc('scalar', _RuleDesc.REQUIRED), - "cpp_varname": - _RuleDesc('scalar'), - "default": - _RuleDesc('scalar_or_mapping', _RuleDesc.REQUIRED, - mapping_parser_func=_parse_expression), - "version": - _RuleDesc('scalar'), - "shouldBeFCVGated": - _RuleDesc('scalar_or_mapping', _RuleDesc.REQUIRED, - mapping_parser_func=_parse_expression), - }) + ctxt, + node, + "feature_flags", + param, + { + "description": _RuleDesc("scalar", _RuleDesc.REQUIRED), + "cpp_varname": _RuleDesc("scalar"), + "default": _RuleDesc( + "scalar_or_mapping", + _RuleDesc.REQUIRED, + mapping_parser_func=_parse_expression, + ), + "version": _RuleDesc("scalar"), + "shouldBeFCVGated": _RuleDesc( + "scalar_or_mapping", + _RuleDesc.REQUIRED, + mapping_parser_func=_parse_expression, + ), + }, + ) spec.feature_flags.append(param) @@ -901,33 +1095,46 @@ def _parse_config_option(ctxt, spec, name, node): if not ctxt.is_mapping_node(node, "configs"): return - option = syntax.ConfigOption(ctxt.file_name, node.start_mark.line, node.start_mark.column) + option = syntax.ConfigOption( + ctxt.file_name, node.start_mark.line, node.start_mark.column + ) option.name = name _generic_parser( - ctxt, node, "configs", option, { - "short_name": _RuleDesc('scalar'), - "single_name": _RuleDesc('scalar'), - "deprecated_name": _RuleDesc('scalar_or_sequence'), - "deprecated_short_name": _RuleDesc('scalar_or_sequence'), - "description": _RuleDesc('scalar_or_mapping', _RuleDesc.REQUIRED, _parse_expression), - "section": _RuleDesc('scalar'), - "arg_vartype": _RuleDesc('scalar', _RuleDesc.REQUIRED), - "cpp_vartype": _RuleDesc('scalar'), - "cpp_varname": _RuleDesc('scalar'), - "condition": _RuleDesc('mapping', mapping_parser_func=_parse_condition), - "conflicts": _RuleDesc('scalar_or_sequence'), - "requires": _RuleDesc('scalar_or_sequence'), - "hidden": _RuleDesc('bool_scalar'), - "redact": _RuleDesc('bool_scalar'), - "default": _RuleDesc('scalar_or_mapping', mapping_parser_func=_parse_expression), - "implicit": _RuleDesc('scalar_or_mapping', mapping_parser_func=_parse_expression), - "source": _RuleDesc('scalar_or_sequence'), - "canonicalize": _RuleDesc('scalar'), - "duplicate_behavior": _RuleDesc('scalar'), - "positional": _RuleDesc('scalar'), - "validator": _RuleDesc('mapping', mapping_parser_func=_parse_validator), - }) + ctxt, + node, + "configs", + option, + { + "short_name": _RuleDesc("scalar"), + "single_name": _RuleDesc("scalar"), + "deprecated_name": _RuleDesc("scalar_or_sequence"), + "deprecated_short_name": _RuleDesc("scalar_or_sequence"), + "description": _RuleDesc( + "scalar_or_mapping", _RuleDesc.REQUIRED, _parse_expression + ), + "section": _RuleDesc("scalar"), + "arg_vartype": _RuleDesc("scalar", _RuleDesc.REQUIRED), + "cpp_vartype": _RuleDesc("scalar"), + "cpp_varname": _RuleDesc("scalar"), + "condition": _RuleDesc("mapping", mapping_parser_func=_parse_condition), + "conflicts": _RuleDesc("scalar_or_sequence"), + "requires": _RuleDesc("scalar_or_sequence"), + "hidden": _RuleDesc("bool_scalar"), + "redact": _RuleDesc("bool_scalar"), + "default": _RuleDesc( + "scalar_or_mapping", mapping_parser_func=_parse_expression + ), + "implicit": _RuleDesc( + "scalar_or_mapping", mapping_parser_func=_parse_expression + ), + "source": _RuleDesc("scalar_or_sequence"), + "canonicalize": _RuleDesc("scalar"), + "duplicate_behavior": _RuleDesc("scalar"), + "positional": _RuleDesc("scalar"), + "validator": _RuleDesc("mapping", mapping_parser_func=_parse_validator), + }, + ) spec.configs.append(option) @@ -986,13 +1193,13 @@ def parse_file(stream, error_file_name, parse_non_forward_compatible_section=Tru if not root_node.id == "mapping": raise errors.IDLError( - "Expected a YAML mapping node as root node of IDL document, got '%s' instead" % - root_node.id) + "Expected a YAML mapping node as root node of IDL document, got '%s' instead" + % root_node.id + ) field_name_set = set() # type: Set[str] for [first_node, second_node] in root_node.value: - first_name = first_node.value if first_name in field_name_set: @@ -1004,22 +1211,29 @@ def parse_file(stream, error_file_name, parse_non_forward_compatible_section=Tru elif first_name == "imports": _parse_imports(ctxt, spec, second_node) elif first_name == "enums": - _parse_mapping(ctxt, spec, second_node, 'enums', _parse_enum) + _parse_mapping(ctxt, spec, second_node, "enums", _parse_enum) elif first_name == "types": - _parse_mapping(ctxt, spec, second_node, 'types', _parse_type) + _parse_mapping(ctxt, spec, second_node, "types", _parse_type) elif first_name == "structs": - _parse_mapping(ctxt, spec, second_node, 'structs', _parse_struct) + _parse_mapping(ctxt, spec, second_node, "structs", _parse_struct) elif first_name == "commands": - _parse_mapping(ctxt, spec, second_node, 'commands', _parse_command) + _parse_mapping(ctxt, spec, second_node, "commands", _parse_command) elif first_name == "server_parameters": if parse_non_forward_compatible_section: - _parse_mapping(ctxt, spec, second_node, "server_parameters", - _parse_server_parameter) + _parse_mapping( + ctxt, + spec, + second_node, + "server_parameters", + _parse_server_parameter, + ) elif first_name == "configs": _parse_mapping(ctxt, spec, second_node, "configs", _parse_config_option) elif first_name == "feature_flags": if parse_non_forward_compatible_section: - _parse_mapping(ctxt, spec, second_node, "feature_flags", _parse_feature_flag) + _parse_mapping( + ctxt, spec, second_node, "feature_flags", _parse_feature_flag + ) else: ctxt.add_unknown_root_node_error(first_node) @@ -1073,8 +1287,10 @@ def parse(stream, input_file_name, resolver, parse_non_forward_compatible_sectio imports = [] # type: List[Tuple[common.SourceLocation, str, str]] needs_include = [] # type: List[str] if root_doc.spec.imports: - imports = [(root_doc.spec.imports, input_file_name, import_file_name) - for import_file_name in root_doc.spec.imports.imports] + imports = [ + (root_doc.spec.imports, input_file_name, import_file_name) + for import_file_name in root_doc.spec.imports.imports + ] resolved_file_names = [] # type: List[str] @@ -1102,27 +1318,35 @@ def parse(stream, input_file_name, resolver, parse_non_forward_compatible_sectio # Parse imported file with resolver.open(resolved_file_name) as file_stream: - parsed_doc = parse_file(file_stream, resolved_file_name, - parse_non_forward_compatible_section) + parsed_doc = parse_file( + file_stream, resolved_file_name, parse_non_forward_compatible_section + ) # Check for errors if parsed_doc.errors: return parsed_doc # We need to generate includes for imported IDL files which have structs or enums. - if (base_file_name == input_file_name - and (parsed_doc.spec.symbols.structs or parsed_doc.spec.symbols.enums)): + if base_file_name == input_file_name and ( + parsed_doc.spec.symbols.structs or parsed_doc.spec.symbols.enums + ): needs_include.append(imported_file_name) # Add other imported files to the list of files to parse if parsed_doc.spec.imports: - imports += [(parsed_doc.spec.imports, resolved_file_name, import_file_name) - for import_file_name in parsed_doc.spec.imports.imports] + imports += [ + (parsed_doc.spec.imports, resolved_file_name, import_file_name) + for import_file_name in parsed_doc.spec.imports.imports + ] # Merge cpp_includes as needed if parsed_doc.spec.globals and parsed_doc.spec.globals.cpp_includes: root_doc.spec.globals.cpp_includes = list( - set(root_doc.spec.globals.cpp_includes + parsed_doc.spec.globals.cpp_includes)) + set( + root_doc.spec.globals.cpp_includes + + parsed_doc.spec.globals.cpp_includes + ) + ) # Merge symbol tables together root_doc.spec.symbols.add_imported_symbol_table(ctxt, parsed_doc.spec.symbols) diff --git a/buildscripts/idl/idl/struct_types.py b/buildscripts/idl/idl/struct_types.py index 1d477b5ccc4..4ecc96bca11 100644 --- a/buildscripts/idl/idl/struct_types.py +++ b/buildscripts/idl/idl/struct_types.py @@ -36,8 +36,14 @@ from . import ast, common, cpp_types def _is_required_constructor_arg(field): # type: (ast.Field) -> bool """Get whether we require this field to have a value set for constructor purposes.""" - return not field.ignore and not field.optional and not field.default and not field.chained \ - and not field.chained_struct_field and not field.serialize_op_msg_request_only + return ( + not field.ignore + and not field.optional + and not field.default + and not field.chained + and not field.chained_struct_field + and not field.serialize_op_msg_request_only + ) def _get_arg_for_field(field): @@ -54,7 +60,9 @@ def _get_required_parameters(struct): # type: (ast.Struct) -> List[str] """Get a list of arguments for required parameters.""" params = [ - _get_arg_for_field(field) for field in struct.fields if _is_required_constructor_arg(field) + _get_arg_for_field(field) + for field in struct.fields + if _is_required_constructor_arg(field) ] # Since this contains defaults, we need to push this to the end of the list. params.append(_get_serialization_ctx_arg()) @@ -62,7 +70,7 @@ def _get_required_parameters(struct): def _get_serialization_ctx_arg(): - return 'boost::optional serializationContext = boost::none' + return "boost::optional serializationContext = boost::none" class ArgumentInfo(object): @@ -72,12 +80,12 @@ class ArgumentInfo(object): # type: (str) -> None """Create a instance of the ArgumentInfo class by parsing the argument string.""" self.defaults = None - equal_tokens = arg.split('=') + equal_tokens = arg.split("=") if len(equal_tokens) > 1: self.defaults = equal_tokens[-1].strip() - space_tokens = equal_tokens[0].strip().split(' ') - self.type = ' '.join(space_tokens[0:-1]) + space_tokens = equal_tokens[0].strip().split(" ") + self.type = " ".join(space_tokens[0:-1]) self.name = space_tokens[-1] def get_string(self, get_defaults): @@ -93,8 +101,17 @@ class MethodInfo(object): # pylint: disable=too-many-instance-attributes - def __init__(self, class_name, method_name, args, return_type=None, static=False, const=False, - explicit=False, desc_for_comment=None): + def __init__( + self, + class_name, + method_name, + args, + return_type=None, + static=False, + const=False, + explicit=False, + desc_for_comment=None, + ): # type: (str, str, List[str], str, bool, bool, bool, Optional[str]) -> None # pylint: disable=too-many-arguments """Create a MethodInfo instance.""" @@ -110,59 +127,71 @@ class MethodInfo(object): def get_declaration(self): # type: () -> str """Get a declaration for a method.""" - pre_modifiers = '' - post_modifiers = '' - return_type_str = '' + pre_modifiers = "" + post_modifiers = "" + return_type_str = "" if self.static: - pre_modifiers = 'static ' + pre_modifiers = "static " if self.const: - post_modifiers = ' const' + post_modifiers = " const" if self.explicit: - pre_modifiers += 'explicit ' + pre_modifiers += "explicit " if self.return_type: - return_type_str = self.return_type + ' ' + return_type_str = self.return_type + " " return common.template_args( "${pre_modifiers}${return_type}${method_name}(${args})${post_modifiers};", - pre_modifiers=pre_modifiers, return_type=return_type_str, method_name=self.method_name, - args=', '.join( - [arg.get_string(True) for arg in self.args]), post_modifiers=post_modifiers) + pre_modifiers=pre_modifiers, + return_type=return_type_str, + method_name=self.method_name, + args=", ".join([arg.get_string(True) for arg in self.args]), + post_modifiers=post_modifiers, + ) def get_definition(self): # type: () -> str """Get a definition for a method.""" - pre_modifiers = '' - post_modifiers = '' - return_type_str = '' + pre_modifiers = "" + post_modifiers = "" + return_type_str = "" if self.const: - post_modifiers = ' const' + post_modifiers = " const" if self.return_type: - return_type_str = self.return_type + ' ' + return_type_str = self.return_type + " " return common.template_args( "${pre_modifiers}${return_type}${class_name}::${method_name}(${args})${post_modifiers}", - pre_modifiers=pre_modifiers, return_type=return_type_str, class_name=self.class_name, - method_name=self.method_name, args=', '.join( - [arg.get_string(False) for arg in self.args]), post_modifiers=post_modifiers) + pre_modifiers=pre_modifiers, + return_type=return_type_str, + class_name=self.class_name, + method_name=self.method_name, + args=", ".join([arg.get_string(False) for arg in self.args]), + post_modifiers=post_modifiers, + ) def get_call(self, obj): # type: (Optional[str]) -> str """Generate a simple call to the method using the defined args list.""" - args = ', '.join([arg.name for arg in self.args]) + args = ", ".join([arg.name for arg in self.args]) if obj: - return common.template_args("${obj}.${method_name}(${args});", obj=obj, - method_name=self.method_name, args=args) + return common.template_args( + "${obj}.${method_name}(${args});", + obj=obj, + method_name=self.method_name, + args=args, + ) - return common.template_args("${method_name}(${args});", method_name=self.method_name, - args=args) + return common.template_args( + "${method_name}(${args});", method_name=self.method_name, args=args + ) def get_desc_for_comment(self): # type: () -> Optional[str] @@ -280,7 +309,9 @@ class _StructTypeInfo(StructTypeInfoBase): def get_required_constructor_method(self, gen_header=False): # type: (bool) -> MethodInfo class_name = common.title_case(self._struct.cpp_name) - return MethodInfo(class_name, class_name, _get_required_parameters(self._struct)) + return MethodInfo( + class_name, class_name, _get_required_parameters(self._struct) + ) def get_sharing_deserializer_static_method(self): # type: () -> MethodInfo @@ -288,9 +319,14 @@ class _StructTypeInfo(StructTypeInfoBase): comment = textwrap.dedent(f"""\ Factory function that parses a {class_name} from a BSONObj. A {class_name} parsed this way participates in ownership of the data underlying the BSONObj.""") - return MethodInfo(class_name, 'parseSharingOwnership', - ['const IDLParserContext& ctxt', 'const BSONObj& bsonObject'], class_name, - static=True, desc_for_comment=comment) + return MethodInfo( + class_name, + "parseSharingOwnership", + ["const IDLParserContext& ctxt", "const BSONObj& bsonObject"], + class_name, + static=True, + desc_for_comment=comment, + ) def get_owned_deserializer_static_method(self): # type: () -> MethodInfo @@ -298,9 +334,14 @@ class _StructTypeInfo(StructTypeInfoBase): comment = textwrap.dedent(f"""\ Factory function that parses a {class_name} from a BSONObj. A {class_name} parsed this way takes ownership of the data underlying the BSONObj.""") - return MethodInfo(class_name, 'parseOwned', - ['const IDLParserContext& ctxt', 'BSONObj&& bsonObject'], class_name, - static=True, desc_for_comment=comment) + return MethodInfo( + class_name, + "parseOwned", + ["const IDLParserContext& ctxt", "BSONObj&& bsonObject"], + class_name, + static=True, + desc_for_comment=comment, + ) def get_deserializer_static_method(self): # type: () -> MethodInfo @@ -311,23 +352,36 @@ class _StructTypeInfo(StructTypeInfoBase): ensure the validity any members of this struct that point-into the BSONObj (i.e. unowned objects).""") - return MethodInfo(class_name, 'parse', - ['const IDLParserContext& ctxt', 'const BSONObj& bsonObject'], class_name, - static=True, desc_for_comment=comment) + return MethodInfo( + class_name, + "parse", + ["const IDLParserContext& ctxt", "const BSONObj& bsonObject"], + class_name, + static=True, + desc_for_comment=comment, + ) def get_deserializer_method(self): # type: () -> MethodInfo return MethodInfo( - common.title_case(self._struct.cpp_name), 'parseProtected', - ['const IDLParserContext& ctxt', 'const BSONObj& bsonObject'], 'void') + common.title_case(self._struct.cpp_name), + "parseProtected", + ["const IDLParserContext& ctxt", "const BSONObj& bsonObject"], + "void", + ) def get_serializer_method(self): # type: () -> MethodInfo - args = ['BSONObjBuilder* builder'] + args = ["BSONObjBuilder* builder"] if self._struct.query_shape_component: args.append("const SerializationOptions& options = {}") return MethodInfo( - common.title_case(self._struct.cpp_name), 'serialize', args, 'void', const=True) + common.title_case(self._struct.cpp_name), + "serialize", + args, + "void", + const=True, + ) def get_to_bson_method(self): # type: () -> MethodInfo @@ -335,7 +389,12 @@ class _StructTypeInfo(StructTypeInfoBase): if self._struct.query_shape_component: args.append("const SerializationOptions& options = {}") return MethodInfo( - common.title_case(self._struct.cpp_name), 'toBSON', args, 'BSONObj', const=True) + common.title_case(self._struct.cpp_name), + "toBSON", + args, + "BSONObj", + const=True, + ) def get_op_msg_request_serializer_method(self): # type: () -> Optional[MethodInfo] @@ -379,21 +438,32 @@ class _CommandBaseTypeInfo(_StructTypeInfo): def get_op_msg_request_serializer_method(self): # type: () -> Optional[MethodInfo] return MethodInfo( - common.title_case(self._struct.cpp_name), 'serialize', - ['const BSONObj& commandPassthroughFields'], 'OpMsgRequest', const=True) + common.title_case(self._struct.cpp_name), + "serialize", + ["const BSONObj& commandPassthroughFields"], + "OpMsgRequest", + const=True, + ) def get_op_msg_request_deserializer_static_method(self): # type: () -> Optional[MethodInfo] class_name = common.title_case(self._struct.cpp_name) - return MethodInfo(class_name, 'parse', - ['const IDLParserContext& ctxt', 'const OpMsgRequest& request'], - class_name, static=True) + return MethodInfo( + class_name, + "parse", + ["const IDLParserContext& ctxt", "const OpMsgRequest& request"], + class_name, + static=True, + ) def get_op_msg_request_deserializer_method(self): # type: () -> Optional[MethodInfo] return MethodInfo( - common.title_case(self._struct.cpp_name), 'parseProtected', - ['const IDLParserContext& ctxt', 'const OpMsgRequest& request'], 'void') + common.title_case(self._struct.cpp_name), + "parseProtected", + ["const IDLParserContext& ctxt", "const OpMsgRequest& request"], + "void", + ) class _IgnoredCommandTypeInfo(_CommandBaseTypeInfo): @@ -409,20 +479,29 @@ class _IgnoredCommandTypeInfo(_CommandBaseTypeInfo): def get_serializer_method(self): # type: () -> MethodInfo return MethodInfo( - common.title_case(self._struct.cpp_name), 'serialize', - ['const BSONObj& commandPassthroughFields', 'BSONObjBuilder* builder'], 'void', - const=True) + common.title_case(self._struct.cpp_name), + "serialize", + ["const BSONObj& commandPassthroughFields", "BSONObjBuilder* builder"], + "void", + const=True, + ) def get_to_bson_method(self): # type: () -> MethodInfo # Commands that require namespaces require it as a parameter to serialize() return MethodInfo( - common.title_case(self._struct.cpp_name), 'toBSON', - ['const BSONObj& commandPassthroughFields'], 'BSONObj', const=True) + common.title_case(self._struct.cpp_name), + "toBSON", + ["const BSONObj& commandPassthroughFields"], + "BSONObj", + const=True, + ) def gen_serializer(self, indented_writer): # type: (writer.IndentedTextWriter) -> None - indented_writer.write_line('builder->append("%s"_sd, 1);' % (self._command.name)) + indented_writer.write_line( + 'builder->append("%s"_sd, 1);' % (self._command.name) + ) def gen_namespace_check(self, indented_writer, db_name, element): # type: (writer.IndentedTextWriter, str, str) -> None @@ -436,8 +515,8 @@ def _get_command_type_parameter(command, gen_header=False): # Use the storage type for the constructor argument since the generated code will use std::move. member_type = cpp_type_info.get_storage_type() result = f"{member_type} {common.camel_case(command.command_field.cpp_name)}" - if not gen_header or '&' in result: - result = 'const ' + result + if not gen_header or "&" in result: + result = "const " + result return result @@ -464,27 +543,41 @@ class _CommandFromType(_CommandBaseTypeInfo): class_name = common.title_case(self._struct.cpp_name) arg = _get_command_type_parameter(self._command, gen_header) - return MethodInfo(class_name, class_name, [arg] + _get_required_parameters(self._struct), - explicit=True) + return MethodInfo( + class_name, + class_name, + [arg] + _get_required_parameters(self._struct), + explicit=True, + ) def get_serializer_method(self): # type: () -> MethodInfo return MethodInfo( - common.title_case(self._struct.cpp_name), 'serialize', - ['const BSONObj& commandPassthroughFields', 'BSONObjBuilder* builder'], 'void', - const=True) + common.title_case(self._struct.cpp_name), + "serialize", + ["const BSONObj& commandPassthroughFields", "BSONObjBuilder* builder"], + "void", + const=True, + ) def get_to_bson_method(self): # type: () -> MethodInfo return MethodInfo( - common.title_case(self._struct.cpp_name), 'toBSON', - ['const BSONObj& commandPassthroughFields'], 'BSONObj', const=True) + common.title_case(self._struct.cpp_name), + "toBSON", + ["const BSONObj& commandPassthroughFields"], + "BSONObj", + const=True, + ) def get_deserializer_method(self): # type: () -> MethodInfo return MethodInfo( - common.title_case(self._struct.cpp_name), 'parseProtected', - ['const IDLParserContext& ctxt', 'const BSONObj& bsonObject'], 'void') + common.title_case(self._struct.cpp_name), + "parseProtected", + ["const IDLParserContext& ctxt", "const BSONObj& bsonObject"], + "void", + ) def gen_getter_method(self, indented_writer): # type: (writer.IndentedTextWriter) -> None @@ -516,78 +609,103 @@ class _CommandWithNamespaceTypeInfo(_CommandBaseTypeInfo): @staticmethod def _get_nss_param(gen_header): - nss_param = 'NamespaceString nss' + nss_param = "NamespaceString nss" if not gen_header: - nss_param = 'const ' + nss_param + nss_param = "const " + nss_param return nss_param def get_constructor_method(self, gen_header=False): # type: (bool) -> MethodInfo class_name = common.title_case(self._struct.cpp_name) sc_arg = _get_serialization_ctx_arg() - return MethodInfo(class_name, class_name, [self._get_nss_param(gen_header), sc_arg], - explicit=True) + return MethodInfo( + class_name, + class_name, + [self._get_nss_param(gen_header), sc_arg], + explicit=True, + ) def get_required_constructor_method(self, gen_header=False): # type: (bool) -> MethodInfo class_name = common.title_case(self._struct.cpp_name) return MethodInfo( - class_name, class_name, - [self._get_nss_param(gen_header)] + _get_required_parameters(self._struct)) + class_name, + class_name, + [self._get_nss_param(gen_header)] + _get_required_parameters(self._struct), + ) def get_serializer_method(self): # type: () -> MethodInfo return MethodInfo( - common.title_case(self._struct.cpp_name), 'serialize', - ['const BSONObj& commandPassthroughFields', 'BSONObjBuilder* builder'], 'void', - const=True) + common.title_case(self._struct.cpp_name), + "serialize", + ["const BSONObj& commandPassthroughFields", "BSONObjBuilder* builder"], + "void", + const=True, + ) def get_to_bson_method(self): # type: () -> MethodInfo return MethodInfo( - common.title_case(self._struct.cpp_name), 'toBSON', - ['const BSONObj& commandPassthroughFields'], 'BSONObj', const=True) + common.title_case(self._struct.cpp_name), + "toBSON", + ["const BSONObj& commandPassthroughFields"], + "BSONObj", + const=True, + ) def get_deserializer_method(self): # type: () -> MethodInfo return MethodInfo( - common.title_case(self._struct.cpp_name), 'parseProtected', - ['const IDLParserContext& ctxt', 'const BSONObj& bsonObject'], 'void') + common.title_case(self._struct.cpp_name), + "parseProtected", + ["const IDLParserContext& ctxt", "const BSONObj& bsonObject"], + "void", + ) def gen_getter_method(self, indented_writer): # type: (writer.IndentedTextWriter) -> None - indented_writer.write_line('const NamespaceString& getNamespace() const { return _nss; }') + indented_writer.write_line( + "const NamespaceString& getNamespace() const { return _nss; }" + ) if self._struct.non_const_getter: - indented_writer.write_line('NamespaceString& getNamespace() { return _nss; }') + indented_writer.write_line( + "NamespaceString& getNamespace() { return _nss; }" + ) def gen_member(self, indented_writer): # type: (writer.IndentedTextWriter) -> None - indented_writer.write_line('NamespaceString _nss;') + indented_writer.write_line("NamespaceString _nss;") def gen_serializer(self, indented_writer): # type: (writer.IndentedTextWriter) -> None if self._struct.allow_global_collection_name: indented_writer.write_line( - '_nss.serializeCollectionName(builder, "%s"_sd);' % (self._command.name)) + '_nss.serializeCollectionName(builder, "%s"_sd);' % (self._command.name) + ) else: - indented_writer.write_line('invariant(!_nss.isEmpty());') + indented_writer.write_line("invariant(!_nss.isEmpty());") indented_writer.write_line( - 'builder->append("%s"_sd, _nss.coll());' % (self._command.name)) + 'builder->append("%s"_sd, _nss.coll());' % (self._command.name) + ) indented_writer.write_empty_line() def gen_namespace_check(self, indented_writer, db_name, element): # type: (writer.IndentedTextWriter, str, str) -> None # TODO: should the name of the first element be validated?? - indented_writer.write_line('invariant(_nss.isEmpty());') - allow_global = 'true' if self._struct.allow_global_collection_name else 'false' + indented_writer.write_line("invariant(_nss.isEmpty());") + allow_global = "true" if self._struct.allow_global_collection_name else "false" indented_writer.write_line( - 'auto collectionName = ctxt.checkAndAssertCollectionName(%s, %s);' % (element, - allow_global)) + "auto collectionName = ctxt.checkAndAssertCollectionName(%s, %s);" + % (element, allow_global) + ) indented_writer.write_line( - '_nss = NamespaceStringUtil::deserialize(%s, collectionName);' % (db_name)) + "_nss = NamespaceStringUtil::deserialize(%s, collectionName);" % (db_name) + ) indented_writer.write_line( 'uassert(ErrorCodes::InvalidNamespace, str::stream() << "Invalid namespace specified: "' - ' << _nss.toStringForErrorMsg(), _nss.isValid());') + " << _nss.toStringForErrorMsg(), _nss.isValid());" + ) class _CommandWithUUIDNamespaceTypeInfo(_CommandBaseTypeInfo): @@ -602,72 +720,95 @@ class _CommandWithUUIDNamespaceTypeInfo(_CommandBaseTypeInfo): @staticmethod def _get_nss_param(gen_header): - nss_param = 'NamespaceStringOrUUID nssOrUUID' + nss_param = "NamespaceStringOrUUID nssOrUUID" if not gen_header: - nss_param = 'const ' + nss_param + nss_param = "const " + nss_param return nss_param def get_constructor_method(self, gen_header=False): # type: (bool) -> MethodInfo class_name = common.title_case(self._struct.cpp_name) sc_arg = _get_serialization_ctx_arg() - return MethodInfo(class_name, class_name, [self._get_nss_param(gen_header), sc_arg], - explicit=True) + return MethodInfo( + class_name, + class_name, + [self._get_nss_param(gen_header), sc_arg], + explicit=True, + ) def get_required_constructor_method(self, gen_header=False): # type: (bool) -> MethodInfo class_name = common.title_case(self._struct.cpp_name) return MethodInfo( - class_name, class_name, - [self._get_nss_param(gen_header)] + _get_required_parameters(self._struct)) + class_name, + class_name, + [self._get_nss_param(gen_header)] + _get_required_parameters(self._struct), + ) def get_serializer_method(self): # type: () -> MethodInfo return MethodInfo( - common.title_case(self._struct.cpp_name), 'serialize', - ['const BSONObj& commandPassthroughFields', 'BSONObjBuilder* builder'], 'void', - const=True) + common.title_case(self._struct.cpp_name), + "serialize", + ["const BSONObj& commandPassthroughFields", "BSONObjBuilder* builder"], + "void", + const=True, + ) def get_to_bson_method(self): # type: () -> MethodInfo return MethodInfo( - common.title_case(self._struct.cpp_name), 'toBSON', - ['const BSONObj& commandPassthroughFields'], 'BSONObj', const=True) + common.title_case(self._struct.cpp_name), + "toBSON", + ["const BSONObj& commandPassthroughFields"], + "BSONObj", + const=True, + ) def get_deserializer_method(self): # type: () -> MethodInfo return MethodInfo( - common.title_case(self._struct.cpp_name), 'parseProtected', - ['const IDLParserContext& ctxt', 'const BSONObj& bsonObject'], 'void') + common.title_case(self._struct.cpp_name), + "parseProtected", + ["const IDLParserContext& ctxt", "const BSONObj& bsonObject"], + "void", + ) def gen_getter_method(self, indented_writer): # type: (writer.IndentedTextWriter) -> None indented_writer.write_line( - 'const NamespaceStringOrUUID& getNamespaceOrUUID() const { return _nssOrUUID; }') + "const NamespaceStringOrUUID& getNamespaceOrUUID() const { return _nssOrUUID; }" + ) if self._struct.non_const_getter: indented_writer.write_line( - 'NamespaceStringOrUUID& getNamespaceOrUUID() { return _nssOrUUID; }') + "NamespaceStringOrUUID& getNamespaceOrUUID() { return _nssOrUUID; }" + ) def gen_member(self, indented_writer): # type: (writer.IndentedTextWriter) -> None - indented_writer.write_line('NamespaceStringOrUUID _nssOrUUID;') + indented_writer.write_line("NamespaceStringOrUUID _nssOrUUID;") def gen_serializer(self, indented_writer): # type: (writer.IndentedTextWriter) -> None - indented_writer.write_line('_nssOrUUID.serialize(builder, "%s"_sd);' % (self._command.name)) + indented_writer.write_line( + '_nssOrUUID.serialize(builder, "%s"_sd);' % (self._command.name) + ) indented_writer.write_empty_line() def gen_namespace_check(self, indented_writer, db_name, element): # type: (writer.IndentedTextWriter, str, str) -> None indented_writer.write_line( - 'auto collOrUUID = ctxt.checkAndAssertCollectionNameOrUUID(%s);' % (element)) + "auto collOrUUID = ctxt.checkAndAssertCollectionNameOrUUID(%s);" % (element) + ) indented_writer.write_line( - '_nssOrUUID = std::holds_alternative(collOrUUID) ? NamespaceStringUtil::deserialize(%s, get(collOrUUID)) : NamespaceStringOrUUID(%s, get(collOrUUID));' - % (db_name, db_name)) + "_nssOrUUID = std::holds_alternative(collOrUUID) ? NamespaceStringUtil::deserialize(%s, get(collOrUUID)) : NamespaceStringOrUUID(%s, get(collOrUUID));" + % (db_name, db_name) + ) indented_writer.write_line( 'uassert(ErrorCodes::InvalidNamespace, str::stream() << "Invalid namespace specified: "' - ' << _nssOrUUID.toStringForErrorMsg()' - ', !_nssOrUUID.isNamespaceString() || _nssOrUUID.nss().isValid());') + " << _nssOrUUID.toStringForErrorMsg()" + ", !_nssOrUUID.isNamespaceString() || _nssOrUUID.nss().isValid());" + ) def get_struct_info(struct): diff --git a/buildscripts/idl/idl/syntax.py b/buildscripts/idl/idl/syntax.py index cc0ee7f6f06..4c351b325cd 100644 --- a/buildscripts/idl/idl/syntax.py +++ b/buildscripts/idl/idl/syntax.py @@ -45,8 +45,9 @@ class IDLParsedSpec(object): def __init__(self, spec, error_collection): # type: (IDLSpec, errors.ParserErrorCollection) -> None """Must specify either an IDL document or errors, not both.""" - assert (spec is None and error_collection is not None) or (spec is not None - and error_collection is None) + assert (spec is None and error_collection is not None) or ( + spec is not None and error_collection is None + ) self.spec = spec self.errors = error_collection @@ -75,11 +76,11 @@ def parse_array_variant_types(name): if not name.startswith("array>"): return None - name = name[len("array, ...>> types. if variant_type.startswith("array<") and variant_type.endswith(">"): @@ -95,7 +96,7 @@ def parse_array_type(name): if not name.startswith("array<") and not name.endswith(">"): return None - name = name[len("array<"):] + name = name[len("array<") :] name = name[:-1] # V1 restriction, ban nested array types to reduce scope. @@ -114,7 +115,9 @@ def _zip_scalar(items, obj): def _item_and_type(dic): # type: (Dict[Any, List[Any]]) -> Iterator[Tuple[Any, Any]] """Return an Iterator of (key, value) pairs from a dictionary.""" - return itertools.chain.from_iterable((_zip_scalar(value, key) for (key, value) in dic.items())) + return itertools.chain.from_iterable( + (_zip_scalar(value, key) for (key, value) in dic.items()) + ) class SymbolTable(object): @@ -138,19 +141,27 @@ class SymbolTable(object): def _is_duplicate(self, ctxt, location, name, duplicate_class_name): # type: (errors.ParserContext, common.SourceLocation, str, str) -> bool """Return true if the given item already exist in the symbol table.""" - for (item, entity_type) in _item_and_type({ + for item, entity_type in _item_and_type( + { "command": self.commands, "enum": self.enums, "struct": self.structs, "type": self.types, - }): + } + ): if item.name == name: - ctxt.add_duplicate_symbol_error(location, name, duplicate_class_name, entity_type) + ctxt.add_duplicate_symbol_error( + location, name, duplicate_class_name, entity_type + ) return True if entity_type == "command": - if name in [item.command_name, item.command_alias if item.command_alias else '']: - ctxt.add_duplicate_symbol_error(location, name, duplicate_class_name, - entity_type) + if name in [ + item.command_name, + item.command_alias if item.command_alias else "", + ]: + ctxt.add_duplicate_symbol_error( + location, name, duplicate_class_name, entity_type + ) return True return False @@ -176,8 +187,9 @@ class SymbolTable(object): def add_command(self, ctxt, command): # type: (errors.ParserContext, Command) -> None """Add an IDL command to the symbol table and check for duplicates.""" - if (not self._is_duplicate(ctxt, command, command.name, "command") - and not self._is_duplicate(ctxt, command, command.command_alias, "command")): + if not self._is_duplicate( + ctxt, command, command.name, "command" + ) and not self._is_duplicate(ctxt, command, command.command_alias, "command"): self.commands.append(command) def add_generic_argument_list(self, field_list): @@ -251,16 +263,22 @@ class SymbolTable(object): """Find the type or struct a field refers to or log an error.""" if isinstance(field_type, FieldTypeVariant): - variant = VariantType(field_type.file_name, field_type.line, field_type.column) + variant = VariantType( + field_type.file_name, field_type.line, field_type.column + ) variant.bson_serialization_type = [] for alternative in field_type.variant: - alternative_type = self.resolve_field_type(ctxt, location, field_name, alternative) + alternative_type = self.resolve_field_type( + ctxt, location, field_name, alternative + ) if not alternative_type: # There was an error. return None if isinstance(alternative_type, Enum): - ctxt.add_variant_enum_error(location, field_name, alternative_type.name) + ctxt.add_variant_enum_error( + location, field_name, alternative_type.name + ) return None if isinstance(alternative_type, Struct): @@ -269,7 +287,7 @@ class SymbolTable(object): # cause parsing ambiguity. first_element = alternative_type.fields[0].name if first_element in [ - elem.fields[0].name for elem in variant.variant_struct_types + elem.fields[0].name for elem in variant.variant_struct_types ]: ctxt.add_variant_structs_error(location, field_name) continue @@ -285,18 +303,22 @@ class SymbolTable(object): bson_serialization_type = [] # If alternative_type is an array, element type could be Struct or Type. if isinstance(base_type, Type): - bson_serialization_type = cast(Type, base_type).bson_serialization_type + bson_serialization_type = cast( + Type, base_type + ).bson_serialization_type variant.bson_serialization_type.extend(bson_serialization_type) return variant if isinstance(field_type, FieldTypeArray): - element_type = self.resolve_field_type(ctxt, location, field_name, - field_type.element_type) + element_type = self.resolve_field_type( + ctxt, location, field_name, field_type.element_type + ) if not element_type: - ctxt.add_unknown_type_error(location, field_name, - field_type.element_type.debug_string()) + ctxt.add_unknown_type_error( + location, field_name, field_type.element_type.debug_string() + ) return None if isinstance(element_type, Enum): @@ -307,7 +329,7 @@ class SymbolTable(object): assert isinstance(field_type, FieldTypeSingle) type_name = field_type.type_name - if type_name.startswith('array<'): + if type_name.startswith("array<"): # The caller should've already stripped "array<...>" from type_name, this may be an # illegal nested array like "array>". ctxt.add_bad_array_type_name_error(location, field_name, type_name) @@ -404,13 +426,14 @@ class ArrayType(Type): def __init__(self, element_type): # type: (Union[Struct, Type]) -> None """Construct an ArrayType.""" - super(ArrayType, self).__init__(element_type.file_name, element_type.line, - element_type.column) - self.name = f'array<{element_type.name}>' + super(ArrayType, self).__init__( + element_type.file_name, element_type.line, element_type.column + ) + self.name = f"array<{element_type.name}>" self.element_type = element_type if isinstance(element_type, Type): if element_type.cpp_type: - self.cpp_type = f'std::vector<{element_type.cpp_type}>' + self.cpp_type = f"std::vector<{element_type.cpp_type}>" else: assert isinstance(element_type, VariantType) # cpp_type can't be set here for array of variants as element_type.cpp_type is not set yet. @@ -423,7 +446,7 @@ class VariantType(Type): # type: (str, int, int) -> None """Construct a VariantType.""" super(VariantType, self).__init__(file_name, line, column) - self.name = 'variant' + self.name = "variant" self.variant_types = [] # type: List[Type] self.variant_struct_types = [] # type: List[Struct] @@ -450,9 +473,14 @@ class Validator(common.SourceLocation): super(Validator, self).__init__(file_name, line, column) def __eq__(self, other): - return (isinstance(other, Validator) and self.gt == other.gt and self.lt == other.lt - and self.gte == other.gte and self.lte == other.lte - and self.callback == other.callback) + return ( + isinstance(other, Validator) + and self.gt == other.gt + and self.lt == other.lt + and self.gte == other.gte + and self.lte == other.lte + and self.callback == other.callback + ) def __ne__(self, other): return not self == other @@ -601,7 +629,11 @@ class Privilege(common.SourceLocation): """ location = super(Privilege, self).__str__() msg = "location: %s, resource_pattern: %s, action_type: %s, agg_stage: %s" % ( - location, self.resource_pattern, self.action_type, self.agg_stage) + location, + self.resource_pattern, + self.action_type, + self.agg_stage, + ) return msg # type: ignore @@ -626,7 +658,11 @@ class AccessCheck(common.SourceLocation): location: test.idl: (17, 4), check: get_single_user, privilege: (location: test.idl: (18, 6), resource_pattern: exact_namespace, action_type: ['find', 'insert', 'update', 'remove'], agg_stage: None """ location = super(AccessCheck, self).__str__() - msg = "location: %s, check: %s, privilege: %s" % (location, self.check, self.privilege) + msg = "location: %s, check: %s, privilege: %s" % ( + location, + self.check, + self.privilege, + ) return msg # type: ignore @@ -708,8 +744,11 @@ class EnumValue(common.SourceLocation): super(EnumValue, self).__init__(file_name, line, column) def __eq__(self, other): - return (isinstance(other, EnumValue) and self.name == other.name - and self.value == other.value) + return ( + isinstance(other, EnumValue) + and self.name == other.name + and self.value == other.value + ) def __ne__(self, other): return not self == other @@ -789,12 +828,13 @@ class FieldTypeArray(FieldType): """Construct a FieldTypeArray.""" self.element_type = element_type # type: Union[FieldTypeSingle, FieldTypeVariant] - super(FieldTypeArray, self).__init__(element_type.file_name, element_type.line, - element_type.column) + super(FieldTypeArray, self).__init__( + element_type.file_name, element_type.line, element_type.column + ) def debug_string(self): """Display this field type in error messages.""" - return f'array<{self.element_type.type_name}>' + return f"array<{self.element_type.type_name}>" class FieldTypeVariant(FieldType): @@ -809,7 +849,7 @@ class FieldTypeVariant(FieldType): def debug_string(self): """Display this field type in error messages.""" - return 'variant<%s>' % (', '.join(v.debug_string() for v in self.variant)) + return "variant<%s>" % (", ".join(v.debug_string() for v in self.variant)) class Expression(common.SourceLocation): @@ -826,8 +866,12 @@ class Expression(common.SourceLocation): super(Expression, self).__init__(file_name, line, column) def __eq__(self, other): - return (isinstance(other, Expression) and self.literal == other.literal - and self.expr == other.expr and self.is_constexpr == other.is_constexpr) + return ( + isinstance(other, Expression) + and self.literal == other.literal + and self.expr == other.expr + and self.is_constexpr == other.is_constexpr + ) def __ne__(self, other): return not self == other diff --git a/buildscripts/idl/idl/writer.py b/buildscripts/idl/idl/writer.py index 436f7ef9eb7..e71e3d111f5 100644 --- a/buildscripts/idl/idl/writer.py +++ b/buildscripts/idl/idl/writer.py @@ -36,9 +36,9 @@ _INDENT_SPACE_COUNT = 4 def _fill_spaces(count): # type: (int) -> str """Fill a string full of spaces.""" - fill = '' + fill = "" for _ in range(count * _INDENT_SPACE_COUNT): - fill += ' ' + fill += " " return fill @@ -48,7 +48,7 @@ def _indent_text(count, unindented_text): """Indent each line of a multi-line string.""" lines = unindented_text.splitlines() fill = _fill_spaces(count) - return '\n'.join(fill + line for line in lines) + return "\n".join(fill + line for line in lines) def is_function(name): @@ -65,10 +65,10 @@ def is_function(name): def get_method_name(name): # type: (str) -> str """Get a method name from a fully qualified method name.""" - pos = name.rfind('::') + pos = name.rfind("::") if pos == -1: return name - return name[pos + 2:] + return name[pos + 2 :] def get_method_name_from_qualified_method_name(name): @@ -79,12 +79,12 @@ def get_method_name_from_qualified_method_name(name): if name.startswith("::"): name = name[2:] - prefix = 'mongo::' + prefix = "mongo::" pos = name.find(prefix) if pos == -1: return name - return name[len(prefix):] + return name[len(prefix) :] class IndentedTextWriter(object): @@ -240,7 +240,7 @@ class NamespaceScopeBlock(WriterBlock): # type: () -> None """Write the beginning of the block and do not indent.""" for namespace in self._namespaces: - self._writer.write_unindented_line('namespace %s {' % (namespace)) + self._writer.write_unindented_line("namespace %s {" % (namespace)) def __exit__(self, *args): # type: (*str) -> None @@ -248,7 +248,7 @@ class NamespaceScopeBlock(WriterBlock): self._namespaces.reverse() for namespace in self._namespaces: - self._writer.write_unindented_line('} // namespace %s' % (namespace)) + self._writer.write_unindented_line("} // namespace %s" % (namespace)) class UnindentedBlock(WriterBlock): @@ -301,7 +301,7 @@ def _get_common_prefix(words): """ empty_words = [lw for lw in words if len(lw) == 0] if empty_words: - return '' + return "" first_letters = {w[0] for w in words} @@ -314,7 +314,7 @@ def _get_common_prefix(words): return words[0][0] + _get_common_prefix(suffix_words) else: - return '' + return "" def gen_trie(words, writer, callback): @@ -327,7 +327,7 @@ def gen_trie(words, writer, callback): """ words = sorted(words) - _gen_trie('', words, writer, callback) + _gen_trie("", words, writer, callback) def _gen_trie(prefix, words, writer, callback): @@ -350,12 +350,14 @@ def _gen_trie(prefix, words, writer, callback): suffix = words[0] suffix_len = len(suffix) - predicate = f'fieldName.size() == {len(word_to_check)} && ' \ + predicate = ( + f"fieldName.size() == {len(word_to_check)} && " + f'std::char_traits::compare(fieldName.rawData() + {prefix_len}, "{suffix}", {suffix_len}) == 0' + ) # If there is no trailing text, we just need to check length to validate we matched if suffix_len == 0: - predicate = f'fieldName.size() == {len(word_to_check)}' + predicate = f"fieldName.size() == {len(word_to_check)}" # Optimization: # Checking strings of length 1 or even length is efficient. Strings of 3 byte length are @@ -363,10 +365,12 @@ def _gen_trie(prefix, words, writer, callback): # length strings require just 1. Since we know the field name is zero terminated, we can # just use memcmp and compare with the trailing null byte. elif suffix_len % 4 == 3: - predicate = f'fieldName.size() == {len(word_to_check)} && ' \ + predicate = ( + f"fieldName.size() == {len(word_to_check)} && " + f' memcmp(fieldName.rawData() + {prefix_len}, "{suffix}\\0", {suffix_len + 1}) == 0' + ) - with IndentedScopedBlock(writer, f'if ({predicate}) {{', '}'): + with IndentedScopedBlock(writer, f"if ({predicate}) {{", "}"): callback(word_to_check) return @@ -377,7 +381,9 @@ def _gen_trie(prefix, words, writer, callback): empty_words = [lw for lw in words if len(lw) == 0] if empty_words: word_to_check = prefix - with IndentedScopedBlock(writer, f'if (fieldName.size() == {len(word_to_check)}) {{', "}"): + with IndentedScopedBlock( + writer, f"if (fieldName.size() == {len(word_to_check)}) {{", "}" + ): callback(word_to_check) # Filter out empty words @@ -393,10 +399,11 @@ def _gen_trie(prefix, words, writer, callback): suffix_words = [flw[gcp_len:] for flw in words] with IndentedScopedBlock( - writer, - f'if (fieldName.size() >= {gcp_len} && '\ - + f'std::char_traits::compare(fieldName.rawData() + {prefix_len}, "{gcp}", {gcp_len}) == 0) {{', - "}"): + writer, + f"if (fieldName.size() >= {gcp_len} && " + + f'std::char_traits::compare(fieldName.rawData() + {prefix_len}, "{gcp}", {gcp_len}) == 0) {{', + "}", + ): _gen_trie(prefix + gcp, suffix_words, writer, callback) return @@ -408,16 +415,16 @@ def _gen_trie(prefix, words, writer, callback): first_letters = sorted(list({w[0] for w in sorted_words})) min_len = len(prefix) + min([len(w) for w in sorted_words]) - with IndentedScopedBlock(writer, f'if (fieldName.size() >= {min_len}) {{', "}"): + with IndentedScopedBlock(writer, f"if (fieldName.size() >= {min_len}) {{", "}"): first_if = True for first_letter in first_letters: - fl_words = [flw[1:] for flw in words if flw[0] == first_letter] - ei = "else " if not first_if else '' + ei = "else " if not first_if else "" with IndentedScopedBlock( - writer, f"{ei}if (fieldName[{len(prefix)}] == '{first_letter}') {{", "}"): + writer, f"{ei}if (fieldName[{len(prefix)}] == '{first_letter}') {{", "}" + ): _gen_trie(prefix + first_letter, fl_words, writer, callback) first_if = False @@ -427,6 +434,8 @@ def gen_string_table_find_function_block(out, in_str, on_match, on_fail, words): # type: (IndentedTextWriter, str, str, str, list[str]) -> None """Wrap a gen_trie generated block as a function.""" index = {word: i for i, word in enumerate(words)} - out.write_line(f'StringData fieldName{{{in_str}}};') - gen_trie(words, out, lambda w: out.write_line(f'return {on_match.format(index[w])};')) - out.write_line(f'return {on_fail};') + out.write_line(f"StringData fieldName{{{in_str}}};") + gen_trie( + words, out, lambda w: out.write_line(f"return {on_match.format(index[w])};") + ) + out.write_line(f"return {on_fail};") diff --git a/buildscripts/idl/idl_check_compatibility.py b/buildscripts/idl/idl_check_compatibility.py index d7e1dff541a..870d5f3f817 100644 --- a/buildscripts/idl/idl_check_compatibility.py +++ b/buildscripts/idl/idl_check_compatibility.py @@ -94,13 +94,12 @@ ALLOW_ANY_TYPE_LIST: List[str] = [ "commandParameterDeserializerNotEqualUnstable-param-deserializerNotEqualParam", "replyFieldDeserializerNotEqualUnstable-reply-deserializerNotEqualReplyUnstableField", "commandDeserializerNotEqualUnstable", - 'create-param-backwards', - 'saslStart-param-payload', - 'saslStart-param-payload', - 'saslStart-reply-payload', - 'saslContinue-param-payload', - 'saslContinue-reply-payload', - + "create-param-backwards", + "saslStart-param-payload", + "saslStart-param-payload", + "saslStart-reply-payload", + "saslContinue-param-payload", + "saslContinue-reply-payload", # These commands (aggregate, find, update, delete, findAndModify, explain) might contain some # fields with type `any`. Currently, it's not possible to avoid the `any` type in those cases. # Instead, here are the preventive measures in-place to catch unintentional breaking changes: @@ -112,79 +111,78 @@ ALLOW_ANY_TYPE_LIST: List[str] = [ # 3- Added further checks to the current script (idl_check_compatibility.py) to check for # changing a custom serializer/deserializer and considering it as a potential breaking # change. - 'aggregate-param-pipeline', - 'aggregate-param-explain', - 'aggregate-param-allowDiskUse', - 'aggregate-param-cursor', - 'aggregate-param-hint', - 'aggregate-param-comment', - 'aggregate-param-allowedIndexes', - 'aggregate-param-indexHints', - 'aggregate-param-needsMerge', - 'aggregate-param-fromMongos', - 'aggregate-param-$_requestReshardingResumeToken', - 'aggregate-param-isMapReduceCommand', - 'bulkWrite-param-shardVersion', - 'bulkWrite-reply-_id', - 'bulkWrite-reply-value', - 'count-param-hint', - 'count-param-limit', - 'count-param-maxTimeMS', - 'find-param-filter', - 'find-param-projection', - 'find-param-sort', - 'find-param-hint', - 'find-param-comment', - 'find-param-allowedIndexes', - 'find-param-indexHints', - 'find-param-collation', - 'find-param-singleBatch', - 'find-param-allowDiskUse', - 'find-param-min', - 'find-param-max', - 'find-param-returnKey', - 'find-param-showRecordId', - 'find-param-$queryOptions', - 'find-param-tailable', - 'find-param-oplogReplay', - 'find-param-noCursorTimeout', - 'find-param-awaitData', - 'find-param-allowPartialResults', - 'find-param-readOnce', - 'find-param-allowSpeculativeMajorityRead', - 'find-param-$_requestResumeToken', - 'find-param-$_resumeAfter', + "aggregate-param-pipeline", + "aggregate-param-explain", + "aggregate-param-allowDiskUse", + "aggregate-param-cursor", + "aggregate-param-hint", + "aggregate-param-comment", + "aggregate-param-allowedIndexes", + "aggregate-param-indexHints", + "aggregate-param-needsMerge", + "aggregate-param-fromMongos", + "aggregate-param-$_requestReshardingResumeToken", + "aggregate-param-isMapReduceCommand", + "bulkWrite-param-shardVersion", + "bulkWrite-reply-_id", + "bulkWrite-reply-value", + "count-param-hint", + "count-param-limit", + "count-param-maxTimeMS", + "find-param-filter", + "find-param-projection", + "find-param-sort", + "find-param-hint", + "find-param-comment", + "find-param-allowedIndexes", + "find-param-indexHints", + "find-param-collation", + "find-param-singleBatch", + "find-param-allowDiskUse", + "find-param-min", + "find-param-max", + "find-param-returnKey", + "find-param-showRecordId", + "find-param-$queryOptions", + "find-param-tailable", + "find-param-oplogReplay", + "find-param-noCursorTimeout", + "find-param-awaitData", + "find-param-allowPartialResults", + "find-param-readOnce", + "find-param-allowSpeculativeMajorityRead", + "find-param-$_requestResumeToken", + "find-param-$_resumeAfter", "find-param-$_startAt", - 'find-param-maxTimeMS', - 'update-param-u', - 'update-param-hint', - 'update-param-upsertSupplied', - 'update-reply-_id', - 'delete-param-limit', - 'delete-param-hint', - 'findAndModify-param-hint', - 'findAndModify-param-update', - 'findAndModify-reply-upserted', - 'insert-reply-opTime', - 'update-reply-opTime', - 'delete-reply-opTime', - 'aggregate-reply-partialResultsReturned', - 'aggregate-reply-invalidated', - 'find-reply-partialResultsReturned', - 'find-reply-invalidated', - 'getMore-reply-partialResultsReturned', - 'getMore-reply-invalidated', - 'listDatabasesForAllTenants-reply-tenant', - 'create-param-min', - 'create-param-max', - 'bulkWrite-param-updateMods', - 'bulkWrite-param-hint', - + "find-param-maxTimeMS", + "update-param-u", + "update-param-hint", + "update-param-upsertSupplied", + "update-reply-_id", + "delete-param-limit", + "delete-param-hint", + "findAndModify-param-hint", + "findAndModify-param-update", + "findAndModify-reply-upserted", + "insert-reply-opTime", + "update-reply-opTime", + "delete-reply-opTime", + "aggregate-reply-partialResultsReturned", + "aggregate-reply-invalidated", + "find-reply-partialResultsReturned", + "find-reply-invalidated", + "getMore-reply-partialResultsReturned", + "getMore-reply-invalidated", + "listDatabasesForAllTenants-reply-tenant", + "create-param-min", + "create-param-max", + "bulkWrite-param-updateMods", + "bulkWrite-param-hint", # No actual user-facing difference - 'bulkWrite-reply-opTime', - 'getMore-param-lastKnownCommittedOpTime', - 'hello-reply-opTime', - 'hello-reply-majorityOpTime', + "bulkWrite-reply-opTime", + "getMore-param-lastKnownCommittedOpTime", + "hello-reply-opTime", + "hello-reply-majorityOpTime", ] # Permit a parameter to move from bson serialisation type any @@ -193,26 +191,26 @@ IGNORE_ANY_TO_NON_ANY_LIST: List[str] = [ # These parameters were type-checked "by hand" previously; # enforcing this from IDL instead does not narrow the range # of permitted values - 'find-param-maxTimeMS', - 'count-param-maxTimeMS', + "find-param-maxTimeMS", + "count-param-maxTimeMS", ] # Permit a parameter to move from a non-any bson serialisation type to any. IGNORE_NON_ANY_TO_ANY_LIST: List[str] = [ - 'aggregate-param-indexHints', - 'bulkWrite-reply-opTime', - 'find-param-indexHints', - 'getMore-param-lastKnownCommittedOpTime', - 'hello-reply-opTime', - 'hello-reply-majorityOpTime', + "aggregate-param-indexHints", + "bulkWrite-reply-opTime", + "find-param-indexHints", + "getMore-param-lastKnownCommittedOpTime", + "hello-reply-opTime", + "hello-reply-majorityOpTime", ] # Permit the cpp type of a parameter to change ALLOW_CPP_TYPE_CHANGE_LIST: List[str] = [ # maxTimeMS has been widened for consistency with # equivalent params for other commands (aggregate) - 'find-param-maxTimeMS-std::int32_t-std::int64_t', - 'count-param-maxTimeMS-std::int32_t-std::int64_t', + "find-param-maxTimeMS-std::int32_t-std::int64_t", + "count-param-maxTimeMS-std::int32_t-std::int64_t", ] # Do not add user visible fields already released in earlier versions. @@ -222,32 +220,30 @@ ALLOW_CPP_TYPE_CHANGE_LIST: List[str] = [ # team. IGNORE_STABLE_TO_UNSTABLE_LIST: List[str] = [ # This list is only used in unit-tests. - 'newReplyFieldUnstableIgnoreList-reply-unstableNewFieldIgnoreList', - 'newTypeFieldUnstableIgnoreList-param-unstableNewFieldIgnoreList', - 'newTypeEnumOrStructIgnoreList-reply-unstableNewFieldIgnoreList', - 'commandParameterUnstableIgnoreList-param-newUnstableParameterIgnoreList', - 'newReplyFieldUnstableOptionalIgnoreList-reply-unstableOptionalNewFieldIgnoreList', - 'newReplyTypeEnumOrStructIgnoreList-reply-newReplyTypeEnumOrStructIgnoreList', - 'newReplyFieldVariantNotSubsetIgnoreList-reply-variantNotSubsetReplyFieldIgnoreList', - 'replyFieldVariantDifferentStructIgnoreList-reply-variantStructRecursiveReplyFieldIgnoreList', - 'replyFieldNonVariantToVariantIgnoreList-reply-nonVariantToVariantReplyFieldIgnoreList', - 'replyFieldNonEnumToEnumIgnoreList-reply-nonEnumToEnumReplyIgnoreList', - 'newUnstableParamTypeChangesIgnoreList-param-newUnstableTypeChangesParamIgnoreList', - 'newUnstableTypeChangesIgnoreList', - 'newUnstableTypeChangesIgnoreList-param-newUnstableTypeChangesFieldIgnoreList', - 'newUnstableReplyFieldTypeChangesIgnoreList-reply-newUnstableTypeChangesReplyFieldIgnoreList', - 'newReplyFieldTypeStructIgnoreList-reply-structReplyField', - 'newReplyFieldTypeStructIgnoreList-reply-unstableNewFieldIgnoreList', - + "newReplyFieldUnstableIgnoreList-reply-unstableNewFieldIgnoreList", + "newTypeFieldUnstableIgnoreList-param-unstableNewFieldIgnoreList", + "newTypeEnumOrStructIgnoreList-reply-unstableNewFieldIgnoreList", + "commandParameterUnstableIgnoreList-param-newUnstableParameterIgnoreList", + "newReplyFieldUnstableOptionalIgnoreList-reply-unstableOptionalNewFieldIgnoreList", + "newReplyTypeEnumOrStructIgnoreList-reply-newReplyTypeEnumOrStructIgnoreList", + "newReplyFieldVariantNotSubsetIgnoreList-reply-variantNotSubsetReplyFieldIgnoreList", + "replyFieldVariantDifferentStructIgnoreList-reply-variantStructRecursiveReplyFieldIgnoreList", + "replyFieldNonVariantToVariantIgnoreList-reply-nonVariantToVariantReplyFieldIgnoreList", + "replyFieldNonEnumToEnumIgnoreList-reply-nonEnumToEnumReplyIgnoreList", + "newUnstableParamTypeChangesIgnoreList-param-newUnstableTypeChangesParamIgnoreList", + "newUnstableTypeChangesIgnoreList", + "newUnstableTypeChangesIgnoreList-param-newUnstableTypeChangesFieldIgnoreList", + "newUnstableReplyFieldTypeChangesIgnoreList-reply-newUnstableTypeChangesReplyFieldIgnoreList", + "newReplyFieldTypeStructIgnoreList-reply-structReplyField", + "newReplyFieldTypeStructIgnoreList-reply-unstableNewFieldIgnoreList", # Real use cases for changing a field from 'stable' to 'unstable'. - 'find-param-maxTimeMS', - 'count-param-maxTimeMS', - + "find-param-maxTimeMS", + "count-param-maxTimeMS", # No actual user-facing difference - 'bulkWrite-reply-opTime', - 'hello-reply-opTime', - 'hello-reply-majorityOpTime', - 'find-param-$_startAt', + "bulkWrite-reply-opTime", + "hello-reply-opTime", + "hello-reply-majorityOpTime", + "find-param-$_startAt", ] # Once a field is part of the stable API, either by direct addition or by changing it from unstable @@ -259,125 +255,128 @@ IGNORE_STABLE_TO_UNSTABLE_LIST: List[str] = [ # team. ALLOWED_STABLE_FIELDS_LIST: List[str] = [ # This list is only used in unit-tests. These cases modify fields from unstable to stable. - 'oldReplyFieldTypeBsonAnyUnstable-reply-oldBsonSerializationTypeAnyUnstableReplyField', - 'newReplyFieldTypeBsonAnyUnstable-reply-newBsonSerializationTypeAnyUnstableReplyField', - 'replyFieldTypeBsonAnyNotAllowedUnstable-reply-bsonSerializationTypeAnyUnstableReplyField', - 'replyFieldCppTypeNotEqualUnstable-reply-cppTypeNotEqualReplyUnstableField', - 'newReplyFieldStable-reply-stableNewField', - 'importedReplyCommand-reply-stableNewField', - 'newReplyFieldTypeStructRecursiveOne-reply-stableNewField', - 'commandParameterStableRequiredNoDefault-param-newRequiredStableParam', - 'oldCommandParamTypeBsonAnyUnstable-param-bsonTypeAnyUnstableParam', - 'newCommandParamTypeBsonAnyUnstable-param-bsonTypeAnyUnstableParam', - 'commandParamTypeBsonAnyNotAllowedUnstable-param-bsonTypeAnyUnstableParam', - 'commandParameterCppTypeNotEqualUnstable-param-cppTypeNotEqualParam', - 'oldTypeBsonAnyUnstable-param-oldBsonSerializationTypeAnyUnstableStructField', - 'newTypeBsonAnyUnstable-param-newBsonSerializationTypeAnyUnstableStructField', - 'typeBsonAnyNotAllowedUnstable-param-bsonSerializationTypeAnyUnstableStructField', - 'commandCppTypeNotEqualUnstable-param-cppTypeNotEqualStructUnstableField', - 'newlyAddedTypeFieldBsonAnyNotAllowed-param-newlyAddedBsonSerializationTypeAnyStructField', - 'typeWithIncompatibleChainedStruct-param-newBsonSerializationTypeAnyUnstableStructField', - 'addedCommandParameterDefault-param-newStableParameter', - 'addedCommandParameterStable-param-newOptionalStableParam', - 'addedCommandParameterStableRequired-param-newStableParam', - 'addedCommandParameterStableWithDefault-param-newStableParamWithDefault', - 'newCommandParameterTypeStructRecursiveOne-param-unstableToStableOptionalField', - 'oldUnstableParamTypeChanges-param-oldUnstableTypeChangesParam', - 'oldUnstableTypeChanges-param-oldUnstableTypeChangesField', - 'newTypeFieldStableOptional-param-stableOptionalTypeField', - 'newTypeFieldStableWithDefault-param-stableWithDefaultTypeField', - + "oldReplyFieldTypeBsonAnyUnstable-reply-oldBsonSerializationTypeAnyUnstableReplyField", + "newReplyFieldTypeBsonAnyUnstable-reply-newBsonSerializationTypeAnyUnstableReplyField", + "replyFieldTypeBsonAnyNotAllowedUnstable-reply-bsonSerializationTypeAnyUnstableReplyField", + "replyFieldCppTypeNotEqualUnstable-reply-cppTypeNotEqualReplyUnstableField", + "newReplyFieldStable-reply-stableNewField", + "importedReplyCommand-reply-stableNewField", + "newReplyFieldTypeStructRecursiveOne-reply-stableNewField", + "commandParameterStableRequiredNoDefault-param-newRequiredStableParam", + "oldCommandParamTypeBsonAnyUnstable-param-bsonTypeAnyUnstableParam", + "newCommandParamTypeBsonAnyUnstable-param-bsonTypeAnyUnstableParam", + "commandParamTypeBsonAnyNotAllowedUnstable-param-bsonTypeAnyUnstableParam", + "commandParameterCppTypeNotEqualUnstable-param-cppTypeNotEqualParam", + "oldTypeBsonAnyUnstable-param-oldBsonSerializationTypeAnyUnstableStructField", + "newTypeBsonAnyUnstable-param-newBsonSerializationTypeAnyUnstableStructField", + "typeBsonAnyNotAllowedUnstable-param-bsonSerializationTypeAnyUnstableStructField", + "commandCppTypeNotEqualUnstable-param-cppTypeNotEqualStructUnstableField", + "newlyAddedTypeFieldBsonAnyNotAllowed-param-newlyAddedBsonSerializationTypeAnyStructField", + "typeWithIncompatibleChainedStruct-param-newBsonSerializationTypeAnyUnstableStructField", + "addedCommandParameterDefault-param-newStableParameter", + "addedCommandParameterStable-param-newOptionalStableParam", + "addedCommandParameterStableRequired-param-newStableParam", + "addedCommandParameterStableWithDefault-param-newStableParamWithDefault", + "newCommandParameterTypeStructRecursiveOne-param-unstableToStableOptionalField", + "oldUnstableParamTypeChanges-param-oldUnstableTypeChangesParam", + "oldUnstableTypeChanges-param-oldUnstableTypeChangesField", + "newTypeFieldStableOptional-param-stableOptionalTypeField", + "newTypeFieldStableWithDefault-param-stableWithDefaultTypeField", # This list is only used in unit-tests. These cases add new fields as stable. - 'newlyAddedReplyFieldTypeBsonAnyNotAllowed-reply-newlyAddedBsonSerializationTypeAnyReplyField', - 'newReplyFieldAdded-reply-addedNewField', - 'replyFieldVariantDifferentStructIgnoreList-reply-fieldOne', - 'replyFieldNonEnumToEnumIgnoreList-reply-replyField', - 'newlyAddedReplyFieldTypeBsonAnyAllowed-reply-newlyAddedBsonSerializationTypeAnyReplyField', - 'newReplyOptionalBool-reply-ok2', - 'commandWithNewArrayTypeParameterAndArrayTypeReply-reply-newArrayTypeField', - 'commandWithNewNestedArrayTypeParameterAndNestedArrayTypeReply-reply-newStructWithArrayTypeField', - 'addedNewReplyFieldMissingUnstableField-reply-missingUnstableFieldAddedNewField', - 'newlyAddedParamBsonAnyNotAllowed-param-newlyAddedBsonAnyNotAllowedParam', - 'addedNewCommandParameterRequired-param-newRequiredParam', - 'newTypeFieldAddedRequired-param-addedRequiredTypeField', - 'arrayCommandParameterTypeError-param-fieldOne', - 'addedNewParameterMissingUnstableField-param-missingUnstableFieldAddedNewParameter', - 'addedNewCommandTypeFieldMissingUnstableField-param-missingUnstableFieldAddedNewField', - 'addedCommandParameter-param-newParameter', - 'newlyAddedParamBsonAnyAllowList-param-newlyAddedBsonAnyAllowListParam', - 'newlyAddedTypeFieldBsonAnyAllowList-param-newlyAddedBsonSerializationTypeAnyStructField', - 'newTypeFieldAddedOptional-param-addedOptionalTypeField', - 'newParameterOptionalBool-param-flag', - 'newCommandTypeOptionalBool-param-ok2', - 'commandWithNewArrayTypeParameterAndArrayTypeReply-param-newArrayTypeParameter', - 'commandWithNewNestedArrayTypeParameterAndNestedArrayTypeReply-param-newNestedArrayTypeParameter', - + "newlyAddedReplyFieldTypeBsonAnyNotAllowed-reply-newlyAddedBsonSerializationTypeAnyReplyField", + "newReplyFieldAdded-reply-addedNewField", + "replyFieldVariantDifferentStructIgnoreList-reply-fieldOne", + "replyFieldNonEnumToEnumIgnoreList-reply-replyField", + "newlyAddedReplyFieldTypeBsonAnyAllowed-reply-newlyAddedBsonSerializationTypeAnyReplyField", + "newReplyOptionalBool-reply-ok2", + "commandWithNewArrayTypeParameterAndArrayTypeReply-reply-newArrayTypeField", + "commandWithNewNestedArrayTypeParameterAndNestedArrayTypeReply-reply-newStructWithArrayTypeField", + "addedNewReplyFieldMissingUnstableField-reply-missingUnstableFieldAddedNewField", + "newlyAddedParamBsonAnyNotAllowed-param-newlyAddedBsonAnyNotAllowedParam", + "addedNewCommandParameterRequired-param-newRequiredParam", + "newTypeFieldAddedRequired-param-addedRequiredTypeField", + "arrayCommandParameterTypeError-param-fieldOne", + "addedNewParameterMissingUnstableField-param-missingUnstableFieldAddedNewParameter", + "addedNewCommandTypeFieldMissingUnstableField-param-missingUnstableFieldAddedNewField", + "addedCommandParameter-param-newParameter", + "newlyAddedParamBsonAnyAllowList-param-newlyAddedBsonAnyAllowListParam", + "newlyAddedTypeFieldBsonAnyAllowList-param-newlyAddedBsonSerializationTypeAnyStructField", + "newTypeFieldAddedOptional-param-addedOptionalTypeField", + "newParameterOptionalBool-param-flag", + "newCommandTypeOptionalBool-param-ok2", + "commandWithNewArrayTypeParameterAndArrayTypeReply-param-newArrayTypeParameter", + "commandWithNewNestedArrayTypeParameterAndNestedArrayTypeReply-param-newNestedArrayTypeParameter", # Add real use cases for allowed new stable or unstable-to-stable fields after this line. # Changes relative to 5.0: - 'collMod-param-isTimeseriesNamespace', - 'collMod-param-cappedSize', - 'collMod-param-cappedMax', - 'createIndexes-param-isTimeseriesNamespace', - 'dropIndexes-param-isTimeseriesNamespace', - 'listIndexes-param-isTimeseriesNamespace', - 'listIndexes-reply-clustered', - 'create-param-encryptedFields', - 'create-param-bucketRoundingSeconds', - 'create-param-temp', - 'endSessions-param-txnNumber', - 'endSessions-param-txnUUID', - 'refreshSessions-param-txnNumber', - 'refreshSessions-param-txnUUID', - 'insert-param-isTimeseriesNamespace', - 'update-param-isTimeseriesNamespace', - 'delete-param-isTimeseriesNamespace', - 'findAndModify-param-stmtId', - 'hello-param-loadBalanced', - 'hello-reply-serviceId', - 'hello-reply-isImplicitDefaultMajorityWC', - 'hello-reply-cwwc', - + "collMod-param-isTimeseriesNamespace", + "collMod-param-cappedSize", + "collMod-param-cappedMax", + "createIndexes-param-isTimeseriesNamespace", + "dropIndexes-param-isTimeseriesNamespace", + "listIndexes-param-isTimeseriesNamespace", + "listIndexes-reply-clustered", + "create-param-encryptedFields", + "create-param-bucketRoundingSeconds", + "create-param-temp", + "endSessions-param-txnNumber", + "endSessions-param-txnUUID", + "refreshSessions-param-txnNumber", + "refreshSessions-param-txnUUID", + "insert-param-isTimeseriesNamespace", + "update-param-isTimeseriesNamespace", + "delete-param-isTimeseriesNamespace", + "findAndModify-param-stmtId", + "hello-param-loadBalanced", + "hello-reply-serviceId", + "hello-reply-isImplicitDefaultMajorityWC", + "hello-reply-cwwc", # BulkWrite fields - 'bulkWrite-param-ops', - 'bulkWrite-param-insert', - 'bulkWrite-param-document', - 'bulkWrite-param-update', - 'bulkWrite-param-filter', - 'bulkWrite-param-multi', - 'bulkWrite-param-updateMods', - 'bulkWrite-param-upsert', - 'bulkWrite-param-arrayFilters', - 'bulkWrite-param-hint', - 'bulkWrite-param-collation', - 'bulkWrite-param-delete', - 'bulkWrite-param-collation', - 'bulkWrite-param-nsInfo', - 'bulkWrite-param-ns', - 'bulkWrite-param-cursor', - 'bulkWrite-param-bypassDocumentValidation', - 'bulkWrite-param-constants', - 'bulkWrite-param-ordered', - 'bulkWrite-param-stmtId', - 'bulkWrite-param-stmtIds', - 'bulkWrite-param-let', - 'bulkWrite-param-errorsOnly', - 'bulkWrite-reply-cursor', - 'bulkWrite-reply-id', - 'bulkWrite-reply-firstBatch', - 'bulkWrite-reply-ns', - 'bulkWrite-reply-electionId', - 'bulkWrite-reply-opTime', - 'bulkWrite-reply-nErrors', - 'bulkWrite-reply-nInserted', - 'bulkWrite-reply-nMatched', - 'bulkWrite-reply-nModified', - 'bulkWrite-reply-nUpserted', - 'bulkWrite-reply-nDeleted', + "bulkWrite-param-ops", + "bulkWrite-param-insert", + "bulkWrite-param-document", + "bulkWrite-param-update", + "bulkWrite-param-filter", + "bulkWrite-param-multi", + "bulkWrite-param-updateMods", + "bulkWrite-param-upsert", + "bulkWrite-param-arrayFilters", + "bulkWrite-param-hint", + "bulkWrite-param-collation", + "bulkWrite-param-delete", + "bulkWrite-param-collation", + "bulkWrite-param-nsInfo", + "bulkWrite-param-ns", + "bulkWrite-param-cursor", + "bulkWrite-param-bypassDocumentValidation", + "bulkWrite-param-constants", + "bulkWrite-param-ordered", + "bulkWrite-param-stmtId", + "bulkWrite-param-stmtIds", + "bulkWrite-param-let", + "bulkWrite-param-errorsOnly", + "bulkWrite-reply-cursor", + "bulkWrite-reply-id", + "bulkWrite-reply-firstBatch", + "bulkWrite-reply-ns", + "bulkWrite-reply-electionId", + "bulkWrite-reply-opTime", + "bulkWrite-reply-nErrors", + "bulkWrite-reply-nInserted", + "bulkWrite-reply-nMatched", + "bulkWrite-reply-nModified", + "bulkWrite-reply-nUpserted", + "bulkWrite-reply-nDeleted", ] SKIPPED_FILES = [ - "unittest.idl", "mozILocalization.idl", "mozILocaleService.idl", "mozIOSPreferences.idl", - "nsICollation.idl", "nsIStringBundle.idl", "nsIScriptableUConv.idl", "nsITextToSubURI.idl" + "unittest.idl", + "mozILocalization.idl", + "mozILocaleService.idl", + "mozIOSPreferences.idl", + "nsICollation.idl", + "nsIStringBundle.idl", + "nsIScriptableUConv.idl", + "nsITextToSubURI.idl", ] # Do not add commands that were visible to users in previously released versions. @@ -385,15 +384,15 @@ IGNORE_COMMANDS_LIST: List[str] = [ # The following commands were released behind a feature flag in 5.3 but were shelved in # favor of getClusterParameter and setClusterParameter. Since the feature flag was not enabled # in 5.3, they were effectively unusable and so can be safely removed from the strict API. - 'getChangeStreamOptions', - 'setChangeStreamOptions', + "getChangeStreamOptions", + "setChangeStreamOptions", ] RENAMED_COMPLEX_ACCESS_CHECKS = dict( # Changed during 6.1 as part of removing multi-auth support. - get_single_user='get_authenticated_user', - get_authenticated_usernames='get_authenticated_username', - get_impersonated_usernames='get_impersonated_username', + get_single_user="get_authenticated_user", + get_authenticated_usernames="get_authenticated_username", + get_impersonated_usernames="get_impersonated_username", ) ALLOWED_NEW_COMPLEX_ACCESS_CHECKS = dict( @@ -403,14 +402,15 @@ ALLOWED_NEW_COMPLEX_ACCESS_CHECKS = dict( # Added in 6.3 due to the new $_analyzeShardKeyReadWriteDistribution stage. "check_cursor_session_privilege" }, - # This list is only used in unit-tests. - complexChecksSupersetAllowed={'checkTwo', 'checkThree'}, - complexChecksSupersetSomeAllowed={'checkTwo'}) + complexChecksSupersetAllowed={"checkTwo", "checkThree"}, + complexChecksSupersetSomeAllowed={"checkTwo"}, +) CHANGED_ACCESS_CHECKS_TYPE = dict( # Changed access checks of update command from 'simple' to 'complex' in 8.1. - update=["simple", "complex"]) + update=["simple", "complex"] +) @dataclass @@ -423,7 +423,9 @@ class AllowedNewPrivilege: @classmethod def create_from(cls, privilege: syntax.Privilege): - return cls(privilege.resource_pattern, privilege.action_type, privilege.agg_stage) + return cls( + privilege.resource_pattern, privilege.action_type, privilege.agg_stage + ) ALLOWED_NEW_ACCESS_CHECK_PRIVILEGES = dict( @@ -435,7 +437,6 @@ ALLOWED_NEW_ACCESS_CHECK_PRIVILEGES = dict( # in the latest patch release. It is guarded by a feature flag so we are allowing this conflict here. AllowedNewPrivilege("cluster", ["queryStatsReadTransformed"], "queryStats"), ], - # This list is only used in unit-tests. complexChecksSupersetAllowed=[ AllowedNewPrivilege("resourcePatternTwo", ["actionTypeTwo"]), @@ -443,15 +444,21 @@ ALLOWED_NEW_ACCESS_CHECK_PRIVILEGES = dict( ], complexCheckPrivilegesSupersetSomeAllowed=[ AllowedNewPrivilege("resourcePatternTwo", ["actionTypeTwo"]) - ]) + ], +) class FieldCompatibility: """Information about a Field to check compatibility.""" - def __init__(self, field_type: Optional[Union[syntax.Enum, syntax.Struct, syntax.Type]], - idl_file: syntax.IDLParsedSpec, idl_file_path: str, stability: Optional[str], - optional: bool) -> None: + def __init__( + self, + field_type: Optional[Union[syntax.Enum, syntax.Struct, syntax.Type]], + idl_file: syntax.IDLParsedSpec, + idl_file_path: str, + stability: Optional[str], + optional: bool, + ) -> None: """Initialize data members and hand special cases, such as optionalBool type.""" self.field_type = field_type self.idl_file = idl_file @@ -459,11 +466,16 @@ class FieldCompatibility: self.stability = stability self.optional = optional - if isinstance(self.field_type, syntax.Type) and self.field_type.name == "optionalBool": + if ( + isinstance(self.field_type, syntax.Type) + and self.field_type.name == "optionalBool" + ): # special case for optionalBool type, because it is compatible # with bool type, but has bson_serialization_type == 'any' # which is not supported by many checks - self.field_type = syntax.Type(field_type.file_name, field_type.line, field_type.column) + self.field_type = syntax.Type( + field_type.file_name, field_type.line, field_type.column + ) self.field_type.name = "bool" self.field_type.bson_serialization_type = ["bool"] self.optional = True @@ -489,7 +501,7 @@ class ArrayTypeCheckResult(Enum): def is_unstable(stability: Optional[str]) -> bool: """Check whether the given stability value is considered as unstable.""" - return stability is not None and stability != 'stable' + return stability is not None and stability != "stable" def is_stable(stability: Optional[str]) -> bool: @@ -498,7 +510,7 @@ def is_stable(stability: Optional[str]) -> bool: def get_new_commands( - ctxt: IDLCompatibilityContext, new_idl_dir: str, import_directories: List[str] + ctxt: IDLCompatibilityContext, new_idl_dir: str, import_directories: List[str] ) -> Tuple[Dict[str, syntax.Command], Dict[str, syntax.IDLParsedSpec], Dict[str, str]]: """Get new IDL commands and check validity.""" new_commands: Dict[str, syntax.Command] = dict() @@ -507,14 +519,17 @@ def get_new_commands( for dirpath, _, filenames in os.walk(new_idl_dir): for new_filename in filenames: - if not new_filename.endswith('.idl') or new_filename in SKIPPED_FILES: + if not new_filename.endswith(".idl") or new_filename in SKIPPED_FILES: continue new_idl_file_path = os.path.join(dirpath, new_filename) with open(new_idl_file_path) as new_file: new_idl_file = parser.parse( - new_file, new_idl_file_path, - CompilerImportResolver(import_directories + [new_idl_dir]), False) + new_file, + new_idl_file_path, + CompilerImportResolver(import_directories + [new_idl_dir]), + False, + ) if new_idl_file.errors: new_idl_file.errors.dump_errors() raise ValueError(f"Cannot parse {new_idl_file_path}") @@ -527,12 +542,14 @@ def get_new_commands( if new_cmd.api_version != "1": # We're not ready to handle future API versions yet. ctxt.add_command_invalid_api_version_error( - new_cmd.command_name, new_cmd.api_version, new_idl_file_path) + new_cmd.command_name, new_cmd.api_version, new_idl_file_path + ) continue if new_cmd.command_name in new_commands: - ctxt.add_duplicate_command_name_error(new_cmd.command_name, new_idl_dir, - new_idl_file_path) + ctxt.add_duplicate_command_name_error( + new_cmd.command_name, new_idl_dir, new_idl_file_path + ) continue new_commands[new_cmd.command_name] = new_cmd @@ -543,50 +560,74 @@ def get_new_commands( def get_chained_type_or_struct( - chained_type_or_struct: Union[syntax.ChainedType, syntax.ChainedStruct], - idl_file: syntax.IDLParsedSpec, - idl_file_path: str) -> Optional[Union[syntax.Enum, syntax.Struct, syntax.Type]]: + chained_type_or_struct: Union[syntax.ChainedType, syntax.ChainedStruct], + idl_file: syntax.IDLParsedSpec, + idl_file_path: str, +) -> Optional[Union[syntax.Enum, syntax.Struct, syntax.Type]]: """Resolve and get chained type or struct from the IDL file.""" parser_ctxt = errors.ParserContext(idl_file_path, errors.ParserErrorCollection()) - resolved = idl_file.spec.symbols.resolve_type_from_name(parser_ctxt, chained_type_or_struct, - chained_type_or_struct.name, - chained_type_or_struct.name) + resolved = idl_file.spec.symbols.resolve_type_from_name( + parser_ctxt, + chained_type_or_struct, + chained_type_or_struct.name, + chained_type_or_struct.name, + ) if parser_ctxt.errors.has_errors(): parser_ctxt.errors.dump_errors() return resolved -def get_field_type(field: Union[syntax.Field, syntax.Command], idl_file: syntax.IDLParsedSpec, - idl_file_path: str) -> Optional[Union[syntax.Enum, syntax.Struct, syntax.Type]]: +def get_field_type( + field: Union[syntax.Field, syntax.Command], + idl_file: syntax.IDLParsedSpec, + idl_file_path: str, +) -> Optional[Union[syntax.Enum, syntax.Struct, syntax.Type]]: """Resolve and get field type of a field from the IDL file.""" parser_ctxt = errors.ParserContext(idl_file_path, errors.ParserErrorCollection()) - field_type = idl_file.spec.symbols.resolve_field_type(parser_ctxt, field, field.name, - field.type) + field_type = idl_file.spec.symbols.resolve_field_type( + parser_ctxt, field, field.name, field.type + ) if parser_ctxt.errors.has_errors(): parser_ctxt.errors.dump_errors() return field_type -def check_subset(ctxt: IDLCompatibilityContext, cmd_name: str, field_name: str, type_name: str, - sub_list: List[Union[str, syntax.EnumValue]], - super_list: List[Union[str, syntax.EnumValue]], file_path: str): +def check_subset( + ctxt: IDLCompatibilityContext, + cmd_name: str, + field_name: str, + type_name: str, + sub_list: List[Union[str, syntax.EnumValue]], + super_list: List[Union[str, syntax.EnumValue]], + file_path: str, +): """Check if sub_list is a subset of the super_list and log an error if not.""" if not set(sub_list).issubset(super_list): - ctxt.add_reply_field_not_subset_error(cmd_name, field_name, type_name, file_path) + ctxt.add_reply_field_not_subset_error( + cmd_name, field_name, type_name, file_path + ) -def check_superset(ctxt: IDLCompatibilityContext, cmd_name: str, type_name: str, - super_list: List[Union[str, syntax.EnumValue]], - sub_list: List[Union[str, syntax.EnumValue]], file_path: str, - param_name: Optional[str], is_command_parameter: bool): +def check_superset( + ctxt: IDLCompatibilityContext, + cmd_name: str, + type_name: str, + super_list: List[Union[str, syntax.EnumValue]], + sub_list: List[Union[str, syntax.EnumValue]], + file_path: str, + param_name: Optional[str], + is_command_parameter: bool, +): """Check if super_list is a superset of the sub_list and log an error if not.""" if not set(super_list).issuperset(sub_list): - ctxt.add_command_or_param_type_not_superset_error(cmd_name, type_name, file_path, - param_name, is_command_parameter) + ctxt.add_command_or_param_type_not_superset_error( + cmd_name, type_name, file_path, param_name, is_command_parameter + ) -def check_reply_field_type_recursive(ctxt: IDLCompatibilityContext, - field_pair: FieldCompatibilityPair) -> None: +def check_reply_field_type_recursive( + ctxt: IDLCompatibilityContext, field_pair: FieldCompatibilityPair +) -> None: """Check compatibility between old and new reply field type if old field type is a syntax.Type instance.""" old_field = field_pair.old new_field = field_pair.new @@ -601,56 +642,89 @@ def check_reply_field_type_recursive(ctxt: IDLCompatibilityContext, # bson_serialization_type. For all other errors, we check that the old field is stable # before adding an error. if not isinstance(new_field_type, syntax.Type): - if not is_unstable( - old_field.stability) and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST: + if ( + not is_unstable(old_field.stability) + and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST + ): ctxt.add_new_reply_field_type_enum_or_struct_error( - cmd_name, field_name, new_field_type.name, old_field_type.name, - new_field.idl_file_path) + cmd_name, + field_name, + new_field_type.name, + old_field_type.name, + new_field.idl_file_path, + ) return # If bson_serialization_type switches from 'any' to non-any type. - if "any" in old_field_type.bson_serialization_type and "any" not in new_field_type.bson_serialization_type: - ctxt.add_old_reply_field_bson_any_error(cmd_name, field_name, old_field_type.name, - new_field_type.name, old_field.idl_file_path) + if ( + "any" in old_field_type.bson_serialization_type + and "any" not in new_field_type.bson_serialization_type + ): + ctxt.add_old_reply_field_bson_any_error( + cmd_name, + field_name, + old_field_type.name, + new_field_type.name, + old_field.idl_file_path, + ) return # If bson_serialization_type switches from non-any to 'any' type. - if "any" not in old_field_type.bson_serialization_type and "any" in new_field_type.bson_serialization_type: + if ( + "any" not in old_field_type.bson_serialization_type + and "any" in new_field_type.bson_serialization_type + ): if ignore_list_name not in IGNORE_NON_ANY_TO_ANY_LIST: - ctxt.add_new_reply_field_bson_any_error(cmd_name, field_name, old_field_type.name, - new_field_type.name, new_field.idl_file_path) + ctxt.add_new_reply_field_bson_any_error( + cmd_name, + field_name, + old_field_type.name, + new_field_type.name, + new_field.idl_file_path, + ) return if "any" in old_field_type.bson_serialization_type: # If 'any' is not explicitly allowed as the bson_serialization_type. if ignore_list_name not in ALLOW_ANY_TYPE_LIST: ctxt.add_old_reply_field_bson_any_not_allowed_error( - cmd_name, field_name, old_field_type.name, old_field.idl_file_path) + cmd_name, field_name, old_field_type.name, old_field.idl_file_path + ) return # If cpp_type is changed, it's a potential breaking change. if old_field_type.cpp_type != new_field_type.cpp_type: - ctxt.add_reply_field_cpp_type_not_equal_error(cmd_name, field_name, new_field_type.name, - new_field.idl_file_path) + ctxt.add_reply_field_cpp_type_not_equal_error( + cmd_name, field_name, new_field_type.name, new_field.idl_file_path + ) # If serializer is changed, it's a potential breaking change. - if not is_unstable( - old_field.stability - ) and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST and old_field_type.serializer != new_field_type.serializer: + if ( + not is_unstable(old_field.stability) + and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST + and old_field_type.serializer != new_field_type.serializer + ): ctxt.add_reply_field_serializer_not_equal_error( - cmd_name, field_name, new_field_type.name, new_field.idl_file_path) + cmd_name, field_name, new_field_type.name, new_field.idl_file_path + ) # If deserializer is changed, it's a potential breaking change. - if not is_unstable( - old_field.stability - ) and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST and old_field_type.deserializer != new_field_type.deserializer: + if ( + not is_unstable(old_field.stability) + and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST + and old_field_type.deserializer != new_field_type.deserializer + ): ctxt.add_reply_field_deserializer_not_equal_error( - cmd_name, field_name, new_field_type.name, new_field.idl_file_path) + cmd_name, field_name, new_field_type.name, new_field.idl_file_path + ) if isinstance(old_field_type, syntax.VariantType): # If the new type is not variant just check the single type. - new_variant_types = new_field_type.variant_types if isinstance( - new_field_type, syntax.VariantType) else [new_field_type] + new_variant_types = ( + new_field_type.variant_types + if isinstance(new_field_type, syntax.VariantType) + else [new_field_type] + ) old_variant_types = old_field_type.variant_types # Check that new variant types are a subset of old variant types. @@ -658,78 +732,138 @@ def check_reply_field_type_recursive(ctxt: IDLCompatibilityContext, for old_variant_type in old_variant_types: if old_variant_type.name == new_variant_type.name: # Check that the old and new version of each variant type is also compatible. - old = FieldCompatibility(old_variant_type, old_field.idl_file, - old_field.idl_file_path, old_field.stability, - old_field.optional) - new = FieldCompatibility(new_variant_type, new_field.idl_file, - new_field.idl_file_path, new_field.stability, - new_field.optional) - check_reply_field_type(ctxt, - FieldCompatibilityPair(old, new, cmd_name, field_name)) + old = FieldCompatibility( + old_variant_type, + old_field.idl_file, + old_field.idl_file_path, + old_field.stability, + old_field.optional, + ) + new = FieldCompatibility( + new_variant_type, + new_field.idl_file, + new_field.idl_file_path, + new_field.stability, + new_field.optional, + ) + check_reply_field_type( + ctxt, FieldCompatibilityPair(old, new, cmd_name, field_name) + ) break else: # new_variant_type was not found in old_variant_types. - if not is_unstable(old_field.stability - ) and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST: + if ( + not is_unstable(old_field.stability) + and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST + ): ctxt.add_new_reply_field_variant_type_not_subset_error( - cmd_name, field_name, new_variant_type.name, new_field.idl_file_path) + cmd_name, + field_name, + new_variant_type.name, + new_field.idl_file_path, + ) # If new type is variant and has a struct as a variant type, compare old and new variant_struct_types. # Since enums can't be part of variant types, we don't explicitly check for enums. - if isinstance(new_field_type, - syntax.VariantType) and new_field_type.variant_struct_types is not None: - if old_field_type.variant_struct_types is None and not is_unstable( - old_field.stability) and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST: + if ( + isinstance(new_field_type, syntax.VariantType) + and new_field_type.variant_struct_types is not None + ): + if ( + old_field_type.variant_struct_types is None + and not is_unstable(old_field.stability) + and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST + ): for variant_type in new_field_type.variant_struct_types: ctxt.add_new_reply_field_variant_type_not_subset_error( - cmd_name, field_name, variant_type.name, new_field.idl_file_path) + cmd_name, field_name, variant_type.name, new_field.idl_file_path + ) return # If the length of both variant_struct_types is 1 then we want to check the struct fields # since an idl name change with the same field names is legal. We do not do this for # lengths > 1 because it would be too ambiguous to tell which pair of variant # types no longer comply with each other. - elif (len(old_field_type.variant_struct_types) == 1) and (len( - new_field_type.variant_struct_types) == 1): - check_reply_fields(ctxt, old_field_type.variant_struct_types[0], - new_field_type.variant_struct_types[0], cmd_name, - old_field.idl_file, new_field.idl_file, old_field.idl_file_path, - new_field.idl_file_path) + elif (len(old_field_type.variant_struct_types) == 1) and ( + len(new_field_type.variant_struct_types) == 1 + ): + check_reply_fields( + ctxt, + old_field_type.variant_struct_types[0], + new_field_type.variant_struct_types[0], + cmd_name, + old_field.idl_file, + new_field.idl_file, + old_field.idl_file_path, + new_field.idl_file_path, + ) return for new_variant_type in new_field_type.variant_struct_types: for old_variant_type in old_field_type.variant_struct_types: if old_variant_type.name == new_variant_type.name: - check_reply_fields(ctxt, old_variant_type, new_variant_type, cmd_name, - old_field.idl_file, new_field.idl_file, - old_field.idl_file_path, new_field.idl_file_path) + check_reply_fields( + ctxt, + old_variant_type, + new_variant_type, + cmd_name, + old_field.idl_file, + new_field.idl_file, + old_field.idl_file_path, + new_field.idl_file_path, + ) break else: - if not is_unstable(old_field.stability - ) and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST: + if ( + not is_unstable(old_field.stability) + and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST + ): # new_variant_type was not found in old_variant_struct_types ctxt.add_new_reply_field_variant_type_not_subset_error( - cmd_name, field_name, new_variant_type.name, new_field.idl_file_path) + cmd_name, + field_name, + new_variant_type.name, + new_field.idl_file_path, + ) - elif not is_unstable( - old_field.stability) and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST: + elif ( + not is_unstable(old_field.stability) + and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST + ): if isinstance(new_field_type, syntax.VariantType): - ctxt.add_new_reply_field_variant_type_error(cmd_name, field_name, old_field_type.name, - new_field.idl_file_path) + ctxt.add_new_reply_field_variant_type_error( + cmd_name, field_name, old_field_type.name, new_field.idl_file_path + ) else: - check_subset(ctxt, cmd_name, field_name, new_field_type.name, - new_field_type.bson_serialization_type, - old_field_type.bson_serialization_type, new_field.idl_file_path) + check_subset( + ctxt, + cmd_name, + field_name, + new_field_type.name, + new_field_type.bson_serialization_type, + old_field_type.bson_serialization_type, + new_field.idl_file_path, + ) -def check_reply_field_type(ctxt: IDLCompatibilityContext, field_pair: FieldCompatibilityPair): +def check_reply_field_type( + ctxt: IDLCompatibilityContext, field_pair: FieldCompatibilityPair +): """Check compatibility between old and new reply field type.""" old_field = field_pair.old new_field = field_pair.new cmd_name = field_pair.cmd_name field_name = field_pair.field_name - array_check = check_array_type(ctxt, "reply_field", old_field.field_type, new_field.field_type, - field_pair.cmd_name, 'type', old_field.idl_file_path, - new_field.idl_file_path, is_unstable(old_field.stability)) + array_check = check_array_type( + ctxt, + "reply_field", + old_field.field_type, + new_field.field_type, + field_pair.cmd_name, + "type", + old_field.idl_file_path, + new_field.idl_file_path, + is_unstable(old_field.stability), + ) if array_check == ArrayTypeCheckResult.INVALID: return @@ -742,11 +876,15 @@ def check_reply_field_type(ctxt: IDLCompatibilityContext, field_pair: FieldCompa cmd_name = field_pair.cmd_name field_name = field_pair.field_name if old_field_type is None: - ctxt.add_reply_field_type_invalid_error(cmd_name, field_name, old_field.idl_file_path) + ctxt.add_reply_field_type_invalid_error( + cmd_name, field_name, old_field.idl_file_path + ) ctxt.errors.dump_errors() sys.exit(1) if new_field_type is None: - ctxt.add_reply_field_type_invalid_error(cmd_name, field_name, new_field.idl_file_path) + ctxt.add_reply_field_type_invalid_error( + cmd_name, field_name, new_field.idl_file_path + ) ctxt.errors.dump_errors() sys.exit(1) @@ -755,32 +893,66 @@ def check_reply_field_type(ctxt: IDLCompatibilityContext, field_pair: FieldCompa if isinstance(old_field_type, syntax.Type): check_reply_field_type_recursive(ctxt, field_pair) - elif isinstance(old_field_type, syntax.Enum) and not is_unstable( - old_field.stability) and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST: + elif ( + isinstance(old_field_type, syntax.Enum) + and not is_unstable(old_field.stability) + and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST + ): if isinstance(new_field_type, syntax.Enum): - check_subset(ctxt, cmd_name, field_name, new_field_type.name, new_field_type.values, - old_field_type.values, new_field.idl_file_path) + check_subset( + ctxt, + cmd_name, + field_name, + new_field_type.name, + new_field_type.values, + old_field_type.values, + new_field.idl_file_path, + ) else: - ctxt.add_new_reply_field_type_not_enum_error(cmd_name, field_name, new_field_type.name, - old_field_type.name, - new_field.idl_file_path) + ctxt.add_new_reply_field_type_not_enum_error( + cmd_name, + field_name, + new_field_type.name, + old_field_type.name, + new_field.idl_file_path, + ) elif isinstance(old_field_type, syntax.Struct): if isinstance(new_field_type, syntax.Struct): - check_reply_fields(ctxt, old_field_type, new_field_type, cmd_name, old_field.idl_file, - new_field.idl_file, old_field.idl_file_path, new_field.idl_file_path) + check_reply_fields( + ctxt, + old_field_type, + new_field_type, + cmd_name, + old_field.idl_file, + new_field.idl_file, + old_field.idl_file_path, + new_field.idl_file_path, + ) else: - if not is_unstable( - old_field.stability) and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST: + if ( + not is_unstable(old_field.stability) + and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST + ): ctxt.add_new_reply_field_type_not_struct_error( - cmd_name, field_name, new_field_type.name, old_field_type.name, - new_field.idl_file_path) + cmd_name, + field_name, + new_field_type.name, + old_field_type.name, + new_field.idl_file_path, + ) -def check_array_type(ctxt: IDLCompatibilityContext, symbol: str, - old_type: Optional[Union[syntax.Enum, syntax.Struct, syntax.Type]], - new_type: Optional[Union[syntax.Enum, syntax.Struct, syntax.Type]], - cmd_name: str, symbol_name: str, old_idl_file_path: str, - new_idl_file_path: str, old_field_unstable: bool) -> ArrayTypeCheckResult: +def check_array_type( + ctxt: IDLCompatibilityContext, + symbol: str, + old_type: Optional[Union[syntax.Enum, syntax.Struct, syntax.Type]], + new_type: Optional[Union[syntax.Enum, syntax.Struct, syntax.Type]], + cmd_name: str, + symbol_name: str, + old_idl_file_path: str, + new_idl_file_path: str, + old_field_unstable: bool, +) -> ArrayTypeCheckResult: """ Check compatibility between old and new ArrayTypes. @@ -795,86 +967,150 @@ def check_array_type(ctxt: IDLCompatibilityContext, symbol: str, return ArrayTypeCheckResult.FALSE if (not old_is_array or not new_is_array) and not old_field_unstable: - ctxt.add_type_not_array_error(symbol, cmd_name, symbol_name, new_type.name, old_type.name, - new_idl_file_path if old_is_array else old_idl_file_path) + ctxt.add_type_not_array_error( + symbol, + cmd_name, + symbol_name, + new_type.name, + old_type.name, + new_idl_file_path if old_is_array else old_idl_file_path, + ) return ArrayTypeCheckResult.INVALID return ArrayTypeCheckResult.TRUE -def check_reply_field(ctxt: IDLCompatibilityContext, old_field: syntax.Field, - new_field: syntax.Field, cmd_name: str, old_idl_file: syntax.IDLParsedSpec, - new_idl_file: syntax.IDLParsedSpec, old_idl_file_path: str, - new_idl_file_path: str): +def check_reply_field( + ctxt: IDLCompatibilityContext, + old_field: syntax.Field, + new_field: syntax.Field, + cmd_name: str, + old_idl_file: syntax.IDLParsedSpec, + new_idl_file: syntax.IDLParsedSpec, + old_idl_file_path: str, + new_idl_file_path: str, +): """Check compatibility between old and new reply field.""" old_field_type = get_field_type(old_field, old_idl_file, old_idl_file_path) new_field_type = get_field_type(new_field, new_idl_file, new_idl_file_path) - old_field_optional = old_field.optional or (old_field_type - and old_field_type.name == "optionalBool") - new_field_optional = new_field.optional or (new_field_type - and new_field_type.name == "optionalBool") + old_field_optional = old_field.optional or ( + old_field_type and old_field_type.name == "optionalBool" + ) + new_field_optional = new_field.optional or ( + new_field_type and new_field_type.name == "optionalBool" + ) ignore_list_name: str = cmd_name + "-reply-" + new_field.name - if not is_unstable( - old_field.stability) and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST: - if is_unstable( - new_field.stability) and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST: - ctxt.add_new_reply_field_unstable_error(cmd_name, new_field.name, new_idl_file_path) + if ( + not is_unstable(old_field.stability) + and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST + ): + if ( + is_unstable(new_field.stability) + and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST + ): + ctxt.add_new_reply_field_unstable_error( + cmd_name, new_field.name, new_idl_file_path + ) if new_field_optional and not old_field_optional: - ctxt.add_new_reply_field_optional_error(cmd_name, new_field.name, new_idl_file_path) + ctxt.add_new_reply_field_optional_error( + cmd_name, new_field.name, new_idl_file_path + ) if new_field.validator: if old_field.validator: if new_field.validator != old_field.validator: - ctxt.add_reply_field_validators_not_equal_error(cmd_name, new_field.name, - new_idl_file_path) + ctxt.add_reply_field_validators_not_equal_error( + cmd_name, new_field.name, new_idl_file_path + ) else: - ctxt.add_reply_field_contains_validator_error(cmd_name, new_field.name, - new_idl_file_path) + ctxt.add_reply_field_contains_validator_error( + cmd_name, new_field.name, new_idl_file_path + ) # A reply field may not change from unstable to stable unless explicitly allowed to. - if is_unstable(old_field.stability) and not is_unstable( - new_field.stability) and ignore_list_name not in ALLOWED_STABLE_FIELDS_LIST: - ctxt.add_unstable_reply_field_changed_to_stable_error(cmd_name, new_field.name, - new_idl_file_path) + if ( + is_unstable(old_field.stability) + and not is_unstable(new_field.stability) + and ignore_list_name not in ALLOWED_STABLE_FIELDS_LIST + ): + ctxt.add_unstable_reply_field_changed_to_stable_error( + cmd_name, new_field.name, new_idl_file_path + ) - old_field_compatibility = FieldCompatibility(old_field_type, old_idl_file, old_idl_file_path, - old_field.stability, old_field.optional) - new_field_compatibility = FieldCompatibility(new_field_type, new_idl_file, new_idl_file_path, - new_field.stability, new_field.optional) - field_pair = FieldCompatibilityPair(old_field_compatibility, new_field_compatibility, cmd_name, - old_field.name) + old_field_compatibility = FieldCompatibility( + old_field_type, + old_idl_file, + old_idl_file_path, + old_field.stability, + old_field.optional, + ) + new_field_compatibility = FieldCompatibility( + new_field_type, + new_idl_file, + new_idl_file_path, + new_field.stability, + new_field.optional, + ) + field_pair = FieldCompatibilityPair( + old_field_compatibility, new_field_compatibility, cmd_name, old_field.name + ) check_reply_field_type(ctxt, field_pair) -def check_reply_fields(ctxt: IDLCompatibilityContext, old_reply: syntax.Struct, - new_reply: syntax.Struct, cmd_name: str, old_idl_file: syntax.IDLParsedSpec, - new_idl_file: syntax.IDLParsedSpec, old_idl_file_path: str, - new_idl_file_path: str): +def check_reply_fields( + ctxt: IDLCompatibilityContext, + old_reply: syntax.Struct, + new_reply: syntax.Struct, + cmd_name: str, + old_idl_file: syntax.IDLParsedSpec, + new_idl_file: syntax.IDLParsedSpec, + old_idl_file_path: str, + new_idl_file_path: str, +): """Check compatibility between old and new reply fields.""" for new_chained_type in new_reply.chained_types or []: - resolved_new_chained_type = get_chained_type_or_struct(new_chained_type, new_idl_file, - new_idl_file_path) + resolved_new_chained_type = get_chained_type_or_struct( + new_chained_type, new_idl_file, new_idl_file_path + ) if resolved_new_chained_type is not None: for old_chained_type in old_reply.chained_types or []: resolved_old_chained_type = get_chained_type_or_struct( - old_chained_type, old_idl_file, old_idl_file_path) - if (resolved_old_chained_type is not None - and resolved_old_chained_type.name == resolved_new_chained_type.name): + old_chained_type, old_idl_file, old_idl_file_path + ) + if ( + resolved_old_chained_type is not None + and resolved_old_chained_type.name == resolved_new_chained_type.name + ): # Check that the old and new version of each chained type is also compatible. - old = FieldCompatibility(resolved_old_chained_type, old_idl_file, - old_idl_file_path, stability='stable', optional=False) - new = FieldCompatibility(resolved_new_chained_type, new_idl_file, - new_idl_file_path, stability='stable', optional=False) + old = FieldCompatibility( + resolved_old_chained_type, + old_idl_file, + old_idl_file_path, + stability="stable", + optional=False, + ) + new = FieldCompatibility( + resolved_new_chained_type, + new_idl_file, + new_idl_file_path, + stability="stable", + optional=False, + ) check_reply_field_type( - ctxt, FieldCompatibilityPair(old, new, cmd_name, old_reply.name)) + ctxt, FieldCompatibilityPair(old, new, cmd_name, old_reply.name) + ) break else: # new chained type was not found in old chained types. ctxt.add_new_reply_chained_type_not_subset_error( - cmd_name, new_reply.name, resolved_new_chained_type.name, new_idl_file_path) + cmd_name, + new_reply.name, + resolved_new_chained_type.name, + new_idl_file_path, + ) old_reply_fields = get_all_struct_fields(old_reply, old_idl_file, old_idl_file_path) new_reply_fields = get_all_struct_fields(new_reply, new_idl_file, new_idl_file_path) @@ -883,19 +1119,30 @@ def check_reply_fields(ctxt: IDLCompatibilityContext, old_reply: syntax.Struct, for new_field in new_reply_fields or []: if new_field.name == old_field.name: new_field_exists = True - check_reply_field(ctxt, old_field, new_field, cmd_name, old_idl_file, new_idl_file, - old_idl_file_path, new_idl_file_path) + check_reply_field( + ctxt, + old_field, + new_field, + cmd_name, + old_idl_file, + new_idl_file, + old_idl_file_path, + new_idl_file_path, + ) break if not new_field_exists and not is_unstable(old_field.stability): - ctxt.add_new_reply_field_missing_error(cmd_name, old_field.name, old_idl_file_path) + ctxt.add_new_reply_field_missing_error( + cmd_name, old_field.name, old_idl_file_path + ) for new_field in new_reply_fields or []: # Check that all fields in the new IDL have specified the 'stability' field. if new_field.stability is None: - ctxt.add_new_reply_field_requires_stability_error(cmd_name, new_field.name, - new_idl_file_path) + ctxt.add_new_reply_field_requires_stability_error( + cmd_name, new_field.name, new_idl_file_path + ) # Check that newly added fields do not have an unallowed use of 'any' as the # bson_serialization_type. @@ -906,26 +1153,37 @@ def check_reply_fields(ctxt: IDLCompatibilityContext, old_reply: syntax.Struct, if newly_added: allow_name: str = cmd_name + "-reply-" + new_field.name - if not is_unstable( - new_field.stability) and allow_name not in ALLOWED_STABLE_FIELDS_LIST: - ctxt.add_new_reply_field_added_as_stable_error(cmd_name, new_field.name, - new_idl_file_path) + if ( + not is_unstable(new_field.stability) + and allow_name not in ALLOWED_STABLE_FIELDS_LIST + ): + ctxt.add_new_reply_field_added_as_stable_error( + cmd_name, new_field.name, new_idl_file_path + ) new_field_type = get_field_type(new_field, new_idl_file, new_idl_file_path) # If we encounter a bson_serialization_type of None, we skip checking if 'any' is used. - if isinstance( - new_field_type, syntax.Type - ) and new_field_type.bson_serialization_type is not None and "any" in new_field_type.bson_serialization_type: + if ( + isinstance(new_field_type, syntax.Type) + and new_field_type.bson_serialization_type is not None + and "any" in new_field_type.bson_serialization_type + ): # If 'any' is not explicitly allowed as the bson_serialization_type. - any_allow = allow_name in ALLOW_ANY_TYPE_LIST or new_field_type.name == 'optionalBool' + any_allow = ( + allow_name in ALLOW_ANY_TYPE_LIST + or new_field_type.name == "optionalBool" + ) if not any_allow: ctxt.add_new_reply_field_bson_any_not_allowed_error( - cmd_name, new_field.name, new_field_type.name, new_idl_file_path) + cmd_name, new_field.name, new_field_type.name, new_idl_file_path + ) -def check_param_or_command_type_recursive(ctxt: IDLCompatibilityContext, - field_pair: FieldCompatibilityPair, - is_command_parameter: bool): +def check_param_or_command_type_recursive( + ctxt: IDLCompatibilityContext, + field_pair: FieldCompatibilityPair, + is_command_parameter: bool, +): """ Check compatibility between old and new command or param type recursively. @@ -939,71 +1197,126 @@ def check_param_or_command_type_recursive(ctxt: IDLCompatibilityContext, cmd_name = field_pair.cmd_name param_name = field_pair.field_name - ignore_list_name: str = cmd_name + "-param-" + param_name if is_command_parameter else cmd_name + ignore_list_name: str = ( + cmd_name + "-param-" + param_name if is_command_parameter else cmd_name + ) # If the old field is unstable, we only add errors related to the use of 'any' as the # bson_serialization_type. For all other errors, we check that the old field is stable # before adding an error. if not isinstance(new_type, syntax.Type): - if not is_unstable( - old_field.stability) and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST: + if ( + not is_unstable(old_field.stability) + and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST + ): ctxt.add_new_command_or_param_type_enum_or_struct_error( - cmd_name, new_type.name, old_type.name, new_field.idl_file_path, param_name, - is_command_parameter) + cmd_name, + new_type.name, + old_type.name, + new_field.idl_file_path, + param_name, + is_command_parameter, + ) return # If bson_serialization_type switches from 'any' to non-any type. - if "any" in old_type.bson_serialization_type and "any" not in new_type.bson_serialization_type: + if ( + "any" in old_type.bson_serialization_type + and "any" not in new_type.bson_serialization_type + ): if ignore_list_name not in IGNORE_ANY_TO_NON_ANY_LIST: ctxt.add_old_command_or_param_type_bson_any_error( - cmd_name, old_type.name, new_type.name, old_field.idl_file_path, param_name, - is_command_parameter) + cmd_name, + old_type.name, + new_type.name, + old_field.idl_file_path, + param_name, + is_command_parameter, + ) return # If bson_serialization_type switches from non-any to 'any' type. - if "any" not in old_type.bson_serialization_type and "any" in new_type.bson_serialization_type and ignore_list_name not in IGNORE_NON_ANY_TO_ANY_LIST: - ctxt.add_new_command_or_param_type_bson_any_error(cmd_name, old_type.name, new_type.name, - new_field.idl_file_path, param_name, - is_command_parameter) + if ( + "any" not in old_type.bson_serialization_type + and "any" in new_type.bson_serialization_type + and ignore_list_name not in IGNORE_NON_ANY_TO_ANY_LIST + ): + ctxt.add_new_command_or_param_type_bson_any_error( + cmd_name, + old_type.name, + new_type.name, + new_field.idl_file_path, + param_name, + is_command_parameter, + ) return if "any" in old_type.bson_serialization_type: # If 'any' is not explicitly allowed as the bson_serialization_type. if ignore_list_name not in ALLOW_ANY_TYPE_LIST: ctxt.add_old_command_or_param_type_bson_any_not_allowed_error( - cmd_name, old_type.name, old_field.idl_file_path, param_name, is_command_parameter) + cmd_name, + old_type.name, + old_field.idl_file_path, + param_name, + is_command_parameter, + ) return # If cpp_type is changed, it's a potential breaking change. if old_type.cpp_type != new_type.cpp_type: - ignore_list_name_with_types: str = f"{ignore_list_name}-{old_type.cpp_type}-{new_type.cpp_type}" + ignore_list_name_with_types: str = ( + f"{ignore_list_name}-{old_type.cpp_type}-{new_type.cpp_type}" + ) if ignore_list_name_with_types not in ALLOW_CPP_TYPE_CHANGE_LIST: - ctxt.add_command_or_param_cpp_type_not_equal_error(cmd_name, new_type.name, - new_field.idl_file_path, - param_name, is_command_parameter) + ctxt.add_command_or_param_cpp_type_not_equal_error( + cmd_name, + new_type.name, + new_field.idl_file_path, + param_name, + is_command_parameter, + ) # If serializer is changed, it's a potential breaking change. - if (not is_unstable(old_field.stability) - and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST - ) and old_type.serializer != new_type.serializer: + if ( + not is_unstable(old_field.stability) + and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST + ) and old_type.serializer != new_type.serializer: ctxt.add_command_or_param_serializer_not_equal_error( - cmd_name, new_type.name, new_field.idl_file_path, param_name, is_command_parameter) + cmd_name, + new_type.name, + new_field.idl_file_path, + param_name, + is_command_parameter, + ) # If deserializer is changed, it's a potential breaking change. - if (not is_unstable(old_field.stability) - and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST - ) and old_type.deserializer != new_type.deserializer: + if ( + not is_unstable(old_field.stability) + and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST + ) and old_type.deserializer != new_type.deserializer: ctxt.add_command_or_param_deserializer_not_equal_error( - cmd_name, new_type.name, new_field.idl_file_path, param_name, is_command_parameter) + cmd_name, + new_type.name, + new_field.idl_file_path, + param_name, + is_command_parameter, + ) if isinstance(old_type, syntax.VariantType): if not isinstance(new_type, syntax.VariantType): - if not is_unstable( - old_field.stability) and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST: + if ( + not is_unstable(old_field.stability) + and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST + ): ctxt.add_new_command_or_param_type_not_variant_type_error( - cmd_name, new_type.name, new_field.idl_file_path, param_name, - is_command_parameter) + cmd_name, + new_type.name, + new_field.idl_file_path, + param_name, + is_command_parameter, + ) else: new_variant_types = new_type.variant_types old_variant_types = old_type.variant_types @@ -1013,26 +1326,44 @@ def check_param_or_command_type_recursive(ctxt: IDLCompatibilityContext, for new_variant_type in new_variant_types: # object->object_owned serialize to the same bson type. object_owned->object is # not always safe so we only limit this special case to object->object_owned. - if (old_variant_type.name == "object" and new_variant_type.name == "object_owned") or \ - old_variant_type.name == new_variant_type.name: + if ( + old_variant_type.name == "object" + and new_variant_type.name == "object_owned" + ) or old_variant_type.name == new_variant_type.name: # Check that the old and new version of each variant type is also compatible. - old = FieldCompatibility(old_variant_type, old_field.idl_file, - old_field.idl_file_path, old_field.stability, - old_field.optional) - new = FieldCompatibility(new_variant_type, new_field.idl_file, - new_field.idl_file_path, new_field.stability, - new_field.optional) + old = FieldCompatibility( + old_variant_type, + old_field.idl_file, + old_field.idl_file_path, + old_field.stability, + old_field.optional, + ) + new = FieldCompatibility( + new_variant_type, + new_field.idl_file, + new_field.idl_file_path, + new_field.stability, + new_field.optional, + ) check_param_or_command_type( - ctxt, FieldCompatibilityPair(old, new, cmd_name, param_name), - is_command_parameter) + ctxt, + FieldCompatibilityPair(old, new, cmd_name, param_name), + is_command_parameter, + ) break else: - if not is_unstable(old_field.stability - ) and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST: + if ( + not is_unstable(old_field.stability) + and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST + ): # old_variant_type was not found in new_variant_types. ctxt.add_new_command_or_param_variant_type_not_superset_error( - cmd_name, old_variant_type.name, new_field.idl_file_path, param_name, - is_command_parameter) + cmd_name, + old_variant_type.name, + new_field.idl_file_path, + param_name, + is_command_parameter, + ) # If old and new types both have a struct as a variant type, compare old and new variant_struct_type. # Since enums can't be part of variant types, we don't explicitly check for enums. @@ -1040,8 +1371,10 @@ def check_param_or_command_type_recursive(ctxt: IDLCompatibilityContext, return if new_type.variant_struct_types is None: - if is_unstable( - old_field.stability) or ignore_list_name in IGNORE_STABLE_TO_UNSTABLE_LIST: + if ( + is_unstable(old_field.stability) + or ignore_list_name in IGNORE_STABLE_TO_UNSTABLE_LIST + ): return # If new_type.variant_struct_types in None then add a @@ -1049,20 +1382,32 @@ def check_param_or_command_type_recursive(ctxt: IDLCompatibilityContext, # old_type.variant_struct_types. for old_variant in old_type.variant_struct_types: ctxt.add_new_command_or_param_variant_type_not_superset_error( - cmd_name, old_variant.name, new_field.idl_file_path, param_name, - is_command_parameter) + cmd_name, + old_variant.name, + new_field.idl_file_path, + param_name, + is_command_parameter, + ) return # If the length of both variant_struct_types is 1 then we want to check the struct fields # since an idl name change with the same field names is legal. We do not do this for # lengths > 1 because it would be too ambiguous to tell which pair of variant # types no longer comply with each other. - if (len(old_type.variant_struct_types) == 1) and (len( - new_type.variant_struct_types) == 1): + if (len(old_type.variant_struct_types) == 1) and ( + len(new_type.variant_struct_types) == 1 + ): check_command_params_or_type_struct_fields( - ctxt, old_type.variant_struct_types[0], new_type.variant_struct_types[0], - cmd_name, old_field.idl_file, new_field.idl_file, old_field.idl_file_path, - new_field.idl_file_path, is_command_parameter) + ctxt, + old_type.variant_struct_types[0], + new_type.variant_struct_types[0], + cmd_name, + old_field.idl_file, + new_field.idl_file, + old_field.idl_file_path, + new_field.idl_file_path, + is_command_parameter, + ) return for old_variant in old_type.variant_struct_types: for new_variant in new_type.variant_struct_types: @@ -1070,38 +1415,69 @@ def check_param_or_command_type_recursive(ctxt: IDLCompatibilityContext, # new_type.variant_struct_types, call check_command_params_or_type_struct_fields. if new_variant.name == old_variant.name: check_command_params_or_type_struct_fields( - ctxt, old_variant, new_variant, cmd_name, old_field.idl_file, - new_field.idl_file, old_field.idl_file_path, new_field.idl_file_path, - is_command_parameter) + ctxt, + old_variant, + new_variant, + cmd_name, + old_field.idl_file, + new_field.idl_file, + old_field.idl_file_path, + new_field.idl_file_path, + is_command_parameter, + ) break # If an item in old_type.variant_struct_types was not found in # new_type.variant_struct_types then add a new_command_or_param_variant_type_not_superset_error. else: - if not is_unstable(old_field.stability - ) and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST: + if ( + not is_unstable(old_field.stability) + and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST + ): ctxt.add_new_command_or_param_variant_type_not_superset_error( - cmd_name, old_variant.name, new_field.idl_file_path, param_name, - is_command_parameter) + cmd_name, + old_variant.name, + new_field.idl_file_path, + param_name, + is_command_parameter, + ) - elif not is_unstable( - old_field.stability) and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST: - check_superset(ctxt, cmd_name, new_type.name, new_type.bson_serialization_type, - old_type.bson_serialization_type, new_field.idl_file_path, param_name, - is_command_parameter) + elif ( + not is_unstable(old_field.stability) + and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST + ): + check_superset( + ctxt, + cmd_name, + new_type.name, + new_type.bson_serialization_type, + old_type.bson_serialization_type, + new_field.idl_file_path, + param_name, + is_command_parameter, + ) -def check_param_or_command_type(ctxt: IDLCompatibilityContext, field_pair: FieldCompatibilityPair, - is_command_parameter: bool): +def check_param_or_command_type( + ctxt: IDLCompatibilityContext, + field_pair: FieldCompatibilityPair, + is_command_parameter: bool, +): """Check compatibility between old and new command parameter type or command type.""" old_field = field_pair.old new_field = field_pair.new field_name = field_pair.field_name cmd_name = field_pair.cmd_name array_check = check_array_type( - ctxt, "command_parameter" if is_command_parameter else "command_namespace", - old_field.field_type, new_field.field_type, field_pair.cmd_name, - field_name if is_command_parameter else "type", old_field.idl_file_path, - new_field.idl_file_path, is_unstable(old_field.stability)) + ctxt, + "command_parameter" if is_command_parameter else "command_namespace", + old_field.field_type, + new_field.field_type, + field_pair.cmd_name, + field_name if is_command_parameter else "type", + old_field.idl_file_path, + new_field.idl_file_path, + is_unstable(old_field.stability), + ) if array_check == ArrayTypeCheckResult.INVALID: return @@ -1112,13 +1488,21 @@ def check_param_or_command_type(ctxt: IDLCompatibilityContext, field_pair: Field old_type = old_field.field_type new_type = new_field.field_type if old_type is None: - ctxt.add_command_or_param_type_invalid_error(cmd_name, old_field.idl_file_path, - field_pair.field_name, is_command_parameter) + ctxt.add_command_or_param_type_invalid_error( + cmd_name, + old_field.idl_file_path, + field_pair.field_name, + is_command_parameter, + ) ctxt.errors.dump_errors() sys.exit(1) if new_type is None: - ctxt.add_command_or_param_type_invalid_error(cmd_name, new_field.idl_file_path, - field_pair.field_name, is_command_parameter) + ctxt.add_command_or_param_type_invalid_error( + cmd_name, + new_field.idl_file_path, + field_pair.field_name, + is_command_parameter, + ) ctxt.errors.dump_errors() sys.exit(1) @@ -1128,32 +1512,69 @@ def check_param_or_command_type(ctxt: IDLCompatibilityContext, field_pair: Field check_param_or_command_type_recursive(ctxt, field_pair, is_command_parameter) # Only add type errors if the old field is stable. - elif isinstance(old_type, syntax.Enum) and not is_unstable( - old_field.stability) and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST: + elif ( + isinstance(old_type, syntax.Enum) + and not is_unstable(old_field.stability) + and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST + ): if isinstance(new_type, syntax.Enum): - check_superset(ctxt, cmd_name, new_type.name, new_type.values, old_type.values, - new_field.idl_file_path, field_pair.field_name, is_command_parameter) + check_superset( + ctxt, + cmd_name, + new_type.name, + new_type.values, + old_type.values, + new_field.idl_file_path, + field_pair.field_name, + is_command_parameter, + ) else: ctxt.add_new_command_or_param_type_not_enum_error( - cmd_name, new_type.name, old_type.name, new_field.idl_file_path, - field_pair.field_name, is_command_parameter) + cmd_name, + new_type.name, + old_type.name, + new_field.idl_file_path, + field_pair.field_name, + is_command_parameter, + ) elif isinstance(old_type, syntax.Struct): if isinstance(new_type, syntax.Struct): check_command_params_or_type_struct_fields( - ctxt, old_type, new_type, cmd_name, old_field.idl_file, new_field.idl_file, - old_field.idl_file_path, new_field.idl_file_path, is_command_parameter) + ctxt, + old_type, + new_type, + cmd_name, + old_field.idl_file, + new_field.idl_file, + old_field.idl_file_path, + new_field.idl_file_path, + is_command_parameter, + ) else: - if not is_unstable( - old_field.stability) and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST: + if ( + not is_unstable(old_field.stability) + and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST + ): ctxt.add_new_command_or_param_type_not_struct_error( - cmd_name, new_type.name, old_type.name, new_field.idl_file_path, - field_pair.field_name, is_command_parameter) + cmd_name, + new_type.name, + old_type.name, + new_field.idl_file_path, + field_pair.field_name, + is_command_parameter, + ) -def check_param_or_type_validator(ctxt: IDLCompatibilityContext, old_field: syntax.Field, - new_field: syntax.Field, cmd_name: str, new_idl_file_path: str, - type_name: Optional[str], is_command_parameter: bool): +def check_param_or_type_validator( + ctxt: IDLCompatibilityContext, + old_field: syntax.Field, + new_field: syntax.Field, + cmd_name: str, + new_idl_file_path: str, + type_name: Optional[str], + is_command_parameter: bool, +): """ Check compatibility between old and new validators. @@ -1171,26 +1592,44 @@ def check_param_or_type_validator(ctxt: IDLCompatibilityContext, old_field: synt if new_field.validator: if old_field.validator: old_field_name: str = cmd_name + "-param-" + old_field.name - if new_field.validator != old_field.validator and old_field_name not in ignore_validator_check_list: + if ( + new_field.validator != old_field.validator + and old_field_name not in ignore_validator_check_list + ): ctxt.add_command_or_param_type_validators_not_equal_error( - cmd_name, new_field.name, new_idl_file_path, type_name, is_command_parameter) + cmd_name, + new_field.name, + new_idl_file_path, + type_name, + is_command_parameter, + ) else: new_field_name: str = cmd_name + "-param-" + new_field.name # In SERVER-77382 we fixed the error handling of creating time-series collections by # adding a new validator to two 'stable' fields, but it didn't break any stable API # guarantees. - if new_field_name not in ["create-param-timeField", "create-param-metaField"]: + if new_field_name not in [ + "create-param-timeField", + "create-param-metaField", + ]: ctxt.add_command_or_param_type_contains_validator_error( - cmd_name, new_field.name, new_idl_file_path, type_name, is_command_parameter) + cmd_name, + new_field.name, + new_idl_file_path, + type_name, + is_command_parameter, + ) -def get_all_struct_fields(struct: syntax.Struct, idl_file: syntax.IDLParsedSpec, - idl_file_path: str): +def get_all_struct_fields( + struct: syntax.Struct, idl_file: syntax.IDLParsedSpec, idl_file_path: str +): """Get all the fields of a struct, including the chained struct fields.""" all_fields = struct.fields or [] for chained_struct in struct.chained_structs or []: - resolved_chained_struct = get_chained_type_or_struct(chained_struct, idl_file, - idl_file_path) + resolved_chained_struct = get_chained_type_or_struct( + chained_struct, idl_file, idl_file_path + ) if resolved_chained_struct is not None: for field in resolved_chained_struct.fields: all_fields.append(field) @@ -1199,38 +1638,69 @@ def get_all_struct_fields(struct: syntax.Struct, idl_file: syntax.IDLParsedSpec, def check_command_params_or_type_struct_fields( - ctxt: IDLCompatibilityContext, old_struct: syntax.Struct, new_struct: syntax.Struct, - cmd_name: str, old_idl_file: syntax.IDLParsedSpec, new_idl_file: syntax.IDLParsedSpec, - old_idl_file_path: str, new_idl_file_path: str, is_command_parameter: bool): + ctxt: IDLCompatibilityContext, + old_struct: syntax.Struct, + new_struct: syntax.Struct, + cmd_name: str, + old_idl_file: syntax.IDLParsedSpec, + new_idl_file: syntax.IDLParsedSpec, + old_idl_file_path: str, + new_idl_file_path: str, + is_command_parameter: bool, +): """Check compatibility between old and new parameters or command type fields.""" # Check chained types. for old_chained_type in old_struct.chained_types or []: - resolved_old_chained_type = get_chained_type_or_struct(old_chained_type, old_idl_file, - old_idl_file_path) + resolved_old_chained_type = get_chained_type_or_struct( + old_chained_type, old_idl_file, old_idl_file_path + ) if resolved_old_chained_type is not None: for new_chained_type in new_struct.chained_types or []: resolved_new_chained_type = get_chained_type_or_struct( - new_chained_type, new_idl_file, new_idl_file_path) - if (resolved_new_chained_type is not None - and resolved_old_chained_type.name == resolved_new_chained_type.name): + new_chained_type, new_idl_file, new_idl_file_path + ) + if ( + resolved_new_chained_type is not None + and resolved_old_chained_type.name == resolved_new_chained_type.name + ): # Check that the old and new version of each chained type is also compatible. - old = FieldCompatibility(resolved_old_chained_type, old_idl_file, - old_idl_file_path, stability="stable", optional=False) - new = FieldCompatibility(resolved_new_chained_type, new_idl_file, - new_idl_file_path, stability="stable", optional=False) + old = FieldCompatibility( + resolved_old_chained_type, + old_idl_file, + old_idl_file_path, + stability="stable", + optional=False, + ) + new = FieldCompatibility( + resolved_new_chained_type, + new_idl_file, + new_idl_file_path, + stability="stable", + optional=False, + ) check_param_or_command_type( - ctxt, FieldCompatibilityPair(old, new, cmd_name, old_struct.name), - is_command_parameter=False) + ctxt, + FieldCompatibilityPair(old, new, cmd_name, old_struct.name), + is_command_parameter=False, + ) break else: # old chained type was not found in new chained types. ctxt.add_new_command_or_param_chained_type_not_superset_error( - cmd_name, old_chained_type.name, new_idl_file_path, old_struct.name, - is_command_parameter) + cmd_name, + old_chained_type.name, + new_idl_file_path, + old_struct.name, + is_command_parameter, + ) - old_struct_fields = get_all_struct_fields(old_struct, old_idl_file, old_idl_file_path) - new_struct_fields = get_all_struct_fields(new_struct, new_idl_file, new_idl_file_path) + old_struct_fields = get_all_struct_fields( + old_struct, old_idl_file, old_idl_file_path + ) + new_struct_fields = get_all_struct_fields( + new_struct, new_idl_file, new_idl_file_path + ) # We need to special-case the stmtId parameter because it was removed. However, it's not a # breaking change to the API because it was added and removed behind a feature flag, so it was @@ -1246,15 +1716,32 @@ def check_command_params_or_type_struct_fields( if new_field.name == old_field.name: new_field_exists = True check_command_param_or_type_struct_field( - ctxt, old_field, new_field, cmd_name, old_idl_file, new_idl_file, - old_idl_file_path, new_idl_file_path, old_struct.name, is_command_parameter) + ctxt, + old_field, + new_field, + cmd_name, + old_idl_file, + new_idl_file, + old_idl_file_path, + new_idl_file_path, + old_struct.name, + is_command_parameter, + ) break allow_name: str = cmd_name + "-param-" + old_field.name - if not new_field_exists and not is_unstable( - old_field.stability) and allow_name not in allow_list: + if ( + not new_field_exists + and not is_unstable(old_field.stability) + and allow_name not in allow_list + ): ctxt.add_new_param_or_command_type_field_missing_error( - cmd_name, old_field.name, old_idl_file_path, old_struct.name, is_command_parameter) + cmd_name, + old_field.name, + old_idl_file_path, + old_struct.name, + is_command_parameter, + ) # Check if a new field has been added to the parameters or type struct. # If so, it must be optional. @@ -1262,7 +1749,8 @@ def check_command_params_or_type_struct_fields( # Check that all fields in the new IDL have specified the 'stability' field. if new_field.stability is None: ctxt.add_new_param_or_command_type_field_requires_stability_error( - cmd_name, new_field.name, new_idl_file_path, is_command_parameter) + cmd_name, new_field.name, new_idl_file_path, is_command_parameter + ) newly_added = True for old_field in old_struct_fields or []: @@ -1271,97 +1759,180 @@ def check_command_params_or_type_struct_fields( if newly_added: allow_stable_name: str = cmd_name + "-param-" + new_field.name - if not is_unstable( - new_field.stability) and allow_stable_name not in ALLOWED_STABLE_FIELDS_LIST: + if ( + not is_unstable(new_field.stability) + and allow_stable_name not in ALLOWED_STABLE_FIELDS_LIST + ): ctxt.add_new_param_or_type_field_added_as_stable_error( - cmd_name, new_field.name, new_idl_file_path, is_command_parameter) + cmd_name, new_field.name, new_idl_file_path, is_command_parameter + ) new_field_type = get_field_type(new_field, new_idl_file, new_idl_file_path) - new_field_optional = new_field.optional or (new_field_type - and new_field_type.name == 'optionalBool') - if not new_field_optional and new_field.default is None and not is_unstable( - new_field.stability): + new_field_optional = new_field.optional or ( + new_field_type and new_field_type.name == "optionalBool" + ) + if ( + not new_field_optional + and new_field.default is None + and not is_unstable(new_field.stability) + ): ctxt.add_new_param_or_command_type_field_added_required_error( - cmd_name, new_field.name, new_idl_file_path, new_struct.name, - is_command_parameter) + cmd_name, + new_field.name, + new_idl_file_path, + new_struct.name, + is_command_parameter, + ) - if (is_unstable(new_field.stability) and not new_field.stability == "internal" - and not new_field_optional): + if ( + is_unstable(new_field.stability) + and not new_field.stability == "internal" + and not new_field_optional + ): ctxt.add_new_param_or_type_field_added_as_unstable_required_error( - cmd_name, new_field.name, new_idl_file_path, is_command_parameter) + cmd_name, new_field.name, new_idl_file_path, is_command_parameter + ) # Check that a new field does not have an unallowed use of 'any' as the bson_serialization_type. - any_allow_name: str = (cmd_name + "-param-" + new_field.name - if is_command_parameter else cmd_name) + any_allow_name: str = ( + cmd_name + "-param-" + new_field.name + if is_command_parameter + else cmd_name + ) # If we encounter a bson_serialization_type of None, we skip checking if 'any' is used. - if isinstance( - new_field_type, syntax.Type - ) and new_field_type.bson_serialization_type is not None and "any" in new_field_type.bson_serialization_type: + if ( + isinstance(new_field_type, syntax.Type) + and new_field_type.bson_serialization_type is not None + and "any" in new_field_type.bson_serialization_type + ): # If 'any' is not explicitly allowed as the bson_serialization_type. - any_allow = any_allow_name in ALLOW_ANY_TYPE_LIST or new_field_type.name == 'optionalBool' + any_allow = ( + any_allow_name in ALLOW_ANY_TYPE_LIST + or new_field_type.name == "optionalBool" + ) if not any_allow: ctxt.add_new_command_or_param_type_bson_any_not_allowed_error( - cmd_name, new_field_type.name, old_idl_file_path, new_field.name, - is_command_parameter) + cmd_name, + new_field_type.name, + old_idl_file_path, + new_field.name, + is_command_parameter, + ) def check_command_param_or_type_struct_field( - ctxt: IDLCompatibilityContext, old_field: syntax.Field, new_field: syntax.Field, - cmd_name: str, old_idl_file: syntax.IDLParsedSpec, new_idl_file: syntax.IDLParsedSpec, - old_idl_file_path: str, new_idl_file_path: str, type_name: Optional[str], - is_command_parameter: bool): + ctxt: IDLCompatibilityContext, + old_field: syntax.Field, + new_field: syntax.Field, + cmd_name: str, + old_idl_file: syntax.IDLParsedSpec, + new_idl_file: syntax.IDLParsedSpec, + old_idl_file_path: str, + new_idl_file_path: str, + type_name: Optional[str], + is_command_parameter: bool, +): """Check compatibility between the old and new command parameter or command type struct field.""" ignore_list_name: str = cmd_name + "-param-" + new_field.name - if not is_unstable(old_field.stability) and is_unstable( - new_field.stability) and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST: + if ( + not is_unstable(old_field.stability) + and is_unstable(new_field.stability) + and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST + ): ctxt.add_new_param_or_command_type_field_unstable_error( - cmd_name, old_field.name, old_idl_file_path, type_name, is_command_parameter) + cmd_name, old_field.name, old_idl_file_path, type_name, is_command_parameter + ) # A command param or type field may not change from unstable to stable unless explicitly allowed to. - if is_unstable(old_field.stability) and not is_unstable( - new_field.stability) and ignore_list_name not in ALLOWED_STABLE_FIELDS_LIST: + if ( + is_unstable(old_field.stability) + and not is_unstable(new_field.stability) + and ignore_list_name not in ALLOWED_STABLE_FIELDS_LIST + ): ctxt.add_unstable_param_or_type_field_to_stable_error( - cmd_name, old_field.name, old_idl_file_path, is_command_parameter) + cmd_name, old_field.name, old_idl_file_path, is_command_parameter + ) # If old field is unstable and new field is stable, the new field should either be optional or # have a default value, unless the old field was a required field. old_field_type = get_field_type(old_field, old_idl_file, old_idl_file_path) new_field_type = get_field_type(new_field, new_idl_file, new_idl_file_path) - old_field_optional = old_field.optional or (old_field_type - and old_field_type.name == "optionalBool") - new_field_optional = new_field.optional or (new_field_type - and new_field_type.name == "optionalBool") - if is_unstable(old_field.stability) and not is_unstable( - new_field.stability) and not new_field_optional and new_field.default is None: + old_field_optional = old_field.optional or ( + old_field_type and old_field_type.name == "optionalBool" + ) + new_field_optional = new_field.optional or ( + new_field_type and new_field_type.name == "optionalBool" + ) + if ( + is_unstable(old_field.stability) + and not is_unstable(new_field.stability) + and not new_field_optional + and new_field.default is None + ): # Only error if the old field was not a required field already. if old_field_optional or old_field.default is not None: ctxt.add_new_param_or_command_type_field_stable_required_no_default_error( - cmd_name, old_field.name, old_idl_file_path, type_name, is_command_parameter) + cmd_name, + old_field.name, + old_idl_file_path, + type_name, + is_command_parameter, + ) - if not is_unstable( - old_field.stability - ) and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST and old_field_optional and not new_field_optional: + if ( + not is_unstable(old_field.stability) + and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST + and old_field_optional + and not new_field_optional + ): ctxt.add_new_param_or_command_type_field_required_error( - cmd_name, old_field.name, old_idl_file_path, type_name, is_command_parameter) + cmd_name, old_field.name, old_idl_file_path, type_name, is_command_parameter + ) - if not is_unstable( - old_field.stability) and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST: - check_param_or_type_validator(ctxt, old_field, new_field, cmd_name, new_idl_file_path, - type_name, is_command_parameter) + if ( + not is_unstable(old_field.stability) + and ignore_list_name not in IGNORE_STABLE_TO_UNSTABLE_LIST + ): + check_param_or_type_validator( + ctxt, + old_field, + new_field, + cmd_name, + new_idl_file_path, + type_name, + is_command_parameter, + ) - old_field_compatibility = FieldCompatibility(old_field_type, old_idl_file, old_idl_file_path, - old_field.stability, old_field.optional) - new_field_compatibility = FieldCompatibility(new_field_type, new_idl_file, new_idl_file_path, - new_field.stability, new_field.optional) - field_pair = FieldCompatibilityPair(old_field_compatibility, new_field_compatibility, cmd_name, - old_field.name) + old_field_compatibility = FieldCompatibility( + old_field_type, + old_idl_file, + old_idl_file_path, + old_field.stability, + old_field.optional, + ) + new_field_compatibility = FieldCompatibility( + new_field_type, + new_idl_file, + new_idl_file_path, + new_field.stability, + new_field.optional, + ) + field_pair = FieldCompatibilityPair( + old_field_compatibility, new_field_compatibility, cmd_name, old_field.name + ) check_param_or_command_type(ctxt, field_pair, is_command_parameter) -def check_namespace(ctxt: IDLCompatibilityContext, old_cmd: syntax.Command, new_cmd: syntax.Command, - old_idl_file: syntax.IDLParsedSpec, new_idl_file: syntax.IDLParsedSpec, - old_idl_file_path: str, new_idl_file_path: str): +def check_namespace( + ctxt: IDLCompatibilityContext, + old_cmd: syntax.Command, + new_cmd: syntax.Command, + old_idl_file: syntax.IDLParsedSpec, + new_idl_file: syntax.IDLParsedSpec, + old_idl_file_path: str, + new_idl_file_path: str, +): """Check compatibility between old and new namespace.""" old_namespace = old_cmd.namespace new_namespace = new_cmd.namespace @@ -1369,51 +1940,81 @@ def check_namespace(ctxt: IDLCompatibilityContext, old_cmd: syntax.Command, new_ # IDL parser already checks that namespace must be one of these 4 types. if old_namespace == common.COMMAND_NAMESPACE_IGNORED: if new_namespace != common.COMMAND_NAMESPACE_IGNORED: - ctxt.add_new_namespace_incompatible_error(old_cmd.command_name, old_namespace, - new_namespace, new_idl_file_path) + ctxt.add_new_namespace_incompatible_error( + old_cmd.command_name, old_namespace, new_namespace, new_idl_file_path + ) elif old_namespace == common.COMMAND_NAMESPACE_CONCATENATE_WITH_DB_OR_UUID: - if new_namespace not in (common.COMMAND_NAMESPACE_IGNORED, - common.COMMAND_NAMESPACE_CONCATENATE_WITH_DB_OR_UUID): - ctxt.add_new_namespace_incompatible_error(old_cmd.command_name, old_namespace, - new_namespace, new_idl_file_path) + if new_namespace not in ( + common.COMMAND_NAMESPACE_IGNORED, + common.COMMAND_NAMESPACE_CONCATENATE_WITH_DB_OR_UUID, + ): + ctxt.add_new_namespace_incompatible_error( + old_cmd.command_name, old_namespace, new_namespace, new_idl_file_path + ) elif old_namespace == common.COMMAND_NAMESPACE_CONCATENATE_WITH_DB: if new_namespace == common.COMMAND_NAMESPACE_TYPE: - ctxt.add_new_namespace_incompatible_error(old_cmd.command_name, old_namespace, - new_namespace, new_idl_file_path) + ctxt.add_new_namespace_incompatible_error( + old_cmd.command_name, old_namespace, new_namespace, new_idl_file_path + ) elif old_namespace == common.COMMAND_NAMESPACE_TYPE: old_type = get_field_type(old_cmd, old_idl_file, old_idl_file_path) if new_namespace == common.COMMAND_NAMESPACE_TYPE: new_type = get_field_type(new_cmd, new_idl_file, new_idl_file_path) - old = FieldCompatibility(old_type, old_idl_file, old_idl_file_path, stability="stable", - optional=False) - new = FieldCompatibility(new_type, new_idl_file, new_idl_file_path, stability="stable", - optional=False) + old = FieldCompatibility( + old_type, + old_idl_file, + old_idl_file_path, + stability="stable", + optional=False, + ) + new = FieldCompatibility( + new_type, + new_idl_file, + new_idl_file_path, + stability="stable", + optional=False, + ) - check_param_or_command_type(ctxt, - FieldCompatibilityPair(old, new, old_cmd.command_name, ""), - is_command_parameter=False) + check_param_or_command_type( + ctxt, + FieldCompatibilityPair(old, new, old_cmd.command_name, ""), + is_command_parameter=False, + ) # If old type is "namespacestring", the new namespace can be changed to any # of the other namespace types. elif old_type.name != "namespacestring": # Otherwise, the new namespace can only be changed to "ignored". if new_namespace != common.COMMAND_NAMESPACE_IGNORED: - ctxt.add_new_namespace_incompatible_error(old_cmd.command_name, old_namespace, - new_namespace, new_idl_file_path) + ctxt.add_new_namespace_incompatible_error( + old_cmd.command_name, + old_namespace, + new_namespace, + new_idl_file_path, + ) else: - assert False, 'unrecognized namespace option' + assert False, "unrecognized namespace option" -def check_error_reply(old_basic_types_path: str, new_basic_types_path: str, - old_import_directories: List[str], - new_import_directories: List[str]) -> IDLCompatibilityErrorCollection: +def check_error_reply( + old_basic_types_path: str, + new_basic_types_path: str, + old_import_directories: List[str], + new_import_directories: List[str], +) -> IDLCompatibilityErrorCollection: """Check IDL compatibility between old and new ErrorReply.""" old_idl_dir = os.path.dirname(old_basic_types_path) new_idl_dir = os.path.dirname(new_basic_types_path) - ctxt = IDLCompatibilityContext(old_idl_dir, new_idl_dir, IDLCompatibilityErrorCollection()) + ctxt = IDLCompatibilityContext( + old_idl_dir, new_idl_dir, IDLCompatibilityErrorCollection() + ) with open(old_basic_types_path) as old_file: - old_idl_file = parser.parse(old_file, old_basic_types_path, - CompilerImportResolver(old_import_directories), False) + old_idl_file = parser.parse( + old_file, + old_basic_types_path, + CompilerImportResolver(old_import_directories), + False, + ) if old_idl_file.errors: old_idl_file.errors.dump_errors() # If parsing old IDL files fails, it might be because the parser has been recently @@ -1426,26 +2027,40 @@ def check_error_reply(old_basic_types_path: str, new_basic_types_path: str, ctxt.add_missing_error_reply_struct_error(old_basic_types_path) else: with open(new_basic_types_path) as new_file: - new_idl_file = parser.parse(new_file, new_basic_types_path, - CompilerImportResolver(new_import_directories), False) + new_idl_file = parser.parse( + new_file, + new_basic_types_path, + CompilerImportResolver(new_import_directories), + False, + ) if new_idl_file.errors: new_idl_file.errors.dump_errors() raise ValueError(f"Cannot parse {new_basic_types_path}") - new_error_reply_struct = new_idl_file.spec.symbols.get_struct("ErrorReply") + new_error_reply_struct = new_idl_file.spec.symbols.get_struct( + "ErrorReply" + ) if new_error_reply_struct is None: ctxt.add_missing_error_reply_struct_error(new_basic_types_path) else: - check_reply_fields(ctxt, old_error_reply_struct, new_error_reply_struct, "n/a", - old_idl_file, new_idl_file, old_basic_types_path, - new_basic_types_path) + check_reply_fields( + ctxt, + old_error_reply_struct, + new_error_reply_struct, + "n/a", + old_idl_file, + new_idl_file, + old_basic_types_path, + new_basic_types_path, + ) ctxt.errors.dump_errors() return ctxt.errors def split_complex_checks( - complex_checks: List[syntax.AccessCheck]) -> Tuple[List[str], List[syntax.Privilege]]: + complex_checks: List[syntax.AccessCheck], +) -> Tuple[List[str], List[syntax.Privilege]]: """Split a list of AccessCheck into checks and privileges.""" checks = [x.check for x in complex_checks if x.check is not None] privileges = [x.privilege for x in complex_checks if x.privilege is not None] @@ -1462,10 +2077,13 @@ def map_complex_access_check_name(name: str) -> str: return name -def check_complex_checks(ctxt: IDLCompatibilityContext, - old_complex_checks: List[syntax.AccessCheck], - new_complex_checks: List[syntax.AccessCheck], cmd: syntax.Command, - new_idl_file_path: str) -> None: +def check_complex_checks( + ctxt: IDLCompatibilityContext, + old_complex_checks: List[syntax.AccessCheck], + new_complex_checks: List[syntax.AccessCheck], + cmd: syntax.Command, + new_idl_file_path: str, +) -> None: """Check the compatibility between complex access checks of the old and new command.""" cmd_name = cmd.command_name old_checks, old_privileges = split_complex_checks(old_complex_checks) @@ -1480,33 +2098,45 @@ def check_complex_checks(ctxt: IDLCompatibilityContext, if cmd_name in ALLOWED_NEW_ACCESS_CHECK_PRIVILEGES: new_privileges = [ - privilege for privilege in new_privileges if AllowedNewPrivilege.create_from(privilege) + privilege + for privilege in new_privileges + if AllowedNewPrivilege.create_from(privilege) not in ALLOWED_NEW_ACCESS_CHECK_PRIVILEGES[cmd_name] ] if (len(new_checks_normalized) + len(new_privileges)) > ( - len(old_checks_normalized) + len(old_privileges)): + len(old_checks_normalized) + len(old_privileges) + ): ctxt.add_new_additional_complex_access_check_error(cmd_name, new_idl_file_path) else: if not new_checks_normalized.issubset(old_checks_normalized): ctxt.add_new_complex_checks_not_subset_error(cmd_name, new_idl_file_path) if len(new_privileges) > len(old_privileges): - ctxt.add_new_complex_privileges_not_subset_error(cmd_name, new_idl_file_path) + ctxt.add_new_complex_privileges_not_subset_error( + cmd_name, new_idl_file_path + ) else: # Check that each new_privilege matches an old_privilege (the resource_pattern is # equal and the action_types are a subset of the old action_types). for new_privilege in new_privileges: for old_privilege in old_privileges: - if (new_privilege.resource_pattern == old_privilege.resource_pattern - and set(new_privilege.action_type).issubset(old_privilege.action_type)): + if ( + new_privilege.resource_pattern == old_privilege.resource_pattern + and set(new_privilege.action_type).issubset( + old_privilege.action_type + ) + ): old_privileges.remove(old_privilege) break else: - ctxt.add_new_complex_privileges_not_subset_error(cmd_name, new_idl_file_path) + ctxt.add_new_complex_privileges_not_subset_error( + cmd_name, new_idl_file_path + ) def split_complex_checks_agg_stages( - complex_checks: List[syntax.AccessCheck]) -> Dict[str, List[syntax.AccessCheck]]: + complex_checks: List[syntax.AccessCheck], +) -> Dict[str, List[syntax.AccessCheck]]: """Split a list of AccessChecks into a map keyed by aggregation stage (defaults to None).""" complex_checks_agg_stages: Dict[str, List[syntax.AccessCheck]] = dict() for access_check in complex_checks: @@ -1520,10 +2150,13 @@ def split_complex_checks_agg_stages( return complex_checks_agg_stages -def check_complex_checks_agg_stages(ctxt: IDLCompatibilityContext, - old_complex_checks: List[syntax.AccessCheck], - new_complex_checks: List[syntax.AccessCheck], - cmd: syntax.Command, new_idl_file_path: str) -> None: +def check_complex_checks_agg_stages( + ctxt: IDLCompatibilityContext, + old_complex_checks: List[syntax.AccessCheck], + new_complex_checks: List[syntax.AccessCheck], + cmd: syntax.Command, + new_idl_file_path: str, +) -> None: """Check the compatibility between complex access checks of the old and new agggreation stages.""" new_complex_checks_agg_stages = split_complex_checks_agg_stages(new_complex_checks) old_complex_checks_agg_stages = split_complex_checks_agg_stages(old_complex_checks) @@ -1533,61 +2166,102 @@ def check_complex_checks_agg_stages(ctxt: IDLCompatibilityContext, # are not present in the previous release. if agg_stage not in old_complex_checks_agg_stages: continue - check_complex_checks(ctxt, old_complex_checks_agg_stages[agg_stage], - new_complex_checks_agg_stages[agg_stage], cmd, new_idl_file_path) + check_complex_checks( + ctxt, + old_complex_checks_agg_stages[agg_stage], + new_complex_checks_agg_stages[agg_stage], + cmd, + new_idl_file_path, + ) -def check_security_access_checks(ctxt: IDLCompatibilityContext, - old_access_checks: syntax.AccessChecks, - new_access_checks: syntax.AccessChecks, cmd: syntax.Command, - new_idl_file_path: str) -> None: +def check_security_access_checks( + ctxt: IDLCompatibilityContext, + old_access_checks: syntax.AccessChecks, + new_access_checks: syntax.AccessChecks, + cmd: syntax.Command, + new_idl_file_path: str, +) -> None: """Check the compatibility between security access checks of the old and new command.""" # pylint:disable=too-many-nested-blocks cmd_name = cmd.command_name if old_access_checks is not None and new_access_checks is not None: old_access_check_type = old_access_checks.get_access_check_type() new_access_check_type = new_access_checks.get_access_check_type() - if old_access_check_type != new_access_check_type and CHANGED_ACCESS_CHECKS_TYPE.get( - cmd_name, None) != [old_access_check_type, new_access_check_type]: - ctxt.add_access_check_type_not_equal_error(cmd_name, old_access_check_type, - new_access_check_type, new_idl_file_path) + if ( + old_access_check_type != new_access_check_type + and CHANGED_ACCESS_CHECKS_TYPE.get(cmd_name, None) + != [old_access_check_type, new_access_check_type] + ): + ctxt.add_access_check_type_not_equal_error( + cmd_name, + old_access_check_type, + new_access_check_type, + new_idl_file_path, + ) else: old_simple_check = old_access_checks.simple new_simple_check = new_access_checks.simple if old_simple_check is not None and new_simple_check is not None: if old_simple_check.check != new_simple_check.check: - ctxt.add_check_not_equal_error(cmd_name, old_simple_check.check, - new_simple_check.check, new_idl_file_path) + ctxt.add_check_not_equal_error( + cmd_name, + old_simple_check.check, + new_simple_check.check, + new_idl_file_path, + ) else: old_privilege = old_simple_check.privilege new_privilege = new_simple_check.privilege if old_privilege is not None and new_privilege is not None: - if old_privilege.resource_pattern != new_privilege.resource_pattern: + if ( + old_privilege.resource_pattern + != new_privilege.resource_pattern + ): ctxt.add_resource_pattern_not_equal_error( - cmd_name, old_privilege.resource_pattern, - new_privilege.resource_pattern, new_idl_file_path) - if not set(new_privilege.action_type).issubset(old_privilege.action_type): - ctxt.add_new_action_types_not_subset_error(cmd_name, new_idl_file_path) + cmd_name, + old_privilege.resource_pattern, + new_privilege.resource_pattern, + new_idl_file_path, + ) + if not set(new_privilege.action_type).issubset( + old_privilege.action_type + ): + ctxt.add_new_action_types_not_subset_error( + cmd_name, new_idl_file_path + ) old_complex_checks = old_access_checks.complex new_complex_checks = new_access_checks.complex if old_complex_checks is not None and new_complex_checks is not None: - check_complex_checks_agg_stages(ctxt, old_complex_checks, new_complex_checks, cmd, - new_idl_file_path) + check_complex_checks_agg_stages( + ctxt, old_complex_checks, new_complex_checks, cmd, new_idl_file_path + ) elif new_access_checks is None and old_access_checks is not None: ctxt.add_removed_access_check_field_error(cmd_name, new_idl_file_path) - elif old_access_checks is None and new_access_checks is not None and cmd.api_version == '1': + elif ( + old_access_checks is None + and new_access_checks is not None + and cmd.api_version == "1" + ): ctxt.add_added_access_check_field_error(cmd_name, new_idl_file_path) -def check_compatibility(old_idl_dir: str, new_idl_dir: str, old_import_directories: List[str], - new_import_directories: List[str]) -> IDLCompatibilityErrorCollection: +def check_compatibility( + old_idl_dir: str, + new_idl_dir: str, + old_import_directories: List[str], + new_import_directories: List[str], +) -> IDLCompatibilityErrorCollection: """Check IDL compatibility between old and new IDL commands.""" - ctxt = IDLCompatibilityContext(old_idl_dir, new_idl_dir, IDLCompatibilityErrorCollection()) + ctxt = IDLCompatibilityContext( + old_idl_dir, new_idl_dir, IDLCompatibilityErrorCollection() + ) new_commands, new_command_file, new_command_file_path = get_new_commands( - ctxt, new_idl_dir, new_import_directories) + ctxt, new_idl_dir, new_import_directories + ) # Check new commands' compatibility with old ones. # Note, a command can be added to V1 at any time, it's ok if a @@ -1595,14 +2269,17 @@ def check_compatibility(old_idl_dir: str, new_idl_dir: str, old_import_directori old_commands: Dict[str, syntax.Command] = dict() for dirpath, _, filenames in os.walk(old_idl_dir): for old_filename in filenames: - if not old_filename.endswith('.idl') or old_filename in SKIPPED_FILES: + if not old_filename.endswith(".idl") or old_filename in SKIPPED_FILES: continue old_idl_file_path = os.path.join(dirpath, old_filename) with open(old_idl_file_path) as old_file: old_idl_file = parser.parse( - old_file, old_idl_file_path, - CompilerImportResolver(old_import_directories + [old_idl_dir]), False) + old_file, + old_idl_file_path, + CompilerImportResolver(old_import_directories + [old_idl_dir]), + False, + ) if old_idl_file.errors: old_idl_file.errors.dump_errors() # If parsing old IDL files fails, it might be because the parser has been @@ -1623,19 +2300,23 @@ def check_compatibility(old_idl_dir: str, new_idl_dir: str, old_import_directori if old_cmd.api_version != "1": # We're not ready to handle future API versions yet. ctxt.add_command_invalid_api_version_error( - old_cmd.command_name, old_cmd.api_version, old_idl_file_path) + old_cmd.command_name, old_cmd.api_version, old_idl_file_path + ) continue if old_cmd.command_name in old_commands: - ctxt.add_duplicate_command_name_error(old_cmd.command_name, old_idl_dir, - old_idl_file_path) + ctxt.add_duplicate_command_name_error( + old_cmd.command_name, old_idl_dir, old_idl_file_path + ) continue old_commands[old_cmd.command_name] = old_cmd if old_cmd.command_name not in new_commands: # Can't remove a command from V1 - ctxt.add_command_removed_error(old_cmd.command_name, old_idl_file_path) + ctxt.add_command_removed_error( + old_cmd.command_name, old_idl_file_path + ) continue new_cmd = new_commands[old_cmd.command_name] @@ -1643,41 +2324,74 @@ def check_compatibility(old_idl_dir: str, new_idl_dir: str, old_import_directori new_idl_file_path = new_command_file_path[old_cmd.command_name] if not old_cmd.strict and new_cmd.strict: - ctxt.add_command_strict_true_error(new_cmd.command_name, new_idl_file_path) + ctxt.add_command_strict_true_error( + new_cmd.command_name, new_idl_file_path + ) # Check compatibility of command's parameters. check_command_params_or_type_struct_fields( - ctxt, old_cmd, new_cmd, old_cmd.command_name, old_idl_file, new_idl_file, - old_idl_file_path, new_idl_file_path, is_command_parameter=True) + ctxt, + old_cmd, + new_cmd, + old_cmd.command_name, + old_idl_file, + new_idl_file, + old_idl_file_path, + new_idl_file_path, + is_command_parameter=True, + ) - check_namespace(ctxt, old_cmd, new_cmd, old_idl_file, new_idl_file, - old_idl_file_path, new_idl_file_path) + check_namespace( + ctxt, + old_cmd, + new_cmd, + old_idl_file, + new_idl_file, + old_idl_file_path, + new_idl_file_path, + ) old_reply = old_idl_file.spec.symbols.get_struct(old_cmd.reply_type) new_reply = new_idl_file.spec.symbols.get_struct(new_cmd.reply_type) - check_reply_fields(ctxt, old_reply, new_reply, old_cmd.command_name, - old_idl_file, new_idl_file, old_idl_file_path, - new_idl_file_path) + check_reply_fields( + ctxt, + old_reply, + new_reply, + old_cmd.command_name, + old_idl_file, + new_idl_file, + old_idl_file_path, + new_idl_file_path, + ) - check_security_access_checks(ctxt, old_cmd.access_check, new_cmd.access_check, - old_cmd, new_idl_file_path) + check_security_access_checks( + ctxt, + old_cmd.access_check, + new_cmd.access_check, + old_cmd, + new_idl_file_path, + ) ctxt.errors.dump_errors() return ctxt.errors -def get_generic_arguments(gen_args_file_path: str, - includes: List[str]) -> Tuple[Set[str], Set[str]]: +def get_generic_arguments( + gen_args_file_path: str, includes: List[str] +) -> Tuple[Set[str], Set[str]]: """Get arguments and reply fields from generic_argument.idl and check validity.""" arguments: Set[str] = set() reply_fields: Set[str] = set() with open(gen_args_file_path) as gen_args_file: - parsed_idl_file = parser.parse(gen_args_file, gen_args_file_path, - CompilerImportResolver(includes), False) + parsed_idl_file = parser.parse( + gen_args_file, gen_args_file_path, CompilerImportResolver(includes), False + ) if parsed_idl_file.errors: parsed_idl_file.errors.dump_errors() - raise ValueError(f"Cannot parse {gen_args_file_path} {parsed_idl_file.errors}") + raise ValueError( + f"Cannot parse {gen_args_file_path} {parsed_idl_file.errors}" + ) # The generic argument/reply field structs have been renamed a few times, so to # account for this when comparing against older releases, we try each set of names. @@ -1687,15 +2401,20 @@ def get_generic_arguments(gen_args_file_path: str, # 8.0.0rc4 ("GenericArgsAPIV1", "GenericReplyFieldsAPIV1"), # Before 8.0.0rc4 - ("generic_args_api_v1", "generic_reply_fields_api_v1") + ("generic_args_api_v1", "generic_reply_fields_api_v1"), ] for args_struct, reply_struct in struct_names: - generic_arguments = parsed_idl_file.spec.symbols.get_generic_argument_list(args_struct) + generic_arguments = parsed_idl_file.spec.symbols.get_generic_argument_list( + args_struct + ) if generic_arguments is None: continue else: - generic_reply_fields = parsed_idl_file.spec.symbols.get_generic_reply_field_list( - reply_struct) + generic_reply_fields = ( + parsed_idl_file.spec.symbols.get_generic_reply_field_list( + reply_struct + ) + ) break for argument in generic_arguments.fields: @@ -1710,18 +2429,28 @@ def get_generic_arguments(gen_args_file_path: str, def check_generic_arguments_compatibility( - old_gen_args_file_path: str, new_gen_args_file_path: str, old_includes: List[str], - new_includes: List[str]) -> IDLCompatibilityErrorCollection: + old_gen_args_file_path: str, + new_gen_args_file_path: str, + old_includes: List[str], + new_includes: List[str], +) -> IDLCompatibilityErrorCollection: """Check IDL compatibility between old and new generic_argument.idl files.""" # IDLCompatibilityContext takes in both 'old_idl_dir' and 'new_idl_dir', # but for generic_argument.idl, the parent directories aren't helpful for logging purposes. # Instead, we pass in "old generic_argument.idl" and "new generic_argument.idl" # to make error messages clearer. - ctxt = IDLCompatibilityContext("old generic_argument.idl", "new generic_argument.idl", - IDLCompatibilityErrorCollection()) + ctxt = IDLCompatibilityContext( + "old generic_argument.idl", + "new generic_argument.idl", + IDLCompatibilityErrorCollection(), + ) - old_arguments, old_reply_fields = get_generic_arguments(old_gen_args_file_path, old_includes) - new_arguments, new_reply_fields = get_generic_arguments(new_gen_args_file_path, new_includes) + old_arguments, old_reply_fields = get_generic_arguments( + old_gen_args_file_path, old_includes + ) + new_arguments, new_reply_fields = get_generic_arguments( + new_gen_args_file_path, new_includes + ) for old_argument in old_arguments: if old_argument not in new_arguments: @@ -1729,7 +2458,9 @@ def check_generic_arguments_compatibility( for old_reply_field in old_reply_fields: if old_reply_field not in new_reply_fields: - ctxt.add_generic_argument_removed_reply_field(old_reply_field, new_gen_args_file_path) + ctxt.add_generic_argument_removed_reply_field( + old_reply_field, new_gen_args_file_path + ) return ctxt.errors @@ -1737,19 +2468,40 @@ def check_generic_arguments_compatibility( def main(): """Run the script.""" arg_parser = argparse.ArgumentParser(description=__doc__) - arg_parser.add_argument("-v", "--verbose", action="count", help="Enable verbose logging") - arg_parser.add_argument("--old-include", dest="old_include", type=str, action="append", - default=[], help="Directory to search for old IDL import files") - arg_parser.add_argument("--new-include", dest="new_include", type=str, action="append", - default=[], help="Directory to search for new IDL import files") - arg_parser.add_argument("old_idl_dir", metavar="OLD_IDL_DIR", - help="Directory where old IDL files are located") - arg_parser.add_argument("new_idl_dir", metavar="NEW_IDL_DIR", - help="Directory where new IDL files are located") + arg_parser.add_argument( + "-v", "--verbose", action="count", help="Enable verbose logging" + ) + arg_parser.add_argument( + "--old-include", + dest="old_include", + type=str, + action="append", + default=[], + help="Directory to search for old IDL import files", + ) + arg_parser.add_argument( + "--new-include", + dest="new_include", + type=str, + action="append", + default=[], + help="Directory to search for new IDL import files", + ) + arg_parser.add_argument( + "old_idl_dir", + metavar="OLD_IDL_DIR", + help="Directory where old IDL files are located", + ) + arg_parser.add_argument( + "new_idl_dir", + metavar="NEW_IDL_DIR", + help="Directory where new IDL files are located", + ) args = arg_parser.parse_args() - error_coll = check_compatibility(args.old_idl_dir, args.new_idl_dir, args.old_include, - args.new_include) + error_coll = check_compatibility( + args.old_idl_dir, args.new_idl_dir, args.old_include, args.new_include + ) if error_coll.has_errors(): sys.exit(1) @@ -1761,15 +2513,21 @@ def main(): old_basic_types_path = locate_basic_types_idl(args.old_idl_dir) new_basic_types_path = locate_basic_types_idl(args.new_idl_dir) - error_reply_coll = check_error_reply(old_basic_types_path, new_basic_types_path, - args.old_include, args.new_include) + error_reply_coll = check_error_reply( + old_basic_types_path, new_basic_types_path, args.old_include, args.new_include + ) if error_reply_coll.has_errors(): sys.exit(1) - old_generic_args_path = os.path.join(args.old_idl_dir, "mongo/idl/generic_argument.idl") - new_generic_args_path = os.path.join(args.new_idl_dir, "mongo/idl/generic_argument.idl") + old_generic_args_path = os.path.join( + args.old_idl_dir, "mongo/idl/generic_argument.idl" + ) + new_generic_args_path = os.path.join( + args.new_idl_dir, "mongo/idl/generic_argument.idl" + ) error_gen_args_coll = check_generic_arguments_compatibility( - old_generic_args_path, new_generic_args_path, args.old_include, args.new_include) + old_generic_args_path, new_generic_args_path, args.old_include, args.new_include + ) if error_gen_args_coll.has_errors(): sys.exit(1) diff --git a/buildscripts/idl/idl_compatibility_errors.py b/buildscripts/idl/idl_compatibility_errors.py index e536f4d6b64..31f5007e61d 100644 --- a/buildscripts/idl/idl_compatibility_errors.py +++ b/buildscripts/idl/idl_compatibility_errors.py @@ -151,8 +151,15 @@ class IDLCompatibilityError(object): - file - a string, the path to the IDL file where the error occurred. """ - def __init__(self, error_id: str, command_name: str, msg: str, old_idl_dir: str, - new_idl_dir: str, file: str) -> None: + def __init__( + self, + error_id: str, + command_name: str, + msg: str, + old_idl_dir: str, + new_idl_dir: str, + file: str, + ) -> None: """Construct an IDLCompatibility error.""" self.error_id = error_id self.command_name = command_name @@ -169,8 +176,13 @@ class IDLCompatibilityError(object): Error in compatibility_test_pass_new/file.idl: ID0001: 'command' has an invalid API version '2'. """ - msg = "Comparing %s and %s: Error in %s: %s: %s" % (self.old_idl_dir, self.new_idl_dir, - self.file, self.error_id, self.msg) + msg = "Comparing %s and %s: Error in %s: %s: %s" % ( + self.old_idl_dir, + self.new_idl_dir, + self.file, + self.error_id, + self.msg, + ) return msg @@ -181,11 +193,21 @@ class IDLCompatibilityErrorCollection(object): """Initialize IDLCompatibilityErrorCollection.""" self._errors: List[IDLCompatibilityError] = [] - def add(self, error_id: str, command_name: str, msg: str, old_idl_dir: str, new_idl_dir: str, - file: str) -> None: + def add( + self, + error_id: str, + command_name: str, + msg: str, + old_idl_dir: str, + new_idl_dir: str, + file: str, + ) -> None: """Add an error message with directory information.""" self._errors.append( - IDLCompatibilityError(error_id, command_name, msg, old_idl_dir, new_idl_dir, file)) + IDLCompatibilityError( + error_id, command_name, msg, old_idl_dir, new_idl_dir, file + ) + ) def has_errors(self) -> bool: """Have any errors been added to the collection?.""" @@ -209,8 +231,9 @@ class IDLCompatibilityErrorCollection(object): assert error is not None return error - def get_error_by_command_name_and_error_id(self, command_name: str, - error_id: str) -> IDLCompatibilityError: + def get_error_by_command_name_and_error_id( + self, command_name: str, error_id: str + ) -> IDLCompatibilityError: """Get the first error in the error collection from command_name with error_id.""" command_name_list = [a for a in self._errors if a.command_name == command_name] error_id_list = [a for a in command_name_list if a.error_id == error_id] @@ -218,7 +241,9 @@ class IDLCompatibilityErrorCollection(object): assert error is not None return error - def get_all_errors_by_command_name(self, command_name: str) -> List[IDLCompatibilityError]: + def get_all_errors_by_command_name( + self, command_name: str + ) -> List[IDLCompatibilityError]: """Get all the errors in the error collection with the command command_name.""" return [a for a in self._errors if a.command_name == command_name] @@ -229,7 +254,10 @@ class IDLCompatibilityErrorCollection(object): def dump_errors(self) -> None: """Print the list of errors.""" error_list = self.to_list() - print("Errors found while checking IDL compatibility: %s errors:" % (len(error_list))) + print( + "Errors found while checking IDL compatibility: %s errors:" + % (len(error_list)) + ) for error_msg in error_list: print("%s\n\n" % error_msg) print("------------------------------------------------") @@ -240,7 +268,7 @@ class IDLCompatibilityErrorCollection(object): def __str__(self) -> str: """Return a list of errors.""" - return ', '.join(self.to_list()) + return ", ".join(self.to_list()) class IDLCompatibilityContext(object): @@ -252,8 +280,12 @@ class IDLCompatibilityContext(object): - single class responsible for producing actual error messages. """ - def __init__(self, old_idl_dir: str, new_idl_dir: str, - errors: IDLCompatibilityErrorCollection) -> None: + def __init__( + self, + old_idl_dir: str, + new_idl_dir: str, + errors: IDLCompatibilityErrorCollection, + ) -> None: """Construct a new IDLCompatibilityContext.""" self.old_idl_dir = old_idl_dir self.new_idl_dir = new_idl_dir @@ -261,79 +293,128 @@ class IDLCompatibilityContext(object): def _add_error(self, error_id: str, command_name: str, msg: str, file: str) -> None: """Add an error with an error id and error message.""" - self.errors.add(error_id, command_name, msg, self.old_idl_dir, self.new_idl_dir, file) + self.errors.add( + error_id, command_name, msg, self.old_idl_dir, self.new_idl_dir, file + ) - def add_command_invalid_api_version_error(self, command_name: str, api_version: str, - file: str) -> None: + def add_command_invalid_api_version_error( + self, command_name: str, api_version: str, file: str + ) -> None: """Add an error about a command with an invalid api version.""" - self._add_error(ERROR_ID_COMMAND_INVALID_API_VERSION, command_name, - "'%s' has an invalid API version '%s'" % (command_name, api_version), file) + self._add_error( + ERROR_ID_COMMAND_INVALID_API_VERSION, + command_name, + "'%s' has an invalid API version '%s'" % (command_name, api_version), + file, + ) def add_command_removed_error(self, command_name: str, file: str) -> None: """Add an error about a command that was removed.""" self._add_error( - ERROR_ID_REMOVED_COMMAND, command_name, - "The command '%s' was present in the stable API but was removed." % (command_name), - file) + ERROR_ID_REMOVED_COMMAND, + command_name, + "The command '%s' was present in the stable API but was removed." + % (command_name), + file, + ) def add_command_strict_true_error(self, command_name: str, file: str) -> None: """Add an error about a command that changes from strict: false to strict: true.""" self._add_error( - ERROR_ID_COMMAND_STRICT_TRUE_ERROR, command_name, + ERROR_ID_COMMAND_STRICT_TRUE_ERROR, + command_name, "'%s' changes from strict: false in the old definition to strict: true in the new definition." - % (command_name), file) + % (command_name), + file, + ) - def add_duplicate_command_name_error(self, command_name: str, dir_name: str, file: str) -> None: + def add_duplicate_command_name_error( + self, command_name: str, dir_name: str, file: str + ) -> None: """Add an error about a duplicate command name within a directory.""" - self._add_error(ERROR_ID_DUPLICATE_COMMAND_NAME, command_name, - "'%s' has duplicate command: '%s'" % (dir_name, command_name), file) + self._add_error( + ERROR_ID_DUPLICATE_COMMAND_NAME, + command_name, + "'%s' has duplicate command: '%s'" % (dir_name, command_name), + file, + ) - def add_reply_field_not_subset_error(self, command_name: str, field_name: str, type_name: str, - file: str) -> None: + def add_reply_field_not_subset_error( + self, command_name: str, field_name: str, type_name: str, file: str + ) -> None: """Add an error about the reply field not being a subset.""" self._add_error( - ERROR_ID_REPLY_FIELD_NOT_SUBSET, command_name, + ERROR_ID_REPLY_FIELD_NOT_SUBSET, + command_name, "'%s' has a reply field or sub-field '%s' with type '%s' " "that is not a subset of the type of the older definition " - "of this reply field." % (command_name, field_name, type_name), file) + "of this reply field." % (command_name, field_name, type_name), + file, + ) - def add_command_or_param_type_invalid_error(self, command_name: str, file: str, - field_name: Optional[str], - is_command_parameter: bool) -> None: + def add_command_or_param_type_invalid_error( + self, + command_name: str, + file: str, + field_name: Optional[str], + is_command_parameter: bool, + ) -> None: """Add an error about the command parameter or type being invalid.""" if is_command_parameter: self._add_error( - ERROR_ID_COMMAND_PARAMETER_TYPE_INVALID, command_name, - "The '%s' command has a field or sub-field '%s' that has an invalid type" % - (command_name, field_name), file) + ERROR_ID_COMMAND_PARAMETER_TYPE_INVALID, + command_name, + "The '%s' command has a field or sub-field '%s' that has an invalid type" + % (command_name, field_name), + file, + ) else: self._add_error( - ERROR_ID_COMMAND_TYPE_INVALID, command_name, - "'%s' has an invalid type or has a sub-struct with an invalid type" % - (command_name), file) + ERROR_ID_COMMAND_TYPE_INVALID, + command_name, + "'%s' has an invalid type or has a sub-struct with an invalid type" + % (command_name), + file, + ) - def add_command_or_param_type_not_superset_error(self, command_name: str, type_name: str, - file: str, field_name: Optional[str], - is_command_parameter: bool) -> None: + def add_command_or_param_type_not_superset_error( + self, + command_name: str, + type_name: str, + file: str, + field_name: Optional[str], + is_command_parameter: bool, + ) -> None: """Add an error about the command or parameter type not being a superset.""" if is_command_parameter: self._add_error( - ERROR_ID_COMMAND_PARAMETER_TYPE_NOT_SUPERSET, command_name, + ERROR_ID_COMMAND_PARAMETER_TYPE_NOT_SUPERSET, + command_name, "The new definition of command '%s' has field or sub-field '%s' with type '%s' " "that is not a superset of the " - "type of the existing definition of the field." % (command_name, field_name, - type_name), file) + "type of the existing definition of the field." + % (command_name, field_name, type_name), + file, + ) else: self._add_error( - ERROR_ID_COMMAND_TYPE_NOT_SUPERSET, command_name, + ERROR_ID_COMMAND_TYPE_NOT_SUPERSET, + command_name, "The new definition of command '%s' or its sub-struct has type '%s' that is not a " "superset of " - "the type of the existing definition of this command/struct." % (command_name, - type_name), file) + "the type of the existing definition of this command/struct." + % (command_name, type_name), + file, + ) - def add_command_or_param_type_contains_validator_error(self, command_name: str, field_name: str, - file: str, type_name: Optional[str], - is_command_parameter: bool) -> None: + def add_command_or_param_type_contains_validator_error( + self, + command_name: str, + field_name: str, + file: str, + type_name: Optional[str], + is_command_parameter: bool, + ) -> None: """ Add an error about a type containing a validator. @@ -342,42 +423,69 @@ class IDLCompatibilityContext(object): """ if is_command_parameter: self._add_error( - ERROR_ID_COMMAND_PARAMETER_CONTAINS_VALIDATOR, command_name, + ERROR_ID_COMMAND_PARAMETER_CONTAINS_VALIDATOR, + command_name, "The new definition of the field or sub-field '%s' for the command '%s' contains a validator " - "while the old definition does not." % (field_name, command_name), file) + "while the old definition does not." % (field_name, command_name), + file, + ) else: self._add_error( - ERROR_ID_COMMAND_TYPE_CONTAINS_VALIDATOR, command_name, + ERROR_ID_COMMAND_TYPE_CONTAINS_VALIDATOR, + command_name, "The new definition of the command '%s' or its sub-struct has type '%s' with field '%s' that " "contains a validator when " - "the old definition did not." % (command_name, type_name, field_name), file) + "the old definition did not." % (command_name, type_name, field_name), + file, + ) def add_command_or_param_type_validators_not_equal_error( - self, command_name: str, field_name: str, file: str, type_name: Optional[str], - is_command_parameter: bool) -> None: + self, + command_name: str, + field_name: str, + file: str, + type_name: Optional[str], + is_command_parameter: bool, + ) -> None: # pylint: disable=invalid-name """Add an error about the new and old command or parameter type validators not being equal.""" if is_command_parameter: self._add_error( - ERROR_ID_COMMAND_PARAMETER_VALIDATORS_NOT_EQUAL, command_name, + ERROR_ID_COMMAND_PARAMETER_VALIDATORS_NOT_EQUAL, + command_name, "Validator for field or sub-field '%s' in old definition of command '%s' is not equal " - "to the validator in the new definition of the field" % (field_name, command_name), - file) + "to the validator in the new definition of the field" + % (field_name, command_name), + file, + ) else: self._add_error( - ERROR_ID_COMMAND_TYPE_VALIDATORS_NOT_EQUAL, command_name, + ERROR_ID_COMMAND_TYPE_VALIDATORS_NOT_EQUAL, + command_name, "Validator for field '%s' in type '%s' in old definition of command '%s' or its " "sub-struct is not equal to the validator in the new defition of the command/struct." - % (field_name, type_name, command_name), file) + % (field_name, type_name, command_name), + file, + ) def add_missing_error_reply_struct_error(self, file: str) -> None: """Add an error about the file missing the ErrorReply struct.""" - self._add_error(ERROR_ID_MISSING_ERROR_REPLY_STRUCT, "n/a", - ("'%s' is missing the ErrorReply struct") % (file), file) + self._add_error( + ERROR_ID_MISSING_ERROR_REPLY_STRUCT, + "n/a", + ("'%s' is missing the ErrorReply struct") % (file), + file, + ) def add_new_command_or_param_type_bson_any_error( - self, command_name: str, old_type: str, new_type: str, file: str, - field_name: Optional[str], is_command_parameter: bool) -> None: + self, + command_name: str, + old_type: str, + new_type: str, + file: str, + field_name: Optional[str], + is_command_parameter: bool, + ) -> None: """ Add an error about BSON serialization type. @@ -387,22 +495,34 @@ class IDLCompatibilityContext(object): """ if is_command_parameter: self._add_error( - ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY, command_name, + ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY, + command_name, "The new definition of '%s' command has field or sub-field '%s' that has type '%s' " "that has a bson serialization type 'any', while the existing older definition had type %s" - " that did not have bson serialization type 'any'" % (command_name, field_name, - new_type, old_type), file) + " that did not have bson serialization type 'any'" + % (command_name, field_name, new_type, old_type), + file, + ) else: self._add_error( - ERROR_ID_NEW_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY, command_name, + ERROR_ID_NEW_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY, + command_name, "The new definition of '%s' command or its sub-struct has type '%s' that " "has a bson serialization type 'any', while the existing older definition had type %s" - " that did not have bson serialization type 'any'" % (command_name, new_type, - old_type), file) + " that did not have bson serialization type 'any'" + % (command_name, new_type, old_type), + file, + ) def add_new_command_or_param_type_enum_or_struct_error( - self, command_name: str, new_type: str, old_type: str, file: str, - field_name: Optional[str], is_command_parameter: bool) -> None: + self, + command_name: str, + new_type: str, + old_type: str, + file: str, + field_name: Optional[str], + is_command_parameter: bool, + ) -> None: """ Add an error about a type that is an enum or struct. @@ -411,21 +531,33 @@ class IDLCompatibilityContext(object): """ if is_command_parameter: self._add_error( - ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_ENUM_OR_STRUCT, command_name, + ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_ENUM_OR_STRUCT, + command_name, "The command '%s' has field or sub-field '%s' of type '%s' that is an enum or " "struct while the old definition of the field type is a non-enum or " - "non-struct of type '%s'." % (command_name, field_name, new_type, old_type), file) + "non-struct of type '%s'." + % (command_name, field_name, new_type, old_type), + file, + ) else: self._add_error( - ERROR_ID_NEW_COMMAND_TYPE_ENUM_OR_STRUCT, command_name, + ERROR_ID_NEW_COMMAND_TYPE_ENUM_OR_STRUCT, + command_name, "The command '%s' or its sub-struct has type '%s' that is an enum " "or struct while the old definition of the" - "type was a non-enum or struct of type '%s'." % (command_name, new_type, old_type), - file) + "type was a non-enum or struct of type '%s'." + % (command_name, new_type, old_type), + file, + ) def add_new_param_or_command_type_field_added_required_error( - self, command_name: str, field_name: str, file: str, type_name: str, - is_command_parameter: bool) -> None: + self, + command_name: str, + field_name: str, + file: str, + type_name: str, + is_command_parameter: bool, + ) -> None: # pylint: disable=invalid-name """ Add a new added required parameter or command type field error. @@ -436,36 +568,58 @@ class IDLCompatibilityContext(object): """ if is_command_parameter: self._add_error( - ERROR_ID_ADDED_REQUIRED_COMMAND_PARAMETER, command_name, + ERROR_ID_ADDED_REQUIRED_COMMAND_PARAMETER, + command_name, "New definition of field or sub-field '%s' for command '%s' is required when it should " - "be optional or have a default value." % (field_name, command_name), file) + "be optional or have a default value." % (field_name, command_name), + file, + ) else: self._add_error( - ERROR_ID_NEW_COMMAND_TYPE_FIELD_ADDED_REQUIRED, command_name, + ERROR_ID_NEW_COMMAND_TYPE_FIELD_ADDED_REQUIRED, + command_name, "The new definition of command '%s' or its sub-struct has type '%s' with an added and " "required type field '%s' that did not exist " "in the old definition of struct type. The field should be optional or have a default value." - % (command_name, type_name, field_name), file) + % (command_name, type_name, field_name), + file, + ) - def add_new_param_or_command_type_field_missing_error(self, command_name: str, field_name: str, - file: str, type_name: str, - is_command_parameter: bool) -> None: + def add_new_param_or_command_type_field_missing_error( + self, + command_name: str, + field_name: str, + file: str, + type_name: str, + is_command_parameter: bool, + ) -> None: """Add an error about a parameter or command type field that is missing in the new command.""" if is_command_parameter: self._add_error( - ERROR_ID_REMOVED_COMMAND_PARAMETER, command_name, + ERROR_ID_REMOVED_COMMAND_PARAMETER, + command_name, "Field or sub-field '%s' for command '%s' was removed from the new definition of the" - "command." % (field_name, command_name), file) + "command." % (field_name, command_name), + file, + ) else: self._add_error( - ERROR_ID_NEW_COMMAND_TYPE_FIELD_MISSING, command_name, + ERROR_ID_NEW_COMMAND_TYPE_FIELD_MISSING, + command_name, "The command '%s' or its sub-struct has type '%s' that is missing a " - "field '%s' that exists in the old definition of the command/struct." % - (command_name, type_name, field_name), file) + "field '%s' that exists in the old definition of the command/struct." + % (command_name, type_name, field_name), + file, + ) - def add_new_param_or_command_type_field_required_error(self, command_name: str, field_name: str, - file: str, type_name: Optional[str], - is_command_parameter: bool) -> None: + def add_new_param_or_command_type_field_required_error( + self, + command_name: str, + field_name: str, + file: str, + type_name: Optional[str], + is_command_parameter: bool, + ) -> None: """ Add a required parameter or command type field error. @@ -474,19 +628,30 @@ class IDLCompatibilityContext(object): """ if is_command_parameter: self._add_error( - ERROR_ID_COMMAND_PARAMETER_REQUIRED, command_name, + ERROR_ID_COMMAND_PARAMETER_REQUIRED, + command_name, "'%s' has a required field or sub-field '%s' that was optional in the old " - "definition of the struct." % (command_name, field_name), file) + "definition of the struct." % (command_name, field_name), + file, + ) else: self._add_error( - ERROR_ID_NEW_COMMAND_TYPE_FIELD_REQUIRED, command_name, + ERROR_ID_NEW_COMMAND_TYPE_FIELD_REQUIRED, + command_name, "'%s' or its sub-struct has type '%s' with a required type field '%s' " - "that was optional in the old definition of the struct type." % - (command_name, type_name, field_name), file) + "that was optional in the old definition of the struct type." + % (command_name, type_name, field_name), + file, + ) def add_new_param_or_command_type_field_stable_required_no_default_error( - self, struct_name: str, field_name: str, file: str, type_name: Optional[str], - is_command_parameter: bool) -> None: + self, + struct_name: str, + field_name: str, + file: str, + type_name: Optional[str], + is_command_parameter: bool, + ) -> None: # pylint: disable=invalid-name """ Add a stable required parameter or command type field error. @@ -497,22 +662,35 @@ class IDLCompatibilityContext(object): """ if is_command_parameter: self._add_error( - ERROR_ID_COMMAND_PARAMETER_STABLE_REQUIRED_NO_DEFAULT, struct_name, + ERROR_ID_COMMAND_PARAMETER_STABLE_REQUIRED_NO_DEFAULT, + struct_name, "'%s' has a stable required field '%s' with no default that was unstable and not required in the" " old definition of the struct." - "The new definition of the field should be optional or have a default value" % - (struct_name, field_name), file) + "The new definition of the field should be optional or have a default value" + % (struct_name, field_name), + file, + ) else: self._add_error( - ERROR_ID_NEW_COMMAND_TYPE_FIELD_STABLE_REQUIRED_NO_DEFAULT, struct_name, - ("'%s' has type '%s' with a stable and required type field '%s' with no default " - "that was unstable and not required in the old definition of the struct type." - "The new definition of the field should be optional or have a default value") % - (struct_name, type_name, field_name), file) + ERROR_ID_NEW_COMMAND_TYPE_FIELD_STABLE_REQUIRED_NO_DEFAULT, + struct_name, + ( + "'%s' has type '%s' with a stable and required type field '%s' with no default " + "that was unstable and not required in the old definition of the struct type." + "The new definition of the field should be optional or have a default value" + ) + % (struct_name, type_name, field_name), + file, + ) - def add_new_param_or_command_type_field_unstable_error(self, command_name: str, field_name: str, - file: str, type_name: Optional[str], - is_command_parameter: bool) -> None: + def add_new_param_or_command_type_field_unstable_error( + self, + command_name: str, + field_name: str, + file: str, + type_name: Optional[str], + is_command_parameter: bool, + ) -> None: """ Add an unstable parameter or command type field error. @@ -521,19 +699,31 @@ class IDLCompatibilityContext(object): """ if is_command_parameter: self._add_error( - ERROR_ID_COMMAND_PARAMETER_UNSTABLE, command_name, + ERROR_ID_COMMAND_PARAMETER_UNSTABLE, + command_name, "'%s' has an unstable field or sub-field '%s' that was stable in the old definition" - " of the struct." % (command_name, field_name), file) + " of the struct." % (command_name, field_name), + file, + ) else: self._add_error( - ERROR_ID_NEW_COMMAND_TYPE_FIELD_UNSTABLE, command_name, + ERROR_ID_NEW_COMMAND_TYPE_FIELD_UNSTABLE, + command_name, "'%s' or its sub-struct has type '%s' with an unstable " "field '%s' that was stable in the old definition of the " - "struct type." % (command_name, type_name, field_name), file) + "struct type." % (command_name, type_name, field_name), + file, + ) def add_new_command_or_param_type_not_enum_error( - self, command_name: str, new_type: str, old_type: str, file: str, - field_name: Optional[str], is_command_parameter: bool) -> None: + self, + command_name: str, + new_type: str, + old_type: str, + file: str, + field_name: Optional[str], + is_command_parameter: bool, + ) -> None: """ Add an not enum parameter or command type field error. @@ -542,37 +732,60 @@ class IDLCompatibilityContext(object): """ if is_command_parameter: self._add_error( - ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_NOT_ENUM, command_name, + ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_NOT_ENUM, + command_name, "The '%s' command has field or sub-field '%s' of type '%s' that is " - "not an enum while the old definition of the field type was an enum of type '%s'." % - (command_name, field_name, new_type, old_type), file) + "not an enum while the old definition of the field type was an enum of type '%s'." + % (command_name, field_name, new_type, old_type), + file, + ) else: self._add_error( - ERROR_ID_NEW_COMMAND_TYPE_NOT_ENUM, command_name, + ERROR_ID_NEW_COMMAND_TYPE_NOT_ENUM, + command_name, "'%s' or its sub-struct has type '%s' that is not an enum while the old definition of the type was an enum of type '%s'." - % (command_name, new_type, old_type), file) + % (command_name, new_type, old_type), + file, + ) def add_new_command_or_param_type_not_struct_error( - self, command_name: str, new_type: str, old_type: str, file: str, - field_name: Optional[str], is_command_parameter: bool) -> None: + self, + command_name: str, + new_type: str, + old_type: str, + file: str, + field_name: Optional[str], + is_command_parameter: bool, + ) -> None: """Add an error about the new command or parameter type not being a struct when the old one is.""" if is_command_parameter: self._add_error( - ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_NOT_STRUCT, command_name, + ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_NOT_STRUCT, + command_name, "The '%s' command has field or sub-field '%s' of type '%s' that is " "not a struct while the old definition of the " - "field type was a struct of type '%s'." % (command_name, field_name, new_type, - old_type), file) + "field type was a struct of type '%s'." + % (command_name, field_name, new_type, old_type), + file, + ) else: self._add_error( - ERROR_ID_NEW_COMMAND_TYPE_NOT_STRUCT, command_name, + ERROR_ID_NEW_COMMAND_TYPE_NOT_STRUCT, + command_name, "'%s' or its sub-struct has type '%s' that is not a " - "struct while the old definition of the type was a struct of type '%s'." % - (command_name, new_type, old_type), file) + "struct while the old definition of the type was a struct of type '%s'." + % (command_name, new_type, old_type), + file, + ) - def add_new_command_or_param_type_not_variant_type_error(self, command_name: str, new_type: str, - file: str, field_name: Optional[str], - is_command_parameter: bool) -> None: + def add_new_command_or_param_type_not_variant_type_error( + self, + command_name: str, + new_type: str, + file: str, + field_name: Optional[str], + is_command_parameter: bool, + ) -> None: # pylint: disable=invalid-name """ Add an error about the new command or parameter type not being a variant type. @@ -582,20 +795,37 @@ class IDLCompatibilityContext(object): """ if is_command_parameter: - self._add_error(ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_NOT_VARIANT, command_name, - ("The '%s' command has field or sub-field '%s' of type '%s' that is " - "not variant while the older definition of the field type is variant.") - % (command_name, field_name, new_type), file) + self._add_error( + ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_NOT_VARIANT, + command_name, + ( + "The '%s' command has field or sub-field '%s' of type '%s' that is " + "not variant while the older definition of the field type is variant." + ) + % (command_name, field_name, new_type), + file, + ) else: - self._add_error(ERROR_ID_NEW_COMMAND_TYPE_NOT_VARIANT, command_name, - ("'%s' or its sub-struct has type '%s' that is not " - "variant while the " - "older definition of the type is variant.") % (command_name, new_type), - file) + self._add_error( + ERROR_ID_NEW_COMMAND_TYPE_NOT_VARIANT, + command_name, + ( + "'%s' or its sub-struct has type '%s' that is not " + "variant while the " + "older definition of the type is variant." + ) + % (command_name, new_type), + file, + ) def add_new_command_or_param_variant_type_not_superset_error( - self, command_name: str, variant_type_name: str, file: str, field_name: Optional[str], - is_command_parameter: bool) -> None: + self, + command_name: str, + variant_type_name: str, + file: str, + field_name: Optional[str], + is_command_parameter: bool, + ) -> None: # pylint: disable=invalid-name """ Add an error about the new variant types not being a superset. @@ -605,23 +835,39 @@ class IDLCompatibilityContext(object): """ if is_command_parameter: self._add_error( - ERROR_ID_NEW_COMMAND_PARAMETER_VARIANT_TYPE_NOT_SUPERSET, command_name, - ("The '%s' command has field or sub-field '%s' of variant types that is not " - "a superset of the older definition of the field variant types: " - "The type '%s' is in the old definition of the field types but not the new " - "definition of the field types.") % (command_name, field_name, variant_type_name), - file) + ERROR_ID_NEW_COMMAND_PARAMETER_VARIANT_TYPE_NOT_SUPERSET, + command_name, + ( + "The '%s' command has field or sub-field '%s' of variant types that is not " + "a superset of the older definition of the field variant types: " + "The type '%s' is in the old definition of the field types but not the new " + "definition of the field types." + ) + % (command_name, field_name, variant_type_name), + file, + ) else: self._add_error( - ERROR_ID_NEW_COMMAND_VARIANT_TYPE_NOT_SUPERSET, command_name, - ("'%s' or its sub-struct has variant types that is not a supserset " - "of the older definition of the command variant types: The type '%s' " - "is in the old definition of the command types but not the new definition of " - "the command types.") % (command_name, variant_type_name), file) + ERROR_ID_NEW_COMMAND_VARIANT_TYPE_NOT_SUPERSET, + command_name, + ( + "'%s' or its sub-struct has variant types that is not a supserset " + "of the older definition of the command variant types: The type '%s' " + "is in the old definition of the command types but not the new definition of " + "the command types." + ) + % (command_name, variant_type_name), + file, + ) def add_new_command_or_param_chained_type_not_superset_error( - self, command_name: str, chained_type_name: str, file: str, field_name: Optional[str], - is_command_parameter: bool) -> None: + self, + command_name: str, + chained_type_name: str, + file: str, + field_name: Optional[str], + is_command_parameter: bool, + ) -> None: # pylint: disable=invalid-name """ Add an error about the new chained types not being a superset. @@ -631,49 +877,77 @@ class IDLCompatibilityContext(object): """ if is_command_parameter: self._add_error( - ERROR_ID_NEW_COMMAND_PARAMETER_CHAINED_TYPE_NOT_SUPERSET, command_name, - ("The '%s' command has field or sub-field '%s' of chained types that is not " - "a superset of the corresponding old definition of the field's chained types: " - "The type '%s' is in the old definition of the field types but not the new " - "definition of the field types.") % (command_name, field_name, chained_type_name), - file) + ERROR_ID_NEW_COMMAND_PARAMETER_CHAINED_TYPE_NOT_SUPERSET, + command_name, + ( + "The '%s' command has field or sub-field '%s' of chained types that is not " + "a superset of the corresponding old definition of the field's chained types: " + "The type '%s' is in the old definition of the field types but not the new " + "definition of the field types." + ) + % (command_name, field_name, chained_type_name), + file, + ) else: self._add_error( - ERROR_ID_NEW_COMMAND_CHAINED_TYPE_NOT_SUPERSET, command_name, - ("'%s' or its sub-struct has chained types that is not a supserset " - "of the corresponding old definition of the command chained types: The type '%s' " - "is in the old definition of the command types but not the new definition of the " - "command types.") % (command_name, chained_type_name), file) + ERROR_ID_NEW_COMMAND_CHAINED_TYPE_NOT_SUPERSET, + command_name, + ( + "'%s' or its sub-struct has chained types that is not a supserset " + "of the corresponding old definition of the command chained types: The type '%s' " + "is in the old definition of the command types but not the new definition of the " + "command types." + ) + % (command_name, chained_type_name), + file, + ) - def add_new_namespace_incompatible_error(self, command_name: str, old_namespace: str, - new_namespace: str, file: str) -> None: + def add_new_namespace_incompatible_error( + self, command_name: str, old_namespace: str, new_namespace: str, file: str + ) -> None: """Add an error about the new namespace being incompatible with the old namespace.""" self._add_error( - ERROR_ID_NEW_NAMESPACE_INCOMPATIBLE, command_name, + ERROR_ID_NEW_NAMESPACE_INCOMPATIBLE, + command_name, "The new definition of '%s' has namespace '%s' that is incompatible with the old definition " - " of the command with namespace '%s'." % (command_name, new_namespace, old_namespace), - file) + " of the command with namespace '%s'." + % (command_name, new_namespace, old_namespace), + file, + ) - def add_new_reply_field_missing_error(self, command_name: str, field_name: str, - file: str) -> None: + def add_new_reply_field_missing_error( + self, command_name: str, field_name: str, file: str + ) -> None: """Add an error about the new command missing a reply field that exists in the old command.""" self._add_error( - ERROR_ID_NEW_REPLY_FIELD_MISSING, command_name, + ERROR_ID_NEW_REPLY_FIELD_MISSING, + command_name, "'%s' is missing a reply field or sub-field '%s' that exists in the old definition of " - "the command." % (command_name, field_name), file) + "the command." % (command_name, field_name), + file, + ) - def add_new_reply_field_optional_error(self, command_name: str, field_name: str, - file: str) -> None: + def add_new_reply_field_optional_error( + self, command_name: str, field_name: str, file: str + ) -> None: """Add an error about the new command reply field being optional when the old reply field is not.""" self._add_error( - ERROR_ID_NEW_REPLY_FIELD_OPTIONAL, command_name, + ERROR_ID_NEW_REPLY_FIELD_OPTIONAL, + command_name, "'%s' has an optional reply field or sub-field '%s' " - "that was non-optional in the old definition of the command." % (command_name, - field_name), file) + "that was non-optional in the old definition of the command." + % (command_name, field_name), + file, + ) - def add_new_reply_field_bson_any_error(self, command_name: str, field_name: str, - old_field_type: str, new_field_type: str, - file: str) -> None: + def add_new_reply_field_bson_any_error( + self, + command_name: str, + field_name: str, + old_field_type: str, + new_field_type: str, + file: str, + ) -> None: """ Add an error about the new reply field type's 'any' bson serialization type. @@ -681,14 +955,20 @@ class IDLCompatibilityContext(object): 'any' when it was not 'any' in the old type or it is not explicitly allowed. """ self._add_error( - ERROR_ID_NEW_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY, command_name, - ("The new definition of '%s' has a reply field or sub-field '%s' of type '%s' " - "that has a bson serialization type 'any', while the old definition" - " had type '%s' that did not have bson serialization type 'any'") % - (command_name, field_name, new_field_type, old_field_type), file) + ERROR_ID_NEW_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY, + command_name, + ( + "The new definition of '%s' has a reply field or sub-field '%s' of type '%s' " + "that has a bson serialization type 'any', while the old definition" + " had type '%s' that did not have bson serialization type 'any'" + ) + % (command_name, field_name, new_field_type, old_field_type), + file, + ) - def add_old_reply_field_bson_any_not_allowed_error(self, command_name: str, field_name: str, - type_name: str, file: str) -> None: + def add_old_reply_field_bson_any_not_allowed_error( + self, command_name: str, field_name: str, type_name: str, file: str + ) -> None: """ Add an error about the old reply field bson serialization type being 'any'. @@ -696,13 +976,20 @@ class IDLCompatibilityContext(object): type 'any' when it is not explicitly allowed. """ self._add_error( - ERROR_ID_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED, command_name, - ("The old definition of '%s' has a reply field or sub-field '%s' of type '%s' " - "that has a bson serialization type 'any' when it " - "is not explicitly allowed.") % (command_name, field_name, type_name), file) + ERROR_ID_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED, + command_name, + ( + "The old definition of '%s' has a reply field or sub-field '%s' of type '%s' " + "that has a bson serialization type 'any' when it " + "is not explicitly allowed." + ) + % (command_name, field_name, type_name), + file, + ) - def add_new_reply_field_bson_any_not_allowed_error(self, command_name: str, field_name: str, - type_name: str, file: str) -> None: + def add_new_reply_field_bson_any_not_allowed_error( + self, command_name: str, field_name: str, type_name: str, file: str + ) -> None: """ Add an error about the new reply field bson serialization type being 'any'. @@ -710,92 +997,162 @@ class IDLCompatibilityContext(object): type 'any' when it is not explicitly allowed. """ self._add_error( - ERROR_ID_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED, command_name, - ("The new definition of '%s' has a reply field or sub-field '%s' of type '%s' " - "that has a bson serialization type 'any' when it " - "is not explicitly allowed.") % (command_name, field_name, type_name), file) + ERROR_ID_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED, + command_name, + ( + "The new definition of '%s' has a reply field or sub-field '%s' of type '%s' " + "that has a bson serialization type 'any' when it " + "is not explicitly allowed." + ) + % (command_name, field_name, type_name), + file, + ) - def add_reply_field_cpp_type_not_equal_error(self, command_name: str, field_name: str, - type_name: str, file: str) -> None: + def add_reply_field_cpp_type_not_equal_error( + self, command_name: str, field_name: str, type_name: str, file: str + ) -> None: """Add an error about the old and new reply field cpp_type not being equal.""" - self._add_error(ERROR_ID_REPLY_FIELD_CPP_TYPE_NOT_EQUAL, command_name, - ("'%s' has a reply field or sub-field '%s' of type '%s' that has cpp_type " - "that is not equal in the old and new definitions of this command.") % - (command_name, field_name, type_name), file) + self._add_error( + ERROR_ID_REPLY_FIELD_CPP_TYPE_NOT_EQUAL, + command_name, + ( + "'%s' has a reply field or sub-field '%s' of type '%s' that has cpp_type " + "that is not equal in the old and new definitions of this command." + ) + % (command_name, field_name, type_name), + file, + ) - def add_reply_field_serializer_not_equal_error(self, command_name: str, field_name: str, - type_name: str, file: str) -> None: + def add_reply_field_serializer_not_equal_error( + self, command_name: str, field_name: str, type_name: str, file: str + ) -> None: """Add an error about the old and new reply field serializer not being equal.""" self._add_error( - ERROR_ID_REPLY_FIELD_SERIALIZER_NOT_EQUAL, command_name, - ("'%s' has a reply field or sub-field '%s' of type '%s' that has " - "serializer that is not equal in the old and new definitions of this command.") % - (command_name, field_name, type_name), file) + ERROR_ID_REPLY_FIELD_SERIALIZER_NOT_EQUAL, + command_name, + ( + "'%s' has a reply field or sub-field '%s' of type '%s' that has " + "serializer that is not equal in the old and new definitions of this command." + ) + % (command_name, field_name, type_name), + file, + ) - def add_reply_field_deserializer_not_equal_error(self, command_name: str, field_name: str, - type_name: str, file: str) -> None: + def add_reply_field_deserializer_not_equal_error( + self, command_name: str, field_name: str, type_name: str, file: str + ) -> None: """Add an error about the old and new reply field deserializer not being equal.""" self._add_error( - ERROR_ID_REPLY_FIELD_DESERIALIZER_NOT_EQUAL, command_name, - ("'%s' has a reply field or sub-field '%s' of type '%s' that has " - "deserializer that is not equal in the old and new definitions of this command.") % - (command_name, field_name, type_name), file) + ERROR_ID_REPLY_FIELD_DESERIALIZER_NOT_EQUAL, + command_name, + ( + "'%s' has a reply field or sub-field '%s' of type '%s' that has " + "deserializer that is not equal in the old and new definitions of this command." + ) + % (command_name, field_name, type_name), + file, + ) - def add_new_reply_field_type_not_enum_error(self, command_name: str, field_name: str, - new_field_type: str, old_field_type: str, - file: str) -> None: + def add_new_reply_field_type_not_enum_error( + self, + command_name: str, + field_name: str, + new_field_type: str, + old_field_type: str, + file: str, + ) -> None: """Add an error about the new reply field type not being an enum when the old one is.""" - self._add_error(ERROR_ID_NEW_REPLY_FIELD_TYPE_NOT_ENUM, command_name, - ("'%s' has a reply field or sub-field '%s' of type '%s' " - "that is not an enum while the corresponding " - "old definition of the reply field was an enum of type '%s'.") % - (command_name, field_name, new_field_type, old_field_type), file) + self._add_error( + ERROR_ID_NEW_REPLY_FIELD_TYPE_NOT_ENUM, + command_name, + ( + "'%s' has a reply field or sub-field '%s' of type '%s' " + "that is not an enum while the corresponding " + "old definition of the reply field was an enum of type '%s'." + ) + % (command_name, field_name, new_field_type, old_field_type), + file, + ) - def add_new_reply_field_type_not_struct_error(self, command_name: str, field_name: str, - new_field_type: str, old_field_type: str, - file: str) -> None: + def add_new_reply_field_type_not_struct_error( + self, + command_name: str, + field_name: str, + new_field_type: str, + old_field_type: str, + file: str, + ) -> None: """Add an error about the new reply field type not being a struct when the old one is.""" - self._add_error(ERROR_ID_NEW_REPLY_FIELD_TYPE_NOT_STRUCT, command_name, - ("'%s' has a reply field or sub-field '%s' of type '%s' " - "that is not a struct while the corresponding " - "old definition of the reply field was a struct of type '%s'.") % - (command_name, field_name, new_field_type, old_field_type), file) + self._add_error( + ERROR_ID_NEW_REPLY_FIELD_TYPE_NOT_STRUCT, + command_name, + ( + "'%s' has a reply field or sub-field '%s' of type '%s' " + "that is not a struct while the corresponding " + "old definition of the reply field was a struct of type '%s'." + ) + % (command_name, field_name, new_field_type, old_field_type), + file, + ) - def add_new_reply_field_type_enum_or_struct_error(self, command_name: str, field_name: str, - new_field_type: str, old_field_type: str, - file: str) -> None: + def add_new_reply_field_type_enum_or_struct_error( + self, + command_name: str, + field_name: str, + new_field_type: str, + old_field_type: str, + file: str, + ) -> None: """ Add an error about a reply field type being incompatible with the old field type. Add an error when the new reply field type is an enum or struct and the old reply field is a non-enum or struct type. """ - self._add_error(ERROR_ID_NEW_REPLY_FIELD_TYPE_ENUM_OR_STRUCT, command_name, - ("'%s' has a reply field or sub-field '%s' of type '%s' that is an " - "enum or struct while the corresponding " - "old definition of the reply field was a non-enum or struct of type '%s'.") - % (command_name, field_name, new_field_type, old_field_type), file) + self._add_error( + ERROR_ID_NEW_REPLY_FIELD_TYPE_ENUM_OR_STRUCT, + command_name, + ( + "'%s' has a reply field or sub-field '%s' of type '%s' that is an " + "enum or struct while the corresponding " + "old definition of the reply field was a non-enum or struct of type '%s'." + ) + % (command_name, field_name, new_field_type, old_field_type), + file, + ) - def add_new_reply_field_unstable_error(self, command_name: str, field_name: str, - file: str) -> None: + def add_new_reply_field_unstable_error( + self, command_name: str, field_name: str, file: str + ) -> None: """Add an error about the new command reply field being unstable when the old one is stable.""" self._add_error( - ERROR_ID_NEW_REPLY_FIELD_UNSTABLE, command_name, + ERROR_ID_NEW_REPLY_FIELD_UNSTABLE, + command_name, "'%s' has an unstable reply field or sub-field '%s' " - "that was stable in the old definition of the command." % (command_name, field_name), - file) + "that was stable in the old definition of the command." + % (command_name, field_name), + file, + ) - def add_new_reply_field_variant_type_error(self, command_name: str, field_name: str, - old_field_type: str, file: str) -> None: + def add_new_reply_field_variant_type_error( + self, command_name: str, field_name: str, old_field_type: str, file: str + ) -> None: """Add an error about the new reply field type being variant when the old one is not.""" - self._add_error(ERROR_ID_NEW_REPLY_FIELD_VARIANT_TYPE, command_name, - ("'%s' has a reply field or sub-field '%s' that has a variant " - "type while the corresponding " - "old definition of the reply field type '%s' is not variant.") % - (command_name, field_name, old_field_type), file) + self._add_error( + ERROR_ID_NEW_REPLY_FIELD_VARIANT_TYPE, + command_name, + ( + "'%s' has a reply field or sub-field '%s' that has a variant " + "type while the corresponding " + "old definition of the reply field type '%s' is not variant." + ) + % (command_name, field_name, old_field_type), + file, + ) def add_new_reply_field_variant_type_not_subset_error( - self, command_name: str, field_name: str, variant_type_name: str, file: str) -> None: + self, command_name: str, field_name: str, variant_type_name: str, file: str + ) -> None: """ Add an error about the reply field variant types not being a subset. @@ -803,14 +1160,21 @@ class IDLCompatibilityContext(object): not being a subset of the old variant types. """ self._add_error( - ERROR_ID_NEW_REPLY_FIELD_VARIANT_TYPE_NOT_SUBSET, command_name, - ("'%s' has a reply field or sub-field '%s' with variant types that is " - "not a subset of the corresponding " - "old definition of the reply field types: The type '%s' is not in the old definition " - "of the reply field types.") % (command_name, field_name, variant_type_name), file) + ERROR_ID_NEW_REPLY_FIELD_VARIANT_TYPE_NOT_SUBSET, + command_name, + ( + "'%s' has a reply field or sub-field '%s' with variant types that is " + "not a subset of the corresponding " + "old definition of the reply field types: The type '%s' is not in the old definition " + "of the reply field types." + ) + % (command_name, field_name, variant_type_name), + file, + ) - def add_new_reply_chained_type_not_subset_error(self, command_name: str, reply_name: str, - chained_type_name: str, file: str) -> None: + def add_new_reply_chained_type_not_subset_error( + self, command_name: str, reply_name: str, chained_type_name: str, file: str + ) -> None: """ Add an error about the reply chained types not being a subset. @@ -818,16 +1182,27 @@ class IDLCompatibilityContext(object): not being a subset of the old chained types. """ self._add_error( - ERROR_ID_NEW_REPLY_CHAINED_TYPE_NOT_SUBSET, command_name, - ("'%s' has a reply '%s' with chained types that is " - "not a subset of the corresponding " - "old definition of the reply chained types: The type '%s' is not in the old " - "definition of the reply chained types.") % (command_name, reply_name, - chained_type_name), file) + ERROR_ID_NEW_REPLY_CHAINED_TYPE_NOT_SUBSET, + command_name, + ( + "'%s' has a reply '%s' with chained types that is " + "not a subset of the corresponding " + "old definition of the reply chained types: The type '%s' is not in the old " + "definition of the reply chained types." + ) + % (command_name, reply_name, chained_type_name), + file, + ) def add_old_command_or_param_type_bson_any_error( - self, command_name: str, old_type: str, new_type: str, file: str, - field_name: Optional[str], is_command_parameter: bool) -> None: + self, + command_name: str, + old_type: str, + new_type: str, + file: str, + field_name: Optional[str], + is_command_parameter: bool, + ) -> None: """ Add an error about BSON serialization type. @@ -837,21 +1212,35 @@ class IDLCompatibilityContext(object): """ if is_command_parameter: self._add_error( - ERROR_ID_OLD_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY, command_name, + ERROR_ID_OLD_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY, + command_name, "The old definition of the '%s' command has field or sub-field '%s' that has type '%s' " "that has a bson serialization type 'any', while the new definition of the command" - " has type '%s' that does not have bson serialization type 'any'" % - (command_name, field_name, old_type, new_type), file) + " has type '%s' that does not have bson serialization type 'any'" + % (command_name, field_name, old_type, new_type), + file, + ) else: - self._add_error(ERROR_ID_OLD_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY, command_name, ( - "The old definition of the command '%s' or its sub-struct has type '%s' that has a " - "bson serialization type 'any', while the new definition has type '%s' " - "that does not have bson serialization type 'any'") % (command_name, old_type, - new_type), file) + self._add_error( + ERROR_ID_OLD_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY, + command_name, + ( + "The old definition of the command '%s' or its sub-struct has type '%s' that has a " + "bson serialization type 'any', while the new definition has type '%s' " + "that does not have bson serialization type 'any'" + ) + % (command_name, old_type, new_type), + file, + ) def add_old_command_or_param_type_bson_any_not_allowed_error( - self, command_name: str, type_name: str, file: str, field_name: Optional[str], - is_command_parameter: bool) -> None: + self, + command_name: str, + type_name: str, + file: str, + field_name: Optional[str], + is_command_parameter: bool, + ) -> None: # pylint: disable=invalid-name """ Add an error about the old command or param type bson serialization type being 'any'. @@ -860,22 +1249,37 @@ class IDLCompatibilityContext(object): being of type 'any' when it is not explicitly allowed. """ if is_command_parameter: - self._add_error(ERROR_ID_COMMAND_PARAMETER_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED, - command_name, - ("The old definition of '%s' has a field or sub-field '%s' of type " - "'%s' that has a bson " - "serialization type 'any' when it is not explicitly allowed.") % - (command_name, field_name, type_name), file) + self._add_error( + ERROR_ID_COMMAND_PARAMETER_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED, + command_name, + ( + "The old definition of '%s' has a field or sub-field '%s' of type " + "'%s' that has a bson " + "serialization type 'any' when it is not explicitly allowed." + ) + % (command_name, field_name, type_name), + file, + ) else: self._add_error( - ERROR_ID_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED, command_name, - ("The old definition of '%s' or its sub-struct has a type '%s' that has a bson " - "serialization type 'any' when it is not explicitly allowed.") % (command_name, - type_name), file) + ERROR_ID_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED, + command_name, + ( + "The old definition of '%s' or its sub-struct has a type '%s' that has a bson " + "serialization type 'any' when it is not explicitly allowed." + ) + % (command_name, type_name), + file, + ) def add_new_command_or_param_type_bson_any_not_allowed_error( - self, command_name: str, type_name: str, file: str, field_name: Optional[str], - is_command_parameter: bool) -> None: + self, + command_name: str, + type_name: str, + file: str, + field_name: Optional[str], + is_command_parameter: bool, + ) -> None: # pylint: disable=invalid-name """ Add an error about the new command or param type bson serialization type being 'any'. @@ -884,70 +1288,133 @@ class IDLCompatibilityContext(object): being of type 'any' when it is not explicitly allowed. """ if is_command_parameter: - self._add_error(ERROR_ID_COMMAND_PARAMETER_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED, - command_name, - ("The new definition of '%s' has a field or sub-field '%s' of type " - "'%s' that has a bson " - "serialization type 'any' when it is not explicitly allowed.") % - (command_name, field_name, type_name), file) + self._add_error( + ERROR_ID_COMMAND_PARAMETER_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED, + command_name, + ( + "The new definition of '%s' has a field or sub-field '%s' of type " + "'%s' that has a bson " + "serialization type 'any' when it is not explicitly allowed." + ) + % (command_name, field_name, type_name), + file, + ) else: self._add_error( - ERROR_ID_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED, command_name, - ("The new definition of '%s' or its sub-struct has a type '%s' that has a bson " - "serialization type 'any' when it is not explicitly allowed.") % (command_name, - type_name), file) + ERROR_ID_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED, + command_name, + ( + "The new definition of '%s' or its sub-struct has a type '%s' that has a bson " + "serialization type 'any' when it is not explicitly allowed." + ) + % (command_name, type_name), + file, + ) - def add_command_or_param_cpp_type_not_equal_error(self, command_name: str, type_name: str, - file: str, field_name: Optional[str], - is_command_parameter: bool) -> None: + def add_command_or_param_cpp_type_not_equal_error( + self, + command_name: str, + type_name: str, + file: str, + field_name: Optional[str], + is_command_parameter: bool, + ) -> None: """Add an error about the old and new command or param cpp_type not being equal.""" if is_command_parameter: - self._add_error(ERROR_ID_COMMAND_PARAMETER_CPP_TYPE_NOT_EQUAL, command_name, - ("'%s' has field or sub-field '%s' of type '%s' that has " - "cpp_type that is not equal in the old and new definitions") % - (command_name, field_name, type_name), file) + self._add_error( + ERROR_ID_COMMAND_PARAMETER_CPP_TYPE_NOT_EQUAL, + command_name, + ( + "'%s' has field or sub-field '%s' of type '%s' that has " + "cpp_type that is not equal in the old and new definitions" + ) + % (command_name, field_name, type_name), + file, + ) else: self._add_error( - ERROR_ID_COMMAND_CPP_TYPE_NOT_EQUAL, command_name, - ("'%s' or its sub-struct has command type '%s' that has cpp_type " - "that is not equal in the old and new definitions") % (command_name, type_name), - file) + ERROR_ID_COMMAND_CPP_TYPE_NOT_EQUAL, + command_name, + ( + "'%s' or its sub-struct has command type '%s' that has cpp_type " + "that is not equal in the old and new definitions" + ) + % (command_name, type_name), + file, + ) - def add_command_or_param_serializer_not_equal_error(self, command_name: str, type_name: str, - file: str, field_name: Optional[str], - is_command_parameter: bool) -> None: + def add_command_or_param_serializer_not_equal_error( + self, + command_name: str, + type_name: str, + file: str, + field_name: Optional[str], + is_command_parameter: bool, + ) -> None: """Add an error about the old and new command or param serializer not being equal.""" if is_command_parameter: - self._add_error(ERROR_ID_COMMAND_PARAMETER_SERIALIZER_NOT_EQUAL, command_name, - ("'%s' has field or sub-field '%s' of type '%s' that has " - "serializer that is not equal in the old and new definitions") % - (command_name, field_name, type_name), file) + self._add_error( + ERROR_ID_COMMAND_PARAMETER_SERIALIZER_NOT_EQUAL, + command_name, + ( + "'%s' has field or sub-field '%s' of type '%s' that has " + "serializer that is not equal in the old and new definitions" + ) + % (command_name, field_name, type_name), + file, + ) else: self._add_error( - ERROR_ID_COMMAND_SERIALIZER_NOT_EQUAL, command_name, - ("'%s' or its sub-struct has command type '%s' that has serializer " - "that is not equal in the old and new definitions") % (command_name, type_name), - file) + ERROR_ID_COMMAND_SERIALIZER_NOT_EQUAL, + command_name, + ( + "'%s' or its sub-struct has command type '%s' that has serializer " + "that is not equal in the old and new definitions" + ) + % (command_name, type_name), + file, + ) - def add_command_or_param_deserializer_not_equal_error(self, command_name: str, type_name: str, - file: str, field_name: Optional[str], - is_command_parameter: bool) -> None: + def add_command_or_param_deserializer_not_equal_error( + self, + command_name: str, + type_name: str, + file: str, + field_name: Optional[str], + is_command_parameter: bool, + ) -> None: """Add an error about the old and new command or param deserializer not being equal.""" if is_command_parameter: - self._add_error(ERROR_ID_COMMAND_PARAMETER_DESERIALIZER_NOT_EQUAL, command_name, - ("'%s' has field or sub-field '%s' of type '%s' that has " - "deserializer that is not equal in the old and new definitions") % - (command_name, field_name, type_name), file) + self._add_error( + ERROR_ID_COMMAND_PARAMETER_DESERIALIZER_NOT_EQUAL, + command_name, + ( + "'%s' has field or sub-field '%s' of type '%s' that has " + "deserializer that is not equal in the old and new definitions" + ) + % (command_name, field_name, type_name), + file, + ) else: self._add_error( - ERROR_ID_COMMAND_DESERIALIZER_NOT_EQUAL, command_name, - ("'%s' or its sub-struct has command type '%s' that has deserializer " - "that is not equal in the old and new definitions") % (command_name, type_name), - file) + ERROR_ID_COMMAND_DESERIALIZER_NOT_EQUAL, + command_name, + ( + "'%s' or its sub-struct has command type '%s' that has deserializer " + "that is not equal in the old and new definitions" + ) + % (command_name, type_name), + file, + ) - def add_old_reply_field_bson_any_error(self, command_name: str, field_name: str, - old_field_type: str, new_field_type: str, - file: str) -> None: + def add_old_reply_field_bson_any_error( + self, + command_name: str, + field_name: str, + old_field_type: str, + new_field_type: str, + file: str, + ) -> None: """ Add an about the old reply field type's 'any' bson serialization type. @@ -955,62 +1422,117 @@ class IDLCompatibilityContext(object): 'any' when the new type is non-any or when it is not explicitly allowed. """ self._add_error( - ERROR_ID_OLD_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY, command_name, - ("The old definition of '%s' has a reply field or sub-field '%s' of type '%s' " - "that has a bson serialization type 'any', while the new definition has" - " type '%s' that does not have bson serialization type 'any'") % - (command_name, field_name, old_field_type, new_field_type), file) + ERROR_ID_OLD_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY, + command_name, + ( + "The old definition of '%s' has a reply field or sub-field '%s' of type '%s' " + "that has a bson serialization type 'any', while the new definition has" + " type '%s' that does not have bson serialization type 'any'" + ) + % (command_name, field_name, old_field_type, new_field_type), + file, + ) - def add_reply_field_contains_validator_error(self, command_name: str, field_name: str, - file: str) -> None: + def add_reply_field_contains_validator_error( + self, command_name: str, field_name: str, file: str + ) -> None: """Add an error about the reply field containing a validator.""" self._add_error( - ERROR_ID_REPLY_FIELD_CONTAINS_VALIDATOR, command_name, - ("The new definition of the command '%s' has a reply field or sub-field '%s' " - "that contains a validator while the old definition does not") % (command_name, - field_name), file) + ERROR_ID_REPLY_FIELD_CONTAINS_VALIDATOR, + command_name, + ( + "The new definition of the command '%s' has a reply field or sub-field '%s' " + "that contains a validator while the old definition does not" + ) + % (command_name, field_name), + file, + ) - def add_reply_field_validators_not_equal_error(self, command_name: str, field_name: str, - file: str) -> None: + def add_reply_field_validators_not_equal_error( + self, command_name: str, field_name: str, file: str + ) -> None: """Add an error about the reply field containing a validator.""" self._add_error( - ERROR_ID_REPLY_FIELD_VALIDATORS_NOT_EQUAL, command_name, - ("Validator for reply field or sub-field '%s' in old definition of command '%s' " - "is not equal to the validator in the new definition of the reply field") % - (command_name, field_name), file) + ERROR_ID_REPLY_FIELD_VALIDATORS_NOT_EQUAL, + command_name, + ( + "Validator for reply field or sub-field '%s' in old definition of command '%s' " + "is not equal to the validator in the new definition of the reply field" + ) + % (command_name, field_name), + file, + ) - def add_reply_field_type_invalid_error(self, command_name: str, field_name: str, - file: str) -> None: + def add_reply_field_type_invalid_error( + self, command_name: str, field_name: str, file: str + ) -> None: """Add an error about the reply field or sub-field type being invalid.""" - self._add_error(ERROR_ID_REPLY_FIELD_TYPE_INVALID, command_name, - ("'%s' has a reply field or sub-field '%s' that has an invalid type") % - (command_name, field_name), file) + self._add_error( + ERROR_ID_REPLY_FIELD_TYPE_INVALID, + command_name, + ("'%s' has a reply field or sub-field '%s' that has an invalid type") + % (command_name, field_name), + file, + ) - def add_check_not_equal_error(self, command_name: str, old_check: str, new_check: str, - file: str) -> None: + def add_check_not_equal_error( + self, command_name: str, old_check: str, new_check: str, file: str + ) -> None: """Add an error about the command access_check check not being equal.""" - self._add_error(ERROR_ID_CHECK_NOT_EQUAL, command_name, ( - "The new definition of '%s' has a check '%s' that is not equal to the check '%s' in the old definition" - " of the command.") % (command_name, new_check, old_check), file) + self._add_error( + ERROR_ID_CHECK_NOT_EQUAL, + command_name, + ( + "The new definition of '%s' has a check '%s' that is not equal to the check '%s' in the old definition" + " of the command." + ) + % (command_name, new_check, old_check), + file, + ) - def add_resource_pattern_not_equal_error(self, command_name: str, old_resource_pattern: str, - new_resource_pattern: str, file: str) -> None: + def add_resource_pattern_not_equal_error( + self, + command_name: str, + old_resource_pattern: str, + new_resource_pattern: str, + file: str, + ) -> None: """Add an error about the command access_check resource_pattern not being equal.""" self._add_error( - ERROR_ID_RESOURCE_PATTERN_NOT_EQUAL, command_name, - ("The new definition of '%s' has a resource pattern '%s' that is not equal to the " - "the resource pattern '%s' in the old definition of the command.") % - (command_name, new_resource_pattern, old_resource_pattern), file) + ERROR_ID_RESOURCE_PATTERN_NOT_EQUAL, + command_name, + ( + "The new definition of '%s' has a resource pattern '%s' that is not equal to the " + "the resource pattern '%s' in the old definition of the command." + ) + % (command_name, new_resource_pattern, old_resource_pattern), + file, + ) - def add_new_action_types_not_subset_error(self, command_name: str, file: str) -> None: + def add_new_action_types_not_subset_error( + self, command_name: str, file: str + ) -> None: """Add an error about the new access_check action types not being a subset of the old ones.""" self._add_error( - ERROR_ID_NEW_ACTION_TYPES_NOT_SUBSET, command_name, - ("The new definition of '%s' has action types that are not a subset of the action" - " types in the old definition") % (command_name), file) + ERROR_ID_NEW_ACTION_TYPES_NOT_SUBSET, + command_name, + ( + "The new definition of '%s' has action types that are not a subset of the action" + " types in the old definition" + ) + % (command_name), + file, + ) - def add_type_not_array_error(self, symbol: str, command_name: str, symbol_name: str, - new_type: str, old_type: str, file: str) -> None: + def add_type_not_array_error( + self, + symbol: str, + command_name: str, + symbol_name: str, + new_type: str, + old_type: str, + file: str, + ) -> None: """ Add an error about type not being an ArrayType when it should be. @@ -1018,150 +1540,281 @@ class IDLCompatibilityContext(object): command parameter type). """ self._add_error( - ERROR_ID_TYPE_NOT_ARRAY, command_name, + ERROR_ID_TYPE_NOT_ARRAY, + command_name, "The new definition of command '%s' has %s: '%s' with type '%s' that is not an ArrayType while the old definition had type '%s' which was an ArrayType." - % (command_name, symbol, symbol_name, new_type, old_type), file) + % (command_name, symbol, symbol_name, new_type, old_type), + file, + ) - def add_access_check_type_not_equal_error(self, command_name: str, old_type: str, new_type: str, - file: str) -> None: + def add_access_check_type_not_equal_error( + self, command_name: str, old_type: str, new_type: str, file: str + ) -> None: """Add an error about the command access_check types not being equal.""" - self._add_error(ERROR_ID_ACCESS_CHECK_TYPE_NOT_EQUAL, command_name, ( - "'%s' has access_check of type %s in the old definition that is not equal to the access_check of type '%s'" - "in the new definition.") % (command_name, old_type, new_type), file) + self._add_error( + ERROR_ID_ACCESS_CHECK_TYPE_NOT_EQUAL, + command_name, + ( + "'%s' has access_check of type %s in the old definition that is not equal to the access_check of type '%s'" + "in the new definition." + ) + % (command_name, old_type, new_type), + file, + ) - def add_new_complex_checks_not_subset_error(self, command_name: str, file: str) -> None: + def add_new_complex_checks_not_subset_error( + self, command_name: str, file: str + ) -> None: """Add an error about the complex access_check checks not being a subset of the old ones.""" self._add_error( - ERROR_ID_NEW_COMPLEX_CHECKS_NOT_SUBSET, command_name, - ("The new definition of '%s' has complex access_checks checks that are not a subset " - " of the old definition's complex" - " access_check checks.") % (command_name), file) + ERROR_ID_NEW_COMPLEX_CHECKS_NOT_SUBSET, + command_name, + ( + "The new definition of '%s' has complex access_checks checks that are not a subset " + " of the old definition's complex" + " access_check checks." + ) + % (command_name), + file, + ) - def add_new_complex_privileges_not_subset_error(self, command_name: str, file: str) -> None: + def add_new_complex_privileges_not_subset_error( + self, command_name: str, file: str + ) -> None: """Add an error about the complex access_check privileges not being a subset of the old ones.""" self._add_error( - ERROR_ID_NEW_COMPLEX_PRIVILEGES_NOT_SUBSET, command_name, - ("'%s' has complex access_checks privileges that have changed the resource_pattern" - " or changed/added an action_type in the new definition of the command.") % - (command_name), file) + ERROR_ID_NEW_COMPLEX_PRIVILEGES_NOT_SUBSET, + command_name, + ( + "'%s' has complex access_checks privileges that have changed the resource_pattern" + " or changed/added an action_type in the new definition of the command." + ) + % (command_name), + file, + ) - def add_new_additional_complex_access_check_error(self, command_name: str, file: str) -> None: + def add_new_additional_complex_access_check_error( + self, command_name: str, file: str + ) -> None: """Add an error about an additional complex access_check being added.""" self._add_error( - ERROR_ID_NEW_ADDITIONAL_COMPLEX_ACCESS_CHECK, command_name, - ("'%s' has additional complex access_checks in the new definition of the command that " - "are not in the old definition of the command") % (command_name), file) + ERROR_ID_NEW_ADDITIONAL_COMPLEX_ACCESS_CHECK, + command_name, + ( + "'%s' has additional complex access_checks in the new definition of the command that " + "are not in the old definition of the command" + ) + % (command_name), + file, + ) - def add_removed_access_check_field_error(self, command_name: str, file: str) -> None: + def add_removed_access_check_field_error( + self, command_name: str, file: str + ) -> None: """Add an error the new command removing the access_check field.""" self._add_error( - ERROR_ID_REMOVED_ACCESS_CHECK_FIELD, command_name, - ("'%s' has removed the access_check field in the new definition of the command when it " - "exists in the old definition of the command") % (command_name), file) + ERROR_ID_REMOVED_ACCESS_CHECK_FIELD, + command_name, + ( + "'%s' has removed the access_check field in the new definition of the command when it " + "exists in the old definition of the command" + ) + % (command_name), + file, + ) def add_added_access_check_field_error(self, command_name: str, file: str) -> None: """Add an error the new command adding the access_check field when the api_version is '1'.""" - self._add_error(ERROR_ID_ADDED_ACCESS_CHECK_FIELD, command_name, ( - "The new definition of '%s' has added the access_check field when it did not exist in the " - "old definition of the command, while the api_version is '1'") % (command_name), file) + self._add_error( + ERROR_ID_ADDED_ACCESS_CHECK_FIELD, + command_name, + ( + "The new definition of '%s' has added the access_check field when it did not exist in the " + "old definition of the command, while the api_version is '1'" + ) + % (command_name), + file, + ) def add_generic_argument_removed(self, field_name: str, file: str) -> None: """Add an error about a generic argument that was removed.""" - self._add_error(ERROR_ID_GENERIC_ARGUMENT_REMOVED, field_name, - ("The generic argument '%s' was removed from the new definition of the " - "generic_argument.idl file") % (field_name), file) + self._add_error( + ERROR_ID_GENERIC_ARGUMENT_REMOVED, + field_name, + ( + "The generic argument '%s' was removed from the new definition of the " + "generic_argument.idl file" + ) + % (field_name), + file, + ) - def add_generic_argument_removed_reply_field(self, field_name: str, file: str) -> None: + def add_generic_argument_removed_reply_field( + self, field_name: str, file: str + ) -> None: """Add an error about a generic reply field that was removed.""" - self._add_error(ERROR_ID_GENERIC_ARGUMENT_REMOVED_REPLY_FIELD, field_name, - ("The generic reply field '%s' was removed from the new definition of the " - "generic_argument.idl file") % (field_name), file) + self._add_error( + ERROR_ID_GENERIC_ARGUMENT_REMOVED_REPLY_FIELD, + field_name, + ( + "The generic reply field '%s' was removed from the new definition of the " + "generic_argument.idl file" + ) + % (field_name), + file, + ) - def add_new_reply_field_requires_stability_error(self, command_name: str, field_name: str, - file: str) -> None: + def add_new_reply_field_requires_stability_error( + self, command_name: str, field_name: str, file: str + ) -> None: """Add an error that a new reply field requires the 'stability' field.""" self._add_error( - ERROR_ID_NEW_REPLY_FIELD_REQUIRES_STABILITY, command_name, - ("The new definition of '%s' has reply field '%s' that requires specifying a value " - "for the 'stability' field") % (command_name, field_name), file) + ERROR_ID_NEW_REPLY_FIELD_REQUIRES_STABILITY, + command_name, + ( + "The new definition of '%s' has reply field '%s' that requires specifying a value " + "for the 'stability' field" + ) + % (command_name, field_name), + file, + ) def add_new_param_or_command_type_field_requires_stability_error( - self, command_name: str, field_name: str, file: str, - is_command_parameter: bool) -> None: + self, command_name: str, field_name: str, file: str, is_command_parameter: bool + ) -> None: # pylint: disable=invalid-name """Add an error that a new param or command type field requires the 'stability' field.""" if is_command_parameter: self._add_error( - ERROR_ID_NEW_PARAMETER_REQUIRES_STABILITY, command_name, - ("The new definition of '%s' has parameter '%s' that requires specifying a value " - "for the 'stability' field") % (command_name, field_name), file) + ERROR_ID_NEW_PARAMETER_REQUIRES_STABILITY, + command_name, + ( + "The new definition of '%s' has parameter '%s' that requires specifying a value " + "for the 'stability' field" + ) + % (command_name, field_name), + file, + ) else: self._add_error( - ERROR_ID_NEW_COMMAND_TYPE_FIELD_REQUIRES_STABILITY, command_name, - ("The new definition of '%s' has command type field '%s' that requires specifying " - "a value for the 'stability' field") % (command_name, field_name), file) + ERROR_ID_NEW_COMMAND_TYPE_FIELD_REQUIRES_STABILITY, + command_name, + ( + "The new definition of '%s' has command type field '%s' that requires specifying " + "a value for the 'stability' field" + ) + % (command_name, field_name), + file, + ) - def add_unstable_reply_field_changed_to_stable_error(self, command_name: str, field_name: str, - file: str) -> None: + def add_unstable_reply_field_changed_to_stable_error( + self, command_name: str, field_name: str, file: str + ) -> None: """Add an error that a reply field may not change from unstable to stable.""" - self._add_error(ERROR_ID_UNSTABLE_REPLY_FIELD_CHANGED_TO_STABLE, command_name, ( - "The command '%s' has reply field '%s' which is unstable and may not be changed to stable in " - "the new definition unless explicitly allowed.") % (command_name, field_name), file) + self._add_error( + ERROR_ID_UNSTABLE_REPLY_FIELD_CHANGED_TO_STABLE, + command_name, + ( + "The command '%s' has reply field '%s' which is unstable and may not be changed to stable in " + "the new definition unless explicitly allowed." + ) + % (command_name, field_name), + file, + ) - def add_unstable_param_or_type_field_to_stable_error(self, command_name: str, field_name: str, - file: str, - is_command_parameter: bool) -> None: + def add_unstable_param_or_type_field_to_stable_error( + self, command_name: str, field_name: str, file: str, is_command_parameter: bool + ) -> None: """Add an error that a command parameter or type field may not change from unstable to stable.""" if is_command_parameter: self._add_error( - ERROR_ID_UNSTABLE_COMMAND_PARAM_FIELD_CHANGED_TO_STABLE, command_name, - ("The command '%s' has command parameter field '%s' which is unstable and may " - "not be changed to stable in the new definition unless explicitly allowed.") % - (command_name, field_name), file) + ERROR_ID_UNSTABLE_COMMAND_PARAM_FIELD_CHANGED_TO_STABLE, + command_name, + ( + "The command '%s' has command parameter field '%s' which is unstable and may " + "not be changed to stable in the new definition unless explicitly allowed." + ) + % (command_name, field_name), + file, + ) else: self._add_error( - ERROR_ID_UNSTABLE_COMMAND_TYPE_FIELD_CHANGED_TO_STABLE, command_name, - ("The command '%s' has command type field '%s' which is unstable and may " - "not be changed to stable in the new definition unless explicitly allowed.") % - (command_name, field_name), file) + ERROR_ID_UNSTABLE_COMMAND_TYPE_FIELD_CHANGED_TO_STABLE, + command_name, + ( + "The command '%s' has command type field '%s' which is unstable and may " + "not be changed to stable in the new definition unless explicitly allowed." + ) + % (command_name, field_name), + file, + ) - def add_new_reply_field_added_as_stable_error(self, command_name: str, field_name: str, - file: str) -> None: + def add_new_reply_field_added_as_stable_error( + self, command_name: str, field_name: str, file: str + ) -> None: """Add an error that a new reply field may not be added as stable unless explicitly allowed.""" self._add_error( - ERROR_ID_NEW_REPLY_FIELD_ADDED_AS_STABLE, command_name, - ("The command '%s' has newly-added reply field '%s' which may not be defined as stable " - "unless that addition is explicitly allowed.") % (command_name, field_name), file) + ERROR_ID_NEW_REPLY_FIELD_ADDED_AS_STABLE, + command_name, + ( + "The command '%s' has newly-added reply field '%s' which may not be defined as stable " + "unless that addition is explicitly allowed." + ) + % (command_name, field_name), + file, + ) - def add_new_param_or_type_field_added_as_stable_error(self, command_name: str, field_name: str, - file: str, - is_command_parameter: bool) -> None: + def add_new_param_or_type_field_added_as_stable_error( + self, command_name: str, field_name: str, file: str, is_command_parameter: bool + ) -> None: """Add an error that a new command param or type field may not be added as stable unless explicitly allowed.""" if is_command_parameter: self._add_error( - ERROR_ID_NEW_COMMAND_PARAM_FIELD_ADDED_AS_STABLE, command_name, - ("The command '%s' has newly-added param '%s' which may not be defined as stable " - "unless that addition is explicitly allowed.") % (command_name, field_name), file) + ERROR_ID_NEW_COMMAND_PARAM_FIELD_ADDED_AS_STABLE, + command_name, + ( + "The command '%s' has newly-added param '%s' which may not be defined as stable " + "unless that addition is explicitly allowed." + ) + % (command_name, field_name), + file, + ) else: self._add_error( - ERROR_ID_NEW_COMMAND_TYPE_FIELD_ADDED_AS_STABLE, command_name, - ("The command '%s' has newly-added type '%s' which may not be defined as stable " - "unless that addition is explicitly allowed.") % (command_name, field_name), file) + ERROR_ID_NEW_COMMAND_TYPE_FIELD_ADDED_AS_STABLE, + command_name, + ( + "The command '%s' has newly-added type '%s' which may not be defined as stable " + "unless that addition is explicitly allowed." + ) + % (command_name, field_name), + file, + ) def add_new_param_or_type_field_added_as_unstable_required_error( - self, command_name: str, field_name: str, file: str, - is_command_parameter: bool) -> None: + self, command_name: str, field_name: str, file: str, is_command_parameter: bool + ) -> None: """Add an error that a new unstable command param or type field may not be added as required.""" if is_command_parameter: self._add_error( - ERROR_ID_NEW_COMMAND_PARAM_FIELD_ADDED_AS_UNSTABLE_REQUIRED, command_name, - ("The command '%s' has newly-added unstable param field '%s' which should be optional." - ) % (command_name, field_name), file) + ERROR_ID_NEW_COMMAND_PARAM_FIELD_ADDED_AS_UNSTABLE_REQUIRED, + command_name, + ( + "The command '%s' has newly-added unstable param field '%s' which should be optional." + ) + % (command_name, field_name), + file, + ) else: self._add_error( - ERROR_ID_NEW_COMMAND_TYPE_FIELD_ADDED_AS_UNSTABLE_REQUIRED, command_name, - ("The command '%s' has newly-added unstable type field '%s' which should be optional." - ) % (command_name, field_name), file) + ERROR_ID_NEW_COMMAND_TYPE_FIELD_ADDED_AS_UNSTABLE_REQUIRED, + command_name, + ( + "The command '%s' has newly-added unstable type field '%s' which should be optional." + ) + % (command_name, field_name), + file, + ) def _assert_unique_error_messages() -> None: @@ -1174,7 +1827,8 @@ def _assert_unique_error_messages() -> None: error_ids_set = set(error_ids) if len(error_ids) != len(error_ids_set): raise IDLCompatibilityCheckerError( - "IDL Compatibility Checker error codes prefixed with ERROR_ID are not unique.") + "IDL Compatibility Checker error codes prefixed with ERROR_ID are not unique." + ) # On file import, check the error messages are unique diff --git a/buildscripts/idl/idlc.py b/buildscripts/idl/idlc.py index c5833d1b346..6618ca5b85d 100644 --- a/buildscripts/idl/idlc.py +++ b/buildscripts/idl/idlc.py @@ -39,29 +39,47 @@ import idl.compiler def main(): # type: () -> None """Execute Main Entry point.""" - parser = argparse.ArgumentParser(description='MongoDB IDL Compiler.') + parser = argparse.ArgumentParser(description="MongoDB IDL Compiler.") - parser.add_argument('file', type=str, help="IDL input file") + parser.add_argument("file", type=str, help="IDL input file") - parser.add_argument('-o', '--output', type=str, help="IDL output source file") + parser.add_argument("-o", "--output", type=str, help="IDL output source file") - parser.add_argument('--header', type=str, help="IDL output header file") + parser.add_argument("--header", type=str, help="IDL output header file") - parser.add_argument('-i', '--include', type=str, action="append", - help="Directory to search for IDL import files") + parser.add_argument( + "-i", + "--include", + type=str, + action="append", + help="Directory to search for IDL import files", + ) - parser.add_argument('-v', '--verbose', action='count', help="Enable verbose tracing") + parser.add_argument( + "-v", "--verbose", action="count", help="Enable verbose tracing" + ) - parser.add_argument('--base_dir', type=str, help="IDL output relative base directory") + parser.add_argument( + "--base_dir", type=str, help="IDL output relative base directory" + ) - parser.add_argument('--write-dependencies', action='store_true', - help='only print out a list of dependent imports') + parser.add_argument( + "--write-dependencies", + action="store_true", + help="only print out a list of dependent imports", + ) - parser.add_argument('--write-dependencies-inline', action='store_true', - help='print out a list of dependent imports during file generation') + parser.add_argument( + "--write-dependencies-inline", + action="store_true", + help="print out a list of dependent imports during file generation", + ) - parser.add_argument('--target_arch', type=str, - help="IDL target archiecture (amd64, s390x). defaults to current machine") + parser.add_argument( + "--target_arch", + type=str, + help="IDL target archiecture (amd64, s390x). defaults to current machine", + ) args = parser.parse_args() @@ -81,8 +99,9 @@ def main(): compiler_args.write_dependencies = args.write_dependencies compiler_args.write_dependencies_inline = args.write_dependencies_inline - if (args.output is not None and args.header is None) or \ - (args.output is None and args.header is not None): + if (args.output is not None and args.header is None) or ( + args.output is None and args.header is not None + ): print("ERROR: Either both --header and --output must be specified or neither.") sys.exit(1) @@ -93,5 +112,5 @@ def main(): sys.exit(1) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/buildscripts/idl/lib.py b/buildscripts/idl/lib.py index 3c8809b5ff5..a7240fefc15 100644 --- a/buildscripts/idl/lib.py +++ b/buildscripts/idl/lib.py @@ -37,14 +37,17 @@ def list_idls(directory: str) -> Set[str]: """Find all IDL files in the current directory.""" return { os.path.join(dirpath, filename) - for dirpath, dirnames, filenames in os.walk(directory) for filename in filenames - if filename.endswith('.idl') + for dirpath, dirnames, filenames in os.walk(directory) + for filename in filenames + if filename.endswith(".idl") } def parse_idl(idl_path: str, import_directories: List[str]) -> syntax.IDLParsedSpec: """Parse an IDL file or throw an error.""" - parsed_doc = parser.parse(open(idl_path), idl_path, CompilerImportResolver(import_directories)) + parsed_doc = parser.parse( + open(idl_path), idl_path, CompilerImportResolver(import_directories) + ) if parsed_doc.errors: parsed_doc.errors.dump_errors() diff --git a/buildscripts/idl/run_tests.py b/buildscripts/idl/run_tests.py index 896226ac591..db0fa70ce34 100644 --- a/buildscripts/idl/run_tests.py +++ b/buildscripts/idl/run_tests.py @@ -47,11 +47,11 @@ def run_tests(): # my-py type information. all_tests = unittest.defaultTestLoader.discover(start_dir="tests") # type: ignore - runner = XMLTestRunner(verbosity=2, failfast=False, output='results') + runner = XMLTestRunner(verbosity=2, failfast=False, output="results") result = runner.run(all_tests) sys.exit(not result.wasSuccessful()) -if __name__ == '__main__': +if __name__ == "__main__": run_tests() diff --git a/buildscripts/idl/tests/context.py b/buildscripts/idl/tests/context.py index af1cd4e1e3e..52191c5bd48 100644 --- a/buildscripts/idl/tests/context.py +++ b/buildscripts/idl/tests/context.py @@ -30,7 +30,7 @@ import os import sys -sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) import idl.ast # noqa: F401 import idl.binder # noqa: F401 diff --git a/buildscripts/idl/tests/test_binder.py b/buildscripts/idl/tests/test_binder.py index 2b0f2641214..10f6c3c4083 100644 --- a/buildscripts/idl/tests/test_binder.py +++ b/buildscripts/idl/tests/test_binder.py @@ -36,6 +36,7 @@ import unittest if __package__ is None: import sys from os import path + sys.path.append(path.dirname(path.abspath(__file__))) import testcase from context import idl @@ -50,9 +51,9 @@ INDENT_SPACE_COUNT = 4 def fill_spaces(count): # type: (int) -> str """Fill a string full of spaces.""" - fill = '' + fill = "" for _ in range(count * INDENT_SPACE_COUNT): - fill += ' ' + fill += " " return fill @@ -62,7 +63,7 @@ def indent_text(count, unindented_text): """Indent each line of a multi-line string.""" lines = unindented_text.splitlines() fill = fill_spaces(count) - return '\n'.join(fill + line for line in lines) + return "\n".join(fill + line for line in lines) class TestBinder(testcase.IDLTestcase): @@ -141,15 +142,17 @@ class TestBinder(testcase.IDLTestcase): cpp_namespace: 'mongo' cpp_includes: - 'bar' - - 'foo'""")) + - 'foo'""") + ) self.assertEqual(spec.globals.cpp_namespace, "mongo") - self.assertListEqual(spec.globals.cpp_includes, ['bar', 'foo']) + self.assertListEqual(spec.globals.cpp_includes, ["bar", "foo"]) spec = self.assert_bind( textwrap.dedent(""" global: cpp_namespace: 'mongo::nested' - """)) + """) + ) self.assertEqual(spec.globals.cpp_namespace, "mongo::nested") def test_global_negatives(self): @@ -159,7 +162,9 @@ class TestBinder(testcase.IDLTestcase): textwrap.dedent(""" global: cpp_namespace: 'something' - """), idl.errors.ERROR_ID_BAD_CPP_NAMESPACE) + """), + idl.errors.ERROR_ID_BAD_CPP_NAMESPACE, + ) def test_type_positive(self): # type: () -> None @@ -175,15 +180,27 @@ class TestBinder(testcase.IDLTestcase): deserializer: foo default: foo is_view: false - """)) + """) + ) # Test supported types for bson_type in [ - "bool", "date", "null", "decimal", "double", "int", "long", "objectid", "regex", - "string", "timestamp", "undefined" + "bool", + "date", + "null", + "decimal", + "double", + "int", + "long", + "objectid", + "regex", + "string", + "timestamp", + "undefined", ]: self.assert_bind( - textwrap.dedent(""" + textwrap.dedent( + """ types: foofoo: description: foo @@ -192,19 +209,23 @@ class TestBinder(testcase.IDLTestcase): default: foo deserializer: BSONElement::fake is_view: false - """ % (bson_type))) + """ + % (bson_type) + ) + ) # Test supported numeric types for cpp_type in [ - "std::int32_t", - "std::uint32_t", - "std::int32_t", - "std::uint64_t", - "std::vector", - "std::array", + "std::int32_t", + "std::uint32_t", + "std::int32_t", + "std::uint64_t", + "std::vector", + "std::array", ]: self.assert_bind( - textwrap.dedent(""" + textwrap.dedent( + """ types: foofoo: description: foo @@ -212,7 +233,10 @@ class TestBinder(testcase.IDLTestcase): bson_serialization_type: int deserializer: BSONElement::fake is_view: false - """ % (cpp_type))) + """ + % (cpp_type) + ) + ) # Test object self.assert_bind( @@ -226,7 +250,8 @@ class TestBinder(testcase.IDLTestcase): deserializer: foo default: foo is_view: false - """)) + """) + ) # Test 'any' self.assert_bind( @@ -240,7 +265,8 @@ class TestBinder(testcase.IDLTestcase): deserializer: foo default: foo is_view: false - """)) + """) + ) # Test 'chain' self.assert_bind( @@ -254,12 +280,14 @@ class TestBinder(testcase.IDLTestcase): deserializer: foo default: foo is_view: false - """)) + """) + ) # Test supported bindata_subtype for bindata_subtype in ["generic", "function", "uuid", "md5"]: self.assert_bind( - textwrap.dedent(""" + textwrap.dedent( + """ types: foofoo: description: foo @@ -268,7 +296,10 @@ class TestBinder(testcase.IDLTestcase): bindata_subtype: %s deserializer: BSONElement::fake is_view: false - """ % (bindata_subtype))) + """ + % (bindata_subtype) + ) + ) def test_type_negative(self): # type: () -> None @@ -283,7 +314,9 @@ class TestBinder(testcase.IDLTestcase): cpp_type: foo bson_serialization_type: foo is_view: false - """), idl.errors.ERROR_ID_BAD_BSON_TYPE) + """), + idl.errors.ERROR_ID_BAD_BSON_TYPE, + ) # Test bad cpp_type name self.assert_bind_fail( @@ -295,50 +328,53 @@ class TestBinder(testcase.IDLTestcase): bson_serialization_type: string deserializer: bar is_view: false - """), idl.errors.ERROR_ID_NO_STRINGDATA) + """), + idl.errors.ERROR_ID_NO_STRINGDATA, + ) # Test unsupported serialization for cpp_type in [ - "char", - "signed char", - "unsigned char", - "signed short int", - "short int", - "short", - "signed short", - "unsigned short", - "unsigned short int", - "signed int", - "signed", - "unsigned int", - "unsigned", - "signed long int", - "signed long", - "int", - "long int", - "long", - "unsigned long int", - "unsigned long", - "signed long long int", - "signed long long", - "long long int", - "long long", - "unsigned long int", - "unsigned long", - "wchar_t", - "char16_t", - "char32_t", - "int8_t", - "int16_t", - "int32_t", - "int64_t", - "uint8_t", - "uint16_t", - "uint32_t", - "uint64_t", + "char", + "signed char", + "unsigned char", + "signed short int", + "short int", + "short", + "signed short", + "unsigned short", + "unsigned short int", + "signed int", + "signed", + "unsigned int", + "unsigned", + "signed long int", + "signed long", + "int", + "long int", + "long", + "unsigned long int", + "unsigned long", + "signed long long int", + "signed long long", + "long long int", + "long long", + "unsigned long int", + "unsigned long", + "wchar_t", + "char16_t", + "char32_t", + "int8_t", + "int16_t", + "int32_t", + "int64_t", + "uint8_t", + "uint16_t", + "uint32_t", + "uint64_t", ]: self.assert_bind_fail( - textwrap.dedent(""" + textwrap.dedent( + """ types: foofoo: description: foo @@ -346,12 +382,22 @@ class TestBinder(testcase.IDLTestcase): bson_serialization_type: int deserializer: BSONElement::int is_view: false - """ % (cpp_type)), idl.errors.ERROR_ID_BAD_NUMERIC_CPP_TYPE) + """ + % (cpp_type) + ), + idl.errors.ERROR_ID_BAD_NUMERIC_CPP_TYPE, + ) # Test the std prefix 8 and 16-byte integers fail - for std_cpp_type in ["std::int8_t", "std::int16_t", "std::uint8_t", "std::uint16_t"]: + for std_cpp_type in [ + "std::int8_t", + "std::int16_t", + "std::uint8_t", + "std::uint16_t", + ]: self.assert_bind_fail( - textwrap.dedent(""" + textwrap.dedent( + """ types: foofoo: description: foo @@ -359,7 +405,11 @@ class TestBinder(testcase.IDLTestcase): bson_serialization_type: int deserializer: BSONElement::int is_view: false - """ % (std_cpp_type)), idl.errors.ERROR_ID_BAD_NUMERIC_CPP_TYPE) + """ + % (std_cpp_type) + ), + idl.errors.ERROR_ID_BAD_NUMERIC_CPP_TYPE, + ) # Test bindata_subtype missing self.assert_bind_fail( @@ -371,7 +421,9 @@ class TestBinder(testcase.IDLTestcase): bson_serialization_type: bindata deserializer: BSONElement::fake is_view: false - """), idl.errors.ERROR_ID_BAD_BSON_BINDATA_SUBTYPE_VALUE) + """), + idl.errors.ERROR_ID_BAD_BSON_BINDATA_SUBTYPE_VALUE, + ) # Test fake bindata_subtype is wrong self.assert_bind_fail( @@ -384,7 +436,9 @@ class TestBinder(testcase.IDLTestcase): bindata_subtype: foo deserializer: BSONElement::fake is_view: false - """), idl.errors.ERROR_ID_BAD_BSON_BINDATA_SUBTYPE_VALUE) + """), + idl.errors.ERROR_ID_BAD_BSON_BINDATA_SUBTYPE_VALUE, + ) # Test deprecated bindata_subtype 'binary', and 'uuid_old' are wrong self.assert_bind_fail( @@ -396,7 +450,9 @@ class TestBinder(testcase.IDLTestcase): bson_serialization_type: bindata bindata_subtype: binary is_view: false - """), idl.errors.ERROR_ID_BAD_BSON_BINDATA_SUBTYPE_VALUE) + """), + idl.errors.ERROR_ID_BAD_BSON_BINDATA_SUBTYPE_VALUE, + ) self.assert_bind_fail( textwrap.dedent(""" @@ -407,7 +463,9 @@ class TestBinder(testcase.IDLTestcase): bson_serialization_type: bindata bindata_subtype: uuid_old is_view: false - """), idl.errors.ERROR_ID_BAD_BSON_BINDATA_SUBTYPE_VALUE) + """), + idl.errors.ERROR_ID_BAD_BSON_BINDATA_SUBTYPE_VALUE, + ) # Test bindata_subtype on wrong type self.assert_bind_fail( @@ -420,7 +478,9 @@ class TestBinder(testcase.IDLTestcase): bindata_subtype: generic deserializer: BSONElement::fake is_view: false - """), idl.errors.ERROR_ID_BAD_BSON_BINDATA_SUBTYPE_TYPE) + """), + idl.errors.ERROR_ID_BAD_BSON_BINDATA_SUBTYPE_TYPE, + ) # Test bindata with default self.assert_bind_fail( @@ -433,7 +493,9 @@ class TestBinder(testcase.IDLTestcase): bindata_subtype: uuid default: 42 is_view: false - """), idl.errors.ERROR_ID_BAD_BINDATA_DEFAULT) + """), + idl.errors.ERROR_ID_BAD_BINDATA_DEFAULT, + ) # Test bindata in list of types self.assert_bind_fail( @@ -446,7 +508,9 @@ class TestBinder(testcase.IDLTestcase): - bindata - string is_view: false - """), idl.errors.ERROR_ID_BAD_BSON_TYPE) + """), + idl.errors.ERROR_ID_BAD_BSON_TYPE, + ) # Test bindata in list of types self.assert_bind_fail( @@ -459,7 +523,9 @@ class TestBinder(testcase.IDLTestcase): - bindata - string is_view: false - """), idl.errors.ERROR_ID_BAD_BSON_TYPE) + """), + idl.errors.ERROR_ID_BAD_BSON_TYPE, + ) # Test 'any' in list of types self.assert_bind_fail( @@ -472,7 +538,9 @@ class TestBinder(testcase.IDLTestcase): - any - int is_view: false - """), idl.errors.ERROR_ID_BAD_ANY_TYPE_USE) + """), + idl.errors.ERROR_ID_BAD_ANY_TYPE_USE, + ) # Test object in list of types self.assert_bind_fail( @@ -485,7 +553,9 @@ class TestBinder(testcase.IDLTestcase): - object - int is_view: false - """), idl.errors.ERROR_ID_BAD_BSON_TYPE_LIST) + """), + idl.errors.ERROR_ID_BAD_BSON_TYPE_LIST, + ) # Test fake in list of types self.assert_bind_fail( @@ -498,7 +568,9 @@ class TestBinder(testcase.IDLTestcase): - int - fake is_view: false - """), idl.errors.ERROR_ID_BAD_BSON_TYPE) + """), + idl.errors.ERROR_ID_BAD_BSON_TYPE, + ) # Test 'chain' in list of types self.assert_bind_fail( @@ -511,15 +583,27 @@ class TestBinder(testcase.IDLTestcase): - chain - int is_view: false - """), idl.errors.ERROR_ID_BAD_ANY_TYPE_USE) + """), + idl.errors.ERROR_ID_BAD_ANY_TYPE_USE, + ) # Test unsupported serialization for bson_type in [ - "bool", "date", "null", "decimal", "double", "int", "long", "objectid", "regex", - "timestamp", "undefined" + "bool", + "date", + "null", + "decimal", + "double", + "int", + "long", + "objectid", + "regex", + "timestamp", + "undefined", ]: self.assert_bind_fail( - textwrap.dedent(""" + textwrap.dedent( + """ types: foofoo: description: foo @@ -528,11 +612,15 @@ class TestBinder(testcase.IDLTestcase): serializer: foo deserializer: BSONElement::fake is_view: false - """ % (bson_type)), - idl.errors.ERROR_ID_CUSTOM_SCALAR_SERIALIZATION_NOT_SUPPORTED) + """ + % (bson_type) + ), + idl.errors.ERROR_ID_CUSTOM_SCALAR_SERIALIZATION_NOT_SUPPORTED, + ) self.assert_bind_fail( - textwrap.dedent(""" + textwrap.dedent( + """ types: foofoo: description: foo @@ -540,8 +628,11 @@ class TestBinder(testcase.IDLTestcase): bson_serialization_type: %s deserializer: foo is_view: false - """ % (bson_type)), - idl.errors.ERROR_ID_CUSTOM_SCALAR_SERIALIZATION_NOT_SUPPORTED) + """ + % (bson_type) + ), + idl.errors.ERROR_ID_CUSTOM_SCALAR_SERIALIZATION_NOT_SUPPORTED, + ) # Test 'any' serialization needs deserializer self.assert_bind_fail( @@ -552,7 +643,9 @@ class TestBinder(testcase.IDLTestcase): cpp_type: foo bson_serialization_type: any is_view: false - """), idl.errors.ERROR_ID_MISSING_AST_REQUIRED_FIELD) + """), + idl.errors.ERROR_ID_MISSING_AST_REQUIRED_FIELD, + ) # Test 'chain' serialization needs deserializer self.assert_bind_fail( @@ -564,7 +657,9 @@ class TestBinder(testcase.IDLTestcase): bson_serialization_type: chain serializer: bar is_view: false - """), idl.errors.ERROR_ID_MISSING_AST_REQUIRED_FIELD) + """), + idl.errors.ERROR_ID_MISSING_AST_REQUIRED_FIELD, + ) # Test 'string' serialization needs deserializer self.assert_bind_fail( @@ -576,7 +671,9 @@ class TestBinder(testcase.IDLTestcase): bson_serialization_type: string serializer: bar is_view: false - """), idl.errors.ERROR_ID_MISSING_AST_REQUIRED_FIELD) + """), + idl.errors.ERROR_ID_MISSING_AST_REQUIRED_FIELD, + ) # Test 'date' serialization needs deserializer self.assert_bind_fail( @@ -587,7 +684,9 @@ class TestBinder(testcase.IDLTestcase): cpp_type: foo bson_serialization_type: date is_view: false - """), idl.errors.ERROR_ID_MISSING_AST_REQUIRED_FIELD) + """), + idl.errors.ERROR_ID_MISSING_AST_REQUIRED_FIELD, + ) # Test 'chain' serialization needs serializer self.assert_bind_fail( @@ -599,7 +698,9 @@ class TestBinder(testcase.IDLTestcase): bson_serialization_type: chain deserializer: bar is_view: false - """), idl.errors.ERROR_ID_MISSING_AST_REQUIRED_FIELD) + """), + idl.errors.ERROR_ID_MISSING_AST_REQUIRED_FIELD, + ) # Test list of bson types needs deserializer self.assert_bind_fail( @@ -612,7 +713,9 @@ class TestBinder(testcase.IDLTestcase): - int - string is_view: false - """), idl.errors.ERROR_ID_MISSING_AST_REQUIRED_FIELD) + """), + idl.errors.ERROR_ID_MISSING_AST_REQUIRED_FIELD, + ) # Test array as name self.assert_bind_fail( @@ -624,7 +727,9 @@ class TestBinder(testcase.IDLTestcase): bson_serialization_type: string deserializer: bar is_view: false - """), idl.errors.ERROR_ID_ARRAY_NOT_VALID_TYPE) + """), + idl.errors.ERROR_ID_ARRAY_NOT_VALID_TYPE, + ) def test_struct_positive(self): # type: () -> None @@ -640,25 +745,32 @@ class TestBinder(testcase.IDLTestcase): bson_serialization_type: int deserializer: mongo::BSONElement::_numberInt is_view: false - """)) + """), + ) - self.assert_bind(test_preamble + textwrap.dedent(""" + self.assert_bind( + test_preamble + + textwrap.dedent(""" structs: foo: description: foo strict: true fields: foo: string - """)) + """) + ) - self.assert_bind(test_preamble + textwrap.dedent(""" + self.assert_bind( + test_preamble + + textwrap.dedent(""" structs: foo: description: foo strict: true fields: foo: array - """)) + """) + ) def test_struct_negative(self): # type: () -> None @@ -674,27 +786,34 @@ class TestBinder(testcase.IDLTestcase): bson_serialization_type: int deserializer: mongo::BSONElement::_numberInt is_view: false - """)) + """), + ) # Test array as name self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: array: description: foo strict: true fields: foo: string - """), idl.errors.ERROR_ID_ARRAY_NOT_VALID_TYPE) + """), + idl.errors.ERROR_ID_ARRAY_NOT_VALID_TYPE, + ) - self.assert_bind(test_preamble + textwrap.dedent(""" + self.assert_bind( + test_preamble + + textwrap.dedent(""" structs: foo: description: foo strict: true fields: foo: array - """)) + """) + ) def test_variant_positive(self): # type: () -> None @@ -717,9 +836,12 @@ class TestBinder(testcase.IDLTestcase): cpp_type: "std::vector" deserializer: "mongo::BSONElement::_binDataVector" is_view: false - """)) + """), + ) - self.assert_bind(test_preamble + textwrap.dedent(""" + self.assert_bind( + test_preamble + + textwrap.dedent(""" structs: foo: description: foo @@ -729,9 +851,12 @@ class TestBinder(testcase.IDLTestcase): variant: - string - int - """)) + """) + ) - self.assert_bind(test_preamble + textwrap.dedent(""" + self.assert_bind( + test_preamble + + textwrap.dedent(""" structs: foo: description: foo @@ -741,9 +866,12 @@ class TestBinder(testcase.IDLTestcase): variant: - string - bindata_function - """)) + """) + ) - self.assert_bind(test_preamble + textwrap.dedent(""" + self.assert_bind( + test_preamble + + textwrap.dedent(""" structs: foo: description: foo @@ -754,10 +882,13 @@ class TestBinder(testcase.IDLTestcase): - string - int default: 1 - """)) + """) + ) # Test multiple BSON serialization type Object. - self.assert_bind(test_preamble + textwrap.dedent(""" + self.assert_bind( + test_preamble + + textwrap.dedent(""" structs: insert_type: description: foo @@ -774,16 +905,19 @@ class TestBinder(testcase.IDLTestcase): - insert_type - update_type - int - """)) + """) + ) def test_variant_negative(self): # type: () -> None """Negative variant test cases.""" # Setup some common types - test_preamble = self.common_types + indent_text( - 1, - textwrap.dedent(""" + test_preamble = ( + self.common_types + + indent_text( + 1, + textwrap.dedent(""" int: description: foo cpp_type: std::int32_t @@ -800,7 +934,9 @@ class TestBinder(testcase.IDLTestcase): cpp_type: "std::int32_t" deserializer: "mongo::BSONElement::safeNumberInt" is_view: false - """)) + textwrap.dedent(""" + """), + ) + + textwrap.dedent(""" enums: foo_enum: description: foo @@ -809,9 +945,11 @@ class TestBinder(testcase.IDLTestcase): v1: 0 v2: 1 """) + ) self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: foo: description: foo @@ -820,10 +958,13 @@ class TestBinder(testcase.IDLTestcase): type: variant: - string - """), idl.errors.ERROR_ID_USELESS_VARIANT) + """), + idl.errors.ERROR_ID_USELESS_VARIANT, + ) self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: foo: description: foo @@ -834,11 +975,15 @@ class TestBinder(testcase.IDLTestcase): - string - int - not_defined - """), idl.errors.ERROR_ID_UNKNOWN_TYPE, True) + """), + idl.errors.ERROR_ID_UNKNOWN_TYPE, + True, + ) # Enums are banned in variants for now. self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: foo: description: foo @@ -848,10 +993,14 @@ class TestBinder(testcase.IDLTestcase): variant: - string - foo_enum - """), idl.errors.ERROR_ID_NO_VARIANT_ENUM, True) + """), + idl.errors.ERROR_ID_NO_VARIANT_ENUM, + True, + ) self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: foo: description: foo @@ -861,10 +1010,13 @@ class TestBinder(testcase.IDLTestcase): variant: - string - string - """), idl.errors.ERROR_ID_VARIANT_DUPLICATE_TYPES) + """), + idl.errors.ERROR_ID_VARIANT_DUPLICATE_TYPES, + ) self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: foo: description: foo @@ -874,10 +1026,13 @@ class TestBinder(testcase.IDLTestcase): variant: - array - array - """), idl.errors.ERROR_ID_VARIANT_DUPLICATE_TYPES) + """), + idl.errors.ERROR_ID_VARIANT_DUPLICATE_TYPES, + ) self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: struct0: description: foo @@ -891,11 +1046,14 @@ class TestBinder(testcase.IDLTestcase): variant: - array - array - """), idl.errors.ERROR_ID_VARIANT_DUPLICATE_TYPES) + """), + idl.errors.ERROR_ID_VARIANT_DUPLICATE_TYPES, + ) # At most one array can have BSON serialization type NumberInt. self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: foo: description: foo @@ -905,10 +1063,13 @@ class TestBinder(testcase.IDLTestcase): variant: - array - array - """), idl.errors.ERROR_ID_VARIANT_DUPLICATE_TYPES) + """), + idl.errors.ERROR_ID_VARIANT_DUPLICATE_TYPES, + ) self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: one_string: description: foo @@ -922,12 +1083,16 @@ class TestBinder(testcase.IDLTestcase): - one_string - one_string - int - """), idl.errors.ERROR_ID_VARIANT_STRUCTS, True) + """), + idl.errors.ERROR_ID_VARIANT_STRUCTS, + True, + ) # For multiple BSON serialization type Objects they must have different field names # for their first field. self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: one_string: description: foo @@ -944,11 +1109,15 @@ class TestBinder(testcase.IDLTestcase): - one_string - one_int - int - """), idl.errors.ERROR_ID_VARIANT_STRUCTS, True) + """), + idl.errors.ERROR_ID_VARIANT_STRUCTS, + True, + ) # At most one type can have BSON serialization type NumberInt. self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: foo: description: foo @@ -958,7 +1127,9 @@ class TestBinder(testcase.IDLTestcase): variant: - safeInt - int - """), idl.errors.ERROR_ID_VARIANT_DUPLICATE_TYPES) + """), + idl.errors.ERROR_ID_VARIANT_DUPLICATE_TYPES, + ) def test_field_positive(self): # type: () -> None @@ -968,17 +1139,22 @@ class TestBinder(testcase.IDLTestcase): test_preamble = self.common_types # Short type - self.assert_bind(test_preamble + textwrap.dedent(""" + self.assert_bind( + test_preamble + + textwrap.dedent(""" structs: bar: description: foo strict: false fields: foo: string - """)) + """) + ) # Long type - self.assert_bind(test_preamble + textwrap.dedent(""" + self.assert_bind( + test_preamble + + textwrap.dedent(""" structs: bar: description: foo @@ -986,10 +1162,13 @@ class TestBinder(testcase.IDLTestcase): fields: foo: type: string - """)) + """) + ) # Long type with default - self.assert_bind(test_preamble + textwrap.dedent(""" + self.assert_bind( + test_preamble + + textwrap.dedent(""" structs: bar: description: foo @@ -998,22 +1177,28 @@ class TestBinder(testcase.IDLTestcase): foo: type: string default: bar - """)) + """) + ) # Test array as field type - self.assert_bind(test_preamble + textwrap.dedent(""" + self.assert_bind( + test_preamble + + textwrap.dedent(""" structs: foo: description: foo strict: true fields: foo: array - """)) + """) + ) # Test array as field type - self.assert_bind(self.common_types + indent_text( - 1, - textwrap.dedent(""" + self.assert_bind( + self.common_types + + indent_text( + 1, + textwrap.dedent(""" arrayfake: description: foo cpp_type: foo @@ -1021,17 +1206,22 @@ class TestBinder(testcase.IDLTestcase): serializer: foo deserializer: foo is_view: false - """)) + textwrap.dedent(""" + """), + ) + + textwrap.dedent(""" structs: foo: description: foo strict: true fields: arrayOfString: arrayfake - """)) + """) + ) # Test always_serialize with optional - self.assert_bind(test_preamble + textwrap.dedent(""" + self.assert_bind( + test_preamble + + textwrap.dedent(""" structs: foo: description: foo @@ -1041,10 +1231,13 @@ class TestBinder(testcase.IDLTestcase): type: string optional: true always_serialize: true - """)) + """) + ) # Test field of a struct type with default=true - self.assert_bind(test_preamble + textwrap.dedent(""" + self.assert_bind( + test_preamble + + textwrap.dedent(""" structs: foo: description: foo @@ -1058,7 +1251,8 @@ class TestBinder(testcase.IDLTestcase): type: foo default: true - """)) + """) + ) def test_field_negative(self): # type: () -> None @@ -1093,7 +1287,8 @@ class TestBinder(testcase.IDLTestcase): # Test field of a struct type with a non-true default self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: foo: description: foo @@ -1107,44 +1302,57 @@ class TestBinder(testcase.IDLTestcase): type: foo default: foo - """), idl.errors.ERROR_ID_DEFAULT_MUST_BE_TRUE_OR_EMPTY_FOR_STRUCT) + """), + idl.errors.ERROR_ID_DEFAULT_MUST_BE_TRUE_OR_EMPTY_FOR_STRUCT, + ) # Test array as field name self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: foo: description: foo strict: true fields: array: string - """), idl.errors.ERROR_ID_ARRAY_NOT_VALID_TYPE) + """), + idl.errors.ERROR_ID_ARRAY_NOT_VALID_TYPE, + ) # Test recursive array as field type self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: foo: description: foo strict: true fields: foo: array> - """), idl.errors.ERROR_ID_BAD_ARRAY_TYPE_NAME, True) + """), + idl.errors.ERROR_ID_BAD_ARRAY_TYPE_NAME, + True, + ) # Test inherited default with array self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: foo: description: foo strict: true fields: foo: array - """), idl.errors.ERROR_ID_ARRAY_NO_DEFAULT) + """), + idl.errors.ERROR_ID_ARRAY_NO_DEFAULT, + ) # Test non-inherited default with array self.assert_bind_fail( - self.common_types + textwrap.dedent(""" + self.common_types + + textwrap.dedent(""" structs: foo: description: foo @@ -1153,11 +1361,14 @@ class TestBinder(testcase.IDLTestcase): foo: type: array default: 123 - """), idl.errors.ERROR_ID_ARRAY_NO_DEFAULT) + """), + idl.errors.ERROR_ID_ARRAY_NO_DEFAULT, + ) # Test bindata with default self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: foo: description: foo @@ -1166,11 +1377,14 @@ class TestBinder(testcase.IDLTestcase): foo: type: bindata default: 42 - """), idl.errors.ERROR_ID_BAD_BINDATA_DEFAULT) + """), + idl.errors.ERROR_ID_BAD_BINDATA_DEFAULT, + ) # Test default and optional for the same field self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: foo: description: foo @@ -1180,11 +1394,14 @@ class TestBinder(testcase.IDLTestcase): type: string default: 42 optional: true - """), idl.errors.ERROR_ID_ILLEGAL_FIELD_DEFAULT_AND_OPTIONAL) + """), + idl.errors.ERROR_ID_ILLEGAL_FIELD_DEFAULT_AND_OPTIONAL, + ) # Test always_serialize without optional for the same field self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: foo: description: foo @@ -1194,11 +1411,14 @@ class TestBinder(testcase.IDLTestcase): type: string default: 42 always_serialize: true - """), idl.errors.ERROR_ID_ILLEGAL_FIELD_ALWAYS_SERIALIZE_NOT_OPTIONAL) + """), + idl.errors.ERROR_ID_ILLEGAL_FIELD_ALWAYS_SERIALIZE_NOT_OPTIONAL, + ) # Test duplicate comparison order self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: foo: description: foo @@ -1211,11 +1431,14 @@ class TestBinder(testcase.IDLTestcase): bar: type: string comparison_order: 1 - """), idl.errors.ERROR_ID_IS_DUPLICATE_COMPARISON_ORDER) + """), + idl.errors.ERROR_ID_IS_DUPLICATE_COMPARISON_ORDER, + ) # Test field marked with non_const_getter in immutable struct self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: foo: description: foo @@ -1224,17 +1447,21 @@ class TestBinder(testcase.IDLTestcase): foo: type: string non_const_getter: true - """), idl.errors.ERROR_ID_NON_CONST_GETTER_IN_IMMUTABLE_STRUCT) + """), + idl.errors.ERROR_ID_NON_CONST_GETTER_IN_IMMUTABLE_STRUCT, + ) def test_ignored_field_negative(self): # type: () -> None """Test that if a field is marked as ignored, no other properties are set.""" for test_value in [ - "optional: true", - "default: foo", + "optional: true", + "default: foo", ]: self.assert_bind_fail( - self.common_types + textwrap.dedent(""" + self.common_types + + textwrap.dedent( + """ structs: foo: description: foo @@ -1244,7 +1471,11 @@ class TestBinder(testcase.IDLTestcase): type: string ignore: true %s - """ % (test_value)), idl.errors.ERROR_ID_FIELD_MUST_BE_EMPTY_FOR_IGNORED) + """ + % (test_value) + ), + idl.errors.ERROR_ID_FIELD_MUST_BE_EMPTY_FOR_IGNORED, + ) def test_chained_type_positive(self): # type: () -> None @@ -1261,17 +1492,21 @@ class TestBinder(testcase.IDLTestcase): deserializer: foo default: foo is_view: false - """)) + """), + ) # Chaining only - self.assert_bind(test_preamble + textwrap.dedent(""" + self.assert_bind( + test_preamble + + textwrap.dedent(""" structs: bar1: description: foo strict: false chained_types: foo1: alias - """)) + """) + ) def test_chained_type_negative(self): # type: () -> None @@ -1287,33 +1522,41 @@ class TestBinder(testcase.IDLTestcase): serializer: foo deserializer: foo is_view: false - """)) + """), + ) # Chaining with strict struct self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: bar1: description: foo strict: true chained_types: foo1: alias - """), idl.errors.ERROR_ID_CHAINED_NO_TYPE_STRICT) + """), + idl.errors.ERROR_ID_CHAINED_NO_TYPE_STRICT, + ) # Non-'any' type as chained type self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: bar1: description: foo strict: false chained_types: string: alias - """), idl.errors.ERROR_ID_CHAINED_TYPE_WRONG_BSON_TYPE) + """), + idl.errors.ERROR_ID_CHAINED_TYPE_WRONG_BSON_TYPE, + ) # Chaining and fields only with same name self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: bar1: description: foo @@ -1322,11 +1565,14 @@ class TestBinder(testcase.IDLTestcase): foo1: alias fields: foo1: string - """), idl.errors.ERROR_ID_CHAINED_DUPLICATE_FIELD) + """), + idl.errors.ERROR_ID_CHAINED_DUPLICATE_FIELD, + ) # Non-existent chained type self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: bar1: description: foo @@ -1335,11 +1581,15 @@ class TestBinder(testcase.IDLTestcase): foobar1: alias fields: foo1: string - """), idl.errors.ERROR_ID_UNKNOWN_TYPE, True) + """), + idl.errors.ERROR_ID_UNKNOWN_TYPE, + True, + ) # A regular field as a chained type self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: bar1: description: foo @@ -1347,26 +1597,34 @@ class TestBinder(testcase.IDLTestcase): fields: foo1: string foo2: foobar1 - """), idl.errors.ERROR_ID_UNKNOWN_TYPE, True) + """), + idl.errors.ERROR_ID_UNKNOWN_TYPE, + True, + ) # Array of chained types self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: bar1: description: foo strict: true fields: field1: array - """), idl.errors.ERROR_ID_NO_ARRAY_OF_CHAIN) + """), + idl.errors.ERROR_ID_NO_ARRAY_OF_CHAIN, + ) def test_chained_struct_positive(self): # type: () -> None """Positive parser chaining test cases.""" # Setup some common types - test_preamble = self.common_types + indent_text( - 1, - textwrap.dedent(""" + test_preamble = ( + self.common_types + + indent_text( + 1, + textwrap.dedent(""" foo1: description: foo cpp_type: foo @@ -1375,7 +1633,9 @@ class TestBinder(testcase.IDLTestcase): deserializer: foo default: foo is_view: false - """)) + textwrap.dedent(""" + """), + ) + + textwrap.dedent(""" structs: chained: description: foo @@ -1389,22 +1649,29 @@ class TestBinder(testcase.IDLTestcase): fields: field1: string """) + ) # A struct with only chaining - self.assert_bind(test_preamble + indent_text( - 1, - textwrap.dedent(""" + self.assert_bind( + test_preamble + + indent_text( + 1, + textwrap.dedent(""" bar1: description: foo strict: true chained_structs: chained2: alias - """))) + """), + ) + ) # Chaining struct's fields and explicit fields - self.assert_bind(test_preamble + indent_text( - 1, - textwrap.dedent(""" + self.assert_bind( + test_preamble + + indent_text( + 1, + textwrap.dedent(""" bar1: description: foo strict: true @@ -1412,12 +1679,16 @@ class TestBinder(testcase.IDLTestcase): chained2: alias fields: str1: string - """))) + """), + ) + ) # Chained types and structs - self.assert_bind(test_preamble + indent_text( - 1, - textwrap.dedent(""" + self.assert_bind( + test_preamble + + indent_text( + 1, + textwrap.dedent(""" bar1: description: foo strict: false @@ -1427,12 +1698,16 @@ class TestBinder(testcase.IDLTestcase): chained2: alias fields: str1: string - """))) + """), + ) + ) # Non-strict chained struct - self.assert_bind(test_preamble + indent_text( - 1, - textwrap.dedent(""" + self.assert_bind( + test_preamble + + indent_text( + 1, + textwrap.dedent(""" bar1: description: foo strict: false @@ -1440,12 +1715,16 @@ class TestBinder(testcase.IDLTestcase): chained2: alias fields: foo1: string - """))) + """), + ) + ) # Inline Chained struct with strict true - self.assert_bind(test_preamble + indent_text( - 1, - textwrap.dedent(""" + self.assert_bind( + test_preamble + + indent_text( + 1, + textwrap.dedent(""" bar1: description: foo strict: true @@ -1461,12 +1740,16 @@ class TestBinder(testcase.IDLTestcase): fields: f1: string - """))) + """), + ) + ) # Inline Chained struct with strict true and inline_chained_structs defaulted - self.assert_bind(test_preamble + indent_text( - 1, - textwrap.dedent(""" + self.assert_bind( + test_preamble + + indent_text( + 1, + textwrap.dedent(""" bar1: description: foo strict: true @@ -1480,15 +1763,19 @@ class TestBinder(testcase.IDLTestcase): bar1: alias fields: f1: string - """))) + """), + ) + ) def test_chained_struct_negative(self): # type: () -> None """Negative parser chaining test cases.""" # Setup some common types - test_preamble = self.common_types + indent_text( - 1, - textwrap.dedent(""" + test_preamble = ( + self.common_types + + indent_text( + 1, + textwrap.dedent(""" foo1: description: foo cpp_type: foo @@ -1497,7 +1784,9 @@ class TestBinder(testcase.IDLTestcase): deserializer: foo default: foo is_view: false - """)) + textwrap.dedent(""" + """), + ) + + textwrap.dedent(""" structs: chained: description: foo @@ -1511,10 +1800,12 @@ class TestBinder(testcase.IDLTestcase): fields: field1: string """) + ) # Non-existing chained struct self.assert_bind_fail( - test_preamble + indent_text( + test_preamble + + indent_text( 1, textwrap.dedent(""" bar1: @@ -1522,11 +1813,16 @@ class TestBinder(testcase.IDLTestcase): strict: true chained_structs: foobar1: alias - """)), idl.errors.ERROR_ID_UNKNOWN_TYPE, True) + """), + ), + idl.errors.ERROR_ID_UNKNOWN_TYPE, + True, + ) # Type as chained struct self.assert_bind_fail( - test_preamble + indent_text( + test_preamble + + indent_text( 1, textwrap.dedent(""" bar1: @@ -1534,11 +1830,15 @@ class TestBinder(testcase.IDLTestcase): strict: true chained_structs: foo1: alias - """)), idl.errors.ERROR_ID_CHAINED_STRUCT_NOT_FOUND) + """), + ), + idl.errors.ERROR_ID_CHAINED_STRUCT_NOT_FOUND, + ) # Struct as chained type self.assert_bind_fail( - test_preamble + indent_text( + test_preamble + + indent_text( 1, textwrap.dedent(""" bar1: @@ -1546,11 +1846,15 @@ class TestBinder(testcase.IDLTestcase): strict: false chained_types: chained: alias - """)), idl.errors.ERROR_ID_CHAINED_TYPE_NOT_FOUND) + """), + ), + idl.errors.ERROR_ID_CHAINED_TYPE_NOT_FOUND, + ) # Duplicated field names across chained struct's fields and fields self.assert_bind_fail( - test_preamble + indent_text( + test_preamble + + indent_text( 1, textwrap.dedent(""" bar1: @@ -1560,11 +1864,15 @@ class TestBinder(testcase.IDLTestcase): chained: alias fields: field1: string - """)), idl.errors.ERROR_ID_CHAINED_DUPLICATE_FIELD) + """), + ), + idl.errors.ERROR_ID_CHAINED_DUPLICATE_FIELD, + ) # Duplicated field names across chained structs self.assert_bind_fail( - test_preamble + indent_text( + test_preamble + + indent_text( 1, textwrap.dedent(""" bar1: @@ -1573,11 +1881,15 @@ class TestBinder(testcase.IDLTestcase): chained_structs: chained: alias chained2: alias - """)), idl.errors.ERROR_ID_CHAINED_DUPLICATE_FIELD) + """), + ), + idl.errors.ERROR_ID_CHAINED_DUPLICATE_FIELD, + ) # Chained struct with strict true self.assert_bind_fail( - test_preamble + indent_text( + test_preamble + + indent_text( 1, textwrap.dedent(""" bar1: @@ -1595,11 +1907,15 @@ class TestBinder(testcase.IDLTestcase): fields: f1: string - """)), idl.errors.ERROR_ID_CHAINED_NO_NESTED_STRUCT_STRICT) + """), + ), + idl.errors.ERROR_ID_CHAINED_NO_NESTED_STRUCT_STRICT, + ) # Chained struct with nested chained struct self.assert_bind_fail( - test_preamble + indent_text( + test_preamble + + indent_text( 1, textwrap.dedent(""" bar1: @@ -1616,11 +1932,15 @@ class TestBinder(testcase.IDLTestcase): fields: f1: string - """)), idl.errors.ERROR_ID_CHAINED_NO_NESTED_CHAINED) + """), + ), + idl.errors.ERROR_ID_CHAINED_NO_NESTED_CHAINED, + ) # Chained struct with nested chained type self.assert_bind_fail( - test_preamble + indent_text( + test_preamble + + indent_text( 1, textwrap.dedent(""" bar1: @@ -1637,7 +1957,10 @@ class TestBinder(testcase.IDLTestcase): fields: f1: bar1 - """)), idl.errors.ERROR_ID_CHAINED_NO_NESTED_CHAINED) + """), + ), + idl.errors.ERROR_ID_CHAINED_NO_NESTED_CHAINED, + ) def test_enum_positive(self): # type: () -> None @@ -1654,7 +1977,8 @@ class TestBinder(testcase.IDLTestcase): v1: 3 v2: 1 v3: 2 - """)) + """) + ) # Test int - non continuous self.assert_bind( @@ -1666,7 +1990,8 @@ class TestBinder(testcase.IDLTestcase): values: v1: 0 v3: 2 - """)) + """) + ) # Test string self.assert_bind( @@ -1679,7 +2004,8 @@ class TestBinder(testcase.IDLTestcase): v1: 0 v2: 1 v3: 2 - """)) + """) + ) def test_enum_negative(self): # type: () -> None @@ -1694,7 +2020,9 @@ class TestBinder(testcase.IDLTestcase): type: foo values: v1: 0 - """), idl.errors.ERROR_ID_ENUM_BAD_TYPE) + """), + idl.errors.ERROR_ID_ENUM_BAD_TYPE, + ) # Test int - dups self.assert_bind_fail( @@ -1706,7 +2034,9 @@ class TestBinder(testcase.IDLTestcase): values: v1: 1 v3: 1 - """), idl.errors.ERROR_ID_ENUM_NON_UNIQUE_VALUES) + """), + idl.errors.ERROR_ID_ENUM_NON_UNIQUE_VALUES, + ) # Test int - non-integer value self.assert_bind_fail( @@ -1718,7 +2048,9 @@ class TestBinder(testcase.IDLTestcase): values: v1: foo v3: 1 - """), idl.errors.ERROR_ID_ENUM_BAD_INT_VAUE) + """), + idl.errors.ERROR_ID_ENUM_BAD_INT_VAUE, + ) # Test string - dups self.assert_bind_fail( @@ -1730,7 +2062,9 @@ class TestBinder(testcase.IDLTestcase): values: v1: foo v3: foo - """), idl.errors.ERROR_ID_ENUM_NON_UNIQUE_VALUES) + """), + idl.errors.ERROR_ID_ENUM_NON_UNIQUE_VALUES, + ) def test_struct_enum_negative(self): # type: () -> None @@ -1748,13 +2082,17 @@ class TestBinder(testcase.IDLTestcase): # Test array of enums self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: foo1: description: foo fields: foo1: array - """), idl.errors.ERROR_ID_NO_ARRAY_ENUM, True) + """), + idl.errors.ERROR_ID_NO_ARRAY_ENUM, + True, + ) def test_command_positive(self): # type: () -> None @@ -1769,7 +2107,9 @@ class TestBinder(testcase.IDLTestcase): foo: string """) - self.assert_bind(test_preamble + textwrap.dedent(""" + self.assert_bind( + test_preamble + + textwrap.dedent(""" commands: foo: description: foo @@ -1780,7 +2120,8 @@ class TestBinder(testcase.IDLTestcase): fields: foo1: string reply_type: reply - """)) + """) + ) def test_command_negative(self): # type: () -> None @@ -1790,7 +2131,8 @@ class TestBinder(testcase.IDLTestcase): test_preamble = self.common_types # Commands cannot be fields in other commands self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" commands: foo: description: foo @@ -1807,11 +2149,14 @@ class TestBinder(testcase.IDLTestcase): api_version: "" fields: foo: foo - """), idl.errors.ERROR_ID_FIELD_NO_COMMAND) + """), + idl.errors.ERROR_ID_FIELD_NO_COMMAND, + ) # Commands cannot be fields in structs self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" commands: foo: description: foo @@ -1826,11 +2171,14 @@ class TestBinder(testcase.IDLTestcase): description: foo fields: foo: foo - """), idl.errors.ERROR_ID_FIELD_NO_COMMAND) + """), + idl.errors.ERROR_ID_FIELD_NO_COMMAND, + ) # Commands cannot have a field as the same name self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" commands: foo: description: foo @@ -1839,11 +2187,14 @@ class TestBinder(testcase.IDLTestcase): api_version: "" fields: foo: string - """), idl.errors.ERROR_ID_COMMAND_DUPLICATES_FIELD) + """), + idl.errors.ERROR_ID_COMMAND_DUPLICATES_FIELD, + ) # Reply type must be resolvable self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" commands: foo: description: foo @@ -1851,11 +2202,14 @@ class TestBinder(testcase.IDLTestcase): namespace: ignored api_version: "" reply_type: not_defined - """), idl.errors.ERROR_ID_UNKNOWN_TYPE) + """), + idl.errors.ERROR_ID_UNKNOWN_TYPE, + ) # Reply type must be a struct self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" commands: foo: description: foo @@ -1863,7 +2217,9 @@ class TestBinder(testcase.IDLTestcase): namespace: ignored api_version: "" reply_type: string - """), idl.errors.ERROR_ID_INVALID_REPLY_TYPE) + """), + idl.errors.ERROR_ID_INVALID_REPLY_TYPE, + ) def test_command_doc_sequence_positive(self): # type: () -> None @@ -1879,7 +2235,9 @@ class TestBinder(testcase.IDLTestcase): foo: object """) - self.assert_bind(test_preamble + textwrap.dedent(""" + self.assert_bind( + test_preamble + + textwrap.dedent(""" commands: foo: description: foo @@ -1890,9 +2248,12 @@ class TestBinder(testcase.IDLTestcase): foo1: type: array supports_doc_sequence: true - """)) + """) + ) - self.assert_bind(test_preamble + textwrap.dedent(""" + self.assert_bind( + test_preamble + + textwrap.dedent(""" commands: foo: description: foo @@ -1903,7 +2264,8 @@ class TestBinder(testcase.IDLTestcase): foo1: type: array supports_doc_sequence: true - """)) + """) + ) def test_command_doc_sequence_negative(self): # type: () -> None @@ -1923,7 +2285,8 @@ class TestBinder(testcase.IDLTestcase): # A struct self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" structs: foo: description: foo @@ -1931,11 +2294,14 @@ class TestBinder(testcase.IDLTestcase): foo: type: array supports_doc_sequence: true - """), idl.errors.ERROR_ID_STRUCT_NO_DOC_SEQUENCE) + """), + idl.errors.ERROR_ID_STRUCT_NO_DOC_SEQUENCE, + ) # A non-array type self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" commands: foo: description: foo @@ -1946,11 +2312,14 @@ class TestBinder(testcase.IDLTestcase): foo: type: object supports_doc_sequence: true - """), idl.errors.ERROR_ID_NO_DOC_SEQUENCE_FOR_NON_ARRAY) + """), + idl.errors.ERROR_ID_NO_DOC_SEQUENCE_FOR_NON_ARRAY, + ) # An array of a scalar self.assert_bind_fail( - test_preamble2 + textwrap.dedent(""" + test_preamble2 + + textwrap.dedent(""" commands: foo: description: foo @@ -1961,11 +2330,14 @@ class TestBinder(testcase.IDLTestcase): foo1: type: array supports_doc_sequence: true - """), idl.errors.ERROR_ID_NO_DOC_SEQUENCE_FOR_NON_OBJECT) + """), + idl.errors.ERROR_ID_NO_DOC_SEQUENCE_FOR_NON_OBJECT, + ) # An array of 'any' self.assert_bind_fail( - test_preamble2 + textwrap.dedent(""" + test_preamble2 + + textwrap.dedent(""" commands: foo: description: foo @@ -1976,7 +2348,9 @@ class TestBinder(testcase.IDLTestcase): foo1: type: array supports_doc_sequence: true - """), idl.errors.ERROR_ID_NO_DOC_SEQUENCE_FOR_NON_OBJECT) + """), + idl.errors.ERROR_ID_NO_DOC_SEQUENCE_FOR_NON_OBJECT, + ) def test_command_type_positive(self): # type: () -> None @@ -1985,7 +2359,9 @@ class TestBinder(testcase.IDLTestcase): test_preamble = self.common_types # string - self.assert_bind(test_preamble + textwrap.dedent(""" + self.assert_bind( + test_preamble + + textwrap.dedent(""" commands: foo: description: foo @@ -1996,10 +2372,13 @@ class TestBinder(testcase.IDLTestcase): type: string fields: field1: string - """)) + """) + ) # array of string - self.assert_bind(test_preamble + textwrap.dedent(""" + self.assert_bind( + test_preamble + + textwrap.dedent(""" commands: foo: description: foo @@ -2010,7 +2389,8 @@ class TestBinder(testcase.IDLTestcase): type: array fields: field1: string - """)) + """) + ) def test_command_type_negative(self): # type: () -> None @@ -2020,7 +2400,8 @@ class TestBinder(testcase.IDLTestcase): # supports_doc_sequence must be a bool self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" commands: foo: description: foo @@ -2030,7 +2411,10 @@ class TestBinder(testcase.IDLTestcase): type: int fields: field1: string - """), idl.errors.ERROR_ID_UNKNOWN_TYPE, True) + """), + idl.errors.ERROR_ID_UNKNOWN_TYPE, + True, + ) def test_server_parameter_positive(self): # type: () -> None @@ -2039,19 +2423,24 @@ class TestBinder(testcase.IDLTestcase): # server parameter with storage. # Also try valid set_at values. for set_at in ["startup", "runtime", "[ startup, runtime ]", "cluster"]: - if set_at != 'cluster': + if set_at != "cluster": self.assert_bind( - textwrap.dedent(""" + textwrap.dedent( + """ server_parameters: foo: set_at: %s description: bar redact: false cpp_varname: baz - """ % (set_at))) + """ + % (set_at) + ) + ) else: self.assert_bind( - textwrap.dedent(""" + textwrap.dedent( + """ server_parameters: foo: set_at: %s @@ -2059,7 +2448,10 @@ class TestBinder(testcase.IDLTestcase): redact: false cpp_varname: baz omit_in_ftdc: false - """ % (set_at))) + """ + % (set_at) + ) + ) # server parameter with storage and optional fields. self.assert_bind( @@ -2078,7 +2470,8 @@ class TestBinder(testcase.IDLTestcase): lte: 999 lt: 1000 callback: qux - """)) + """) + ) # Cluster server parameter with storage. self.assert_bind( @@ -2098,7 +2491,8 @@ class TestBinder(testcase.IDLTestcase): lte: 999 lt: 1000 callback: qux - """)) + """) + ) # Bound setting with arbitrary expression default and validators. self.assert_bind( @@ -2120,7 +2514,8 @@ class TestBinder(testcase.IDLTestcase): is_constexpr: false gt: 0 lt: 255 - """)) + """) + ) # Specialized SCPs. self.assert_bind( @@ -2131,7 +2526,8 @@ class TestBinder(testcase.IDLTestcase): description: bar redact: false cpp_class: baz - """)) + """) + ) self.assert_bind( textwrap.dedent(""" @@ -2142,7 +2538,8 @@ class TestBinder(testcase.IDLTestcase): redact: false cpp_class: name: baz - """)) + """) + ) self.assert_bind( textwrap.dedent(""" @@ -2157,7 +2554,8 @@ class TestBinder(testcase.IDLTestcase): override_set: true override_ctor: false override_validate: true - """)) + """) + ) self.assert_bind( textwrap.dedent(""" @@ -2170,7 +2568,8 @@ class TestBinder(testcase.IDLTestcase): redact: true test_only: true deprecated_name: bling - """)) + """) + ) self.assert_bind( textwrap.dedent(""" @@ -2186,7 +2585,8 @@ class TestBinder(testcase.IDLTestcase): omit_in_ftdc: true test_only: true deprecated_name: bling - """)) + """) + ) # Default without data. self.assert_bind( @@ -2198,7 +2598,8 @@ class TestBinder(testcase.IDLTestcase): redact: false cpp_class: baz default: blong - """)) + """) + ) def test_server_parameter_negative(self): # type: () -> None @@ -2213,7 +2614,9 @@ class TestBinder(testcase.IDLTestcase): description: bar redact: false cpp_varname: baz - """), idl.errors.ERROR_ID_BAD_SETAT_SPECIFIER) + """), + idl.errors.ERROR_ID_BAD_SETAT_SPECIFIER, + ) # Mix of specialized with bound storage. self.assert_bind_fail( @@ -2225,7 +2628,9 @@ class TestBinder(testcase.IDLTestcase): redact: false cpp_class: baz cpp_varname: bling - """), idl.errors.ERROR_ID_SERVER_PARAMETER_INVALID_ATTR) + """), + idl.errors.ERROR_ID_SERVER_PARAMETER_INVALID_ATTR, + ) # Startup with omit_in_ftdc=true. self.assert_bind_fail( @@ -2237,7 +2642,9 @@ class TestBinder(testcase.IDLTestcase): cpp_varname: baz redact: false omit_in_ftdc: true - """), idl.errors.ERROR_ID_SERVER_PARAMETER_INVALID_ATTR) + """), + idl.errors.ERROR_ID_SERVER_PARAMETER_INVALID_ATTR, + ) # Startup with omit_in_ftdc=false. self.assert_bind_fail( @@ -2249,7 +2656,9 @@ class TestBinder(testcase.IDLTestcase): cpp_varname: baz redact: false omit_in_ftdc: false - """), idl.errors.ERROR_ID_SERVER_PARAMETER_INVALID_ATTR) + """), + idl.errors.ERROR_ID_SERVER_PARAMETER_INVALID_ATTR, + ) # Runtime with omit_in_ftdc=true. self.assert_bind_fail( @@ -2261,7 +2670,9 @@ class TestBinder(testcase.IDLTestcase): cpp_varname: baz redact: false omit_in_ftdc: true - """), idl.errors.ERROR_ID_SERVER_PARAMETER_INVALID_ATTR) + """), + idl.errors.ERROR_ID_SERVER_PARAMETER_INVALID_ATTR, + ) # Runtime with omit_in_ftdc=false. self.assert_bind_fail( @@ -2273,7 +2684,9 @@ class TestBinder(testcase.IDLTestcase): cpp_varname: baz redact: false omit_in_ftdc: false - """), idl.errors.ERROR_ID_SERVER_PARAMETER_INVALID_ATTR) + """), + idl.errors.ERROR_ID_SERVER_PARAMETER_INVALID_ATTR, + ) # Cluster with omit_in_ftdc unspecified. self.assert_bind_fail( @@ -2292,7 +2705,9 @@ class TestBinder(testcase.IDLTestcase): lte: 999 lt: 1000 callback: qux - """), idl.errors.ERROR_ID_SERVER_PARAMETER_REQUIRED_ATTR) + """), + idl.errors.ERROR_ID_SERVER_PARAMETER_REQUIRED_ATTR, + ) def test_config_option_positive(self): # type: () -> None @@ -2324,7 +2739,8 @@ class TestBinder(testcase.IDLTestcase): gte: 1 lte: 99 callback: doSomething - """)) + """) + ) # Required fields only. self.assert_bind( @@ -2334,7 +2750,8 @@ class TestBinder(testcase.IDLTestcase): description: comment arg_vartype: Switch source: yaml - """)) + """) + ) # List and enum variants. self.assert_bind( @@ -2350,12 +2767,14 @@ class TestBinder(testcase.IDLTestcase): requires: [ d, e, f ] hidden: true duplicate_behavior: overwrite - """)) + """) + ) # Positional variants. - for positional in ['1', '1-', '-2', '1-2']: + for positional in ["1", "1-", "-2", "1-2"]: self.assert_bind( - textwrap.dedent(""" + textwrap.dedent( + """ configs: foo: short_name: foo @@ -2363,17 +2782,24 @@ class TestBinder(testcase.IDLTestcase): arg_vartype: Bool source: cli positional: %s - """ % (positional))) + """ + % (positional) + ) + ) # With implicit short name. self.assert_bind( - textwrap.dedent(""" + textwrap.dedent( + """ configs: foo: description: comment arg_vartype: Bool source: cli positional: %s - """ % (positional))) + """ + % (positional) + ) + ) # Expressions in default, implicit, and validators. self.assert_bind( @@ -2388,7 +2814,8 @@ class TestBinder(testcase.IDLTestcase): validator: gte: { expr: kMinimum } lte: { expr: kMaximum } - """)) + """) + ) def test_config_option_negative(self): # type: () -> None @@ -2402,7 +2829,9 @@ class TestBinder(testcase.IDLTestcase): description: comment arg_vartype: Long source: json - """), idl.errors.ERROR_ID_BAD_SOURCE_SPECIFIER) + """), + idl.errors.ERROR_ID_BAD_SOURCE_SPECIFIER, + ) self.assert_bind_fail( textwrap.dedent(""" @@ -2412,18 +2841,25 @@ class TestBinder(testcase.IDLTestcase): arg_vartype: StringMap source: [ cli, yaml ] duplicate_behavior: guess - """), idl.errors.ERROR_ID_BAD_DUPLICATE_BEHAVIOR_SPECIFIER) + """), + idl.errors.ERROR_ID_BAD_DUPLICATE_BEHAVIOR_SPECIFIER, + ) for positional in ["x", "1-2-3", "-2-", "1--3"]: self.assert_bind_fail( - textwrap.dedent(""" + textwrap.dedent( + """ configs: foo: description: comment arg_vartype: String source: cli positional: %s - """ % (positional)), idl.errors.ERROR_ID_BAD_NUMERIC_RANGE) + """ + % (positional) + ), + idl.errors.ERROR_ID_BAD_NUMERIC_RANGE, + ) self.assert_bind_fail( textwrap.dedent(""" @@ -2433,7 +2869,9 @@ class TestBinder(testcase.IDLTestcase): short_name: "bar.baz" arg_vartype: Bool source: cli - """), idl.errors.ERROR_ID_INVALID_SHORT_NAME) + """), + idl.errors.ERROR_ID_INVALID_SHORT_NAME, + ) self.assert_bind_fail( textwrap.dedent(""" @@ -2444,7 +2882,9 @@ class TestBinder(testcase.IDLTestcase): deprecated_short_name: "baz.qux" arg_vartype: Long source: cli - """), idl.errors.ERROR_ID_INVALID_SHORT_NAME) + """), + idl.errors.ERROR_ID_INVALID_SHORT_NAME, + ) # dottedName is not valid as a shortName. self.assert_bind_fail( @@ -2455,7 +2895,9 @@ class TestBinder(testcase.IDLTestcase): arg_vartype: String source: cli positional: 1 - """), idl.errors.ERROR_ID_MISSING_SHORTNAME_FOR_POSITIONAL) + """), + idl.errors.ERROR_ID_MISSING_SHORTNAME_FOR_POSITIONAL, + ) # Invalid shortname using boost::po format directly. self.assert_bind_fail( @@ -2466,19 +2908,26 @@ class TestBinder(testcase.IDLTestcase): arg_vartype: Switch description: comment source: cli - """), idl.errors.ERROR_ID_INVALID_SHORT_NAME) + """), + idl.errors.ERROR_ID_INVALID_SHORT_NAME, + ) # Invalid single names, must be single alpha char. for name in ["foo", "1", ".", ""]: self.assert_bind_fail( - textwrap.dedent(""" + textwrap.dedent( + """ configs: foo: single_name: "%s" arg_vartype: Switch description: comment source: cli - """ % (name)), idl.errors.ERROR_ID_INVALID_SINGLE_NAME) + """ + % (name) + ), + idl.errors.ERROR_ID_INVALID_SINGLE_NAME, + ) # Single names require a valid short name. self.assert_bind_fail( @@ -2489,7 +2938,9 @@ class TestBinder(testcase.IDLTestcase): arg_vartype: Switch description: comment source: cli - """), idl.errors.ERROR_ID_MISSING_SHORT_NAME_WITH_SINGLE_NAME) + """), + idl.errors.ERROR_ID_MISSING_SHORT_NAME_WITH_SINGLE_NAME, + ) def test_feature_flag(self): # type: () -> None @@ -2504,7 +2955,8 @@ class TestBinder(testcase.IDLTestcase): cpp_varname: gToaster default: false shouldBeFCVGated: false - """)) + """) + ) self.assert_bind( textwrap.dedent(""" @@ -2514,7 +2966,8 @@ class TestBinder(testcase.IDLTestcase): cpp_varname: gToaster default: false shouldBeFCVGated: true - """)) + """) + ) # if shouldBeFCVGated: true, feature flag can default to true with a version self.assert_bind( @@ -2526,7 +2979,8 @@ class TestBinder(testcase.IDLTestcase): default: true version: 123 shouldBeFCVGated: true - """)) + """) + ) # if shouldBeFCVGated: false, we do not need a version self.assert_bind( @@ -2537,7 +2991,8 @@ class TestBinder(testcase.IDLTestcase): cpp_varname: gToaster default: true shouldBeFCVGated: false - """)) + """) + ) # if shouldBeFCVGated: true and default: true, a version is required self.assert_bind_fail( @@ -2548,7 +3003,9 @@ class TestBinder(testcase.IDLTestcase): cpp_varname: gToaster default: true shouldBeFCVGated: true - """), idl.errors.ERROR_ID_FEATURE_FLAG_DEFAULT_TRUE_MISSING_VERSION) + """), + idl.errors.ERROR_ID_FEATURE_FLAG_DEFAULT_TRUE_MISSING_VERSION, + ) # false is not allowed with a version and shouldBeFCVGated: true self.assert_bind_fail( @@ -2560,7 +3017,9 @@ class TestBinder(testcase.IDLTestcase): default: false version: 123 shouldBeFCVGated: true - """), idl.errors.ERROR_ID_FEATURE_FLAG_DEFAULT_FALSE_HAS_VERSION) + """), + idl.errors.ERROR_ID_FEATURE_FLAG_DEFAULT_FALSE_HAS_VERSION, + ) # false is not allowed with a version and shouldBeFCVGated: false self.assert_bind_fail( @@ -2572,7 +3031,9 @@ class TestBinder(testcase.IDLTestcase): default: false version: 123 shouldBeFCVGated: false - """), idl.errors.ERROR_ID_FEATURE_FLAG_DEFAULT_FALSE_HAS_VERSION) + """), + idl.errors.ERROR_ID_FEATURE_FLAG_DEFAULT_FALSE_HAS_VERSION, + ) # if shouldBeFCVGated is false, a version is not allowed self.assert_bind_fail( @@ -2584,7 +3045,9 @@ class TestBinder(testcase.IDLTestcase): default: true version: 123 shouldBeFCVGated: false - """), idl.errors.ERROR_ID_FEATURE_FLAG_SHOULD_BE_FCV_GATED_FALSE_HAS_VERSION) + """), + idl.errors.ERROR_ID_FEATURE_FLAG_SHOULD_BE_FCV_GATED_FALSE_HAS_VERSION, + ) def test_access_check(self): # type: () -> None @@ -2620,7 +3083,9 @@ class TestBinder(testcase.IDLTestcase): """) # Test none - self.assert_bind(test_preamble + textwrap.dedent(""" + self.assert_bind( + test_preamble + + textwrap.dedent(""" commands: test1: description: foo @@ -2632,10 +3097,13 @@ class TestBinder(testcase.IDLTestcase): fields: foo: string reply_type: reply - """)) + """) + ) # Test simple with access check - self.assert_bind(test_preamble + textwrap.dedent(""" + self.assert_bind( + test_preamble + + textwrap.dedent(""" commands: test1: description: foo @@ -2648,10 +3116,13 @@ class TestBinder(testcase.IDLTestcase): fields: foo: string reply_type: reply - """)) + """) + ) # Test simple with privilege - self.assert_bind(test_preamble + textwrap.dedent(""" + self.assert_bind( + test_preamble + + textwrap.dedent(""" commands: test1: description: foo @@ -2666,7 +3137,8 @@ class TestBinder(testcase.IDLTestcase): fields: foo: string reply_type: reply - """)) + """) + ) self.assert_parse( textwrap.dedent(""" @@ -2688,7 +3160,8 @@ class TestBinder(testcase.IDLTestcase): fields: foo: bar reply_type: foo_reply_struct - """)) + """) + ) def test_access_check_negative(self): # type: () -> None @@ -2724,7 +3197,8 @@ class TestBinder(testcase.IDLTestcase): # Test simple with bad access check self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" commands: test1: description: foo @@ -2737,11 +3211,14 @@ class TestBinder(testcase.IDLTestcase): fields: foo: string reply_type: reply - """), idl.errors.ERROR_ID_UNKOWN_ENUM_VALUE) + """), + idl.errors.ERROR_ID_UNKOWN_ENUM_VALUE, + ) # Test simple with bad access check with privilege self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" commands: test1: description: foo @@ -2756,11 +3233,14 @@ class TestBinder(testcase.IDLTestcase): fields: foo: string reply_type: reply - """), idl.errors.ERROR_ID_UNKOWN_ENUM_VALUE) + """), + idl.errors.ERROR_ID_UNKOWN_ENUM_VALUE, + ) # Test simple with bad access check with privilege self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" commands: test1: description: foo @@ -2775,10 +3255,14 @@ class TestBinder(testcase.IDLTestcase): fields: foo: string reply_type: reply - """), idl.errors.ERROR_ID_UNKOWN_ENUM_VALUE) + """), + idl.errors.ERROR_ID_UNKOWN_ENUM_VALUE, + ) # Test simple with access check and privileges - self.assert_bind(test_preamble + textwrap.dedent(""" + self.assert_bind( + test_preamble + + textwrap.dedent(""" commands: test1: description: foo @@ -2793,11 +3277,13 @@ class TestBinder(testcase.IDLTestcase): fields: foo: string reply_type: reply - """)) + """) + ) # Test simple with privilege with duplicate action_type self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" commands: test1: description: foo @@ -2812,11 +3298,14 @@ class TestBinder(testcase.IDLTestcase): fields: foo: string reply_type: reply - """), idl.errors.ERROR_ID_DUPLICATE_ACTION_TYPE) + """), + idl.errors.ERROR_ID_DUPLICATE_ACTION_TYPE, + ) # complex with duplicate check self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" commands: test1: description: foo @@ -2830,11 +3319,14 @@ class TestBinder(testcase.IDLTestcase): fields: foo: string reply_type: reply - """), idl.errors.ERROR_ID_DUPLICATE_ACCESS_CHECK) + """), + idl.errors.ERROR_ID_DUPLICATE_ACCESS_CHECK, + ) # complex with duplicate priv self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" commands: test1: description: foo @@ -2852,11 +3344,14 @@ class TestBinder(testcase.IDLTestcase): fields: foo: string reply_type: reply - """), idl.errors.ERROR_ID_DUPLICATE_ACCESS_CHECK) + """), + idl.errors.ERROR_ID_DUPLICATE_ACCESS_CHECK, + ) # api_version != "" but not access_check self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" commands: test1: description: foo @@ -2866,10 +3361,14 @@ class TestBinder(testcase.IDLTestcase): fields: foo: string reply_type: reply - """), idl.errors.ERROR_ID_MISSING_ACCESS_CHECK) + """), + idl.errors.ERROR_ID_MISSING_ACCESS_CHECK, + ) def test_query_shape_component_validation(self): - self.assert_bind(self.common_types + textwrap.dedent(""" + self.assert_bind( + self.common_types + + textwrap.dedent(""" structs: struct1: query_shape_component: true @@ -2882,10 +3381,12 @@ class TestBinder(testcase.IDLTestcase): field2: type: bool query_shape: parameter - """)) + """) + ) self.assert_bind_fail( - self.common_types + textwrap.dedent(""" + self.common_types + + textwrap.dedent(""" structs: struct1: query_shape_component: true @@ -2897,10 +3398,13 @@ class TestBinder(testcase.IDLTestcase): field2: type: bool query_shape: parameter - """), idl.errors.ERROR_ID_FIELD_MUST_DECLARE_SHAPE_LITERAL) + """), + idl.errors.ERROR_ID_FIELD_MUST_DECLARE_SHAPE_LITERAL, + ) self.assert_bind_fail( - self.common_types + textwrap.dedent(""" + self.common_types + + textwrap.dedent(""" structs: struct1: strict: true @@ -2911,7 +3415,9 @@ class TestBinder(testcase.IDLTestcase): field2: type: bool query_shape: parameter - """), idl.errors.ERROR_ID_CANNOT_DECLARE_SHAPE_LITERAL) + """), + idl.errors.ERROR_ID_CANNOT_DECLARE_SHAPE_LITERAL, + ) # Validating query_shape_anonymize relies on std::string basic_types = textwrap.dedent(""" @@ -2935,7 +3441,9 @@ class TestBinder(testcase.IDLTestcase): internal_only: true is_view: false """) - self.assert_bind(basic_types + textwrap.dedent(""" + self.assert_bind( + basic_types + + textwrap.dedent(""" structs: struct1: query_shape_component: true @@ -2948,9 +3456,12 @@ class TestBinder(testcase.IDLTestcase): field2: query_shape: parameter type: bool - """)) + """) + ) - self.assert_bind(basic_types + textwrap.dedent(""" + self.assert_bind( + basic_types + + textwrap.dedent(""" structs: struct1: query_shape_component: true @@ -2963,10 +3474,12 @@ class TestBinder(testcase.IDLTestcase): field2: query_shape: parameter type: bool - """)) + """) + ) self.assert_bind_fail( - basic_types + textwrap.dedent(""" + basic_types + + textwrap.dedent(""" structs: struct1: strict: true @@ -2975,10 +3488,13 @@ class TestBinder(testcase.IDLTestcase): field1: query_shape: blah type: string - """), idl.errors.ERROR_ID_QUERY_SHAPE_INVALID_VALUE) + """), + idl.errors.ERROR_ID_QUERY_SHAPE_INVALID_VALUE, + ) self.assert_bind_fail( - basic_types + textwrap.dedent(""" + basic_types + + textwrap.dedent(""" structs: struct1: query_shape_component: true @@ -2991,10 +3507,13 @@ class TestBinder(testcase.IDLTestcase): field2: query_shape: parameter type: bool - """), idl.errors.ERROR_ID_INVALID_TYPE_FOR_SHAPIFY) + """), + idl.errors.ERROR_ID_INVALID_TYPE_FOR_SHAPIFY, + ) self.assert_bind_fail( - basic_types + textwrap.dedent(""" + basic_types + + textwrap.dedent(""" structs: struct1: query_shape_component: true @@ -3007,10 +3526,13 @@ class TestBinder(testcase.IDLTestcase): field2: query_shape: parameter type: bool - """), idl.errors.ERROR_ID_INVALID_TYPE_FOR_SHAPIFY) + """), + idl.errors.ERROR_ID_INVALID_TYPE_FOR_SHAPIFY, + ) self.assert_bind_fail( - basic_types + textwrap.dedent(""" + basic_types + + textwrap.dedent(""" structs: StructZero: strict: true @@ -3019,10 +3541,13 @@ class TestBinder(testcase.IDLTestcase): field1: query_shape: literal type: string - """), idl.errors.ERROR_ID_CANNOT_DECLARE_SHAPE_LITERAL) + """), + idl.errors.ERROR_ID_CANNOT_DECLARE_SHAPE_LITERAL, + ) self.assert_bind_fail( - basic_types + textwrap.dedent(""" + basic_types + + textwrap.dedent(""" structs: StructZero: strict: true @@ -3039,16 +3564,19 @@ class TestBinder(testcase.IDLTestcase): type: StructZero description: "" query_shape: literal - """), idl.errors.ERROR_ID_CANNOT_DECLARE_SHAPE_LITERAL) + """), + idl.errors.ERROR_ID_CANNOT_DECLARE_SHAPE_LITERAL, + ) # pylint: disable=invalid-name - def test_struct_unsafe_dangerous_disable_extra_field_duplicate_checks_negative(self): + def test_struct_unsafe_dangerous_disable_extra_field_duplicate_checks_negative( + self, + ): # type: () -> None """Negative struct tests for unsafe_dangerous_disable_extra_field_duplicate_checks.""" # Setup some common types - test_preamble = self.common_types + \ - textwrap.dedent(""" + test_preamble = self.common_types + textwrap.dedent(""" structs: danger: description: foo @@ -3060,7 +3588,8 @@ class TestBinder(testcase.IDLTestcase): # Test strict and unsafe_dangerous_disable_extra_field_duplicate_checks are not allowed self.assert_bind_fail( - test_preamble + indent_text( + test_preamble + + indent_text( 1, textwrap.dedent(""" danger1: @@ -3069,11 +3598,15 @@ class TestBinder(testcase.IDLTestcase): unsafe_dangerous_disable_extra_field_duplicate_checks: true fields: foo: string - """)), idl.errors.ERROR_ID_STRICT_AND_DISABLE_CHECK_NOT_ALLOWED) + """), + ), + idl.errors.ERROR_ID_STRICT_AND_DISABLE_CHECK_NOT_ALLOWED, + ) # Test inheritance is prohibited through structs self.assert_bind_fail( - test_preamble + indent_text( + test_preamble + + indent_text( 1, textwrap.dedent(""" danger2: @@ -3082,11 +3615,15 @@ class TestBinder(testcase.IDLTestcase): fields: foo: string d1: danger - """)), idl.errors.ERROR_ID_INHERITANCE_AND_DISABLE_CHECK_NOT_ALLOWED) + """), + ), + idl.errors.ERROR_ID_INHERITANCE_AND_DISABLE_CHECK_NOT_ALLOWED, + ) # Test inheritance is prohibited through commands self.assert_bind_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" commands: dangerc: description: foo @@ -3097,9 +3634,10 @@ class TestBinder(testcase.IDLTestcase): fields: foo: string d1: danger - """), idl.errors.ERROR_ID_INHERITANCE_AND_DISABLE_CHECK_NOT_ALLOWED) + """), + idl.errors.ERROR_ID_INHERITANCE_AND_DISABLE_CHECK_NOT_ALLOWED, + ) -if __name__ == '__main__': - +if __name__ == "__main__": unittest.main() diff --git a/buildscripts/idl/tests/test_compatibility.py b/buildscripts/idl/tests/test_compatibility.py index caff8b41df9..4f29e97bca7 100644 --- a/buildscripts/idl/tests/test_compatibility.py +++ b/buildscripts/idl/tests/test_compatibility.py @@ -48,7 +48,11 @@ class TestIDLCompatibilityChecker(unittest.TestCase): self.assertFalse( idl_check_compatibility.check_compatibility( path.join(dir_path, "compatibility_test_pass/old"), - path.join(dir_path, "compatibility_test_pass/new"), ["src"], ["src"]).has_errors()) + path.join(dir_path, "compatibility_test_pass/new"), + ["src"], + ["src"], + ).has_errors() + ) def test_should_abort(self): """Tests that invalid old and new IDL commands should cause script to abort.""" @@ -56,30 +60,58 @@ class TestIDLCompatibilityChecker(unittest.TestCase): # Test that when old command has a reply field with an invalid reply type, the script aborts. with self.assertRaises(SystemExit): idl_check_compatibility.check_compatibility( - path.join(dir_path, "compatibility_test_fail/abort/invalid_reply_field_type"), - path.join(dir_path, "compatibility_test_fail/abort/valid_reply_field_type"), - ["src"], ["src"]) + path.join( + dir_path, "compatibility_test_fail/abort/invalid_reply_field_type" + ), + path.join( + dir_path, "compatibility_test_fail/abort/valid_reply_field_type" + ), + ["src"], + ["src"], + ) # Test that when new command has a reply field with an invalid reply type, the script aborts. with self.assertRaises(SystemExit): idl_check_compatibility.check_compatibility( - path.join(dir_path, "compatibility_test_fail/abort/valid_reply_field_type"), - path.join(dir_path, "compatibility_test_fail/abort/invalid_reply_field_type"), - ["src"], ["src"]) + path.join( + dir_path, "compatibility_test_fail/abort/valid_reply_field_type" + ), + path.join( + dir_path, "compatibility_test_fail/abort/invalid_reply_field_type" + ), + ["src"], + ["src"], + ) # Test that when new command has a parameter with an invalid type, the script aborts. with self.assertRaises(SystemExit): idl_check_compatibility.check_compatibility( - path.join(dir_path, "compatibility_test_fail/abort/invalid_command_parameter_type"), - path.join(dir_path, "compatibility_test_fail/abort/valid_command_parameter_type"), - ["src"], ["src"]) + path.join( + dir_path, + "compatibility_test_fail/abort/invalid_command_parameter_type", + ), + path.join( + dir_path, + "compatibility_test_fail/abort/valid_command_parameter_type", + ), + ["src"], + ["src"], + ) # Test that when new command has a parameter with an invalid type, the script aborts. with self.assertRaises(SystemExit): idl_check_compatibility.check_compatibility( - path.join(dir_path, "compatibility_test_fail/abort/valid_command_parameter_type"), - path.join(dir_path, "compatibility_test_fail/abort/invalid_command_parameter_type"), - ["src"], ["src"]) + path.join( + dir_path, + "compatibility_test_fail/abort/valid_command_parameter_type", + ), + path.join( + dir_path, + "compatibility_test_fail/abort/invalid_command_parameter_type", + ), + ["src"], + ["src"], + ) # pylint: disable=invalid-name def test_newly_added_commands_should_fail(self): @@ -87,58 +119,94 @@ class TestIDLCompatibilityChecker(unittest.TestCase): dir_path = path.dirname(path.realpath(__file__)) error_collection = idl_check_compatibility.check_compatibility( path.join(dir_path, "compatibility_test_fail/newly_added_commands"), - path.join(dir_path, "compatibility_test_fail/newly_added_commands"), ["src"], ["src"]) + path.join(dir_path, "compatibility_test_fail/newly_added_commands"), + ["src"], + ["src"], + ) self.assertTrue(error_collection.has_errors()) self.assertEqual(error_collection.count(), 6) - new_parameter_no_unstable_field_error = error_collection.get_error_by_command_name( - "newCommandParameterNoUnstableField") - self.assertTrue(new_parameter_no_unstable_field_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_PARAMETER_REQUIRES_STABILITY) + new_parameter_no_unstable_field_error = ( + error_collection.get_error_by_command_name( + "newCommandParameterNoUnstableField" + ) + ) + self.assertTrue( + new_parameter_no_unstable_field_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_PARAMETER_REQUIRES_STABILITY + ) self.assertRegex( - str(new_parameter_no_unstable_field_error), "newCommandParameterNoUnstableField") + str(new_parameter_no_unstable_field_error), + "newCommandParameterNoUnstableField", + ) new_reply_no_unstable_field_error = error_collection.get_error_by_command_name( - "newCommandReplyNoUnstableField") - self.assertTrue(new_reply_no_unstable_field_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_REQUIRES_STABILITY) - self.assertRegex(str(new_reply_no_unstable_field_error), "newCommandReplyNoUnstableField") + "newCommandReplyNoUnstableField" + ) + self.assertTrue( + new_reply_no_unstable_field_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_REQUIRES_STABILITY + ) + self.assertRegex( + str(new_reply_no_unstable_field_error), "newCommandReplyNoUnstableField" + ) - new_command_type_struct_no_unstable_field_error = error_collection.get_error_by_command_name( - "newCommandTypeStructFieldNoUnstableField") - self.assertTrue(new_command_type_struct_no_unstable_field_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_FIELD_REQUIRES_STABILITY) + new_command_type_struct_no_unstable_field_error = ( + error_collection.get_error_by_command_name( + "newCommandTypeStructFieldNoUnstableField" + ) + ) + self.assertTrue( + new_command_type_struct_no_unstable_field_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_FIELD_REQUIRES_STABILITY + ) self.assertRegex( str(new_command_type_struct_no_unstable_field_error), - "newCommandTypeStructFieldNoUnstableField") + "newCommandTypeStructFieldNoUnstableField", + ) - new_parameter_bson_serialization_type_any_error = error_collection.get_error_by_command_name( - "newCommandParameterBsonSerializationTypeAny") + new_parameter_bson_serialization_type_any_error = ( + error_collection.get_error_by_command_name( + "newCommandParameterBsonSerializationTypeAny" + ) + ) self.assertTrue( - new_parameter_bson_serialization_type_any_error.error_id == idl_compatibility_errors. - ERROR_ID_COMMAND_PARAMETER_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED) + new_parameter_bson_serialization_type_any_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED + ) self.assertRegex( str(new_parameter_bson_serialization_type_any_error), - "newCommandParameterBsonSerializationTypeAny") + "newCommandParameterBsonSerializationTypeAny", + ) - new_reply_bson_serialization_type_any_error = error_collection.get_error_by_command_name( - "newCommandReplyBsonSerializationTypeAny") + new_reply_bson_serialization_type_any_error = ( + error_collection.get_error_by_command_name( + "newCommandReplyBsonSerializationTypeAny" + ) + ) self.assertTrue( - new_reply_bson_serialization_type_any_error.error_id == - idl_compatibility_errors.ERROR_ID_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED) + new_reply_bson_serialization_type_any_error.error_id + == idl_compatibility_errors.ERROR_ID_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED + ) self.assertRegex( str(new_reply_bson_serialization_type_any_error), - "newCommandReplyBsonSerializationTypeAny") + "newCommandReplyBsonSerializationTypeAny", + ) - new_command_type_struct_bson_serialization_type_any_error = error_collection.get_error_by_command_name( - "newCommandTypeStructFieldBsonSerializationTypeAny") + new_command_type_struct_bson_serialization_type_any_error = ( + error_collection.get_error_by_command_name( + "newCommandTypeStructFieldBsonSerializationTypeAny" + ) + ) self.assertTrue( - new_command_type_struct_bson_serialization_type_any_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED) + new_command_type_struct_bson_serialization_type_any_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED + ) self.assertRegex( str(new_command_type_struct_bson_serialization_type_any_error), - "newCommandTypeStructFieldBsonSerializationTypeAny") + "newCommandTypeStructFieldBsonSerializationTypeAny", + ) # pylint: disable=invalid-name def test_should_fail(self): @@ -146,1403 +214,2356 @@ class TestIDLCompatibilityChecker(unittest.TestCase): dir_path = path.dirname(path.realpath(__file__)) error_collection = idl_check_compatibility.check_compatibility( path.join(dir_path, "compatibility_test_fail/old"), - path.join(dir_path, "compatibility_test_fail/new"), ["src"], ["src"]) + path.join(dir_path, "compatibility_test_fail/new"), + ["src"], + ["src"], + ) self.assertTrue(error_collection.has_errors()) invalid_api_version_new_error = error_collection.get_error_by_command_name( - "invalidAPIVersionNew") - self.assertTrue(invalid_api_version_new_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_INVALID_API_VERSION) + "invalidAPIVersionNew" + ) + self.assertTrue( + invalid_api_version_new_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_INVALID_API_VERSION + ) self.assertRegex(str(invalid_api_version_new_error), "invalidAPIVersionNew") duplicate_command_new_error = error_collection.get_error_by_command_name( - "duplicateCommandNew") - self.assertTrue(duplicate_command_new_error.error_id == - idl_compatibility_errors.ERROR_ID_DUPLICATE_COMMAND_NAME) + "duplicateCommandNew" + ) + self.assertTrue( + duplicate_command_new_error.error_id + == idl_compatibility_errors.ERROR_ID_DUPLICATE_COMMAND_NAME + ) self.assertRegex(str(duplicate_command_new_error), "duplicateCommandNew") invalid_api_version_old_error = error_collection.get_error_by_command_name( - "invalidAPIVersionOld") - self.assertTrue(invalid_api_version_old_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_INVALID_API_VERSION) + "invalidAPIVersionOld" + ) + self.assertTrue( + invalid_api_version_old_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_INVALID_API_VERSION + ) self.assertRegex(str(invalid_api_version_old_error), "invalidAPIVersionOld") duplicate_command_old_error = error_collection.get_error_by_command_name( - "duplicateCommandOld") - self.assertTrue(duplicate_command_old_error.error_id == - idl_compatibility_errors.ERROR_ID_DUPLICATE_COMMAND_NAME) + "duplicateCommandOld" + ) + self.assertTrue( + duplicate_command_old_error.error_id + == idl_compatibility_errors.ERROR_ID_DUPLICATE_COMMAND_NAME + ) self.assertRegex(str(duplicate_command_old_error), "duplicateCommandOld") removed_command_error = error_collection.get_error_by_error_id( - idl_compatibility_errors.ERROR_ID_REMOVED_COMMAND) + idl_compatibility_errors.ERROR_ID_REMOVED_COMMAND + ) self.assertRegex(str(removed_command_error), "removedCommand") strict_false_to_true_command_error = error_collection.get_error_by_command_name( - "strictFalseToTrueCommand") - self.assertTrue(strict_false_to_true_command_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_STRICT_TRUE_ERROR) - self.assertRegex(str(strict_false_to_true_command_error), "strictFalseToTrueCommand") + "strictFalseToTrueCommand" + ) + self.assertTrue( + strict_false_to_true_command_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_STRICT_TRUE_ERROR + ) + self.assertRegex( + str(strict_false_to_true_command_error), "strictFalseToTrueCommand" + ) removed_command_parameter_error = error_collection.get_error_by_command_name( - "removedCommandParameter") - self.assertTrue(removed_command_parameter_error.error_id == - idl_compatibility_errors.ERROR_ID_REMOVED_COMMAND_PARAMETER) - self.assertRegex(str(removed_command_parameter_error), "removedCommandParameter") - - added_required_command_parameter_error = error_collection.get_error_by_command_name( - "addedNewCommandParameterRequired") - self.assertTrue(added_required_command_parameter_error.error_id == - idl_compatibility_errors.ERROR_ID_ADDED_REQUIRED_COMMAND_PARAMETER) + "removedCommandParameter" + ) + self.assertTrue( + removed_command_parameter_error.error_id + == idl_compatibility_errors.ERROR_ID_REMOVED_COMMAND_PARAMETER + ) self.assertRegex( - str(added_required_command_parameter_error), "addedNewCommandParameterRequired") + str(removed_command_parameter_error), "removedCommandParameter" + ) + + added_required_command_parameter_error = ( + error_collection.get_error_by_command_name( + "addedNewCommandParameterRequired" + ) + ) + self.assertTrue( + added_required_command_parameter_error.error_id + == idl_compatibility_errors.ERROR_ID_ADDED_REQUIRED_COMMAND_PARAMETER + ) + self.assertRegex( + str(added_required_command_parameter_error), + "addedNewCommandParameterRequired", + ) command_parameter_unstable_error = error_collection.get_error_by_command_name( - "commandParameterUnstable") - self.assertTrue(command_parameter_unstable_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_UNSTABLE) - self.assertRegex(str(command_parameter_unstable_error), "commandParameterUnstable") + "commandParameterUnstable" + ) + self.assertTrue( + command_parameter_unstable_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_UNSTABLE + ) + self.assertRegex( + str(command_parameter_unstable_error), "commandParameterUnstable" + ) command_parameter_internal_error = error_collection.get_error_by_command_name( - "commandParameterInternal") - self.assertTrue(command_parameter_internal_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_UNSTABLE) - self.assertRegex(str(command_parameter_internal_error), "commandParameterInternal") - - command_parameter_stable_required_no_default_error = error_collection.get_error_by_command_name( - "commandParameterStableRequiredNoDefault") + "commandParameterInternal" + ) self.assertTrue( - command_parameter_stable_required_no_default_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_STABLE_REQUIRED_NO_DEFAULT) + command_parameter_internal_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_UNSTABLE + ) + self.assertRegex( + str(command_parameter_internal_error), "commandParameterInternal" + ) + + command_parameter_stable_required_no_default_error = ( + error_collection.get_error_by_command_name( + "commandParameterStableRequiredNoDefault" + ) + ) + self.assertTrue( + command_parameter_stable_required_no_default_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_STABLE_REQUIRED_NO_DEFAULT + ) self.assertRegex( str(command_parameter_stable_required_no_default_error), - "commandParameterStableRequiredNoDefault") + "commandParameterStableRequiredNoDefault", + ) command_parameter_required_error = error_collection.get_error_by_command_name( - "commandParameterRequired") - self.assertTrue(command_parameter_required_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_REQUIRED) - self.assertRegex(str(command_parameter_required_error), "commandParameterRequired") - - old_command_parameter_type_bson_any_error = error_collection.get_error_by_command_name( - "oldCommandParameterTypeBsonSerializationAny") + "commandParameterRequired" + ) self.assertTrue( - old_command_parameter_type_bson_any_error.error_id == idl_compatibility_errors. - ERROR_ID_OLD_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY) + command_parameter_required_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_REQUIRED + ) + self.assertRegex( + str(command_parameter_required_error), "commandParameterRequired" + ) + + old_command_parameter_type_bson_any_error = ( + error_collection.get_error_by_command_name( + "oldCommandParameterTypeBsonSerializationAny" + ) + ) + self.assertTrue( + old_command_parameter_type_bson_any_error.error_id + == idl_compatibility_errors.ERROR_ID_OLD_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( str(old_command_parameter_type_bson_any_error), - "oldCommandParameterTypeBsonSerializationAny") + "oldCommandParameterTypeBsonSerializationAny", + ) - new_command_parameter_type_bson_any_error = error_collection.get_error_by_command_name( - "newCommandParameterTypeBsonSerializationAny") + new_command_parameter_type_bson_any_error = ( + error_collection.get_error_by_command_name( + "newCommandParameterTypeBsonSerializationAny" + ) + ) self.assertTrue( - new_command_parameter_type_bson_any_error.error_id == idl_compatibility_errors. - ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY) + new_command_parameter_type_bson_any_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( str(new_command_parameter_type_bson_any_error), - "newCommandParameterTypeBsonSerializationAny") + "newCommandParameterTypeBsonSerializationAny", + ) - old_param_type_bson_any_allow_list_error = error_collection.get_error_by_command_name( - "oldParamTypeBsonAnyAllowList") + old_param_type_bson_any_allow_list_error = ( + error_collection.get_error_by_command_name("oldParamTypeBsonAnyAllowList") + ) self.assertTrue( - old_param_type_bson_any_allow_list_error.error_id == idl_compatibility_errors. - ERROR_ID_OLD_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY) + old_param_type_bson_any_allow_list_error.error_id + == idl_compatibility_errors.ERROR_ID_OLD_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( - str(old_param_type_bson_any_allow_list_error), "oldParamTypeBsonAnyAllowList") + str(old_param_type_bson_any_allow_list_error), + "oldParamTypeBsonAnyAllowList", + ) - new_param_type_bson_any_allow_list_error = error_collection.get_error_by_command_name( - "newParamTypeBsonAnyAllowList") + new_param_type_bson_any_allow_list_error = ( + error_collection.get_error_by_command_name("newParamTypeBsonAnyAllowList") + ) self.assertTrue( - new_param_type_bson_any_allow_list_error.error_id == idl_compatibility_errors. - ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY) + new_param_type_bson_any_allow_list_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( - str(new_param_type_bson_any_allow_list_error), "newParamTypeBsonAnyAllowList") + str(new_param_type_bson_any_allow_list_error), + "newParamTypeBsonAnyAllowList", + ) - command_parameter_type_bson_any_not_allowed_error = error_collection.get_error_by_command_name( - "commandParameterTypeBsonSerializationAnyNotAllowed") + command_parameter_type_bson_any_not_allowed_error = ( + error_collection.get_error_by_command_name( + "commandParameterTypeBsonSerializationAnyNotAllowed" + ) + ) self.assertTrue( - command_parameter_type_bson_any_not_allowed_error.error_id == idl_compatibility_errors. - ERROR_ID_COMMAND_PARAMETER_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED) + command_parameter_type_bson_any_not_allowed_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED + ) self.assertRegex( str(command_parameter_type_bson_any_not_allowed_error), - "commandParameterTypeBsonSerializationAnyNotAllowed") + "commandParameterTypeBsonSerializationAnyNotAllowed", + ) - command_parameter_cpp_type_not_equal_error = error_collection.get_error_by_command_name( - "commandParameterCppTypeNotEqual") - self.assertTrue(command_parameter_cpp_type_not_equal_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_CPP_TYPE_NOT_EQUAL) + command_parameter_cpp_type_not_equal_error = ( + error_collection.get_error_by_command_name( + "commandParameterCppTypeNotEqual" + ) + ) + self.assertTrue( + command_parameter_cpp_type_not_equal_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_CPP_TYPE_NOT_EQUAL + ) self.assertRegex( - str(command_parameter_cpp_type_not_equal_error), "commandParameterCppTypeNotEqual") + str(command_parameter_cpp_type_not_equal_error), + "commandParameterCppTypeNotEqual", + ) - command_parameter_serializer_not_equal_error = error_collection.get_error_by_command_name( - "commandParameterSerializerNotEqual") - self.assertEqual(command_parameter_serializer_not_equal_error.error_id, - idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_SERIALIZER_NOT_EQUAL) + command_parameter_serializer_not_equal_error = ( + error_collection.get_error_by_command_name( + "commandParameterSerializerNotEqual" + ) + ) + self.assertEqual( + command_parameter_serializer_not_equal_error.error_id, + idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_SERIALIZER_NOT_EQUAL, + ) self.assertRegex( - str(command_parameter_serializer_not_equal_error), "commandParameterSerializerNotEqual") + str(command_parameter_serializer_not_equal_error), + "commandParameterSerializerNotEqual", + ) - command_parameter_deserializer_not_equal_error = error_collection.get_error_by_command_name( - "commandParameterDeserializerNotEqual") - self.assertTrue(command_parameter_deserializer_not_equal_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_DESERIALIZER_NOT_EQUAL) + command_parameter_deserializer_not_equal_error = ( + error_collection.get_error_by_command_name( + "commandParameterDeserializerNotEqual" + ) + ) + self.assertTrue( + command_parameter_deserializer_not_equal_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_DESERIALIZER_NOT_EQUAL + ) self.assertRegex( str(command_parameter_deserializer_not_equal_error), - "commandParameterDeserializerNotEqual") + "commandParameterDeserializerNotEqual", + ) - old_command_parameter_type_bson_any_unstable_error = error_collection.get_error_by_command_name( - "oldCommandParamTypeBsonAnyUnstable") + old_command_parameter_type_bson_any_unstable_error = ( + error_collection.get_error_by_command_name( + "oldCommandParamTypeBsonAnyUnstable" + ) + ) self.assertTrue( - old_command_parameter_type_bson_any_unstable_error.error_id == idl_compatibility_errors. - ERROR_ID_OLD_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY) + old_command_parameter_type_bson_any_unstable_error.error_id + == idl_compatibility_errors.ERROR_ID_OLD_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( str(old_command_parameter_type_bson_any_unstable_error), - "oldCommandParamTypeBsonAnyUnstable") + "oldCommandParamTypeBsonAnyUnstable", + ) - new_command_parameter_type_bson_any_unstable_error = error_collection.get_error_by_command_name( - "newCommandParamTypeBsonAnyUnstable") + new_command_parameter_type_bson_any_unstable_error = ( + error_collection.get_error_by_command_name( + "newCommandParamTypeBsonAnyUnstable" + ) + ) self.assertTrue( - new_command_parameter_type_bson_any_unstable_error.error_id == idl_compatibility_errors. - ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY) + new_command_parameter_type_bson_any_unstable_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( str(new_command_parameter_type_bson_any_unstable_error), - "newCommandParamTypeBsonAnyUnstable") + "newCommandParamTypeBsonAnyUnstable", + ) - command_parameter_type_bson_any_not_allowed_unstable_error = error_collection.get_error_by_command_name( - "commandParamTypeBsonAnyNotAllowedUnstable") - self.assertTrue(command_parameter_type_bson_any_not_allowed_unstable_error.error_id == - idl_compatibility_errors. - ERROR_ID_COMMAND_PARAMETER_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED) + command_parameter_type_bson_any_not_allowed_unstable_error = ( + error_collection.get_error_by_command_name( + "commandParamTypeBsonAnyNotAllowedUnstable" + ) + ) + self.assertTrue( + command_parameter_type_bson_any_not_allowed_unstable_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED + ) self.assertRegex( str(command_parameter_type_bson_any_not_allowed_unstable_error), - "commandParamTypeBsonAnyNotAllowedUnstable") + "commandParamTypeBsonAnyNotAllowedUnstable", + ) - command_parameter_cpp_type_not_equal_unstable_error = error_collection.get_error_by_command_name( - "commandParameterCppTypeNotEqualUnstable") - self.assertTrue(command_parameter_cpp_type_not_equal_unstable_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_CPP_TYPE_NOT_EQUAL) + command_parameter_cpp_type_not_equal_unstable_error = ( + error_collection.get_error_by_command_name( + "commandParameterCppTypeNotEqualUnstable" + ) + ) + self.assertTrue( + command_parameter_cpp_type_not_equal_unstable_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_CPP_TYPE_NOT_EQUAL + ) self.assertRegex( str(command_parameter_cpp_type_not_equal_unstable_error), - "commandParameterCppTypeNotEqualUnstable") + "commandParameterCppTypeNotEqualUnstable", + ) parameter_field_type_bson_any_with_variant_unstable_error = error_collection.get_error_by_command_name_and_error_id( - "parameterFieldTypeBsonAnyWithVariantUnstable", idl_compatibility_errors. - ERROR_ID_OLD_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY) - self.assertTrue(parameter_field_type_bson_any_with_variant_unstable_error.error_id == - idl_compatibility_errors. - ERROR_ID_OLD_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY) - self.assertRegex( - str(parameter_field_type_bson_any_with_variant_unstable_error), - "parameterFieldTypeBsonAnyWithVariantUnstable") - - parameter_field_type_bson_any_with_variant_unstable_error = error_collection.get_error_by_command_name_and_error_id( - "parameterFieldTypeBsonAnyWithVariantUnstable", idl_compatibility_errors. - ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY) - self.assertTrue(parameter_field_type_bson_any_with_variant_unstable_error.error_id == - idl_compatibility_errors. - ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY) - self.assertRegex( - str(parameter_field_type_bson_any_with_variant_unstable_error), - "parameterFieldTypeBsonAnyWithVariantUnstable") - - newly_added_param_bson_any_not_allowed_error = error_collection.get_error_by_command_name( - "newlyAddedParamBsonAnyNotAllowed") + "parameterFieldTypeBsonAnyWithVariantUnstable", + idl_compatibility_errors.ERROR_ID_OLD_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY, + ) self.assertTrue( - newly_added_param_bson_any_not_allowed_error.error_id == idl_compatibility_errors. - ERROR_ID_COMMAND_PARAMETER_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED) + parameter_field_type_bson_any_with_variant_unstable_error.error_id + == idl_compatibility_errors.ERROR_ID_OLD_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( - str(newly_added_param_bson_any_not_allowed_error), "newlyAddedParamBsonAnyNotAllowed") + str(parameter_field_type_bson_any_with_variant_unstable_error), + "parameterFieldTypeBsonAnyWithVariantUnstable", + ) - new_command_parameter_type_enum_not_superset = error_collection.get_error_by_command_name( - "newCommandParameterTypeEnumNotSuperset") - self.assertTrue(new_command_parameter_type_enum_not_superset.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_TYPE_NOT_SUPERSET) + parameter_field_type_bson_any_with_variant_unstable_error = error_collection.get_error_by_command_name_and_error_id( + "parameterFieldTypeBsonAnyWithVariantUnstable", + idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY, + ) + self.assertTrue( + parameter_field_type_bson_any_with_variant_unstable_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY + ) + self.assertRegex( + str(parameter_field_type_bson_any_with_variant_unstable_error), + "parameterFieldTypeBsonAnyWithVariantUnstable", + ) + + newly_added_param_bson_any_not_allowed_error = ( + error_collection.get_error_by_command_name( + "newlyAddedParamBsonAnyNotAllowed" + ) + ) + self.assertTrue( + newly_added_param_bson_any_not_allowed_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED + ) + self.assertRegex( + str(newly_added_param_bson_any_not_allowed_error), + "newlyAddedParamBsonAnyNotAllowed", + ) + + new_command_parameter_type_enum_not_superset = ( + error_collection.get_error_by_command_name( + "newCommandParameterTypeEnumNotSuperset" + ) + ) + self.assertTrue( + new_command_parameter_type_enum_not_superset.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_TYPE_NOT_SUPERSET + ) self.assertRegex( str(new_command_parameter_type_enum_not_superset), - "newCommandParameterTypeEnumNotSuperset") + "newCommandParameterTypeEnumNotSuperset", + ) - new_command_parameter_type_not_enum = error_collection.get_error_by_command_name( - "newCommandParameterTypeNotEnum") - self.assertTrue(new_command_parameter_type_not_enum.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_NOT_ENUM) - self.assertRegex(str(new_command_parameter_type_not_enum), "newCommandParameterTypeNotEnum") - - new_command_parameter_type_not_struct = error_collection.get_error_by_command_name( - "newCommandParameterTypeNotStruct") - self.assertTrue(new_command_parameter_type_not_struct.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_NOT_STRUCT) + new_command_parameter_type_not_enum = ( + error_collection.get_error_by_command_name("newCommandParameterTypeNotEnum") + ) + self.assertTrue( + new_command_parameter_type_not_enum.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_NOT_ENUM + ) self.assertRegex( - str(new_command_parameter_type_not_struct), "newCommandParameterTypeNotStruct") + str(new_command_parameter_type_not_enum), "newCommandParameterTypeNotEnum" + ) - new_command_parameter_type_enum_or_struct_one = error_collection.get_error_by_command_name( - "newCommandParameterTypeEnumOrStructOne") - self.assertTrue(new_command_parameter_type_enum_or_struct_one.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_ENUM_OR_STRUCT) + new_command_parameter_type_not_struct = ( + error_collection.get_error_by_command_name( + "newCommandParameterTypeNotStruct" + ) + ) + self.assertTrue( + new_command_parameter_type_not_struct.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_NOT_STRUCT + ) + self.assertRegex( + str(new_command_parameter_type_not_struct), + "newCommandParameterTypeNotStruct", + ) + + new_command_parameter_type_enum_or_struct_one = ( + error_collection.get_error_by_command_name( + "newCommandParameterTypeEnumOrStructOne" + ) + ) + self.assertTrue( + new_command_parameter_type_enum_or_struct_one.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_ENUM_OR_STRUCT + ) self.assertRegex( str(new_command_parameter_type_enum_or_struct_one), - "newCommandParameterTypeEnumOrStructOne") + "newCommandParameterTypeEnumOrStructOne", + ) - new_command_parameter_type_enum_or_struct_two = error_collection.get_error_by_command_name( - "newCommandParameterTypeEnumOrStructTwo") - self.assertTrue(new_command_parameter_type_enum_or_struct_two.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_ENUM_OR_STRUCT) + new_command_parameter_type_enum_or_struct_two = ( + error_collection.get_error_by_command_name( + "newCommandParameterTypeEnumOrStructTwo" + ) + ) + self.assertTrue( + new_command_parameter_type_enum_or_struct_two.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_ENUM_OR_STRUCT + ) self.assertRegex( str(new_command_parameter_type_enum_or_struct_two), - "newCommandParameterTypeEnumOrStructTwo") + "newCommandParameterTypeEnumOrStructTwo", + ) - new_command_parameter_type_bson_not_superset = error_collection.get_error_by_command_name( - "newCommandParameterTypeBsonNotSuperset") - self.assertTrue(new_command_parameter_type_bson_not_superset.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_TYPE_NOT_SUPERSET) + new_command_parameter_type_bson_not_superset = ( + error_collection.get_error_by_command_name( + "newCommandParameterTypeBsonNotSuperset" + ) + ) + self.assertTrue( + new_command_parameter_type_bson_not_superset.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_TYPE_NOT_SUPERSET + ) self.assertRegex( str(new_command_parameter_type_bson_not_superset), - "newCommandParameterTypeBsonNotSuperset") + "newCommandParameterTypeBsonNotSuperset", + ) - new_command_parameter_type_recursive_one_error = error_collection.get_error_by_command_name( - "newCommandParameterTypeStructRecursiveOne") - self.assertTrue(new_command_parameter_type_recursive_one_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_UNSTABLE) + new_command_parameter_type_recursive_one_error = ( + error_collection.get_error_by_command_name( + "newCommandParameterTypeStructRecursiveOne" + ) + ) + self.assertTrue( + new_command_parameter_type_recursive_one_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_UNSTABLE + ) self.assertRegex( str(new_command_parameter_type_recursive_one_error), - "newCommandParameterTypeStructRecursiveOne") + "newCommandParameterTypeStructRecursiveOne", + ) - new_command_parameter_type_recursive_two_error = error_collection.get_error_by_command_name( - "newCommandParameterTypeStructRecursiveTwo") - self.assertTrue(new_command_parameter_type_recursive_two_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_TYPE_NOT_SUPERSET) + new_command_parameter_type_recursive_two_error = ( + error_collection.get_error_by_command_name( + "newCommandParameterTypeStructRecursiveTwo" + ) + ) + self.assertTrue( + new_command_parameter_type_recursive_two_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_TYPE_NOT_SUPERSET + ) self.assertRegex( str(new_command_parameter_type_recursive_two_error), - "newCommandParameterTypeStructRecursiveTwo") + "newCommandParameterTypeStructRecursiveTwo", + ) new_reply_field_unstable_error = error_collection.get_error_by_command_name( - "newReplyFieldUnstable") - self.assertTrue(new_reply_field_unstable_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_UNSTABLE) + "newReplyFieldUnstable" + ) + self.assertTrue( + new_reply_field_unstable_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_UNSTABLE + ) self.assertRegex(str(new_reply_field_unstable_error), "newReplyFieldUnstable") new_reply_field_internal_error = error_collection.get_error_by_command_name( - "newReplyFieldInternal") - self.assertTrue(new_reply_field_internal_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_UNSTABLE) + "newReplyFieldInternal" + ) + self.assertTrue( + new_reply_field_internal_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_UNSTABLE + ) self.assertRegex(str(new_reply_field_internal_error), "newReplyFieldInternal") new_reply_field_optional_error = error_collection.get_error_by_error_id( - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_OPTIONAL) + idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_OPTIONAL + ) self.assertRegex(str(new_reply_field_optional_error), "newReplyFieldOptional") new_reply_field_missing_error = error_collection.get_error_by_error_id( - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_MISSING) + idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_MISSING + ) self.assertRegex(str(new_reply_field_missing_error), "newReplyFieldMissing") - imported_reply_field_unstable_error = error_collection.get_error_by_command_name( - "importedReplyCommand") - self.assertTrue(imported_reply_field_unstable_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_UNSTABLE) - self.assertRegex(str(imported_reply_field_unstable_error), "importedReplyCommand") - - new_reply_field_type_enum_not_subset_error = error_collection.get_error_by_command_name( - "newReplyFieldTypeEnumNotSubset") - self.assertTrue(new_reply_field_type_enum_not_subset_error.error_id == - idl_compatibility_errors.ERROR_ID_REPLY_FIELD_NOT_SUBSET) + imported_reply_field_unstable_error = ( + error_collection.get_error_by_command_name("importedReplyCommand") + ) + self.assertTrue( + imported_reply_field_unstable_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_UNSTABLE + ) self.assertRegex( - str(new_reply_field_type_enum_not_subset_error), "newReplyFieldTypeEnumNotSubset") + str(imported_reply_field_unstable_error), "importedReplyCommand" + ) + + new_reply_field_type_enum_not_subset_error = ( + error_collection.get_error_by_command_name("newReplyFieldTypeEnumNotSubset") + ) + self.assertTrue( + new_reply_field_type_enum_not_subset_error.error_id + == idl_compatibility_errors.ERROR_ID_REPLY_FIELD_NOT_SUBSET + ) + self.assertRegex( + str(new_reply_field_type_enum_not_subset_error), + "newReplyFieldTypeEnumNotSubset", + ) new_reply_field_type_not_enum_error = error_collection.get_error_by_error_id( - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_TYPE_NOT_ENUM) - self.assertRegex(str(new_reply_field_type_not_enum_error), "newReplyFieldTypeNotEnum") + idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_TYPE_NOT_ENUM + ) + self.assertRegex( + str(new_reply_field_type_not_enum_error), "newReplyFieldTypeNotEnum" + ) new_reply_field_type_not_struct_error = error_collection.get_error_by_error_id( - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_TYPE_NOT_STRUCT) - self.assertRegex(str(new_reply_field_type_not_struct_error), "newReplyFieldTypeNotStruct") - - new_reply_field_type_enum_or_struct_error = error_collection.get_error_by_error_id( - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_TYPE_ENUM_OR_STRUCT) + idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_TYPE_NOT_STRUCT + ) self.assertRegex( - str(new_reply_field_type_enum_or_struct_error), "newReplyFieldTypeEnumOrStruct") + str(new_reply_field_type_not_struct_error), "newReplyFieldTypeNotStruct" + ) - new_reply_field_type_bson_not_subset_error = error_collection.get_error_by_command_name( - "newReplyFieldTypeBsonNotSubset") - self.assertTrue(new_reply_field_type_bson_not_subset_error.error_id == - idl_compatibility_errors.ERROR_ID_REPLY_FIELD_NOT_SUBSET) + new_reply_field_type_enum_or_struct_error = ( + error_collection.get_error_by_error_id( + idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_TYPE_ENUM_OR_STRUCT + ) + ) self.assertRegex( - str(new_reply_field_type_bson_not_subset_error), "newReplyFieldTypeBsonNotSubset") + str(new_reply_field_type_enum_or_struct_error), + "newReplyFieldTypeEnumOrStruct", + ) - new_reply_field_type_bson_not_subset_two_error = error_collection.get_error_by_command_name( - "newReplyFieldTypeBsonNotSubsetTwo") - self.assertTrue(new_reply_field_type_bson_not_subset_two_error.error_id == - idl_compatibility_errors.ERROR_ID_REPLY_FIELD_NOT_SUBSET) + new_reply_field_type_bson_not_subset_error = ( + error_collection.get_error_by_command_name("newReplyFieldTypeBsonNotSubset") + ) + self.assertTrue( + new_reply_field_type_bson_not_subset_error.error_id + == idl_compatibility_errors.ERROR_ID_REPLY_FIELD_NOT_SUBSET + ) + self.assertRegex( + str(new_reply_field_type_bson_not_subset_error), + "newReplyFieldTypeBsonNotSubset", + ) + + new_reply_field_type_bson_not_subset_two_error = ( + error_collection.get_error_by_command_name( + "newReplyFieldTypeBsonNotSubsetTwo" + ) + ) + self.assertTrue( + new_reply_field_type_bson_not_subset_two_error.error_id + == idl_compatibility_errors.ERROR_ID_REPLY_FIELD_NOT_SUBSET + ) self.assertRegex( str(new_reply_field_type_bson_not_subset_two_error), - "newReplyFieldTypeBsonNotSubsetTwo") + "newReplyFieldTypeBsonNotSubsetTwo", + ) - old_reply_field_type_bson_any_error = error_collection.get_error_by_command_name( - "oldReplyFieldTypeBsonAny") - self.assertTrue(old_reply_field_type_bson_any_error.error_id == idl_compatibility_errors. - ERROR_ID_OLD_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY) - self.assertRegex(str(old_reply_field_type_bson_any_error), "oldReplyFieldTypeBsonAny") - - new_reply_field_type_bson_any_error = error_collection.get_error_by_command_name( - "newReplyFieldTypeBsonAny") - self.assertTrue(new_reply_field_type_bson_any_error.error_id == idl_compatibility_errors. - ERROR_ID_NEW_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY) - self.assertRegex(str(new_reply_field_type_bson_any_error), "newReplyFieldTypeBsonAny") - - old_reply_field_type_bson_any_allow_list_error = error_collection.get_error_by_command_name( - "oldReplyFieldTypeBsonAnyAllowList") + old_reply_field_type_bson_any_error = ( + error_collection.get_error_by_command_name("oldReplyFieldTypeBsonAny") + ) self.assertTrue( - old_reply_field_type_bson_any_allow_list_error.error_id == - idl_compatibility_errors.ERROR_ID_OLD_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY) + old_reply_field_type_bson_any_error.error_id + == idl_compatibility_errors.ERROR_ID_OLD_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY + ) + self.assertRegex( + str(old_reply_field_type_bson_any_error), "oldReplyFieldTypeBsonAny" + ) + + new_reply_field_type_bson_any_error = ( + error_collection.get_error_by_command_name("newReplyFieldTypeBsonAny") + ) + self.assertTrue( + new_reply_field_type_bson_any_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY + ) + self.assertRegex( + str(new_reply_field_type_bson_any_error), "newReplyFieldTypeBsonAny" + ) + + old_reply_field_type_bson_any_allow_list_error = ( + error_collection.get_error_by_command_name( + "oldReplyFieldTypeBsonAnyAllowList" + ) + ) + self.assertTrue( + old_reply_field_type_bson_any_allow_list_error.error_id + == idl_compatibility_errors.ERROR_ID_OLD_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( str(old_reply_field_type_bson_any_allow_list_error), - "oldReplyFieldTypeBsonAnyAllowList") + "oldReplyFieldTypeBsonAnyAllowList", + ) - new_reply_field_type_bson_any_allow_list_error = error_collection.get_error_by_command_name( - "newReplyFieldTypeBsonAnyAllowList") + new_reply_field_type_bson_any_allow_list_error = ( + error_collection.get_error_by_command_name( + "newReplyFieldTypeBsonAnyAllowList" + ) + ) self.assertTrue( - new_reply_field_type_bson_any_allow_list_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY) + new_reply_field_type_bson_any_allow_list_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( str(new_reply_field_type_bson_any_allow_list_error), - "newReplyFieldTypeBsonAnyAllowList") + "newReplyFieldTypeBsonAnyAllowList", + ) - reply_field_type_bson_any_not_allowed_error = error_collection.get_error_by_command_name( - "replyFieldTypeBsonAnyNotAllowed") + reply_field_type_bson_any_not_allowed_error = ( + error_collection.get_error_by_command_name( + "replyFieldTypeBsonAnyNotAllowed" + ) + ) self.assertTrue( - reply_field_type_bson_any_not_allowed_error.error_id == - idl_compatibility_errors.ERROR_ID_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED) + reply_field_type_bson_any_not_allowed_error.error_id + == idl_compatibility_errors.ERROR_ID_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED + ) self.assertRegex( - str(reply_field_type_bson_any_not_allowed_error), "replyFieldTypeBsonAnyNotAllowed") + str(reply_field_type_bson_any_not_allowed_error), + "replyFieldTypeBsonAnyNotAllowed", + ) reply_field_type_bson_any_with_variant_error = error_collection.get_error_by_command_name_and_error_id( "replyFieldTypeBsonAnyWithVariant", - idl_compatibility_errors.ERROR_ID_OLD_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY) + idl_compatibility_errors.ERROR_ID_OLD_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY, + ) self.assertTrue( - reply_field_type_bson_any_with_variant_error.error_id == - idl_compatibility_errors.ERROR_ID_OLD_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY) + reply_field_type_bson_any_with_variant_error.error_id + == idl_compatibility_errors.ERROR_ID_OLD_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( - str(reply_field_type_bson_any_with_variant_error), "replyFieldTypeBsonAnyWithVariant") + str(reply_field_type_bson_any_with_variant_error), + "replyFieldTypeBsonAnyWithVariant", + ) reply_field_type_bson_any_with_variant_error = error_collection.get_error_by_command_name_and_error_id( "replyFieldTypeBsonAnyWithVariant", - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY) + idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY, + ) self.assertTrue( - reply_field_type_bson_any_with_variant_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY) + reply_field_type_bson_any_with_variant_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( - str(reply_field_type_bson_any_with_variant_error), "replyFieldTypeBsonAnyWithVariant") + str(reply_field_type_bson_any_with_variant_error), + "replyFieldTypeBsonAnyWithVariant", + ) - old_reply_field_type_bson_any_unstable_error = error_collection.get_error_by_command_name( - "oldReplyFieldTypeBsonAnyUnstable") + old_reply_field_type_bson_any_unstable_error = ( + error_collection.get_error_by_command_name( + "oldReplyFieldTypeBsonAnyUnstable" + ) + ) self.assertTrue( - old_reply_field_type_bson_any_unstable_error.error_id == - idl_compatibility_errors.ERROR_ID_OLD_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY) + old_reply_field_type_bson_any_unstable_error.error_id + == idl_compatibility_errors.ERROR_ID_OLD_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( - str(old_reply_field_type_bson_any_unstable_error), "oldReplyFieldTypeBsonAnyUnstable") + str(old_reply_field_type_bson_any_unstable_error), + "oldReplyFieldTypeBsonAnyUnstable", + ) - new_reply_field_type_bson_any_unstable_error = error_collection.get_error_by_command_name( - "newReplyFieldTypeBsonAnyUnstable") + new_reply_field_type_bson_any_unstable_error = ( + error_collection.get_error_by_command_name( + "newReplyFieldTypeBsonAnyUnstable" + ) + ) self.assertTrue( - new_reply_field_type_bson_any_unstable_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY) + new_reply_field_type_bson_any_unstable_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( - str(new_reply_field_type_bson_any_unstable_error), "newReplyFieldTypeBsonAnyUnstable") + str(new_reply_field_type_bson_any_unstable_error), + "newReplyFieldTypeBsonAnyUnstable", + ) - reply_field_type_bson_any_not_allowed_unstable_error = error_collection.get_error_by_command_name( - "replyFieldTypeBsonAnyNotAllowedUnstable") + reply_field_type_bson_any_not_allowed_unstable_error = ( + error_collection.get_error_by_command_name( + "replyFieldTypeBsonAnyNotAllowedUnstable" + ) + ) self.assertTrue( - reply_field_type_bson_any_not_allowed_unstable_error.error_id == - idl_compatibility_errors.ERROR_ID_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED) + reply_field_type_bson_any_not_allowed_unstable_error.error_id + == idl_compatibility_errors.ERROR_ID_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED + ) self.assertRegex( str(reply_field_type_bson_any_not_allowed_unstable_error), - "replyFieldTypeBsonAnyNotAllowedUnstable") + "replyFieldTypeBsonAnyNotAllowedUnstable", + ) reply_field_type_bson_any_with_variant_unstable_error = error_collection.get_error_by_command_name_and_error_id( "replyFieldTypeBsonAnyWithVariantUnstable", - idl_compatibility_errors.ERROR_ID_OLD_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY) + idl_compatibility_errors.ERROR_ID_OLD_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY, + ) self.assertTrue( - reply_field_type_bson_any_with_variant_unstable_error.error_id == - idl_compatibility_errors.ERROR_ID_OLD_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY) + reply_field_type_bson_any_with_variant_unstable_error.error_id + == idl_compatibility_errors.ERROR_ID_OLD_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( str(reply_field_type_bson_any_with_variant_unstable_error), - "replyFieldTypeBsonAnyWithVariantUnstable") + "replyFieldTypeBsonAnyWithVariantUnstable", + ) reply_field_type_bson_any_with_variant_unstable_error = error_collection.get_error_by_command_name_and_error_id( "replyFieldTypeBsonAnyWithVariantUnstable", - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY) + idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY, + ) self.assertTrue( - reply_field_type_bson_any_with_variant_unstable_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY) + reply_field_type_bson_any_with_variant_unstable_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( str(reply_field_type_bson_any_with_variant_unstable_error), - "replyFieldTypeBsonAnyWithVariantUnstable") + "replyFieldTypeBsonAnyWithVariantUnstable", + ) - reply_field_cpp_type_not_equal_unstable_error = error_collection.get_error_by_command_name( - "replyFieldCppTypeNotEqualUnstable") - self.assertTrue(reply_field_cpp_type_not_equal_unstable_error.error_id == - idl_compatibility_errors.ERROR_ID_REPLY_FIELD_CPP_TYPE_NOT_EQUAL) - self.assertRegex( - str(reply_field_cpp_type_not_equal_unstable_error), "replyFieldCppTypeNotEqualUnstable") - - newly_added_reply_field_bson_any_not_allowed_error = error_collection.get_error_by_command_name( - "newlyAddedReplyFieldTypeBsonAnyNotAllowed") + reply_field_cpp_type_not_equal_unstable_error = ( + error_collection.get_error_by_command_name( + "replyFieldCppTypeNotEqualUnstable" + ) + ) self.assertTrue( - newly_added_reply_field_bson_any_not_allowed_error.error_id == - idl_compatibility_errors.ERROR_ID_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED) + reply_field_cpp_type_not_equal_unstable_error.error_id + == idl_compatibility_errors.ERROR_ID_REPLY_FIELD_CPP_TYPE_NOT_EQUAL + ) + self.assertRegex( + str(reply_field_cpp_type_not_equal_unstable_error), + "replyFieldCppTypeNotEqualUnstable", + ) + + newly_added_reply_field_bson_any_not_allowed_error = ( + error_collection.get_error_by_command_name( + "newlyAddedReplyFieldTypeBsonAnyNotAllowed" + ) + ) + self.assertTrue( + newly_added_reply_field_bson_any_not_allowed_error.error_id + == idl_compatibility_errors.ERROR_ID_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED + ) self.assertRegex( str(newly_added_reply_field_bson_any_not_allowed_error), - "newlyAddedReplyFieldTypeBsonAnyNotAllowed") + "newlyAddedReplyFieldTypeBsonAnyNotAllowed", + ) reply_field_type_bson_any_with_variant_with_array_error = error_collection.get_error_by_command_name_and_error_id( "replyFieldTypeBsonAnyWithVariantWithArray", - idl_compatibility_errors.ERROR_ID_OLD_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY) + idl_compatibility_errors.ERROR_ID_OLD_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY, + ) self.assertTrue( - reply_field_type_bson_any_with_variant_with_array_error.error_id == - idl_compatibility_errors.ERROR_ID_OLD_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY) + reply_field_type_bson_any_with_variant_with_array_error.error_id + == idl_compatibility_errors.ERROR_ID_OLD_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( str(reply_field_type_bson_any_with_variant_with_array_error), - "replyFieldTypeBsonAnyWithVariantWithArray") + "replyFieldTypeBsonAnyWithVariantWithArray", + ) reply_field_type_bson_any_with_variant_with_array_error = error_collection.get_error_by_command_name_and_error_id( "replyFieldTypeBsonAnyWithVariantWithArray", - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY) + idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY, + ) self.assertTrue( - reply_field_type_bson_any_with_variant_with_array_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY) + reply_field_type_bson_any_with_variant_with_array_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( str(reply_field_type_bson_any_with_variant_with_array_error), - "replyFieldTypeBsonAnyWithVariantWithArray") + "replyFieldTypeBsonAnyWithVariantWithArray", + ) parameter_field_type_bson_any_with_variant_error = error_collection.get_error_by_command_name_and_error_id( - "parameterFieldTypeBsonAnyWithVariant", idl_compatibility_errors. - ERROR_ID_OLD_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY) + "parameterFieldTypeBsonAnyWithVariant", + idl_compatibility_errors.ERROR_ID_OLD_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY, + ) self.assertTrue( - parameter_field_type_bson_any_with_variant_error.error_id == idl_compatibility_errors. - ERROR_ID_OLD_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY) + parameter_field_type_bson_any_with_variant_error.error_id + == idl_compatibility_errors.ERROR_ID_OLD_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( str(parameter_field_type_bson_any_with_variant_error), - "parameterFieldTypeBsonAnyWithVariant") + "parameterFieldTypeBsonAnyWithVariant", + ) parameter_field_type_bson_any_with_variant_error = error_collection.get_error_by_command_name_and_error_id( - "parameterFieldTypeBsonAnyWithVariant", idl_compatibility_errors. - ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY) + "parameterFieldTypeBsonAnyWithVariant", + idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY, + ) self.assertTrue( - parameter_field_type_bson_any_with_variant_error.error_id == idl_compatibility_errors. - ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY) + parameter_field_type_bson_any_with_variant_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( str(parameter_field_type_bson_any_with_variant_error), - "parameterFieldTypeBsonAnyWithVariant") + "parameterFieldTypeBsonAnyWithVariant", + ) parameter_field_type_bson_any_with_variant_with_array_error = error_collection.get_error_by_command_name_and_error_id( - "parameterFieldTypeBsonAnyWithVariantWithArray", idl_compatibility_errors. - ERROR_ID_OLD_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY) - self.assertTrue(parameter_field_type_bson_any_with_variant_with_array_error.error_id == - idl_compatibility_errors. - ERROR_ID_OLD_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY) + "parameterFieldTypeBsonAnyWithVariantWithArray", + idl_compatibility_errors.ERROR_ID_OLD_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY, + ) + self.assertTrue( + parameter_field_type_bson_any_with_variant_with_array_error.error_id + == idl_compatibility_errors.ERROR_ID_OLD_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( str(parameter_field_type_bson_any_with_variant_with_array_error), - "parameterFieldTypeBsonAnyWithVariantWithArray") + "parameterFieldTypeBsonAnyWithVariantWithArray", + ) parameter_field_type_bson_any_with_variant_with_array_error = error_collection.get_error_by_command_name_and_error_id( - "parameterFieldTypeBsonAnyWithVariantWithArray", idl_compatibility_errors. - ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY) - self.assertTrue(parameter_field_type_bson_any_with_variant_with_array_error.error_id == - idl_compatibility_errors. - ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY) + "parameterFieldTypeBsonAnyWithVariantWithArray", + idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY, + ) + self.assertTrue( + parameter_field_type_bson_any_with_variant_with_array_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( str(parameter_field_type_bson_any_with_variant_with_array_error), - "parameterFieldTypeBsonAnyWithVariantWithArray") + "parameterFieldTypeBsonAnyWithVariantWithArray", + ) command_type_bson_any_with_variant_error = error_collection.get_error_by_command_name_and_error_id( "commandTypeBsonAnyWithVariant", - idl_compatibility_errors.ERROR_ID_OLD_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY) + idl_compatibility_errors.ERROR_ID_OLD_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY, + ) self.assertTrue( - command_type_bson_any_with_variant_error.error_id == - idl_compatibility_errors.ERROR_ID_OLD_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY) + command_type_bson_any_with_variant_error.error_id + == idl_compatibility_errors.ERROR_ID_OLD_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( - str(command_type_bson_any_with_variant_error), "commandTypeBsonAnyWithVariant") + str(command_type_bson_any_with_variant_error), + "commandTypeBsonAnyWithVariant", + ) command_type_bson_any_with_variant_error = error_collection.get_error_by_command_name_and_error_id( "commandTypeBsonAnyWithVariant", - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY) + idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY, + ) self.assertTrue( - command_type_bson_any_with_variant_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY) + command_type_bson_any_with_variant_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( - str(command_type_bson_any_with_variant_error), "commandTypeBsonAnyWithVariant") + str(command_type_bson_any_with_variant_error), + "commandTypeBsonAnyWithVariant", + ) command_type_bson_any_with_variant_with_array_error = error_collection.get_error_by_command_name_and_error_id( "commandTypeBsonAnyWithVariantWithArray", - idl_compatibility_errors.ERROR_ID_OLD_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY) + idl_compatibility_errors.ERROR_ID_OLD_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY, + ) self.assertTrue( - command_type_bson_any_with_variant_with_array_error.error_id == - idl_compatibility_errors.ERROR_ID_OLD_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY) + command_type_bson_any_with_variant_with_array_error.error_id + == idl_compatibility_errors.ERROR_ID_OLD_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( str(command_type_bson_any_with_variant_with_array_error), - "commandTypeBsonAnyWithVariantWithArray") + "commandTypeBsonAnyWithVariantWithArray", + ) command_type_bson_any_with_variant_with_array_error = error_collection.get_error_by_command_name_and_error_id( "commandTypeBsonAnyWithVariantWithArray", - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY) + idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY, + ) self.assertTrue( - command_type_bson_any_with_variant_with_array_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY) + command_type_bson_any_with_variant_with_array_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( str(command_type_bson_any_with_variant_with_array_error), - "commandTypeBsonAnyWithVariantWithArray") + "commandTypeBsonAnyWithVariantWithArray", + ) - reply_field_cpp_type_not_equal_error = error_collection.get_error_by_command_name( - "replyFieldCppTypeNotEqual") - self.assertTrue(reply_field_cpp_type_not_equal_error.error_id == - idl_compatibility_errors.ERROR_ID_REPLY_FIELD_CPP_TYPE_NOT_EQUAL) - self.assertRegex(str(reply_field_cpp_type_not_equal_error), "replyFieldCppTypeNotEqual") - - reply_field_serializer_not_equal_error = error_collection.get_error_by_command_name( - "replyFieldSerializerNotEqual") - self.assertTrue(reply_field_serializer_not_equal_error.error_id == - idl_compatibility_errors.ERROR_ID_REPLY_FIELD_SERIALIZER_NOT_EQUAL) + reply_field_cpp_type_not_equal_error = ( + error_collection.get_error_by_command_name("replyFieldCppTypeNotEqual") + ) + self.assertTrue( + reply_field_cpp_type_not_equal_error.error_id + == idl_compatibility_errors.ERROR_ID_REPLY_FIELD_CPP_TYPE_NOT_EQUAL + ) self.assertRegex( - str(reply_field_serializer_not_equal_error), "replyFieldSerializerNotEqual") + str(reply_field_cpp_type_not_equal_error), "replyFieldCppTypeNotEqual" + ) - reply_field_deserializer_not_equal_error = error_collection.get_error_by_command_name( - "replyFieldDeserializerNotEqual") - self.assertTrue(reply_field_deserializer_not_equal_error.error_id == - idl_compatibility_errors.ERROR_ID_REPLY_FIELD_DESERIALIZER_NOT_EQUAL) + reply_field_serializer_not_equal_error = ( + error_collection.get_error_by_command_name("replyFieldSerializerNotEqual") + ) + self.assertTrue( + reply_field_serializer_not_equal_error.error_id + == idl_compatibility_errors.ERROR_ID_REPLY_FIELD_SERIALIZER_NOT_EQUAL + ) self.assertRegex( - str(reply_field_deserializer_not_equal_error), "replyFieldDeserializerNotEqual") + str(reply_field_serializer_not_equal_error), "replyFieldSerializerNotEqual" + ) - new_reply_field_type_struct_one_error = error_collection.get_error_by_command_name( - "newReplyFieldTypeStructRecursiveOne") - self.assertTrue(new_reply_field_type_struct_one_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_UNSTABLE) + reply_field_deserializer_not_equal_error = ( + error_collection.get_error_by_command_name("replyFieldDeserializerNotEqual") + ) + self.assertTrue( + reply_field_deserializer_not_equal_error.error_id + == idl_compatibility_errors.ERROR_ID_REPLY_FIELD_DESERIALIZER_NOT_EQUAL + ) self.assertRegex( - str(new_reply_field_type_struct_one_error), "newReplyFieldTypeStructRecursiveOne") + str(reply_field_deserializer_not_equal_error), + "replyFieldDeserializerNotEqual", + ) - new_reply_field_type_struct_two_error = error_collection.get_error_by_command_name( - "newReplyFieldTypeStructRecursiveTwo") - self.assertTrue(new_reply_field_type_struct_two_error.error_id == - idl_compatibility_errors.ERROR_ID_REPLY_FIELD_NOT_SUBSET) + new_reply_field_type_struct_one_error = ( + error_collection.get_error_by_command_name( + "newReplyFieldTypeStructRecursiveOne" + ) + ) + self.assertTrue( + new_reply_field_type_struct_one_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_UNSTABLE + ) self.assertRegex( - str(new_reply_field_type_struct_two_error), "newReplyFieldTypeStructRecursiveTwo") + str(new_reply_field_type_struct_one_error), + "newReplyFieldTypeStructRecursiveOne", + ) + + new_reply_field_type_struct_two_error = ( + error_collection.get_error_by_command_name( + "newReplyFieldTypeStructRecursiveTwo" + ) + ) + self.assertTrue( + new_reply_field_type_struct_two_error.error_id + == idl_compatibility_errors.ERROR_ID_REPLY_FIELD_NOT_SUBSET + ) + self.assertRegex( + str(new_reply_field_type_struct_two_error), + "newReplyFieldTypeStructRecursiveTwo", + ) new_namespace_not_ignored_error = error_collection.get_error_by_command_name( - "newNamespaceNotIgnored") - self.assertTrue(new_namespace_not_ignored_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_NAMESPACE_INCOMPATIBLE) + "newNamespaceNotIgnored" + ) + self.assertTrue( + new_namespace_not_ignored_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_NAMESPACE_INCOMPATIBLE + ) self.assertRegex(str(new_namespace_not_ignored_error), "newNamespaceNotIgnored") - new_namespace_not_concatenate_with_db_or_uuid_error = error_collection.get_error_by_command_name( - "newNamespaceNotConcatenateWithDbOrUuid") - self.assertTrue(new_namespace_not_concatenate_with_db_or_uuid_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_NAMESPACE_INCOMPATIBLE) + new_namespace_not_concatenate_with_db_or_uuid_error = ( + error_collection.get_error_by_command_name( + "newNamespaceNotConcatenateWithDbOrUuid" + ) + ) + self.assertTrue( + new_namespace_not_concatenate_with_db_or_uuid_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_NAMESPACE_INCOMPATIBLE + ) self.assertRegex( str(new_namespace_not_concatenate_with_db_or_uuid_error), - "newNamespaceNotConcatenateWithDbOrUuid") + "newNamespaceNotConcatenateWithDbOrUuid", + ) - new_namespace_not_concatenate_with_db_error = error_collection.get_error_by_command_name( - "newNamespaceNotConcatenateWithDb") - self.assertTrue(new_namespace_not_concatenate_with_db_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_NAMESPACE_INCOMPATIBLE) + new_namespace_not_concatenate_with_db_error = ( + error_collection.get_error_by_command_name( + "newNamespaceNotConcatenateWithDb" + ) + ) + self.assertTrue( + new_namespace_not_concatenate_with_db_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_NAMESPACE_INCOMPATIBLE + ) self.assertRegex( - str(new_namespace_not_concatenate_with_db_error), "newNamespaceNotConcatenateWithDb") + str(new_namespace_not_concatenate_with_db_error), + "newNamespaceNotConcatenateWithDb", + ) new_namespace_not_type_error = error_collection.get_error_by_command_name( - "newNamespaceNotType") - self.assertTrue(new_namespace_not_type_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_NAMESPACE_INCOMPATIBLE) + "newNamespaceNotType" + ) + self.assertTrue( + new_namespace_not_type_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_NAMESPACE_INCOMPATIBLE + ) self.assertRegex(str(new_namespace_not_type_error), "newNamespaceNotType") - old_type_bson_any_error = error_collection.get_error_by_command_name("oldTypeBsonAny") - self.assertTrue(old_type_bson_any_error.error_id == idl_compatibility_errors. - ERROR_ID_OLD_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY) + old_type_bson_any_error = error_collection.get_error_by_command_name( + "oldTypeBsonAny" + ) + self.assertTrue( + old_type_bson_any_error.error_id + == idl_compatibility_errors.ERROR_ID_OLD_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex(str(old_type_bson_any_error), "oldTypeBsonAny") - new_type_bson_any_error = error_collection.get_error_by_command_name("newTypeBsonAny") - self.assertTrue(new_type_bson_any_error.error_id == idl_compatibility_errors. - ERROR_ID_NEW_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY) + new_type_bson_any_error = error_collection.get_error_by_command_name( + "newTypeBsonAny" + ) + self.assertTrue( + new_type_bson_any_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex(str(new_type_bson_any_error), "newTypeBsonAny") old_type_bson_any_allow_list_error = error_collection.get_error_by_command_name( - "oldTypeBsonAnyAllowList") - self.assertTrue(old_type_bson_any_allow_list_error.error_id == idl_compatibility_errors. - ERROR_ID_OLD_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY) - self.assertRegex(str(old_type_bson_any_allow_list_error), "oldTypeBsonAnyAllowList") + "oldTypeBsonAnyAllowList" + ) + self.assertTrue( + old_type_bson_any_allow_list_error.error_id + == idl_compatibility_errors.ERROR_ID_OLD_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY + ) + self.assertRegex( + str(old_type_bson_any_allow_list_error), "oldTypeBsonAnyAllowList" + ) new_type_bson_any_allow_list_error = error_collection.get_error_by_command_name( - "newTypeBsonAnyAllowList") - self.assertTrue(new_type_bson_any_allow_list_error.error_id == idl_compatibility_errors. - ERROR_ID_NEW_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY) - self.assertRegex(str(new_type_bson_any_allow_list_error), "newTypeBsonAnyAllowList") + "newTypeBsonAnyAllowList" + ) + self.assertTrue( + new_type_bson_any_allow_list_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY + ) + self.assertRegex( + str(new_type_bson_any_allow_list_error), "newTypeBsonAnyAllowList" + ) type_bson_any_not_allowed_error = error_collection.get_error_by_command_name( - "typeBsonAnyNotAllowed") - self.assertTrue(type_bson_any_not_allowed_error.error_id == idl_compatibility_errors. - ERROR_ID_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED) + "typeBsonAnyNotAllowed" + ) + self.assertTrue( + type_bson_any_not_allowed_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED + ) self.assertRegex(str(type_bson_any_not_allowed_error), "typeBsonAnyNotAllowed") command_cpp_type_not_equal_error = error_collection.get_error_by_command_name( - "commandCppTypeNotEqual") - self.assertTrue(command_cpp_type_not_equal_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_CPP_TYPE_NOT_EQUAL) - self.assertRegex(str(command_cpp_type_not_equal_error), "commandCppTypeNotEqual") + "commandCppTypeNotEqual" + ) + self.assertTrue( + command_cpp_type_not_equal_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_CPP_TYPE_NOT_EQUAL + ) + self.assertRegex( + str(command_cpp_type_not_equal_error), "commandCppTypeNotEqual" + ) command_serializer_not_equal_error = error_collection.get_error_by_command_name( - "commandSerializerNotEqual") - self.assertTrue(command_serializer_not_equal_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_SERIALIZER_NOT_EQUAL) - self.assertRegex(str(command_serializer_not_equal_error), "commandSerializerNotEqual") + "commandSerializerNotEqual" + ) + self.assertTrue( + command_serializer_not_equal_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_SERIALIZER_NOT_EQUAL + ) + self.assertRegex( + str(command_serializer_not_equal_error), "commandSerializerNotEqual" + ) - command_deserializer_not_equal_error = error_collection.get_error_by_command_name( - "commandDeserializerNotEqual") - self.assertTrue(command_deserializer_not_equal_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_DESERIALIZER_NOT_EQUAL) - self.assertRegex(str(command_deserializer_not_equal_error), "commandDeserializerNotEqual") + command_deserializer_not_equal_error = ( + error_collection.get_error_by_command_name("commandDeserializerNotEqual") + ) + self.assertTrue( + command_deserializer_not_equal_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_DESERIALIZER_NOT_EQUAL + ) + self.assertRegex( + str(command_deserializer_not_equal_error), "commandDeserializerNotEqual" + ) old_type_bson_any_unstable_error = error_collection.get_error_by_command_name( - "oldTypeBsonAnyUnstable") - self.assertTrue(old_type_bson_any_unstable_error.error_id == idl_compatibility_errors. - ERROR_ID_OLD_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY) - self.assertRegex(str(old_type_bson_any_unstable_error), "oldTypeBsonAnyUnstable") + "oldTypeBsonAnyUnstable" + ) + self.assertTrue( + old_type_bson_any_unstable_error.error_id + == idl_compatibility_errors.ERROR_ID_OLD_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY + ) + self.assertRegex( + str(old_type_bson_any_unstable_error), "oldTypeBsonAnyUnstable" + ) new_type_bson_any_unstable_error = error_collection.get_error_by_command_name( - "newTypeBsonAnyUnstable") - self.assertTrue(new_type_bson_any_unstable_error.error_id == idl_compatibility_errors. - ERROR_ID_NEW_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY) - self.assertRegex(str(new_type_bson_any_unstable_error), "newTypeBsonAnyUnstable") - - type_bson_any_not_allowed_unstable_error = error_collection.get_error_by_command_name( - "typeBsonAnyNotAllowedUnstable") + "newTypeBsonAnyUnstable" + ) self.assertTrue( - type_bson_any_not_allowed_unstable_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED) + new_type_bson_any_unstable_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( - str(type_bson_any_not_allowed_unstable_error), "typeBsonAnyNotAllowedUnstable") + str(new_type_bson_any_unstable_error), "newTypeBsonAnyUnstable" + ) - command_cpp_type_not_equal_unstable_error = error_collection.get_error_by_command_name( - "commandCppTypeNotEqualUnstable") - self.assertTrue(command_cpp_type_not_equal_unstable_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_CPP_TYPE_NOT_EQUAL) + type_bson_any_not_allowed_unstable_error = ( + error_collection.get_error_by_command_name("typeBsonAnyNotAllowedUnstable") + ) + self.assertTrue( + type_bson_any_not_allowed_unstable_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED + ) self.assertRegex( - str(command_cpp_type_not_equal_unstable_error), "commandCppTypeNotEqualUnstable") + str(type_bson_any_not_allowed_unstable_error), + "typeBsonAnyNotAllowedUnstable", + ) + + command_cpp_type_not_equal_unstable_error = ( + error_collection.get_error_by_command_name("commandCppTypeNotEqualUnstable") + ) + self.assertTrue( + command_cpp_type_not_equal_unstable_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_CPP_TYPE_NOT_EQUAL + ) + self.assertRegex( + str(command_cpp_type_not_equal_unstable_error), + "commandCppTypeNotEqualUnstable", + ) command_type_bson_any_with_variant_unstable_error = error_collection.get_error_by_command_name_and_error_id( "commandTypeBsonAnyWithVariantUnstable", - idl_compatibility_errors.ERROR_ID_OLD_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY) + idl_compatibility_errors.ERROR_ID_OLD_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY, + ) self.assertTrue( - command_type_bson_any_with_variant_unstable_error.error_id == - idl_compatibility_errors.ERROR_ID_OLD_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY) + command_type_bson_any_with_variant_unstable_error.error_id + == idl_compatibility_errors.ERROR_ID_OLD_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( str(command_type_bson_any_with_variant_unstable_error), - "commandTypeBsonAnyWithVariantUnstable") + "commandTypeBsonAnyWithVariantUnstable", + ) command_type_bson_any_with_variant_unstable_error = error_collection.get_error_by_command_name_and_error_id( "commandTypeBsonAnyWithVariantUnstable", - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY) + idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY, + ) self.assertTrue( - command_type_bson_any_with_variant_unstable_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY) + command_type_bson_any_with_variant_unstable_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY + ) self.assertRegex( str(command_type_bson_any_with_variant_unstable_error), - "commandTypeBsonAnyWithVariantUnstable") + "commandTypeBsonAnyWithVariantUnstable", + ) - newly_added_type_field_bson_any_not_allowed_error = error_collection.get_error_by_command_name( - "newlyAddedTypeFieldBsonAnyNotAllowed") + newly_added_type_field_bson_any_not_allowed_error = ( + error_collection.get_error_by_command_name( + "newlyAddedTypeFieldBsonAnyNotAllowed" + ) + ) self.assertTrue( - newly_added_type_field_bson_any_not_allowed_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED) + newly_added_type_field_bson_any_not_allowed_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY_NOT_ALLOWED + ) self.assertRegex( str(newly_added_type_field_bson_any_not_allowed_error), - "newlyAddedTypeFieldBsonAnyNotAllowed") + "newlyAddedTypeFieldBsonAnyNotAllowed", + ) new_type_not_enum_error = error_collection.get_error_by_error_id( - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_NOT_ENUM) + idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_NOT_ENUM + ) self.assertRegex(str(new_type_not_enum_error), "newTypeNotEnum") new_type_not_struct_error = error_collection.get_error_by_error_id( - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_NOT_STRUCT) + idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_NOT_STRUCT + ) self.assertRegex(str(new_type_not_struct_error), "newTypeNotStruct") new_type_enum_or_struct_error = error_collection.get_error_by_error_id( - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_ENUM_OR_STRUCT) + idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_ENUM_OR_STRUCT + ) self.assertRegex(str(new_type_enum_or_struct_error), "newTypeEnumOrStruct") new_type_not_superset_error = error_collection.get_error_by_command_name( - "newTypeNotSuperset") - self.assertTrue(new_type_not_superset_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_TYPE_NOT_SUPERSET) + "newTypeNotSuperset" + ) + self.assertTrue( + new_type_not_superset_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_TYPE_NOT_SUPERSET + ) self.assertRegex(str(new_type_not_superset_error), "newTypeNotSuperset") new_type_enum_not_superset_error = error_collection.get_error_by_command_name( - "newTypeEnumNotSuperset") - self.assertTrue(new_type_enum_not_superset_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_TYPE_NOT_SUPERSET) - self.assertRegex(str(new_type_enum_not_superset_error), "newTypeEnumNotSuperset") + "newTypeEnumNotSuperset" + ) + self.assertTrue( + new_type_enum_not_superset_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_TYPE_NOT_SUPERSET + ) + self.assertRegex( + str(new_type_enum_not_superset_error), "newTypeEnumNotSuperset" + ) new_type_struct_recursive_error = error_collection.get_error_by_command_name( - "newTypeStructRecursive") - self.assertTrue(new_type_struct_recursive_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_FIELD_UNSTABLE) + "newTypeStructRecursive" + ) + self.assertTrue( + new_type_struct_recursive_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_FIELD_UNSTABLE + ) self.assertRegex(str(new_type_struct_recursive_error), "newTypeStructRecursive") new_type_field_unstable_error = error_collection.get_error_by_command_name( - "newTypeFieldUnstable") - self.assertTrue(new_type_field_unstable_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_FIELD_UNSTABLE) + "newTypeFieldUnstable" + ) + self.assertTrue( + new_type_field_unstable_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_FIELD_UNSTABLE + ) self.assertRegex(str(new_type_field_unstable_error), "newTypeFieldUnstable") new_type_field_required_error = error_collection.get_error_by_error_id( - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_FIELD_REQUIRED) + idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_FIELD_REQUIRED + ) self.assertRegex(str(new_type_field_required_error), "newTypeFieldRequired") new_type_field_missing_error = error_collection.get_error_by_error_id( - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_FIELD_MISSING) + idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_FIELD_MISSING + ) self.assertRegex(str(new_type_field_missing_error), "newTypeFieldMissing") new_type_field_added_required_error = error_collection.get_error_by_error_id( - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_FIELD_ADDED_REQUIRED) - self.assertRegex(str(new_type_field_added_required_error), "newTypeFieldAddedRequired") + idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_FIELD_ADDED_REQUIRED + ) + self.assertRegex( + str(new_type_field_added_required_error), "newTypeFieldAddedRequired" + ) new_type_field_stable_required_no_default_error = error_collection.get_error_by_error_id( - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_FIELD_STABLE_REQUIRED_NO_DEFAULT) + idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_FIELD_STABLE_REQUIRED_NO_DEFAULT + ) self.assertRegex( str(new_type_field_stable_required_no_default_error), - "newTypeFieldStableRequiredNoDefault") + "newTypeFieldStableRequiredNoDefault", + ) new_reply_field_variant_type_error = error_collection.get_error_by_error_id( - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_VARIANT_TYPE) - self.assertRegex(str(new_reply_field_variant_type_error), "newReplyFieldVariantType") - - new_reply_field_variant_not_subset_error = error_collection.get_error_by_command_name( - "newReplyFieldVariantNotSubset") - self.assertTrue(new_reply_field_variant_not_subset_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_VARIANT_TYPE_NOT_SUBSET) + idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_VARIANT_TYPE + ) self.assertRegex( - str(new_reply_field_variant_not_subset_error), "newReplyFieldVariantNotSubset") + str(new_reply_field_variant_type_error), "newReplyFieldVariantType" + ) - new_reply_field_variant_not_subset_two_errors = error_collection.get_all_errors_by_command_name( - "newReplyFieldVariantNotSubsetTwo") + new_reply_field_variant_not_subset_error = ( + error_collection.get_error_by_command_name("newReplyFieldVariantNotSubset") + ) + self.assertTrue( + new_reply_field_variant_not_subset_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_VARIANT_TYPE_NOT_SUBSET + ) + self.assertRegex( + str(new_reply_field_variant_not_subset_error), + "newReplyFieldVariantNotSubset", + ) + + new_reply_field_variant_not_subset_two_errors = ( + error_collection.get_all_errors_by_command_name( + "newReplyFieldVariantNotSubsetTwo" + ) + ) self.assertTrue(len(new_reply_field_variant_not_subset_two_errors) == 2) for error in new_reply_field_variant_not_subset_two_errors: - self.assertTrue(error.error_id == idl_compatibility_errors. - ERROR_ID_NEW_REPLY_FIELD_VARIANT_TYPE_NOT_SUBSET) + self.assertTrue( + error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_VARIANT_TYPE_NOT_SUBSET + ) - new_reply_field_variant_recursive_error = error_collection.get_error_by_command_name( - "replyFieldVariantRecursive") - self.assertTrue(new_reply_field_variant_recursive_error.error_id == - idl_compatibility_errors.ERROR_ID_REPLY_FIELD_NOT_SUBSET) - self.assertRegex(str(new_reply_field_variant_recursive_error), "replyFieldVariantRecursive") + new_reply_field_variant_recursive_error = ( + error_collection.get_error_by_command_name("replyFieldVariantRecursive") + ) + self.assertTrue( + new_reply_field_variant_recursive_error.error_id + == idl_compatibility_errors.ERROR_ID_REPLY_FIELD_NOT_SUBSET + ) + self.assertRegex( + str(new_reply_field_variant_recursive_error), "replyFieldVariantRecursive" + ) - new_reply_field_variant_struct_not_subset_error = error_collection.get_error_by_command_name( - "newReplyFieldVariantStructNotSubset") - self.assertTrue(new_reply_field_variant_struct_not_subset_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_VARIANT_TYPE_NOT_SUBSET) + new_reply_field_variant_struct_not_subset_error = ( + error_collection.get_error_by_command_name( + "newReplyFieldVariantStructNotSubset" + ) + ) + self.assertTrue( + new_reply_field_variant_struct_not_subset_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_VARIANT_TYPE_NOT_SUBSET + ) self.assertRegex( str(new_reply_field_variant_struct_not_subset_error), - "newReplyFieldVariantStructNotSubset") + "newReplyFieldVariantStructNotSubset", + ) - new_reply_field_variant_struct_not_subset_two_error = error_collection.get_error_by_command_name( - "newReplyFieldVariantStructNotSubsetTwo") - self.assertTrue(new_reply_field_variant_struct_not_subset_two_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_VARIANT_TYPE_NOT_SUBSET) + new_reply_field_variant_struct_not_subset_two_error = ( + error_collection.get_error_by_command_name( + "newReplyFieldVariantStructNotSubsetTwo" + ) + ) + self.assertTrue( + new_reply_field_variant_struct_not_subset_two_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_VARIANT_TYPE_NOT_SUBSET + ) self.assertRegex( str(new_reply_field_variant_struct_not_subset_two_error), - "newReplyFieldVariantStructNotSubsetTwo") + "newReplyFieldVariantStructNotSubsetTwo", + ) - new_reply_field_array_variant_struct_not_subset_error = error_collection.get_error_by_command_name( - "newReplyFieldArrayVariantStructNotSubset") - self.assertTrue(new_reply_field_array_variant_struct_not_subset_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_VARIANT_TYPE_NOT_SUBSET) + new_reply_field_array_variant_struct_not_subset_error = ( + error_collection.get_error_by_command_name( + "newReplyFieldArrayVariantStructNotSubset" + ) + ) + self.assertTrue( + new_reply_field_array_variant_struct_not_subset_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_VARIANT_TYPE_NOT_SUBSET + ) self.assertRegex( str(new_reply_field_array_variant_struct_not_subset_error), - "newReplyFieldArrayVariantStructNotSubset") + "newReplyFieldArrayVariantStructNotSubset", + ) - new_reply_field_variant_struct_recursive_error = error_collection.get_error_by_command_name( - "replyFieldVariantStructRecursive") - self.assertTrue(new_reply_field_variant_struct_recursive_error.error_id == - idl_compatibility_errors.ERROR_ID_REPLY_FIELD_NOT_SUBSET) + new_reply_field_variant_struct_recursive_error = ( + error_collection.get_error_by_command_name( + "replyFieldVariantStructRecursive" + ) + ) + self.assertTrue( + new_reply_field_variant_struct_recursive_error.error_id + == idl_compatibility_errors.ERROR_ID_REPLY_FIELD_NOT_SUBSET + ) self.assertRegex( - str(new_reply_field_variant_struct_recursive_error), "replyFieldVariantStructRecursive") + str(new_reply_field_variant_struct_recursive_error), + "replyFieldVariantStructRecursive", + ) - new_reply_field_variant_not_subset_with_array_error = error_collection.get_error_by_command_name( - "newReplyFieldVariantNotSubsetWithArray") - self.assertTrue(new_reply_field_variant_not_subset_with_array_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_VARIANT_TYPE_NOT_SUBSET) + new_reply_field_variant_not_subset_with_array_error = ( + error_collection.get_error_by_command_name( + "newReplyFieldVariantNotSubsetWithArray" + ) + ) + self.assertTrue( + new_reply_field_variant_not_subset_with_array_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_VARIANT_TYPE_NOT_SUBSET + ) self.assertRegex( str(new_reply_field_variant_not_subset_with_array_error), - "newReplyFieldVariantNotSubsetWithArray") + "newReplyFieldVariantNotSubsetWithArray", + ) - new_reply_field_variant_not_subset_with_array_two_errors = error_collection.get_all_errors_by_command_name( - "newReplyFieldVariantNotSubsetTwoWithArray") - self.assertTrue(len(new_reply_field_variant_not_subset_with_array_two_errors) == 2) + new_reply_field_variant_not_subset_with_array_two_errors = ( + error_collection.get_all_errors_by_command_name( + "newReplyFieldVariantNotSubsetTwoWithArray" + ) + ) + self.assertTrue( + len(new_reply_field_variant_not_subset_with_array_two_errors) == 2 + ) for error in new_reply_field_variant_not_subset_with_array_two_errors: - self.assertTrue(error.error_id == idl_compatibility_errors. - ERROR_ID_NEW_REPLY_FIELD_VARIANT_TYPE_NOT_SUBSET) + self.assertTrue( + error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_VARIANT_TYPE_NOT_SUBSET + ) - new_reply_field_variant_recursive_with_array_error = error_collection.get_error_by_command_name( - "replyFieldVariantRecursiveWithArray") - self.assertTrue(new_reply_field_variant_recursive_with_array_error.error_id == - idl_compatibility_errors.ERROR_ID_REPLY_FIELD_NOT_SUBSET) + new_reply_field_variant_recursive_with_array_error = ( + error_collection.get_error_by_command_name( + "replyFieldVariantRecursiveWithArray" + ) + ) + self.assertTrue( + new_reply_field_variant_recursive_with_array_error.error_id + == idl_compatibility_errors.ERROR_ID_REPLY_FIELD_NOT_SUBSET + ) self.assertRegex( str(new_reply_field_variant_recursive_with_array_error), - "replyFieldVariantRecursiveWithArray") + "replyFieldVariantRecursiveWithArray", + ) - new_reply_field_variant_struct_not_subset_with_array_error = error_collection.get_error_by_command_name( - "newReplyFieldVariantStructNotSubsetWithArray") - self.assertTrue(new_reply_field_variant_struct_not_subset_with_array_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_VARIANT_TYPE_NOT_SUBSET) + new_reply_field_variant_struct_not_subset_with_array_error = ( + error_collection.get_error_by_command_name( + "newReplyFieldVariantStructNotSubsetWithArray" + ) + ) + self.assertTrue( + new_reply_field_variant_struct_not_subset_with_array_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_VARIANT_TYPE_NOT_SUBSET + ) self.assertRegex( str(new_reply_field_variant_struct_not_subset_with_array_error), - "newReplyFieldVariantStructNotSubsetWithArray") + "newReplyFieldVariantStructNotSubsetWithArray", + ) - new_reply_field_variant_struct_recursive_with_array_error = error_collection.get_error_by_command_name( - "replyFieldVariantStructRecursiveWithArray") - self.assertTrue(new_reply_field_variant_struct_recursive_with_array_error.error_id == - idl_compatibility_errors.ERROR_ID_REPLY_FIELD_NOT_SUBSET) + new_reply_field_variant_struct_recursive_with_array_error = ( + error_collection.get_error_by_command_name( + "replyFieldVariantStructRecursiveWithArray" + ) + ) + self.assertTrue( + new_reply_field_variant_struct_recursive_with_array_error.error_id + == idl_compatibility_errors.ERROR_ID_REPLY_FIELD_NOT_SUBSET + ) self.assertRegex( str(new_reply_field_variant_struct_recursive_with_array_error), - "replyFieldVariantStructRecursiveWithArray") + "replyFieldVariantStructRecursiveWithArray", + ) - new_command_parameter_contains_validator_error = error_collection.get_error_by_command_name( - "newCommandParameterValidator") - self.assertTrue(new_command_parameter_contains_validator_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_CONTAINS_VALIDATOR) + new_command_parameter_contains_validator_error = ( + error_collection.get_error_by_command_name("newCommandParameterValidator") + ) + self.assertTrue( + new_command_parameter_contains_validator_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_CONTAINS_VALIDATOR + ) self.assertRegex( - str(new_command_parameter_contains_validator_error), "newCommandParameterValidator") + str(new_command_parameter_contains_validator_error), + "newCommandParameterValidator", + ) - command_parameter_validators_not_equal_error = error_collection.get_error_by_command_name( - "commandParameterValidatorsNotEqual") - self.assertTrue(command_parameter_validators_not_equal_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_VALIDATORS_NOT_EQUAL) + command_parameter_validators_not_equal_error = ( + error_collection.get_error_by_command_name( + "commandParameterValidatorsNotEqual" + ) + ) + self.assertTrue( + command_parameter_validators_not_equal_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_VALIDATORS_NOT_EQUAL + ) self.assertRegex( - str(command_parameter_validators_not_equal_error), "commandParameterValidatorsNotEqual") + str(command_parameter_validators_not_equal_error), + "commandParameterValidatorsNotEqual", + ) - new_command_type_contains_validator_error = error_collection.get_error_by_command_name( - "newCommandTypeValidator") - self.assertTrue(new_command_type_contains_validator_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_TYPE_CONTAINS_VALIDATOR) - self.assertRegex(str(new_command_type_contains_validator_error), "newCommandTypeValidator") - - command_type_validators_not_equal_error = error_collection.get_error_by_command_name( - "commandTypeValidatorsNotEqual") - self.assertTrue(command_type_validators_not_equal_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_TYPE_VALIDATORS_NOT_EQUAL) + new_command_type_contains_validator_error = ( + error_collection.get_error_by_command_name("newCommandTypeValidator") + ) + self.assertTrue( + new_command_type_contains_validator_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_TYPE_CONTAINS_VALIDATOR + ) self.assertRegex( - str(command_type_validators_not_equal_error), "commandTypeValidatorsNotEqual") + str(new_command_type_contains_validator_error), "newCommandTypeValidator" + ) + + command_type_validators_not_equal_error = ( + error_collection.get_error_by_command_name("commandTypeValidatorsNotEqual") + ) + self.assertTrue( + command_type_validators_not_equal_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_TYPE_VALIDATORS_NOT_EQUAL + ) + self.assertRegex( + str(command_type_validators_not_equal_error), + "commandTypeValidatorsNotEqual", + ) array_command_type_error = error_collection.get_error_by_command_name( - "arrayCommandTypeError") - self.assertTrue(array_command_type_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_NOT_STRUCT) + "arrayCommandTypeError" + ) + self.assertTrue( + array_command_type_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_NOT_STRUCT + ) self.assertRegex(str(array_command_type_error), "ArrayTypeStruct") - array_command_param_type_two_errors = error_collection.get_all_errors_by_command_name( - "arrayCommandParameterTypeError") + array_command_param_type_two_errors = ( + error_collection.get_all_errors_by_command_name( + "arrayCommandParameterTypeError" + ) + ) self.assertTrue(len(array_command_param_type_two_errors) == 2) - self.assertTrue(array_command_param_type_two_errors[0].error_id == - idl_compatibility_errors.ERROR_ID_REMOVED_COMMAND_PARAMETER) - self.assertRegex(str(array_command_param_type_two_errors[0]), "ArrayCommandParameter") - self.assertTrue(array_command_param_type_two_errors[1].error_id == - idl_compatibility_errors.ERROR_ID_ADDED_REQUIRED_COMMAND_PARAMETER) + self.assertTrue( + array_command_param_type_two_errors[0].error_id + == idl_compatibility_errors.ERROR_ID_REMOVED_COMMAND_PARAMETER + ) + self.assertRegex( + str(array_command_param_type_two_errors[0]), "ArrayCommandParameter" + ) + self.assertTrue( + array_command_param_type_two_errors[1].error_id + == idl_compatibility_errors.ERROR_ID_ADDED_REQUIRED_COMMAND_PARAMETER + ) self.assertRegex(str(array_command_param_type_two_errors[1]), "fieldOne") - new_param_variant_not_superset_error = error_collection.get_error_by_command_name( - "newParamVariantNotSuperset") - self.assertTrue(new_param_variant_not_superset_error.error_id == idl_compatibility_errors. - ERROR_ID_NEW_COMMAND_PARAMETER_VARIANT_TYPE_NOT_SUPERSET) - self.assertRegex(str(new_param_variant_not_superset_error), "newParamVariantNotSuperset") + new_param_variant_not_superset_error = ( + error_collection.get_error_by_command_name("newParamVariantNotSuperset") + ) + self.assertTrue( + new_param_variant_not_superset_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_VARIANT_TYPE_NOT_SUPERSET + ) + self.assertRegex( + str(new_param_variant_not_superset_error), "newParamVariantNotSuperset" + ) - new_param_variant_not_superset_two_errors = error_collection.get_all_errors_by_command_name( - "newParamVariantNotSupersetTwo") + new_param_variant_not_superset_two_errors = ( + error_collection.get_all_errors_by_command_name( + "newParamVariantNotSupersetTwo" + ) + ) self.assertTrue(len(new_param_variant_not_superset_two_errors) == 2) for error in new_param_variant_not_superset_two_errors: - self.assertTrue(error.error_id == idl_compatibility_errors. - ERROR_ID_NEW_COMMAND_PARAMETER_VARIANT_TYPE_NOT_SUPERSET) + self.assertTrue( + error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_VARIANT_TYPE_NOT_SUPERSET + ) - new_param_variant_not_superset_three_error = error_collection.get_error_by_command_name( - "newParamVariantNotSupersetThree") + new_param_variant_not_superset_three_error = ( + error_collection.get_error_by_command_name( + "newParamVariantNotSupersetThree" + ) + ) self.assertTrue( - new_param_variant_not_superset_three_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_VARIANT_TYPE_NOT_SUPERSET) + new_param_variant_not_superset_three_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_VARIANT_TYPE_NOT_SUPERSET + ) self.assertRegex( - str(new_param_variant_not_superset_three_error), "newParamVariantNotSupersetThree") + str(new_param_variant_not_superset_three_error), + "newParamVariantNotSupersetThree", + ) - new_param_array_variant_not_superset_error = error_collection.get_error_by_command_name( - "newParamArrayVariantNotSuperset") + new_param_array_variant_not_superset_error = ( + error_collection.get_error_by_command_name( + "newParamArrayVariantNotSuperset" + ) + ) self.assertTrue( - new_param_array_variant_not_superset_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_VARIANT_TYPE_NOT_SUPERSET) + new_param_array_variant_not_superset_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_VARIANT_TYPE_NOT_SUPERSET + ) self.assertRegex( - str(new_param_array_variant_not_superset_error), "newParamArrayVariantNotSuperset") + str(new_param_array_variant_not_superset_error), + "newParamArrayVariantNotSuperset", + ) new_param_type_not_variant_error = error_collection.get_error_by_command_name( - "newParamTypeNotVariant") - self.assertTrue(new_param_type_not_variant_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_NOT_VARIANT) - self.assertRegex(str(new_param_type_not_variant_error), "newParamTypeNotVariant") + "newParamTypeNotVariant" + ) + self.assertTrue( + new_param_type_not_variant_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_TYPE_NOT_VARIANT + ) + self.assertRegex( + str(new_param_type_not_variant_error), "newParamTypeNotVariant" + ) new_param_variant_recursive_error = error_collection.get_error_by_command_name( - "newParamVariantRecursive") - self.assertTrue(new_param_variant_recursive_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_TYPE_NOT_SUPERSET) - self.assertRegex(str(new_param_variant_recursive_error), "newParamVariantRecursive") - - new_param_variant_struct_not_superset_error = error_collection.get_error_by_command_name( - "newParamVariantStructNotSuperset") + "newParamVariantRecursive" + ) self.assertTrue( - new_param_variant_struct_not_superset_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_VARIANT_TYPE_NOT_SUPERSET) + new_param_variant_recursive_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_TYPE_NOT_SUPERSET + ) self.assertRegex( - str(new_param_variant_struct_not_superset_error), "newParamVariantStructNotSuperset") + str(new_param_variant_recursive_error), "newParamVariantRecursive" + ) - new_param_variant_struct_recursive_error = error_collection.get_error_by_command_name( - "newParamVariantStructRecursive") - self.assertTrue(new_param_variant_struct_recursive_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_TYPE_NOT_SUPERSET) + new_param_variant_struct_not_superset_error = ( + error_collection.get_error_by_command_name( + "newParamVariantStructNotSuperset" + ) + ) + self.assertTrue( + new_param_variant_struct_not_superset_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_VARIANT_TYPE_NOT_SUPERSET + ) self.assertRegex( - str(new_param_variant_struct_recursive_error), "newParamVariantStructRecursive") + str(new_param_variant_struct_not_superset_error), + "newParamVariantStructNotSuperset", + ) - new_command_type_variant_not_superset_error = error_collection.get_error_by_command_name( - "newCommandTypeVariantNotSuperset") - self.assertTrue(new_command_type_variant_not_superset_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_VARIANT_TYPE_NOT_SUPERSET) + new_param_variant_struct_recursive_error = ( + error_collection.get_error_by_command_name("newParamVariantStructRecursive") + ) + self.assertTrue( + new_param_variant_struct_recursive_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_TYPE_NOT_SUPERSET + ) self.assertRegex( - str(new_command_type_variant_not_superset_error), "newCommandTypeVariantNotSuperset") + str(new_param_variant_struct_recursive_error), + "newParamVariantStructRecursive", + ) - new_command_type_variant_not_superset_two_errors = error_collection.get_all_errors_by_command_name( - "newCommandTypeVariantNotSupersetTwo") + new_command_type_variant_not_superset_error = ( + error_collection.get_error_by_command_name( + "newCommandTypeVariantNotSuperset" + ) + ) + self.assertTrue( + new_command_type_variant_not_superset_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_VARIANT_TYPE_NOT_SUPERSET + ) + self.assertRegex( + str(new_command_type_variant_not_superset_error), + "newCommandTypeVariantNotSuperset", + ) + + new_command_type_variant_not_superset_two_errors = ( + error_collection.get_all_errors_by_command_name( + "newCommandTypeVariantNotSupersetTwo" + ) + ) self.assertTrue(len(new_command_type_variant_not_superset_two_errors) == 2) for error in new_command_type_variant_not_superset_two_errors: - self.assertTrue(error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_VARIANT_TYPE_NOT_SUPERSET) + self.assertTrue( + error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_VARIANT_TYPE_NOT_SUPERSET + ) new_command_type_not_variant_error = error_collection.get_error_by_command_name( - "newCommandTypeNotVariant") - self.assertTrue(new_command_type_not_variant_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_NOT_VARIANT) - self.assertRegex(str(new_command_type_not_variant_error), "newCommandTypeNotVariant") - - new_command_type_variant_recursive_error = error_collection.get_error_by_command_name( - "newCommandTypeVariantRecursive") - self.assertTrue(new_command_type_variant_recursive_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_TYPE_NOT_SUPERSET) + "newCommandTypeNotVariant" + ) + self.assertTrue( + new_command_type_not_variant_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_NOT_VARIANT + ) self.assertRegex( - str(new_command_type_variant_recursive_error), "newCommandTypeVariantRecursive") + str(new_command_type_not_variant_error), "newCommandTypeNotVariant" + ) - new_command_type_variant_struct_not_superset_error = error_collection.get_error_by_command_name( - "newCommandTypeVariantStructNotSuperset") - self.assertTrue(new_command_type_variant_struct_not_superset_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_VARIANT_TYPE_NOT_SUPERSET) + new_command_type_variant_recursive_error = ( + error_collection.get_error_by_command_name("newCommandTypeVariantRecursive") + ) + self.assertTrue( + new_command_type_variant_recursive_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_TYPE_NOT_SUPERSET + ) + self.assertRegex( + str(new_command_type_variant_recursive_error), + "newCommandTypeVariantRecursive", + ) + + new_command_type_variant_struct_not_superset_error = ( + error_collection.get_error_by_command_name( + "newCommandTypeVariantStructNotSuperset" + ) + ) + self.assertTrue( + new_command_type_variant_struct_not_superset_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_VARIANT_TYPE_NOT_SUPERSET + ) self.assertRegex( str(new_command_type_variant_struct_not_superset_error), - "newCommandTypeVariantStructNotSuperset") + "newCommandTypeVariantStructNotSuperset", + ) - new_command_type_variant_struct_recursive_error = error_collection.get_error_by_command_name( - "newCommandTypeVariantStructRecursive") - self.assertTrue(new_command_type_variant_struct_recursive_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_TYPE_NOT_SUPERSET) + new_command_type_variant_struct_recursive_error = ( + error_collection.get_error_by_command_name( + "newCommandTypeVariantStructRecursive" + ) + ) + self.assertTrue( + new_command_type_variant_struct_recursive_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_TYPE_NOT_SUPERSET + ) self.assertRegex( str(new_command_type_variant_struct_recursive_error), - "newCommandTypeVariantStructRecursive") - new_reply_field_contains_validator_error = error_collection.get_error_by_command_name( - "newReplyFieldValidator") - self.assertTrue(new_reply_field_contains_validator_error.error_id == - idl_compatibility_errors.ERROR_ID_REPLY_FIELD_CONTAINS_VALIDATOR) - self.assertRegex(str(new_reply_field_contains_validator_error), "newReplyFieldValidator") - - reply_field_validators_not_equal_error = error_collection.get_error_by_command_name( - "replyFieldValidatorsNotEqual") - self.assertTrue(reply_field_validators_not_equal_error.error_id == - idl_compatibility_errors.ERROR_ID_REPLY_FIELD_VALIDATORS_NOT_EQUAL) + "newCommandTypeVariantStructRecursive", + ) + new_reply_field_contains_validator_error = ( + error_collection.get_error_by_command_name("newReplyFieldValidator") + ) + self.assertTrue( + new_reply_field_contains_validator_error.error_id + == idl_compatibility_errors.ERROR_ID_REPLY_FIELD_CONTAINS_VALIDATOR + ) self.assertRegex( - str(reply_field_validators_not_equal_error), "replyFieldValidatorsNotEqual") + str(new_reply_field_contains_validator_error), "newReplyFieldValidator" + ) + + reply_field_validators_not_equal_error = ( + error_collection.get_error_by_command_name("replyFieldValidatorsNotEqual") + ) + self.assertTrue( + reply_field_validators_not_equal_error.error_id + == idl_compatibility_errors.ERROR_ID_REPLY_FIELD_VALIDATORS_NOT_EQUAL + ) + self.assertRegex( + str(reply_field_validators_not_equal_error), "replyFieldValidatorsNotEqual" + ) simple_check_not_equal_error = error_collection.get_error_by_command_name( - "simpleCheckNotEqual") - self.assertTrue(simple_check_not_equal_error.error_id == - idl_compatibility_errors.ERROR_ID_CHECK_NOT_EQUAL) + "simpleCheckNotEqual" + ) + self.assertTrue( + simple_check_not_equal_error.error_id + == idl_compatibility_errors.ERROR_ID_CHECK_NOT_EQUAL + ) self.assertRegex(str(simple_check_not_equal_error), "simpleCheckNotEqual") simple_check_not_equal_error_two = error_collection.get_error_by_command_name( - "simpleCheckNotEqualTwo") - self.assertTrue(simple_check_not_equal_error_two.error_id == - idl_compatibility_errors.ERROR_ID_CHECK_NOT_EQUAL) - self.assertRegex(str(simple_check_not_equal_error_two), "simpleCheckNotEqualTwo") + "simpleCheckNotEqualTwo" + ) + self.assertTrue( + simple_check_not_equal_error_two.error_id + == idl_compatibility_errors.ERROR_ID_CHECK_NOT_EQUAL + ) + self.assertRegex( + str(simple_check_not_equal_error_two), "simpleCheckNotEqualTwo" + ) simple_check_not_equal_error_three = error_collection.get_error_by_command_name( - "simpleCheckNotEqualThree") - self.assertTrue(simple_check_not_equal_error_three.error_id == - idl_compatibility_errors.ERROR_ID_CHECK_NOT_EQUAL) - self.assertRegex(str(simple_check_not_equal_error_three), "simpleCheckNotEqualThree") - - simple_resource_pattern_not_equal_error = error_collection.get_error_by_command_name( - "simpleResourcePatternNotEqual") - self.assertTrue(simple_resource_pattern_not_equal_error.error_id == - idl_compatibility_errors.ERROR_ID_RESOURCE_PATTERN_NOT_EQUAL) - self.assertRegex( - str(simple_resource_pattern_not_equal_error), "simpleResourcePatternNotEqual") - - new_simple_action_types_not_subset_error = error_collection.get_error_by_command_name( - "newSimpleActionTypesNotSubset") - self.assertTrue(new_simple_action_types_not_subset_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_ACTION_TYPES_NOT_SUBSET) - self.assertRegex( - str(new_simple_action_types_not_subset_error), "newSimpleActionTypesNotSubset") - - new_param_variant_not_superset_with_array_error = error_collection.get_error_by_command_name( - "newParamVariantNotSupersetWithArray") + "simpleCheckNotEqualThree" + ) self.assertTrue( - new_param_variant_not_superset_with_array_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_VARIANT_TYPE_NOT_SUPERSET) + simple_check_not_equal_error_three.error_id + == idl_compatibility_errors.ERROR_ID_CHECK_NOT_EQUAL + ) + self.assertRegex( + str(simple_check_not_equal_error_three), "simpleCheckNotEqualThree" + ) + + simple_resource_pattern_not_equal_error = ( + error_collection.get_error_by_command_name("simpleResourcePatternNotEqual") + ) + self.assertTrue( + simple_resource_pattern_not_equal_error.error_id + == idl_compatibility_errors.ERROR_ID_RESOURCE_PATTERN_NOT_EQUAL + ) + self.assertRegex( + str(simple_resource_pattern_not_equal_error), + "simpleResourcePatternNotEqual", + ) + + new_simple_action_types_not_subset_error = ( + error_collection.get_error_by_command_name("newSimpleActionTypesNotSubset") + ) + self.assertTrue( + new_simple_action_types_not_subset_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_ACTION_TYPES_NOT_SUBSET + ) + self.assertRegex( + str(new_simple_action_types_not_subset_error), + "newSimpleActionTypesNotSubset", + ) + + new_param_variant_not_superset_with_array_error = ( + error_collection.get_error_by_command_name( + "newParamVariantNotSupersetWithArray" + ) + ) + self.assertTrue( + new_param_variant_not_superset_with_array_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_VARIANT_TYPE_NOT_SUPERSET + ) self.assertRegex( str(new_param_variant_not_superset_with_array_error), - "newParamVariantNotSupersetWithArray") + "newParamVariantNotSupersetWithArray", + ) - new_param_variant_not_superset_with_array_two_errors = error_collection.get_all_errors_by_command_name( - "newParamVariantNotSupersetTwoWithArray") + new_param_variant_not_superset_with_array_two_errors = ( + error_collection.get_all_errors_by_command_name( + "newParamVariantNotSupersetTwoWithArray" + ) + ) self.assertTrue(len(new_param_variant_not_superset_with_array_two_errors) == 2) for error in new_param_variant_not_superset_with_array_two_errors: - self.assertTrue(error.error_id == idl_compatibility_errors. - ERROR_ID_NEW_COMMAND_PARAMETER_VARIANT_TYPE_NOT_SUPERSET) + self.assertTrue( + error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_VARIANT_TYPE_NOT_SUPERSET + ) - new_param_variant_recursive_with_array_error = error_collection.get_error_by_command_name( - "newParamVariantRecursiveWithArray") - self.assertTrue(new_param_variant_recursive_with_array_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_TYPE_NOT_SUPERSET) - self.assertRegex( - str(new_param_variant_recursive_with_array_error), "newParamVariantRecursiveWithArray") - - new_param_variant_struct_not_superset_with_array_error = error_collection.get_error_by_command_name( - "newParamVariantStructNotSupersetWithArray") + new_param_variant_recursive_with_array_error = ( + error_collection.get_error_by_command_name( + "newParamVariantRecursiveWithArray" + ) + ) self.assertTrue( - new_param_variant_struct_not_superset_with_array_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_VARIANT_TYPE_NOT_SUPERSET) + new_param_variant_recursive_with_array_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_TYPE_NOT_SUPERSET + ) + self.assertRegex( + str(new_param_variant_recursive_with_array_error), + "newParamVariantRecursiveWithArray", + ) + + new_param_variant_struct_not_superset_with_array_error = ( + error_collection.get_error_by_command_name( + "newParamVariantStructNotSupersetWithArray" + ) + ) + self.assertTrue( + new_param_variant_struct_not_superset_with_array_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_VARIANT_TYPE_NOT_SUPERSET + ) self.assertRegex( str(new_param_variant_struct_not_superset_with_array_error), - "newParamVariantStructNotSupersetWithArray") + "newParamVariantStructNotSupersetWithArray", + ) - new_param_variant_struct_recursive_with_array_error = error_collection.get_error_by_command_name( - "newParamVariantStructRecursiveWithArray") - self.assertTrue(new_param_variant_struct_recursive_with_array_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_TYPE_NOT_SUPERSET) + new_param_variant_struct_recursive_with_array_error = ( + error_collection.get_error_by_command_name( + "newParamVariantStructRecursiveWithArray" + ) + ) + self.assertTrue( + new_param_variant_struct_recursive_with_array_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_TYPE_NOT_SUPERSET + ) self.assertRegex( str(new_param_variant_struct_recursive_with_array_error), - "newParamVariantStructRecursiveWithArray") + "newParamVariantStructRecursiveWithArray", + ) - new_command_type_variant_not_superset_with_array_error = error_collection.get_error_by_command_name( - "newCommandTypeVariantNotSupersetWithArray") - self.assertTrue(new_command_type_variant_not_superset_with_array_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_VARIANT_TYPE_NOT_SUPERSET) + new_command_type_variant_not_superset_with_array_error = ( + error_collection.get_error_by_command_name( + "newCommandTypeVariantNotSupersetWithArray" + ) + ) + self.assertTrue( + new_command_type_variant_not_superset_with_array_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_VARIANT_TYPE_NOT_SUPERSET + ) self.assertRegex( str(new_command_type_variant_not_superset_with_array_error), - "newCommandTypeVariantNotSupersetWithArray") + "newCommandTypeVariantNotSupersetWithArray", + ) - new_command_type_variant_not_superset_with_array_two_errors = error_collection.get_all_errors_by_command_name( - "newCommandTypeVariantNotSupersetTwoWithArray") - self.assertTrue(len(new_command_type_variant_not_superset_with_array_two_errors) == 2) + new_command_type_variant_not_superset_with_array_two_errors = ( + error_collection.get_all_errors_by_command_name( + "newCommandTypeVariantNotSupersetTwoWithArray" + ) + ) + self.assertTrue( + len(new_command_type_variant_not_superset_with_array_two_errors) == 2 + ) for error in new_command_type_variant_not_superset_with_array_two_errors: - self.assertTrue(error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_VARIANT_TYPE_NOT_SUPERSET) + self.assertTrue( + error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_VARIANT_TYPE_NOT_SUPERSET + ) - new_command_type_variant_recursive_with_array_error = error_collection.get_error_by_command_name( - "newCommandTypeVariantRecursiveWithArray") - self.assertTrue(new_command_type_variant_recursive_with_array_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_TYPE_NOT_SUPERSET) + new_command_type_variant_recursive_with_array_error = ( + error_collection.get_error_by_command_name( + "newCommandTypeVariantRecursiveWithArray" + ) + ) + self.assertTrue( + new_command_type_variant_recursive_with_array_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_TYPE_NOT_SUPERSET + ) self.assertRegex( str(new_command_type_variant_recursive_with_array_error), - "newCommandTypeVariantRecursiveWithArray") + "newCommandTypeVariantRecursiveWithArray", + ) - new_command_type_variant_struct_not_superset_with_array_error = error_collection.get_error_by_command_name( - "newCommandTypeVariantStructNotSupersetWithArray") - self.assertTrue(new_command_type_variant_struct_not_superset_with_array_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_VARIANT_TYPE_NOT_SUPERSET) + new_command_type_variant_struct_not_superset_with_array_error = ( + error_collection.get_error_by_command_name( + "newCommandTypeVariantStructNotSupersetWithArray" + ) + ) + self.assertTrue( + new_command_type_variant_struct_not_superset_with_array_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_VARIANT_TYPE_NOT_SUPERSET + ) self.assertRegex( str(new_command_type_variant_struct_not_superset_with_array_error), - "newCommandTypeVariantStructNotSupersetWithArray") + "newCommandTypeVariantStructNotSupersetWithArray", + ) - new_command_type_variant_struct_recursive_with_array_error = error_collection.get_error_by_command_name( - "newCommandTypeVariantStructRecursiveWithArray") - self.assertTrue(new_command_type_variant_struct_recursive_with_array_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_TYPE_NOT_SUPERSET) + new_command_type_variant_struct_recursive_with_array_error = ( + error_collection.get_error_by_command_name( + "newCommandTypeVariantStructRecursiveWithArray" + ) + ) + self.assertTrue( + new_command_type_variant_struct_recursive_with_array_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_TYPE_NOT_SUPERSET + ) self.assertRegex( str(new_command_type_variant_struct_recursive_with_array_error), - "newCommandTypeVariantStructRecursiveWithArray") + "newCommandTypeVariantStructRecursiveWithArray", + ) access_check_type_change_error = error_collection.get_error_by_command_name( - "accessCheckTypeChange") - self.assertTrue(access_check_type_change_error.error_id == - idl_compatibility_errors.ERROR_ID_ACCESS_CHECK_TYPE_NOT_EQUAL) + "accessCheckTypeChange" + ) + self.assertTrue( + access_check_type_change_error.error_id + == idl_compatibility_errors.ERROR_ID_ACCESS_CHECK_TYPE_NOT_EQUAL + ) self.assertRegex(str(access_check_type_change_error), "accessCheckTypeChange") access_check_type_change_two_error = error_collection.get_error_by_command_name( - "accessCheckTypeChangeTwo") - self.assertTrue(access_check_type_change_two_error.error_id == - idl_compatibility_errors.ERROR_ID_ACCESS_CHECK_TYPE_NOT_EQUAL) - self.assertRegex(str(access_check_type_change_two_error), "accessCheckTypeChangeTwo") + "accessCheckTypeChangeTwo" + ) + self.assertTrue( + access_check_type_change_two_error.error_id + == idl_compatibility_errors.ERROR_ID_ACCESS_CHECK_TYPE_NOT_EQUAL + ) + self.assertRegex( + str(access_check_type_change_two_error), "accessCheckTypeChangeTwo" + ) complex_checks_not_subset_error = error_collection.get_error_by_command_name( - "complexChecksNotSubset") - self.assertTrue(complex_checks_not_subset_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMPLEX_CHECKS_NOT_SUBSET) + "complexChecksNotSubset" + ) + self.assertTrue( + complex_checks_not_subset_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMPLEX_CHECKS_NOT_SUBSET + ) self.assertRegex(str(complex_checks_not_subset_error), "complexChecksNotSubset") - complex_checks_not_subset_two_error = error_collection.get_error_by_command_name( - "complexChecksNotSubsetTwo") - self.assertTrue(complex_checks_not_subset_two_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_ADDITIONAL_COMPLEX_ACCESS_CHECK) - self.assertRegex(str(complex_checks_not_subset_two_error), "complexChecksNotSubsetTwo") + complex_checks_not_subset_two_error = ( + error_collection.get_error_by_command_name("complexChecksNotSubsetTwo") + ) + self.assertTrue( + complex_checks_not_subset_two_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_ADDITIONAL_COMPLEX_ACCESS_CHECK + ) + self.assertRegex( + str(complex_checks_not_subset_two_error), "complexChecksNotSubsetTwo" + ) - complex_check_privileges_superset_none_allowed_error = error_collection.get_error_by_command_name( - "complexCheckPrivilegesSupersetNoneAllowed") - self.assertTrue(complex_check_privileges_superset_none_allowed_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_ADDITIONAL_COMPLEX_ACCESS_CHECK) + complex_check_privileges_superset_none_allowed_error = ( + error_collection.get_error_by_command_name( + "complexCheckPrivilegesSupersetNoneAllowed" + ) + ) + self.assertTrue( + complex_check_privileges_superset_none_allowed_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_ADDITIONAL_COMPLEX_ACCESS_CHECK + ) self.assertRegex( str(complex_check_privileges_superset_none_allowed_error), - "complexCheckPrivilegesSupersetNoneAllowed") + "complexCheckPrivilegesSupersetNoneAllowed", + ) - complex_check_privileges_superset_some_allowed_error = error_collection.get_error_by_command_name( - "complexCheckPrivilegesSupersetSomeAllowed") - self.assertTrue(complex_check_privileges_superset_some_allowed_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_ADDITIONAL_COMPLEX_ACCESS_CHECK) + complex_check_privileges_superset_some_allowed_error = ( + error_collection.get_error_by_command_name( + "complexCheckPrivilegesSupersetSomeAllowed" + ) + ) + self.assertTrue( + complex_check_privileges_superset_some_allowed_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_ADDITIONAL_COMPLEX_ACCESS_CHECK + ) self.assertRegex( str(complex_check_privileges_superset_some_allowed_error), - "complexCheckPrivilegesSupersetSomeAllowed") + "complexCheckPrivilegesSupersetSomeAllowed", + ) - complex_checks_superset_none_allowed_error = error_collection.get_error_by_command_name( - "complexChecksSupersetNoneAllowed") - self.assertTrue(complex_checks_superset_none_allowed_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_ADDITIONAL_COMPLEX_ACCESS_CHECK) + complex_checks_superset_none_allowed_error = ( + error_collection.get_error_by_command_name( + "complexChecksSupersetNoneAllowed" + ) + ) + self.assertTrue( + complex_checks_superset_none_allowed_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_ADDITIONAL_COMPLEX_ACCESS_CHECK + ) self.assertRegex( - str(complex_checks_superset_none_allowed_error), "complexChecksSupersetNoneAllowed") + str(complex_checks_superset_none_allowed_error), + "complexChecksSupersetNoneAllowed", + ) - complex_checks_superset_some_allowed_error = error_collection.get_error_by_command_name( - "complexChecksSupersetSomeAllowed") - self.assertTrue(complex_checks_superset_some_allowed_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_ADDITIONAL_COMPLEX_ACCESS_CHECK) + complex_checks_superset_some_allowed_error = ( + error_collection.get_error_by_command_name( + "complexChecksSupersetSomeAllowed" + ) + ) + self.assertTrue( + complex_checks_superset_some_allowed_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_ADDITIONAL_COMPLEX_ACCESS_CHECK + ) self.assertRegex( - str(complex_checks_superset_some_allowed_error), "complexChecksSupersetSomeAllowed") + str(complex_checks_superset_some_allowed_error), + "complexChecksSupersetSomeAllowed", + ) - complex_resource_pattern_change_error = error_collection.get_error_by_command_name( - "complexResourcePatternChange") - self.assertTrue(complex_resource_pattern_change_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMPLEX_PRIVILEGES_NOT_SUBSET) - self.assertRegex(str(complex_resource_pattern_change_error), "complexResourcePatternChange") - - complex_action_types_not_subset_error = error_collection.get_error_by_command_name( - "complexActionTypesNotSubset") - self.assertTrue(complex_action_types_not_subset_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMPLEX_PRIVILEGES_NOT_SUBSET) - self.assertRegex(str(complex_action_types_not_subset_error), "complexActionTypesNotSubset") - - complex_action_types_not_subset_two_error = error_collection.get_error_by_command_name( - "complexActionTypesNotSubsetTwo") - self.assertTrue(complex_action_types_not_subset_two_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMPLEX_PRIVILEGES_NOT_SUBSET) + complex_resource_pattern_change_error = ( + error_collection.get_error_by_command_name("complexResourcePatternChange") + ) + self.assertTrue( + complex_resource_pattern_change_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMPLEX_PRIVILEGES_NOT_SUBSET + ) self.assertRegex( - str(complex_action_types_not_subset_two_error), "complexActionTypesNotSubsetTwo") + str(complex_resource_pattern_change_error), "complexResourcePatternChange" + ) - additional_complex_access_check_error = error_collection.get_error_by_command_name( - "additionalComplexAccessCheck") - self.assertTrue(additional_complex_access_check_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_ADDITIONAL_COMPLEX_ACCESS_CHECK) - self.assertRegex(str(additional_complex_access_check_error), "additionalComplexAccessCheck") + complex_action_types_not_subset_error = ( + error_collection.get_error_by_command_name("complexActionTypesNotSubset") + ) + self.assertTrue( + complex_action_types_not_subset_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMPLEX_PRIVILEGES_NOT_SUBSET + ) + self.assertRegex( + str(complex_action_types_not_subset_error), "complexActionTypesNotSubset" + ) - additional_complex_access_check_agg_stage_error = error_collection.get_error_by_command_name( - "additionalComplexAccessCheckAggStage") - self.assertTrue(additional_complex_access_check_agg_stage_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_ADDITIONAL_COMPLEX_ACCESS_CHECK) + complex_action_types_not_subset_two_error = ( + error_collection.get_error_by_command_name("complexActionTypesNotSubsetTwo") + ) + self.assertTrue( + complex_action_types_not_subset_two_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMPLEX_PRIVILEGES_NOT_SUBSET + ) + self.assertRegex( + str(complex_action_types_not_subset_two_error), + "complexActionTypesNotSubsetTwo", + ) + + additional_complex_access_check_error = ( + error_collection.get_error_by_command_name("additionalComplexAccessCheck") + ) + self.assertTrue( + additional_complex_access_check_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_ADDITIONAL_COMPLEX_ACCESS_CHECK + ) + self.assertRegex( + str(additional_complex_access_check_error), "additionalComplexAccessCheck" + ) + + additional_complex_access_check_agg_stage_error = ( + error_collection.get_error_by_command_name( + "additionalComplexAccessCheckAggStage" + ) + ) + self.assertTrue( + additional_complex_access_check_agg_stage_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_ADDITIONAL_COMPLEX_ACCESS_CHECK + ) self.assertRegex( str(additional_complex_access_check_agg_stage_error), - "additionalComplexAccessCheckAggStage") + "additionalComplexAccessCheckAggStage", + ) removed_access_check_field_error = error_collection.get_error_by_command_name( - "removedAccessCheckField") - self.assertTrue(removed_access_check_field_error.error_id == - idl_compatibility_errors.ERROR_ID_REMOVED_ACCESS_CHECK_FIELD) - self.assertRegex(str(removed_access_check_field_error), "removedAccessCheckField") + "removedAccessCheckField" + ) + self.assertTrue( + removed_access_check_field_error.error_id + == idl_compatibility_errors.ERROR_ID_REMOVED_ACCESS_CHECK_FIELD + ) + self.assertRegex( + str(removed_access_check_field_error), "removedAccessCheckField" + ) added_access_check_field_error = error_collection.get_error_by_command_name( - "addedAccessCheckField") - self.assertTrue(added_access_check_field_error.error_id == - idl_compatibility_errors.ERROR_ID_ADDED_ACCESS_CHECK_FIELD) + "addedAccessCheckField" + ) + self.assertTrue( + added_access_check_field_error.error_id + == idl_compatibility_errors.ERROR_ID_ADDED_ACCESS_CHECK_FIELD + ) self.assertRegex(str(added_access_check_field_error), "addedAccessCheckField") - missing_array_command_type_old_error = error_collection.get_error_by_command_name( - "arrayCommandTypeErrorNoArrayOld") - self.assertTrue(missing_array_command_type_old_error.error_id == - idl_compatibility_errors.ERROR_ID_TYPE_NOT_ARRAY) - self.assertRegex(str(missing_array_command_type_old_error), "array") - - missing_array_command_type_new_error = error_collection.get_error_by_command_name( - "arrayCommandTypeErrorNoArrayNew") - self.assertTrue(missing_array_command_type_new_error.error_id == - idl_compatibility_errors.ERROR_ID_TYPE_NOT_ARRAY) - self.assertRegex(str(missing_array_command_type_new_error), "array") - - missing_array_command_parameter_old_error = error_collection.get_error_by_command_name( - "arrayCommandParameterNoArrayOld") - self.assertTrue(missing_array_command_parameter_old_error.error_id == - idl_compatibility_errors.ERROR_ID_TYPE_NOT_ARRAY) - self.assertRegex(str(missing_array_command_parameter_old_error), "array") - - missing_array_command_parameter_new_error = error_collection.get_error_by_command_name( - "arrayCommandParameterNoArrayNew") - self.assertTrue(missing_array_command_parameter_new_error.error_id == - idl_compatibility_errors.ERROR_ID_TYPE_NOT_ARRAY) - self.assertRegex(str(missing_array_command_parameter_new_error), "array") - - new_reply_field_missing_unstable_field_error = error_collection.get_error_by_command_name( - "newReplyFieldMissingUnstableField") - self.assertTrue(new_reply_field_missing_unstable_field_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_REQUIRES_STABILITY) + missing_array_command_type_old_error = ( + error_collection.get_error_by_command_name( + "arrayCommandTypeErrorNoArrayOld" + ) + ) + self.assertTrue( + missing_array_command_type_old_error.error_id + == idl_compatibility_errors.ERROR_ID_TYPE_NOT_ARRAY + ) self.assertRegex( - str(new_reply_field_missing_unstable_field_error), "newReplyFieldMissingUnstableField") + str(missing_array_command_type_old_error), "array" + ) - new_command_type_field_missing_unstable_field_error = error_collection.get_error_by_command_name( - "newCommandTypeFieldMissingUnstableField") - self.assertTrue(new_command_type_field_missing_unstable_field_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_FIELD_REQUIRES_STABILITY) + missing_array_command_type_new_error = ( + error_collection.get_error_by_command_name( + "arrayCommandTypeErrorNoArrayNew" + ) + ) + self.assertTrue( + missing_array_command_type_new_error.error_id + == idl_compatibility_errors.ERROR_ID_TYPE_NOT_ARRAY + ) + self.assertRegex( + str(missing_array_command_type_new_error), "array" + ) + + missing_array_command_parameter_old_error = ( + error_collection.get_error_by_command_name( + "arrayCommandParameterNoArrayOld" + ) + ) + self.assertTrue( + missing_array_command_parameter_old_error.error_id + == idl_compatibility_errors.ERROR_ID_TYPE_NOT_ARRAY + ) + self.assertRegex( + str(missing_array_command_parameter_old_error), "array" + ) + + missing_array_command_parameter_new_error = ( + error_collection.get_error_by_command_name( + "arrayCommandParameterNoArrayNew" + ) + ) + self.assertTrue( + missing_array_command_parameter_new_error.error_id + == idl_compatibility_errors.ERROR_ID_TYPE_NOT_ARRAY + ) + self.assertRegex( + str(missing_array_command_parameter_new_error), "array" + ) + + new_reply_field_missing_unstable_field_error = ( + error_collection.get_error_by_command_name( + "newReplyFieldMissingUnstableField" + ) + ) + self.assertTrue( + new_reply_field_missing_unstable_field_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_REQUIRES_STABILITY + ) + self.assertRegex( + str(new_reply_field_missing_unstable_field_error), + "newReplyFieldMissingUnstableField", + ) + + new_command_type_field_missing_unstable_field_error = ( + error_collection.get_error_by_command_name( + "newCommandTypeFieldMissingUnstableField" + ) + ) + self.assertTrue( + new_command_type_field_missing_unstable_field_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_FIELD_REQUIRES_STABILITY + ) self.assertRegex( str(new_command_type_field_missing_unstable_field_error), - "newCommandTypeFieldMissingUnstableField") + "newCommandTypeFieldMissingUnstableField", + ) - new_parameter_missing_unstable_field_error = error_collection.get_error_by_command_name( - "newParameterMissingUnstableField") - self.assertTrue(new_parameter_missing_unstable_field_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_PARAMETER_REQUIRES_STABILITY) + new_parameter_missing_unstable_field_error = ( + error_collection.get_error_by_command_name( + "newParameterMissingUnstableField" + ) + ) + self.assertTrue( + new_parameter_missing_unstable_field_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_PARAMETER_REQUIRES_STABILITY + ) self.assertRegex( - str(new_parameter_missing_unstable_field_error), "newParameterMissingUnstableField") + str(new_parameter_missing_unstable_field_error), + "newParameterMissingUnstableField", + ) - added_new_reply_field_missing_unstable_field_error = error_collection.get_error_by_command_name( - "addedNewReplyFieldMissingUnstableField") - self.assertTrue(added_new_reply_field_missing_unstable_field_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_REQUIRES_STABILITY) + added_new_reply_field_missing_unstable_field_error = ( + error_collection.get_error_by_command_name( + "addedNewReplyFieldMissingUnstableField" + ) + ) + self.assertTrue( + added_new_reply_field_missing_unstable_field_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_REQUIRES_STABILITY + ) self.assertRegex( str(added_new_reply_field_missing_unstable_field_error), - "addedNewReplyFieldMissingUnstableField") + "addedNewReplyFieldMissingUnstableField", + ) - added_new_command_type_field_missing_unstable_field_error = error_collection.get_error_by_command_name( - "addedNewCommandTypeFieldMissingUnstableField") - self.assertTrue(added_new_command_type_field_missing_unstable_field_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_FIELD_REQUIRES_STABILITY) + added_new_command_type_field_missing_unstable_field_error = ( + error_collection.get_error_by_command_name( + "addedNewCommandTypeFieldMissingUnstableField" + ) + ) + self.assertTrue( + added_new_command_type_field_missing_unstable_field_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_FIELD_REQUIRES_STABILITY + ) self.assertRegex( str(added_new_command_type_field_missing_unstable_field_error), - "addedNewCommandTypeFieldMissingUnstableField") + "addedNewCommandTypeFieldMissingUnstableField", + ) - added_new_parameter_missing_unstable_field_error = error_collection.get_error_by_command_name( - "addedNewParameterMissingUnstableField") - self.assertTrue(added_new_parameter_missing_unstable_field_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_PARAMETER_REQUIRES_STABILITY) + added_new_parameter_missing_unstable_field_error = ( + error_collection.get_error_by_command_name( + "addedNewParameterMissingUnstableField" + ) + ) + self.assertTrue( + added_new_parameter_missing_unstable_field_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_PARAMETER_REQUIRES_STABILITY + ) self.assertRegex( str(added_new_parameter_missing_unstable_field_error), - "addedNewParameterMissingUnstableField") + "addedNewParameterMissingUnstableField", + ) chained_struct_incompatible_error = error_collection.get_error_by_command_name( - "chainedStructIncompatible") - self.assertTrue(chained_struct_incompatible_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_TYPE_NOT_SUPERSET) - self.assertRegex(str(chained_struct_incompatible_error), "chainedStructIncompatible") - - reply_with_incompatible_chained_struct_error = error_collection.get_error_by_command_name( - "replyWithIncompatibleChainedStruct") - self.assertTrue(reply_with_incompatible_chained_struct_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_VARIANT_TYPE_NOT_SUBSET) - self.assertRegex( - str(reply_with_incompatible_chained_struct_error), "replyWithIncompatibleChainedStruct") - - type_with_incompatible_chained_struct_error = error_collection.get_error_by_command_name( - "typeWithIncompatibleChainedStruct") + "chainedStructIncompatible" + ) self.assertTrue( - type_with_incompatible_chained_struct_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY) + chained_struct_incompatible_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_TYPE_NOT_SUPERSET + ) self.assertRegex( - str(type_with_incompatible_chained_struct_error), "typeWithIncompatibleChainedStruct") + str(chained_struct_incompatible_error), "chainedStructIncompatible" + ) + + reply_with_incompatible_chained_struct_error = ( + error_collection.get_error_by_command_name( + "replyWithIncompatibleChainedStruct" + ) + ) + self.assertTrue( + reply_with_incompatible_chained_struct_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_VARIANT_TYPE_NOT_SUBSET + ) + self.assertRegex( + str(reply_with_incompatible_chained_struct_error), + "replyWithIncompatibleChainedStruct", + ) + + type_with_incompatible_chained_struct_error = ( + error_collection.get_error_by_command_name( + "typeWithIncompatibleChainedStruct" + ) + ) + self.assertTrue( + type_with_incompatible_chained_struct_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_BSON_SERIALIZATION_TYPE_ANY + ) + self.assertRegex( + str(type_with_incompatible_chained_struct_error), + "typeWithIncompatibleChainedStruct", + ) incompatible_chained_type_error = error_collection.get_error_by_command_name( - "incompatibleChainedType") - self.assertTrue(incompatible_chained_type_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_TYPE_NOT_SUPERSET) - self.assertRegex(str(incompatible_chained_type_error), "incompatibleChainedType") - - new_parameter_removed_chained_type_error = error_collection.get_error_by_command_name( - "newParameterRemovedChainedType") + "incompatibleChainedType" + ) self.assertTrue( - new_parameter_removed_chained_type_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_CHAINED_TYPE_NOT_SUPERSET) + incompatible_chained_type_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_TYPE_NOT_SUPERSET + ) self.assertRegex( - str(new_parameter_removed_chained_type_error), "newParameterRemovedChainedType") + str(incompatible_chained_type_error), "incompatibleChainedType" + ) + + new_parameter_removed_chained_type_error = ( + error_collection.get_error_by_command_name("newParameterRemovedChainedType") + ) + self.assertTrue( + new_parameter_removed_chained_type_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAMETER_CHAINED_TYPE_NOT_SUPERSET + ) + self.assertRegex( + str(new_parameter_removed_chained_type_error), + "newParameterRemovedChainedType", + ) new_reply_added_chained_type_error = error_collection.get_error_by_command_name( - "newReplyAddedChainedType") - self.assertTrue(new_reply_added_chained_type_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_REPLY_CHAINED_TYPE_NOT_SUBSET) - self.assertRegex(str(new_reply_added_chained_type_error), "newReplyAddedChainedType") + "newReplyAddedChainedType" + ) + self.assertTrue( + new_reply_added_chained_type_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_REPLY_CHAINED_TYPE_NOT_SUBSET + ) + self.assertRegex( + str(new_reply_added_chained_type_error), "newReplyAddedChainedType" + ) - optional_bool_to_bool_parameter_error = error_collection.get_error_by_command_name( - "optionalBoolToBoolParameter") - self.assertTrue(optional_bool_to_bool_parameter_error.error_id == - idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_REQUIRED) + optional_bool_to_bool_parameter_error = ( + error_collection.get_error_by_command_name("optionalBoolToBoolParameter") + ) + self.assertTrue( + optional_bool_to_bool_parameter_error.error_id + == idl_compatibility_errors.ERROR_ID_COMMAND_PARAMETER_REQUIRED + ) - optional_bool_to_bool_command_type_error = error_collection.get_error_by_command_name( - "optionalBoolToBoolCommandType") - self.assertTrue(optional_bool_to_bool_command_type_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_FIELD_REQUIRED) + optional_bool_to_bool_command_type_error = ( + error_collection.get_error_by_command_name("optionalBoolToBoolCommandType") + ) + self.assertTrue( + optional_bool_to_bool_command_type_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_FIELD_REQUIRED + ) bool_to_optional_bool_reply_error = error_collection.get_error_by_command_name( - "boolToOptionalBoolReply") - self.assertTrue(bool_to_optional_bool_reply_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_OPTIONAL) - - unstable_to_stable_reply_field_error = error_collection.get_error_by_command_name( - "unstableToStableReplyField") - self.assertTrue(unstable_to_stable_reply_field_error.error_id == - idl_compatibility_errors.ERROR_ID_UNSTABLE_REPLY_FIELD_CHANGED_TO_STABLE) - self.assertRegex(str(unstable_to_stable_reply_field_error), "unstableToStableReplyField") - - unstable_to_stable_param_field_error = error_collection.get_error_by_command_name( - "unstableToStableParamField") - self.assertTrue(unstable_to_stable_param_field_error.error_id == idl_compatibility_errors. - ERROR_ID_UNSTABLE_COMMAND_PARAM_FIELD_CHANGED_TO_STABLE) - self.assertRegex(str(unstable_to_stable_param_field_error), "unstableToStableParamField") - - unstable_to_stable_type_field_error = error_collection.get_error_by_command_name( - "unstableToStableTypeField") - self.assertTrue(unstable_to_stable_type_field_error.error_id == idl_compatibility_errors. - ERROR_ID_UNSTABLE_COMMAND_TYPE_FIELD_CHANGED_TO_STABLE) - self.assertRegex(str(unstable_to_stable_type_field_error), "unstableToStableTypeField") - - new_reply_field_added_as_stable_error = error_collection.get_error_by_command_name( - "newStableReplyFieldAdded") - self.assertTrue(new_reply_field_added_as_stable_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_ADDED_AS_STABLE) - self.assertRegex(str(new_reply_field_added_as_stable_error), "newStableReplyFieldAdded") - - new_command_param_field_added_as_stable_error = error_collection.get_error_by_command_name( - "newStableParameterAdded") - self.assertTrue(new_command_param_field_added_as_stable_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAM_FIELD_ADDED_AS_STABLE) - self.assertRegex( - str(new_command_param_field_added_as_stable_error), "newStableParameterAdded") - - new_command_type_field_added_as_stable_error = error_collection.get_error_by_command_name( - "newStableTypeFieldAdded") - self.assertTrue(new_command_type_field_added_as_stable_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_FIELD_ADDED_AS_STABLE) - self.assertRegex( - str(new_command_type_field_added_as_stable_error), "newStableTypeFieldAdded") - - new_type_field_added_as_unstable_required_error = error_collection.get_error_by_command_name( - "commandWithNewRequiredUnstableFieldInType") + "boolToOptionalBoolReply" + ) self.assertTrue( - new_type_field_added_as_unstable_required_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_FIELD_ADDED_AS_UNSTABLE_REQUIRED) + bool_to_optional_bool_reply_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_OPTIONAL + ) + + unstable_to_stable_reply_field_error = ( + error_collection.get_error_by_command_name("unstableToStableReplyField") + ) + self.assertTrue( + unstable_to_stable_reply_field_error.error_id + == idl_compatibility_errors.ERROR_ID_UNSTABLE_REPLY_FIELD_CHANGED_TO_STABLE + ) + self.assertRegex( + str(unstable_to_stable_reply_field_error), "unstableToStableReplyField" + ) + + unstable_to_stable_param_field_error = ( + error_collection.get_error_by_command_name("unstableToStableParamField") + ) + self.assertTrue( + unstable_to_stable_param_field_error.error_id + == idl_compatibility_errors.ERROR_ID_UNSTABLE_COMMAND_PARAM_FIELD_CHANGED_TO_STABLE + ) + self.assertRegex( + str(unstable_to_stable_param_field_error), "unstableToStableParamField" + ) + + unstable_to_stable_type_field_error = ( + error_collection.get_error_by_command_name("unstableToStableTypeField") + ) + self.assertTrue( + unstable_to_stable_type_field_error.error_id + == idl_compatibility_errors.ERROR_ID_UNSTABLE_COMMAND_TYPE_FIELD_CHANGED_TO_STABLE + ) + self.assertRegex( + str(unstable_to_stable_type_field_error), "unstableToStableTypeField" + ) + + new_reply_field_added_as_stable_error = ( + error_collection.get_error_by_command_name("newStableReplyFieldAdded") + ) + self.assertTrue( + new_reply_field_added_as_stable_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_ADDED_AS_STABLE + ) + self.assertRegex( + str(new_reply_field_added_as_stable_error), "newStableReplyFieldAdded" + ) + + new_command_param_field_added_as_stable_error = ( + error_collection.get_error_by_command_name("newStableParameterAdded") + ) + self.assertTrue( + new_command_param_field_added_as_stable_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAM_FIELD_ADDED_AS_STABLE + ) + self.assertRegex( + str(new_command_param_field_added_as_stable_error), + "newStableParameterAdded", + ) + + new_command_type_field_added_as_stable_error = ( + error_collection.get_error_by_command_name("newStableTypeFieldAdded") + ) + self.assertTrue( + new_command_type_field_added_as_stable_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_FIELD_ADDED_AS_STABLE + ) + self.assertRegex( + str(new_command_type_field_added_as_stable_error), "newStableTypeFieldAdded" + ) + + new_type_field_added_as_unstable_required_error = ( + error_collection.get_error_by_command_name( + "commandWithNewRequiredUnstableFieldInType" + ) + ) + self.assertTrue( + new_type_field_added_as_unstable_required_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_TYPE_FIELD_ADDED_AS_UNSTABLE_REQUIRED + ) self.assertRegex( str(new_type_field_added_as_unstable_required_error), - "commandWithNewRequiredUnstableFieldInType") + "commandWithNewRequiredUnstableFieldInType", + ) - new_param_field_added_as_unstable_required_error = error_collection.get_error_by_command_name( - "newUnstableRequiredParameterAdded") + new_param_field_added_as_unstable_required_error = ( + error_collection.get_error_by_command_name( + "newUnstableRequiredParameterAdded" + ) + ) self.assertTrue( - new_param_field_added_as_unstable_required_error.error_id == - idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAM_FIELD_ADDED_AS_UNSTABLE_REQUIRED) + new_param_field_added_as_unstable_required_error.error_id + == idl_compatibility_errors.ERROR_ID_NEW_COMMAND_PARAM_FIELD_ADDED_AS_UNSTABLE_REQUIRED + ) self.assertRegex( str(new_param_field_added_as_unstable_required_error), - "newUnstableRequiredParameterAdded") + "newUnstableRequiredParameterAdded", + ) self.assertEqual(error_collection.count(), 217) @@ -1552,9 +2573,14 @@ class TestIDLCompatibilityChecker(unittest.TestCase): test_path = path.join(dir_path, "compatibility_test_pass/generic_argument/") include_paths = [path.join(dir_path, "include/")] - error_collection = idl_check_compatibility.check_generic_arguments_compatibility( - path.join(test_path, "old.idl"), path.join(test_path, "new.idl"), include_paths, - include_paths) + error_collection = ( + idl_check_compatibility.check_generic_arguments_compatibility( + path.join(test_path, "old.idl"), + path.join(test_path, "new.idl"), + include_paths, + include_paths, + ) + ) error_collection.dump_errors() @@ -1566,9 +2592,14 @@ class TestIDLCompatibilityChecker(unittest.TestCase): test_path = path.join(dir_path, "compatibility_test_fail/generic_argument/") include_paths = [path.join(dir_path, "include/")] - error_collection = idl_check_compatibility.check_generic_arguments_compatibility( - path.join(test_path, "old.idl"), path.join(test_path, "new.idl"), include_paths, - include_paths) + error_collection = ( + idl_check_compatibility.check_generic_arguments_compatibility( + path.join(test_path, "old.idl"), + path.join(test_path, "new.idl"), + include_paths, + include_paths, + ) + ) error_collection.dump_errors() @@ -1576,16 +2607,24 @@ class TestIDLCompatibilityChecker(unittest.TestCase): self.assertTrue(error_collection.count() == 2) removed_generic_argument_error = error_collection.get_error_by_command_name( - "removedGenericArgument") - self.assertTrue(removed_generic_argument_error.error_id == - idl_compatibility_errors.ERROR_ID_GENERIC_ARGUMENT_REMOVED) + "removedGenericArgument" + ) + self.assertTrue( + removed_generic_argument_error.error_id + == idl_compatibility_errors.ERROR_ID_GENERIC_ARGUMENT_REMOVED + ) self.assertRegex(str(removed_generic_argument_error), "removedGenericArgument") removed_generic_reply_field_error = error_collection.get_error_by_command_name( - "removedGenericReplyField") - self.assertTrue(removed_generic_reply_field_error.error_id == - idl_compatibility_errors.ERROR_ID_GENERIC_ARGUMENT_REMOVED_REPLY_FIELD) - self.assertRegex(str(removed_generic_reply_field_error), "removedGenericReplyField") + "removedGenericReplyField" + ) + self.assertTrue( + removed_generic_reply_field_error.error_id + == idl_compatibility_errors.ERROR_ID_GENERIC_ARGUMENT_REMOVED_REPLY_FIELD + ) + self.assertRegex( + str(removed_generic_reply_field_error), "removedGenericReplyField" + ) def test_error_reply(self): """Tests the compatibility checker with the ErrorReply struct.""" @@ -1594,20 +2633,29 @@ class TestIDLCompatibilityChecker(unittest.TestCase): self.assertFalse( idl_check_compatibility.check_error_reply( path.join(dir_path, "compatibility_test_pass/old/error_reply.idl"), - path.join(dir_path, "compatibility_test_pass/new/error_reply.idl"), [], - []).has_errors()) + path.join(dir_path, "compatibility_test_pass/new/error_reply.idl"), + [], + [], + ).has_errors() + ) error_collection_fail = idl_check_compatibility.check_error_reply( path.join(dir_path, "compatibility_test_fail/old/error_reply.idl"), - path.join(dir_path, "compatibility_test_fail/new/error_reply.idl"), [], []) + path.join(dir_path, "compatibility_test_fail/new/error_reply.idl"), + [], + [], + ) self.assertTrue(error_collection_fail.has_errors()) self.assertTrue(error_collection_fail.count() == 1) - new_error_reply_field_optional_error = error_collection_fail.get_error_by_error_id( - idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_OPTIONAL) + new_error_reply_field_optional_error = ( + error_collection_fail.get_error_by_error_id( + idl_compatibility_errors.ERROR_ID_NEW_REPLY_FIELD_OPTIONAL + ) + ) self.assertRegex(str(new_error_reply_field_optional_error), "n/a") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/buildscripts/idl/tests/test_generator.py b/buildscripts/idl/tests/test_generator.py index b159de735d9..468ba73a448 100644 --- a/buildscripts/idl/tests/test_generator.py +++ b/buildscripts/idl/tests/test_generator.py @@ -43,6 +43,7 @@ from textwrap import dedent # import package so that it works regardless of whether we run as a module or file if __package__ is None: import sys + sys.path.append(os.path.dirname(os.path.abspath(__file__))) import testcase from context import idl @@ -61,22 +62,25 @@ class TestGenerator(testcase.IDLTestcase): def _src_dir(self): """Get the directory of the src folder.""" base_dir = os.path.dirname( - os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + ) return os.path.join( base_dir, - 'src', + "src", ) @property def _idl_dir(self): """Get the directory of the idl folder.""" - return os.path.join(self._src_dir, 'mongo', 'idl') + return os.path.join(self._src_dir, "mongo", "idl") def tearDown(self) -> None: """Cleanup resources created by tests.""" for idl_file in self.idl_files_to_test: for ext in ["h", "cpp"]: - file_path = os.path.join(self._idl_dir, f"{idl_file}{self.output_suffix}.{ext}") + file_path = os.path.join( + self._idl_dir, f"{idl_file}{self.output_suffix}.{ext}" + ) if os.path.exists(file_path): os.remove(file_path) @@ -87,14 +91,18 @@ class TestGenerator(testcase.IDLTestcase): args.output_suffix = self.output_suffix args.import_directories = [self._src_dir] - unittest_idl_file = os.path.join(self._idl_dir, f'{self.idl_files_to_test[0]}.idl') + unittest_idl_file = os.path.join( + self._idl_dir, f"{self.idl_files_to_test[0]}.idl" + ) if not os.path.exists(unittest_idl_file): unittest.skip( - "Skipping IDL Generator testing since %s could not be found." % (unittest_idl_file)) + "Skipping IDL Generator testing since %s could not be found." + % (unittest_idl_file) + ) return for idl_file in self.idl_files_to_test: - args.input_file = os.path.join(self._idl_dir, f'{idl_file}.idl') + args.input_file = os.path.join(self._idl_dir, f"{idl_file}.idl") self.assertTrue(idl.compiler.compile_idl(args)) def test_enum_non_const(self): @@ -130,13 +138,15 @@ class TestGenerator(testcase.IDLTestcase): # Make sure the getter is marked as const. # Make sure the return type is not marked as const by validating the getter marked as const # is the only occurrence of the word "const". - header_lines = header.split('\n') + header_lines = header.split("\n") found = False for header_line in header_lines: - if header_line.find("getValue") > 0 \ - and header_line.find("const {") > 0 \ - and header_line.find("const {") == header_line.find("const"): + if ( + header_line.find("getValue") > 0 + and header_line.find("const {") > 0 + and header_line.find("const {") == header_line.find("const") + ): found = True self.assertTrue(found, "Bad Header: " + header) @@ -223,7 +233,9 @@ class TestGenerator(testcase.IDLTestcase): """) self.assertIn(expected, source) - def test_object_type_with_custom_serializer_and_query_shape_specification_custom(self) -> None: + def test_object_type_with_custom_serializer_and_query_shape_specification_custom( + self, + ) -> None: """Serialization with custom query_shape.""" _, source = self.assert_generate(""" types: @@ -267,7 +279,8 @@ class TestGenerator(testcase.IDLTestcase): self.assertIn(expected, source) def test_array_of_object_type_with_custom_serializer_and_query_shape_specification_custom( - self) -> None: + self, + ) -> None: """Serialization with custom query_shape used, array use case.""" _, source = self.assert_generate(""" types: @@ -373,7 +386,9 @@ class TestGenerator(testcase.IDLTestcase): def test_view_struct_generates_anchor(self) -> None: """Test anchor generation on view struct.""" - header, _ = self.assert_generate(self.view_test_common_types + dedent(""" + header, _ = self.assert_generate( + self.view_test_common_types + + dedent(""" structs: ViewStruct: description: ViewStruct @@ -382,14 +397,17 @@ class TestGenerator(testcase.IDLTestcase): value2: random_type_not_view value3: random_type_not_view value4: random_type_not_view - """)) + """) + ) expected = dedent("BSONObj _anchorObj;") self.assertIn(expected, header) def test_non_view_struct_does_not_generate_anchor(self) -> None: """Test anchor is not generated on non view struct.""" - header, _ = self.assert_generate(self.view_test_common_types + dedent(""" + header, _ = self.assert_generate( + self.view_test_common_types + + dedent(""" structs: NonViewStruct: description: NonViewStruct @@ -398,14 +416,17 @@ class TestGenerator(testcase.IDLTestcase): value2: random_type_not_view value3: random_type_not_view value4: random_type_not_view - """)) + """) + ) expected = dedent("BSONObj _anchorObj;") self.assertNotIn(expected, header) def test_compound_view_struct_generates_anchor(self) -> None: """Test anchor generation on view struct with compound type.""" - header, _ = self.assert_generate(self.view_test_common_types + dedent(""" + header, _ = self.assert_generate( + self.view_test_common_types + + dedent(""" structs: ViewStruct: description: ViewStruct @@ -414,14 +435,17 @@ class TestGenerator(testcase.IDLTestcase): value2: random_type_not_view value3: random_type_not_view value4: random_type_not_view - """)) + """) + ) expected = dedent("BSONObj _anchorObj;") self.assertIn(expected, header) def test_compound_non_view_struct_does_not_generate_anchor(self) -> None: """Test anchor is not generated on non view struct with compound type.""" - header, _ = self.assert_generate(self.view_test_common_types + dedent(""" + header, _ = self.assert_generate( + self.view_test_common_types + + dedent(""" structs: NonViewStruct: description: NonViewStruct @@ -430,14 +454,17 @@ class TestGenerator(testcase.IDLTestcase): value2: random_type_not_view value3: random_type_not_view value4: random_type_not_view - """)) + """) + ) expected = dedent("BSONObj _anchorObj;") self.assertNotIn(expected, header) def test_command_view_type_generates_anchor(self) -> None: """Test anchor generation on command with view parameter.""" - header, _ = self.assert_generate(self.view_test_common_types + dedent(""" + header, _ = self.assert_generate( + self.view_test_common_types + + dedent(""" commands: CommandTypeArrayObjectCommand: description: CommandTypeArrayObjectCommand @@ -445,14 +472,17 @@ class TestGenerator(testcase.IDLTestcase): namespace: type api_version: "" type: array - """)) + """) + ) expected = dedent("BSONObj _anchorObj;") self.assertIn(expected, header) def test_command_non_view_type_does_not_generate_anchor(self) -> None: """Test anchor is not generated on command with nont view parameter.""" - header, _ = self.assert_generate(self.view_test_common_types + dedent(""" + header, _ = self.assert_generate( + self.view_test_common_types + + dedent(""" commands: CommandTypeArrayObjectCommand: description: CommandTypeArrayObjectCommand @@ -460,14 +490,17 @@ class TestGenerator(testcase.IDLTestcase): namespace: type api_version: "" type: array - """)) + """) + ) expected = dedent("BSONObj _anchorObj;") self.assertNotIn(expected, header) def test_chained_view_struct_generates_anchor(self) -> None: """Test anchor generation on struct chained with view struct.""" - header, _ = self.assert_generate(self.view_test_common_types + dedent(""" + header, _ = self.assert_generate( + self.view_test_common_types + + dedent(""" structs: ViewStruct: description: ViewStruct @@ -480,14 +513,17 @@ class TestGenerator(testcase.IDLTestcase): description: ViewStructChainedStruct chained_structs: ViewStruct: ViewStruct - """)) + """) + ) expected = dedent("BSONObj _anchorObj;") self.assertIn(expected, header) def test_chained_non_view_struct_does_not_generate_anchor(self) -> None: """Test anchor not generated on struct chained with non view struct.""" - header, _ = self.assert_generate(self.view_test_common_types + dedent(""" + header, _ = self.assert_generate( + self.view_test_common_types + + dedent(""" structs: NonViewStruct: description: NonViewStruct @@ -500,12 +536,12 @@ class TestGenerator(testcase.IDLTestcase): description: NonViewStructChainedStruct chained_structs: NonViewStruct: NonViewStruct - """)) + """) + ) expected = dedent("BSONObj _anchorObj;") self.assertNotIn(expected, header) -if __name__ == '__main__': - +if __name__ == "__main__": unittest.main() diff --git a/buildscripts/idl/tests/test_import.py b/buildscripts/idl/tests/test_import.py index 3381b8a123b..91c9e62c57d 100644 --- a/buildscripts/idl/tests/test_import.py +++ b/buildscripts/idl/tests/test_import.py @@ -37,6 +37,7 @@ import unittest if __package__ is None: import sys from os import path + sys.path.append(path.dirname(path.abspath(__file__))) import testcase from context import idl @@ -87,27 +88,32 @@ class TestImport(testcase.IDLTestcase): imports: - "b.idl" - """), idl.errors.ERROR_ID_DUPLICATE_NODE) + """), + idl.errors.ERROR_ID_DUPLICATE_NODE, + ) self.assert_parse_fail( textwrap.dedent(""" imports: "basetypes.idl" - """), idl.errors.ERROR_ID_IS_NODE_TYPE) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + ) self.assert_parse_fail( textwrap.dedent(""" imports: a: "a.idl" b: "b.idl" - """), idl.errors.ERROR_ID_IS_NODE_TYPE) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + ) def test_import_positive(self): # type: () -> None """Postive import tests.""" import_dict = { - "basetypes.idl": - textwrap.dedent(""" + "basetypes.idl": textwrap.dedent(""" global: cpp_namespace: 'mongo' @@ -134,8 +140,7 @@ class TestImport(testcase.IDLTestcase): fields: foo: string """), - "recurse1.idl": - textwrap.dedent(""" + "recurse1.idl": textwrap.dedent(""" imports: - "basetypes.idl" @@ -148,8 +153,7 @@ class TestImport(testcase.IDLTestcase): is_view: false """), - "recurse2.idl": - textwrap.dedent(""" + "recurse2.idl": textwrap.dedent(""" imports: - "recurse1.idl" @@ -162,8 +166,7 @@ class TestImport(testcase.IDLTestcase): is_view: false """), - "recurse1b.idl": - textwrap.dedent(""" + "recurse1b.idl": textwrap.dedent(""" imports: - "basetypes.idl" @@ -175,8 +178,7 @@ class TestImport(testcase.IDLTestcase): deserializer: BSONElement::fake is_view: false """), - "cycle1a.idl": - textwrap.dedent(""" + "cycle1a.idl": textwrap.dedent(""" global: cpp_namespace: 'mongo' @@ -207,8 +209,7 @@ class TestImport(testcase.IDLTestcase): foo: string foo1: bool """), - "cycle1b.idl": - textwrap.dedent(""" + "cycle1b.idl": textwrap.dedent(""" global: cpp_namespace: 'mongo' @@ -231,8 +232,7 @@ class TestImport(testcase.IDLTestcase): foo: string foo1: bool """), - "cycle2.idl": - textwrap.dedent(""" + "cycle2.idl": textwrap.dedent(""" global: cpp_namespace: 'mongo' @@ -274,7 +274,9 @@ class TestImport(testcase.IDLTestcase): strict: false fields: foo: string - """), resolver=resolver) + """), + resolver=resolver, + ) # Test nested import self.assert_bind( @@ -293,7 +295,9 @@ class TestImport(testcase.IDLTestcase): foo: string foo1: int foo2: double - """), resolver=resolver) + """), + resolver=resolver, + ) # Test diamond import self.assert_bind( @@ -314,7 +318,9 @@ class TestImport(testcase.IDLTestcase): foo1: int foo2: double foo3: bool - """), resolver=resolver) + """), + resolver=resolver, + ) # Test cycle import self.assert_bind( @@ -332,7 +338,9 @@ class TestImport(testcase.IDLTestcase): fields: foo: string foo1: bool - """), resolver=resolver) + """), + resolver=resolver, + ) # Test self cycle import self.assert_bind( @@ -349,15 +357,16 @@ class TestImport(testcase.IDLTestcase): strict: false fields: foo: string - """), resolver=resolver) + """), + resolver=resolver, + ) def test_import_negative(self): # type: () -> None """Negative import tests.""" import_dict = { - "basetypes.idl": - textwrap.dedent(""" + "basetypes.idl": textwrap.dedent(""" global: cpp_namespace: 'mongo' @@ -387,8 +396,7 @@ class TestImport(testcase.IDLTestcase): b1: 1 """), - "bug.idl": - textwrap.dedent(""" + "bug.idl": textwrap.dedent(""" global: cpp_namespace: 'mongo' @@ -408,7 +416,10 @@ class TestImport(testcase.IDLTestcase): textwrap.dedent(""" imports: - "notfound.idl" - """), idl.errors.ERROR_ID_BAD_IMPORT, resolver=resolver) + """), + idl.errors.ERROR_ID_BAD_IMPORT, + resolver=resolver, + ) # Duplicate types self.assert_parse_fail( @@ -422,7 +433,10 @@ class TestImport(testcase.IDLTestcase): cpp_type: foo bson_serialization_type: string is_view: false - """), idl.errors.ERROR_ID_DUPLICATE_SYMBOL, resolver=resolver) + """), + idl.errors.ERROR_ID_DUPLICATE_SYMBOL, + resolver=resolver, + ) # Duplicate structs self.assert_parse_fail( @@ -435,7 +449,10 @@ class TestImport(testcase.IDLTestcase): description: foo fields: foo1: string - """), idl.errors.ERROR_ID_DUPLICATE_SYMBOL, resolver=resolver) + """), + idl.errors.ERROR_ID_DUPLICATE_SYMBOL, + resolver=resolver, + ) # Duplicate struct and type self.assert_parse_fail( @@ -448,7 +465,10 @@ class TestImport(testcase.IDLTestcase): description: foo fields: foo1: string - """), idl.errors.ERROR_ID_DUPLICATE_SYMBOL, resolver=resolver) + """), + idl.errors.ERROR_ID_DUPLICATE_SYMBOL, + resolver=resolver, + ) # Duplicate type and struct self.assert_parse_fail( @@ -462,7 +482,10 @@ class TestImport(testcase.IDLTestcase): cpp_type: foo bson_serialization_type: string is_view: false - """), idl.errors.ERROR_ID_DUPLICATE_SYMBOL, resolver=resolver) + """), + idl.errors.ERROR_ID_DUPLICATE_SYMBOL, + resolver=resolver, + ) # Duplicate enums self.assert_parse_fail( @@ -477,7 +500,10 @@ class TestImport(testcase.IDLTestcase): values: a0: 0 b1: 1 - """), idl.errors.ERROR_ID_DUPLICATE_SYMBOL, resolver=resolver) + """), + idl.errors.ERROR_ID_DUPLICATE_SYMBOL, + resolver=resolver, + ) # Import a file with errors self.assert_parse_fail( @@ -492,9 +518,11 @@ class TestImport(testcase.IDLTestcase): cpp_type: foo bson_serialization_type: string is_view: false - """), idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD, resolver=resolver) + """), + idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD, + resolver=resolver, + ) -if __name__ == '__main__': - +if __name__ == "__main__": unittest.main() diff --git a/buildscripts/idl/tests/test_parser.py b/buildscripts/idl/tests/test_parser.py index a7083506b43..3b4d2cd9f9f 100644 --- a/buildscripts/idl/tests/test_parser.py +++ b/buildscripts/idl/tests/test_parser.py @@ -36,6 +36,7 @@ import unittest if __package__ is None: import sys from os import path + sys.path.append(path.dirname(path.abspath(__file__))) import testcase from context import idl @@ -59,20 +60,26 @@ class TestParser(testcase.IDLTestcase): textwrap.dedent(""" fake: cpp_namespace: 'foo' - """), idl.errors.ERROR_ID_UNKNOWN_ROOT) + """), + idl.errors.ERROR_ID_UNKNOWN_ROOT, + ) def test_global_positive(self): # type: () -> None """Postive global tests.""" # cpp_namespace alone - self.assert_parse(textwrap.dedent(""" + self.assert_parse( + textwrap.dedent(""" global: - cpp_namespace: 'foo'""")) + cpp_namespace: 'foo'""") + ) # cpp_includes scalar - self.assert_parse(textwrap.dedent(""" + self.assert_parse( + textwrap.dedent(""" global: - cpp_includes: 'foo'""")) + cpp_includes: 'foo'""") + ) # cpp_includes list self.assert_parse( @@ -80,7 +87,8 @@ class TestParser(testcase.IDLTestcase): global: cpp_includes: - 'bar' - - 'foo'""")) + - 'foo'""") + ) def test_global_negative(self): # type: () -> None @@ -90,7 +98,9 @@ class TestParser(testcase.IDLTestcase): self.assert_parse_fail( textwrap.dedent(""" global: foo - """), idl.errors.ERROR_ID_IS_NODE_TYPE) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + ) # Duplicate globals self.assert_parse_fail( @@ -99,21 +109,27 @@ class TestParser(testcase.IDLTestcase): cpp_namespace: 'foo' global: cpp_namespace: 'bar' - """), idl.errors.ERROR_ID_DUPLICATE_NODE) + """), + idl.errors.ERROR_ID_DUPLICATE_NODE, + ) # Duplicate cpp_namespace self.assert_parse_fail( textwrap.dedent(""" global: cpp_namespace: 'foo' - cpp_namespace: 'foo'"""), idl.errors.ERROR_ID_DUPLICATE_NODE) + cpp_namespace: 'foo'"""), + idl.errors.ERROR_ID_DUPLICATE_NODE, + ) # Duplicate cpp_includes self.assert_parse_fail( textwrap.dedent(""" global: cpp_includes: 'foo' - cpp_includes: 'foo'"""), idl.errors.ERROR_ID_DUPLICATE_NODE) + cpp_includes: 'foo'"""), + idl.errors.ERROR_ID_DUPLICATE_NODE, + ) # cpp_namespace as a sequence self.assert_parse_fail( @@ -121,35 +137,45 @@ class TestParser(testcase.IDLTestcase): global: cpp_namespace: - 'foo' - - 'bar'"""), idl.errors.ERROR_ID_IS_NODE_TYPE) + - 'bar'"""), + idl.errors.ERROR_ID_IS_NODE_TYPE, + ) # cpp_namespace as a map self.assert_parse_fail( textwrap.dedent(""" global: cpp_namespace: - name: 'foo'"""), idl.errors.ERROR_ID_IS_NODE_TYPE) + name: 'foo'"""), + idl.errors.ERROR_ID_IS_NODE_TYPE, + ) # cpp_includes as a map self.assert_parse_fail( textwrap.dedent(""" global: cpp_includes: - inc1: 'foo'"""), idl.errors.ERROR_ID_IS_NODE_TYPE_SCALAR_OR_SEQUENCE) + inc1: 'foo'"""), + idl.errors.ERROR_ID_IS_NODE_TYPE_SCALAR_OR_SEQUENCE, + ) # cpp_includes as a sequence of tuples self.assert_parse_fail( textwrap.dedent(""" global: cpp_includes: - - inc1: 'foo'"""), idl.errors.ERROR_ID_IS_NODE_TYPE) + - inc1: 'foo'"""), + idl.errors.ERROR_ID_IS_NODE_TYPE, + ) # Unknown scalar self.assert_parse_fail( textwrap.dedent(""" global: bar: 'foo' - """), idl.errors.ERROR_ID_UNKNOWN_NODE) + """), + idl.errors.ERROR_ID_UNKNOWN_NODE, + ) def test_type_positive(self): # type: () -> None @@ -168,7 +194,8 @@ class TestParser(testcase.IDLTestcase): default: foo bindata_subtype: foo is_view: false - """)) + """) + ) # Test sequence of bson serialization types self.assert_parse( @@ -181,7 +208,8 @@ class TestParser(testcase.IDLTestcase): - foo - bar is_view: false - """)) + """) + ) def test_type_negative(self): # type: () -> None @@ -202,13 +230,17 @@ class TestParser(testcase.IDLTestcase): cpp_type: foo bson_serialization_type: int is_view: false - """), idl.errors.ERROR_ID_DUPLICATE_NODE) + """), + idl.errors.ERROR_ID_DUPLICATE_NODE, + ) # Test scalar fails self.assert_parse_fail( textwrap.dedent(""" types: - foo: 'bar'"""), idl.errors.ERROR_ID_IS_NODE_TYPE) + foo: 'bar'"""), + idl.errors.ERROR_ID_IS_NODE_TYPE, + ) # Test unknown field self.assert_parse_fail( @@ -220,7 +252,9 @@ class TestParser(testcase.IDLTestcase): cpp_type: foo bson_serialization_type: is_view: false - """), idl.errors.ERROR_ID_UNKNOWN_NODE) + """), + idl.errors.ERROR_ID_UNKNOWN_NODE, + ) # test duplicate field self.assert_parse_fail( @@ -232,14 +266,19 @@ class TestParser(testcase.IDLTestcase): cpp_type: foo bson_serialization_type: is_view: false - """), idl.errors.ERROR_ID_DUPLICATE_NODE) + """), + idl.errors.ERROR_ID_DUPLICATE_NODE, + ) # test list instead of scalar self.assert_parse_fail( textwrap.dedent(""" types: - foo: - """), idl.errors.ERROR_ID_IS_NODE_TYPE, multiple=True) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + multiple=True, + ) # test list instead of scalar self.assert_parse_fail( @@ -247,7 +286,10 @@ class TestParser(testcase.IDLTestcase): types: foo: - bar - """), idl.errors.ERROR_ID_IS_NODE_TYPE, multiple=True) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + multiple=True, + ) # test map instead of scalar self.assert_parse_fail( @@ -256,7 +298,10 @@ class TestParser(testcase.IDLTestcase): foo: description: foo: bar - """), idl.errors.ERROR_ID_IS_NODE_TYPE, multiple=True) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + multiple=True, + ) # test missing bson_serialization_type field self.assert_parse_fail( @@ -266,7 +311,9 @@ class TestParser(testcase.IDLTestcase): description: foo cpp_type: foo is_view: false - """), idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD) + """), + idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD, + ) # test missing cpp_type field self.assert_parse_fail( @@ -276,7 +323,9 @@ class TestParser(testcase.IDLTestcase): description: foo bson_serialization_type: foo is_view: false - """), idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD) + """), + idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD, + ) def test_struct_positive(self): # type: () -> None @@ -295,7 +344,8 @@ class TestParser(testcase.IDLTestcase): cpp_validator_func: funcName fields: foo: bar - """)) + """) + ) # All fields with false for bools self.assert_parse( @@ -310,7 +360,8 @@ class TestParser(testcase.IDLTestcase): cpp_validator_func: funcName fields: foo: bar - """)) + """) + ) # Missing fields self.assert_parse( @@ -319,7 +370,8 @@ class TestParser(testcase.IDLTestcase): foo: description: foo strict: true - """)) + """) + ) def test_struct_negative(self): # type: () -> None @@ -330,7 +382,9 @@ class TestParser(testcase.IDLTestcase): textwrap.dedent(""" structs: foo: foo - """), idl.errors.ERROR_ID_IS_NODE_TYPE) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + ) # unknown field self.assert_parse_fail( @@ -341,7 +395,9 @@ class TestParser(testcase.IDLTestcase): foo: bar fields: foo: bar - """), idl.errors.ERROR_ID_UNKNOWN_NODE) + """), + idl.errors.ERROR_ID_UNKNOWN_NODE, + ) # strict is a bool self.assert_parse_fail( @@ -352,7 +408,9 @@ class TestParser(testcase.IDLTestcase): strict: bar fields: foo: bar - """), idl.errors.ERROR_ID_IS_NODE_VALID_BOOL) + """), + idl.errors.ERROR_ID_IS_NODE_VALID_BOOL, + ) # immutable is a bool self.assert_parse_fail( @@ -363,7 +421,9 @@ class TestParser(testcase.IDLTestcase): immutable: bar fields: foo: bar - """), idl.errors.ERROR_ID_IS_NODE_VALID_BOOL) + """), + idl.errors.ERROR_ID_IS_NODE_VALID_BOOL, + ) # inline_chained_structs is a bool self.assert_parse_fail( @@ -374,7 +434,9 @@ class TestParser(testcase.IDLTestcase): inline_chained_structs: bar fields: foo: bar - """), idl.errors.ERROR_ID_IS_NODE_VALID_BOOL) + """), + idl.errors.ERROR_ID_IS_NODE_VALID_BOOL, + ) # generate_comparison_operators is a bool self.assert_parse_fail( @@ -385,7 +447,9 @@ class TestParser(testcase.IDLTestcase): generate_comparison_operators: bar fields: foo: bar - """), idl.errors.ERROR_ID_IS_NODE_VALID_BOOL) + """), + idl.errors.ERROR_ID_IS_NODE_VALID_BOOL, + ) # cpp_name is not allowed self.assert_parse_fail( @@ -396,7 +460,9 @@ class TestParser(testcase.IDLTestcase): cpp_name: bar fields: foo: bar - """), idl.errors.ERROR_ID_UNKNOWN_NODE) + """), + idl.errors.ERROR_ID_UNKNOWN_NODE, + ) def test_variant_positive(self): # type: () -> None @@ -417,7 +483,8 @@ class TestParser(testcase.IDLTestcase): - string - array - object - """)) + """) + ) def test_variant_negative(self): # type: () -> None @@ -432,7 +499,9 @@ class TestParser(testcase.IDLTestcase): my_variant_field: type: variant: {} - """), idl.errors.ERROR_ID_IS_NODE_TYPE) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + ) self.assert_parse_fail( textwrap.dedent(""" @@ -443,7 +512,9 @@ class TestParser(testcase.IDLTestcase): my_variant_field: type: variant: 1 - """), idl.errors.ERROR_ID_IS_NODE_TYPE) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + ) self.assert_parse_fail( textwrap.dedent(""" @@ -455,7 +526,9 @@ class TestParser(testcase.IDLTestcase): type: variant: [] unknown_option: true - """), idl.errors.ERROR_ID_UNKNOWN_NODE) + """), + idl.errors.ERROR_ID_UNKNOWN_NODE, + ) self.assert_parse_fail( textwrap.dedent(""" @@ -468,7 +541,9 @@ class TestParser(testcase.IDLTestcase): variant: - string - {variant: [string, int]} - """), idl.errors.ERROR_ID_IS_NODE_TYPE) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + ) self.assert_parse_fail( textwrap.dedent(""" @@ -480,7 +555,9 @@ class TestParser(testcase.IDLTestcase): my_variant_field: type: variant: [string, int] - """), idl.errors.ERROR_ID_VARIANT_COMPARISON) + """), + idl.errors.ERROR_ID_VARIANT_COMPARISON, + ) def test_field_positive(self): # type: () -> None @@ -494,7 +571,8 @@ class TestParser(testcase.IDLTestcase): description: foo fields: foo: short - """)) + """) + ) # Test all fields self.assert_parse( @@ -511,7 +589,8 @@ class TestParser(testcase.IDLTestcase): cpp_name: bar comparison_order: 3 stability: unstable - """)) + """) + ) # Test false bools self.assert_parse( @@ -526,7 +605,8 @@ class TestParser(testcase.IDLTestcase): optional: false ignore: false stability: stable - """)) + """) + ) def test_field_negative(self): # type: () -> None @@ -542,7 +622,9 @@ class TestParser(testcase.IDLTestcase): fields: foo: short foo: int - """), idl.errors.ERROR_ID_DUPLICATE_NODE) + """), + idl.errors.ERROR_ID_DUPLICATE_NODE, + ) # Test bad bool self.assert_parse_fail( @@ -555,7 +637,9 @@ class TestParser(testcase.IDLTestcase): foo: type: string optional: bar - """), idl.errors.ERROR_ID_IS_NODE_VALID_BOOL) + """), + idl.errors.ERROR_ID_IS_NODE_VALID_BOOL, + ) # Test bad bool self.assert_parse_fail( @@ -568,7 +652,9 @@ class TestParser(testcase.IDLTestcase): foo: type: string ignore: bar - """), idl.errors.ERROR_ID_IS_NODE_VALID_BOOL) + """), + idl.errors.ERROR_ID_IS_NODE_VALID_BOOL, + ) # Test bad int scalar self.assert_parse_fail( @@ -583,7 +669,9 @@ class TestParser(testcase.IDLTestcase): comparison_order: - a - b - """), idl.errors.ERROR_ID_IS_NODE_TYPE) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + ) # Test bad int self.assert_parse_fail( @@ -596,7 +684,9 @@ class TestParser(testcase.IDLTestcase): foo: type: string comparison_order: 3.14159 - """), idl.errors.ERROR_ID_IS_NODE_VALID_INT) + """), + idl.errors.ERROR_ID_IS_NODE_VALID_INT, + ) # Test bad negative int self.assert_parse_fail( @@ -609,7 +699,9 @@ class TestParser(testcase.IDLTestcase): foo: type: string comparison_order: -1 - """), idl.errors.ERROR_ID_IS_NODE_VALID_NON_NEGATIVE_INT) + """), + idl.errors.ERROR_ID_IS_NODE_VALID_NON_NEGATIVE_INT, + ) def test_name_collisions_negative(self): # type: () -> None @@ -633,7 +725,9 @@ class TestParser(testcase.IDLTestcase): strict: true fields: foo: string - """), idl.errors.ERROR_ID_DUPLICATE_SYMBOL) + """), + idl.errors.ERROR_ID_DUPLICATE_SYMBOL, + ) # Type after struct self.assert_parse_fail( @@ -654,7 +748,9 @@ class TestParser(testcase.IDLTestcase): deserializer: foo default: foo is_view: false - """), idl.errors.ERROR_ID_DUPLICATE_SYMBOL) + """), + idl.errors.ERROR_ID_DUPLICATE_SYMBOL, + ) def test_chained_type_positive(self): # type: () -> None @@ -667,7 +763,8 @@ class TestParser(testcase.IDLTestcase): chained_types: foo1: alias foo2: alias - """)) + """) + ) def test_chained_type_negative(self): # type: () -> None @@ -680,7 +777,9 @@ class TestParser(testcase.IDLTestcase): chained_types: foo1 fields: foo: bar - """), idl.errors.ERROR_ID_IS_NODE_TYPE) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + ) self.assert_parse_fail( textwrap.dedent(""" @@ -691,7 +790,9 @@ class TestParser(testcase.IDLTestcase): - foo1 fields: foo: bar - """), idl.errors.ERROR_ID_IS_NODE_TYPE) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + ) # Duplicate chained types self.assert_parse_fail( @@ -703,7 +804,9 @@ class TestParser(testcase.IDLTestcase): chained_types: foo1: alias foo1: alias - """), idl.errors.ERROR_ID_DUPLICATE_NODE) + """), + idl.errors.ERROR_ID_DUPLICATE_NODE, + ) def test_chained_struct_positive(self): # type: () -> None @@ -716,7 +819,8 @@ class TestParser(testcase.IDLTestcase): chained_structs: foo1: foo1_cpp foo2: foo2_cpp - """)) + """) + ) def test_chained_struct_negative(self): # type: () -> None @@ -729,7 +833,9 @@ class TestParser(testcase.IDLTestcase): chained_structs: foo1 fields: foo: bar - """), idl.errors.ERROR_ID_IS_NODE_TYPE) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + ) self.assert_parse_fail( textwrap.dedent(""" @@ -740,7 +846,9 @@ class TestParser(testcase.IDLTestcase): - foo1 fields: foo: bar - """), idl.errors.ERROR_ID_IS_NODE_TYPE) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + ) # Duplicate chained structs self.assert_parse_fail( @@ -752,7 +860,9 @@ class TestParser(testcase.IDLTestcase): chained_structs: chained: alias chained: alias - """), idl.errors.ERROR_ID_DUPLICATE_NODE) + """), + idl.errors.ERROR_ID_DUPLICATE_NODE, + ) def test_enum_positive(self): # type: () -> None @@ -767,7 +877,8 @@ class TestParser(testcase.IDLTestcase): type: foo values: v1: 0 - """)) + """) + ) # Test extended value self.assert_parse( @@ -780,7 +891,8 @@ class TestParser(testcase.IDLTestcase): v1: description: foo value: 0 - """)) + """) + ) # Test extra_data self.assert_parse( @@ -795,7 +907,8 @@ class TestParser(testcase.IDLTestcase): value: 0 extra_data: bar: baz - """)) + """) + ) def test_enum_negative(self): # type: () -> None @@ -816,13 +929,17 @@ class TestParser(testcase.IDLTestcase): type: int values: v1: 0 - """), idl.errors.ERROR_ID_DUPLICATE_SYMBOL) + """), + idl.errors.ERROR_ID_DUPLICATE_SYMBOL, + ) # Test scalar fails self.assert_parse_fail( textwrap.dedent(""" enums: - foo: 'bar'"""), idl.errors.ERROR_ID_IS_NODE_TYPE) + foo: 'bar'"""), + idl.errors.ERROR_ID_IS_NODE_TYPE, + ) # Test unknown field self.assert_parse_fail( @@ -834,7 +951,9 @@ class TestParser(testcase.IDLTestcase): type: foo values: v1: 0 - """), idl.errors.ERROR_ID_UNKNOWN_NODE) + """), + idl.errors.ERROR_ID_UNKNOWN_NODE, + ) # test duplicate field self.assert_parse_fail( @@ -846,14 +965,19 @@ class TestParser(testcase.IDLTestcase): type: foo values: v1: 0 - """), idl.errors.ERROR_ID_DUPLICATE_NODE) + """), + idl.errors.ERROR_ID_DUPLICATE_NODE, + ) # test list instead of scalar self.assert_parse_fail( textwrap.dedent(""" enums: - foo: - """), idl.errors.ERROR_ID_IS_NODE_TYPE, multiple=True) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + multiple=True, + ) # test list instead of scalar self.assert_parse_fail( @@ -861,7 +985,10 @@ class TestParser(testcase.IDLTestcase): enums: foo: - bar - """), idl.errors.ERROR_ID_IS_NODE_TYPE, multiple=True) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + multiple=True, + ) # test missing type field self.assert_parse_fail( @@ -871,7 +998,9 @@ class TestParser(testcase.IDLTestcase): description: foo values: v1: 0 - """), idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD) + """), + idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD, + ) # test missing values field self.assert_parse_fail( @@ -880,7 +1009,9 @@ class TestParser(testcase.IDLTestcase): foo: description: foo type: foo - """), idl.errors.ERROR_ID_BAD_EMPTY_ENUM) + """), + idl.errors.ERROR_ID_BAD_EMPTY_ENUM, + ) # Test no values self.assert_parse_fail( @@ -889,7 +1020,9 @@ class TestParser(testcase.IDLTestcase): foo: description: foo type: int - """), idl.errors.ERROR_ID_BAD_EMPTY_ENUM) + """), + idl.errors.ERROR_ID_BAD_EMPTY_ENUM, + ) # Name collision with types self.assert_parse_fail( @@ -910,7 +1043,9 @@ class TestParser(testcase.IDLTestcase): type: foo values: v1: 0 - """), idl.errors.ERROR_ID_DUPLICATE_SYMBOL) + """), + idl.errors.ERROR_ID_DUPLICATE_SYMBOL, + ) # Name collision with structs self.assert_parse_fail( @@ -938,7 +1073,9 @@ class TestParser(testcase.IDLTestcase): type: foo values: v1: 0 - """), idl.errors.ERROR_ID_DUPLICATE_SYMBOL) + """), + idl.errors.ERROR_ID_DUPLICATE_SYMBOL, + ) # Test int - duplicate names self.assert_parse_fail( @@ -950,7 +1087,9 @@ class TestParser(testcase.IDLTestcase): values: v1: 0 v1: 1 - """), idl.errors.ERROR_ID_DUPLICATE_NODE) + """), + idl.errors.ERROR_ID_DUPLICATE_NODE, + ) # Test extra_data invalid type self.assert_parse_fail( @@ -961,7 +1100,9 @@ class TestParser(testcase.IDLTestcase): type: int values: v1: [ 'foo' ] - """), idl.errors.ERROR_ID_IS_NODE_TYPE) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + ) # Test extended value missing fields (description) self.assert_parse_fail( @@ -973,7 +1114,9 @@ class TestParser(testcase.IDLTestcase): values: v1: value: 0 - """), idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD) + """), + idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD, + ) # Test extended value missing fields (value) self.assert_parse_fail( @@ -985,7 +1128,9 @@ class TestParser(testcase.IDLTestcase): values: v1: description: foo - """), idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD) + """), + idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD, + ) # Test invalid extra_data (scalar) self.assert_parse_fail( @@ -999,7 +1144,9 @@ class TestParser(testcase.IDLTestcase): description: foo value: 0 extra_data: foo - """), idl.errors.ERROR_ID_IS_NODE_TYPE) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + ) # Test invalid extra_data (sequence) self.assert_parse_fail( @@ -1013,7 +1160,9 @@ class TestParser(testcase.IDLTestcase): description: foo value: 0 extra_data: [ foo ] - """), idl.errors.ERROR_ID_IS_NODE_TYPE) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + ) def test_command_positive(self): # type: () -> None @@ -1037,7 +1186,8 @@ class TestParser(testcase.IDLTestcase): fields: foo: bar reply_type: foo_reply_struct - """)) + """) + ) # All fields with false for bools self.assert_parse( @@ -1056,7 +1206,8 @@ class TestParser(testcase.IDLTestcase): fields: foo: bar reply_type: foo_reply_struct - """)) + """) + ) # All fields with false for bools, empty api_version self.assert_parse( @@ -1075,7 +1226,8 @@ class TestParser(testcase.IDLTestcase): fields: foo: bar reply_type: foo_reply_struct - """)) + """) + ) # Quoted api_version self.assert_parse( @@ -1089,7 +1241,8 @@ class TestParser(testcase.IDLTestcase): fields: foo: bar reply_type: foo_reply_struct - """)) + """) + ) # Namespace ignored self.assert_parse( @@ -1102,7 +1255,8 @@ class TestParser(testcase.IDLTestcase): api_version: "" fields: foo: bar - """)) + """) + ) # Namespace concatenate_with_db self.assert_parse( @@ -1115,7 +1269,8 @@ class TestParser(testcase.IDLTestcase): api_version: "" fields: foo: bar - """)) + """) + ) # No fields self.assert_parse( @@ -1127,7 +1282,8 @@ class TestParser(testcase.IDLTestcase): namespace: ignored api_version: "" strict: true - """)) + """) + ) # Reply type permitted without api_version self.assert_parse( @@ -1139,7 +1295,8 @@ class TestParser(testcase.IDLTestcase): namespace: ignored api_version: "" reply_type: foo_reply_struct - """)) + """) + ) def test_command_negative(self): # type: () -> None @@ -1150,7 +1307,9 @@ class TestParser(testcase.IDLTestcase): textwrap.dedent(""" commands: foo: foo - """), idl.errors.ERROR_ID_IS_NODE_TYPE) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + ) # unknown field self.assert_parse_fail( @@ -1164,7 +1323,9 @@ class TestParser(testcase.IDLTestcase): foo: bar fields: foo: bar - """), idl.errors.ERROR_ID_UNKNOWN_NODE) + """), + idl.errors.ERROR_ID_UNKNOWN_NODE, + ) # strict is a bool self.assert_parse_fail( @@ -1178,7 +1339,9 @@ class TestParser(testcase.IDLTestcase): api_version: "" fields: foo: bar - """), idl.errors.ERROR_ID_IS_NODE_VALID_BOOL) + """), + idl.errors.ERROR_ID_IS_NODE_VALID_BOOL, + ) # command_name is required self.assert_parse_fail( @@ -1190,7 +1353,9 @@ class TestParser(testcase.IDLTestcase): api_version: "" fields: foo: bar - """), idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD) + """), + idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD, + ) # command_name is a scalar self.assert_parse_fail( @@ -1203,7 +1368,10 @@ class TestParser(testcase.IDLTestcase): api_version: "" fields: foo: bar - """), idl.errors.ERROR_ID_IS_NODE_TYPE, True) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + True, + ) self.assert_parse_fail( textwrap.dedent(""" @@ -1215,7 +1383,10 @@ class TestParser(testcase.IDLTestcase): api_version: "" fields: foo: bar - """), idl.errors.ERROR_ID_IS_NODE_TYPE, True) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + True, + ) # is_deprecated is a bool self.assert_parse_fail( @@ -1229,7 +1400,9 @@ class TestParser(testcase.IDLTestcase): is_deprecated: bar fields: foo: bar - """), idl.errors.ERROR_ID_IS_NODE_VALID_BOOL) + """), + idl.errors.ERROR_ID_IS_NODE_VALID_BOOL, + ) # api_version is required self.assert_parse_fail( @@ -1241,7 +1414,10 @@ class TestParser(testcase.IDLTestcase): namespace: ignored fields: foo: bar - """), idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD, True) + """), + idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD, + True, + ) # api_version is a scalar self.assert_parse_fail( @@ -1255,7 +1431,10 @@ class TestParser(testcase.IDLTestcase): fields: foo: bar reply_type: foo_reply_struct - """), idl.errors.ERROR_ID_IS_NODE_TYPE, True) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + True, + ) self.assert_parse_fail( textwrap.dedent(""" @@ -1268,7 +1447,10 @@ class TestParser(testcase.IDLTestcase): fields: foo: bar reply_type: foo_reply_struct - """), idl.errors.ERROR_ID_IS_NODE_TYPE, True) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + True, + ) # Must specify reply_type if api_version is non-empty self.assert_parse_fail( @@ -1281,7 +1463,9 @@ class TestParser(testcase.IDLTestcase): api_version: 1 fields: foo: bar - """), idl.errors.ERROR_ID_MISSING_REPLY_TYPE) + """), + idl.errors.ERROR_ID_MISSING_REPLY_TYPE, + ) # Namespace is required self.assert_parse_fail( @@ -1293,7 +1477,9 @@ class TestParser(testcase.IDLTestcase): api_version: "" fields: foo: bar - """), idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD) + """), + idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD, + ) # Namespace is wrong self.assert_parse_fail( @@ -1306,7 +1492,9 @@ class TestParser(testcase.IDLTestcase): api_version: "" fields: foo: bar - """), idl.errors.ERROR_ID_BAD_COMMAND_NAMESPACE) + """), + idl.errors.ERROR_ID_BAD_COMMAND_NAMESPACE, + ) # Setup some common types test_preamble = textwrap.dedent(""" @@ -1323,7 +1511,8 @@ class TestParser(testcase.IDLTestcase): # Commands and structs with same name self.assert_parse_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" commands: foo: description: foo @@ -1338,11 +1527,14 @@ class TestParser(testcase.IDLTestcase): description: foo fields: foo: foo - """), idl.errors.ERROR_ID_DUPLICATE_SYMBOL) + """), + idl.errors.ERROR_ID_DUPLICATE_SYMBOL, + ) # Commands and types with same name self.assert_parse_fail( - test_preamble + textwrap.dedent(""" + test_preamble + + textwrap.dedent(""" commands: string: description: foo @@ -1352,7 +1544,9 @@ class TestParser(testcase.IDLTestcase): strict: true fields: foo: string - """), idl.errors.ERROR_ID_DUPLICATE_SYMBOL) + """), + idl.errors.ERROR_ID_DUPLICATE_SYMBOL, + ) self.assert_parse_fail( textwrap.dedent(""" @@ -1365,7 +1559,10 @@ class TestParser(testcase.IDLTestcase): strict: true fields: foo: string - """) + test_preamble, idl.errors.ERROR_ID_DUPLICATE_SYMBOL) + """) + + test_preamble, + idl.errors.ERROR_ID_DUPLICATE_SYMBOL, + ) # Namespace concatenate_with_db self.assert_parse_fail( @@ -1379,7 +1576,9 @@ class TestParser(testcase.IDLTestcase): type: foobar fields: foo: bar - """), idl.errors.ERROR_ID_IS_COMMAND_TYPE_EXTRANEOUS) + """), + idl.errors.ERROR_ID_IS_COMMAND_TYPE_EXTRANEOUS, + ) # Reply type must be a scalar, not a mapping self.assert_parse_fail( @@ -1392,7 +1591,9 @@ class TestParser(testcase.IDLTestcase): api_version: "" reply_type: arbitrary_field: foo - """), idl.errors.ERROR_ID_IS_NODE_TYPE) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + ) def test_command_doc_sequence_positive(self): # type: () -> None @@ -1411,7 +1612,8 @@ class TestParser(testcase.IDLTestcase): foo: type: bar supports_doc_sequence: false - """)) + """) + ) # supports_doc_sequence can be true self.assert_parse( @@ -1426,7 +1628,8 @@ class TestParser(testcase.IDLTestcase): foo: type: bar supports_doc_sequence: true - """)) + """) + ) def test_command_doc_sequence_negative(self): # type: () -> None @@ -1445,7 +1648,9 @@ class TestParser(testcase.IDLTestcase): foo: type: bar supports_doc_sequence: foo - """), idl.errors.ERROR_ID_IS_NODE_VALID_BOOL) + """), + idl.errors.ERROR_ID_IS_NODE_VALID_BOOL, + ) def test_command_type_positive(self): # type: () -> None @@ -1463,7 +1668,8 @@ class TestParser(testcase.IDLTestcase): type: string fields: foo: bar - """)) + """) + ) # array of string self.assert_parse( @@ -1478,7 +1684,8 @@ class TestParser(testcase.IDLTestcase): type: array fields: foo: bar - """)) + """) + ) # no fields self.assert_parse( @@ -1491,7 +1698,8 @@ class TestParser(testcase.IDLTestcase): namespace: type api_version: "" type: string - """)) + """) + ) def test_command_type_negative(self): # type: () -> None @@ -1508,7 +1716,9 @@ class TestParser(testcase.IDLTestcase): api_version: "" fields: foo: bar - """), idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD) + """), + idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD, + ) def test_stability_positive(self): # type: () -> None @@ -1527,7 +1737,8 @@ class TestParser(testcase.IDLTestcase): type: bar stability: {stability} reply_type: foo_reply_struct - """)) + """) + ) def test_stability_negative(self): # type: () -> None @@ -1545,7 +1756,9 @@ class TestParser(testcase.IDLTestcase): type: bar stability: unstable reply_type: foo_reply_struct - """), idl.errors.ERROR_ID_STABILITY_NO_API_VERSION) + """), + idl.errors.ERROR_ID_STABILITY_NO_API_VERSION, + ) self.assert_parse_fail( textwrap.dedent(""" commands: @@ -1559,7 +1772,9 @@ class TestParser(testcase.IDLTestcase): type: bar stability: "unknown" reply_type: foo_reply_struct - """), idl.errors.ERROR_ID_STABILITY_UNKNOWN_VALUE) + """), + idl.errors.ERROR_ID_STABILITY_UNKNOWN_VALUE, + ) self.assert_parse_fail( textwrap.dedent(""" commands: @@ -1574,7 +1789,9 @@ class TestParser(testcase.IDLTestcase): unstable: true stability: "unstable" reply_type: foo_reply_struct - """), idl.errors.ERROR_ID_DUPLICATE_UNSTABLE_STABILITY) + """), + idl.errors.ERROR_ID_DUPLICATE_UNSTABLE_STABILITY, + ) def test_scalar_or_mapping_negative(self): # type: () -> None @@ -1592,7 +1809,9 @@ class TestParser(testcase.IDLTestcase): default: - one - two - """), idl.errors.ERROR_ID_IS_NODE_TYPE_SCALAR_OR_MAPPING) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE_SCALAR_OR_MAPPING, + ) def test_feature_flag(self): # type: () -> None @@ -1606,7 +1825,9 @@ class TestParser(testcase.IDLTestcase): description: "Make toast" cpp_varname: gToaster shouldBeFCVGated: true - """), idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD) + """), + idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD, + ) # Missing shouldBeFCVGated self.assert_parse_fail( @@ -1616,7 +1837,9 @@ class TestParser(testcase.IDLTestcase): description: "Make toast" cpp_varname: gToaster default: false - """), idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD) + """), + idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD, + ) def test_command_alias(self): # type: () -> None @@ -1636,7 +1859,9 @@ class TestParser(testcase.IDLTestcase): foo: type: bar reply_type: foo_reply_struct - """), idl.errors.ERROR_ID_COMMAND_DUPLICATES_NAME_AND_ALIAS) + """), + idl.errors.ERROR_ID_COMMAND_DUPLICATES_NAME_AND_ALIAS, + ) def test_access_checks_positive(self): # type: () -> None @@ -1655,7 +1880,8 @@ class TestParser(testcase.IDLTestcase): fields: foo: bar reply_type: foo_reply_struct - """)) + """) + ) self.assert_parse( textwrap.dedent(""" @@ -1670,7 +1896,8 @@ class TestParser(testcase.IDLTestcase): fields: foo: bar reply_type: foo_reply_struct - """)) + """) + ) self.assert_parse( textwrap.dedent(""" @@ -1686,7 +1913,8 @@ class TestParser(testcase.IDLTestcase): fields: foo: bar reply_type: foo_reply_struct - """)) + """) + ) self.assert_parse( textwrap.dedent(""" @@ -1712,7 +1940,8 @@ class TestParser(testcase.IDLTestcase): fields: foo: bar reply_type: foo_reply_struct - """)) + """) + ) self.assert_parse( textwrap.dedent(""" @@ -1730,7 +1959,8 @@ class TestParser(testcase.IDLTestcase): fields: foo: bar reply_type: foo_reply_struct - """)) + """) + ) def test_access_checks_negative(self): # type: () -> None @@ -1754,7 +1984,9 @@ class TestParser(testcase.IDLTestcase): fields: foo: bar reply_type: foo_reply_struct - """), idl.errors.ERROR_ID_EITHER_CHECK_OR_PRIVILEGE) + """), + idl.errors.ERROR_ID_EITHER_CHECK_OR_PRIVILEGE, + ) # simple: true fails self.assert_parse_fail( @@ -1770,7 +2002,9 @@ class TestParser(testcase.IDLTestcase): fields: foo: bar reply_type: foo_reply_struct - """), idl.errors.ERROR_ID_IS_NODE_TYPE) + """), + idl.errors.ERROR_ID_IS_NODE_TYPE, + ) # simple empty fails self.assert_parse_fail( @@ -1786,7 +2020,9 @@ class TestParser(testcase.IDLTestcase): fields: foo: bar reply_type: foo_reply_struct - """), idl.errors.ERROR_ID_EITHER_CHECK_OR_PRIVILEGE) + """), + idl.errors.ERROR_ID_EITHER_CHECK_OR_PRIVILEGE, + ) # duplicate access_check - none and simple self.assert_parse_fail( @@ -1806,7 +2042,9 @@ class TestParser(testcase.IDLTestcase): fields: foo: bar reply_type: foo_reply_struct - """), idl.errors.ERROR_ID_EMPTY_ACCESS_CHECK) + """), + idl.errors.ERROR_ID_EMPTY_ACCESS_CHECK, + ) # duplicate access_check - none and complex self.assert_parse_fail( @@ -1830,7 +2068,9 @@ class TestParser(testcase.IDLTestcase): fields: foo: bar reply_type: foo_reply_struct - """), idl.errors.ERROR_ID_EMPTY_ACCESS_CHECK) + """), + idl.errors.ERROR_ID_EMPTY_ACCESS_CHECK, + ) # duplicate access_check - simple and complex self.assert_parse_fail( @@ -1857,7 +2097,9 @@ class TestParser(testcase.IDLTestcase): fields: foo: bar reply_type: foo_reply_struct - """), idl.errors.ERROR_ID_EMPTY_ACCESS_CHECK) + """), + idl.errors.ERROR_ID_EMPTY_ACCESS_CHECK, + ) # duplicate access_check - none, simple and complex self.assert_parse_fail( @@ -1885,11 +2127,14 @@ class TestParser(testcase.IDLTestcase): fields: foo: bar reply_type: foo_reply_struct - """), idl.errors.ERROR_ID_EMPTY_ACCESS_CHECK) + """), + idl.errors.ERROR_ID_EMPTY_ACCESS_CHECK, + ) # pylint: disable=invalid-name - def test_struct_unsafe_dangerous_disable_extra_field_duplicate_checks_negative(self): - + def test_struct_unsafe_dangerous_disable_extra_field_duplicate_checks_negative( + self, + ): # Test commands and unsafe_dangerous_disable_extra_field_duplicate_checks are disallowed self.assert_parse_fail( textwrap.dedent(""" @@ -1903,9 +2148,10 @@ class TestParser(testcase.IDLTestcase): unsafe_dangerous_disable_extra_field_duplicate_checks: true fields: foo: string - """), idl.errors.ERROR_ID_UNKNOWN_NODE) + """), + idl.errors.ERROR_ID_UNKNOWN_NODE, + ) -if __name__ == '__main__': - +if __name__ == "__main__": unittest.main() diff --git a/buildscripts/idl/tests/testcase.py b/buildscripts/idl/tests/testcase.py index 27a7d6a540b..0caf7076077 100644 --- a/buildscripts/idl/tests/testcase.py +++ b/buildscripts/idl/tests/testcase.py @@ -29,9 +29,10 @@ import unittest -if __name__ == 'testcase': +if __name__ == "testcase": import sys from os import path + sys.path.append(path.dirname(path.dirname(path.abspath(__file__)))) from context import idl else: @@ -78,8 +79,9 @@ class IDLTestcase(unittest.TestCase): """Assert a document parsed correctly by the IDL compiler and returned no errors.""" self.assertIsNone( parsed_doc.errors, - "Expected no parser errors\nFor document:\n%s\nReceived errors:\n\n%s" % - (doc_str, errors_to_str(parsed_doc.errors))) + "Expected no parser errors\nFor document:\n%s\nReceived errors:\n\n%s" + % (doc_str, errors_to_str(parsed_doc.errors)), + ) self.assertIsNotNone(parsed_doc.spec, "Expected a parsed doc") def assert_parse(self, doc_str, resolver=NothingImportResolver()): @@ -88,8 +90,9 @@ class IDLTestcase(unittest.TestCase): parsed_doc = self._parse(doc_str, resolver) self._assert_parse(doc_str, parsed_doc) - def assert_parse_fail(self, doc_str, error_id, multiple=False, - resolver=NothingImportResolver()): + def assert_parse_fail( + self, doc_str, error_id, multiple=False, resolver=NothingImportResolver() + ): # type: (str, str, bool, idl.parser.ImportResolverBase) -> None """ Assert a document parsed correctly by the YAML parser, but not the by the IDL compiler. @@ -106,12 +109,14 @@ class IDLTestcase(unittest.TestCase): self.assertTrue( multiple or parsed_doc.errors.count() == 1, "For document:\n%s\nExpected only error message '%s' but received multiple errors:\n\n%s" - % (doc_str, error_id, errors_to_str(parsed_doc.errors))) + % (doc_str, error_id, errors_to_str(parsed_doc.errors)), + ) self.assertTrue( parsed_doc.errors.contains(error_id), - "For document:\n%s\nExpected error message '%s' but received only errors:\n %s" % - (doc_str, error_id, errors_to_str(parsed_doc.errors))) + "For document:\n%s\nExpected error message '%s' but received only errors:\n %s" + % (doc_str, error_id, errors_to_str(parsed_doc.errors)), + ) def assert_bind(self, doc_str, resolver=NothingImportResolver()): # type: (str, idl.parser.ImportResolverBase) -> idl.ast.IDLBoundSpec @@ -122,13 +127,17 @@ class IDLTestcase(unittest.TestCase): bound_doc = idl.binder.bind(parsed_doc.spec) self.assertIsNone( - bound_doc.errors, "Expected no binder errors\nFor document:\n%s\nReceived errors:\n\n%s" - % (doc_str, errors_to_str(bound_doc.errors))) + bound_doc.errors, + "Expected no binder errors\nFor document:\n%s\nReceived errors:\n\n%s" + % (doc_str, errors_to_str(bound_doc.errors)), + ) self.assertIsNotNone(bound_doc.spec, "Expected a bound doc") return bound_doc.spec - def assert_bind_fail(self, doc_str, error_id, multiple=False, resolver=NothingImportResolver()): + def assert_bind_fail( + self, doc_str, error_id, multiple=False, resolver=NothingImportResolver() + ): # type: (str, str, bool, idl.parser.ImportResolverBase) -> None """ Assert a document parsed correctly by the YAML parser and IDL parser, but not bound by the IDL binder. @@ -140,20 +149,25 @@ class IDLTestcase(unittest.TestCase): bound_doc = idl.binder.bind(parsed_doc.spec) - self.assertIsNone(bound_doc.spec, "Expected no bound doc\nFor document:\n%s\n" % (doc_str)) + self.assertIsNone( + bound_doc.spec, "Expected no bound doc\nFor document:\n%s\n" % (doc_str) + ) self.assertIsNotNone(bound_doc.errors, "Expected binder errors") # Assert that negative test cases are only testing one fault in a test. # This is impossible to assert for all tests though. self.assertTrue( - (multiple and bound_doc.errors.count() >= 1) or bound_doc.errors.count() == 1, + (multiple and bound_doc.errors.count() >= 1) + or bound_doc.errors.count() == 1, "For document:\n%s\nExpected only error message '%s' but received multiple errors:\n\n%s" - % (doc_str, error_id, errors_to_str(bound_doc.errors))) + % (doc_str, error_id, errors_to_str(bound_doc.errors)), + ) self.assertTrue( bound_doc.errors.contains(error_id), - "For document:\n%s\nExpected error message '%s' but received only errors:\n %s" % - (doc_str, error_id, errors_to_str(bound_doc.errors))) + "For document:\n%s\nExpected error message '%s' but received only errors:\n %s" + % (doc_str, error_id, errors_to_str(bound_doc.errors)), + ) def assert_generate(self, doc_str, resolver=NothingImportResolver()): # type: (str, idl.parser.ImportResolverBase) -> Tuple[str,str] diff --git a/buildscripts/install_bazel.py b/buildscripts/install_bazel.py index c50c6305c5e..62254ef34d5 100755 --- a/buildscripts/install_bazel.py +++ b/buildscripts/install_bazel.py @@ -64,8 +64,12 @@ def install_buildozer(download_location: str = "./"): def install_bazel(binary_directory: str) -> str: install_buildozer(binary_directory) - normalized_arch = (platform.machine().lower().replace("aarch64", "arm64").replace( - "x86_64", "amd64")) + normalized_arch = ( + platform.machine() + .lower() + .replace("aarch64", "arm64") + .replace("x86_64", "amd64") + ) normalized_os = sys.platform.replace("win32", "windows").replace("darwin", "macos") is_bazelisk_supported = normalized_arch not in ["ppc64le", "s390x"] binary_filename = "bazelisk" @@ -95,20 +99,23 @@ def install_bazel(binary_directory: str) -> str: def _set_bazel_permissions(binary_path: str) -> None: # Bazel is a self-extracting zip launcher and needs read perms on the executable to read the zip from itself. - perms = (stat.S_IXUSR - | stat.S_IXGRP - | stat.S_IXOTH - | stat.S_IRUSR - | stat.S_IRGRP - | stat.S_IROTH - | stat.S_IWUSR - | stat.S_IWGRP) + perms = ( + stat.S_IXUSR + | stat.S_IXGRP + | stat.S_IXOTH + | stat.S_IRUSR + | stat.S_IRGRP + | stat.S_IROTH + | stat.S_IWUSR + | stat.S_IWGRP + ) os.chmod(binary_path, perms) def create_bazel_to_bazelisk_symlink(binary_directory: str) -> str: - bazel_symlink = os.path.join(binary_directory, - "bazel.exe" if sys.platform == "win32" else "bazel") + bazel_symlink = os.path.join( + binary_directory, "bazel.exe" if sys.platform == "win32" else "bazel" + ) if os.path.exists(bazel_symlink): print(f"Symlink {bazel_symlink} already exists, skipping symlink creation") return bazel_symlink @@ -156,13 +163,19 @@ def main(): else: print("To add it to your PATH, run: \n") if os.path.exists(os.path.expanduser("~/.bashrc")): - print(f'echo "export PATH=\\{abs_binary_directory}:$PATH" >> ~/.bashrc') + print( + f'echo "export PATH=\\{abs_binary_directory}:$PATH" >> ~/.bashrc' + ) print("source ~/.bashrc") elif os.path.exists(os.path.expanduser("~/.bash_profile")): - print(f'echo "export PATH=\\{abs_binary_directory}:$PATH" >> ~/.bash_profile') + print( + f'echo "export PATH=\\{abs_binary_directory}:$PATH" >> ~/.bash_profile' + ) print("source ~/.bash_profile") elif os.path.exists(os.path.expanduser("~/.zshrc")): - print(f'echo "export PATH=\\{abs_binary_directory}:$PATH" >> ~/.zshrc') + print( + f'echo "export PATH=\\{abs_binary_directory}:$PATH" >> ~/.zshrc' + ) print("source ~/.zshrc") else: print(f"export PATH={abs_binary_directory}:$PATH") diff --git a/buildscripts/jepsen_report.py b/buildscripts/jepsen_report.py index 243c53f5934..84c3e386740 100644 --- a/buildscripts/jepsen_report.py +++ b/buildscripts/jepsen_report.py @@ -1,4 +1,5 @@ """Generate Evergreen reports from the Jepsen list-append workload.""" + import json import os import re @@ -26,7 +27,9 @@ class ParserOutput(TypedDict): _JEPSEN_TIME_FORMAT = "%Y-%m-%d %H:%M:%S" _JEPSEN_MILLI_RE = re.compile("([0-9]+){(.*)}") -_JEPSEN_TIME_RE = re.compile("[0-9]{4}-[0-8]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2},[0-9]+{.*}") +_JEPSEN_TIME_RE = re.compile( + "[0-9]{4}-[0-8]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2},[0-9]+{.*}" +) def _time_parse(time: str): @@ -147,20 +150,26 @@ def parse(text: List[str]) -> ParserOutput: # noqa: D406 target.append(line) assert success_table_matches == len( - successful_tests), "Mismatch in success_table_matches and length of successful_tests" + successful_tests + ), "Mismatch in success_table_matches and length of successful_tests" assert unknown_table_matches == len( - indeterminate_tests), "Mismatch in unknown_table_matches and length of indeterminate_tests" + indeterminate_tests + ), "Mismatch in unknown_table_matches and length of indeterminate_tests" assert crash_table_matches == len( - crashed_tests), "Mismatch in crash_table_matches and length of crashed_tests" + crashed_tests + ), "Mismatch in crash_table_matches and length of crashed_tests" assert fail_table_matches == len( - failed_tests), "Mismatch in fail_table_matches and length of failed_tests" + failed_tests + ), "Mismatch in fail_table_matches and length of failed_tests" - return ParserOutput({ - 'success': successful_tests, - 'unknown': indeterminate_tests, - 'crashed': crashed_tests, - 'failed': failed_tests, - }) + return ParserOutput( + { + "success": successful_tests, + "unknown": indeterminate_tests, + "crashed": crashed_tests, + "failed": failed_tests, + } + ) def _try_find_log_file(store: Optional[str], test_name) -> str: @@ -175,45 +184,84 @@ def _try_find_log_file(store: Optional[str], test_name) -> str: return "" -def report(out: ParserOutput, start_time: int, end_time: int, elapsed: int, - store: Optional[str]) -> Report: +def report( + out: ParserOutput, + start_time: int, + end_time: int, + elapsed: int, + store: Optional[str], +) -> Report: """Given ParserOutput, return report.json as a dict.""" results = [] failures = 0 - for test_name in out['success']: + for test_name in out["success"]: log_raw = _try_find_log_file(store, test_name) start_time, end_time, elapsed_time = _calc_time_from_log(log_raw) results.append( - Result(status='pass', exit_code=0, test_file=test_name, start=start_time, end=end_time, - elapsed=elapsed_time, log_raw=log_raw)) + Result( + status="pass", + exit_code=0, + test_file=test_name, + start=start_time, + end=end_time, + elapsed=elapsed_time, + log_raw=log_raw, + ) + ) - for test_name in out['failed']: + for test_name in out["failed"]: log_raw = _try_find_log_file(store, test_name) start_time, end_time, elapsed_time = _calc_time_from_log(log_raw) failures += 1 results.append( - Result(status='fail', exit_code=1, test_file=test_name, start=start_time, end=end_time, - elapsed=elapsed_time, log_raw=log_raw)) + Result( + status="fail", + exit_code=1, + test_file=test_name, + start=start_time, + end=end_time, + elapsed=elapsed_time, + log_raw=log_raw, + ) + ) - for test_name in out['crashed']: + for test_name in out["crashed"]: log_raw = "Log files are unavailable for crashed tests because Jepsen does not save them separately. You may be able to find the exception and stack trace in the task log" failures += 1 results.append( - Result(status='fail', exit_code=1, test_file=test_name, start=start_time, end=end_time, - elapsed=elapsed, log_raw=log_raw)) + Result( + status="fail", + exit_code=1, + test_file=test_name, + start=start_time, + end=end_time, + elapsed=elapsed, + log_raw=log_raw, + ) + ) - for test_name in out['unknown']: + for test_name in out["unknown"]: log_raw = _try_find_log_file(store, test_name) start_time, end_time, elapsed_time = _calc_time_from_log(log_raw) failures += 1 results.append( - Result(status='fail', exit_code=1, test_file=test_name, start=start_time, end=end_time, - elapsed=elapsed_time, log_raw=log_raw)) - return Report({ - "failures": failures, - "results": results, - }) + Result( + status="fail", + exit_code=1, + test_file=test_name, + start=start_time, + end=end_time, + elapsed=elapsed_time, + log_raw=log_raw, + ) + ) + return Report( + { + "failures": failures, + "results": results, + } + ) def _get_log_lines(filename: str) -> List[str]: @@ -230,26 +278,41 @@ def _put_report(report_: Report) -> None: @click.option("--start_time", type=int, required=True) @click.option("--end_time", type=int, required=True) @click.option("--elapsed", type=int, required=True) -@click.option("--emit_status_files", type=bool, is_flag=True, default=False, - help="If true, emit status files for marking Evergreen tasks as system fails") -@click.option("--store", type=str, default=None, - help="Path to folder containing jepsen 'store' directory") +@click.option( + "--emit_status_files", + type=bool, + is_flag=True, + default=False, + help="If true, emit status files for marking Evergreen tasks as system fails", +) +@click.option( + "--store", + type=str, + default=None, + help="Path to folder containing jepsen 'store' directory", +) @click.argument("filename", type=str) -def main(filename: str, start_time: str, end_time: str, elapsed: str, emit_status_files: bool, - store: Optional[str]): +def main( + filename: str, + start_time: str, + end_time: str, + elapsed: str, + emit_status_files: bool, + store: Optional[str], +): """Generate Evergreen reports from the Jepsen list-append workload.""" out = parse(_get_log_lines(filename)) _put_report(report(out, start_time, end_time, elapsed, store)) exit_code = 255 - if out['crashed']: + if out["crashed"]: exit_code = 2 if emit_status_files: with open("jepsen_system_fail.txt", "w") as fh: fh.write(str(exit_code)) else: - if out['unknown'] or out['failed']: + if out["unknown"] or out["failed"]: exit_code = 1 else: exit_code = 0 diff --git a/buildscripts/jstoh.py b/buildscripts/jstoh.py index adcb69ed2c5..28bd1c39673 100755 --- a/buildscripts/jstoh.py +++ b/buildscripts/jstoh.py @@ -25,7 +25,6 @@ import sys def jsToHeader(target, source): - outFile = target h = [ @@ -52,8 +51,10 @@ def jsToHeader(target, source): h.append("0};") # symbols aren't exported w/o this h.append("extern const JSFile %s;" % objname) - h.append('const JSFile %s = { "%s", StringData(%s, sizeof(%s) - 1) };' % - (objname, filename.replace("\\", "/"), stringname, stringname)) + h.append( + 'const JSFile %s = { "%s", StringData(%s, sizeof(%s) - 1) };' + % (objname, filename.replace("\\", "/"), stringname, stringname) + ) h.append("} // namespace JSFiles") h.append("} // namespace mongo") diff --git a/buildscripts/large_file_check.py b/buildscripts/large_file_check.py index d22c4b1814e..87535ce40c7 100755 --- a/buildscripts/large_file_check.py +++ b/buildscripts/large_file_check.py @@ -13,7 +13,9 @@ from typing import Any, Dict, List, Optional, Tuple import structlog from git import Repo -mongo_dir = os.path.dirname(os.path.dirname(os.path.abspath(os.path.realpath(__file__)))) +mongo_dir = os.path.dirname( + os.path.dirname(os.path.abspath(os.path.realpath(__file__))) +) # Get relative imports to work when the package is not installed on the PYTHONPATH. if __name__ == "__main__" and __package__ is None: sys.path.append(mongo_dir) @@ -63,7 +65,9 @@ def _get_repos_and_revisions() -> Tuple[List[Repo], RevisionMap]: else: repos = [Repo(git.get_base_dir())] - revision_map = generate_revision_map(repos, {"mongo": os.environ.get(MONGO_REVISION_ENV_VAR)}) + revision_map = generate_revision_map( + repos, {"mongo": os.environ.get(MONGO_REVISION_ENV_VAR)} + ) return repos, revision_map @@ -87,7 +91,9 @@ def git_changed_files(excludes: List[pathlib.Path]) -> List[pathlib.Path]: files = [ filename - for filename in list(map(pathlib.Path, find_changed_files_in_repos(repos, revision_map))) + for filename in list( + map(pathlib.Path, find_changed_files_in_repos(repos, revision_map)) + ) if _filter_fn(filename) ] @@ -95,7 +101,9 @@ def git_changed_files(excludes: List[pathlib.Path]) -> List[pathlib.Path]: return files -def diff_file_sizes(size_limit: int, excludes: Optional[List[str]] = None) -> List[pathlib.Path]: +def diff_file_sizes( + size_limit: int, excludes: Optional[List[str]] = None +) -> List[pathlib.Path]: if excludes is None: excludes = [] @@ -129,7 +137,9 @@ def main(*args: str) -> int: type=pathlib.Path, required=False, ) - parser.add_argument("--size-mb", help="File size limit (MiB)", type=int, default="10") + parser.add_argument( + "--size-mb", help="File size limit (MiB)", type=int, default="10" + ) parsed_args = parser.parse_args(args[1:]) if parsed_args.verbose: @@ -139,7 +149,9 @@ def main(*args: str) -> int: logging.basicConfig(level=logging.INFO) structlog.stdlib.filter_by_level(LOGGER, "info", {}) - large_files = diff_file_sizes(parsed_args.size_mb * 1024 * 1024, parsed_args.exclude) + large_files = diff_file_sizes( + parsed_args.size_mb * 1024 * 1024, parsed_args.exclude + ) if len(large_files) == 0: LOGGER.info("All files passed size check") return 0 diff --git a/buildscripts/linter/filediff.py b/buildscripts/linter/filediff.py index 1fc2d00cc7a..cff28a3ce17 100644 --- a/buildscripts/linter/filediff.py +++ b/buildscripts/linter/filediff.py @@ -9,7 +9,9 @@ from git import Repo # Get relative imports to work when the package is not installed on the PYTHONPATH. if __name__ == "__main__" and __package__ is None: - sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(os.path.realpath(__file__))))) + sys.path.append( + os.path.dirname(os.path.dirname(os.path.abspath(os.path.realpath(__file__)))) + ) from buildscripts.linter import git from buildscripts.patch_builds.change_data import ( @@ -27,7 +29,9 @@ def _get_repos_and_revisions() -> Tuple[List[Repo], RevisionMap]: """Get the repo object and a map of revisions to compare against.""" repos = [Repo(git.get_base_dir())] - revision_map = generate_revision_map(repos, {"mongo": os.environ.get(MONGO_REVISION_ENV_VAR)}) + revision_map = generate_revision_map( + repos, {"mongo": os.environ.get(MONGO_REVISION_ENV_VAR)} + ) return repos, revision_map @@ -42,7 +46,9 @@ def _filter_file(filename: str, is_interesting_file: Callable[[str], bool]) -> b return os.path.exists(filename) and is_interesting_file(filename) -def gather_changed_files_for_lint(is_interesting_file: Callable[[str], bool]) -> List[str]: +def gather_changed_files_for_lint( + is_interesting_file: Callable[[str], bool], +) -> List[str]: """ Get the files that have changes since the last git commit. @@ -54,7 +60,9 @@ def gather_changed_files_for_lint(is_interesting_file: Callable[[str], bool]) -> candidate_files = find_changed_files_in_repos(repos, revision_map) files = [ - filename for filename in candidate_files if _filter_file(filename, is_interesting_file) + filename + for filename in candidate_files + if _filter_file(filename, is_interesting_file) ] LOGGER.info("Found files to lint", files=files) diff --git a/buildscripts/linter/git.py b/buildscripts/linter/git.py index 4abc223b6ea..3d9aeb73004 100644 --- a/buildscripts/linter/git.py +++ b/buildscripts/linter/git.py @@ -64,7 +64,9 @@ class Repo(_git.Repository): valid_files = list(self.get_candidate_files(filter_function)) # Get the full file name here - valid_files = [os.path.normpath(os.path.join(self.directory, f)) for f in valid_files] + valid_files = [ + os.path.normpath(os.path.join(self.directory, f)) for f in valid_files + ] return valid_files @@ -75,7 +77,11 @@ class Repo(_git.Repository): # This allows us to pick all the interesting files # in the mongo and mongo-enterprise repos - file_list = [line.rstrip() for line in gito.splitlines() if filter_function(line.rstrip())] + file_list = [ + line.rstrip() + for line in gito.splitlines() + if filter_function(line.rstrip()) + ] return file_list @@ -121,7 +127,9 @@ class Repo(_git.Repository): valid_files = list(self.get_working_tree_candidate_files(filter_function)) # Get the full file name here - valid_files = [os.path.normpath(os.path.join(self.directory, f)) for f in valid_files] + valid_files = [ + os.path.normpath(os.path.join(self.directory, f)) for f in valid_files + ] # Filter out files that git thinks exist but were removed. valid_files = [f for f in valid_files if os.path.exists(f)] @@ -166,7 +174,9 @@ def get_valid_files_from_candidates(candidates, filter_fn: Callable[[str], bool] repos = get_repos() valid_files = list( - itertools.chain.from_iterable([r.get_candidates(candidates, filter_fn) for r in repos]) + itertools.chain.from_iterable( + [r.get_candidates(candidates, filter_fn) for r in repos] + ) ) return valid_files @@ -185,7 +195,9 @@ def get_files_to_check(files, filter_function): valid_files = get_valid_files_from_candidates(candidates, filter_function) if files and not valid_files: - raise ValueError("Globs '%s' did not find any files with glob in git." % (files)) + raise ValueError( + "Globs '%s' did not find any files with glob in git." % (files) + ) return valid_files diff --git a/buildscripts/linter/git_base.py b/buildscripts/linter/git_base.py index 852cb74fc28..fe34033d8c3 100644 --- a/buildscripts/linter/git_base.py +++ b/buildscripts/linter/git_base.py @@ -79,7 +79,9 @@ class Repository(object): def get_origin_url(self): """Return the URL of the origin repository.""" - return self._callgito("config", ["--local", "--get", "remote.origin.url"]).rstrip() + return self._callgito( + "config", ["--local", "--get", "remote.origin.url"] + ).rstrip() def get_branch_name(self): """ @@ -110,7 +112,9 @@ class Repository(object): """Return True if the specified parent hash an ancestor of child hash.""" # If the common point between parent_revision and child_revision is # parent_revision, then parent_revision is an ancestor of child_revision. - merge_base = self._callgito("merge-base", [parent_revision, child_revision]).rstrip() + merge_base = self._callgito( + "merge-base", [parent_revision, child_revision] + ).rstrip() return parent_revision == merge_base def is_commit(self, revision): @@ -184,7 +188,7 @@ class Repository(object): params.extend(["rev-parse", "--show-toplevel"]) result = Repository._run_process("rev-parse", params) result.check_returncode() - return result.stdout.decode('utf-8').rstrip() + return result.stdout.decode("utf-8").rstrip() @staticmethod def current_repository(): @@ -195,7 +199,7 @@ class Repository(object): """Call git for this repository, and return the captured output.""" result = self._run_cmd(cmd, args) result.check_returncode() - return result.stdout.decode('utf-8') + return result.stdout.decode("utf-8") def _callgit(self, cmd, args, raise_exception=False): """ @@ -216,14 +220,18 @@ class Repository(object): @staticmethod def _run_process(cmd, params, cwd=None): - process = subprocess.Popen(params, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=cwd) + process = subprocess.Popen( + params, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=cwd + ) (stdout, stderr) = process.communicate() if process.returncode: if stdout: LOGGER.error("Output of '%s': %s", " ".join(params), stdout) if stderr: LOGGER.error("Error output of '%s': %s", " ".join(params), stderr) - return GitCommandResult(cmd, params, process.returncode, stdout=stdout, stderr=stderr) + return GitCommandResult( + cmd, params, process.returncode, stdout=stdout, stderr=stderr + ) class GitException(Exception): @@ -239,8 +247,15 @@ class GitException(Exception): """ - def __init__(self, message, returncode=None, cmd=None, process_args=None, stdout=None, - stderr=None): + def __init__( + self, + message, + returncode=None, + cmd=None, + process_args=None, + stdout=None, + stderr=None, + ): """Initialize GitException.""" Exception.__init__(self, message) self.returncode = returncode @@ -274,6 +289,12 @@ class GitCommandResult(object): """Raise GitException if the exit code is non-zero.""" if self.returncode: raise GitException( - "Command '{0}' failed with code '{1}'".format(" ".join(self.process_args), - self.returncode), self.returncode, - self.cmd, self.process_args, self.stdout, self.stderr) + "Command '{0}' failed with code '{1}'".format( + " ".join(self.process_args), self.returncode + ), + self.returncode, + self.cmd, + self.process_args, + self.stdout, + self.stderr, + ) diff --git a/buildscripts/linter/mongolint.py b/buildscripts/linter/mongolint.py index 46ee87dcfc7..7b2a645bd2d 100644 --- a/buildscripts/linter/mongolint.py +++ b/buildscripts/linter/mongolint.py @@ -10,30 +10,30 @@ import sys _RE_LINT = re.compile("//.*NOLINT") _RE_COMMENT_STRIP = re.compile("//.*") -_RE_GENERIC_FCV_COMMENT = re.compile(r'\(Generic FCV reference\):') +_RE_GENERIC_FCV_COMMENT = re.compile(r"\(Generic FCV reference\):") GENERIC_FCV = [ - r'::kLatest', - r'::kLastContinuous', - r'::kLastLTS', - r'::kUpgradingFromLastLTSToLatest', - r'::kUpgradingFromLastContinuousToLatest', - r'::kDowngradingFromLatestToLastLTS', - r'::kDowngradingFromLatestToLastContinuous', - r'\.isUpgradingOrDowngrading', - r'->isUpgradingOrDowngrading', - r'::kDowngradingFromLatestToLastContinuous', - r'::kUpgradingFromLastLTSToLastContinuous', + r"::kLatest", + r"::kLastContinuous", + r"::kLastLTS", + r"::kUpgradingFromLastLTSToLatest", + r"::kUpgradingFromLastContinuousToLatest", + r"::kDowngradingFromLatestToLastLTS", + r"::kDowngradingFromLatestToLastContinuous", + r"\.isUpgradingOrDowngrading", + r"->isUpgradingOrDowngrading", + r"::kDowngradingFromLatestToLastContinuous", + r"::kUpgradingFromLastLTSToLastContinuous", ] -_RE_GENERIC_FCV_REF = re.compile(r'(' + '|'.join(GENERIC_FCV) + r')\b') -_RE_FEATURE_FLAG_IGNORE_FCV_CHECK_REF = re.compile(r'isEnabledAndIgnoreFCVUnsafe\(\)') -_RE_FEATURE_FLAG_IGNORE_FCV_CHECK_COMMENT = re.compile(r'\(Ignore FCV check\)') -_RE_HEADER = re.compile(r'\.(h|hpp)$') +_RE_GENERIC_FCV_REF = re.compile(r"(" + "|".join(GENERIC_FCV) + r")\b") +_RE_FEATURE_FLAG_IGNORE_FCV_CHECK_REF = re.compile(r"isEnabledAndIgnoreFCVUnsafe\(\)") +_RE_FEATURE_FLAG_IGNORE_FCV_CHECK_COMMENT = re.compile(r"\(Ignore FCV check\)") +_RE_HEADER = re.compile(r"\.(h|hpp)$") class Linter: """Simple C++ Linter.""" - _license_header = '''\ + _license_header = """\ /** * Copyright (C) {year}-present MongoDB, Inc. * @@ -61,7 +61,7 @@ class Linter: * delete this exception statement from your version. If you delete this * exception statement from all source files in the program, then also delete * it in the license file. - */'''.splitlines() + */""".splitlines() def __init__(self, file_name, raw_lines): """Create new linter.""" @@ -104,10 +104,12 @@ class Linter: def _check_newlines(self): """Check that each source file ends with a newline character.""" - if self.raw_lines and self.raw_lines[-1][-1:] != '\n': + if self.raw_lines and self.raw_lines[-1][-1:] != "\n": self._error( - len(self.raw_lines), 'mongo/final_newline', - 'Files must end with a newline character.') + len(self.raw_lines), + "mongo/final_newline", + "Files must end with a newline character.", + ) def _check_and_strip_comments(self): in_multi_line_comment = False @@ -145,9 +147,9 @@ class Linter: self.clean_lines.append(clean_line) - def _license_error(self, linenum, msg, category='legal/license'): - style_url = 'https://github.com/mongodb/mongo/wiki/Server-Code-Style' - self._error(linenum, category, '{} See {}'.format(msg, style_url)) + def _license_error(self, linenum, msg, category="legal/license"): + style_url = "https://github.com/mongodb/mongo/wiki/Server-Code-Style" + self._error(linenum, category, "{} See {}".format(msg, style_url)) return (False, linenum) def _check_for_server_side_public_license(self): @@ -157,22 +159,26 @@ class Linter: for linenum, lic_line in enumerate(self._license_header): src_line = next(src_iter, None) if src_line is None: - self._license_error(linenum, 'Missing or incomplete license header.') + self._license_error(linenum, "Missing or incomplete license header.") return linenum - lic_re = re.escape(lic_line).replace(r'\{year\}', r'\d{4}') + lic_re = re.escape(lic_line).replace(r"\{year\}", r"\d{4}") if not re.fullmatch(lic_re, src_line): self._license_error( - linenum, 'Incorrect license header.\n' + linenum, + "Incorrect license header.\n" ' Expected: "{}"\n' - ' Received: "{}"\n'.format(lic_line, src_line)) + ' Received: "{}"\n'.format(lic_line, src_line), + ) return linenum # Warn if SSPL appears in Enterprise code, which has a different license. expect_sspl_license = "enterprise" not in self.file_name if not expect_sspl_license: - self._license_error(linenum, - 'Incorrect license header found. Expected Enterprise license.', - category='legal/enterprise_license') + self._license_error( + linenum, + "Incorrect license header found. Expected Enterprise license.", + category="legal/enterprise_license", + ) return linenum return linenum @@ -183,27 +189,35 @@ class Linter: i = bisect.bisect_right(self.generic_fcv_comments, linenum) if not i or self.generic_fcv_comments[i - 1] < (linenum - 10): self._error( - linenum, 'mongodb/fcv', - 'Please add a comment containing "(Generic FCV reference):" within 10 lines ' + - 'before the generic FCV reference.') + linenum, + "mongodb/fcv", + 'Please add a comment containing "(Generic FCV reference):" within 10 lines ' + + "before the generic FCV reference.", + ) def _check_for_feature_flag_ignore_fcv(self, linenum): line = self.clean_lines[linenum] if _RE_FEATURE_FLAG_IGNORE_FCV_CHECK_REF.search(line): # Find the first ignore FCV check comment preceding the current line. - i = bisect.bisect_right(self.feature_flag_ignore_fcv_check_comments, linenum) - if not i or self.feature_flag_ignore_fcv_check_comments[i - 1] < (linenum - 10): + i = bisect.bisect_right( + self.feature_flag_ignore_fcv_check_comments, linenum + ) + if not i or self.feature_flag_ignore_fcv_check_comments[i - 1] < ( + linenum - 10 + ): self._error( - linenum, 'mongodb/fcv', - 'Please add a comment containing "(Ignore FCV check)":" within 10 lines ' + - 'before the isEnabledAndIgnoreFCVUnsafe() function call explaining why ' + - 'the FCV check is ignored.') + linenum, + "mongodb/fcv", + 'Please add a comment containing "(Ignore FCV check)":" within 10 lines ' + + "before the isEnabledAndIgnoreFCVUnsafe() function call explaining why " + + "the FCV check is ignored.", + ) def _error(self, linenum, category, message): if linenum in self.nolint_suppression: return - norm_file_name = self.file_name.replace('\\', '/') + norm_file_name = self.file_name.replace("\\", "/") # Custom clang-tidy check tests purposefully produce errors for # tests to find. They should be ignored. @@ -217,31 +231,35 @@ class Linter: # The following files are in the src/mongo/ directory but technically belong # in src/third_party/ because their copyright does not belong to MongoDB. - files_to_ignore = set([ - 'src/mongo/scripting/mozjs/PosixNSPR.cpp', - 'src/mongo/shell/linenoise.cpp', - 'src/mongo/shell/linenoise.h', - 'src/mongo/shell/mk_wcwidth.cpp', - 'src/mongo/shell/mk_wcwidth.h', - 'src/mongo/util/md5.cpp', - 'src/mongo/util/md5.h', - 'src/mongo/util/md5main.cpp', - 'src/mongo/util/net/ssl_stream.cpp', - 'src/mongo/util/scopeguard.h', - ]) + files_to_ignore = set( + [ + "src/mongo/scripting/mozjs/PosixNSPR.cpp", + "src/mongo/shell/linenoise.cpp", + "src/mongo/shell/linenoise.h", + "src/mongo/shell/mk_wcwidth.cpp", + "src/mongo/shell/mk_wcwidth.h", + "src/mongo/util/md5.cpp", + "src/mongo/util/md5.h", + "src/mongo/util/md5main.cpp", + "src/mongo/util/net/ssl_stream.cpp", + "src/mongo/util/scopeguard.h", + ] + ) for file_to_ignore in files_to_ignore: if file_to_ignore in norm_file_name: return # We count internally from 0 but users count from 1 for line numbers - print("Error: %s:%d - %s - %s" % (self.file_name, linenum + 1, category, message)) + print( + "Error: %s:%d - %s - %s" % (self.file_name, linenum + 1, category, message) + ) self._error_count += 1 def lint_file(file_name): """Lint file and print errors to console.""" - with io.open(file_name, encoding='utf-8') as file_stream: + with io.open(file_name, encoding="utf-8") as file_stream: raw_lines = file_stream.readlines() linter = Linter(file_name, raw_lines) @@ -251,11 +269,13 @@ def lint_file(file_name): def main(): # type: () -> int """Execute Main Entry point.""" - parser = argparse.ArgumentParser(description='MongoDB Simple C++ Linter.') + parser = argparse.ArgumentParser(description="MongoDB Simple C++ Linter.") - parser.add_argument('file', type=str, help="C++ input file") + parser.add_argument("file", type=str, help="C++ input file") - parser.add_argument('-v', '--verbose', action='count', help="Enable verbose tracing") + parser.add_argument( + "-v", "--verbose", action="count", help="Enable verbose tracing" + ) args = parser.parse_args() @@ -273,5 +293,5 @@ def main(): return 2 -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/buildscripts/linter/mypy.py b/buildscripts/linter/mypy.py index fc310f4fd26..e0550d5ef25 100644 --- a/buildscripts/linter/mypy.py +++ b/buildscripts/linter/mypy.py @@ -29,6 +29,6 @@ class MypyLinter(base.LinterBase): # Only idl and linter should be type checked by mypy. Other # files return errors under python 3 type checking. If we # return an empty list the runner will skip this file. - if 'idl' in file_name or 'linter' in file_name: + if "idl" in file_name or "linter" in file_name: return args + [file_name] return [] diff --git a/buildscripts/linter/runner.py b/buildscripts/linter/runner.py index 0a9835b0c68..115e2e71630 100644 --- a/buildscripts/linter/runner.py +++ b/buildscripts/linter/runner.py @@ -21,30 +21,41 @@ def _check_version(linter, cmd_path, args): try: cmd = cmd_path + args logging.info(str(cmd)) - process_handle = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + process_handle = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) output, stderr = process_handle.communicate() - decoded_output = output.decode('utf-8') + decoded_output = output.decode("utf-8") if process_handle.returncode: logging.info( "Version check failed for [%s], return code '%d'." - "Standard Output:\n%s\nStandard Error:\n%s", cmd, process_handle.returncode, - decoded_output, stderr) + "Standard Output:\n%s\nStandard Error:\n%s", + cmd, + process_handle.returncode, + decoded_output, + stderr, + ) pattern = r"\b(?:(%s) )?(?P\S+)\b" % (linter.cmd_name) required_version = pkg_resources.parse_version(linter.required_version) match = re.search(pattern, decoded_output) if match: - found_version = match.group('version') + found_version = match.group("version") else: - found_version = '0.0' + found_version = "0.0" if pkg_resources.parse_version(found_version) < required_version: logging.info( "Linter %s has wrong version for '%s'. Expected >= '%s'," - "Standard Output:\n'%s'\nStandard Error:\n%s", linter.cmd_name, cmd, - required_version, decoded_output, stderr) + "Standard Output:\n'%s'\nStandard Error:\n%s", + linter.cmd_name, + cmd, + required_version, + decoded_output, + stderr, + ) return False except OSError as os_error: @@ -89,8 +100,8 @@ def _find_linter(linter, config_dict): cmd = [cmd_str] else: # On Mac and with Homebrew, check for the binaries in /usr/local instead of sys.executable. - if sys.platform == 'darwin' and python_dir.startswith('/usr/local/opt'): - python_dir = '/usr/local/bin' + if sys.platform == "darwin" and python_dir.startswith("/usr/local/opt"): + python_dir = "/usr/local/bin" # On Linux, these scripts are installed in %PYTHONDIR%\bin like # '/opt/mongodbtoolchain/v4/bin', but they may point to the wrong interpreter. @@ -101,8 +112,10 @@ def _find_linter(linter, config_dict): if _check_version(linter, cmd, linter.get_lint_version_cmd_args()): return base.LinterInstance(linter, cmd) - logging.info("First version check failed for linter '%s', trying a different location.", - linter.cmd_name) + logging.info( + "First version check failed for linter '%s', trying a different location.", + linter.cmd_name, + ) # Check 2: Check USERBASE cmd = [os.path.join(site.getuserbase(), "bin", linter.cmd_name)] @@ -116,7 +129,10 @@ def _find_linter(linter, config_dict): # Check 4: When a virtualenv is setup the linter modules are not installed, so we need # to use the linters installed in '/opt/mongodbtoolchain/v4/bin'. - cmd = [sys.executable, os.path.join('/opt/mongodbtoolchain/v4/bin', linter.cmd_name)] + cmd = [ + sys.executable, + os.path.join("/opt/mongodbtoolchain/v4/bin", linter.cmd_name), + ] if _check_version(linter, cmd, linter.get_lint_version_cmd_args()): return base.LinterInstance(linter, cmd) @@ -141,7 +157,10 @@ To fix, install the needed python modules for Python 3.x: These commands are typically available via packages with names like python-pip, or python3-pip. See your OS documentation for help. -""", linter.cmd_name, linter.required_version) +""", + linter.cmd_name, + linter.required_version, + ) return None linter_instances.append(linter_instance) @@ -167,8 +186,13 @@ class LintRunner(object): with self.print_lock: print(line) - def run_lint(self, linter: base.LinterInstance, file_name: str, mongo_path: str, - fix_command: str) -> bool: + def run_lint( + self, + linter: base.LinterInstance, + file_name: str, + mongo_path: str, + fix_command: str, + ) -> bool: """Run the specified linter for the file.""" linter_args = linter.linter.get_lint_cmd_args(file_name) @@ -182,17 +206,17 @@ class LintRunner(object): cmd = linter.cmd_path + linter_args - logging.debug(' '.join(cmd)) + logging.debug(" ".join(cmd)) no_lint_errors = True try: if linter.linter.needs_file_diff(): # Need a file diff - with open(file_name, 'rb') as original_text: - original_file = original_text.read().decode('utf-8') + with open(file_name, "rb") as original_text: + original_file = original_text.read().decode("utf-8") - formatted_file = subprocess.check_output(cmd).decode('utf-8') + formatted_file = subprocess.check_output(cmd).decode("utf-8") if original_file != formatted_file: original_lines = original_file.splitlines() formatted_lines = formatted_file.splitlines() @@ -212,14 +236,18 @@ class LintRunner(object): count += 1 if count == 0: - print("ERROR: The files only differ in trailing whitespace? LF vs CRLF") + print( + "ERROR: The files only differ in trailing whitespace? LF vs CRLF" + ) no_lint_errors = False else: - subprocess.check_output(cmd).decode('utf-8') + subprocess.check_output(cmd).decode("utf-8") except subprocess.CalledProcessError as cpe: - self._safe_print("CMD [%s] failed:\n%s" % (' '.join(cmd), cpe.output.decode('utf-8'))) + self._safe_print( + "CMD [%s] failed:\n%s" % (" ".join(cmd), cpe.output.decode("utf-8")) + ) no_lint_errors = False return no_lint_errors @@ -231,9 +259,9 @@ class LintRunner(object): logging.debug(str(cmd)) try: - subprocess.check_output(cmd).decode('utf-8') + subprocess.check_output(cmd).decode("utf-8") except subprocess.CalledProcessError as cpe: - self._safe_print("CMD [%s] failed:\n%s" % (' '.join(cmd), cpe.output)) + self._safe_print("CMD [%s] failed:\n%s" % (" ".join(cmd), cpe.output)) return False return True diff --git a/buildscripts/lldb/lldb_commands.py b/buildscripts/lldb/lldb_commands.py index 794d5fcae24..d0ea2a711a2 100644 --- a/buildscripts/lldb/lldb_commands.py +++ b/buildscripts/lldb/lldb_commands.py @@ -7,13 +7,17 @@ import shlex def __lldb_init_module(debugger, *_args): """Register custom commands.""" debugger.HandleCommand( - "command script add -o -f lldb_commands.PrintGlobalServiceContext mongodb-service-context") + "command script add -o -f lldb_commands.PrintGlobalServiceContext mongodb-service-context" + ) debugger.HandleCommand( - "command script add -o -f lldb_commands.PrintGlobalServiceContext mongodb-dump-locks") + "command script add -o -f lldb_commands.PrintGlobalServiceContext mongodb-dump-locks" + ) debugger.HandleCommand( - "command script add -o -f lldb_commands.BreakpointOnAssert mongodb-breakpoint-assert") + "command script add -o -f lldb_commands.BreakpointOnAssert mongodb-breakpoint-assert" + ) debugger.HandleCommand( - "command script add -o -f lldb_commands.MongoDBFindBreakpoint mongodb-find-breakpoint") + "command script add -o -f lldb_commands.MongoDBFindBreakpoint mongodb-find-breakpoint" + ) debugger.HandleCommand("command script add -o -f lldb_commands.DumpGSC mongodb-gsc") debugger.HandleCommand("command alias mongodb-help help") @@ -42,13 +46,14 @@ def BreakpointOnAssert(debugger, command, _exec_ctx, _result, _internal_dict): arg_strs = shlex.split(command) - parser = argparse.ArgumentParser(description='Set a breakpoint on a usassert code.') - parser.add_argument('code', metavar='N', type=int, help='uassert code') + parser = argparse.ArgumentParser(description="Set a breakpoint on a usassert code.") + parser.add_argument("code", metavar="N", type=int, help="uassert code") args = parser.parse_args(arg_strs) debugger.HandleCommand( - "breakpoint set -n mongo::uassertedWithLocation -c \"(int)status._error.px->code == %s\"" % - args.code) + 'breakpoint set -n mongo::uassertedWithLocation -c "(int)status._error.px->code == %s"' + % args.code + ) def MongoDBFindBreakpoint(debugger, _command, exec_ctx, _result, _internal_dict): # pylint: disable=invalid-name @@ -90,17 +95,22 @@ def DumpGSC(_debugger, _command, exec_ctx, _result, _internal_dict): # pylint: for child in range(decoration_info.num_children): di = decoration_info.children[child] constructor = di.GetChildMemberWithName("constructor").__str__() - index = di.GetChildMemberWithName("descriptor").GetChildMemberWithName( - "_index").GetValueAsUnsigned() + index = ( + di.GetChildMemberWithName("descriptor") + .GetChildMemberWithName("_index") + .GetValueAsUnsigned() + ) type_name = constructor - type_name = type_name[0:len(type_name) - 1] - type_name = type_name[0:type_name.rindex(">")] - type_name = type_name[type_name.index("constructAt<"):].replace("constructAt<", "") + type_name = type_name[0 : len(type_name) - 1] + type_name = type_name[0 : type_name.rindex(">")] + type_name = type_name[type_name.index("constructAt<") :].replace( + "constructAt<", "" + ) # If the type is a pointer type, strip the * at the end. - if type_name.endswith('*'): - type_name = type_name[0:len(type_name) - 1] + if type_name.endswith("*"): + type_name = type_name[0 : len(type_name) - 1] type_name = type_name.rstrip() type_t = exec_ctx.target.FindTypes(type_name).GetTypeAtIndex(0) diff --git a/buildscripts/lldb/lldb_printers.py b/buildscripts/lldb/lldb_printers.py index f18efdea9d4..6c6073e0a68 100644 --- a/buildscripts/lldb/lldb_printers.py +++ b/buildscripts/lldb/lldb_printers.py @@ -7,6 +7,7 @@ To import script in lldb, run: This file must maintain Python 2 and 3 compatibility until Apple upgrades to Python 3 and updates their LLDB to use it. """ + from __future__ import print_function import datetime @@ -29,29 +30,49 @@ except ImportError: def __lldb_init_module(debugger, *_args): """Register pretty printers.""" debugger.HandleCommand( - "type summary add -s 'A${*var.__ptr_.__value_}' -x '^std::__1::unique_ptr<.+>$'") + "type summary add -s 'A${*var.__ptr_.__value_}' -x '^std::__1::unique_ptr<.+>$'" + ) - debugger.HandleCommand("type summary add -s '${var._value}' -x '^mongo::AtomicWord<.+>$'") - debugger.HandleCommand("type summary add -s '${var._M_base._M_i}' 'std::atomic'") + debugger.HandleCommand( + "type summary add -s '${var._value}' -x '^mongo::AtomicWord<.+>$'" + ) + debugger.HandleCommand( + "type summary add -s '${var._M_base._M_i}' 'std::atomic'" + ) debugger.HandleCommand("type summary add -s '${var._M_i}' -x '^std::atomic<.+>$'") - debugger.HandleCommand("type summary add mongo::BSONObj -F lldb_printers.BSONObjPrinter") debugger.HandleCommand( - "type summary add mongo::BSONElement -F lldb_printers.BSONElementPrinter") - - debugger.HandleCommand("type summary add mongo::Status -F lldb_printers.StatusPrinter") + "type summary add mongo::BSONObj -F lldb_printers.BSONObjPrinter" + ) debugger.HandleCommand( - "type summary add -x '^mongo::StatusWith<.+>$' -F lldb_printers.StatusWithPrinter") + "type summary add mongo::BSONElement -F lldb_printers.BSONElementPrinter" + ) - debugger.HandleCommand("type summary add mongo::StringData -F lldb_printers.StringDataPrinter") - debugger.HandleCommand("type summary add mongo::NamespaceString --summary-string '${var._ns}'") + debugger.HandleCommand( + "type summary add mongo::Status -F lldb_printers.StatusPrinter" + ) + debugger.HandleCommand( + "type summary add -x '^mongo::StatusWith<.+>$' -F lldb_printers.StatusWithPrinter" + ) + + debugger.HandleCommand( + "type summary add mongo::StringData -F lldb_printers.StringDataPrinter" + ) + debugger.HandleCommand( + "type summary add mongo::NamespaceString --summary-string '${var._ns}'" + ) debugger.HandleCommand("type summary add mongo::UUID -F lldb_printers.UUIDPrinter") - debugger.HandleCommand("type summary add mongo::Decimal128 -F lldb_printers.Decimal128Printer") - debugger.HandleCommand("type summary add mongo::Date_t -F lldb_printers.Date_tPrinter") + debugger.HandleCommand( + "type summary add mongo::Decimal128 -F lldb_printers.Decimal128Printer" + ) + debugger.HandleCommand( + "type summary add mongo::Date_t -F lldb_printers.Date_tPrinter" + ) debugger.HandleCommand( - "type summary add --summary-string '${var.m_pathname}' 'boost::filesystem::path'") + "type summary add --summary-string '${var.m_pathname}' 'boost::filesystem::path'" + ) debugger.HandleCommand( "type synthetic add -x '^boost::optional<.+>$' --python-class lldb_printers.OptionalPrinter" @@ -61,9 +82,11 @@ def __lldb_init_module(debugger, *_args): ) debugger.HandleCommand( - "type summary add -x '^boost::optional<.+>$' -F lldb_printers.OptionalSummaryPrinter") + "type summary add -x '^boost::optional<.+>$' -F lldb_printers.OptionalSummaryPrinter" + ) debugger.HandleCommand( - "type summary add mongo::ConstDataRange -F lldb_printers.ConstDataRangePrinter") + "type summary add mongo::ConstDataRange -F lldb_printers.ConstDataRangePrinter" + ) debugger.HandleCommand( "type synthetic add -x '^mongo::stdx::unordered_set<.+>$' --python-class lldb_printers.AbslHashSetPrinter" @@ -85,20 +108,23 @@ def StatusPrinter(valobj, *_args): # pylint: disable=invalid-name if px.GetValueAsUnsigned() == 0: return "Status::OK()" code = px.GetChildMemberWithName("code").GetValue() - reason = px.GetChildMemberWithName("reason").\ - GetSummary() + reason = px.GetChildMemberWithName("reason").GetSummary() return "Status({}, {})".format(code, reason) def StatusWithPrinter(valobj, *_args): # pylint: disable=invalid-name """Extend the StatusPrinter to print the value of With for a StatusWith.""" status = valobj.GetChildMemberWithName("_status") - code = status.GetChildMemberWithName("_error").\ - GetChildMemberWithName("px").\ - GetChildMemberWithName("code").\ - GetValueAsUnsigned() + code = ( + status.GetChildMemberWithName("_error") + .GetChildMemberWithName("px") + .GetChildMemberWithName("code") + .GetValueAsUnsigned() + ) if code == 0: - return "StatusWith(OK, {})".format(valobj.GetChildMemberWithName("_t").children[0]) + return "StatusWith(OK, {})".format( + valobj.GetChildMemberWithName("_t").children[0] + ) rep = StatusPrinter(status) return rep.replace("Status", "StatusWith", 1) @@ -107,10 +133,12 @@ def StringDataPrinter(valobj, *_args): # pylint: disable=invalid-name """Print StringData value.""" ptr = valobj.GetChildMemberWithName("_data").GetValueAsUnsigned() if ptr == 0: - return 'nullptr' + return "nullptr" size1 = valobj.GetChildMemberWithName("_size").GetValueAsUnsigned(0) - return '"{}"'.format(valobj.GetProcess().ReadMemory(ptr, size1, lldb.SBError()).decode("utf-8")) + return '"{}"'.format( + valobj.GetProcess().ReadMemory(ptr, size1, lldb.SBError()).decode("utf-8") + ) def read_memory_as_hex(process, address, size): @@ -179,9 +207,11 @@ def BSONElementPrinter(valobj, *_args): # pylint: disable=invalid-name mem = bytes(memoryview(valobj.GetProcess().ReadMemory(ptr, size, lldb.SBError()))) # Call an internal bson method to directly convert an BSON element to a string - el_tuple = bson._element_to_dict(mem, memoryview(mem), 0, len(mem), DEFAULT_CODEC_OPTIONS) # pylint: disable=protected-access + el_tuple = bson._element_to_dict( + mem, memoryview(mem), 0, len(mem), DEFAULT_CODEC_OPTIONS + ) # pylint: disable=protected-access - return "\"%s\": %s" % (el_tuple[0], el_tuple[1]) + return '"%s": %s' % (el_tuple[0], el_tuple[1]) def Date_tPrinter(valobj, *_args): # pylint: disable=invalid-name @@ -240,8 +270,11 @@ class UniquePtrPrinter: Always prints object pointed at by the ptr. """ if index == 0: - return self.valobj.GetChildMemberWithName("__ptr_").GetChildMemberWithName( - "__value_").Dereference() + return ( + self.valobj.GetChildMemberWithName("__ptr_") + .GetChildMemberWithName("__value_") + .Dereference() + ) else: return None @@ -286,7 +319,10 @@ class OptionalPrinter: def update(self): """Check if the optional has changed.""" - self.is_init = self.valobj.GetChildMemberWithName("m_initialized").GetValueAsUnsigned() != 0 + self.is_init = ( + self.valobj.GetChildMemberWithName("m_initialized").GetValueAsUnsigned() + != 0 + ) self.value = None if self.is_init: temp_type = self.valobj.GetType().GetTemplateArgumentType(0) @@ -358,7 +394,8 @@ class AbslHashSetPrinter: pos += 1 value = self.valobj.GetChildMemberWithName("slots_").CreateChildAtOffset( - "%d" % (index), (pos - 1) * self.data_size, self.data_type) + "%d" % (index), (pos - 1) * self.data_size, self.data_type + ) return value.Dereference() def has_children(self): @@ -366,7 +403,9 @@ class AbslHashSetPrinter: return True def update(self): - self.capacity = self.valobj.GetChildMemberWithName("capacity_").GetValueAsUnsigned() + self.capacity = self.valobj.GetChildMemberWithName( + "capacity_" + ).GetValueAsUnsigned() self.data_type = self.valobj.GetChildMemberWithName("slots_").GetType() @@ -374,7 +413,8 @@ class AbslHashSetPrinter: try: self.data_type = resolve_type_to_base( - self.valobj.GetChildMemberWithName("slots_").GetType()).GetPointerType() + self.valobj.GetChildMemberWithName("slots_").GetType() + ).GetPointerType() except: # pylint: disable=bare-except print("Exception: " + str(sys.exc_info())) @@ -412,15 +452,20 @@ class AbslHashMapPrinter: pos += 1 - return self.valobj.GetChildMemberWithName("slots_").GetChildAtIndex(pos - 1, False, - True).Dereference() + return ( + self.valobj.GetChildMemberWithName("slots_") + .GetChildAtIndex(pos - 1, False, True) + .Dereference() + ) def has_children(self): """Match LLDB's expected API.""" return True def update(self): - self.capacity = self.valobj.GetChildMemberWithName("capacity_").GetValueAsUnsigned() + self.capacity = self.valobj.GetChildMemberWithName( + "capacity_" + ).GetValueAsUnsigned() self.data_type = self.valobj.GetChildMemberWithName("slots_").GetType() self.data_size = self.data_type.GetByteSize() @@ -429,7 +474,6 @@ class AbslHashMapPrinter: # LLDB Debugging utility functions # def print_type_base(data_type): - print("type: %s " % (data_type)) print("basic_type: %s " % (data_type.GetBasicType())) # print("canonical: %s " % (data_type.GetCanonicalType())) @@ -446,13 +490,19 @@ def print_type_base(data_type): print("IsPolymorphicClass: %s " % (data_type.IsPolymorphicClass())) print("GetNumberOfFields: %s " % (data_type.GetNumberOfFields())) print("GetNumberOfMemberFunctions: %s " % (data_type.GetNumberOfMemberFunctions())) - print("GetNumberOfTemplateArguments: %s " % (data_type.GetNumberOfTemplateArguments())) - print("GetNumberOfVirtualBaseClasses: %s " % (data_type.GetNumberOfVirtualBaseClasses())) - print("GetNumberOfDirectBaseClasses: %s " % (data_type.GetNumberOfDirectBaseClasses())) + print( + "GetNumberOfTemplateArguments: %s " % (data_type.GetNumberOfTemplateArguments()) + ) + print( + "GetNumberOfVirtualBaseClasses: %s " + % (data_type.GetNumberOfVirtualBaseClasses()) + ) + print( + "GetNumberOfDirectBaseClasses: %s " % (data_type.GetNumberOfDirectBaseClasses()) + ) def print_type(data_type): - if isinstance(data_type, lldb.SBTypeMember): print("TypeMember: " + str(data_type)) print_type(data_type.GetType()) @@ -465,7 +515,6 @@ def print_type(data_type): def walk_type_to_base(data_type): - print("walk_type: %s " % (data_type)) if data_type.IsPointerType(): print("===P") @@ -480,7 +529,6 @@ def walk_type_to_base(data_type): def resolve_type_to_base(data_type): - if isinstance(data_type, lldb.SBTypeMember): return resolve_type_to_base(data_type.GetType()) diff --git a/buildscripts/make_vcxproj.py b/buildscripts/make_vcxproj.py index 65e0706e434..acee86b08cc 100644 --- a/buildscripts/make_vcxproj.py +++ b/buildscripts/make_vcxproj.py @@ -98,7 +98,9 @@ def _read_vcxproj(file_name): tree = ET.parse(file_name) - interesting_tags = ["{%s}%s" % (VCXPROJ_NAMESPACE, tag) for tag in VCXPROJ_FIELDS_TO_PRESERVE] + interesting_tags = [ + "{%s}%s" % (VCXPROJ_NAMESPACE, tag) for tag in VCXPROJ_FIELDS_TO_PRESERVE + ] save_elements = {} @@ -119,7 +121,9 @@ def _replace_vcxproj(file_name, restore_elements): tree = ET.parse(file_name) - interesting_tags = ["{%s}%s" % (VCXPROJ_NAMESPACE, tag) for tag in VCXPROJ_FIELDS_TO_PRESERVE] + interesting_tags = [ + "{%s}%s" % (VCXPROJ_NAMESPACE, tag) for tag in VCXPROJ_FIELDS_TO_PRESERVE + ] for parent in tree.getroot(): for child in parent: @@ -137,7 +141,9 @@ def _replace_vcxproj(file_name, restore_elements): # Strip the "ns0:" namespace prefix because ElementTree does not support default namespaces. str_value = ( - str_value.replace("\n") + self.vcxproj.write( + "\n" + ) self.vcxproj.write(" \n") for command in self.compiles: @@ -216,7 +224,9 @@ class ProjFileGenerator(object): + "\n" ) else: - self.vcxproj.write(' \n') + self.vcxproj.write( + ' \n' + ) self.vcxproj.write(" \n") self.filters = open(self.target + ".vcxproj.filters", "w") @@ -338,7 +348,9 @@ class ProjFileGenerator(object): self.filters.write(" \n") for file_name in sorted(dirs): self.filters.write(" \n" % file_name) - self.filters.write(" {%s}\n" % uuid.uuid4()) + self.filters.write( + " {%s}\n" % uuid.uuid4() + ) self.filters.write(" \n") self.filters.write(" \n") @@ -347,7 +359,9 @@ class ProjFileGenerator(object): for file_name in sorted(self.files): if not self.__is_header(file_name): self.filters.write(" \n" % file_name) - self.filters.write(" %s\n" % os.path.dirname(file_name)) + self.filters.write( + " %s\n" % os.path.dirname(file_name) + ) self.filters.write(" \n") self.filters.write(" \n") @@ -356,7 +370,9 @@ class ProjFileGenerator(object): for file_name in sorted(self.files): if self.__is_header(file_name): self.filters.write(" \n" % file_name) - self.filters.write(" %s\n" % os.path.dirname(file_name)) + self.filters.write( + " %s\n" % os.path.dirname(file_name) + ) self.filters.write(" \n") self.filters.write(" \n") @@ -364,7 +380,9 @@ class ProjFileGenerator(object): self.filters.write(" \n") for file_name in sorted(bazel_files): self.filters.write(" \n" % file_name) - self.filters.write(" %s\n" % os.path.dirname(file_name)) + self.filters.write( + " %s\n" % os.path.dirname(file_name) + ) self.filters.write(" \n") self.filters.write(" \n") diff --git a/buildscripts/mongo_toolchain.py b/buildscripts/mongo_toolchain.py index 39224dea7f3..4328b5e9900 100644 --- a/buildscripts/mongo_toolchain.py +++ b/buildscripts/mongo_toolchain.py @@ -14,7 +14,9 @@ from bazelisk import get_bazel_path, make_bazel_cmd # Get relative imports to work when the package is not installed on the PYTHONPATH. if __name__ == "__main__" and __package__ is None: - sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(os.path.realpath(__file__))))) + sys.path.append( + os.path.dirname(os.path.dirname(os.path.abspath(os.path.realpath(__file__)))) + ) SUPPORTED_VERSIONS = "v5" @@ -49,10 +51,10 @@ class MongoToolchain: def check_exists(self) -> None: for directory in ( - self._root_dir, - self._get_bin_dir_path(), - self._get_include_dir_path(), - self._get_lib_dir_path(), + self._root_dir, + self._get_bin_dir_path(), + self._get_include_dir_path(), + self._get_lib_dir_path(), ): if not directory.is_dir(): raise MongoToolchainNotFoundError(f"{directory} is not a directory") @@ -102,10 +104,17 @@ def _execute_bazel(argv): def _fetch_bazel_toolchain(version: str) -> None: try: _execute_bazel( - ["build", "--bes_backend=", "--bes_results_url=", f"@mongo_toolchain_{version}//:all"]) + [ + "build", + "--bes_backend=", + "--bes_results_url=", + f"@mongo_toolchain_{version}//:all", + ] + ) except subprocess.CalledProcessError as e: raise MongoToolchainNotFoundError( - f"Failed to fetch bazel toolchain: `{e.cmd}` exited with code {e.returncode}") + f"Failed to fetch bazel toolchain: `{e.cmd}` exited with code {e.returncode}" + ) def _get_bazel_execroot() -> Path: @@ -113,7 +122,8 @@ def _get_bazel_execroot() -> Path: execroot_str = _execute_bazel(["info", "execution_root"]) except subprocess.CalledProcessError as e: raise MongoToolchainNotFoundError( - f"Couldn't find bazel execroot: `{e.cmd}` exited with code {e.returncode}") + f"Couldn't find bazel execroot: `{e.cmd}` exited with code {e.returncode}" + ) return Path(execroot_str) @@ -133,7 +143,8 @@ def _get_bazel_toolchain(version: str) -> MongoToolchain: _fetch_bazel_toolchain(version) if not path.is_dir(): raise MongoToolchainNotFoundError( - f"Couldn't find bazel toolchain: {path} is not a directory") + f"Couldn't find bazel toolchain: {path} is not a directory" + ) return _get_toolchain_from_path(path) @@ -145,8 +156,9 @@ def _get_installed_toolchain(version: str): return _get_toolchain_from_path(_get_installed_toolchain_path(version)) -def get_mongo_toolchain(*, version: str | None = None, - from_bazel: bool | None = None) -> MongoToolchain: +def get_mongo_toolchain( + *, version: str | None = None, from_bazel: bool | None = None +) -> MongoToolchain: # When running under bazel this environment variable will be set and will point to the # toolchain the target was configured to use. It can also be set manually to override # a script's selection of toolchain. @@ -197,10 +209,11 @@ if __name__ == "__main__": @_app.command() def main( - tool: Annotated[Optional[str], typer.Argument()] = None, - version: Annotated[Optional[str], typer.Option("--version")] = None, - from_bazel: Annotated[Optional[bool], - typer.Option("--bazel/--no-bazel")] = None, + tool: Annotated[Optional[str], typer.Argument()] = None, + version: Annotated[Optional[str], typer.Option("--version")] = None, + from_bazel: Annotated[ + Optional[bool], typer.Option("--bazel/--no-bazel") + ] = None, ): """ Prints the path to tools in the mongo toolchain or the toolchain's root directory (which diff --git a/buildscripts/mongosymb.py b/buildscripts/mongosymb.py index 3b3beeafaf4..5d3e8fe607d 100755 --- a/buildscripts/mongosymb.py +++ b/buildscripts/mongosymb.py @@ -36,9 +36,9 @@ from tenacity import Retrying, retry_if_result, stop_after_delay, wait_fixed sys.path.append(str(Path(os.getcwd(), __file__).parent.parent)) from buildscripts.build_system_options import ( - PathOptions, #pylint: disable=wrong-import-position + PathOptions, # pylint: disable=wrong-import-position ) -from buildscripts.util.oauth import ( #pylint: disable=wrong-import-position +from buildscripts.util.oauth import ( # pylint: disable=wrong-import-position Configs, get_client_cred_oauth_credentials, get_oauth_credentials, @@ -101,8 +101,11 @@ class S3BuildidDbgFileResolver(DbgFileResolver): self._get_from_s3(build_id) except Exception: # noqa pylint: disable=broad-except ex = sys.exc_info()[0] - sys.stderr.write("Failed to find debug symbols for {} in s3: {}\n".format( - build_id, ex)) + sys.stderr.write( + "Failed to find debug symbols for {} in s3: {}\n".format( + build_id, ex + ) + ) return None if not os.path.exists(build_id_path): return None @@ -111,9 +114,15 @@ class S3BuildidDbgFileResolver(DbgFileResolver): def _get_from_s3(self, build_id): """Download debug symbols from S3.""" subprocess.check_call( - ['wget', 'https://s3.amazonaws.com/{}/{}.debug.gz'.format(self._s3_bucket, build_id)], - cwd=self._cache_dir) - subprocess.check_call(['gunzip', build_id + ".debug.gz"], cwd=self._cache_dir) + [ + "wget", + "https://s3.amazonaws.com/{}/{}.debug.gz".format( + self._s3_bucket, build_id + ), + ], + cwd=self._cache_dir, + ) + subprocess.check_call(["gunzip", build_id + ".debug.gz"], cwd=self._cache_dir) class CachedResults(object): @@ -195,10 +204,19 @@ class PathResolver(DbgFileResolver): default_client_credentials_user_name = "client-user" download_timeout_secs = timedelta(minutes=4).total_seconds() - def __init__(self, host: str = None, cache_size: int = 0, cache_dir: str = None, - client_credentials_scope: str = None, client_credentials_user_name: str = None, - client_id: str = None, client_secret: str = None, redirect_port: int = None, - scope: str = None, auth_domain: str = None): + def __init__( + self, + host: str = None, + cache_size: int = 0, + cache_dir: str = None, + client_credentials_scope: str = None, + client_credentials_user_name: str = None, + client_id: str = None, + client_secret: str = None, + redirect_port: int = None, + scope: str = None, + auth_domain: str = None, + ): """ Initialize instance. @@ -210,17 +228,25 @@ class PathResolver(DbgFileResolver): self._cached_results = CachedResults(max_cache_size=cache_size) self.cache_dir = cache_dir or self.default_cache_dir self.mci_build_dir = None - self.client_credentials_scope = client_credentials_scope or self.default_client_credentials_scope - self.client_credentials_user_name = client_credentials_user_name or self.default_client_credentials_user_name + self.client_credentials_scope = ( + client_credentials_scope or self.default_client_credentials_scope + ) + self.client_credentials_user_name = ( + client_credentials_user_name or self.default_client_credentials_user_name + ) self.client_id = client_id self.client_secret = client_secret self.redirect_port = redirect_port self.scope = scope self.auth_domain = auth_domain - self.configs = Configs(client_credentials_scope=self.client_credentials_scope, - client_credentials_user_name=self.client_credentials_user_name, - client_id=self.client_id, auth_domain=self.auth_domain, - redirect_port=self.redirect_port, scope=self.scope) + self.configs = Configs( + client_credentials_scope=self.client_credentials_scope, + client_credentials_user_name=self.client_credentials_user_name, + client_id=self.client_id, + auth_domain=self.auth_domain, + redirect_port=self.redirect_port, + scope=self.scope, + ) self.http_client = requests.Session() self.path_options = PathOptions() @@ -237,28 +263,41 @@ class PathResolver(DbgFileResolver): if os.path.exists(self.default_creds_file_path): with open(self.default_creds_file_path) as cfile: data = json.loads(cfile.read()) - access_token, expire_time = data.get("access_token"), data.get("expire_time") + access_token, expire_time = ( + data.get("access_token"), + data.get("expire_time"), + ) if time.time() < expire_time: # credentials not expired yet - self.http_client.headers.update({"Authorization": f"Bearer {access_token}"}) + self.http_client.headers.update( + {"Authorization": f"Bearer {access_token}"} + ) return if self.client_id and self.client_secret: # auth using secrets - credentials = get_client_cred_oauth_credentials(self.client_id, self.client_secret, - self.configs) + credentials = get_client_cred_oauth_credentials( + self.client_id, self.client_secret, self.configs + ) else: # since we don't have access to secrets, ask user to auth manually - credentials = get_oauth_credentials(configs=self.configs, print_auth_url=True) - self.http_client.headers.update({"Authorization": f"Bearer {credentials.access_token}"}) + credentials = get_oauth_credentials( + configs=self.configs, print_auth_url=True + ) + self.http_client.headers.update( + {"Authorization": f"Bearer {credentials.access_token}"} + ) # write credentials to local file for further useage with open(self.default_creds_file_path, "w") as cfile: cfile.write( - json.dumps({ - "access_token": credentials.access_token, - "expire_time": time.time() + credentials.expires_in - })) + json.dumps( + { + "access_token": credentials.access_token, + "expire_time": time.time() + credentials.expires_in, + } + ) + ) @staticmethod def is_valid_path(path: str) -> bool: @@ -343,7 +382,9 @@ class PathResolver(DbgFileResolver): :param url: URL string :param local_path: full name for local file """ - with requests.get(url, stream=True, timeout=self.download_timeout_secs) as response: + with requests.get( + url, stream=True, timeout=self.download_timeout_secs + ) as response: with open(local_path, "wb") as file: for chunk in response.iter_content(chunk_size=2 * 1024 * 1024): file.write(chunk) @@ -366,19 +407,29 @@ class PathResolver(DbgFileResolver): search_parameters = {"build_id": build_id} if version: search_parameters["version"] = version - print(f"Getting data from service... Search parameters: {search_parameters}") - response = self.http_client.get(f"{self.host}/find_by_id", params=search_parameters) + print( + f"Getting data from service... Search parameters: {search_parameters}" + ) + response = self.http_client.get( + f"{self.host}/find_by_id", params=search_parameters + ) if response.status_code != 200: sys.stderr.write( f"Server returned unsuccessful status: {response.status_code}, " - f"response body: {response.text}\n") + f"response body: {response.text}\n" + ) return None else: data = response.json().get("data", {}) - path, binary_name = data.get("debug_symbols_url"), data.get("file_name") + path, binary_name = ( + data.get("debug_symbols_url"), + data.get("file_name"), + ) except Exception as err: # noqa pylint: disable=broad-except - sys.stderr.write(f"Error occurred while trying to get response from server " - f"for buildId({build_id}): {err}\n") + sys.stderr.write( + f"Error occurred while trying to get response from server " + f"for buildId({build_id}): {err}\n" + ) return None # update cached results @@ -449,8 +500,13 @@ def parse_input(trace_doc, dbg_path_resolver): addr -= 1 frames.append( dict( - path=dbg_path_resolver.get_dbg_file(soinfo), buildId=soinfo.get("buildId", None), - offset=frame["o"], addr="0x{:x}".format(addr), symbol=frame.get("s", None))) + path=dbg_path_resolver.get_dbg_file(soinfo), + buildId=soinfo.get("buildId", None), + offset=frame["o"], + addr="0x{:x}".format(addr), + symbol=frame.get("s", None), + ) + ) return frames @@ -464,30 +520,38 @@ def get_version(trace_doc: Dict[str, Any]) -> Optional[str]: return trace_doc.get("processInfo", {}).get("mongodbVersion") -def symbolize_frames(trace_doc, dbg_path_resolver, symbolizer_path, dsym_hint, input_format, - **kwargs): +def symbolize_frames( + trace_doc, dbg_path_resolver, symbolizer_path, dsym_hint, input_format, **kwargs +): """Return a list of symbolized stack frames from a trace_doc in MongoDB stack dump format.""" # Keep frames in kwargs to avoid changing the function signature. frames = kwargs.get("frames") if frames is None: total_seconds_for_retries = kwargs.get("total_seconds_for_retries", 0) - frames = preprocess_frames_with_retries(dbg_path_resolver, trace_doc, input_format, - total_seconds_for_retries) + frames = preprocess_frames_with_retries( + dbg_path_resolver, trace_doc, input_format, total_seconds_for_retries + ) if not symbolizer_path: symbolizer_path = os.environ.get(SYMBOLIZER_PATH_ENV) if not symbolizer_path: - print(f"Env value for '{SYMBOLIZER_PATH_ENV}' not found, using" - f" '{DEFAULT_SYMBOLIZER_PATH}' as a default executable path.") + print( + f"Env value for '{SYMBOLIZER_PATH_ENV}' not found, using" + f" '{DEFAULT_SYMBOLIZER_PATH}' as a default executable path." + ) symbolizer_path = DEFAULT_SYMBOLIZER_PATH symbolizer_args = [symbolizer_path] for dh in dsym_hint: symbolizer_args.append("-dsym-hint={}".format(dh)) - symbolizer_process = subprocess.Popen(args=symbolizer_args, close_fds=True, - stdin=subprocess.PIPE, stdout=subprocess.PIPE, - stderr=sys.stdout) + symbolizer_process = subprocess.Popen( + args=symbolizer_args, + close_fds=True, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=sys.stdout, + ) def extract_symbols(stdin): """Extract symbol information from the output of llvm-symbolizer. @@ -512,8 +576,10 @@ def symbolize_frames(trace_doc, dbg_path_resolver, symbolizer_path, dsym_hint, i result.append({"fn": line.strip()}) step = 1 else: - file_name, line, column = line.strip().rsplit(':', 3) - result[-1].update({"file": file_name, "column": int(column), "line": int(line)}) + file_name, line, column = line.strip().rsplit(":", 3) + result[-1].update( + {"file": file_name, "column": int(column), "line": int(line)} + ) step = 0 return result @@ -530,8 +596,9 @@ def symbolize_frames(trace_doc, dbg_path_resolver, symbolizer_path, dsym_hint, i return frames -def preprocess_frames(dbg_path_resolver: DbgFileResolver, trace_doc: Dict[str, Any], - input_format: str) -> List[Dict[str, Any]]: +def preprocess_frames( + dbg_path_resolver: DbgFileResolver, trace_doc: Dict[str, Any], input_format: str +) -> List[Dict[str, Any]]: """ Process the paths in frame objects. @@ -565,9 +632,12 @@ def has_high_not_found_paths_ratio(frames: List[Dict[str, Any]]) -> bool: return not_found_ratio >= 0.5 -def preprocess_frames_with_retries(dbg_path_resolver: DbgFileResolver, trace_doc: Dict[str, Any], - input_format: str, - total_seconds_for_retries: int = 0) -> List[Dict[str, Any]]: +def preprocess_frames_with_retries( + dbg_path_resolver: DbgFileResolver, + trace_doc: Dict[str, Any], + input_format: str, + total_seconds_for_retries: int = 0, +) -> List[Dict[str, Any]]: """ Process the paths in frame objects. @@ -579,9 +649,11 @@ def preprocess_frames_with_retries(dbg_path_resolver: DbgFileResolver, trace_doc """ retrying = Retrying( - retry=retry_if_result(has_high_not_found_paths_ratio), wait=wait_fixed(60), + retry=retry_if_result(has_high_not_found_paths_ratio), + wait=wait_fixed(60), stop=stop_after_delay(total_seconds_for_retries), - retry_error_callback=lambda retry_state: retry_state.outcome.result()) + retry_error_callback=lambda retry_state: retry_state.outcome.result(), + ) return retrying(preprocess_frames, dbg_path_resolver, trace_doc, input_format) @@ -592,10 +664,15 @@ def classic_output(frames, outfile, **kwargs): # pylint: disable=unused-argumen symbinfo = frame.get("symbinfo") if symbinfo: for sframe in symbinfo: - outfile.write(" {file:s}:{line:d}:{column:d}: {fn:s}\n".format(**sframe)) + outfile.write( + " {file:s}:{line:d}:{column:d}: {fn:s}\n".format(**sframe) + ) else: - outfile.write(" Couldn't extract symbols: path={path}\n".format( - path=frame.get('path', 'no value found'))) + outfile.write( + " Couldn't extract symbols: path={path}\n".format( + path=frame.get("path", "no value found") + ) + ) def make_argument_parser(parser=None, **kwargs): @@ -603,42 +680,68 @@ def make_argument_parser(parser=None, **kwargs): if parser is None: parser = argparse.ArgumentParser(**kwargs) - parser.add_argument('--dsym-hint', default=[], action='append') - parser.add_argument('--symbolizer-path', default='') - parser.add_argument('--input-format', choices=['classic', 'thin'], default='classic') - parser.add_argument('--output-format', choices=['classic', 'json'], default='classic', - help='"json" shows some extra information') - parser.add_argument('--debug-file-resolver', choices=['path', 's3', 'pr'], default='pr') - parser.add_argument('--src-dir-to-move', action="store", type=str, default=None, - help="Specify a src dir to move to /data/mci/{original_buildid}/src") + parser.add_argument("--dsym-hint", default=[], action="append") + parser.add_argument("--symbolizer-path", default="") parser.add_argument( - '--total-seconds-for-retries', default=0, type=int, + "--input-format", choices=["classic", "thin"], default="classic" + ) + parser.add_argument( + "--output-format", + choices=["classic", "json"], + default="classic", + help='"json" shows some extra information', + ) + parser.add_argument( + "--debug-file-resolver", choices=["path", "s3", "pr"], default="pr" + ) + parser.add_argument( + "--src-dir-to-move", + action="store", + type=str, + default=None, + help="Specify a src dir to move to /data/mci/{original_buildid}/src", + ) + parser.add_argument( + "--total-seconds-for-retries", + default=0, + type=int, help="If web service fails to find path for given build id, it could be because mapping " "process was not finished yet. We can wait for it to finish and retry again. Each retry" " adds 2 minutes to previous wait time. It is guaranteed that total wait time does not exceed this " - "specified amount.") + "specified amount.", + ) - parser.add_argument('--live', action='store_true') + parser.add_argument("--live", action="store_true") s3_group = parser.add_argument_group( - "s3 options", description='Options used with \'--debug-file-resolver s3\'') - s3_group.add_argument('--s3-cache-dir') - s3_group.add_argument('--s3-bucket') + "s3 options", description="Options used with '--debug-file-resolver s3'" + ) + s3_group.add_argument("--s3-cache-dir") + s3_group.add_argument("--s3-bucket") pr_group = parser.add_argument_group( - 'Path Resolver options (Path Resolver uses a special web service to retrieve URL of debug symbols file for ' + "Path Resolver options (Path Resolver uses a special web service to retrieve URL of debug symbols file for " 'a given BuildID), we use "pr" as a shorter/easier name for this', - description='Options used with \'--debug-file-resolver pr\'') - pr_group.add_argument('--pr-host', default='', - help='URL of web service running the API to get debug symbol URL') - pr_group.add_argument('--pr-cache-dir', default='', - help='Full path to a directory to store cache/files') - pr_group.add_argument('--client-secret', default='', help='Secret key for Okta Oauth') - pr_group.add_argument('--client-id', default='', help='Client id for Okta Oauth') + description="Options used with '--debug-file-resolver pr'", + ) + pr_group.add_argument( + "--pr-host", + default="", + help="URL of web service running the API to get debug symbol URL", + ) + pr_group.add_argument( + "--pr-cache-dir", + default="", + help="Full path to a directory to store cache/files", + ) + pr_group.add_argument( + "--client-secret", default="", help="Secret key for Okta Oauth" + ) + pr_group.add_argument("--client-id", default="", help="Client id for Okta Oauth") # caching mechanism is currently not fully developed and needs more advanced cleaning techniques, we add an option # to enable it after completing the implementation # Look for symbols in the cwd by default. - parser.add_argument('path_to_executable', nargs="?") + parser.add_argument("path_to_executable", nargs="?") return parser @@ -657,7 +760,7 @@ def substitute_stdin(options, resolver): line = line.strip() - if 'Frame: 0x' in line: + if "Frame: 0x" in line: continue if backtrace_indicator in line: @@ -687,13 +790,17 @@ def main(options): """Execute Main program.""" resolver = None - if options.debug_file_resolver == 'path': + if options.debug_file_resolver == "path": resolver = PathDbgFileResolver(options.path_to_executable) - elif options.debug_file_resolver == 's3': + elif options.debug_file_resolver == "s3": resolver = S3BuildidDbgFileResolver(options.s3_cache_dir, options.s3_bucket) - elif options.debug_file_resolver == 'pr': - resolver = PathResolver(host=options.pr_host, cache_dir=options.pr_cache_dir, - client_secret=options.client_secret, client_id=options.client_id) + elif options.debug_file_resolver == "pr": + resolver = PathResolver( + host=options.pr_host, + cache_dir=options.pr_cache_dir, + client_secret=options.client_secret, + client_id=options.client_id, + ) if options.live: print("Entering live mode") @@ -706,8 +813,10 @@ def main(options): trace_doc = sys.stdin.read() if not trace_doc or not trace_doc.strip(): - print("Please provide the backtrace through stdin for symbolization;" - " e.g. `your/symbolization/command < /file/with/stacktrace`") + print( + "Please provide the backtrace through stdin for symbolization;" + " e.g. `your/symbolization/command < /file/with/stacktrace`" + ) # Search the trace_doc for an object having "backtrace" and "processInfo" keys. def bt_search(obj): @@ -725,7 +834,7 @@ def main(options): # given a log file including traceback, # we try to find traceback from that file, analyzing each line until we find it for line in trace_doc.splitlines(): - possible_trace_doc = line[line.find('{'):] + possible_trace_doc = line[line.find("{") :] try: possible_trace_doc = json.JSONDecoder().raw_decode(possible_trace_doc)[0] trace_doc = bt_search(possible_trace_doc) @@ -738,29 +847,32 @@ def main(options): sys.exit(1) output_fn = None - if options.output_format == 'json': + if options.output_format == "json": output_fn = json.dump - if options.output_format == 'classic': + if options.output_format == "classic": output_fn = classic_output - frames = preprocess_frames_with_retries(resolver, trace_doc, options.input_format, - options.total_seconds_for_retries) + frames = preprocess_frames_with_retries( + resolver, trace_doc, options.input_format, options.total_seconds_for_retries + ) if options.src_dir_to_move and resolver.mci_build_dir is not None: try: os.makedirs(resolver.mci_build_dir) os.symlink( os.path.join(os.getcwd(), options.src_dir_to_move), - os.path.join(resolver.mci_build_dir, 'src')) + os.path.join(resolver.mci_build_dir, "src"), + ) except FileExistsError: pass - frames = symbolize_frames(frames=frames, trace_doc=trace_doc, dbg_path_resolver=resolver, - **vars(options)) + frames = symbolize_frames( + frames=frames, trace_doc=trace_doc, dbg_path_resolver=resolver, **vars(options) + ) output_fn(frames, sys.stdout, indent=2) -if __name__ == '__main__': +if __name__ == "__main__": symbolizer_options = make_argument_parser(description=__doc__).parse_args() main(symbolizer_options) sys.exit(0) diff --git a/buildscripts/mongosymb_multithread.py b/buildscripts/mongosymb_multithread.py index 801aa8329ba..6e8a20f2b9b 100755 --- a/buildscripts/mongosymb_multithread.py +++ b/buildscripts/mongosymb_multithread.py @@ -16,7 +16,9 @@ def main(): """Execute Main program.""" parent_parser = mongosymb.make_argument_parser(add_help=False) - parser = argparse.ArgumentParser(parents=[parent_parser], description=__doc__, add_help=True) + parser = argparse.ArgumentParser( + parents=[parent_parser], description=__doc__, add_help=True + ) options = parser.parse_args() # Remember the prologue between lines, @@ -42,21 +44,25 @@ def main(): merged = {**thread_record, **prologue} output_fn = None - if options.output_format == 'json': + if options.output_format == "json": output_fn = json.dump - if options.output_format == 'classic': + if options.output_format == "classic": output_fn = mongosymb.classic_output resolver = None - if options.debug_file_resolver == 'path': + if options.debug_file_resolver == "path": resolver = mongosymb.PathDbgFileResolver(options.path_to_executable) - elif options.debug_file_resolver == 's3': - resolver = mongosymb.S3BuildidDbgFileResolver(options.s3_cache_dir, - options.s3_bucket) + elif options.debug_file_resolver == "s3": + resolver = mongosymb.S3BuildidDbgFileResolver( + options.s3_cache_dir, options.s3_bucket + ) frames = mongosymb.symbolize_frames(merged, resolver, **vars(options)) - print("\nthread {{name='{}', tid={}}}:".format(thread_record["name"], - thread_record["tid"])) + print( + "\nthread {{name='{}', tid={}}}:".format( + thread_record["name"], thread_record["tid"] + ) + ) output_fn(frames, sys.stdout, indent=2) @@ -64,6 +70,6 @@ def main(): print("failed to parse line: `{}`".format(line), file=sys.stderr) -if __name__ == '__main__': +if __name__ == "__main__": main() sys.exit(0) diff --git a/buildscripts/monitor_mongo_fork_10gen.py b/buildscripts/monitor_mongo_fork_10gen.py index 85afb0cf0f4..33f285c6d53 100644 --- a/buildscripts/monitor_mongo_fork_10gen.py +++ b/buildscripts/monitor_mongo_fork_10gen.py @@ -8,8 +8,9 @@ from simple_report import make_report, put_report from buildscripts.util.read_config import read_config_file -def get_installation_access_token(app_id: int, private_key: str, - installation_id: int) -> Optional[str]: # noqa: D406 +def get_installation_access_token( + app_id: int, private_key: str, installation_id: int +) -> Optional[str]: # noqa: D406 """ Obtain an installation access token using JWT. @@ -64,7 +65,9 @@ def are_users_members_of_org(users: List[str], org: str, token: str) -> List[str try: github_client = Github(token) organization = github_client.get_organization(org) - org_member_usernames = set(member.login for member in organization.get_members()) + org_member_usernames = set( + member.login for member in organization.get_members() + ) return [user for user in users if user in org_member_usernames] except GithubException as e: print(f"An exception occurred: {e}") @@ -73,11 +76,22 @@ def are_users_members_of_org(users: List[str], org: str, token: str) -> List[str def main(): # Set up argument parsing - parser = argparse.ArgumentParser(description='Monitor forks of MongoDB repo by 10gen members.') - parser.add_argument("-l", "--log-file", type=str, default="mongo_fork_from_10gen", - help="Log file for storing output.") - parser.add_argument("--expansions-file", "-e", default="../expansions.yml", - help="Expansions file to read GitHub app credentials from.") + parser = argparse.ArgumentParser( + description="Monitor forks of MongoDB repo by 10gen members." + ) + parser.add_argument( + "-l", + "--log-file", + type=str, + default="mongo_fork_from_10gen", + help="Log file for storing output.", + ) + parser.add_argument( + "--expansions-file", + "-e", + default="../expansions.yml", + help="Expansions file to read GitHub app credentials from.", + ) args = parser.parse_args() # Read configurations @@ -85,27 +99,36 @@ def main(): # Obtain installation access tokens using app credentials access_token_mongodb_forks = get_installation_access_token( - expansions["app_id_mongodb_forks"], expansions["private_key_mongodb_forks"], - expansions["installation_id_mongodb_forks"]) + expansions["app_id_mongodb_forks"], + expansions["private_key_mongodb_forks"], + expansions["installation_id_mongodb_forks"], + ) access_token_10gen_member = get_installation_access_token( - expansions["app_id_10gen_member"], expansions["private_key_10gen_member"], - expansions["installation_id_10gen_member"]) + expansions["app_id_10gen_member"], + expansions["private_key_10gen_member"], + expansions["installation_id_10gen_member"], + ) if not access_token_mongodb_forks or not access_token_10gen_member: print("Error obtaining the installation tokens.") return # Retrieve list of users who forked mongodb/mongo repo - forked_users = get_users_who_forked_mongo_repo('mongodb', 'mongo', access_token_mongodb_forks) + forked_users = get_users_who_forked_mongo_repo( + "mongodb", "mongo", access_token_mongodb_forks + ) print(f"Recent forks info: {forked_users}") - #TODO: SERVER-83253: Request for Deletion of mongodb/mongo Fork - #TODO: SERVER-83254: Request for Deletion of mongodb/mongo Fork - exclude_list = ['RedBeard0531', 'hanumantmk'] + # TODO: SERVER-83253: Request for Deletion of mongodb/mongo Fork + # TODO: SERVER-83254: Request for Deletion of mongodb/mongo Fork + exclude_list = ["RedBeard0531", "hanumantmk"] # Filter out users who are members of the specified organization members_from_10gen = [ - user for user in are_users_members_of_org(forked_users, '10gen', access_token_10gen_member) + user + for user in are_users_members_of_org( + forked_users, "10gen", access_token_10gen_member + ) if user not in exclude_list ] @@ -117,11 +140,14 @@ def main(): users_list = [f"+ {user}" for user in members_from_10gen] users_list_message = ( "For each of these names, please make a BF and assign it to that user.\n\n" - "Users who recently forked mongodb/mongo and are members of 10gen:\n" + - '\n'.join(users_list)) + "Users who recently forked mongodb/mongo and are members of 10gen:\n" + + "\n".join(users_list) + ) print(users_list_message) else: - users_list_message = "No users who recently forked mongodb/mongo are members of 10gen." + users_list_message = ( + "No users who recently forked mongodb/mongo are members of 10gen." + ) # Make report exit_code = 1 if members_from_10gen else 0 diff --git a/buildscripts/package_test.py b/buildscripts/package_test.py index 5b9041d7e6f..4b63bd48d84 100644 --- a/buildscripts/package_test.py +++ b/buildscripts/package_test.py @@ -29,7 +29,7 @@ root.setLevel(logging.DEBUG) handler = logging.StreamHandler(sys.stdout) handler.setLevel(logging.DEBUG) -formatter = logging.Formatter('[%(asctime)s]%(levelname)s:%(message)s') +formatter = logging.Formatter("[%(asctime)s]%(levelname)s:%(message)s") handler.setFormatter(formatter) root.addHandler(handler) @@ -56,26 +56,40 @@ OS_DOCKER_LOOKUP = { "amazon2": ( "amazonlinux:2", "yum", - frozenset(["python", "python3", "wget", "pkgconfig", "systemd", "procps", "file"]), + frozenset( + ["python", "python3", "wget", "pkgconfig", "systemd", "procps", "file"] + ), "python3", ), "amazon2023": ( "amazonlinux:2023", "yum", - frozenset(["python", "python3", "wget", "pkgconfig", "systemd", "procps", "file"]), + frozenset( + ["python", "python3", "wget", "pkgconfig", "systemd", "procps", "file"] + ), "python3", ), "debian10": ( "debian:10-slim", "apt", - frozenset(["python", "python3", "wget", "pkg-config", "systemd", "procps", "file"]), + frozenset( + ["python", "python3", "wget", "pkg-config", "systemd", "procps", "file"] + ), "python3", ), "debian11": ( "debian:11-slim", "apt", frozenset( - ["python3", "python-is-python3", "wget", "pkg-config", "systemd", "procps", "file"] + [ + "python3", + "python-is-python3", + "wget", + "pkg-config", + "systemd", + "procps", + "file", + ] ), "python3", ), @@ -83,26 +97,40 @@ OS_DOCKER_LOOKUP = { "debian:12-slim", "apt", frozenset( - ["python3", "python-is-python3", "wget", "pkg-config", "systemd", "procps", "file"] + [ + "python3", + "python-is-python3", + "wget", + "pkg-config", + "systemd", + "procps", + "file", + ] ), "python3", ), "debian71": ( "debian:7-slim", "apt", - frozenset(["python", "python3", "wget", "pkg-config", "systemd", "procps", "file"]), + frozenset( + ["python", "python3", "wget", "pkg-config", "systemd", "procps", "file"] + ), "python3", ), "debian81": ( "debian:8-slim", "apt", - frozenset(["python", "python3", "wget", "pkg-config", "systemd", "procps", "file"]), + frozenset( + ["python", "python3", "wget", "pkg-config", "systemd", "procps", "file"] + ), "python3", ), "debian92": ( "debian:9-slim", "apt", - frozenset(["python", "python3", "wget", "pkg-config", "systemd", "procps", "file"]), + frozenset( + ["python", "python3", "wget", "pkg-config", "systemd", "procps", "file"] + ), "python3", ), "linux_i686": None, @@ -116,25 +144,33 @@ OS_DOCKER_LOOKUP = { "rhel70": ( "registry.access.redhat.com/ubi7/ubi", "yum", - frozenset(["rh-python38.x86_64", "wget", "pkgconfig", "systemd", "procps", "file"]), + frozenset( + ["rh-python38.x86_64", "wget", "pkgconfig", "systemd", "procps", "file"] + ), "/opt/rh/rh-python38/root/usr/bin/python3", ), "rhel71": ( "registry.access.redhat.com/ubi7/ubi", "yum", - frozenset(["rh-python38.x86_64", "wget", "pkgconfig", "systemd", "procps", "file"]), + frozenset( + ["rh-python38.x86_64", "wget", "pkgconfig", "systemd", "procps", "file"] + ), "/opt/rh/rh-python38/root/usr/bin/python3", ), "rhel72": ( "registry.access.redhat.com/ubi7/ubi", "yum", - frozenset(["rh-python38.x86_64", "wget", "pkgconfig", "systemd", "procps", "file"]), + frozenset( + ["rh-python38.x86_64", "wget", "pkgconfig", "systemd", "procps", "file"] + ), "/opt/rh/rh-python38/root/usr/bin/python3", ), "rhel79": ( "registry.access.redhat.com/ubi7/ubi", "yum", - frozenset(["rh-python38.x86_64", "wget", "pkgconfig", "systemd", "procps", "file"]), + frozenset( + ["rh-python38.x86_64", "wget", "pkgconfig", "systemd", "procps", "file"] + ), "/opt/rh/rh-python38/root/usr/bin/python3", ), "rhel8": ( @@ -195,43 +231,93 @@ OS_DOCKER_LOOKUP = { "python3", ), # Has the same error as above - 'ubuntu1204': None, - 'ubuntu1404': None, - 'ubuntu1604': ('ubuntu:16.04', "apt", - frozenset([ - "apt-utils", "python", "python3", "wget", "pkg-config", "systemd", "procps", - "file" - ]), "python3"), - 'ubuntu1804': ('ubuntu:18.04', "apt", - frozenset( - ["python", "python3", "wget", "pkg-config", "systemd", "procps", "file"]), - "python3"), - 'ubuntu2004': ('ubuntu:20.04', "apt", - frozenset([ - "python3", "python-is-python3", "wget", "pkg-config", "systemd", "procps", - "file" - ]), "python3"), - 'ubuntu2204': ('ubuntu:22.04', "apt", - frozenset([ - "python3", "python-is-python3", "wget", "pkg-config", "systemd", "procps", - "file" - ]), "python3"), - 'ubuntu2404': ('ubuntu:24.04', "apt", - frozenset([ - "python3", "python-is-python3", "wget", "pkg-config", "systemd", "procps", - "file" - ]), "python3"), - 'windows': None, - 'windows_i686': None, - 'windows_x86_64': None, - 'windows_x86_64-2008plus': None, - 'windows_x86_64-2008plus-ssl': None, - 'windows_x86_64-2012plus': None, + "ubuntu1204": None, + "ubuntu1404": None, + "ubuntu1604": ( + "ubuntu:16.04", + "apt", + frozenset( + [ + "apt-utils", + "python", + "python3", + "wget", + "pkg-config", + "systemd", + "procps", + "file", + ] + ), + "python3", + ), + "ubuntu1804": ( + "ubuntu:18.04", + "apt", + frozenset( + ["python", "python3", "wget", "pkg-config", "systemd", "procps", "file"] + ), + "python3", + ), + "ubuntu2004": ( + "ubuntu:20.04", + "apt", + frozenset( + [ + "python3", + "python-is-python3", + "wget", + "pkg-config", + "systemd", + "procps", + "file", + ] + ), + "python3", + ), + "ubuntu2204": ( + "ubuntu:22.04", + "apt", + frozenset( + [ + "python3", + "python-is-python3", + "wget", + "pkg-config", + "systemd", + "procps", + "file", + ] + ), + "python3", + ), + "ubuntu2404": ( + "ubuntu:24.04", + "apt", + frozenset( + [ + "python3", + "python-is-python3", + "wget", + "pkg-config", + "systemd", + "procps", + "file", + ] + ), + "python3", + ), + "windows": None, + "windows_i686": None, + "windows_x86_64": None, + "windows_x86_64-2008plus": None, + "windows_x86_64-2008plus-ssl": None, + "windows_x86_64-2012plus": None, } # These versions are marked "current" but in fact are EOL VERSIONS_TO_SKIP: Set[str] = set( - ['3.0.15', '3.2.22', '3.4.24', '3.6.23', '4.0.28', '4.2.24', '6.3.2']) + ["3.0.15", "3.2.22", "3.4.24", "3.6.23", "4.0.28", "4.2.24", "6.3.2"] +) DISABLED_TESTS: Set[Tuple[str, str]] = set() VALID_TAR_DIRECTORY_ARCHITECTURES = [ @@ -289,12 +375,18 @@ def get_image(test: Test, client: DockerClient) -> Image: tries = 1 while True: try: - logging.info("Pulling base image for %s: %s, try %s", test.os_name, test.base_image, - tries) + logging.info( + "Pulling base image for %s: %s, try %s", + test.os_name, + test.base_image, + tries, + ) base_image = client.images.pull(test.base_image) except docker.errors.ImageNotFound as exc: if tries >= 5: - logging.error("Base image %s not found after %s tries", test.base_image, tries) + logging.error( + "Base image %s not found after %s tries", test.base_image, tries + ) raise exc else: return base_image @@ -303,7 +395,7 @@ def get_image(test: Test, client: DockerClient) -> Image: time.sleep(1) -def join_commands(commands: List[str], sep: str = ' && ') -> str: +def join_commands(commands: List[str], sep: str = " && ") -> str: return sep.join(commands) @@ -316,8 +408,14 @@ def run_test_with_timeout(test: Test, client: DockerClient, timeout: int) -> Res except futures.TimeoutError: end_time = time.time() logging.debug("Test %s timed out", test) - result = Result(status="fail", test_file=test.name(), start=start_time, - log_raw="test timed out", end=end_time, exit_code=1) + result = Result( + status="fail", + test_file=test.name(), + start=start_time, + log_raw="test timed out", + end=end_time, + exit_code=1, + ) return result @@ -338,7 +436,7 @@ def run_test(test: Test, client: DockerClient) -> Result: "yum -y install yum-utils epel-release", "yum-config-manager --enable epel", ] - if test.os_name.startswith('debian92'): + if test.os_name.startswith("debian92"): # Adapted from https://stackoverflow.com/questions/76094428/debian-stretch-repositories-404-not-found # Debian92 renamed its repos to archive # The first two sed commands are to replace debian92's sources list to archive repo @@ -354,15 +452,17 @@ def run_test(test: Test, client: DockerClient) -> Result: test.install_command.format(" ".join(test.base_packages)), ] - if test.python_command != 'python3': + if test.python_command != "python3": commands.append(f"ln -s {test.python_command} /usr/bin/python3") os.makedirs(log_external_path.parent, exist_ok=True) commands.append( f"python3 /mnt/package_test/package_test_internal.py {log_docker_path} {' '.join(test.packages_urls)}" ) - logging.debug("Attempting to run the following docker commands:\n\t%s", - join_commands(commands, sep='\n\t')) + logging.debug( + "Attempting to run the following docker commands:\n\t%s", + join_commands(commands, sep="\n\t"), + ) image: Image | None = None container: Container | None = None @@ -382,42 +482,44 @@ def run_test(test: Test, client: DockerClient) -> Result: ], ) for log in container.logs(stream=True): - result["log_raw"] += log.decode('UTF-8') + result["log_raw"] += log.decode("UTF-8") # This is pretty verbose, lets run this way for a while and we can delete this if it ends up being too much - logging.debug(log.decode('UTF-8').strip()) + logging.debug(log.decode("UTF-8").strip()) exit_code = container.wait() - result["exit_code"] = exit_code['StatusCode'] + result["exit_code"] = exit_code["StatusCode"] except docker.errors.APIError as exc: traceback.print_exception(type(exc), exc, exc.__traceback__) logging.error("Failed to start test") result["end"] = time.time() - result['status'] = 'fail' + result["status"] = "fail" result["exit_code"] = 1 return result try: - with open(log_external_path, 'r') as log_raw: + with open(log_external_path, "r") as log_raw: result["log_raw"] += log_raw.read() except OSError as oserror: logging.error("Failed to open %s with error %s", log_external_path, oserror) - if exit_code['StatusCode'] != 0: + if exit_code["StatusCode"] != 0: logging.error("Failed test %s with exit code %s", test, exit_code) - result['status'] = 'fail' + result["status"] = "fail" result["end"] = time.time() return result logging.info("Attempting to download current mongo releases json") -r = requests.get('https://downloads.mongodb.org/current.json') +r = requests.get("https://downloads.mongodb.org/current.json") current_releases = r.json() logging.info("Attempting to download current mongo tools releases json") -r = requests.get('https://downloads.mongodb.org/tools/db/release.json') +r = requests.get("https://downloads.mongodb.org/tools/db/release.json") current_tools_releases = r.json() logging.info("Attempting to download current mongosh releases json") -r = requests.get('https://s3.amazonaws.com/info-mongodb-com/com-download-center/mongosh.json') +r = requests.get( + "https://s3.amazonaws.com/info-mongodb-com/com-download-center/mongosh.json" +) mongosh_releases = r.json() @@ -437,16 +539,22 @@ def iterate_over_downloads() -> Generator[Dict[str, Any], None, None]: def get_tools_package(arch_name: str, os_name: str) -> Optional[str]: # TODO: MONGOSH-1308 - we need to sub the arch alias until package # architectures are named consistently with the server packages - if arch_name == "aarch64" and not os_name.startswith("amazon") and not os_name.startswith( - "rhel"): + if ( + arch_name == "aarch64" + and not os_name.startswith("amazon") + and not os_name.startswith("rhel") + ): arch_name = "arm64" # Tools packages are only published to the latest RHEL version supported on master, but # the tools binaries are cross compatible with other RHEL versions # (see https://jira.mongodb.org/browse/SERVER-92939) def major_version_matches(download_name: str) -> bool: - if (os_name.startswith("rhel") and download_name.startswith("rhel") - and os_name[4] == download_name[4]): + if ( + os_name.startswith("rhel") + and download_name.startswith("rhel") + and os_name[4] == download_name[4] + ): return True return download_name == os_name @@ -473,10 +581,10 @@ def get_mongosh_package(arch_name: str, os_name: str) -> Optional[str]: def get_arch_aliases(arch_name: str) -> List[str]: - if arch_name in ('amd64', 'x86_64'): - return ['amd64', 'x86_64'] - if arch_name in ('ppc64le', 'ppc64el'): - return ['ppc64le', 'ppc64el'] + if arch_name in ("amd64", "x86_64"): + return ["amd64", "x86_64"] + if arch_name in ("ppc64le", "ppc64el"): + return ["ppc64le", "ppc64el"] return [arch_name] @@ -485,12 +593,19 @@ def get_edition_alias(edition_name: str) -> str: return "org" return edition_name + def validate_top_level_directory(tar_name: str): command = f"tar -tf {tar_name} | head -n 1 | awk -F/ '{{print $1}}'" proc = subprocess.run(command, capture_output=True, shell=True, text=True) top_level_directory = proc.stdout.strip() - if all(os_arch not in top_level_directory for os_arch in VALID_TAR_DIRECTORY_ARCHITECTURES): - raise Exception(f"Found an unexpected os-arch pairing as the top level directory. Top level directory: {top_level_directory}") + if all( + os_arch not in top_level_directory + for os_arch in VALID_TAR_DIRECTORY_ARCHITECTURES + ): + raise Exception( + f"Found an unexpected os-arch pairing as the top level directory. Top level directory: {top_level_directory}" + ) + arches: Set[str] = set() oses: Set[str] = set() @@ -504,12 +619,18 @@ for dl in iterate_over_downloads(): versions.add(dl["version"]) parser = argparse.ArgumentParser( - description= - 'Test packages on various hosts. This will spin up docker containers and test the installs.') -parser.add_argument("--arch", type=str, help="Arch of packages to test", - choices=["auto"] + list(arches), default="auto") -parser.add_argument("-r", "--retries", type=int, help="Number of times to retry failed tests", - default=3) + description="Test packages on various hosts. This will spin up docker containers and test the installs." +) +parser.add_argument( + "--arch", + type=str, + help="Arch of packages to test", + choices=["auto"] + list(arches), + default="auto", +) +parser.add_argument( + "-r", "--retries", type=int, help="Number of times to retry failed tests", default=3 +) parser.add_argument( "--skip-enterprise-check", action="store_true", @@ -519,27 +640,53 @@ parser.add_argument( subparsers = parser.add_subparsers(dest="command") release_test_parser = subparsers.add_parser("release") release_test_parser.add_argument( - "--os", type=str, help= - "OS of docker image to run test(s) on. All means run all os tests on this arch. None means run no os test on this arch (except for one specified in extra-packages.", - choices=["all"] + list(oses), default="all") -release_test_parser.add_argument("-e", "--edition", help="Server edition to run tests for", - choices=["all"] + list(editions), default="all") -release_test_parser.add_argument("-v", "--server-version", type=str, - help="Version of MongoDB to run tests for", - choices=["all"] + list(versions), default="all") + "--os", + type=str, + help="OS of docker image to run test(s) on. All means run all os tests on this arch. None means run no os test on this arch (except for one specified in extra-packages.", + choices=["all"] + list(oses), + default="all", +) release_test_parser.add_argument( - "--evg-project", type=str, help= - "The evergreen project this is intended to run under (master only). Note that this interface is primarly for evergreen to set, and so the script will check if its is appropriate to run the tests.", - default="") + "-e", + "--edition", + help="Server edition to run tests for", + choices=["all"] + list(editions), + default="all", +) +release_test_parser.add_argument( + "-v", + "--server-version", + type=str, + help="Version of MongoDB to run tests for", + choices=["all"] + list(versions), + default="all", +) +release_test_parser.add_argument( + "--evg-project", + type=str, + help="The evergreen project this is intended to run under (master only). Note that this interface is primarly for evergreen to set, and so the script will check if its is appropriate to run the tests.", + default="", +) branch_test_parser = subparsers.add_parser("branch") branch_test_parser.add_argument( - "-t", "--test", type=str, help= - "Space-separated tuple of (test_os, package_archive_path). For example: ubuntu2004 https://s3.amazonaws.com/mciuploads/${project}/${build_variant}/${revision}/artifacts/${build_id}-packages.tgz.", - action='append', nargs=2, default=[]) -branch_test_parser.add_argument("-e", "--edition", type=str, help="Server edition being tested", - required=True) -branch_test_parser.add_argument("-v", "--server-version", type=str, - help="Server version being tested", required=True) + "-t", + "--test", + type=str, + help="Space-separated tuple of (test_os, package_archive_path). For example: ubuntu2004 https://s3.amazonaws.com/mciuploads/${project}/${build_variant}/${revision}/artifacts/${build_id}-packages.tgz.", + action="append", + nargs=2, + default=[], +) +branch_test_parser.add_argument( + "-e", "--edition", type=str, help="Server edition being tested", required=True +) +branch_test_parser.add_argument( + "-v", + "--server-version", + type=str, + help="Server version being tested", + required=True, +) args = parser.parse_args() if args.command == "release": @@ -552,7 +699,8 @@ if args.command == "release": if re.fullmatch(r"mongodb-mongo-v\d\.\d-staging", evg_project): logging.info( "Non-master evergreen project detected: '%s', skipping release package testing which is expected to only be run from master branches.", - evg_project) + evg_project, + ) sys.exit(0) arch: str = args.arch @@ -567,14 +715,17 @@ if args.command == "branch": test_os = test_pair[0] urls = [test_pair[1]] if test_os not in OS_DOCKER_LOOKUP: - logging.error("We have not seen this OS %s before, please add it to OS_DOCKER_LOOKUP", - test_os) + logging.error( + "We have not seen this OS %s before, please add it to OS_DOCKER_LOOKUP", + test_os, + ) sys.exit(1) if not OS_DOCKER_LOOKUP[test_os]: logging.info( "Skipping test on target because the OS has no associated container %s->???", - test_os) + test_os, + ) continue tools_package = get_tools_package(arch, test_os) @@ -592,8 +743,13 @@ if args.command == "branch": sys.exit(1) tests.append( - Test(os_name=test_os, edition=args.edition, version=args.server_version, - packages_urls=urls)) + Test( + os_name=test_os, + edition=args.edition, + version=args.server_version, + packages_urls=urls, + ) + ) validate_top_level_directory("mongo-binaries.tgz") @@ -603,12 +759,16 @@ if args.command == "branch": ) if args.edition != "enterprise": - exception_msg = "Found enterprise code in non-enterprise binary {binfile}." + exception_msg = ( + "Found enterprise code in non-enterprise binary {binfile}." + ) def validate_binaries(sources_text): return "src/mongo/db/modules/enterprise" not in sources_text else: - exception_msg = "Failed to find enterprise code in enterprise binary {binfile}." + exception_msg = ( + "Failed to find enterprise code in enterprise binary {binfile}." + ) def validate_binaries(sources_text): return "src/mongo/db/modules/enterprise" in sources_text @@ -662,7 +822,6 @@ if args.command == "branch": # If os is None we only want to do the tests specified in the arguments if args.command == "release": - for dl in iterate_over_downloads(): if args.os not in ["all", dl["target"]]: continue @@ -681,13 +840,17 @@ if args.command == "release": if not OS_DOCKER_LOOKUP[dl["target"]]: logging.info( "Skipping test on target because the OS has no associated container %s->??? on mongo version %s", - dl['target'], dl['version']) + dl["target"], + dl["version"], + ) continue if "packages" not in dl: logging.info( "Skipping test on target because there are no packages %s->??? on mongo version %s", - dl['target'], dl['version']) + dl["target"], + dl["version"], + ) continue if (dl["target"], dl["version"]) in DISABLED_TESTS: @@ -707,7 +870,7 @@ if args.command == "release": repo_uri: str package: str repo_uri, package = urls[0].rsplit("/", 1) - match = re.match(r'(\w+-(\w+(?:-unstable)?))-[^-_]+((?:-|_).*)', package) + match = re.match(r"(\w+-(\w+(?:-unstable)?))-[^-_]+((?:-|_).*)", package) if match: urls.insert(0, f"{repo_uri}/{match.group(1)}{match.group(3)}") # The actual "edition" may be an unstable package release, so we @@ -719,7 +882,9 @@ if args.command == "release": urls.append(f"{repo_uri}/{match.group(1)}-database{match.group(3)}") if version_major > 4 or (version_major == 4 and version_minor >= 3): - urls.append(f"{repo_uri}/{match.group(1)}-database-tools-extra{match.group(3)}") + urls.append( + f"{repo_uri}/{match.group(1)}-database-tools-extra{match.group(3)}" + ) urls.append(f"{repo_uri}/{match.group(1)}-tools{match.group(3)}") urls.append(f"{repo_uri}/{match.group(1)}-mongos{match.group(3)}") @@ -741,7 +906,9 @@ if args.command == "release": if tools_package: urls.append(tools_package) else: - logging.error("Could not find tools package for %s and %s", arch, test_os) + logging.error( + "Could not find tools package for %s and %s", arch, test_os + ) sys.exit(1) mongosh_package = get_mongosh_package(arch, test_os) @@ -754,11 +921,17 @@ if args.command == "release": sys.exit(1) tests.append( - Test(os_name=test_os, packages_urls=urls, edition=edition, version=server_version)) + Test( + os_name=test_os, + packages_urls=urls, + edition=edition, + version=server_version, + ) + ) docker_client = docker.client.from_env() -docker_username = os.environ.get('docker_username') -docker_password = os.environ.get('docker_password') +docker_username = os.environ.get("docker_username") +docker_password = os.environ.get("docker_password") if all((docker_username, docker_password)): logging.info("Logging into docker.io") response = docker_client.login(username=docker_username, password=docker_password) @@ -771,15 +944,18 @@ with futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as tpe: # Set a timeout of 10mins timeout for a single test SINGLE_TEST_TIMEOUT = 10 * 60 test_futures = { - tpe.submit(run_test_with_timeout, test, docker_client, SINGLE_TEST_TIMEOUT): test + tpe.submit( + run_test_with_timeout, test, docker_client, SINGLE_TEST_TIMEOUT + ): test for test in tests } completed_tests: int = 0 retried_tests: int = 0 total_tests: int = len(tests) while len(test_futures.keys()) > 0: - finished_futures, active_futures = futures.wait(test_futures.keys(), timeout=None, - return_when="FIRST_COMPLETED") + finished_futures, active_futures = futures.wait( + test_futures.keys(), timeout=None, return_when="FIRST_COMPLETED" + ) for f in finished_futures: completed_test = test_futures.pop(f) test_result = f.result() @@ -787,8 +963,9 @@ with futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as tpe: if completed_test.attempts < args.retries: retried_tests += 1 completed_test.attempts += 1 - test_futures[tpe.submit(run_test, completed_test, - docker_client)] = completed_test + test_futures[ + tpe.submit(run_test, completed_test, docker_client) + ] = completed_test continue report["failures"] += 1 @@ -797,7 +974,11 @@ with futures.ThreadPoolExecutor(max_workers=os.cpu_count()) as tpe: logging.info( "Completed %s tests, retried %s tests, total %s tests, %s tests are in progress.", - completed_tests, retried_tests, total_tests, len(test_futures)) + completed_tests, + retried_tests, + total_tests, + len(test_futures), + ) # We are printing here to help diagnose hangs # This adds a bit of logging so we are only going to log running tests after a test completes @@ -808,15 +989,16 @@ with open("report.json", "w") as fh: json.dump(report, fh) if report["failures"] == 0: - logging.info("All %s tests passed :)", len(report['results'])) + logging.info("All %s tests passed :)", len(report["results"])) sys.exit(0) else: failed_tests = [ - test_result["test_file"] for test_result in report["results"] + test_result["test_file"] + for test_result in report["results"] if test_result["exit_code"] != 0 ] - success_count = len(report['results']) - len(failed_tests) - logging.info("%s/%s tests passed", success_count, len(report['results'])) + success_count = len(report["results"]) - len(failed_tests) + logging.info("%s/%s tests passed", success_count, len(report["results"])) if len(failed_tests) > 0: logging.info("Failed tests:\n\t%s", "\n\t".join(failed_tests)) sys.exit(1) diff --git a/buildscripts/package_test_internal.py b/buildscripts/package_test_internal.py index 4a0b731de8c..f651cab74db 100644 --- a/buildscripts/package_test_internal.py +++ b/buildscripts/package_test_internal.py @@ -24,15 +24,17 @@ root.setLevel(logging.DEBUG) stdout_handler = logging.StreamHandler(sys.stdout) stdout_handler.setLevel(logging.DEBUG) -file_handler = WatchedFileHandler(sys.argv[1], mode='w', encoding='utf8') +file_handler = WatchedFileHandler(sys.argv[1], mode="w", encoding="utf8") file_handler.setLevel(logging.DEBUG) -formatter = logging.Formatter('[%(asctime)s]%(levelname)s:%(message)s') +formatter = logging.Formatter("[%(asctime)s]%(levelname)s:%(message)s") stdout_handler.setFormatter(formatter) file_handler.setFormatter(formatter) root.addHandler(stdout_handler) root.addHandler(file_handler) -DOCKER_SYSTEMCTL_REPO = "https://raw.githubusercontent.com/gdraheim/docker-systemctl-replacement" +DOCKER_SYSTEMCTL_REPO = ( + "https://raw.githubusercontent.com/gdraheim/docker-systemctl-replacement" +) SYSTEMCTL_URL = DOCKER_SYSTEMCTL_REPO + "/master/files/docker/systemctl3.py" JOURNALCTL_URL = DOCKER_SYSTEMCTL_REPO + "/master/files/docker/journalctl3.py" @@ -41,7 +43,9 @@ TestArgs = Dict[str, Union[str, int, List[str]]] def run_and_log(cmd: str, end_on_error: bool = True): # type: (str, bool) -> 'subprocess.CompletedProcess[bytes]' - proc = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + proc = subprocess.run( + cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT + ) logging.debug(cmd) logging.debug(proc.stdout.decode("UTF-8").strip()) if end_on_error and proc.returncode != 0: @@ -53,15 +57,15 @@ def run_and_log(cmd: str, end_on_error: bool = True): def download_extract_package(package: str) -> List[str]: # Use wget here because using urllib we get errors like the following # https://stackoverflow.com/questions/27835619/urllib-and-ssl-certificate-verify-failed-error - run_and_log("wget -q \"{}\"".format(package)) - downloaded_file = package.split('/')[-1] + run_and_log('wget -q "{}"'.format(package)) + downloaded_file = package.split("/")[-1] if not package.endswith(".tgz"): return [downloaded_file] extracted_paths = [] # type: List[str] with tarfile.open(downloaded_file) as tf: for member in tf.getmembers(): - if member.name.endswith('.deb') or member.name.endswith('.rpm'): + if member.name.endswith(".deb") or member.name.endswith(".rpm"): extracted_paths.append(member.name) tf.extractall() @@ -79,17 +83,21 @@ def download_extract_all_packages(package_urls: List[str]) -> List[str]: def run_apt_test(packages: List[str]): logging.info("Detected apt running test.") - run_and_log("DEBIAN_FRONTEND=noninteractive apt-get install -y {}".format(' '.join(packages))) + run_and_log( + "DEBIAN_FRONTEND=noninteractive apt-get install -y {}".format( + " ".join(packages) + ) + ) def run_yum_test(packages: List[str]): logging.info("Detected yum running test.") - run_and_log("yum install -y {}".format(' '.join(packages))) + run_and_log("yum install -y {}".format(" ".join(packages))) def run_zypper_test(packages: List[str]): logging.info("Detected zypper running test.") - run_and_log("zypper -n --no-gpg-checks install {}".format(' '.join(packages))) + run_and_log("zypper -n --no-gpg-checks install {}".format(" ".join(packages))) def run_mongo_query(shell, query, should_fail=False, tries=60, interval=1.0): @@ -112,12 +120,15 @@ def run_mongo_query(shell, query, should_fail=False, tries=60, interval=1.0): traceback.print_exception(type(exc), exc, exc.__traceback__) raise else: - if ((not should_fail and exec_result.returncode == 0) - or (should_fail and exec_result.returncode != 0)): + if (not should_fail and exec_result.returncode == 0) or ( + should_fail and exec_result.returncode != 0 + ): return exec_result - logging.error("Command failed with output: %s", - exec_result.stdout.decode('UTF-8').rstrip()) + logging.error( + "Command failed with output: %s", + exec_result.stdout.decode("UTF-8").rstrip(), + ) current_try += 1 @@ -129,10 +140,10 @@ def run_mongo_query(shell, query, should_fail=False, tries=60, interval=1.0): def parse_os_release(path: str) -> Dict[str, str]: result = {} # type: Dict[str, str] - with open(path, 'r', encoding='utf-8') as os_release: + with open(path, "r", encoding="utf-8") as os_release: for line in os_release: try: - key, value = line.rstrip().split('=', 1) + key, value = line.rstrip().split("=", 1) except ValueError: continue value = value.strip('"') @@ -141,19 +152,21 @@ def parse_os_release(path: str) -> Dict[str, str]: def get_os_release() -> Tuple[str, int, int]: - if os.path.exists('/etc/os-release'): - release_info = parse_os_release('/etc/os-release') - elif os.path.exists('/usr/lib/os-release'): - release_info = parse_os_release('/etc/os-release') + if os.path.exists("/etc/os-release"): + release_info = parse_os_release("/etc/os-release") + elif os.path.exists("/usr/lib/os-release"): + release_info = parse_os_release("/etc/os-release") else: logging.error("SYSFAIL: Could not find os-release file") sys.exit(2) - os_name = release_info['ID'] - os_version = release_info['VERSION_ID'] + os_name = release_info["ID"] + os_version = release_info["VERSION_ID"] try: - os_version_major, os_version_minor = (int(text) for text in os_version.split('.')) + os_version_major, os_version_minor = ( + int(text) for text in os_version.split(".") + ) except ValueError: os_version_major = int(os_version) os_version_minor = 0 @@ -163,27 +176,33 @@ def get_os_release() -> Tuple[str, int, int]: def parse_ulimits(pid: int) -> Dict[str, Tuple[int, int, Optional[str]]]: ulimit_line_re = re.compile( - r'(?P.*?)\s{2,}(?P\S+)\s+(?P\S+)(?:\s+(?P\S+))?', re.MULTILINE) + r"(?P.*?)\s{2,}(?P\S+)\s+(?P\S+)(?:\s+(?P\S+))?", + re.MULTILINE, + ) result = {} # type: Dict[str, Tuple[int, int, Optional[str]]] - with open("/proc/{}/limits".format(pid), 'r', encoding='utf-8') as ulimits_file: + with open("/proc/{}/limits".format(pid), "r", encoding="utf-8") as ulimits_file: next(ulimits_file) for line in ulimits_file: limits = ulimit_line_re.match(line) if limits is None: continue try: - soft_limit = int(limits.group('soft')) + soft_limit = int(limits.group("soft")) except ValueError: # unlimited soft_limit = -1 try: - hard_limit = int(limits.group('hard')) + hard_limit = int(limits.group("hard")) except ValueError: # unlimited hard_limit = -1 - result[limits.group('name')] = (soft_limit, hard_limit, limits.group('units')) + result[limits.group("name")] = ( + soft_limit, + hard_limit, + limits.group("units"), + ) return result @@ -192,57 +211,67 @@ def get_test_args(package_manager: str, package_files: List[str]) -> TestArgs: # Set up data for later tests test_args = {} # type: TestArgs - test_args['package_manager'] = package_manager - test_args['package_files'] = package_files + test_args["package_manager"] = package_manager + test_args["package_files"] = package_files os_name, os_version_major, os_version_minor = get_os_release() - test_args['os_name'] = os_name - test_args['os_version_major'] = os_version_major - test_args['os_version_minor'] = os_version_minor + test_args["os_name"] = os_name + test_args["os_version_major"] = os_version_major + test_args["os_version_minor"] = os_version_minor - test_args['systemd_units_dir'] = run_and_log( - "pkg-config systemd --variable=systemdsystemunitdir").stdout.decode('utf-8').strip() - test_args['systemd_presets_dir'] = run_and_log( - "pkg-config systemd --variable=systemdsystempresetdir").stdout.decode('utf-8').strip() + test_args["systemd_units_dir"] = ( + run_and_log("pkg-config systemd --variable=systemdsystemunitdir") + .stdout.decode("utf-8") + .strip() + ) + test_args["systemd_presets_dir"] = ( + run_and_log("pkg-config systemd --variable=systemdsystempresetdir") + .stdout.decode("utf-8") + .strip() + ) - if package_manager in ('yum', 'zypper'): - test_args['mongo_username'] = 'mongod' - test_args['mongo_groupname'] = 'mongod' - test_args['mongo_home_dir'] = '/var/lib/mongo' - test_args['mongo_work_dir'] = '/var/lib/mongo' - test_args['mongo_user_shell'] = '/bin/false' + if package_manager in ("yum", "zypper"): + test_args["mongo_username"] = "mongod" + test_args["mongo_groupname"] = "mongod" + test_args["mongo_home_dir"] = "/var/lib/mongo" + test_args["mongo_work_dir"] = "/var/lib/mongo" + test_args["mongo_user_shell"] = "/bin/false" else: - test_args['mongo_username'] = 'mongodb' - test_args['mongo_groupname'] = 'mongodb' - test_args['mongo_home_dir'] = '/home/mongodb' - test_args['mongo_work_dir'] = '/var/lib/mongodb' + test_args["mongo_username"] = "mongodb" + test_args["mongo_groupname"] = "mongodb" + test_args["mongo_home_dir"] = "/home/mongodb" + test_args["mongo_work_dir"] = "/var/lib/mongodb" - if ((os_name == 'debian' and os_version_major >= 10) - or (os_name == 'ubuntu' and os_version_major >= 18)): - test_args['mongo_user_shell'] = '/usr/sbin/nologin' + if (os_name == "debian" and os_version_major >= 10) or ( + os_name == "ubuntu" and os_version_major >= 18 + ): + test_args["mongo_user_shell"] = "/usr/sbin/nologin" else: - test_args['mongo_user_shell'] = '/bin/false' + test_args["mongo_user_shell"] = "/bin/false" - test_args['arch'] = platform.machine() + test_args["arch"] = platform.machine() - deb_output_re = re.compile(r'(?<=Package: ).*$', re.MULTILINE) + deb_output_re = re.compile(r"(?<=Package: ).*$", re.MULTILINE) def get_package_name(package_file: str) -> str: - if package_manager in ('yum', 'zypper'): + if package_manager in ("yum", "zypper"): result = run_and_log( - "rpm --nosignature -qp --queryformat '%{{NAME}}' {0}".format(package_file)) - return result.stdout.decode('utf-8').strip() + "rpm --nosignature -qp --queryformat '%{{NAME}}' {0}".format( + package_file + ) + ) + return result.stdout.decode("utf-8").strip() else: result = run_and_log("dpkg -I {}".format(package_file)) - match = deb_output_re.search(result.stdout.decode('utf-8')) + match = deb_output_re.search(result.stdout.decode("utf-8")) if match is not None: return match.group(0) - return '' + return "" package_names = [] # type: List[str] for package in package_files: package_names.append(get_package_name(package)) - test_args['package_names'] = package_names + test_args["package_names"] = package_names if pathlib.Path("/usr/bin/systemd").exists(): test_args["systemd_path"] = "/usr/bin" @@ -264,7 +293,9 @@ def setup(test_args: TestArgs): # it in our tests run_and_log("mkdir -p /run/systemd/system") run_and_log("mkdir -p {}".format(test_args["systemd_presets_dir"])) - run_and_log("echo 'disable *' > {}/00-test.preset".format(test_args["systemd_presets_dir"])) + run_and_log( + "echo 'disable *' > {}/00-test.preset".format(test_args["systemd_presets_dir"]) + ) def install_fake_systemd(test_args: TestArgs): @@ -308,22 +339,22 @@ def test_install_is_complete(test_args: TestArgs): logging.info("Checking that the installation is complete.") required_files = [ - pathlib.Path('/etc/mongod.conf'), - pathlib.Path('/usr/bin/mongod'), - pathlib.Path('/var/log/mongodb/mongod.log'), - pathlib.Path(test_args['systemd_units_dir']) / "mongod.service", + pathlib.Path("/etc/mongod.conf"), + pathlib.Path("/usr/bin/mongod"), + pathlib.Path("/var/log/mongodb/mongod.log"), + pathlib.Path(test_args["systemd_units_dir"]) / "mongod.service", ] # type: List[pathlib.Path] required_dirs = [ - pathlib.Path('/run/mongodb'), - pathlib.Path('/var/run/mongodb'), - pathlib.Path(test_args['mongo_work_dir']), + pathlib.Path("/run/mongodb"), + pathlib.Path("/var/run/mongodb"), + pathlib.Path(test_args["mongo_work_dir"]), ] # type: List[pathlib.Path] - if test_args['package_manager'] in ('yum', 'zypper'): + if test_args["package_manager"] in ("yum", "zypper"): # Only RPM-based distros create the home directory. Debian/Ubuntu # distros use a non-existent directory in /home - required_dirs.append(pathlib.Path(test_args['mongo_home_dir'])) + required_dirs.append(pathlib.Path(test_args["mongo_home_dir"])) for path in required_files: if not (path.exists() and path.is_file()): @@ -334,65 +365,83 @@ def test_install_is_complete(test_args: TestArgs): raise RuntimeError("Required directory missing: {}".format(path)) try: - user_info = pwd.getpwnam(test_args['mongo_username']) + user_info = pwd.getpwnam(test_args["mongo_username"]) except KeyError: - raise RuntimeError("Required user missing: {}".format(test_args['mongo_username'])) + raise RuntimeError( + "Required user missing: {}".format(test_args["mongo_username"]) + ) try: - grp.getgrnam(test_args['mongo_groupname']) + grp.getgrnam(test_args["mongo_groupname"]) except KeyError: - raise RuntimeError("Required group missing: {}".format(test_args['mongo_username'])) + raise RuntimeError( + "Required group missing: {}".format(test_args["mongo_username"]) + ) # All of the supplemental groups (the .deb pattern) mongo_user_groups = [ - g.gr_name for g in grp.getgrall() if test_args['mongo_username'] in g.gr_mem + g.gr_name for g in grp.getgrall() if test_args["mongo_username"] in g.gr_mem ] # The user's primary group (the .rpm pattern) mongo_user_groups.append(grp.getgrgid(user_info.pw_gid).gr_name) - if test_args['mongo_groupname'] not in mongo_user_groups: - raise RuntimeError("Required group `{}' is not in configured groups: {}".format( - test_args['mongo_groupname'], mongo_user_groups)) + if test_args["mongo_groupname"] not in mongo_user_groups: + raise RuntimeError( + "Required group `{}' is not in configured groups: {}".format( + test_args["mongo_groupname"], mongo_user_groups + ) + ) - if test_args['package_manager'] in ('yum', 'zypper'): + if test_args["package_manager"] in ("yum", "zypper"): # Only RPM-based distros create the home directory. Debian/Ubuntu # distros use a non-existent directory in /home - if user_info.pw_dir != test_args['mongo_home_dir']: + if user_info.pw_dir != test_args["mongo_home_dir"]: raise RuntimeError( "Configured home directory `{}' does not match required path `{}'".format( - user_info.pw_dir, test_args['mongo_home_dir'])) + user_info.pw_dir, test_args["mongo_home_dir"] + ) + ) - if user_info.pw_shell != test_args['mongo_user_shell']: - raise RuntimeError("Configured user shell `{}' does not match required path `{}'".format( - user_info.pw_shell, test_args['mongo_user_shell'])) + if user_info.pw_shell != test_args["mongo_user_shell"]: + raise RuntimeError( + "Configured user shell `{}' does not match required path `{}'".format( + user_info.pw_shell, test_args["mongo_user_shell"] + ) + ) def test_ulimits_correct(): logging.info("Checking that mongod process limits are correct.") exec_result = run_and_log("pgrep '^mongod$'") - mongod_pid = int(exec_result.stdout.decode('utf-8').strip()) + mongod_pid = int(exec_result.stdout.decode("utf-8").strip()) ulimits = parse_ulimits(mongod_pid) - if ulimits['Max file size'][0] != -1: - raise RuntimeError("RLMIT_FSIZE != unlimited: {}".format(ulimits['Max file size'])) + if ulimits["Max file size"][0] != -1: + raise RuntimeError( + "RLMIT_FSIZE != unlimited: {}".format(ulimits["Max file size"]) + ) - if ulimits['Max cpu time'][0] != -1: - raise RuntimeError("RLMIT_CPU != unlimited: {}".format(ulimits['Max cpu time'])) + if ulimits["Max cpu time"][0] != -1: + raise RuntimeError("RLMIT_CPU != unlimited: {}".format(ulimits["Max cpu time"])) - if ulimits['Max address space'][0] != -1: - raise RuntimeError("RLMIT_AS != unlimited: {}".format(ulimits['Max address space'])) + if ulimits["Max address space"][0] != -1: + raise RuntimeError( + "RLMIT_AS != unlimited: {}".format(ulimits["Max address space"]) + ) - if ulimits['Max open files'][0] != -1 and ulimits['Max open files'][0] < 64000: - raise RuntimeError("RLMIT_NOFILE < 64000: {}".format(ulimits['Max open files'])) + if ulimits["Max open files"][0] != -1 and ulimits["Max open files"][0] < 64000: + raise RuntimeError("RLMIT_NOFILE < 64000: {}".format(ulimits["Max open files"])) - if ulimits['Max resident set'][0] != -1: - raise RuntimeError("RLMIT_RSS != unlimited: {}".format(ulimits['Max resident set'])) + if ulimits["Max resident set"][0] != -1: + raise RuntimeError( + "RLMIT_RSS != unlimited: {}".format(ulimits["Max resident set"]) + ) - if ulimits['Max processes'][0] != -1 and ulimits['Max processes'][0] < 64000: - raise RuntimeError("RLMIT_NPROC < 64000: {}".format(ulimits['Max processes'])) + if ulimits["Max processes"][0] != -1 and ulimits["Max processes"][0] < 64000: + raise RuntimeError("RLMIT_NPROC < 64000: {}".format(ulimits["Max processes"])) def test_restart(): @@ -412,14 +461,18 @@ def test_stop(): run_and_log("systemctl stop mongod.service") logging.debug("Waiting up to 60 seconds for mongod to finish shutting down...") - run_mongo_query(test_args["mongo_shell"], "db.smoke.insertOne({answer: 42})", should_fail=True) + run_mongo_query( + test_args["mongo_shell"], "db.smoke.insertOne({answer: 42})", should_fail=True + ) run_and_log("systemctl is-active mongod.service", end_on_error=False) def test_install_compass(test_args: TestArgs): - - if test_args['arch'] != "x86_64" or test_args['os_name'] not in ["ubuntu", "almalinux"]: + if test_args["arch"] != "x86_64" or test_args["os_name"] not in [ + "ubuntu", + "almalinux", + ]: logging.info( "Not installing compass on unsupported platform, see the docs: https://www.mongodb.com/docs/compass/current/install/" ) @@ -436,28 +489,33 @@ def test_install_compass(test_args: TestArgs): def test_uninstall(test_args: TestArgs): - logging.info("Uninstalling packages:\n\t%s", '\n\t'.join(test_args['package_names'])) + logging.info( + "Uninstalling packages:\n\t%s", "\n\t".join(test_args["package_names"]) + ) - command = '' # type: str - if test_args['package_manager'] == 'apt': - command = 'apt-get remove -y {}' - elif test_args['package_manager'] == 'yum': - command = 'yum remove -y {}' - elif test_args['package_manager'] == 'zypper': - command = 'zypper -n remove {}' + command = "" # type: str + if test_args["package_manager"] == "apt": + command = "apt-get remove -y {}" + elif test_args["package_manager"] == "yum": + command = "yum remove -y {}" + elif test_args["package_manager"] == "zypper": + command = "zypper -n remove {}" else: - raise RuntimeError("Don't know how to uninstall with package manager: {}".format( - test_args['package_manager'])) + raise RuntimeError( + "Don't know how to uninstall with package manager: {}".format( + test_args["package_manager"] + ) + ) - run_and_log(command.format(' '.join(test_args['package_names']))) + run_and_log(command.format(" ".join(test_args["package_names"]))) def test_uninstall_is_complete(test_args: TestArgs): logging.info("Checking that the uninstallation is complete.") leftover_files = [ - pathlib.Path('/usr/bin/mongod'), - pathlib.Path(test_args['systemd_units_dir']) / 'mongod.service', + pathlib.Path("/usr/bin/mongod"), + pathlib.Path(test_args["systemd_units_dir"]) / "mongod.service", ] # type: List[pathlib.Path] for path in leftover_files: @@ -473,20 +531,20 @@ if len(package_urls) == 0: package_files = download_extract_all_packages(package_urls) -package_manager = '' # type: str +package_manager = "" # type: str apt_proc = run_and_log("apt --help", end_on_error=False) yum_proc = run_and_log("yum --help", end_on_error=False) zypper_proc = run_and_log("zypper -n --help", end_on_error=False) # zypper if apt_proc.returncode == 0: run_apt_test(packages=package_files) - package_manager = 'apt' + package_manager = "apt" elif yum_proc.returncode == 0: run_yum_test(packages=package_files) - package_manager = 'yum' + package_manager = "yum" elif zypper_proc.returncode == 0: run_zypper_test(packages=package_files) - package_manager = 'zypper' + package_manager = "zypper" else: logging.error("Found no supported package manager...Failing Test\n") sys.exit(1) diff --git a/buildscripts/packager.py b/buildscripts/packager.py index 7e3ee247c23..8a38fdcd084 100755 --- a/buildscripts/packager.py +++ b/buildscripts/packager.py @@ -419,13 +419,23 @@ def get_args(distros, arch_choices): parser = argparse.ArgumentParser(description="Build MongoDB Packages") parser.add_argument( - "-s", "--server-version", help="Server version to build (e.g. 2.7.8-rc0)", required=True + "-s", + "--server-version", + help="Server version to build (e.g. 2.7.8-rc0)", + required=True, ) parser.add_argument( - "-m", "--metadata-gitspec", help="Gitspec to use for package metadata files", required=False + "-m", + "--metadata-gitspec", + help="Gitspec to use for package metadata files", + required=False, ) parser.add_argument( - "-r", "--release-number", help="RPM release number base", type=int, required=False + "-r", + "--release-number", + help="RPM release number base", + type=int, + required=False, ) parser.add_argument( "-d", @@ -436,7 +446,9 @@ def get_args(distros, arch_choices): default=[], action="append", ) - parser.add_argument("-p", "--prefix", help="Directory to build into", required=False) + parser.add_argument( + "-p", "--prefix", help="Directory to build into", required=False + ) parser.add_argument( "-a", "--arches", @@ -573,7 +585,13 @@ def unpack_binaries_into(build_os, arch, spec, where): try: sysassert(["tar", "xvzf", rootdir + "/" + tarfile(build_os, arch, spec)]) release_dir = glob("mongodb-linux-*")[0] - for releasefile in "bin", "LICENSE-Community.txt", "README", "THIRD-PARTY-NOTICES", "MPL-2": + for releasefile in ( + "bin", + "LICENSE-Community.txt", + "README", + "THIRD-PARTY-NOTICES", + "MPL-2", + ): print("moving file: %s/%s" % (release_dir, releasefile)) os.rename("%s/%s" % (release_dir, releasefile), releasefile) os.rmdir(release_dir) @@ -596,13 +614,16 @@ def make_package(distro, build_os, arch, spec, srcdir): # directory, so the debian directory is needed in all cases (and # innocuous in the debianoids' sdirs). for pkgdir in ["debian", "rpm"]: - print("Copying packaging files from %s to %s" % ("%s/%s" % (srcdir, pkgdir), sdir)) + print( + "Copying packaging files from %s to %s" % ("%s/%s" % (srcdir, pkgdir), sdir) + ) # FIXME: sh-dash-cee is bad. See if tarfile can do this. sysassert( [ "sh", "-c", - '(cd "%s" && tar cf - %s ) | (cd "%s" && tar xvf -)' % (srcdir, pkgdir, sdir), + '(cd "%s" && tar cf - %s ) | (cd "%s" && tar xvf -)' + % (srcdir, pkgdir, sdir), ] ) # Splat the binaries under sdir. The "build" stages of the @@ -737,7 +758,11 @@ Codename: %s/mongodb-org Architectures: amd64 arm64 s390x Components: %s Description: MongoDB packages -""" % (distro.repo_os_version(build_os), distro.repo_os_version(build_os), distro.repo_component()) +""" % ( + distro.repo_os_version(build_os), + distro.repo_os_version(build_os), + distro.repo_component(), + ) if os.path.exists(repo + "../../Release"): os.unlink(repo + "../../Release") if os.path.exists(repo + "../../Release.gpg"): @@ -834,7 +859,8 @@ def write_debian_changelog(path, spec, srcdir): # only commit changes if there are any if len(git_repo.index.diff("HEAD")) != 0: with git_repo.git.custom_environment( - GIT_COMMITTER_NAME="Evergreen", GIT_COMMITTER_EMAIL="evergreen@mongodb.com" + GIT_COMMITTER_NAME="Evergreen", + GIT_COMMITTER_EMAIL="evergreen@mongodb.com", ): git_repo.git.commit("--author='Evergreen <>'", "-m", "temp commit") @@ -842,7 +868,11 @@ def write_debian_changelog(path, spec, srcdir): # FIXME: make consistent with the rest of the code when we have more packaging testing print("Getting changelog for specified gitspec:", spec.metadata_gitspec()) sb = preamble + backtick( - ["sh", "-c", "git archive %s debian/changelog | tar xOf -" % spec.metadata_gitspec()] + [ + "sh", + "-c", + "git archive %s debian/changelog | tar xOf -" % spec.metadata_gitspec(), + ] ).decode("utf-8") # reset branch to original state @@ -854,10 +884,14 @@ def write_debian_changelog(path, spec, srcdir): # If the first line starts with "mongodb", it's not a revision # preamble, and so frob the version number. lines[0] = re.sub( - "^mongodb \\(.*\\)", "mongodb (%s)" % (spec.pversion(Distro("debian"))), lines[0] + "^mongodb \\(.*\\)", + "mongodb (%s)" % (spec.pversion(Distro("debian"))), + lines[0], ) # Rewrite every changelog entry starting in mongodb - lines = [re.sub("^mongodb ", "mongodb%s " % (spec.suffix()), line) for line in lines] + lines = [ + re.sub("^mongodb ", "mongodb%s " % (spec.suffix()), line) for line in lines + ] lines = [re.sub("^ --", " --", line) for line in lines] sb = "\n".join(lines) with open(path, "w") as fh: @@ -909,7 +943,8 @@ def make_rpm(distro, build_os, arch, spec, srcdir): [ "tar", "-cpzf", - topdir + "SOURCES/mongodb%s-%s.tar.gz" % (suffix, spec.pversion(distro)), + topdir + + "SOURCES/mongodb%s-%s.tar.gz" % (suffix, spec.pversion(distro)), os.path.basename(os.path.dirname(sdir)), ] ) @@ -949,7 +984,9 @@ def make_rpm(distro, build_os, arch, spec, srcdir): ensure_dir(repo_dir) # FIXME: see if some combination of shutil.copy and glob # can do this without shelling out. - sysassert(["sh", "-c", 'cp -v "%s/RPMS/%s/"*.rpm "%s"' % (topdir, distro_arch, repo_dir)]) + sysassert( + ["sh", "-c", 'cp -v "%s/RPMS/%s/"*.rpm "%s"' % (topdir, distro_arch, repo_dir)] + ) return repo_dir diff --git a/buildscripts/packager_enterprise.py b/buildscripts/packager_enterprise.py index c972da89e25..7e173a94afa 100755 --- a/buildscripts/packager_enterprise.py +++ b/buildscripts/packager_enterprise.py @@ -55,9 +55,17 @@ class EnterpriseSpec(packager.Spec): def suffix(self): """Suffix.""" if int(self.ver.split(".")[0]) >= 5: - return "-enterprise" if int(self.ver.split(".")[1]) == 0 else "-enterprise-unstable" + return ( + "-enterprise" + if int(self.ver.split(".")[1]) == 0 + else "-enterprise-unstable" + ) else: - return "-enterprise" if int(self.ver.split(".")[1]) % 2 == 0 else "-enterprise-unstable" + return ( + "-enterprise" + if int(self.ver.split(".")[1]) % 2 == 0 + else "-enterprise-unstable" + ) class EnterpriseDistro(packager.Distro): @@ -107,14 +115,26 @@ class EnterpriseDistro(packager.Distro): if re.search("^(debian|ubuntu)", self.dname): return "repo/apt/%s/dists/%s/mongodb-enterprise/%s/%s/binary-%s/" % ( - self.dname, self.repo_os_version(build_os), repo_directory, self.repo_component(), - self.archname(arch)) + self.dname, + self.repo_os_version(build_os), + repo_directory, + self.repo_component(), + self.archname(arch), + ) elif re.search("(redhat|fedora|centos|amazon)", self.dname): return "repo/yum/%s/%s/mongodb-enterprise/%s/%s/RPMS/" % ( - self.dname, self.repo_os_version(build_os), repo_directory, self.archname(arch)) + self.dname, + self.repo_os_version(build_os), + repo_directory, + self.archname(arch), + ) elif re.search("(suse)", self.dname): return "repo/zypper/%s/%s/mongodb-enterprise/%s/%s/RPMS/" % ( - self.dname, self.repo_os_version(build_os), repo_directory, self.archname(arch)) + self.dname, + self.repo_os_version(build_os), + repo_directory, + self.archname(arch), + ) else: raise Exception("BUG: unsupported platform?") @@ -125,33 +145,42 @@ class EnterpriseDistro(packager.Distro): are for redhat, the others are delegated to the super class. """ if arch == "ppc64le": - if self.dname == 'ubuntu': + if self.dname == "ubuntu": return ["ubuntu1604", "ubuntu1804"] - if self.dname == 'redhat': + if self.dname == "redhat": return ["rhel71", "rhel81", "rhel9"] return [] if arch == "s390x": - if self.dname == 'redhat': + if self.dname == "redhat": return ["rhel67", "rhel72", "rhel83", "rhel9"] - if self.dname == 'suse': + if self.dname == "suse": return ["suse11", "suse12", "suse15"] - if self.dname == 'ubuntu': + if self.dname == "ubuntu": return ["ubuntu1604", "ubuntu1804"] return [] if arch == "arm64": - if self.dname == 'ubuntu': + if self.dname == "ubuntu": return ["ubuntu1804", "ubuntu2004", "ubuntu2204", "ubuntu2404"] if arch == "aarch64": - if self.dname == 'redhat': + if self.dname == "redhat": return ["rhel82", "rhel88", "rhel90", "rhel93"] - if self.dname == 'amazon2': + if self.dname == "amazon2": return ["amazon2"] - if self.dname == 'amazon2023': + if self.dname == "amazon2023": return ["amazon2023"] return [] if re.search("(redhat|fedora|centos)", self.dname): - return ["rhel93", "rhel90", "rhel88", "rhel80", "rhel70", "rhel79", "rhel62", "rhel57"] + return [ + "rhel93", + "rhel90", + "rhel88", + "rhel80", + "rhel70", + "rhel79", + "rhel62", + "rhel57", + ] return super(EnterpriseDistro, self).build_os(arch) @@ -162,7 +191,9 @@ def main(): args = packager.get_args(distros, ARCH_CHOICES) - spec = EnterpriseSpec(args.server_version, args.metadata_gitspec, args.release_number) + spec = EnterpriseSpec( + args.server_version, args.metadata_gitspec, args.release_number + ) oldcwd = os.getcwd() srcdir = oldcwd + "/../" @@ -180,11 +211,9 @@ def main(): made_pkg = False # Build a package for each distro/spec/arch tuple, and # accumulate the repository-layout directories. - for (distro, arch) in packager.crossproduct(distros, args.arches): - + for distro, arch in packager.crossproduct(distros, args.arches): for build_os in distro.build_os(arch): if build_os in args.distros or not args.distros: - filename = tarfile(build_os, arch, spec) packager.ensure_dir(filename) shutil.copyfile(args.tarball, filename) @@ -203,7 +232,11 @@ def main(): def tarfile(build_os, arch, spec): """Return the location where we store the downloaded tarball for this package.""" - return "dl/mongodb-linux-%s-enterprise-%s-%s.tar.gz" % (spec.version(), build_os, arch) + return "dl/mongodb-linux-%s-enterprise-%s-%s.tar.gz" % ( + spec.version(), + build_os, + arch, + ) def setupdir(distro, build_os, arch, spec): @@ -214,8 +247,14 @@ def setupdir(distro, build_os, arch, spec): # the following format string is unclear, an example setupdir # would be dst/x86_64/debian-sysvinit/wheezy/mongodb-org-unstable/ # or dst/x86_64/redhat/rhel57/mongodb-org-unstable/ - return "dst/%s/%s/%s/%s%s-%s/" % (arch, distro.name(), build_os, distro.pkgbase(), - spec.suffix(), spec.pversion(distro)) + return "dst/%s/%s/%s/%s%s-%s/" % ( + arch, + distro.name(), + build_os, + distro.pkgbase(), + spec.suffix(), + spec.pversion(distro), + ) def unpack_binaries_into(build_os, arch, spec, where): @@ -228,9 +267,17 @@ def unpack_binaries_into(build_os, arch, spec, where): # thing and chdir into where and run tar there. os.chdir(where) try: - packager.sysassert(["tar", "xvzf", rootdir + "/" + tarfile(build_os, arch, spec)]) - release_dir = glob('mongodb-linux-*')[0] - for releasefile in "bin", "LICENSE-Enterprise.txt", "README", "THIRD-PARTY-NOTICES", "MPL-2": + packager.sysassert( + ["tar", "xvzf", rootdir + "/" + tarfile(build_os, arch, spec)] + ) + release_dir = glob("mongodb-linux-*")[0] + for releasefile in ( + "bin", + "LICENSE-Enterprise.txt", + "README", + "THIRD-PARTY-NOTICES", + "MPL-2", + ): os.rename("%s/%s" % (release_dir, releasefile), releasefile) os.rmdir(release_dir) except Exception: @@ -252,7 +299,9 @@ def make_package(distro, build_os, arch, spec, srcdir): # directory, so the debian directory is needed in all cases (and # innocuous in the debianoids' sdirs). for pkgdir in ["debian", "rpm"]: - print("Copying packaging files from %s to %s" % ("%s/%s" % (srcdir, pkgdir), sdir)) + print( + "Copying packaging files from %s to %s" % ("%s/%s" % (srcdir, pkgdir), sdir) + ) git_repo = git.Repo(srcdir) # get the original HEAD position of repo head_commit_sha = git_repo.head.object.hexsha @@ -262,19 +311,26 @@ def make_package(distro, build_os, arch, spec, srcdir): git_repo.git.add(all=True) # only commit changes if there are any if len(git_repo.index.diff("HEAD")) != 0: - with git_repo.git.custom_environment(GIT_COMMITTER_NAME="Evergreen", - GIT_COMMITTER_EMAIL="evergreen@mongodb.com"): + with git_repo.git.custom_environment( + GIT_COMMITTER_NAME="Evergreen", + GIT_COMMITTER_EMAIL="evergreen@mongodb.com", + ): git_repo.git.commit("--author='Evergreen <>'", "-m", "temp commit") # original command to preserve functionality # FIXME: make consistent with the rest of the code when we have more packaging testing # FIXME: sh-dash-cee is bad. See if tarfile can do this. - print("Copying packaging files from specified gitspec:", spec.metadata_gitspec()) - packager.sysassert([ - "sh", "-c", - "(cd \"%s\" && git archive %s %s/ ) | (cd \"%s\" && tar xvf -)" % - (srcdir, spec.metadata_gitspec(), pkgdir, sdir) - ]) + print( + "Copying packaging files from specified gitspec:", spec.metadata_gitspec() + ) + packager.sysassert( + [ + "sh", + "-c", + '(cd "%s" && git archive %s %s/ ) | (cd "%s" && tar xvf -)' + % (srcdir, spec.metadata_gitspec(), pkgdir, sdir), + ] + ) # reset branch to original state print("Resetting branch to original state") @@ -307,7 +363,9 @@ def make_deb_repo(repo, distro, build_os): try: dirs = { os.path.dirname(deb)[2:] - for deb in packager.backtick(["find", ".", "-name", "*.deb"]).decode('utf-8').split() + for deb in packager.backtick(["find", ".", "-name", "*.deb"]) + .decode("utf-8") + .split() } for directory in dirs: st = packager.backtick(["dpkg-scanpackages", directory, "/dev/null"]) @@ -328,7 +386,11 @@ Codename: %s/mongodb-enterprise Architectures: amd64 ppc64el s390x arm64 Components: %s Description: MongoDB packages -""" % (distro.repo_os_version(build_os), distro.repo_os_version(build_os), distro.repo_component()) +""" % ( + distro.repo_os_version(build_os), + distro.repo_os_version(build_os), + distro.repo_component(), + ) if os.path.exists(repo + "../../Release"): os.unlink(repo + "../../Release") if os.path.exists(repo + "../../Release.gpg"): @@ -337,8 +399,8 @@ Description: MongoDB packages os.chdir(repo + "../../") s2 = packager.backtick(["apt-ftparchive", "release", "."]) try: - with open("Release", 'wb') as fh: - fh.write(s1.encode('utf-8')) + with open("Release", "wb") as fh: + fh.write(s1.encode("utf-8")) fh.write(s2) finally: os.chdir(oldpwd) diff --git a/buildscripts/patch_builds/change_data.py b/buildscripts/patch_builds/change_data.py index f715f2026f1..68c7fb8bd4d 100644 --- a/buildscripts/patch_builds/change_data.py +++ b/buildscripts/patch_builds/change_data.py @@ -24,7 +24,9 @@ def _get_id_from_repo(repo: Repo) -> str: return os.path.basename(repo.working_dir) -def generate_revision_map(repos: List[Repo], revisions_data: Dict[str, str]) -> RevisionMap: +def generate_revision_map( + repos: List[Repo], revisions_data: Dict[str, str] +) -> RevisionMap: """ Generate a revision map for the given repositories using the revisions in the given file. @@ -32,7 +34,9 @@ def generate_revision_map(repos: List[Repo], revisions_data: Dict[str, str]) -> :param revisions_data: Dictionary of revisions to use for repositories. :return: Map of repositories to revisions """ - revision_map = {repo.git_dir: revisions_data.get(_get_id_from_repo(repo)) for repo in repos} + revision_map = { + repo.git_dir: revisions_data.get(_get_id_from_repo(repo)) for repo in repos + } return {k: v for k, v in revision_map.items() if v} @@ -72,7 +76,9 @@ def _modified_files_for_diff(diff: DiffIndex, log: Any) -> Set: return modified_files.union(added_files).union(renamed_files).union(deleted_files) -def find_changed_files(repo: Repo, revision_map: Optional[RevisionMap] = None) -> Set[str]: +def find_changed_files( + repo: Repo, revision_map: Optional[RevisionMap] = None +) -> Set[str]: """ Find files that were new or added to the repository between commits. @@ -85,7 +91,9 @@ def find_changed_files(repo: Repo, revision_map: Optional[RevisionMap] = None) - if not revision_map: revision_map = {} diff = repo.index.diff(None) - work_tree_files = _modified_files_for_diff(diff, LOGGER.bind(diff="working tree diff")) + work_tree_files = _modified_files_for_diff( + diff, LOGGER.bind(diff="working tree diff") + ) commit = repo.index diff = commit.diff(revision_map.get(repo.git_dir, repo.head.commit), R=True) @@ -114,7 +122,9 @@ def find_changed_files_in_repos( :param revision_map: Map of revisions to compare against for repos. :return: Set of changed files. """ - return set(chain.from_iterable([find_changed_files(repo, revision_map) for repo in repos])) + return set( + chain.from_iterable([find_changed_files(repo, revision_map) for repo in repos]) + ) def find_modified_lines_for_files( @@ -157,14 +167,18 @@ def find_modified_lines_for_files( else: start_line = int(start_line_count) if start_line <= len(lines): - line_modifications.append((start_line, lines[start_line - 1].rstrip())) + line_modifications.append( + (start_line, lines[start_line - 1].rstrip()) + ) modified_lines_and_content[file_path] = line_modifications return modified_lines_and_content def find_modified_lines_for_files_in_repos( - repos: Iterable[Repo], changed_files: List[str], revision_map: Optional[RevisionMap] = None + repos: Iterable[Repo], + changed_files: List[str], + revision_map: Optional[RevisionMap] = None, ) -> Dict[str, List[Tuple[int, str]]]: """ Find the modified lines in files with changes. diff --git a/buildscripts/pip_requirements.py b/buildscripts/pip_requirements.py index c29c96efecd..b5aee3028dc 100644 --- a/buildscripts/pip_requirements.py +++ b/buildscripts/pip_requirements.py @@ -12,6 +12,7 @@ import sys class MissingRequirements(Exception): """Raised when when verify_requirements() detects missing requirements.""" + pass @@ -31,9 +32,9 @@ def verify_requirements(silent: bool = False, executable=sys.executable): print(*args, **kwargs) def raiseSuggestion(ex, pip_pkg): - raise MissingRequirements(f"{ex}\n" - f"Try running:\n" - f" {executable} -m pip install {pip_pkg}") from ex + raise MissingRequirements( + f"{ex}\n" f"Try running:\n" f" {executable} -m pip install {pip_pkg}" + ) from ex # Import poetry. If this fails then we know the next function will fail. # This is so the user will have an easier time diagnosing the problem @@ -46,10 +47,22 @@ def verify_requirements(silent: bool = False, executable=sys.executable): try: extras = [] - if platform.machine() in set(["s390x", "ppc64le"]) and ".el9" not in platform.release(): + if ( + platform.machine() in set(["s390x", "ppc64le"]) + and ".el9" not in platform.release() + ): extras = ["--extras", "oldcrypt"] poetry_dry_run_proc = subprocess.run( - [executable, "-m", "poetry", "install", "--no-root", "--sync", "--dry-run", *extras], + [ + executable, + "-m", + "poetry", + "install", + "--no-root", + "--sync", + "--dry-run", + *extras, + ], check=True, text=True, capture_output=True, @@ -63,24 +76,33 @@ def verify_requirements(silent: bool = False, executable=sys.executable): "Detected one or more packages are out of date. " "Try running:\n" " export PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring\n" - " python3 -m poetry install --no-root --sync") + " python3 -m poetry install --no-root --sync" + ) # String match should look like the following # Package operations: 2 installs, 3 updates, 0 removals, 165 skipped - match = re.search(r"Package operations: (\d+) \w+, (\d+) \w+, (\d+) \w+, (\d+) \w+", - poetry_dry_run_proc.stdout) + match = re.search( + r"Package operations: (\d+) \w+, (\d+) \w+, (\d+) \w+, (\d+) \w+", + poetry_dry_run_proc.stdout, + ) verbose("Requirements list:") verbose(poetry_dry_run_proc.stdout) installs = int(match[1]) updates = int(match[2]) - if updates == 1 and sys.platform == 'win32' and "Updating pywin32" in poetry_dry_run_proc.stdout: + if ( + updates == 1 + and sys.platform == "win32" + and "Updating pywin32" in poetry_dry_run_proc.stdout + ): # We have no idea why pywin32 thinks it needs to be updated # We could use some more investigation into this verbose( - "Windows detected a single update to pywin32 which is known to be buggy. Continuing.") + "Windows detected a single update to pywin32 which is known to be buggy. Continuing." + ) elif installs + updates > 0: raise MissingRequirements( f"Detected one or more packages are out of date. " f"Try running:\n" f" export PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring\n" - f" {executable} -m poetry install --no-root --sync") + f" {executable} -m poetry install --no-root --sync" + ) diff --git a/buildscripts/powercycle_sentinel.py b/buildscripts/powercycle_sentinel.py index 1dd784e9a65..e307786d7cc 100755 --- a/buildscripts/powercycle_sentinel.py +++ b/buildscripts/powercycle_sentinel.py @@ -3,6 +3,7 @@ Error out when any powercycle task on the same buildvariant runs for more than 2 hours. """ + import logging import os import sys @@ -37,23 +38,32 @@ def get_evergreen_api() -> EvergreenApi: evg_api = RetryingEvergreenApi.get_api(config_file=file) return evg_api - LOGGER.error("Evergreen config not found in locations.", locations=EVERGREEN_CONFIG_LOCATIONS) + LOGGER.error( + "Evergreen config not found in locations.", locations=EVERGREEN_CONFIG_LOCATIONS + ) sys.exit(1) -def watch_tasks(task_ids: List[str], evg_api: EvergreenApi, watch_interval_secs: int) -> List[str]: +def watch_tasks( + task_ids: List[str], evg_api: EvergreenApi, watch_interval_secs: int +) -> List[str]: """Watch tasks if they run longer than exec timeout.""" watch_task_ids = task_ids[:] long_running_task_ids = [] while watch_task_ids: - LOGGER.info("Looking if powercycle tasks are still running on the current buildvariant.") + LOGGER.info( + "Looking if powercycle tasks are still running on the current buildvariant." + ) powercycle_tasks = [evg_api.task_by_id(task_id) for task_id in watch_task_ids] for task in powercycle_tasks: if task.finish_time: watch_task_ids.remove(task.task_id) - elif task.start_time and (datetime.now(timezone.utc) - task.start_time - ).total_seconds() > POWERCYCLE_TASK_EXEC_TIMEOUT_SECS: + elif ( + task.start_time + and (datetime.now(timezone.utc) - task.start_time).total_seconds() + > POWERCYCLE_TASK_EXEC_TIMEOUT_SECS + ): long_running_task_ids.append(task.task_id) watch_task_ids.remove(task.task_id) if watch_task_ids: @@ -87,23 +97,31 @@ def main(expansions_file: str = "expansions.yml") -> None: evg_api = get_evergreen_api() build_tasks = evg_api.tasks_by_build(build_id) - gen_task_id = [task.task_id for task in build_tasks if gen_task_name in task.task_id][0] + gen_task_id = [ + task.task_id for task in build_tasks if gen_task_name in task.task_id + ][0] gen_task_url = f"{EVERGREEN_HOST}/task/{gen_task_id}" while evg_api.task_by_id(gen_task_id).is_active(): LOGGER.info( - f"Waiting for '{gen_task_name}' task to generate powercycle tasks:\n{gen_task_url}") + f"Waiting for '{gen_task_name}' task to generate powercycle tasks:\n{gen_task_url}" + ) time.sleep(WATCH_INTERVAL_SECS) build_tasks = evg_api.tasks_by_build(build_id) powercycle_task_ids = [ - task.task_id for task in build_tasks - if not task.display_only and task.task_id != current_task_id and task.task_id != gen_task_id + task.task_id + for task in build_tasks + if not task.display_only + and task.task_id != current_task_id + and task.task_id != gen_task_id and "powercycle" in task.task_id ] LOGGER.info(f"Watching powercycle tasks:\n{get_links(powercycle_task_ids)}") - long_running_task_ids = watch_tasks(powercycle_task_ids, evg_api, WATCH_INTERVAL_SECS) + long_running_task_ids = watch_tasks( + powercycle_task_ids, evg_api, WATCH_INTERVAL_SECS + ) if long_running_task_ids: LOGGER.error( f"Found powercycle tasks that are running for more than {POWERCYCLE_TASK_EXEC_TIMEOUT_SECS} " @@ -111,9 +129,10 @@ def main(expansions_file: str = "expansions.yml") -> None: ) LOGGER.error( "Hopefully hosts from the tasks are still in run at the time you are seeing this " - "and the Build team is able to check them to diagnose the issue.") + "and the Build team is able to check them to diagnose the issue." + ) sys.exit(1) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/buildscripts/quickmongolint.py b/buildscripts/quickmongolint.py index c3fe3dfb4e3..c0b2fb14cf8 100755 --- a/buildscripts/quickmongolint.py +++ b/buildscripts/quickmongolint.py @@ -10,7 +10,9 @@ from typing import List # Get relative imports to work when the package is not installed on the PYTHONPATH. if __name__ == "__main__" and __package__ is None: - sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(os.path.realpath(__file__))))) + sys.path.append( + os.path.dirname(os.path.dirname(os.path.abspath(os.path.realpath(__file__)))) + ) from buildscripts.linter import ( git, # pylint: disable=wrong-import-position @@ -18,25 +20,32 @@ from buildscripts.linter import ( parallel, # pylint: disable=wrong-import-position ) -FILES_RE = re.compile('\\.(h|cpp)$') +FILES_RE = re.compile("\\.(h|cpp)$") def is_interesting_file(file_name: str) -> bool: """Return true if this file should be checked.""" - return (file_name.startswith("jstests") - or file_name.startswith("src") and not file_name.startswith("src/third_party/") - and not file_name.startswith("src/mongo/gotools/") - and not file_name.startswith("src/streams/third_party") - and not file_name.startswith("src/mongo/db/modules/enterprise/src/streams/third_party") - and not file_name.endswith(".cstruct.h") - # TODO SERVER-49805: These files should be generated at compile time. - and not file_name == "src/mongo/db/cst/parser_gen.cpp") and FILES_RE.search(file_name) + return ( + file_name.startswith("jstests") + or file_name.startswith("src") + and not file_name.startswith("src/third_party/") + and not file_name.startswith("src/mongo/gotools/") + and not file_name.startswith("src/streams/third_party") + and not file_name.startswith( + "src/mongo/db/modules/enterprise/src/streams/third_party" + ) + and not file_name.endswith(".cstruct.h") + # TODO SERVER-49805: These files should be generated at compile time. + and not file_name == "src/mongo/db/cst/parser_gen.cpp" + ) and FILES_RE.search(file_name) def _lint_files(file_names: List[str]) -> None: """Lint a list of files with clang-format.""" run_lint1 = lambda param1: mongolint.lint_file(param1) == 0 - if not parallel.parallel_process([os.path.abspath(f) for f in file_names], run_lint1): + if not parallel.parallel_process( + [os.path.abspath(f) for f in file_names], run_lint1 + ): print("ERROR: Code Style does not match coding style") sys.exit(1) @@ -77,27 +86,37 @@ def lint_my(origin_branch: List[str]) -> None: def main() -> None: """Execute Main entry point.""" - parser = argparse.ArgumentParser(description='Quick C++ Lint frontend.') + parser = argparse.ArgumentParser(description="Quick C++ Lint frontend.") - parser.add_argument('-v', "--verbose", action='store_true', help="Enable verbose logging") + parser.add_argument( + "-v", "--verbose", action="store_true", help="Enable verbose logging" + ) sub = parser.add_subparsers(title="Linter subcommands", help="sub-command help") - parser_lint = sub.add_parser('lint', help='Lint only Git files') + parser_lint = sub.add_parser("lint", help="Lint only Git files") parser_lint.add_argument("file_names", nargs="*", help="Globs of files to check") parser_lint.set_defaults(func=lint) - parser_lint_all = sub.add_parser('lint-all', help='Lint All files') - parser_lint_all.add_argument("file_names", nargs="*", help="Globs of files to check") + parser_lint_all = sub.add_parser("lint-all", help="Lint All files") + parser_lint_all.add_argument( + "file_names", nargs="*", help="Globs of files to check" + ) parser_lint_all.set_defaults(func=lint_all) - parser_lint_patch = sub.add_parser('lint-patch', help='Lint the files in a patch') - parser_lint_patch.add_argument("file_names", nargs="*", help="Globs of files to check") + parser_lint_patch = sub.add_parser("lint-patch", help="Lint the files in a patch") + parser_lint_patch.add_argument( + "file_names", nargs="*", help="Globs of files to check" + ) parser_lint_patch.set_defaults(func=lint_patch) - parser_lint_my = sub.add_parser('lint-my', help='Lint my files') - parser_lint_my.add_argument("--branch", dest="file_names", default="origin/master", - help="Branch to compare against") + parser_lint_my = sub.add_parser("lint-my", help="Lint my files") + parser_lint_my.add_argument( + "--branch", + dest="file_names", + default="origin/master", + help="Branch to compare against", + ) parser_lint_my.set_defaults(func=lint_my) args = parser.parse_args() diff --git a/buildscripts/resmoke_proxy/resmoke_proxy.py b/buildscripts/resmoke_proxy/resmoke_proxy.py index dd6b9194272..07044eefdbd 100644 --- a/buildscripts/resmoke_proxy/resmoke_proxy.py +++ b/buildscripts/resmoke_proxy/resmoke_proxy.py @@ -1,4 +1,5 @@ """A service to proxy requests to resmoke.""" + from typing import Any, Dict, List import inject diff --git a/buildscripts/resmoke_tests_runtime_validate.py b/buildscripts/resmoke_tests_runtime_validate.py index c5c9c1ca35f..c7d01b2d184 100644 --- a/buildscripts/resmoke_tests_runtime_validate.py +++ b/buildscripts/resmoke_tests_runtime_validate.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 """Utility to validate resmoke tests runtime.""" + import json import sys from collections import namedtuple @@ -30,14 +31,20 @@ def parse_resmoke_report(report_file: str) -> List[TestInfo]: with open(report_file, "r") as fh: report_data = json.load(fh) test_report = TestReport.from_dict(report_data) - return [test_info for test_info in test_report.test_infos if "jstests" in test_info.test_file] + return [ + test_info + for test_info in test_report.test_infos + if "jstests" in test_info.test_file + ] -def get_historic_stats(project_id: str, task_name: str, - build_variant: str) -> List[HistoricalTestInformation]: +def get_historic_stats( + project_id: str, task_name: str, build_variant: str +) -> List[HistoricalTestInformation]: """Get historic test stats.""" base_task_name = get_task_name_without_suffix(task_name, build_variant).replace( - BURN_IN_PREFIX, "") + BURN_IN_PREFIX, "" + ) return HistoricTaskData.get_stats_from_s3(project_id, base_task_name, build_variant) @@ -55,26 +62,38 @@ def make_stats_map(stats: List[_TestData]) -> Dict[str, List[float]]: @click.command() -@click.option("--resmoke-report-file", type=str, required=True, - help="Location of resmoke's report JSON file.") +@click.option( + "--resmoke-report-file", + type=str, + required=True, + help="Location of resmoke's report JSON file.", +) @click.option("--project-id", type=str, required=True, help="Evergreen project id.") -@click.option("--build-variant", type=str, required=True, help="Evergreen build variant name.") +@click.option( + "--build-variant", type=str, required=True, help="Evergreen build variant name." +) @click.option("--task-name", type=str, required=True, help="Evergreen task name.") -def main(resmoke_report_file: str, project_id: str, build_variant: str, task_name: str) -> None: +def main( + resmoke_report_file: str, project_id: str, build_variant: str, task_name: str +) -> None: """Compare resmoke tests runtime with historic stats.""" enable_logging(verbose=False) current_test_infos = parse_resmoke_report(resmoke_report_file) - current_stats_map = make_stats_map([ - _TestData(test_info.test_file, test_info.end_time - test_info.start_time) - for test_info in current_test_infos - ]) + current_stats_map = make_stats_map( + [ + _TestData(test_info.test_file, test_info.end_time - test_info.start_time) + for test_info in current_test_infos + ] + ) historic_stats = get_historic_stats(project_id, task_name, build_variant) - historic_stats_map = make_stats_map([ - _TestData(test_stats.test_name, test_stats.avg_duration_pass) - for test_stats in historic_stats - ]) + historic_stats_map = make_stats_map( + [ + _TestData(test_stats.test_name, test_stats.avg_duration_pass) + for test_stats in historic_stats + ] + ) failed = False @@ -88,9 +107,13 @@ def main(resmoke_report_file: str, project_id: str, build_variant: str, task_nam historic_max = max(historic_test_stats) target = historic_max * HISTORIC_MAX_MULTIPLIER if current_mean > target: - LOGGER.error("Found long running test.", test_file=test, - current_mean_time=current_mean, maximum_expected_time=target, - historic_max_time=historic_max) + LOGGER.error( + "Found long running test.", + test_file=test, + current_mean_time=current_mean, + maximum_expected_time=target, + historic_max_time=historic_max, + ) failed = True LOGGER.info("Done comparing resmoke tests runtime with historic stats.") @@ -99,15 +122,21 @@ def main(resmoke_report_file: str, project_id: str, build_variant: str, task_nam LOGGER.error( f"The test failed due to its runtime taking {percent}% more than the recent max" " and can negatively contribute to the future patch build experience." - " Consider checking if there is an unexpected regression.") - LOGGER.error("If the test is being intentionally expanded, please split it up into separate" - " JS files that run as separate tests.") + " Consider checking if there is an unexpected regression." + ) + LOGGER.error( + "If the test is being intentionally expanded, please split it up into separate" + " JS files that run as separate tests." + ) LOGGER.error( "If you believe the test has inherently large variability, please consider writing" - " a new test instead of modifying this one.") - LOGGER.error("For any other questions or concerns, please reach out to #server-testing.") + " a new test instead of modifying this one." + ) + LOGGER.error( + "For any other questions or concerns, please reach out to #server-testing." + ) sys.exit(1) -if __name__ == '__main__': +if __name__ == "__main__": main() # pylint: disable=no-value-for-parameter diff --git a/buildscripts/resmokelib/cli.py b/buildscripts/resmokelib/cli.py index 0ba8304c67b..22d23aef6c9 100644 --- a/buildscripts/resmokelib/cli.py +++ b/buildscripts/resmokelib/cli.py @@ -17,13 +17,14 @@ def main(argv): :return: None """ __start_time = time.time() - os.environ['RESMOKE_PARENT_PROCESS'] = str(os.getpid()) - os.environ['RESMOKE_PARENT_CTIME'] = str(psutil.Process().create_time()) + os.environ["RESMOKE_PARENT_PROCESS"] = str(os.getpid()) + os.environ["RESMOKE_PARENT_CTIME"] = str(psutil.Process().create_time()) subcommand = parser.parse_command_line( - argv[1:], start_time=__start_time, + argv[1:], + start_time=__start_time, usage="Resmoke is MongoDB's correctness testing orchestrator.\n" "For more information, see the help message for each subcommand.\n" "For example: resmoke.py run -h\n" - "Note: bisect, setup-multiversion and symbolize subcommands have been moved to db-contrib-tool (https://github.com/10gen/db-contrib-tool#readme).\n" + "Note: bisect, setup-multiversion and symbolize subcommands have been moved to db-contrib-tool (https://github.com/10gen/db-contrib-tool#readme).\n", ) subcommand.execute() diff --git a/buildscripts/resmokelib/config.py b/buildscripts/resmokelib/config.py index 8a33f173262..6d165033786 100644 --- a/buildscripts/resmokelib/config.py +++ b/buildscripts/resmokelib/config.py @@ -132,10 +132,8 @@ DEFAULTS = { "shell_tls_certificate_key_file": None, "mongos_tls_certificate_key_file": None, "mongod_tls_certificate_key_file": None, - # Internal testing options. "internal_params": [], - # Evergreen options. "evergreen_url": "evergreen.mongodb.com", "build_id": None, @@ -152,65 +150,55 @@ DEFAULTS = { "version_id": None, "work_dir": None, "evg_project_config_path": "etc/evergreen.yml", - # WiredTiger options. "wt_coll_config": None, "wt_engine_config": None, "wt_index_config": None, - # Benchmark options. "benchmark_filter": None, "benchmark_list_tests": None, "benchmark_min_time_secs": None, "benchmark_repetitions": None, - # Config Dir "config_dir": "buildscripts/resmokeconfig", - # Directory with jstests "jstests_dir": "jstests", - # UndoDB options "undo_recorder_path": None, - # Generate multiversion exclude tags options "exclude_tags_file_path": "generated_resmoke_config/multiversion_exclude_tags.yml", - # Limit the number of tests to execute "max_test_queue_size": None, - # Sanity check "sanity_check": False, - # otel info "otel_trace_id": None, "otel_parent_id": None, "otel_collector_dir": None, - # The images to build for an External System Under Test "docker_compose_build_images": None, - # Where the `--dockerComposeBuildImages` is happening. "docker_compose_build_env": "local", - # Tag to use for images built & used for an External System Under Test "docker_compose_tag": "development", - # Whether or not this resmoke suite is running against an External System Under Test "external_sut": False, } -_SuiteOptions = collections.namedtuple("_SuiteOptions", [ - "description", - "fail_fast", - "include_tags", - "num_jobs", - "num_repeat_suites", - "num_repeat_tests", - "num_repeat_tests_max", - "num_repeat_tests_min", - "time_repeat_tests_secs", -]) +_SuiteOptions = collections.namedtuple( + "_SuiteOptions", + [ + "description", + "fail_fast", + "include_tags", + "num_jobs", + "num_repeat_suites", + "num_repeat_tests", + "num_repeat_tests_max", + "num_repeat_tests_min", + "time_repeat_tests_secs", + ], +) class SuiteOptions(_SuiteOptions): @@ -247,7 +235,9 @@ class SuiteOptions(_SuiteOptions): combined_value = combined_options[field] if combined_value is not cls.INHERIT and combined_value != value: - raise ValueError("Attempted to set '{}' option multiple times".format(field)) + raise ValueError( + "Attempted to set '{}' option multiple times".format(field) + ) combined_options[field] = value if include_tags_list: @@ -266,17 +256,22 @@ class SuiteOptions(_SuiteOptions): include_tags = None parent = dict( list( - zip(SuiteOptions._fields, [ - description, - FAIL_FAST, - include_tags, - JOBS, - REPEAT_SUITES, - REPEAT_TESTS, - REPEAT_TESTS_MAX, - REPEAT_TESTS_MIN, - REPEAT_TESTS_SECS, - ]))) + zip( + SuiteOptions._fields, + [ + description, + FAIL_FAST, + include_tags, + JOBS, + REPEAT_SUITES, + REPEAT_TESTS, + REPEAT_TESTS_MAX, + REPEAT_TESTS_MIN, + REPEAT_TESTS_SECS, + ], + ) + ) + ) options = self._asdict() for field in SuiteOptions._fields: @@ -287,7 +282,8 @@ class SuiteOptions(_SuiteOptions): SuiteOptions.ALL_INHERITED = SuiteOptions( # type: ignore - **dict(list(zip(SuiteOptions._fields, itertools.repeat(SuiteOptions.INHERIT))))) + **dict(list(zip(SuiteOptions._fields, itertools.repeat(SuiteOptions.INHERIT)))) +) class MultiversionOptions(object): @@ -662,14 +658,31 @@ BENCHMARK_OUT_FORMAT = "json" ORDER_TESTS_BY_NAME = True # Default file names for externally generated lists of tests created during the build. -DEFAULT_BENCHMARK_TEST_LIST = "bazel-bin/install/install-mongo_benchmark-stripped_test_list.txt" +DEFAULT_BENCHMARK_TEST_LIST = ( + "bazel-bin/install/install-mongo_benchmark-stripped_test_list.txt" +) DEFAULT_UNIT_TEST_LIST = "bazel-bin/install/install-mongo_unittest_test_list.txt" -DEFAULT_INTEGRATION_TEST_LIST = "bazel-bin/install/install-mongo_integration_test_test_list.txt" -DEFAULT_LIBFUZZER_TEST_LIST = "bazel-bin/install/install-mongo_fuzzer_test_test_list.txt" -DEFAULT_PRETTY_PRINTER_TEST_LIST = "bazel-bin/install/install-dist-test-stripped_test_list.txt" +DEFAULT_INTEGRATION_TEST_LIST = ( + "bazel-bin/install/install-mongo_integration_test_test_list.txt" +) +DEFAULT_LIBFUZZER_TEST_LIST = ( + "bazel-bin/install/install-mongo_fuzzer_test_test_list.txt" +) +DEFAULT_PRETTY_PRINTER_TEST_LIST = ( + "bazel-bin/install/install-dist-test-stripped_test_list.txt" +) SPLIT_UNITTESTS_LISTS = [ f"bazel-bin/install/install-mongo_unittest_{test_group}_group_test_list.txt" - for test_group in ["first", "second", "third", "fourth", "fifth", "sixth", "seventh", "eighth"] + for test_group in [ + "first", + "second", + "third", + "fourth", + "fifth", + "sixth", + "seventh", + "eighth", + ] ] BENCHMARK_SUITE_TEST_LISTS = [ "bazel-bin/install/install-repl_bm_test_list.txt", @@ -682,10 +695,16 @@ BENCHMARK_SUITE_TEST_LISTS = [ ] # External files or executables, used as suite selectors, that are created during the build and # therefore might not be available when creating a test membership map. -EXTERNAL_SUITE_SELECTORS = (DEFAULT_BENCHMARK_TEST_LIST, DEFAULT_UNIT_TEST_LIST, - DEFAULT_INTEGRATION_TEST_LIST, DEFAULT_DBTEST_EXECUTABLE, - DEFAULT_LIBFUZZER_TEST_LIST, DEFAULT_PRETTY_PRINTER_TEST_LIST, - *SPLIT_UNITTESTS_LISTS, *BENCHMARK_SUITE_TEST_LISTS) +EXTERNAL_SUITE_SELECTORS = ( + DEFAULT_BENCHMARK_TEST_LIST, + DEFAULT_UNIT_TEST_LIST, + DEFAULT_INTEGRATION_TEST_LIST, + DEFAULT_DBTEST_EXECUTABLE, + DEFAULT_LIBFUZZER_TEST_LIST, + DEFAULT_PRETTY_PRINTER_TEST_LIST, + *SPLIT_UNITTESTS_LISTS, + *BENCHMARK_SUITE_TEST_LISTS, +) # Where to look for logging and suite configuration files CONFIG_DIR = None @@ -707,7 +726,7 @@ USE_LEGACY_MULTIVERSION = True # Expansions file location # in CI, the expansions file is located in the ${workdir}, one dir up # from src, the checkout directory -EXPANSIONS_FILE = "../expansions.yml" if 'CI' in os.environ else "expansions.yml" +EXPANSIONS_FILE = "../expansions.yml" if "CI" in os.environ else "expansions.yml" # Symbolizer secrets SYMBOLIZER_CLIENT_SECRET = None diff --git a/buildscripts/resmokelib/configure_resmoke.py b/buildscripts/resmokelib/configure_resmoke.py index 83705d753e9..8d40ab38627 100644 --- a/buildscripts/resmokelib/configure_resmoke.py +++ b/buildscripts/resmokelib/configure_resmoke.py @@ -48,16 +48,19 @@ def validate_and_update_config(parser, args): def _validate_options(parser, args): """Do preliminary validation on the options and error on any invalid options.""" - if 'shell_port' not in args or 'shell_conn_string' not in args: + if "shell_port" not in args or "shell_conn_string" not in args: return if args.shell_port is not None and args.shell_conn_string is not None: parser.error("Cannot specify both `shellPort` and `shellConnString`") if args.executor_file: - parser.error("--executor is superseded by --suites; specify --suites={} {} to run the" - " test(s) under those suite configuration(s)".format( - args.executor_file, " ".join(args.test_files))) + parser.error( + "--executor is superseded by --suites; specify --suites={} {} to run the" + " test(s) under those suite configuration(s)".format( + args.executor_file, " ".join(args.test_files) + ) + ) # The "test_files" positional argument logically overlaps with `--replayFile`. Disallow using both. if args.test_files and args.replay_file: @@ -69,7 +72,8 @@ def _validate_options(parser, args): parser.error("The --shellSeed argument must be used with only one test.") if args.additional_feature_flags_file and not os.path.isfile( - args.additional_feature_flags_file): + args.additional_feature_flags_file + ): parser.error("The specified additional feature flags file does not exist.") def get_set_param_errors(process_params): @@ -93,11 +97,18 @@ def _validate_options(parser, args): return errors config = vars(args) - mongod_set_param_errors = get_set_param_errors(config.get('mongod_set_parameters') or []) - mongos_set_param_errors = get_set_param_errors(config.get('mongos_set_parameters') or []) + mongod_set_param_errors = get_set_param_errors( + config.get("mongod_set_parameters") or [] + ) + mongos_set_param_errors = get_set_param_errors( + config.get("mongos_set_parameters") or [] + ) mongocryptd_set_param_errors = get_set_param_errors( - config.get('mongocryptd_set_parameters') or []) - mongo_set_param_errors = get_set_param_errors(config.get('mongo_set_parameters') or []) + config.get("mongocryptd_set_parameters") or [] + ) + mongo_set_param_errors = get_set_param_errors( + config.get("mongo_set_parameters") or [] + ) error_msgs = {} if mongod_set_param_errors: error_msgs["mongodSetParameters"] = mongod_set_param_errors @@ -129,15 +140,21 @@ def _validate_config(parser): if _config.MIXED_BIN_VERSIONS is not None: for version in _config.MIXED_BIN_VERSIONS: - if version not in set(['old', 'new']): - parser.error("Must specify binary versions as 'old' or 'new' in format" - " 'version1-version2'") + if version not in set(["old", "new"]): + parser.error( + "Must specify binary versions as 'old' or 'new' in format" + " 'version1-version2'" + ) if _config.UNDO_RECORDER_PATH is not None: - if not sys.platform.startswith('linux') or platform.machine() not in [ - "i386", "i686", "x86_64" + if not sys.platform.startswith("linux") or platform.machine() not in [ + "i386", + "i686", + "x86_64", ]: - parser.error("--recordWith is only supported on x86 and x86_64 Linux distributions") + parser.error( + "--recordWith is only supported on x86 and x86_64 Linux distributions" + ) return resolved_path = shutil.which(_config.UNDO_RECORDER_PATH) @@ -156,9 +173,13 @@ def _validate_config(parser): if _config.TLS_CA_FILE: parser.error("--tlsCAFile requires server TLS to be enabled") if _config.MONGOD_TLS_CERTIFICATE_KEY_FILE: - parser.error("--mongodTlsCertificateKeyFile requires server TLS to be enabled") + parser.error( + "--mongodTlsCertificateKeyFile requires server TLS to be enabled" + ) if _config.MONGOS_TLS_CERTIFICATE_KEY_FILE: - parser.error("--mongosTlsCertificateKeyFile requires server TLS to be enabled") + parser.error( + "--mongosTlsCertificateKeyFile requires server TLS to be enabled" + ) if not _config.SHELL_TLS_ENABLED: if _config.SHELL_TLS_CERTIFICATE_KEY_FILE: @@ -179,10 +200,10 @@ def _find_resmoke_wrappers(): def _set_up_tracing( - otel_collector_dir: Optional[str], - trace_id: Optional[str], - parent_span_id: Optional[str], - extra_context: Optional[Dict[str, object]], + otel_collector_dir: Optional[str], + trace_id: Optional[str], + parent_span_id: Optional[str], + extra_context: Optional[Dict[str, object]], ) -> bool: """Try to set up otel tracing. On success return True. On failure return False. @@ -209,7 +230,8 @@ def _set_up_tracing( # Make the file easy to read when ran locally. pretty_print = _config.EVERGREEN_TASK_ID is None processor = BatchedBaggageSpanProcessor( - FileSpanExporter(otel_collector_dir, pretty_print)) + FileSpanExporter(otel_collector_dir, pretty_print) + ) provider.add_span_processor(processor) except OSError: traceback.print_exc() @@ -276,13 +298,17 @@ be invoked as either: def setup_feature_flags(): _config.RUN_ALL_FEATURE_FLAG_TESTS = config.pop("run_all_feature_flag_tests") _config.RUN_NO_FEATURE_FLAG_TESTS = config.pop("run_no_feature_flag_tests") - _config.ADDITIONAL_FEATURE_FLAGS_FILE = config.pop("additional_feature_flags_file") + _config.ADDITIONAL_FEATURE_FLAGS_FILE = config.pop( + "additional_feature_flags_file" + ) if values.command == "run": # These logging messages start with # becuase the output of this file must produce # valid yaml. This comments out these print statements when the output is parsed. print("# Fetching feature flags...") - all_ff = gen_all_feature_flag_list.get_all_feature_flags_turned_off_by_default() + all_ff = ( + gen_all_feature_flag_list.get_all_feature_flags_turned_off_by_default() + ) print("# Fetched feature flags...") else: all_ff = [] @@ -293,18 +319,23 @@ be invoked as either: if _config.ADDITIONAL_FEATURE_FLAGS_FILE: enabled_feature_flags.extend( - process_feature_flag_file(_config.ADDITIONAL_FEATURE_FLAGS_FILE)) + process_feature_flag_file(_config.ADDITIONAL_FEATURE_FLAGS_FILE) + ) # Specify additional feature flags from the command line. # Set running all feature flag tests to True if this options is specified. - additional_feature_flags = _tags_from_list(config.pop("additional_feature_flags")) + additional_feature_flags = _tags_from_list( + config.pop("additional_feature_flags") + ) if additional_feature_flags is not None: enabled_feature_flags.extend(additional_feature_flags) return enabled_feature_flags, all_ff _config.ENABLED_FEATURE_FLAGS, all_feature_flags = setup_feature_flags() - not_enabled_feature_flags = list(set(all_feature_flags) - set(_config.ENABLED_FEATURE_FLAGS)) + not_enabled_feature_flags = list( + set(all_feature_flags) - set(_config.ENABLED_FEATURE_FLAGS) + ) _config.AUTO_KILL = config.pop("auto_kill") _config.ALWAYS_USE_LOG_FILES = config.pop("always_use_log_files") @@ -317,9 +348,12 @@ be invoked as either: # EXCLUDE_WITH_ANY_TAGS will always contain the implicitly defined EXCLUDED_TAG. _config.EXCLUDE_WITH_ANY_TAGS = [_config.EXCLUDED_TAG] _config.EXCLUDE_WITH_ANY_TAGS.extend( - utils.default_if_none(_tags_from_list(config.pop("exclude_with_any_tags")), [])) + utils.default_if_none(_tags_from_list(config.pop("exclude_with_any_tags")), []) + ) - with open("buildscripts/resmokeconfig/fully_disabled_feature_flags.yml") as fully_disabled_ffs: + with open( + "buildscripts/resmokeconfig/fully_disabled_feature_flags.yml" + ) as fully_disabled_ffs: force_disabled_flags = yaml.safe_load(fully_disabled_ffs) _config.EXCLUDE_WITH_ANY_TAGS.extend(force_disabled_flags) @@ -331,11 +365,17 @@ be invoked as either: # Don't run tests with feature flags that are not enabled. _config.EXCLUDE_WITH_ANY_TAGS.extend(not_enabled_feature_flags) _config.EXCLUDE_WITH_ANY_TAGS.extend( - [f"{feature_flag}_incompatible" for feature_flag in _config.ENABLED_FEATURE_FLAGS]) + [ + f"{feature_flag}_incompatible" + for feature_flag in _config.ENABLED_FEATURE_FLAGS + ] + ) _config.DOCKER_COMPOSE_BUILD_IMAGES = config.pop("docker_compose_build_images") if _config.DOCKER_COMPOSE_BUILD_IMAGES is not None: - _config.DOCKER_COMPOSE_BUILD_IMAGES = _config.DOCKER_COMPOSE_BUILD_IMAGES.split(",") + _config.DOCKER_COMPOSE_BUILD_IMAGES = _config.DOCKER_COMPOSE_BUILD_IMAGES.split( + "," + ) _config.DOCKER_COMPOSE_BUILD_ENV = config.pop("docker_compose_build_env") _config.DOCKER_COMPOSE_TAG = config.pop("docker_compose_tag") _config.EXTERNAL_SUT = config.pop("external_sut") @@ -344,7 +384,9 @@ be invoked as either: # (1) We are building images for an External SUT, OR ... # (2) We are running resmoke against an External SUT # This option needs to be set before the _config.CONFIG_SHARD option below - _config.NOOP_MONGO_D_S_PROCESSES = _config.DOCKER_COMPOSE_BUILD_IMAGES is not None or _config.EXTERNAL_SUT + _config.NOOP_MONGO_D_S_PROCESSES = ( + _config.DOCKER_COMPOSE_BUILD_IMAGES is not None or _config.EXTERNAL_SUT + ) # When running resmoke against an External SUT, we are expected to be in # the workload container -- which may require additional setup before running tests. @@ -385,7 +427,9 @@ or explicitly pass --installDir to the run subcommand of buildscripts/resmoke.py # Normalize the path so that on Windows dist-test/bin # translates to .\dist-test\bin then absolutify it since the # Windows PATH variable requires absolute paths. - _config.INSTALL_DIR = os.path.abspath(_expand_user(os.path.normpath(_config.INSTALL_DIR))) + _config.INSTALL_DIR = os.path.abspath( + _expand_user(os.path.normpath(_config.INSTALL_DIR)) + ) for binary in ["mongo", "mongod", "mongos", "mongot-localdev/mongot", "dbtest"]: keyname = binary + "_executable" @@ -415,9 +459,16 @@ or explicitly pass --installDir to the run subcommand of buildscripts/resmoke.py _config.CONFIG_FUZZ_SEED = random.randrange(sys.maxsize) else: _config.CONFIG_FUZZ_SEED = int(_config.CONFIG_FUZZ_SEED) - _config.MONGOD_SET_PARAMETERS, _config.WT_ENGINE_CONFIG, _config.WT_COLL_CONFIG, \ - _config.WT_INDEX_CONFIG = mongo_fuzzer_configs.fuzz_mongod_set_parameters( - _config.FUZZ_MONGOD_CONFIGS, _config.CONFIG_FUZZ_SEED, _config.MONGOD_SET_PARAMETERS) + ( + _config.MONGOD_SET_PARAMETERS, + _config.WT_ENGINE_CONFIG, + _config.WT_COLL_CONFIG, + _config.WT_INDEX_CONFIG, + ) = mongo_fuzzer_configs.fuzz_mongod_set_parameters( + _config.FUZZ_MONGOD_CONFIGS, + _config.CONFIG_FUZZ_SEED, + _config.MONGOD_SET_PARAMETERS, + ) _config.EXCLUDE_WITH_ANY_TAGS.extend(["uses_compact"]) _config.EXCLUDE_WITH_ANY_TAGS.extend(["requires_emptycapped"]) @@ -432,12 +483,19 @@ or explicitly pass --installDir to the run subcommand of buildscripts/resmoke.py _config.CONFIG_FUZZ_SEED = int(_config.CONFIG_FUZZ_SEED) _config.MONGOS_SET_PARAMETERS = mongo_fuzzer_configs.fuzz_mongos_set_parameters( - _config.FUZZ_MONGOS_CONFIGS, _config.CONFIG_FUZZ_SEED, _config.MONGOS_SET_PARAMETERS) + _config.FUZZ_MONGOS_CONFIGS, + _config.CONFIG_FUZZ_SEED, + _config.MONGOS_SET_PARAMETERS, + ) - _config.MONGOCRYPTD_SET_PARAMETERS = _merge_set_params(config.pop("mongocryptd_set_parameters")) + _config.MONGOCRYPTD_SET_PARAMETERS = _merge_set_params( + config.pop("mongocryptd_set_parameters") + ) _config.MONGO_SET_PARAMETERS = _merge_set_params(config.pop("mongo_set_parameters")) - _config.MONGOT_EXECUTABLE = _expand_user(config.pop("mongot-localdev/mongot_executable")) + _config.MONGOT_EXECUTABLE = _expand_user( + config.pop("mongot-localdev/mongot_executable") + ) mongot_set_parameters = config.pop("mongot_set_parameters") _config.MONGOT_SET_PARAMETERS = _merge_set_params(mongot_set_parameters) @@ -449,12 +507,19 @@ or explicitly pass --installDir to the run subcommand of buildscripts/resmoke.py _config.TLS_MODE = config.pop("tls_mode") _config.TLS_CA_FILE = config.pop("tls_ca_file") _config.SHELL_TLS_ENABLED = config.pop("shell_tls_enabled") - _config.SHELL_TLS_CERTIFICATE_KEY_FILE = config.pop("shell_tls_certificate_key_file") - _config.MONGOD_TLS_CERTIFICATE_KEY_FILE = config.pop("mongod_tls_certificate_key_file") - _config.MONGOS_TLS_CERTIFICATE_KEY_FILE = config.pop("mongos_tls_certificate_key_file") + _config.SHELL_TLS_CERTIFICATE_KEY_FILE = config.pop( + "shell_tls_certificate_key_file" + ) + _config.MONGOD_TLS_CERTIFICATE_KEY_FILE = config.pop( + "mongod_tls_certificate_key_file" + ) + _config.MONGOS_TLS_CERTIFICATE_KEY_FILE = config.pop( + "mongos_tls_certificate_key_file" + ) _config.NUM_SHARDS = config.pop("num_shards") _config.CONFIG_SHARD = utils.pick_catalog_shard_node( - config.pop("config_shard"), _config.NUM_SHARDS) + config.pop("config_shard"), _config.NUM_SHARDS + ) _config.EMBEDDED_ROUTER = config.pop("embedded_router") _config.ORIGIN_SUITE = config.pop("origin_suite") _config.CEDAR_REPORT_FILE = config.pop("cedar_report_file") @@ -542,7 +607,10 @@ or explicitly pass --installDir to the run subcommand of buildscripts/resmoke.py _config.ARCHIVE_FILE = None else: # Enable archival globally for all mainline variants. - if _config.EVERGREEN_VARIANT_NAME is not None and not _config.EVERGREEN_PATCH_BUILD: + if ( + _config.EVERGREEN_VARIANT_NAME is not None + and not _config.EVERGREEN_PATCH_BUILD + ): _config.FORCE_ARCHIVE_ALL_DATA_FILES = True _config.ARCHIVE_LIMIT_MB = config.pop("archive_limit_mb") @@ -575,9 +643,12 @@ or explicitly pass --installDir to the run subcommand of buildscripts/resmoke.py # Configure evergreen task documentation if _config.EVERGREEN_TASK_NAME: - task_name = utils.get_task_name_without_suffix(_config.EVERGREEN_TASK_NAME, - _config.EVERGREEN_VARIANT_NAME) - evg_task_doc_file = os.path.join(_config.CONFIG_DIR, "evg_task_doc", "evg_task_doc.yml") + task_name = utils.get_task_name_without_suffix( + _config.EVERGREEN_TASK_NAME, _config.EVERGREEN_VARIANT_NAME + ) + evg_task_doc_file = os.path.join( + _config.CONFIG_DIR, "evg_task_doc", "evg_task_doc.yml" + ) if os.path.exists(evg_task_doc_file): evg_task_doc = utils.load_yaml_file(evg_task_doc_file) if task_name in evg_task_doc: @@ -596,7 +667,9 @@ or explicitly pass --installDir to the run subcommand of buildscripts/resmoke.py # Treat `resmoke run @to_replay` as `resmoke run --replayFile to_replay` if len(test_files) == 1 and test_files[0].startswith("@"): to_replay = test_files[0][1:] - elif len(test_files) > 1 and any(test_file.startswith("@") for test_file in test_files): + elif len(test_files) > 1 and any( + test_file.startswith("@") for test_file in test_files + ): parser.error( "Cannot use @replay with additional test files listed on the command line invocation." ) @@ -652,7 +725,9 @@ def _set_logging_config(): if os.path.exists(pathname): logger_config = utils.load_yaml_file(pathname) _config.LOGGING_CONFIG = logger_config.pop("logging") - _config.SHORTEN_LOGGER_NAME_CONFIG = logger_config.pop("shorten_logger_name") + _config.SHORTEN_LOGGER_NAME_CONFIG = logger_config.pop( + "shorten_logger_name" + ) return root = os.path.abspath(_config.LOGGER_DIR) @@ -662,10 +737,14 @@ def _set_logging_config(): if ext in (".yml", ".yaml") and short_name == pathname: config_file = os.path.join(root, filename) if not os.path.isfile(config_file): - raise ValueError("Expected a logger YAML config, but got '%s'" % pathname) + raise ValueError( + "Expected a logger YAML config, but got '%s'" % pathname + ) logger_config = utils.load_yaml_file(config_file) _config.LOGGING_CONFIG = logger_config.pop("logging") - _config.SHORTEN_LOGGER_NAME_CONFIG = logger_config.pop("shorten_logger_name") + _config.SHORTEN_LOGGER_NAME_CONFIG = logger_config.pop( + "shorten_logger_name" + ) return raise ValueError("Unknown logger '%s'" % pathname) @@ -730,8 +809,9 @@ def add_otel_args(parser: argparse.ArgumentParser): ) -def detect_evergreen_config(parsed_args: argparse.Namespace, - expansions_file: str = "../expansions.yml"): +def detect_evergreen_config( + parsed_args: argparse.Namespace, expansions_file: str = "../expansions.yml" +): if not os.path.exists(expansions_file): return @@ -748,4 +828,6 @@ def detect_evergreen_config(parsed_args: argparse.Namespace, parsed_args.variant_name = expansions.get("build_variant", None) parsed_args.version_id = expansions.get("version_id", None) parsed_args.work_dir = expansions.get("workdir", None) - parsed_args.evg_project_config_path = expansions.get("evergreen_config_file_path", None) + parsed_args.evg_project_config_path = expansions.get( + "evergreen_config_file_path", None + ) diff --git a/buildscripts/resmokelib/core/network.py b/buildscripts/resmokelib/core/network.py index 1f77b396e05..830d25d4bb2 100644 --- a/buildscripts/resmokelib/core/network.py +++ b/buildscripts/resmokelib/core/network.py @@ -22,8 +22,10 @@ def _check_port(func): raise errors.PortAllocationError("Attempted to use a negative port") if port > PortAllocator.MAX_PORT: - raise errors.PortAllocationError("Exhausted all available ports. Consider decreasing" - " the number of jobs, or using a lower base port") + raise errors.PortAllocationError( + "Exhausted all available ports. Consider decreasing" + " the number of jobs, or using a lower base port" + ) return port @@ -73,8 +75,9 @@ class PortAllocator(object): if next_port >= start_port + cls._PORTS_PER_FIXTURE: raise errors.PortAllocationError( - "Fixture has requested more than the %d ports reserved per fixture" % - cls._PORTS_PER_FIXTURE) + "Fixture has requested more than the %d ports reserved per fixture" + % cls._PORTS_PER_FIXTURE + ) return next_port @@ -86,7 +89,9 @@ class PortAllocator(object): Raises a PortAllocationError if that port is higher than the maximum port. """ - return config.BASE_PORT + (job_num * cls._PORTS_PER_JOB) + cls._PORTS_PER_FIXTURE + return ( + config.BASE_PORT + (job_num * cls._PORTS_PER_JOB) + cls._PORTS_PER_FIXTURE + ) @classmethod @_check_port diff --git a/buildscripts/resmokelib/core/pipe.py b/buildscripts/resmokelib/core/pipe.py index e1d322d677d..fd00c1e3302 100644 --- a/buildscripts/resmokelib/core/pipe.py +++ b/buildscripts/resmokelib/core/pipe.py @@ -4,6 +4,7 @@ Helper class to read output of a subprocess. Used to avoid deadlocks from the pipe buffer filling up and blocking the subprocess while it's being waited on. """ + import threading from textwrap import wrap from typing import List diff --git a/buildscripts/resmokelib/core/process.py b/buildscripts/resmokelib/core/process.py index 3e7d1bbb774..1df3ae2e521 100644 --- a/buildscripts/resmokelib/core/process.py +++ b/buildscripts/resmokelib/core/process.py @@ -50,17 +50,20 @@ if sys.platform == "win32": job_object = win32job.CreateJobObject(None, "") # Get the limit and job state information of the newly-created job object. - job_info = win32job.QueryInformationJobObject(job_object, - win32job.JobObjectExtendedLimitInformation) + job_info = win32job.QueryInformationJobObject( + job_object, win32job.JobObjectExtendedLimitInformation + ) # Set up the job object so that closing the last handle to the job object # will terminate all associated processes and destroy the job object itself. - job_info["BasicLimitInformation"]["LimitFlags"] |= \ - win32job.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE + job_info["BasicLimitInformation"]["LimitFlags"] |= ( + win32job.JOB_OBJECT_LIMIT_KILL_ON_JOB_CLOSE + ) # Update the limits of the job object. - win32job.SetInformationJobObject(job_object, win32job.JobObjectExtendedLimitInformation, - job_info) + win32job.SetInformationJobObject( + job_object, win32job.JobObjectExtendedLimitInformation, job_info + ) return job_object @@ -89,19 +92,22 @@ class Process(object): self.args = args self.env = utils.default_if_none(env, os.environ.copy()) - if not self.env.get('RESMOKE_PARENT_PROCESS'): - self.env['RESMOKE_PARENT_PROCESS'] = os.environ.get('RESMOKE_PARENT_PROCESS', - str(os.getpid())) - if not self.env.get('RESMOKE_PARENT_CTIME'): - self.env['RESMOKE_PARENT_CTIME'] = os.environ.get('RESMOKE_PARENT_CTIME', - str(psutil.Process().create_time())) + if not self.env.get("RESMOKE_PARENT_PROCESS"): + self.env["RESMOKE_PARENT_PROCESS"] = os.environ.get( + "RESMOKE_PARENT_PROCESS", str(os.getpid()) + ) + if not self.env.get("RESMOKE_PARENT_CTIME"): + self.env["RESMOKE_PARENT_CTIME"] = os.environ.get( + "RESMOKE_PARENT_CTIME", str(psutil.Process().create_time()) + ) if env_vars is not None: self.env.update(env_vars) # If we are running against an External System Under Test & this is a `mongo{d,s}` process, we make this process a NOOP. # `mongo{d,s}` processes are not running locally for an External System Under Test. self.NOOP = _config.NOOP_MONGO_D_S_PROCESSES and os.path.basename( - self.args[0]) in ["mongod", "mongos"] + self.args[0] + ) in ["mongod", "mongos"] # The `pid` attribute is assigned after the local process is started. If this process is a NOOP, we assign it a dummy value. self.pid = 1 if self.NOOP else None @@ -133,42 +139,67 @@ class Process(object): # thread, or concurrently from multiple threads -- from causing another subprocess to wait # for the completion of the newly spawned child process. Closing other file descriptors # isn't supported on Windows when stdout and stderr are redirected. - close_fds = (sys.platform != "win32") + close_fds = sys.platform != "win32" with _POPEN_LOCK: - # Record unittests directly since resmoke doesn't not interact with them and they can finish # too quickly for the recorder to have a chance at attaching. recorder_args = [] - if _config.UNDO_RECORDER_PATH is not None and self.args[0].endswith("_test"): + if _config.UNDO_RECORDER_PATH is not None and self.args[0].endswith( + "_test" + ): now_str = datetime.now().strftime("%Y-%m-%dT%H-%M-%S") # Only use the process name since we have to be able to correlate the recording name # with the binary name easily. recorder_output_file = "{process}-{t}.undo".format( - process=os.path.basename(self.args[0]), t=now_str) + process=os.path.basename(self.args[0]), t=now_str + ) recorder_args = [_config.UNDO_RECORDER_PATH, "-o", recorder_output_file] - self._process = subprocess.Popen(recorder_args + self.args, bufsize=buffer_size, - stdout=subprocess.PIPE, stderr=subprocess.PIPE, - close_fds=close_fds, env=self.env, - creationflags=creation_flags, cwd=self._cwd) + self._process = subprocess.Popen( + recorder_args + self.args, + bufsize=buffer_size, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + close_fds=close_fds, + env=self.env, + creationflags=creation_flags, + cwd=self._cwd, + ) self.pid = self._process.pid - if _config.UNDO_RECORDER_PATH is not None and (not self.args[0].endswith("_test")) and ( - "mongod" in self.args[0] or "mongos" in self.args[0]): + if ( + _config.UNDO_RECORDER_PATH is not None + and (not self.args[0].endswith("_test")) + and ("mongod" in self.args[0] or "mongos" in self.args[0]) + ): now_str = datetime.now().strftime("%Y-%m-%dT%H-%M-%S") recorder_output_file = "{logger}-{process}-{pid}-{t}.undo".format( - logger=self.logger.name.replace('/', '-'), - process=os.path.basename(self.args[0]), pid=self.pid, t=now_str) + logger=self.logger.name.replace("/", "-"), + process=os.path.basename(self.args[0]), + pid=self.pid, + t=now_str, + ) recorder_args = [ - _config.UNDO_RECORDER_PATH, "-p", - str(self.pid), "-o", recorder_output_file + _config.UNDO_RECORDER_PATH, + "-p", + str(self.pid), + "-o", + recorder_output_file, ] - self._recorder = subprocess.Popen(recorder_args, bufsize=buffer_size, env=self.env, - creationflags=creation_flags) + self._recorder = subprocess.Popen( + recorder_args, + bufsize=buffer_size, + env=self.env, + creationflags=creation_flags, + ) - self._stdout_pipe = pipe.LoggerPipe(self.logger, logging.INFO, self._process.stdout) - self._stderr_pipe = pipe.LoggerPipe(self.logger, logging.ERROR, self._process.stderr) + self._stdout_pipe = pipe.LoggerPipe( + self.logger, logging.INFO, self._process.stdout + ) + self._stderr_pipe = pipe.LoggerPipe( + self.logger, logging.ERROR, self._process.stderr + ) self._stdout_pipe.wait_until_started() self._stderr_pipe.wait_until_started() @@ -193,8 +224,11 @@ class Process(object): mode = fixture_interface.TeardownMode.TERMINATE if sys.platform == "win32": - if mode != fixture_interface.TeardownMode.KILL and self.args and self.args[0].find( - "mongod") != -1: + if ( + mode != fixture_interface.TeardownMode.KILL + and self.args + and self.args[0].find("mongod") != -1 + ): self._request_clean_shutdown_on_windows() else: self._terminate_on_windows() @@ -207,8 +241,10 @@ class Process(object): elif mode == fixture_interface.TeardownMode.ABORT: self._process.send_signal(mode.value) else: - raise errors.ProcessError("Process wrapper given unrecognized teardown mode: " + - mode.value) + raise errors.ProcessError( + "Process wrapper given unrecognized teardown mode: " + + mode.value + ) except OSError as err: # ESRCH (errno=3) is received when the process has already died. @@ -230,7 +266,9 @@ class Process(object): status = None try: # Wait 60 seconds for the program to exit. - status = win32event.WaitForSingleObject(self._process._handle, 60 * 1000) + status = win32event.WaitForSingleObject( + self._process._handle, 60 * 1000 + ) except win32process.error as err: # ERROR_FILE_NOT_FOUND (winerror=2) # ERROR_ACCESS_DENIED (winerror=5) @@ -243,18 +281,22 @@ class Process(object): if status is not None and status != win32event.WAIT_OBJECT_0: self.logger.info( f"Failed to cleanly exit the program, calling TerminateProcess() on PID:" - f" {str(self._process.pid)}") + f" {str(self._process.pid)}" + ) self._terminate_on_windows() return_code = self._process.wait(timeout) if self._recorder is not None: - self.logger.info('Saving the UndoDB recording; it may take a few minutes...') + self.logger.info( + "Saving the UndoDB recording; it may take a few minutes..." + ) recorder_return = self._recorder.wait(timeout) if recorder_return != 0: raise errors.ServerFailure( "UndoDB live-record did not terminate correctly. This is likely a bug with UndoDB. " - "Please record the logs and notify the #server-testing Slack channel") + "Please record the logs and notify the #server-testing Slack channel" + ) if self._stdout_pipe: self._stdout_pipe.wait_until_finished() @@ -312,7 +354,10 @@ class Process(object): _windows_mongo_signal_handle = None try: _windows_mongo_signal_handle = win32event.OpenEvent( - win32event.EVENT_MODIFY_STATE, False, "Global\\Mongo_" + str(self._process.pid)) + win32event.EVENT_MODIFY_STATE, + False, + "Global\\Mongo_" + str(self._process.pid), + ) if not _windows_mongo_signal_handle: # The process has already died. diff --git a/buildscripts/resmokelib/core/programs.py b/buildscripts/resmokelib/core/programs.py index 4717b7f309b..a9b0cf08834 100644 --- a/buildscripts/resmokelib/core/programs.py +++ b/buildscripts/resmokelib/core/programs.py @@ -57,8 +57,9 @@ def get_binary_version(executable): return LATEST_FCV -def remove_set_parameter_if_before_version(set_parameters, parameter_name, bin_version, - required_bin_version): +def remove_set_parameter_if_before_version( + set_parameters, parameter_name, bin_version, required_bin_version +): """ Used for removing a server parameter that does not exist prior to a specified version. @@ -88,13 +89,15 @@ def mongod_program(logger, job_num, executable, process_kwargs, mongod_options): args[0] = os.path.basename(args[0]) mongod_options["set_parameters"]["fassertOnLockTimeoutForStepUpDown"] = 0 mongod_options["set_parameters"].pop("backtraceLogFile", None) - mongod_options.update({ - "logpath": "/var/log/mongodb/mongodb.log", - "dbpath": "/data/db", - "bind_ip": "0.0.0.0", - "oplogSize": "256", - "wiredTigerCacheSizeGB": "1", - }) + mongod_options.update( + { + "logpath": "/var/log/mongodb/mongodb.log", + "dbpath": "/data/db", + "bind_ip": "0.0.0.0", + "oplogSize": "256", + "wiredTigerCacheSizeGB": "1", + } + ) if config.TLS_MODE: mongod_options["tlsMode"] = config.TLS_MODE @@ -115,20 +118,30 @@ def mongod_program(logger, job_num, executable, process_kwargs, mongod_options): suite_set_parameters = mongod_options.get("set_parameters", {}) remove_set_parameter_if_before_version( - suite_set_parameters, "queryAnalysisSamplerConfigurationRefreshSecs", bin_version, "7.0.0") - remove_set_parameter_if_before_version(suite_set_parameters, "queryAnalysisWriterIntervalSecs", - bin_version, "7.0.0") - remove_set_parameter_if_before_version(suite_set_parameters, "defaultConfigCommandTimeoutMS", - bin_version, "7.3.0") - - - remove_set_parameter_if_before_version(suite_set_parameters, "internalQueryStatsRateLimit", - bin_version, "7.3.0") + suite_set_parameters, + "queryAnalysisSamplerConfigurationRefreshSecs", + bin_version, + "7.0.0", + ) remove_set_parameter_if_before_version( - suite_set_parameters, "internalQueryStatsErrorsAreCommandFatal", bin_version, "7.3.0") - remove_set_parameter_if_before_version(suite_set_parameters, "enableAutoCompaction", - bin_version, "7.3.0") + suite_set_parameters, "queryAnalysisWriterIntervalSecs", bin_version, "7.0.0" + ) + remove_set_parameter_if_before_version( + suite_set_parameters, "defaultConfigCommandTimeoutMS", bin_version, "7.3.0" + ) + remove_set_parameter_if_before_version( + suite_set_parameters, "internalQueryStatsRateLimit", bin_version, "7.3.0" + ) + remove_set_parameter_if_before_version( + suite_set_parameters, + "internalQueryStatsErrorsAreCommandFatal", + bin_version, + "7.3.0", + ) + remove_set_parameter_if_before_version( + suite_set_parameters, "enableAutoCompaction", bin_version, "7.3.0" + ) _apply_set_parameters(args, suite_set_parameters) final_mongod_options = mongod_options.copy() @@ -148,7 +161,9 @@ def mongod_program(logger, job_num, executable, process_kwargs, mongod_options): return make_process(logger, args, **process_kwargs), final_mongod_options -def mongos_program(logger, job_num, executable=None, process_kwargs=None, mongos_options=None): +def mongos_program( + logger, job_num, executable=None, process_kwargs=None, mongos_options=None +): """Return a Process instance that starts a mongos with arguments constructed from 'kwargs'.""" bin_version = get_binary_version(executable) args = [executable] @@ -158,7 +173,9 @@ def mongos_program(logger, job_num, executable=None, process_kwargs=None, mongos if config.NOOP_MONGO_D_S_PROCESSES: args[0] = os.path.basename(args[0]) mongos_options["set_parameters"]["fassertOnLockTimeoutForStepUpDown"] = 0 - mongos_options.update({"logpath": "/var/log/mongodb/mongodb.log", "bind_ip": "0.0.0.0"}) + mongos_options.update( + {"logpath": "/var/log/mongodb/mongodb.log", "bind_ip": "0.0.0.0"} + ) if config.TLS_MODE: mongos_options["tlsMode"] = config.TLS_MODE @@ -176,16 +193,24 @@ def mongos_program(logger, job_num, executable=None, process_kwargs=None, mongos suite_set_parameters = mongos_options.get("set_parameters", {}) remove_set_parameter_if_before_version( - suite_set_parameters, "queryAnalysisSamplerConfigurationRefreshSecs", bin_version, "7.0.0") - remove_set_parameter_if_before_version(suite_set_parameters, "defaultConfigCommandTimeoutMS", - bin_version, "7.3.0") - - - remove_set_parameter_if_before_version(suite_set_parameters, "internalQueryStatsRateLimit", - bin_version, "7.3.0") + suite_set_parameters, + "queryAnalysisSamplerConfigurationRefreshSecs", + bin_version, + "7.0.0", + ) remove_set_parameter_if_before_version( - suite_set_parameters, "internalQueryStatsErrorsAreCommandFatal", bin_version, "7.3.0") + suite_set_parameters, "defaultConfigCommandTimeoutMS", bin_version, "7.3.0" + ) + remove_set_parameter_if_before_version( + suite_set_parameters, "internalQueryStatsRateLimit", bin_version, "7.3.0" + ) + remove_set_parameter_if_before_version( + suite_set_parameters, + "internalQueryStatsErrorsAreCommandFatal", + bin_version, + "7.3.0", + ) _apply_set_parameters(args, suite_set_parameters) final_mongos_options = mongos_options.copy() @@ -201,7 +226,9 @@ def mongos_program(logger, job_num, executable=None, process_kwargs=None, mongos return make_process(logger, args, **process_kwargs), final_mongos_options -def mongot_program(logger, job_num, executable=None, process_kwargs=None, mongot_options=None): +def mongot_program( + logger, job_num, executable=None, process_kwargs=None, mongot_options=None +): """Return a Process instance that starts a mongot.""" args = [executable] mongot_options = mongot_options.copy() @@ -212,8 +239,15 @@ def mongot_program(logger, job_num, executable=None, process_kwargs=None, mongot return make_process(logger, args, **process_kwargs), final_mongot_options -def mongo_shell_program(logger, executable=None, connection_string=None, filename=None, - test_filename=None, process_kwargs=None, **kwargs): +def mongo_shell_program( + logger, + executable=None, + connection_string=None, + filename=None, + test_filename=None, + process_kwargs=None, + **kwargs, +): """Return a Process instance that starts a mongo shell. The shell is started with the given connection string and arguments constructed from 'kwargs'. @@ -224,7 +258,9 @@ def mongo_shell_program(logger, executable=None, connection_string=None, filenam """ executable = utils.default_if_none( - utils.default_if_none(executable, config.MONGO_EXECUTABLE), config.DEFAULT_MONGO_EXECUTABLE) + utils.default_if_none(executable, config.MONGO_EXECUTABLE), + config.DEFAULT_MONGO_EXECUTABLE, + ) args = [executable] eval_sb = [] # String builder. @@ -271,8 +307,9 @@ def mongo_shell_program(logger, executable=None, connection_string=None, filenam test_data["shellTlsEnabled"] = True if config.SHELL_TLS_CERTIFICATE_KEY_FILE: - test_data["shellTlsCertificateKeyFile"] = config.SHELL_TLS_CERTIFICATE_KEY_FILE - + test_data["shellTlsCertificateKeyFile"] = ( + config.SHELL_TLS_CERTIFICATE_KEY_FILE + ) if config.TLS_CA_FILE: test_data["tlsCAFile"] = config.TLS_CA_FILE @@ -281,10 +318,14 @@ def mongo_shell_program(logger, executable=None, connection_string=None, filenam test_data["tlsMode"] = config.TLS_MODE if config.MONGOD_TLS_CERTIFICATE_KEY_FILE: - test_data["mongodTlsCertificateKeyFile"] = config.MONGOD_TLS_CERTIFICATE_KEY_FILE + test_data["mongodTlsCertificateKeyFile"] = ( + config.MONGOD_TLS_CERTIFICATE_KEY_FILE + ) if config.MONGOS_TLS_CERTIFICATE_KEY_FILE: - test_data["mongosTlsCertificateKeyFile"] = config.MONGOS_TLS_CERTIFICATE_KEY_FILE + test_data["mongosTlsCertificateKeyFile"] = ( + config.MONGOS_TLS_CERTIFICATE_KEY_FILE + ) global_vars["TestData"] = test_data @@ -323,7 +364,9 @@ def mongo_shell_program(logger, executable=None, connection_string=None, filenam # Propagate additional setParameters to mongocryptd processes spawned by the mongo shell. # Command line options to resmoke.py override the YAML configuration. if config.MONGOCRYPTD_SET_PARAMETERS is not None: - mongocryptd_set_parameters.update(utils.load_yaml(config.MONGOCRYPTD_SET_PARAMETERS)) + mongocryptd_set_parameters.update( + utils.load_yaml(config.MONGOCRYPTD_SET_PARAMETERS) + ) mongocryptd_set_parameters.update(feature_flag_dict) if config.MONGO_SET_PARAMETERS is not None: @@ -335,18 +378,24 @@ def mongo_shell_program(logger, executable=None, connection_string=None, filenam # If the 'logComponentVerbosity' setParameter for mongod was not already specified, we set its # value to a default. mongod_set_parameters.setdefault( - "logComponentVerbosity", mongod_launcher.get_default_log_component_verbosity_for_mongod()) + "logComponentVerbosity", + mongod_launcher.get_default_log_component_verbosity_for_mongod(), + ) # If the 'enableFlowControl' setParameter for mongod was not already specified, we set its value # to a default. if config.FLOW_CONTROL is not None: - mongod_set_parameters.setdefault("enableFlowControl", config.FLOW_CONTROL == "on") + mongod_set_parameters.setdefault( + "enableFlowControl", config.FLOW_CONTROL == "on" + ) mongos_launcher = shardedcluster.MongosLauncher(fixturelib) # If the 'logComponentVerbosity' setParameter for mongos was not already specified, we set its # value to a default. - mongos_set_parameters.setdefault("logComponentVerbosity", - mongos_launcher.default_mongos_log_component_verbosity()) + mongos_set_parameters.setdefault( + "logComponentVerbosity", + mongos_launcher.default_mongos_log_component_verbosity(), + ) test_data["setParameters"] = mongod_set_parameters test_data["setParametersMongos"] = mongos_set_parameters @@ -378,7 +427,10 @@ def mongo_shell_program(logger, executable=None, connection_string=None, filenam test_data["connectionString"] = connection_string connection_string = None - if config.FUZZ_MONGOD_CONFIGS is not None and config.FUZZ_MONGOD_CONFIGS is not False: + if ( + config.FUZZ_MONGOD_CONFIGS is not None + and config.FUZZ_MONGOD_CONFIGS is not False + ): test_data["fuzzMongodConfigs"] = True for var_name in global_vars: @@ -388,37 +440,50 @@ def mongo_shell_program(logger, executable=None, connection_string=None, filenam eval_sb.append(str(kwargs.pop("eval"))) # Load a callback to check that the cluster-wide metadata is consistent. - eval_sb.append('await import("jstests/libs/override_methods/check_metadata_consistency.js")') + eval_sb.append( + 'await import("jstests/libs/override_methods/check_metadata_consistency.js")' + ) # Load this file to allow a callback to validate collections before shutting down mongod. eval_sb.append( - 'await import("jstests/libs/override_methods/validate_collections_on_shutdown.js")') + 'await import("jstests/libs/override_methods/validate_collections_on_shutdown.js")' + ) # Load a callback to check UUID consistency before shutting down a ShardingTest. eval_sb.append( - 'await import("jstests/libs/override_methods/check_uuids_consistent_across_cluster.js")') + 'await import("jstests/libs/override_methods/check_uuids_consistent_across_cluster.js")' + ) # Load a callback to check index consistency before shutting down a ShardingTest. eval_sb.append( - 'await import("jstests/libs/override_methods/check_indexes_consistent_across_cluster.js")') + 'await import("jstests/libs/override_methods/check_indexes_consistent_across_cluster.js")' + ) # Load a callback to check that all orphans are deleted before shutting down a ShardingTest. - eval_sb.append('await import("jstests/libs/override_methods/check_orphans_are_deleted.js")') + eval_sb.append( + 'await import("jstests/libs/override_methods/check_orphans_are_deleted.js")' + ) # Load a callback to check that the info stored in config.collections and config.chunks is # semantically correct before shutting down a ShardingTest. eval_sb.append( - 'await import("jstests/libs/override_methods/check_routing_table_consistency.js")') + 'await import("jstests/libs/override_methods/check_routing_table_consistency.js")' + ) # Load a callback to check that all shards have correct filtering information before shutting # down a ShardingTest. eval_sb.append( - 'await import("jstests/libs/override_methods/check_shard_filtering_metadata.js")') + 'await import("jstests/libs/override_methods/check_shard_filtering_metadata.js")' + ) - if config.FUZZ_MONGOD_CONFIGS is not None and config.FUZZ_MONGOD_CONFIGS is not False: + if ( + config.FUZZ_MONGOD_CONFIGS is not None + and config.FUZZ_MONGOD_CONFIGS is not False + ): # Prevent commands from running with the config fuzzer. eval_sb.append( - 'await import("jstests/libs/override_methods/config_fuzzer_incompatible_commands.js")') + 'await import("jstests/libs/override_methods/config_fuzzer_incompatible_commands.js")' + ) # Load this file to retry operations that fail due to in-progress background operations. eval_sb.append( @@ -426,7 +491,7 @@ def mongo_shell_program(logger, executable=None, connection_string=None, filenam ) eval_sb.append( - "(function() { Timestamp.prototype.toString = function() { throw new Error(\"Cannot toString timestamps. Consider using timestampCmp() for comparison or tojson() for output.\"); } })()" + '(function() { Timestamp.prototype.toString = function() { throw new Error("Cannot toString timestamps. Consider using timestampCmp() for comparison or tojson() for output."); } })()' ) eval_str = "; ".join(eval_sb) @@ -440,7 +505,6 @@ def mongo_shell_program(logger, executable=None, connection_string=None, filenam if config.SHELL_TLS_CERTIFICATE_KEY_FILE: kwargs["tlsCertificateKeyFile"] = config.SHELL_TLS_CERTIFICATE_KEY_FILE - if connection_string is not None: # The --host and --port options are ignored by the mongo shell when an explicit connection # string is specified. We remove these options to avoid any ambiguity with what server the @@ -483,7 +547,7 @@ def _format_shell_vars(sb, paths, value): # Convert the list ["a", "b", "c"] into the string 'a["b"]["c"]' def bracketize(lst): - return lst[0] + ''.join(f'["{i}"]' for i in lst[1:]) + return lst[0] + "".join(f'["{i}"]' for i in lst[1:]) # Only need to do special handling for JSON objects. if not isinstance(value, (dict, HistoryDict)): @@ -511,7 +575,7 @@ def dbtest_program(logger, executable=None, suites=None, process_kwargs=None, ** kwargs["storageEngine"] = config.STORAGE_ENGINE if config.FLOW_CONTROL is not None: - kwargs["flowControl"] = (config.FLOW_CONTROL == "on") + kwargs["flowControl"] = config.FLOW_CONTROL == "on" return generic_program(logger, args, process_kwargs=process_kwargs, **kwargs) diff --git a/buildscripts/resmokelib/discovery/__init__.py b/buildscripts/resmokelib/discovery/__init__.py index f6619f27feb..df19e2241cb 100644 --- a/buildscripts/resmokelib/discovery/__init__.py +++ b/buildscripts/resmokelib/discovery/__init__.py @@ -1,4 +1,5 @@ """Subcommands for test discovery.""" + from typing import List, Optional import yaml @@ -92,12 +93,14 @@ class DiscoveryPlugin(PluginInterface): :param subparsers: argparse subparsers """ - parser = subparsers.add_parser(TEST_DISCOVERY_SUBCOMMAND, - help="Discover what tests are run by a suite.") + parser = subparsers.add_parser( + TEST_DISCOVERY_SUBCOMMAND, help="Discover what tests are run by a suite." + ) parser.add_argument("--suite", metavar="SUITE", help="Suite to run against.") - parser = subparsers.add_parser(SUITECONFIG_SUBCOMMAND, - help="Display configuration of a test suite.") + parser = subparsers.add_parser( + SUITECONFIG_SUBCOMMAND, help="Display configuration of a test suite." + ) parser.add_argument("--suite", metavar="SUITE", help="Suite to run against.") def parse(self, subcommand, parser, parsed_args, **kwargs) -> Optional[Subcommand]: diff --git a/buildscripts/resmokelib/errors.py b/buildscripts/resmokelib/errors.py index 051a14a4df8..b7bf3827289 100644 --- a/buildscripts/resmokelib/errors.py +++ b/buildscripts/resmokelib/errors.py @@ -3,31 +3,37 @@ class ResmokeError(Exception): # noqa: D204 """Base class for all resmoke.py exceptions.""" + pass class SuiteNotFound(ResmokeError): # noqa: D204 """A suite that isn't recognized was specified.""" + pass class DuplicateSuiteDefinition(ResmokeError): # noqa: D204 """A suite name with multiple definitions.""" + pass class StopExecution(ResmokeError): # noqa: D204 """Exception raised when resmoke.py should stop executing tests if failing fast is enabled.""" + pass class UserInterrupt(StopExecution): # noqa: D204 """Exception raised when a user signals resmoke.py to unconditionally stop executing tests.""" + EXIT_CODE = 130 # Simulate SIGINT as exit code. class LoggerRuntimeConfigError(StopExecution): # noqa: D204 """Exception raised when a logging handler couldn't be configured at runtime.""" + EXIT_CODE = 75 @@ -36,6 +42,7 @@ class TestFailure(ResmokeError): # noqa: D204 Raised if it determines the the previous test should be marked as a failure. """ + pass @@ -45,6 +52,7 @@ class ServerFailure(TestFailure): # noqa: D204 Raised if it detects that the fixture did not exit cleanly and should be marked as a failure. """ + pass @@ -54,6 +62,7 @@ class PortAllocationError(ResmokeError): # noqa: D204 Raised if a port is requested outside of the range of valid ports, or if a fixture requests more ports than were reserved for that job. """ + pass diff --git a/buildscripts/resmokelib/generate_fcv_constants/__init__.py b/buildscripts/resmokelib/generate_fcv_constants/__init__.py index 1edaeb605bc..67f7e6aa2c1 100644 --- a/buildscripts/resmokelib/generate_fcv_constants/__init__.py +++ b/buildscripts/resmokelib/generate_fcv_constants/__init__.py @@ -1,4 +1,5 @@ """Generate FCV constants for consumption by non-C++ integration tests.""" + import argparse from buildscripts.resmokelib import configure_resmoke, logging @@ -28,6 +29,7 @@ class GenerateFCVConstants(Subcommand): self._setup_logging() import buildscripts.resmokelib.multiversionconstants + buildscripts.resmokelib.multiversionconstants.log_constants(self._exec_logger) diff --git a/buildscripts/resmokelib/generate_fuzz_config/__init__.py b/buildscripts/resmokelib/generate_fuzz_config/__init__.py index 8619f9b2bb0..bc46984a570 100644 --- a/buildscripts/resmokelib/generate_fuzz_config/__init__.py +++ b/buildscripts/resmokelib/generate_fuzz_config/__init__.py @@ -1,4 +1,5 @@ """Generate mongod.conf and mongos.conf using config fuzzer.""" + import json import os.path import shutil @@ -27,10 +28,11 @@ class GenerateFuzzConfig(Subcommand): filename = "mongod.conf" output_file = os.path.join(self._output_path, filename) user_param = utils.dump_yaml({}) - set_parameters, wt_engine_config, wt_coll_config, \ - wt_index_config = mongo_fuzzer_configs.fuzz_mongod_set_parameters(self._mongod_mode, - self._seed, - user_param) + set_parameters, wt_engine_config, wt_coll_config, wt_index_config = ( + mongo_fuzzer_configs.fuzz_mongod_set_parameters( + self._mongod_mode, self._seed, user_param + ) + ) set_parameters = utils.load_yaml(set_parameters) set_parameters["mirrorReads"] = json.dumps(set_parameters["mirrorReads"]) # This is moved from Jepsen mongod.conf to have only one setParameter key value pair. @@ -41,13 +43,14 @@ class GenerateFuzzConfig(Subcommand): set_parameters["numInitialSyncAttempts"] = 20 set_parameters["testingDiagnosticsEnabled"] = True conf = { - "setParameter": set_parameters, "storage": { + "setParameter": set_parameters, + "storage": { "wiredTiger": { "engineConfig": {"configString": wt_engine_config}, "collectionConfig": {"configString": wt_coll_config}, - "indexConfig": {"configString": wt_index_config} + "indexConfig": {"configString": wt_index_config}, } - } + }, } if self._template_path is not None: try: @@ -56,7 +59,7 @@ class GenerateFuzzConfig(Subcommand): pass fuzz_config = utils.dump_yaml(conf) - with open(output_file, 'a') as file: + with open(output_file, "a") as file: file.write(fuzz_config) file.write("\n") @@ -65,7 +68,8 @@ class GenerateFuzzConfig(Subcommand): output_file = os.path.join(self._output_path, filename) user_param = utils.dump_yaml({}) set_parameters = mongo_fuzzer_configs.fuzz_mongos_set_parameters( - self._mongos_mode, self._seed, user_param) + self._mongos_mode, self._seed, user_param + ) set_parameters = utils.load_yaml(set_parameters) conf = {"setParameter": set_parameters} if self._template_path is not None: @@ -74,11 +78,13 @@ class GenerateFuzzConfig(Subcommand): except shutil.SameFileError: pass except FileNotFoundError: - print("There is no mongos template in the path, skip generating mongos.conf.") + print( + "There is no mongos template in the path, skip generating mongos.conf." + ) return fuzz_config = utils.dump_yaml(conf) - with open(output_file, 'a') as file: + with open(output_file, "a") as file: file.write(fuzz_config) file.write("\n") @@ -103,21 +109,38 @@ class GenerateFuzzConfigPlugin(PluginInterface): :return: None """ parser = subparsers.add_parser(_COMMAND, help=_HELP) - parser.add_argument("--template", '-t', type=str, required=False, - help="Path to templates to append config-fuzzer-generated parameters.") - parser.add_argument("--output", '-o', type=str, required=True, - help="Path to the output file.") parser.add_argument( - "--fuzzMongodConfigs", dest="fuzz_mongod_configs", + "--template", + "-t", + type=str, + required=False, + help="Path to templates to append config-fuzzer-generated parameters.", + ) + parser.add_argument( + "--output", "-o", type=str, required=True, help="Path to the output file." + ) + parser.add_argument( + "--fuzzMongodConfigs", + dest="fuzz_mongod_configs", help="Randomly chooses mongod parameters that were not specified. Use 'stress' to fuzz " "all configs including stressful storage configurations that may significantly " "slow down the server. Use 'normal' to only fuzz non-stressful configurations. ", - metavar="MODE", choices=('normal', 'stress')) - parser.add_argument("--fuzzMongosConfigs", dest="fuzz_mongos_configs", - help="Randomly chooses mongos parameters that were not specified", - metavar="MODE", choices=('normal', )) - parser.add_argument("--configFuzzSeed", dest="config_fuzz_seed", metavar="PATH", - help="Sets the seed used by mongod and mongos config fuzzers") + metavar="MODE", + choices=("normal", "stress"), + ) + parser.add_argument( + "--fuzzMongosConfigs", + dest="fuzz_mongos_configs", + help="Randomly chooses mongos parameters that were not specified", + metavar="MODE", + choices=("normal",), + ) + parser.add_argument( + "--configFuzzSeed", + dest="config_fuzz_seed", + metavar="PATH", + help="Sets the seed used by mongod and mongos config fuzzers", + ) def parse(self, subcommand, parser, parsed_args, **kwargs): """ @@ -133,6 +156,10 @@ class GenerateFuzzConfigPlugin(PluginInterface): if subcommand != _COMMAND: return None - return GenerateFuzzConfig(parsed_args.template, parsed_args.output, - parsed_args.fuzz_mongod_configs, parsed_args.fuzz_mongos_configs, - parsed_args.config_fuzz_seed) + return GenerateFuzzConfig( + parsed_args.template, + parsed_args.output, + parsed_args.fuzz_mongod_configs, + parsed_args.fuzz_mongos_configs, + parsed_args.config_fuzz_seed, + ) diff --git a/buildscripts/resmokelib/hang_analyzer/attach_core_analyzer_task.py b/buildscripts/resmokelib/hang_analyzer/attach_core_analyzer_task.py index 091692e4c2e..a6d549da09b 100644 --- a/buildscripts/resmokelib/hang_analyzer/attach_core_analyzer_task.py +++ b/buildscripts/resmokelib/hang_analyzer/attach_core_analyzer_task.py @@ -18,18 +18,23 @@ from buildscripts.resmokelib.utils import evergreen_conn from buildscripts.util.read_config import read_config_file -def matches_generated_task_pattern(original_task_name: str, - generated_task_name: str) -> Optional[str]: +def matches_generated_task_pattern( + original_task_name: str, generated_task_name: str +) -> Optional[str]: regex = re.match( f"{GENERATED_TASK_PREFIX}_{original_task_name}([0-9]{{1,2}})_[a-zA-Z0-9]{{{RANDOM_STRING_LENGTH}}}", - generated_task_name) + generated_task_name, + ) return regex.group(1) if regex else None -def maybe_attach_core_analyzer_task(expansions_file: str, conditional_file: str, - artifact_output_file: str, results_output_file: str): - +def maybe_attach_core_analyzer_task( + expansions_file: str, + conditional_file: str, + artifact_output_file: str, + results_output_file: str, +): # This script runs for every task even if the task is passing. # This statement checks to see if the file that was made to generate a core analyzer task # lives on the machine or not. @@ -48,10 +53,14 @@ def maybe_attach_core_analyzer_task(expansions_file: str, conditional_file: str, # If the task is a part of a display task, search the parent's execution tasks # If the task has no parent search the whole build variant parent_id = current_task.parent_task_id - search_tasks = evg_api.task_by_id(parent_id).execution_tasks if parent_id else build.tasks + search_tasks = ( + evg_api.task_by_id(parent_id).execution_tasks if parent_id else build.tasks + ) # The task id uses underscores instead of hyphens - task_id_search_term = f"{GENERATED_TASK_PREFIX}_{current_task_name.replace('-', '_')}" + task_id_search_term = ( + f"{GENERATED_TASK_PREFIX}_{current_task_name.replace('-', '_')}" + ) matching_task = None matching_execution = None @@ -78,13 +87,19 @@ def maybe_attach_core_analyzer_task(expansions_file: str, conditional_file: str, # Check if the core analysis is from the current execution or a previous one gen_from_cur_execution = current_task.execution == int(matching_execution) - artifact_name = "Core Analyzer Task" if gen_from_cur_execution else f"Core Analyzer Task (Previous Execution #{matching_execution})" + artifact_name = ( + "Core Analyzer Task" + if gen_from_cur_execution + else f"Core Analyzer Task (Previous Execution #{matching_execution})" + ) - core_analyzer_task_artifact = [{ - "name": artifact_name, - "link": core_analysis_task_url, - "visibility": "public", - }] + core_analyzer_task_artifact = [ + { + "name": artifact_name, + "link": core_analysis_task_url, + "visibility": "public", + } + ] with open(artifact_output_file, "w") as file: json.dump(core_analyzer_task_artifact, file, indent=4) @@ -106,27 +121,34 @@ def maybe_attach_core_analyzer_task(expansions_file: str, conditional_file: str, file.write("\n".join(file_lines)) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--expansions-file", "-e", - help="Expansions file to read task info and aws credentials from.", - default="../expansions.yml") + parser.add_argument( + "--expansions-file", + "-e", + help="Expansions file to read task info and aws credentials from.", + default="../expansions.yml", + ) parser.add_argument( "--conditional-file", help="Path to file. If this file exists, that means task generation was successful.", - default="hang_analyzer_task.json") + default="hang_analyzer_task.json", + ) parser.add_argument( "--artifact-output-file", help="Name of output file to write artifact of the core analyzer task url to.", - default="core_analyzer_artifact.json") + default="core_analyzer_artifact.json", + ) parser.add_argument( "--results-output-file", help="Name of output file to write the temperary core analyzer result text.", - default="core_analyzer_results.txt") + default="core_analyzer_results.txt", + ) args = parser.parse_args() expansions_file = args.expansions_file conditional_file = args.conditional_file artifact_output_file = args.artifact_output_file results_output_file = args.results_output_file - maybe_attach_core_analyzer_task(expansions_file, conditional_file, artifact_output_file, - results_output_file) + maybe_attach_core_analyzer_task( + expansions_file, conditional_file, artifact_output_file, results_output_file + ) diff --git a/buildscripts/resmokelib/hang_analyzer/core_analyzer.py b/buildscripts/resmokelib/hang_analyzer/core_analyzer.py index 2d0a24a0390..41398ff894c 100644 --- a/buildscripts/resmokelib/hang_analyzer/core_analyzer.py +++ b/buildscripts/resmokelib/hang_analyzer/core_analyzer.py @@ -39,19 +39,30 @@ class CoreAnalyzer(Subcommand): if file.read().strip() == self.task_id: skip_download = True self.root_logger.info( - "Files from task id provided were already on disk, skipping download.") + "Files from task id provided were already on disk, skipping download." + ) multiversion_dir = os.path.join(base_dir, "multiversion") if not skip_download and not download_task_artifacts( - self.root_logger, self.task_id, base_dir, dumpers.dbg, multiversion_dir, - self.execution): + self.root_logger, + self.task_id, + base_dir, + dumpers.dbg, + multiversion_dir, + self.execution, + ): self.root_logger.error("Artifacts were not found.") - current_span.set_attributes({ - "core_analyzer_execute_error": "Artifacts were not found.", - }) - current_span.set_status(StatusCode.ERROR, description="Artifacts were not found.") + current_span.set_attributes( + { + "core_analyzer_execute_error": "Artifacts were not found.", + } + ) + current_span.set_status( + StatusCode.ERROR, description="Artifacts were not found." + ) raise RuntimeError( - "Artifacts were not found for specified task. Could not analyze cores.") + "Artifacts were not found for specified task. Could not analyze cores." + ) with open(task_id_file, "w") as file: file.write(self.task_id) @@ -59,14 +70,20 @@ class CoreAnalyzer(Subcommand): core_dump_dir = os.path.join(base_dir, "core-dumps") install_dir = os.path.join(base_dir, "install") else: # if a task id was not specified, look for input files on the current machine - install_dir = self.options.install_dir or os.path.join(os.path.curdir, "build", - "install") + install_dir = self.options.install_dir or os.path.join( + os.path.curdir, "build", "install" + ) core_dump_dir = self.options.core_dir or os.path.curdir multiversion_dir = self.options.multiversion_dir or os.path.curdir analysis_dir = os.path.join(base_dir, "analysis") - report = dumpers.dbg.analyze_cores(core_dump_dir, install_dir, analysis_dir, - multiversion_dir, self.gdb_index_cache) + report = dumpers.dbg.analyze_cores( + core_dump_dir, + install_dir, + analysis_dir, + multiversion_dir, + self.gdb_index_cache, + ) if self.options.generate_report: with open("report.json", "w") as file: @@ -86,10 +103,15 @@ class CoreAnalyzer(Subcommand): class CoreAnalyzerPlugin(PluginInterface): """Integration-point for core-analyzer.""" - def parse(self, subcommand: str, parser: argparse.ArgumentParser, - parsed_args: argparse.Namespace, **kwargs) -> Optional[Subcommand]: + def parse( + self, + subcommand: str, + parser: argparse.ArgumentParser, + parsed_args: argparse.Namespace, + **kwargs, + ) -> Optional[Subcommand]: """Parse command-line options.""" - if subcommand == 'core-analyzer': + if subcommand == "core-analyzer": configure_resmoke.detect_evergreen_config(parsed_args) configure_resmoke.validate_and_update_config(parser, parsed_args) return CoreAnalyzer(parsed_args) @@ -99,11 +121,13 @@ class CoreAnalyzerPlugin(PluginInterface): """Create and add the parser for the core analyzer subcommand.""" parser = subparsers.add_parser( - "core-analyzer", help="Analyzes the core dumps from the specified input files.") + "core-analyzer", + help="Analyzes the core dumps from the specified input files.", + ) parser.add_argument( "--task-id", - '-t', + "-t", action="store", type=str, default=None, @@ -112,27 +136,58 @@ class CoreAnalyzerPlugin(PluginInterface): ) parser.add_argument( - "--execution", '-e', action="store", type=int, default=None, + "--execution", + "-e", + action="store", + type=int, + default=None, help="The execution of the task you want to download core dumps for." - " This will default to the latest execution if left blank.") - - parser.add_argument("--install-dir", '-b', action="store", type=str, default=None, - help="Directory that contains binaires and debugsymbols.") - - parser.add_argument("--multiversion-dir", '-m', action="store", type=str, default=None, - help="Directory that contains multiversion binaries and debugsymbols.") - - parser.add_argument("--core-dir", '-c', action="store", type=str, default=None, - help="Directory that contains core dumps.") - - parser.add_argument( - "--working-dir", '-w', action="store", type=str, default="core-analyzer", - help="Directory that downloaded artifacts will be stored and output will be written to." + " This will default to the latest execution if left blank.", ) parser.add_argument( - "--generate-report", '-r', action="store_true", default=False, - help="Whether to generate a report used to log individual tests in evergreen.") + "--install-dir", + "-b", + action="store", + type=str, + default=None, + help="Directory that contains binaires and debugsymbols.", + ) + + parser.add_argument( + "--multiversion-dir", + "-m", + action="store", + type=str, + default=None, + help="Directory that contains multiversion binaries and debugsymbols.", + ) + + parser.add_argument( + "--core-dir", + "-c", + action="store", + type=str, + default=None, + help="Directory that contains core dumps.", + ) + + parser.add_argument( + "--working-dir", + "-w", + action="store", + type=str, + default="core-analyzer", + help="Directory that downloaded artifacts will be stored and output will be written to.", + ) + + parser.add_argument( + "--generate-report", + "-r", + action="store_true", + default=False, + help="Whether to generate a report used to log individual tests in evergreen.", + ) parser.add_argument( "--gdb-index-cache", diff --git a/buildscripts/resmokelib/hang_analyzer/dumper.py b/buildscripts/resmokelib/hang_analyzer/dumper.py index d1e9bfef2a2..642d302a8fb 100644 --- a/buildscripts/resmokelib/hang_analyzer/dumper.py +++ b/buildscripts/resmokelib/hang_analyzer/dumper.py @@ -25,7 +25,7 @@ from buildscripts.resmokelib.hang_analyzer.process_list import Pinfo from buildscripts.resmokelib.utils.otel_utils import get_default_current_span from buildscripts.simple_report import Report, Result -Dumpers = namedtuple('Dumpers', ['dbg', 'jstack']) +Dumpers = namedtuple("Dumpers", ["dbg", "jstack"]) TRACER = trace.get_tracer("resmoke") @@ -71,9 +71,9 @@ class Dumper(metaclass=ABCMeta): @abstractmethod def dump_info( - self, - pinfo: Pinfo, - take_dump: bool, + self, + pinfo: Pinfo, + take_dump: bool, ): """ Perform dump for a process. @@ -81,25 +81,35 @@ class Dumper(metaclass=ABCMeta): :param pinfo: A Pinfo describing the process :param take_dump: Whether to take a core dump """ - raise NotImplementedError("dump_info must be implemented in OS-specific subclasses") + raise NotImplementedError( + "dump_info must be implemented in OS-specific subclasses" + ) @abstractmethod def get_dump_ext(self): """Return the dump file extension.""" - raise NotImplementedError("get_dump_ext must be implemented in OS-specific subclasses") + raise NotImplementedError( + "get_dump_ext must be implemented in OS-specific subclasses" + ) @abstractmethod def _find_debugger(self): """Find the installed debugger.""" - raise NotImplementedError("_find_debugger must be implemented in OS-specific subclasses") + raise NotImplementedError( + "_find_debugger must be implemented in OS-specific subclasses" + ) @abstractmethod def _prefix(self): """Return the commands to set up a debugger process.""" - raise NotImplementedError("_prefix must be implemented in OS-specific subclasses") + raise NotImplementedError( + "_prefix must be implemented in OS-specific subclasses" + ) @abstractmethod - def _process_specific(self, pinfo: Pinfo, take_dump: bool, logger: logging.Logger = None): + def _process_specific( + self, pinfo: Pinfo, take_dump: bool, logger: logging.Logger = None + ): """ Return the commands that attach to each process, dump info and detach. @@ -107,7 +117,9 @@ class Dumper(metaclass=ABCMeta): :param take_dump: Whether to take a core dump :param logger: Logger to output dump info to """ - raise NotImplementedError("_process_specific must be implemented in OS-specific subclasses") + raise NotImplementedError( + "_process_specific must be implemented in OS-specific subclasses" + ) @abstractmethod def analyze_cores(self, core_file_dir: str, install_dir: str, analysis_dir: str): @@ -117,18 +129,23 @@ class Dumper(metaclass=ABCMeta): :param core_file_dir: Directory to be scanned for core dumps :param install_dir: Directory to be scanned for binaries and debugsymbols """ - raise NotImplementedError("analyze_cores must be implemented in OS-specific subclasses") + raise NotImplementedError( + "analyze_cores must be implemented in OS-specific subclasses" + ) @abstractmethod def _postfix(self): """Return the commands to exit the debugger.""" - raise NotImplementedError("_postfix must be implemented in OS-specific subclasses") + raise NotImplementedError( + "_postfix must be implemented in OS-specific subclasses" + ) @abstractmethod def get_binary_from_core_dump(self, core_file_path): """Return the name of the binary that created the input core dump.""" raise NotImplementedError( - "get_binary_from_core_dump must be implemented in OS-specific subclasses") + "get_binary_from_core_dump must be implemented in OS-specific subclasses" + ) class WindowsDumper(Dumper): @@ -151,10 +168,15 @@ class WindowsDumper(Dumper): root_dir = shell.SHGetFolderPath(0, shellcon.CSIDL_PROGRAM_FILESX86, None, 0) # Construct the debugger search paths in most-recent order - debugger_paths = [os.path.join(root_dir, "Windows Kits", "10", "Debuggers", "x64")] + debugger_paths = [ + os.path.join(root_dir, "Windows Kits", "10", "Debuggers", "x64") + ] for idx in reversed(range(0, 2)): debugger_paths.append( - os.path.join(root_dir, "Windows Kits", "8." + str(idx), "Debuggers", "x64")) + os.path.join( + root_dir, "Windows Kits", "8." + str(idx), "Debuggers", "x64" + ) + ) for dbg_path in debugger_paths: self._root_logger.info("Checking for debugger in %s", dbg_path) @@ -180,8 +202,11 @@ class WindowsDumper(Dumper): if take_dump: # Dump to file, dump_..mdmp - dump_file = "dump_%s.%d.%s" % (os.path.splitext(pinfo.name)[0], pinfo.pidv, - self.get_dump_ext()) + dump_file = "dump_%s.%d.%s" % ( + os.path.splitext(pinfo.name)[0], + pinfo.pidv, + self.get_dump_ext(), + ) dump_command = ".dump /ma %s" % dump_file self._root_logger.info("Dumping core to %s", dump_file) @@ -213,21 +238,33 @@ class WindowsDumper(Dumper): dbg = self._find_debugger() if dbg is None: - self._root_logger.warning("Debugger not found, skipping dumping of %s", str(pinfo.pidv)) + self._root_logger.warning( + "Debugger not found, skipping dumping of %s", str(pinfo.pidv) + ) return - self._root_logger.info("Debugger %s, analyzing %s processes with PIDs %s", dbg, pinfo.name, - str(pinfo.pidv)) + self._root_logger.info( + "Debugger %s, analyzing %s processes with PIDs %s", + dbg, + pinfo.name, + str(pinfo.pidv), + ) for pid in pinfo.pidv: logger = _get_process_logger(self._dbg_output, pinfo.name, pid=pid) process = Pinfo(name=pinfo.name, pidv=pid) - cmds = self._prefix() + self._process_specific(process, take_dump) + self._postfix() + cmds = ( + self._prefix() + + self._process_specific(process, take_dump) + + self._postfix() + ) - call([dbg, '-c', ";".join(cmds), '-p', str(pid)], logger) + call([dbg, "-c", ";".join(cmds), "-p", str(pid)], logger) - self._root_logger.info("Done analyzing %s process with PID %d", pinfo.name, pid) + self._root_logger.info( + "Done analyzing %s process with PID %d", pinfo.name, pid + ) def analyze_cores(self, core_file_dir: str, install_dir: str, analysis_dir: str): install_dir = os.path.abspath(install_dir) @@ -244,11 +281,13 @@ class WindowsDumper(Dumper): def analyze_core(self, core_file_path: str, install_dir: str): filename = os.path.basename(core_file_path) - regex = re.search(fr"dump_(.+)\.([0-9]+)\.{self.get_dump_ext()}", filename) + regex = re.search(rf"dump_(.+)\.([0-9]+)\.{self.get_dump_ext()}", filename) if not regex: self._root_logger.warning( - "Core dump file name does not match expected pattern, skipping %s", filename) + "Core dump file name does not match expected pattern, skipping %s", + filename, + ) return binary_name = f"{regex.group(1)}.exe" @@ -262,8 +301,12 @@ class WindowsDumper(Dumper): return if len(binary_files) > 1: - logger.error("More than one file found in %s matching %s", install_dir, binary_name) - raise RuntimeError(f"More than one file found in {install_dir} matching {binary_name}") + logger.error( + "More than one file found in %s matching %s", install_dir, binary_name + ) + raise RuntimeError( + f"More than one file found in {install_dir} matching {binary_name}" + ) binary_path = binary_files[0] symbol_path = binary_path.replace(".exe", ".pdb") @@ -271,26 +314,45 @@ class WindowsDumper(Dumper): dbg = self._find_debugger() if dbg is None: - self._root_logger.warning("Debugger not found, skipping dumping of %s", filename) + self._root_logger.warning( + "Debugger not found, skipping dumping of %s", filename + ) return - cmds = self._prefix() + [ - "!peb", # Dump current exe, & environment variables - "lm", # Dump loaded modules - "!uniqstack -pn", # Dump All unique Threads with function arguments - "!cs -l", # Dump all locked critical sections - ] + self._postfix() + cmds = ( + self._prefix() + + [ + "!peb", # Dump current exe, & environment variables + "lm", # Dump loaded modules + "!uniqstack -pn", # Dump All unique Threads with function arguments + "!cs -l", # Dump all locked critical sections + ] + + self._postfix() + ) call( - [dbg, "-i", binary_path, "-z", core_file_path, "-y", symbol_path, "-v", ";".join(cmds)], - logger) + [ + dbg, + "-i", + binary_path, + "-z", + core_file_path, + "-y", + symbol_path, + "-v", + ";".join(cmds), + ], + logger, + ) def get_dump_ext(self): """Return the dump file extension.""" return "mdmp" def get_binary_from_core_dump(self, core_file_path): - raise NotImplementedError("get_binary_from_core_dump is not implemented on windows") + raise NotImplementedError( + "get_binary_from_core_dump is not implemented on windows" + ) # LLDB dumper is for MacOS X @@ -301,7 +363,7 @@ class LLDBDumper(Dumper): def _find_debugger(): """Find the installed debugger.""" debugger = "lldb" - return find_program(debugger, ['/usr/bin']) + return find_program(debugger, ["/usr/bin"]) def _prefix(self): pass @@ -353,11 +415,17 @@ class LLDBDumper(Dumper): logger = _get_process_logger(self._dbg_output, pinfo.name) if dbg is None: - self._root_logger.warning("Debugger not found, skipping dumping of %s", str(pinfo.pidv)) + self._root_logger.warning( + "Debugger not found, skipping dumping of %s", str(pinfo.pidv) + ) return - self._root_logger.info("Debugger %s, analyzing %s processes with PIDs %s", dbg, pinfo.name, - str(pinfo.pidv)) + self._root_logger.info( + "Debugger %s, analyzing %s processes with PIDs %s", + dbg, + pinfo.name, + str(pinfo.pidv), + ) lldb_version = callo([dbg, "--version"], logger) @@ -368,18 +436,18 @@ class LLDBDumper(Dumper): # XCode (7.2): lldb-340.4.119 # LLVM - lldb version 3.7.0 ( revision ) - if 'version' not in lldb_version: + if "version" not in lldb_version: # We have XCode's lldb - lldb_version = lldb_version[lldb_version.index("lldb-"):] - lldb_version = lldb_version.replace('lldb-', '') - lldb_major_version = int(lldb_version[:lldb_version.index('.')]) + lldb_version = lldb_version[lldb_version.index("lldb-") :] + lldb_version = lldb_version.replace("lldb-", "") + lldb_major_version = int(lldb_version[: lldb_version.index(".")]) if lldb_major_version < 340: logger.warning("Debugger lldb is too old, please upgrade to XCode 7.2") return cmds = self._process_specific(pinfo, take_dump) + self._postfix() - tf = tempfile.NamedTemporaryFile(mode='w', encoding='utf-8') + tf = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8") for cmd in cmds: tf.write(cmd + "\n") @@ -387,12 +455,13 @@ class LLDBDumper(Dumper): tf.flush() # Works on in MacOS 10.9 & later - #call([dbg] + list( itertools.chain.from_iterable([['-o', b] for b in cmds])), logger) - call(['cat', tf.name], logger) - call([dbg, '--source', tf.name], logger) + # call([dbg] + list( itertools.chain.from_iterable([['-o', b] for b in cmds])), logger) + call(["cat", tf.name], logger) + call([dbg, "--source", tf.name], logger) - self._root_logger.info("Done analyzing %s processes with PIDs %s", pinfo.name, - str(pinfo.pidv)) + self._root_logger.info( + "Done analyzing %s processes with PIDs %s", pinfo.name, str(pinfo.pidv) + ) if take_dump: need_sigabrt = {} @@ -418,20 +487,26 @@ class LLDBDumper(Dumper): return files def get_binary_from_core_dump(self, core_file_path): - raise NotImplementedError("get_binary_from_core_dump is not implemented on macos") + raise NotImplementedError( + "get_binary_from_core_dump is not implemented on macos" + ) # GDB dumper is for Linux class GDBDumper(Dumper): """GDBDumper class.""" - def __init__(self, root_logger: logging.Logger, dbg_output: str, - timeout_seconds_for_gdb_process=720): + def __init__( + self, + root_logger: logging.Logger, + dbg_output: str, + timeout_seconds_for_gdb_process=720, + ): """Initialize GDBDumper.""" if resmoke_config.EVERGREEN_TASK_ID is None: # Set 24 hours time out for hang analyzer being run in locally timeout_seconds_for_gdb_process = 86400 - #Timeout for hang analyzer, default timeout is 12mins(out of total 15mins) in Evergreen + # Timeout for hang analyzer, default timeout is 12mins(out of total 15mins) in Evergreen self._timeout_seconds_for_gdb_process = timeout_seconds_for_gdb_process super().__init__(root_logger, dbg_output) @@ -446,12 +521,14 @@ class GDBDumper(Dumper): def _find_debugger(self): """Find the installed debugger.""" debugger = "gdb" - return find_program(debugger, ['/opt/mongodbtoolchain/v4/bin', '/usr/bin']) + return find_program(debugger, ["/opt/mongodbtoolchain/v4/bin", "/usr/bin"]) def _prefix(self): """Return the commands to set up a debugger process.""" - add_venv_sys_path = f"py sys.path.extend({sys.path})" # Makes venv packages available in GDB + add_venv_sys_path = ( + f"py sys.path.extend({sys.path})" # Makes venv packages available in GDB + ) cmds = [ "set interactive-mode off", @@ -494,47 +571,58 @@ class GDBDumper(Dumper): if logger: base, ext = os.path.splitext(filename) set_logging_on_commands = [ - 'set logging file %s_%d%s' % (base, pid, ext), 'set logging on' + "set logging file %s_%d%s" % (base, pid, ext), + "set logging on", ] - set_logging_off_commands = ['set logging off'] + set_logging_off_commands = ["set logging off"] raw_stacks_filename = "%s_%d_raw_stacks%s" % (base, pid, ext) raw_stacks_commands = [ - 'echo \\nWriting raw stacks to %s.\\n' % raw_stacks_filename, + "echo \\nWriting raw stacks to %s.\\n" % raw_stacks_filename, # This sends output to log file rather than stdout until we turn logging off. - 'set logging redirect on', - 'set logging file ' + raw_stacks_filename, - 'set logging on', - 'thread apply all bt', - 'set logging off', - 'set logging redirect off', + "set logging redirect on", + "set logging file " + raw_stacks_filename, + "set logging on", + "thread apply all bt", + "set logging off", + "set logging redirect off", ] else: set_logging_on_commands = [] set_logging_off_commands = [] raw_stacks_commands = [] - mongodb_waitsfor_graph = "mongodb-waitsfor-graph debugger_waitsfor_%s_%d.gv" % \ - (pinfo.name, pid) + mongodb_waitsfor_graph = ( + "mongodb-waitsfor-graph debugger_waitsfor_%s_%d.gv" + % (pinfo.name, pid) + ) - cmds += set_logging_on_commands + [ - "attach %d" % pid, - "handle SIGSTOP ignore noprint", - "info sharedlibrary", - "info threads", # Dump a simple list of commands to get the thread name - ] + set_logging_off_commands + raw_stacks_commands + set_logging_on_commands + [ - mongodb_uniqstack, - # Lock the scheduler, before running commands, which execute code in the attached process. - "set scheduler-locking on", - mongodb_dump_locks, - mongodb_show_locks, - mongodb_waitsfor_graph, - mongodb_javascript_stack, - mongod_dump_sessions, - mongodb_dump_mutexes, - mongodb_dump_recovery_units, - mongodb_dump_storage_engine_info, - "detach", - ] + set_logging_off_commands + cmds += ( + set_logging_on_commands + + [ + "attach %d" % pid, + "handle SIGSTOP ignore noprint", + "info sharedlibrary", + "info threads", # Dump a simple list of commands to get the thread name + ] + + set_logging_off_commands + + raw_stacks_commands + + set_logging_on_commands + + [ + mongodb_uniqstack, + # Lock the scheduler, before running commands, which execute code in the attached process. + "set scheduler-locking on", + mongodb_dump_locks, + mongodb_show_locks, + mongodb_waitsfor_graph, + mongodb_javascript_stack, + mongod_dump_sessions, + mongodb_dump_mutexes, + mongodb_dump_recovery_units, + mongodb_dump_storage_engine_info, + "detach", + ] + + set_logging_off_commands + ) return cmds @@ -551,21 +639,33 @@ class GDBDumper(Dumper): _start_time = datetime.now() if dbg is None: - self._root_logger.warning("Debugger not found, skipping dumping of %s", str(pinfo.pidv)) + self._root_logger.warning( + "Debugger not found, skipping dumping of %s", str(pinfo.pidv) + ) return if self._timeout_seconds_for_gdb_process <= 0: self._root_logger.warning( "Skipping dumping of %s processes with PIDs %s because the time limit expired", - pinfo.name, str(pinfo.pidv)) + pinfo.name, + str(pinfo.pidv), + ) return - self._root_logger.info("Debugger %s, analyzing %s processes with PIDs %s", dbg, pinfo.name, - str(pinfo.pidv)) + self._root_logger.info( + "Debugger %s, analyzing %s processes with PIDs %s", + dbg, + pinfo.name, + str(pinfo.pidv), + ) call([dbg, "--version"], logger) - cmds = self._prefix() + self._process_specific(pinfo, take_dump, logger) + self._postfix() + cmds = ( + self._prefix() + + self._process_specific(pinfo, take_dump, logger) + + self._postfix() + ) # gcore is both a command within GDB and a script packaged alongside gdb. The gcore script # invokes the gdb binary with --readnever to avoid spending time loading the debug symbols @@ -579,30 +679,37 @@ class GDBDumper(Dumper): # supported in the hang analyzer. The only gdb commands we run here are to take core # dumps of the running processes. if take_dump: - call([dbg, "--quiet", "--nx"] + skip_reading_symbols_on_take_dump + list( - itertools.chain.from_iterable([['-ex', b] for b in cmds])), logger, - self._timeout_seconds_for_gdb_process, pinfo) + call( + [dbg, "--quiet", "--nx"] + + skip_reading_symbols_on_take_dump + + list(itertools.chain.from_iterable([["-ex", b] for b in cmds])), + logger, + self._timeout_seconds_for_gdb_process, + pinfo, + ) time_period = (datetime.now() - _start_time).total_seconds() self._reduce_timeout_for_gdb_process(time_period) - self._root_logger.info("Done analyzing %s processes with PIDs %s", pinfo.name, - str(pinfo.pidv)) + self._root_logger.info( + "Done analyzing %s processes with PIDs %s", pinfo.name, str(pinfo.pidv) + ) @TRACER.start_as_current_span("core_analyzer.analyze_cores") def analyze_cores( - self, - core_file_dir: str, - install_dir: str, - analysis_dir: str, - multiversion_dir: str, - gdb_index_cache: str, + self, + core_file_dir: str, + install_dir: str, + analysis_dir: str, + multiversion_dir: str, + gdb_index_cache: str, ) -> Report: core_files = find_files(f"*.{self.get_dump_ext()}", core_file_dir) analyze_cores_span = get_default_current_span() if not core_files: analyze_cores_span.set_status(StatusCode.ERROR, "No core dumps found") - analyze_cores_span.set_attribute("analyze_cores_error", - f"No core dumps found in {core_file_dir}") + analyze_cores_span.set_attribute( + "analyze_cores_error", f"No core dumps found in {core_file_dir}" + ) raise RuntimeError(f"No core dumps found in {core_file_dir}") tmp_dir = os.path.join(analysis_dir, "tmp") @@ -617,7 +724,9 @@ class GDBDumper(Dumper): handler = logging.StreamHandler(log_stream) handler.setFormatter(logging.Formatter(fmt="%(message)s")) logger.addHandler(handler) - with TRACER.start_as_current_span("core_analyzer.analyze_core") as analyze_core_span: + with TRACER.start_as_current_span( + "core_analyzer.analyze_core" + ) as analyze_core_span: analyze_core_span.set_status(StatusCode.OK) try: exit_code, status = self.analyze_core( @@ -634,43 +743,56 @@ class GDBDumper(Dumper): exit_code = 1 status = "fail" - analyze_core_span.set_attributes({ - "analyze_core_status": status, - "core_file": core_file_path, - }) + analyze_core_span.set_attributes( + { + "analyze_core_status": status, + "core_file": core_file_path, + } + ) if status == "fail": - analyze_core_span.set_status(StatusCode.ERROR, - description="Failed to analyze core dump.") + analyze_core_span.set_status( + StatusCode.ERROR, description="Failed to analyze core dump." + ) output = log_stream.getvalue() result = Result( - Result({ - "status": status, "exit_code": exit_code, "test_file": basename, - "log_raw": output - })) + Result( + { + "status": status, + "exit_code": exit_code, + "test_file": basename, + "log_raw": output, + } + ) + ) if exit_code == 1: report["failures"] += 1 report["results"].append(result) - self._root_logger.info("Analysis of %s ended with status %s", basename, status) + self._root_logger.info( + "Analysis of %s ended with status %s", basename, status + ) analyze_cores_span.set_attributes( - {"failures": report["failures"], "core_dump_count": len(core_files)}) + {"failures": report["failures"], "core_dump_count": len(core_files)} + ) shutil.rmtree(tmp_dir) return report def analyze_core( - self, - core_file_path: str, - install_dir: str, - analysis_dir: str, - tmp_dir: str, - multiversion_dir: str, - logger: logging.Logger, - gdb_index_cache: str, + self, + core_file_path: str, + install_dir: str, + analysis_dir: str, + tmp_dir: str, + multiversion_dir: str, + logger: logging.Logger, + gdb_index_cache: str, ) -> Tuple[int, str]: # returns (exit_code, test_status) cmds = [] dbg = self._find_debugger() basename = os.path.basename(core_file_path) if dbg is None: - self._root_logger.error("Debugger not found, skipping dumping of %s", basename) + self._root_logger.error( + "Debugger not found, skipping dumping of %s", basename + ) return 1, "fail" # ensure debugger version is loggged @@ -687,11 +809,15 @@ class GDBDumper(Dumper): return 0, "skip" if len(binary_files) > 1: - logger.error("More than one file found in %s matching %s", install_dir, binary_name) + logger.error( + "More than one file found in %s matching %s", install_dir, binary_name + ) return 1, "fail" binary_path = os.path.realpath(os.path.abspath(binary_files[0])) - lib_dir = os.path.abspath(os.path.join(os.path.dirname(binary_files[0]), "..", "lib")) + lib_dir = os.path.abspath( + os.path.join(os.path.dirname(binary_files[0]), "..", "lib") + ) basename = os.path.basename(core_file_path) logging_dir = os.path.join(analysis_dir, basename) @@ -710,13 +836,15 @@ class GDBDumper(Dumper): def add_commands(command: str, name: str): file_path = os.path.join(logging_dir, f"{basename}.{name}.txt") - cmds.extend([ - f"echo \\nWriting {name} to {file_path}.\\n", - f"set logging file {file_path}", - "set logging enabled on", - command, - "set logging enabled off", - ]) + cmds.extend( + [ + f"echo \\nWriting {name} to {file_path}.\\n", + f"set logging file {file_path}", + "set logging enabled on", + command, + "set logging enabled off", + ] + ) add_commands("info threads", "info_threads") add_commands("thread apply all bt", "backtraces") @@ -726,11 +854,15 @@ class GDBDumper(Dumper): add_commands("mongodb-dump-mutexes", "dump_mutexes") add_commands("mongodb-dump-recovery-units", "dump_recovery_units") # depends on gdbmongo python dependency - add_commands("python print(gdbmongo.LockManagerPrinter.from_global().val)", "dump_locks") + add_commands( + "python print(gdbmongo.LockManagerPrinter.from_global().val)", "dump_locks" + ) cmds = self._prefix() + cmds + self._postfix() - args = [dbg, "--nx"] + list(itertools.chain.from_iterable([['-ex', b] for b in cmds])) + args = [dbg, "--nx"] + list( + itertools.chain.from_iterable([["-ex", b] for b in cmds]) + ) exit_code = call(args, logger, check=False) current_span = trace.get_current_span() @@ -748,13 +880,21 @@ class GDBDumper(Dumper): def get_binary_from_core_dump(self, core_file_path): dbg = self._find_debugger() if dbg is None: - raise RuntimeError("Debugger not found, can't run get_binary_from_core_dump") - process = subprocess.run([dbg, "-batch", "--quiet", "-ex", f"core {core_file_path}"], - check=True, capture_output=True, text=True) + raise RuntimeError( + "Debugger not found, can't run get_binary_from_core_dump" + ) + process = subprocess.run( + [dbg, "-batch", "--quiet", "-ex", f"core {core_file_path}"], + check=True, + capture_output=True, + text=True, + ) regex = re.search("Core was generated by `(.*)'.", process.stdout) if not regex: - raise RuntimeError("gdb output did not match pattern, could not find binary name") + raise RuntimeError( + "gdb output did not match pattern, could not find binary name" + ) binary_path = regex.group(1) binary_name = binary_path.split(" ")[0] @@ -783,7 +923,7 @@ class JstackDumper(object): def _find_debugger(): """Find the installed jstack debugger.""" debugger = "jstack" - return find_program(debugger, ['/usr/bin']) + return find_program(debugger, ["/usr/bin"]) def dump_info(self, root_logger, dbg_output, pid, process_name): """Dump java thread stack traces to the console.""" @@ -794,7 +934,9 @@ class JstackDumper(object): logger.warning("Debugger not found, skipping dumping of %d", pid) return - root_logger.info("Debugger %s, analyzing %s process with PID %d", jstack, process_name, pid) + root_logger.info( + "Debugger %s, analyzing %s process with PID %d", jstack, process_name, pid + ) call([jstack, "-l", str(pid)], logger) @@ -809,7 +951,9 @@ class JstackWindowsDumper(object): def dump_info(root_logger, pid): """Dump java thread stack traces to the logger.""" - root_logger.warning("Debugger jstack not supported, skipping dumping of %d", pid) + root_logger.warning( + "Debugger jstack not supported, skipping dumping of %d", pid + ) def _get_process_logger(dbg_output, pname: str, pid: int = None): @@ -817,12 +961,12 @@ def _get_process_logger(dbg_output, pname: str, pid: int = None): process_logger = logging.Logger("process", level=logging.DEBUG) process_logger.mongo_process_filename = None - if 'stdout' in dbg_output: + if "stdout" in dbg_output: s_handler = logging.StreamHandler(sys.stdout) s_handler.setFormatter(logging.Formatter(fmt="%(message)s")) process_logger.addHandler(s_handler) - if 'file' in dbg_output: + if "file" in dbg_output: if pid: filename = "debugger_%s_%d.log" % (os.path.splitext(pname)[0], pid) else: @@ -842,8 +986,14 @@ class DumpError(Exception): Tracks what cores still need to be generated. """ - def __init__(self, dump_pids, message=("Failed to create core dumps for some processes," - " SIGABRT will be sent as a fallback if -k is set.")): + def __init__( + self, + dump_pids, + message=( + "Failed to create core dumps for some processes," + " SIGABRT will be sent as a fallback if -k is set." + ), + ): """Initialize error.""" self.dump_pids = dump_pids self.message = message diff --git a/buildscripts/resmokelib/hang_analyzer/extractor.py b/buildscripts/resmokelib/hang_analyzer/extractor.py index d0c21c3f1eb..e0e802bd2e2 100644 --- a/buildscripts/resmokelib/hang_analyzer/extractor.py +++ b/buildscripts/resmokelib/hang_analyzer/extractor.py @@ -1,4 +1,5 @@ """Extracts `mongo-debugsymbols.tgz` in an idempotent manner for performance.""" + import concurrent.futures import glob import gzip @@ -35,13 +36,18 @@ from buildscripts.resmokelib.utils.otel_thread_pool_executor import ( from buildscripts.resmokelib.utils.otel_utils import get_default_current_span from evergreen.task import Artifact, Task -_DEBUG_FILE_BASE_NAMES = ['mongo', 'mongod', 'mongos'] +_DEBUG_FILE_BASE_NAMES = ["mongo", "mongod", "mongos"] TOOLCHAIN_ROOT = "/opt/mongodbtoolchain/v4" TRACER = trace.get_tracer("resmoke") -def run_with_retries(root_logger: Logger, func: Callable[..., bool], timeout_secs: int, - retry_secs: int, **kwargs) -> bool: +def run_with_retries( + root_logger: Logger, + func: Callable[..., bool], + timeout_secs: int, + retry_secs: int, + **kwargs, +) -> bool: start_time = time.time() while True: try: @@ -54,16 +60,24 @@ def run_with_retries(root_logger: Logger, func: Callable[..., bool], timeout_sec time_difference = time.time() - start_time if time_difference > timeout_secs: root_logger.error( - f"Timeout hit for function {func.__name__} after {time_difference} seconds.") + f"Timeout hit for function {func.__name__} after {time_difference} seconds." + ) return False - root_logger.error(f"Failed to run {func.__name__}, retrying in {retry_secs} seconds...") + root_logger.error( + f"Failed to run {func.__name__}, retrying in {retry_secs} seconds..." + ) time.sleep(retry_secs) @TRACER.start_as_current_span("core_analyzer.download_core_dumps") -def download_core_dumps(root_logger: Logger, task: Task, download_dir: str, debugger: Dumper, - multiversion_versions: set) -> bool: +def download_core_dumps( + root_logger: Logger, + task: Task, + download_dir: str, + debugger: Dumper, + multiversion_versions: set, +) -> bool: root_logger.info("Looking for core dumps") artifacts = task.artifacts core_dumps_found = 0 @@ -79,10 +93,12 @@ def download_core_dumps(root_logger: Logger, task: Task, download_dir: str, debu extract_path = os.path.join(core_dumps_dir, extracted_name) with TRACER.start_as_current_span( - "core_analyzer.download_core_dump", attributes={ - "core_dump_file_name": file_name, - "core_dump_file_url": artifact.url, - }) as core_dump_span: + "core_analyzer.download_core_dump", + attributes={ + "core_dump_file_name": file_name, + "core_dump_file_url": artifact.url, + }, + ) as core_dump_span: attempts = 0 core_dump_span.set_status(StatusCode.OK) try: @@ -98,17 +114,20 @@ def download_core_dumps(root_logger: Logger, task: Task, download_dir: str, debu root_logger.info(f"Extracting core dump: {file_name}") if os.path.exists(extract_path): os.remove(extract_path) - with gzip.open(file_name, 'rb') as f_in: - with open(extract_path, 'wb') as f_out: + with gzip.open(file_name, "rb") as f_in: + with open(extract_path, "wb") as f_out: shutil.copyfileobj(f_in, f_out) - core_dump_span.set_attributes({ - "core_dump_compressed_size": os.path.getsize(file_name), - "core_dump_extracted_size": os.path.getsize(extract_path), - "core_dump_download_attempts": attempts, - }) + core_dump_span.set_attributes( + { + "core_dump_compressed_size": os.path.getsize(file_name), + "core_dump_extracted_size": os.path.getsize(extract_path), + "core_dump_download_attempts": attempts, + } + ) root_logger.info( - f"Done extracting core dump {extracted_name} to {extract_path}") + f"Done extracting core dump {extracted_name} to {extract_path}" + ) os.remove(file_name) _, bin_version = debugger.get_binary_from_core_dump(extract_path) @@ -121,53 +140,88 @@ def download_core_dumps(root_logger: Logger, task: Task, download_dir: str, debu except Exception as ex: root_logger.error( "An error occured while trying to download and extract core dump %s", - extracted_name) + extracted_name, + ) root_logger.error(ex) - core_dump_span.set_status(StatusCode.ERROR, "Failed to download core dump.") - core_dump_span.set_attributes({ - "core_dump_download_attempts": attempts, - "core_dump_error": ex, - }) + core_dump_span.set_status( + StatusCode.ERROR, "Failed to download core dump." + ) + core_dump_span.set_attributes( + { + "core_dump_download_attempts": attempts, + "core_dump_error": ex, + } + ) - core_dump_directory_size = sum(f.stat().st_size for f in Path("./").glob('**/*') if f.is_file()) - current_span.set_attributes({ - "core_dumps_dir": core_dumps_dir, - "core_dump_directory_size": core_dump_directory_size, - }) + core_dump_directory_size = sum( + f.stat().st_size for f in Path("./").glob("**/*") if f.is_file() + ) + current_span.set_attributes( + { + "core_dumps_dir": core_dumps_dir, + "core_dump_directory_size": core_dump_directory_size, + } + ) if not core_dumps_found: root_logger.error("No core dumps found") current_span.set_status(StatusCode.ERROR, description="No core dumps found") - current_span.set_attributes({ - "core_dumps_error": "No core dumps found", - }) + current_span.set_attributes( + { + "core_dumps_error": "No core dumps found", + } + ) return False return True @TRACER.start_as_current_span("core_analyzer.download_multiversion_artifact") -def download_multiversion_artifact(root_logger: Logger, version_id: str, variant: str, - download_options: _DownloadOptions, download_dir: str, name: str, - bin_version: str = None) -> bool: +def download_multiversion_artifact( + root_logger: Logger, + version_id: str, + variant: str, + download_options: _DownloadOptions, + download_dir: str, + name: str, + bin_version: str = None, +) -> bool: current_span = get_default_current_span( - {"downloaded_artifact_type": name, "version": bin_version if bin_version else "current"}) + { + "downloaded_artifact_type": name, + "version": bin_version if bin_version else "current", + } + ) try: root_logger.info("Downloading %s", name) - multiversion_setup = SetupMultiversion(download_options=download_options, - ignore_failed_push=True, - link_dir=os.path.abspath(download_dir)) - urlinfo = multiversion_setup.get_urls(version=version_id, buildvariant_name=variant) + multiversion_setup = SetupMultiversion( + download_options=download_options, + ignore_failed_push=True, + link_dir=os.path.abspath(download_dir), + ) + urlinfo = multiversion_setup.get_urls( + version=version_id, buildvariant_name=variant + ) if bin_version: - install_dir = os.path.abspath(os.path.join(download_dir, bin_version, "install")) + install_dir = os.path.abspath( + os.path.join(download_dir, bin_version, "install") + ) os.makedirs(install_dir, exist_ok=True) multiversion_setup.download_and_extract_from_urls( - urlinfo.urls, bin_suffix=bin_version, install_dir=install_dir, skip_symlinks=False) + urlinfo.urls, + bin_suffix=bin_version, + install_dir=install_dir, + skip_symlinks=False, + ) else: install_dir = os.path.abspath(os.path.join(download_dir, "install")) os.makedirs(install_dir, exist_ok=True) multiversion_setup.download_and_extract_from_urls( - urlinfo.urls, bin_suffix=None, install_dir=install_dir, skip_symlinks=True) + urlinfo.urls, + bin_suffix=None, + install_dir=install_dir, + skip_symlinks=True, + ) root_logger.info("Downloaded %s", name) return True except Exception as ex: @@ -180,16 +234,24 @@ def download_multiversion_artifact(root_logger: Logger, version_id: str, variant @TRACER.start_as_current_span("core_analyzer.post_install_gdb_optimization") def post_install_gdb_optimization(download_dir: str, root_looger: Logger): - @TRACER.start_as_current_span("core_analyzer.post_install_gdb_optimization.add_index") + @TRACER.start_as_current_span( + "core_analyzer.post_install_gdb_optimization.add_index" + ) def add_index(file_path: str): """Generate and add gdb-index to ELF binary.""" - current_span = get_default_current_span({ - "file": file_path, "add_index_status": "success", - "add_index_original_file_size": os.path.getsize(file_path) - }) + current_span = get_default_current_span( + { + "file": file_path, + "add_index_status": "success", + "add_index_original_file_size": os.path.getsize(file_path), + } + ) start_time = time.time() - process = subprocess.run([f"{TOOLCHAIN_ROOT}/bin/llvm-dwarfdump", "-r", "0", file_path], - capture_output=True, text=True) + process = subprocess.run( + [f"{TOOLCHAIN_ROOT}/bin/llvm-dwarfdump", "-r", "0", file_path], + capture_output=True, + text=True, + ) # it is normal for non debug binaries to fail this command # there also can be some python files in the bin dir that will fail @@ -200,11 +262,15 @@ def post_install_gdb_optimization(download_dir: str, root_looger: Logger): # find dwarf version from output, it should always be present regex = re.search("version = 0x([0-9]{4}),", process.stdout) if not regex: - current_span.set_status(StatusCode.ERROR, "Could not find dwarf version in file.") - current_span.set_attributes({ - "add_index_status": "failed", - "add_index_error": "Could not find dwarf version in file", - }) + current_span.set_status( + StatusCode.ERROR, "Could not find dwarf version in file." + ) + current_span.set_attributes( + { + "add_index_status": "failed", + "add_index_error": "Could not find dwarf version in file", + } + ) raise RuntimeError(f"Could not find dwarf version in file {file_path}") version = int(regex.group(1)) @@ -214,91 +280,156 @@ def post_install_gdb_optimization(download_dir: str, root_looger: Logger): try: # logic copied from https://sourceware.org/gdb/onlinedocs/gdb/Index-Files.html if version == 5: - subprocess.run([ - f"{TOOLCHAIN_ROOT}/bin/gdb", "--batch-silent", "--quiet", "--nx", - "--eval-command", f"save gdb-index -dwarf-5 {target_dir}", file_path - ], check=True) - subprocess.run([ - f"{TOOLCHAIN_ROOT}/bin/objcopy", "--dump-section", - f".debug_str={file_path}.debug_str.new", file_path - ]) + subprocess.run( + [ + f"{TOOLCHAIN_ROOT}/bin/gdb", + "--batch-silent", + "--quiet", + "--nx", + "--eval-command", + f"save gdb-index -dwarf-5 {target_dir}", + file_path, + ], + check=True, + ) + subprocess.run( + [ + f"{TOOLCHAIN_ROOT}/bin/objcopy", + "--dump-section", + f".debug_str={file_path}.debug_str.new", + file_path, + ] + ) with open(f"{file_path}.debug_str", "r") as file1: with open(f"{file_path}.debug_str.new", "a") as file2: file2.write(file1.read()) - subprocess.run([ - f"{TOOLCHAIN_ROOT}/bin/objcopy", "--add-section", - f".debug_names={file_path}.debug_names", "--set-section-flags", - ".debug_names=readonly", "--update-section", - f".debug_str={file_path}.debug_str.new", file_path, file_path - ], check=True) + subprocess.run( + [ + f"{TOOLCHAIN_ROOT}/bin/objcopy", + "--add-section", + f".debug_names={file_path}.debug_names", + "--set-section-flags", + ".debug_names=readonly", + "--update-section", + f".debug_str={file_path}.debug_str.new", + file_path, + file_path, + ], + check=True, + ) os.remove(f"{file_path}.debug_str.new") os.remove(f"{file_path}.debug_str") os.remove(f"{file_path}.debug_names") elif version == 4: - subprocess.run([ - f"{TOOLCHAIN_ROOT}/bin/gdb", "--batch-silent", "--quiet", "--nx", - "--eval-command", f"save gdb-index {target_dir}", file_path - ], check=True) - subprocess.run([ - f"{TOOLCHAIN_ROOT}/bin/objcopy", "--add-section", - f".gdb_index={file_path}.gdb-index", "--set-section-flags", - ".gdb_index=readonly", file_path, file_path - ], check=True) + subprocess.run( + [ + f"{TOOLCHAIN_ROOT}/bin/gdb", + "--batch-silent", + "--quiet", + "--nx", + "--eval-command", + f"save gdb-index {target_dir}", + file_path, + ], + check=True, + ) + subprocess.run( + [ + f"{TOOLCHAIN_ROOT}/bin/objcopy", + "--add-section", + f".gdb_index={file_path}.gdb-index", + "--set-section-flags", + ".gdb_index=readonly", + file_path, + file_path, + ], + check=True, + ) os.remove(f"{file_path}.gdb-index") else: - current_span.set_status(StatusCode.ERROR, f"Unsupported dwarf version: {version}") - current_span.set_attributes({ - "add_index_status": "failed", - "add_index_error": "Does not support dwarf version", - }) + current_span.set_status( + StatusCode.ERROR, f"Unsupported dwarf version: {version}" + ) + current_span.set_attributes( + { + "add_index_status": "failed", + "add_index_error": "Does not support dwarf version", + } + ) raise RuntimeError(f"Does not support dwarf version {version}") except Exception as ex: root_looger.exception("Failed to add gdb index to %s", file_path) current_span.set_status(StatusCode.ERROR, "Failed to add gdb index") - current_span.set_attributes({ - "add_index_status": "failed", - "add_index_error": ex, - }) + current_span.set_attributes( + { + "add_index_status": "failed", + "add_index_error": ex, + } + ) return - current_span.set_attribute("add_index_changed_file_size", os.path.getsize(file_path)) + current_span.set_attribute( + "add_index_changed_file_size", os.path.getsize(file_path) + ) - root_looger.debug("Finished creating gdb-index for %s in %s", file_path, - (time.time() - start_time)) + root_looger.debug( + "Finished creating gdb-index for %s in %s", + file_path, + (time.time() - start_time), + ) - @TRACER.start_as_current_span("core_analyzer.post_install_gdb_optimization.recalc_debuglink") + @TRACER.start_as_current_span( + "core_analyzer.post_install_gdb_optimization.recalc_debuglink" + ) def recalc_debuglink(file_path: str): """ Recalcuate the debuglink for ELF binaries. - + After creating the index file in a separate debug file, the debuglink CRC is no longer valid, this will simply recreate the debuglink and therefore update the CRC to match. """ current_span = get_default_current_span( - {"file": file_path, "recalc_debuglink_status": "success"}) + {"file": file_path, "recalc_debuglink_status": "success"} + ) - process = subprocess.run([f"{TOOLCHAIN_ROOT}/bin/eu-readelf", "-S", file_path], - capture_output=True, text=True) + process = subprocess.run( + [f"{TOOLCHAIN_ROOT}/bin/eu-readelf", "-S", file_path], + capture_output=True, + text=True, + ) if process.returncode != 0 or ".gnu_debuglink" not in process.stdout: current_span.set_attribute("recalc_debuglink_status", "skipped") return try: subprocess.run( - [f"{TOOLCHAIN_ROOT}/bin/objcopy", "--remove-section", ".gnu_debuglink", file_path], - check=True) - subprocess.run([ - f"{TOOLCHAIN_ROOT}/bin/objcopy", "--add-gnu-debuglink", - f"{os.path.abspath(file_path)}.debug", file_path - ], check=True) + [ + f"{TOOLCHAIN_ROOT}/bin/objcopy", + "--remove-section", + ".gnu_debuglink", + file_path, + ], + check=True, + ) + subprocess.run( + [ + f"{TOOLCHAIN_ROOT}/bin/objcopy", + "--add-gnu-debuglink", + f"{os.path.abspath(file_path)}.debug", + file_path, + ], + check=True, + ) except Exception as ex: root_looger.exception("Failed to recalculate debuglink") current_span.set_status(StatusCode.ERROR, "Failed to recalculate debuglink") - current_span.set_attributes({ - "recalc_debuglink_status": "failed", - "recalc_debuglink_error": ex, - }) + current_span.set_attributes( + { + "recalc_debuglink_status": "failed", + "recalc_debuglink_error": ex, + } + ) return root_looger.debug("Finished recalculating the debuglink for %s", file_path) @@ -312,7 +443,8 @@ def post_install_gdb_optimization(download_dir: str, root_looger: Logger): with OtelThreadPoolExecutor() as executor: with TRACER.start_as_current_span( - "core_analyzer.post_install_gdb_optimization.add_indexes") as current_span: + "core_analyzer.post_install_gdb_optimization.add_indexes" + ) as current_span: futures = [] current_span.set_status(StatusCode.OK) for file_path in lib_files: @@ -324,7 +456,8 @@ def post_install_gdb_optimization(download_dir: str, root_looger: Logger): concurrent.futures.wait(futures) with TRACER.start_as_current_span( - "core_analyzer.post_install_gdb_optimization.recalc_debuglink"): + "core_analyzer.post_install_gdb_optimization.recalc_debuglink" + ): futures = [] current_span = get_default_current_span() for file_path in lib_files: @@ -338,9 +471,16 @@ def post_install_gdb_optimization(download_dir: str, root_looger: Logger): @TRACER.start_as_current_span("core_analyzer.download_task_artifacts") -def download_task_artifacts(root_logger: Logger, task_id: str, download_dir: str, debugger: Dumper, - multiversion_dir: str, execution: Optional[int] = None, - retry_secs: int = 10, download_timeout_secs: int = 30 * 60) -> bool: +def download_task_artifacts( + root_logger: Logger, + task_id: str, + download_dir: str, + debugger: Dumper, + multiversion_dir: str, + execution: Optional[int] = None, + retry_secs: int = 10, + download_timeout_secs: int = 30 * 60, +) -> bool: if os.path.exists(download_dir): # quick sanity check to ensure we don't delete a repo if os.path.exists(os.path.join(download_dir, ".git")): @@ -358,7 +498,9 @@ def download_task_artifacts(root_logger: Logger, task_id: str, download_dir: str else: task_info = evg_api.task_by_id(task_id) binary_download_options = _DownloadOptions(db=True, ds=False, da=False, dv=False) - debugsymbols_download_options = _DownloadOptions(db=False, ds=True, da=False, dv=False) + debugsymbols_download_options = _DownloadOptions( + db=False, ds=True, da=False, dv=False + ) @retry(tries=3, delay=5) def get_multiversion_download_links(task: Task) -> Optional[dict]: @@ -391,70 +533,122 @@ def download_task_artifacts(root_logger: Logger, task_id: str, download_dir: str with OtelThreadPoolExecutor() as executor: futures = [] futures.append( - executor.submit(run_with_retries, root_logger=root_logger, func=download_core_dumps, - timeout_secs=download_timeout_secs, retry_secs=retry_secs, - task=task_info, download_dir=download_dir, debugger=debugger, - multiversion_versions=multiversion_versions)) + executor.submit( + run_with_retries, + root_logger=root_logger, + func=download_core_dumps, + timeout_secs=download_timeout_secs, + retry_secs=retry_secs, + task=task_info, + download_dir=download_dir, + debugger=debugger, + multiversion_versions=multiversion_versions, + ) + ) futures.append( - executor.submit(run_with_retries, func=download_multiversion_artifact, - timeout_secs=download_timeout_secs, retry_secs=retry_secs, - root_logger=root_logger, version_id=version_id, variant=variant, - download_options=binary_download_options, download_dir=download_dir, - name="binaries")) + executor.submit( + run_with_retries, + func=download_multiversion_artifact, + timeout_secs=download_timeout_secs, + retry_secs=retry_secs, + root_logger=root_logger, + version_id=version_id, + variant=variant, + download_options=binary_download_options, + download_dir=download_dir, + name="binaries", + ) + ) futures.append( - executor.submit(run_with_retries, func=download_multiversion_artifact, - timeout_secs=download_timeout_secs, retry_secs=retry_secs, - root_logger=root_logger, version_id=version_id, variant=variant, - download_options=debugsymbols_download_options, - download_dir=download_dir, name="debugsymbols")) + executor.submit( + run_with_retries, + func=download_multiversion_artifact, + timeout_secs=download_timeout_secs, + retry_secs=retry_secs, + root_logger=root_logger, + version_id=version_id, + variant=variant, + download_options=debugsymbols_download_options, + download_dir=download_dir, + name="debugsymbols", + ) + ) for future in concurrent.futures.as_completed(futures): if not future.result(): - current_span.set_status(StatusCode.ERROR, "Errors occured while fetching artifacts") - current_span.set_attribute("download_task_artifacts_error", - "Errors occured while fetching artifacts") + current_span.set_status( + StatusCode.ERROR, "Errors occured while fetching artifacts" + ) + current_span.set_attribute( + "download_task_artifacts_error", + "Errors occured while fetching artifacts", + ) root_logger.error("Errors occured while fetching artifacts") all_downloaded = False break if multiversion_versions: if not multiversion_downloads: - raise RuntimeError("Multiversion core dumps were found without download links.") + raise RuntimeError( + "Multiversion core dumps were found without download links." + ) with OtelThreadPoolExecutor() as executor: futures = [] for version in multiversion_versions: version_downloads = next( filter( - lambda actual, desired=version: actual.get("bin_suffix") == desired, + lambda actual, desired=version: actual.get("bin_suffix") + == desired, multiversion_downloads, ) ) version_id = version_downloads["evg_urls_info"]["evg_version_id"] variant = version_downloads["evg_urls_info"]["evg_build_variant"] futures.append( - executor.submit(run_with_retries, func=download_multiversion_artifact, - timeout_secs=download_timeout_secs, retry_secs=retry_secs, - root_logger=root_logger, version_id=version_id, variant=variant, - download_options=binary_download_options, - download_dir=multiversion_dir, name=f"binaries-{version}", - bin_version=version)) + executor.submit( + run_with_retries, + func=download_multiversion_artifact, + timeout_secs=download_timeout_secs, + retry_secs=retry_secs, + root_logger=root_logger, + version_id=version_id, + variant=variant, + download_options=binary_download_options, + download_dir=multiversion_dir, + name=f"binaries-{version}", + bin_version=version, + ) + ) futures.append( - executor.submit(run_with_retries, func=download_multiversion_artifact, - timeout_secs=download_timeout_secs, retry_secs=retry_secs, - root_logger=root_logger, version_id=version_id, variant=variant, - download_options=debugsymbols_download_options, - download_dir=multiversion_dir, name=f"debugsymbols-{version}", - bin_version=version)) + executor.submit( + run_with_retries, + func=download_multiversion_artifact, + timeout_secs=download_timeout_secs, + retry_secs=retry_secs, + root_logger=root_logger, + version_id=version_id, + variant=variant, + download_options=debugsymbols_download_options, + download_dir=multiversion_dir, + name=f"debugsymbols-{version}", + bin_version=version, + ) + ) for future in concurrent.futures.as_completed(futures): if not future.result(): - current_span.set_status(StatusCode.ERROR, - "Errors occured while fetching old version artifacts") + current_span.set_status( + StatusCode.ERROR, + "Errors occured while fetching old version artifacts", + ) current_span.set_attribute( "download_task_artifacts_error", - "Errors occured while fetching old version artifacts") - root_logger.error("Errors occured while fetching old version artifacts") + "Errors occured while fetching old version artifacts", + ) + root_logger.error( + "Errors occured while fetching old version artifacts" + ) all_downloaded = False break @@ -462,13 +656,19 @@ def download_task_artifacts(root_logger: Logger, task_id: str, download_dir: str post_install_gdb_optimization(download_dir, root_logger) for version in multiversion_versions: - post_install_gdb_optimization(os.path.join(multiversion_dir, version), root_logger) + post_install_gdb_optimization( + os.path.join(multiversion_dir, version), root_logger + ) return all_downloaded -def download_debug_symbols(root_logger, symbolizer: Symbolizer, retry_secs: int = 10, - download_timeout_secs: int = 10 * 60): +def download_debug_symbols( + root_logger, + symbolizer: Symbolizer, + retry_secs: int = 10, + download_timeout_secs: int = 10 * 60, +): """ Extract debug symbols. Idempotent. @@ -484,7 +684,8 @@ def download_debug_symbols(root_logger, symbolizer: Symbolizer, retry_secs: int if len(sym_files) >= len(_DEBUG_FILE_BASE_NAMES): root_logger.info( - "Skipping downloading debug symbols as there are already symbol files present") + "Skipping downloading debug symbols as there are already symbol files present" + ) return while True: @@ -495,21 +696,30 @@ def download_debug_symbols(root_logger, symbolizer: Symbolizer, retry_secs: int except (tarfile.ReadError, DownloadError): root_logger.warn( "Debug symbols unavailable after %s secs, retrying in %s secs, waiting for a total of %s secs", - compare_start_time(time.time()), retry_secs, download_timeout_secs) + compare_start_time(time.time()), + retry_secs, + download_timeout_secs, + ) time.sleep(retry_secs) if compare_start_time(time.time()) > download_timeout_secs: root_logger.warn( - 'Debug-symbols archive-file does not exist after %s secs; ' - 'Hang-Analyzer may not complete successfully.', download_timeout_secs) + "Debug-symbols archive-file does not exist after %s secs; " + "Hang-Analyzer may not complete successfully.", + download_timeout_secs, + ) break def _get_symbol_files(): out = [] - for ext in ['debug', 'dSYM', 'pdb']: + for ext in ["debug", "dSYM", "pdb"]: for file in _DEBUG_FILE_BASE_NAMES: - haystack = build_hygienic_bin_path(child='{file}.{ext}'.format(file=file, ext=ext)) + haystack = build_hygienic_bin_path( + child="{file}.{ext}".format(file=file, ext=ext) + ) for needle in glob.glob(haystack): - out.append((needle, os.path.join(os.getcwd(), os.path.basename(needle)))) + out.append( + (needle, os.path.join(os.getcwd(), os.path.basename(needle))) + ) return out diff --git a/buildscripts/resmokelib/hang_analyzer/gen_hang_analyzer_tasks.py b/buildscripts/resmokelib/hang_analyzer/gen_hang_analyzer_tasks.py index 74370204e06..670ec150aba 100644 --- a/buildscripts/resmokelib/hang_analyzer/gen_hang_analyzer_tasks.py +++ b/buildscripts/resmokelib/hang_analyzer/gen_hang_analyzer_tasks.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 """Generate a task to run core analysis on uploaded core dumps in evergreen.""" + import argparse import json import os @@ -31,14 +32,18 @@ MULTIVERSION_BIN_DIR = os.path.normpath("/data/multiversion") def get_generated_task_name(current_task_name: str, execution: str) -> str: # random string so we do not define the same task name for multiple variants which causes issues - random_string = ''.join( - random.choices(string.ascii_uppercase + string.digits + string.ascii_lowercase, - k=RANDOM_STRING_LENGTH)) + random_string = "".join( + random.choices( + string.ascii_uppercase + string.digits + string.ascii_lowercase, + k=RANDOM_STRING_LENGTH, + ) + ) return f"{GENERATED_TASK_PREFIX}_{current_task_name}{execution}_{random_string}" -def get_core_analyzer_commands(task_id: str, execution: str, core_analyzer_results_url: str, - gdb_index_cache: str) -> List[FunctionCall]: +def get_core_analyzer_commands( + task_id: str, execution: str, core_analyzer_results_url: str, gdb_index_cache: str +) -> List[FunctionCall]: """Return setup commands.""" return [ FunctionCall("f_expansions_write"), @@ -51,7 +56,8 @@ def get_core_analyzer_commands(task_id: str, execution: str, core_analyzer_resul FunctionCall("upload pip requirements"), FunctionCall("configure evergreen api credentials"), BuiltInCommand( - "subprocess.exec", { + "subprocess.exec", + { "binary": "bash", "args": [ "src/evergreen/run_python_script.sh", @@ -67,15 +73,19 @@ def get_core_analyzer_commands(task_id: str, execution: str, core_analyzer_resul "OTEL_PARENT_ID": "${otel_parent_id}", "OTEL_COLLECTOR_DIR": "../build/OTelTraces/", }, - }), + }, + ), BuiltInCommand( - "archive.targz_pack", { + "archive.targz_pack", + { "target": "src/mongo-coreanalysis.tgz", "source_dir": "src", "include": ["./core-analyzer/analysis/**"], - }), + }, + ), BuiltInCommand( - "s3.put", { + "s3.put", + { "aws_key": "${aws_key}", "aws_secret": "${aws_secret}", "local_file": "src/mongo-coreanalysis.tgz", @@ -84,22 +94,28 @@ def get_core_analyzer_commands(task_id: str, execution: str, core_analyzer_resul "permissions": "public-read", "content_type": "application/gzip", "display_name": "Core Analyzer Output - Execution ${execution}", - }), + }, + ), # We delete the core dumps after we are done processing them so they are not # reuploaded to s3 in the generated task's post task block FunctionCall( - "remove files", { - "files": - " ".join([ - "src/core-analyzer/core-dumps/*.core", "src/core-analyzer/core-dumps/*.mdmp" - ]) - }), + "remove files", + { + "files": " ".join( + [ + "src/core-analyzer/core-dumps/*.core", + "src/core-analyzer/core-dumps/*.mdmp", + ] + ) + }, + ), ] -def generate(expansions_file: str = "../expansions.yml", - output_file: str = "hang_analyzer_task.json") -> None: - +def generate( + expansions_file: str = "../expansions.yml", + output_file: str = "hang_analyzer_task.json", +) -> None: if not sys.platform.startswith("linux"): print("This platform is not supported, skipping core analysis task generation.") return @@ -107,14 +123,20 @@ def generate(expansions_file: str = "../expansions.yml", # gather information from the current task being run expansions = read_config_file(expansions_file) distro = None - for distro_expansion in ["core_analyzer_distro_name", "large_distro_name", "distro_id"]: + for distro_expansion in [ + "core_analyzer_distro_name", + "large_distro_name", + "distro_id", + ]: if distro := expansions.get(distro_expansion, None): break assert distro is not None current_task_name = expansions.get("task_name") task_id = expansions.get("task_id") execution = expansions.get("execution") - gdb_index_cache = "off" if expansions.get("core_analyzer_gdb_index_cache") == "off" else "on" + gdb_index_cache = ( + "off" if expansions.get("core_analyzer_gdb_index_cache") == "off" else "on" + ) build_variant_name = expansions.get("build_variant") core_analyzer_results_url = expansions.get("core_analyzer_results_url") compile_variant = expansions.get("compile_variant") @@ -124,13 +146,16 @@ def generate(expansions_file: str = "../expansions.yml", except RuntimeError: print( "WARNING: Cannot generate core analysis because the evergreen api file could not be found.", - file=sys.stderr) + file=sys.stderr, + ) print( "This is probably not an error, if you want core analysis to run on this task make sure", - file=sys.stderr) + file=sys.stderr, + ) print( "the evergreen function 'configure evergreen api credentials' is called before this task", - file=sys.stderr) + file=sys.stderr, + ) return task_info = evg_api.task_by_id(task_id) @@ -144,7 +169,9 @@ def generate(expansions_file: str = "../expansions.yml", # LOCAL_BIN_DIR does not exists on non-resmoke tasks, so return early as there is no work to be done. if not os.path.exists(LOCAL_BIN_DIR): - print(f"Skipping task generation because binary directory not found: {LOCAL_BIN_DIR}") + print( + f"Skipping task generation because binary directory not found: {LOCAL_BIN_DIR}" + ) return # See if any core dumps were uploaded for this task @@ -163,7 +190,9 @@ def generate(expansions_file: str = "../expansions.yml", if binary_name in binary_files: has_known_core_dumps = True break - print(f"{core_file} was generated by {binary_name} but the binary was not found.") + print( + f"{core_file} was generated by {binary_name} but the binary was not found." + ) if not has_known_core_dumps: print( @@ -179,12 +208,15 @@ def generate(expansions_file: str = "../expansions.yml", # Make the evergreen variant that will be generated build_variant = BuildVariant(name=build_variant_name, activate=True) - commands = get_core_analyzer_commands(task_id, execution, core_analyzer_results_url, - gdb_index_cache) + commands = get_core_analyzer_commands( + task_id, execution, core_analyzer_results_url, gdb_index_cache + ) deps = {TaskDependency("archive_dist_test_debug", compile_variant)} # TODO SERVER-92571 add archive_jstestshell_debug dep for variants that have it. - sub_tasks = set([Task(get_generated_task_name(current_task_name, execution), commands, deps)]) + sub_tasks = set( + [Task(get_generated_task_name(current_task_name, execution), commands, deps)] + ) if display_task_name: # If the task is already in a display task add the new task to the current display task @@ -207,13 +239,18 @@ def generate(expansions_file: str = "../expansions.yml", write_file(output_file, json.dumps(output_dict)) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--expansions-file", help="Location of evergreen expansions file.", - default="../expansions.yml") - parser.add_argument("--output-file", - help="Name of output file to write the generated task config to.", - default="hang_analyzer_task.json") + parser.add_argument( + "--expansions-file", + help="Location of evergreen expansions file.", + default="../expansions.yml", + ) + parser.add_argument( + "--output-file", + help="Name of output file to write the generated task config to.", + default="hang_analyzer_task.json", + ) args = parser.parse_args() expansions_file = args.expansions_file output_file = args.output_file diff --git a/buildscripts/resmokelib/hang_analyzer/hang_analyzer.py b/buildscripts/resmokelib/hang_analyzer/hang_analyzer.py index cfe67416bbb..ccd44f894c5 100755 --- a/buildscripts/resmokelib/hang_analyzer/hang_analyzer.py +++ b/buildscripts/resmokelib/hang_analyzer/hang_analyzer.py @@ -9,6 +9,7 @@ A prototype hang analyzer for Evergreen integration to help investigate test tim Supports Linux, MacOS X, and Windows. """ + import getpass import logging import os @@ -48,7 +49,7 @@ class HangAnalyzer(Subcommand): "mongod", "mongos", "_test", - "dbtest" + "dbtest", ] self.go_processes = [] self.process_ids = [] @@ -60,7 +61,8 @@ class HangAnalyzer(Subcommand): raise ValueError( "The Evergreen Task ID (tid) should be either passed in through `resmoke.py run` " "or through `resmoke.py hang-analyzer` but not both. run tid: %s, hang-analyzer tid: %s" - % (run_tid, hang_analyzer_tid)) + % (run_tid, hang_analyzer_tid) + ) return run_tid or hang_analyzer_tid self.task_id = configure_task_id() @@ -69,8 +71,12 @@ class HangAnalyzer(Subcommand): def kill_rogue_processes(self): """Kill any processes that are currently being analyzed.""" - processes = process_list.get_processes(self.process_ids, self.interesting_processes, - self.options.process_match, self.root_logger) + processes = process_list.get_processes( + self.process_ids, + self.interesting_processes, + self.options.process_match, + self.root_logger, + ) process.teardown_processes(self.root_logger, processes, dump_pids={}) def execute(self): @@ -85,8 +91,12 @@ class HangAnalyzer(Subcommand): dumpers = dumper.get_dumpers(self.root_logger, self.options.debugger_output) - processes = process_list.get_processes(self.process_ids, self.interesting_processes, - self.options.process_match, self.root_logger) + processes = process_list.get_processes( + self.process_ids, + self.interesting_processes, + self.options.process_match, + self.root_logger, + ) def is_python_process(pname: str): # "live-record*" and "python*" are Python processes. Sending SIGUSR1 causes resmoke.py @@ -96,7 +106,9 @@ class HangAnalyzer(Subcommand): # Suspending all processes, except python, to prevent them from getting unstuck when # the hang analyzer attaches to them. - for pinfo in [pinfo for pinfo in processes if not is_python_process(pinfo.name)]: + for pinfo in [ + pinfo for pinfo in processes if not is_python_process(pinfo.name) + ]: for pid in pinfo.pidv: process.pause_process(self.root_logger, pinfo.name, pid) @@ -112,9 +124,11 @@ class HangAnalyzer(Subcommand): # Dump core files of all processes, except python & java. if self.options.dump_core: take_core_processes = [ - pinfo for pinfo in processes if not re.match("^(java|python)", pinfo.name) + pinfo + for pinfo in processes + if not re.match("^(java|python)", pinfo.name) ] - if (os.getenv('ASAN_OPTIONS') or os.getenv('TSAN_OPTIONS')): + if os.getenv("ASAN_OPTIONS") or os.getenv("TSAN_OPTIONS"): quit_processes: list[psutil.Process] = [] for pinfo in take_core_processes: for pid in pinfo.pidv: @@ -123,7 +137,8 @@ class HangAnalyzer(Subcommand): process.resume_process(self.root_logger, pinfo.name, pid) self.root_logger.info( "Process %d may be running a sanitizer which uses a large amount of virtual memory.", - pid) + pid, + ) self.root_logger.info( "Attempting to send SIGABRT from resmoke to capture a more manageable sized core dump" ) @@ -138,8 +153,10 @@ class HangAnalyzer(Subcommand): alive_processes = [] # This loop filters out processes that have ended or become a zombie for quit_process in quit_processes: - if quit_process.is_running( - ) and quit_process.status() != psutil.STATUS_ZOMBIE: + if ( + quit_process.is_running() + and quit_process.status() != psutil.STATUS_ZOMBIE + ): alive_processes.append(quit_process) # Update the quit_processes list with only the ones left alive @@ -154,7 +171,7 @@ class HangAnalyzer(Subcommand): f"The following processes took too long to end after SIGABRT: {alive_processes}" ) - time.sleep(.1) + time.sleep(0.1) self.root_logger.info("Finished waiting for all processes to end.") else: for pinfo in take_core_processes: @@ -165,14 +182,17 @@ class HangAnalyzer(Subcommand): self.root_logger.error(err.message) dump_pids = {**err.dump_pids, **dump_pids} except Exception as err: # pylint: disable=broad-except - self.root_logger.info("Error encountered when invoking debugger %s", - err) + self.root_logger.info( + "Error encountered when invoking debugger %s", err + ) trapped_exceptions.append(traceback.format_exc()) else: self.root_logger.info( "Not enough space for a core dump, skipping %s processes with PIDs %s", - pinfo.name, str(pinfo.pidv)) + pinfo.name, + str(pinfo.pidv), + ) # Download symbols after pausing if the task ID is not None and not running with sanitizers. # Sanitizer builds are not stripped and don't require debug symbols. @@ -182,21 +202,28 @@ class HangAnalyzer(Subcommand): download_debug_symbols(self.root_logger, my_symbolizer) # Dump info of all processes, except python & java. - for pinfo in [pinfo for pinfo in processes if not re.match("^(java|python)", pinfo.name)]: + for pinfo in [ + pinfo for pinfo in processes if not re.match("^(java|python)", pinfo.name) + ]: try: dumpers.dbg.dump_info(pinfo, take_dump=False) except Exception as err: # pylint: disable=broad-except - self.root_logger.info("Error encountered when invoking debugger %s", err) + self.root_logger.info( + "Error encountered when invoking debugger %s", err + ) trapped_exceptions.append(traceback.format_exc()) # Dump java processes using jstack. for pinfo in [pinfo for pinfo in processes if pinfo.name.startswith("java")]: for pid in pinfo.pidv: try: - dumpers.jstack.dump_info(self.root_logger, self.options.debugger_output, - pinfo.name, pid) + dumpers.jstack.dump_info( + self.root_logger, self.options.debugger_output, pinfo.name, pid + ) except Exception as err: # pylint: disable=broad-except - self.root_logger.info("Error encountered when invoking debugger %s", err) + self.root_logger.info( + "Error encountered when invoking debugger %s", err + ) trapped_exceptions.append(traceback.format_exc()) # Signal go processes to ensure they print out stack traces, and die on POSIX OSes. @@ -205,8 +232,11 @@ class HangAnalyzer(Subcommand): # Note: The stacktrace output may be captured elsewhere (i.e. resmoke). for pinfo in [pinfo for pinfo in processes if pinfo.name in self.go_processes]: for pid in pinfo.pidv: - self.root_logger.info("Sending signal SIGABRT to go process %s with PID %d", - pinfo.name, pid) + self.root_logger.info( + "Sending signal SIGABRT to go process %s with PID %d", + pinfo.name, + pid, + ) process.signal_process(self.root_logger, pid, signal.SIGABRT) self.root_logger.info("Done analyzing all processes for hangs") @@ -216,7 +246,9 @@ class HangAnalyzer(Subcommand): process.teardown_processes(self.root_logger, processes, dump_pids) else: # Resuming all suspended processes. - for pinfo in [pinfo for pinfo in processes if not pinfo.name.startswith("python")]: + for pinfo in [ + pinfo for pinfo in processes if not pinfo.name.startswith("python") + ]: for pid in pinfo.pidv: process.resume_process(self.root_logger, pinfo.name, pid) @@ -224,22 +256,23 @@ class HangAnalyzer(Subcommand): self.root_logger.info(exception) if trapped_exceptions: raise RuntimeError( - "Exceptions were thrown while dumping. There may still be some valid dumps.") + "Exceptions were thrown while dumping. There may still be some valid dumps." + ) def _configure_processes(self): if self.options.debugger_output is None: - self.options.debugger_output = ['stdout'] + self.options.debugger_output = ["stdout"] # add != "" check to avoid empty process_ids if self.options.process_ids is not None and self.options.process_ids != "": # self.process_ids is an int list of PIDs - self.process_ids = [int(pid) for pid in self.options.process_ids.split(',')] + self.process_ids = [int(pid) for pid in self.options.process_ids.split(",")] if self.options.process_names is not None: - self.interesting_processes = self.options.process_names.split(',') + self.interesting_processes = self.options.process_names.split(",") if self.options.go_process_names is not None: - self.go_processes = self.options.go_process_names.split(',') + self.go_processes = self.options.go_process_names.split(",") self.interesting_processes += self.go_processes def _setup_logging(self, logger): @@ -259,10 +292,14 @@ class HangAnalyzer(Subcommand): if sys.platform in ["win32", "cygwin"]: self.root_logger.info("Windows Distribution: %s", platform.win32_ver()) else: - self.root_logger.info("Linux Distribution: %s", distro.linux_distribution()) + self.root_logger.info( + "Linux Distribution: %s", distro.linux_distribution() + ) except AttributeError: - self.root_logger.warning("Cannot determine Linux distro since Python is too old") + self.root_logger.warning( + "Cannot determine Linux distro since Python is too old" + ) try: current_login = getpass.getuser() @@ -271,7 +308,8 @@ class HangAnalyzer(Subcommand): self.root_logger.info("Current UID: %s", uid) except AttributeError: self.root_logger.warning( - "Cannot determine Unix Current Login, not supported on Windows") + "Cannot determine Unix Current Login, not supported on Windows" + ) def _check_enough_free_space(self): usage_percent = psutil.disk_usage(".").percent @@ -284,7 +322,7 @@ class HangAnalyzerPlugin(PluginInterface): def parse(self, subcommand, parser, parsed_args, **kwargs): """Parse command-line options.""" - if subcommand == 'hang-analyzer': + if subcommand == "hang-analyzer": return HangAnalyzer(parsed_args, task_id=parsed_args.task_id, **kwargs) return None @@ -293,36 +331,79 @@ class HangAnalyzerPlugin(PluginInterface): parser = subparsers.add_parser("hang-analyzer", help=__doc__) parser.add_argument( - '-m', '--process-match', dest='process_match', choices=('contains', - 'exact'), default='contains', + "-m", + "--process-match", + dest="process_match", + choices=("contains", "exact"), + default="contains", help="Type of match for process names (-p & -g), specify 'contains', or" " 'exact'. Note that the process name match performs the following" " conversions: change all process names to lowecase, strip off the file" - " extension, like '.exe' on Windows. Default is 'contains'.") - parser.add_argument('-p', '--process-names', dest='process_names', - help='Comma separated list of process names to analyze') - parser.add_argument('-g', '--go-process-names', dest='go_process_names', - help='Comma separated list of go process names to analyze') + " extension, like '.exe' on Windows. Default is 'contains'.", + ) parser.add_argument( - '-d', '--process-ids', dest='process_ids', default=None, - help='Comma separated list of process ids (PID) to analyze, overrides -p &' - ' -g') - parser.add_argument('-c', '--dump-core', dest='dump_core', action="store_true", - default=False, help='Dump core file for each analyzed process') - parser.add_argument('-s', '--max-disk-usage-percent', dest='max_disk_usage_percent', - default=90, help='Maximum disk usage percent for a core dump') + "-p", + "--process-names", + dest="process_names", + help="Comma separated list of process names to analyze", + ) parser.add_argument( - '-o', '--debugger-output', dest='debugger_output', action="append", choices=('file', - 'stdout'), - default=None, help="If 'stdout', then the debugger's output is written to the Python" + "-g", + "--go-process-names", + dest="go_process_names", + help="Comma separated list of go process names to analyze", + ) + parser.add_argument( + "-d", + "--process-ids", + dest="process_ids", + default=None, + help="Comma separated list of process ids (PID) to analyze, overrides -p &" + " -g", + ) + parser.add_argument( + "-c", + "--dump-core", + dest="dump_core", + action="store_true", + default=False, + help="Dump core file for each analyzed process", + ) + parser.add_argument( + "-s", + "--max-disk-usage-percent", + dest="max_disk_usage_percent", + default=90, + help="Maximum disk usage percent for a core dump", + ) + parser.add_argument( + "-o", + "--debugger-output", + dest="debugger_output", + action="append", + choices=("file", "stdout"), + default=None, + help="If 'stdout', then the debugger's output is written to the Python" " process's stdout. If 'file', then the debugger's output is written" " to a file named debugger__.log for each process it" " attaches to. This option can be specified multiple times on the" " command line to have the debugger's output written to multiple" " locations. By default, the debugger's output is written only to the" - " Python process's stdout.") - parser.add_argument('-k', '--kill-processes', dest='kill_processes', action="store_true", - default=False, - help="Kills the analyzed processes after analysis completes.") - parser.add_argument("--task-id", '-t', action="store", type=str, default=None, - help="Fetch corresponding symbols given an Evergreen task ID") + " Python process's stdout.", + ) + parser.add_argument( + "-k", + "--kill-processes", + dest="kill_processes", + action="store_true", + default=False, + help="Kills the analyzed processes after analysis completes.", + ) + parser.add_argument( + "--task-id", + "-t", + action="store", + type=str, + default=None, + help="Fetch corresponding symbols given an Evergreen task ID", + ) diff --git a/buildscripts/resmokelib/hang_analyzer/process.py b/buildscripts/resmokelib/hang_analyzer/process.py index 7e06ec91b20..13e62bea7e6 100644 --- a/buildscripts/resmokelib/hang_analyzer/process.py +++ b/buildscripts/resmokelib/hang_analyzer/process.py @@ -13,7 +13,7 @@ import psutil from buildscripts.resmokelib import core -_IS_WINDOWS = (sys.platform == "win32") +_IS_WINDOWS = sys.platform == "win32" if _IS_WINDOWS: import win32api @@ -34,8 +34,11 @@ def call(args, logger, timeout_seconds=None, pinfo=None, check=True) -> int: try: ret = process.wait(timeout=timeout_seconds) except subprocess.TimeoutExpired: - logger.error("Killing %s processes with PIDs %s because time limit expired", pinfo.name, - str(pinfo.pidv)) + logger.error( + "Killing %s processes with PIDs %s because time limit expired", + pinfo.name, + str(pinfo.pidv), + ) process.kill() process.wait() logger_pipe.wait_until_finished() @@ -63,7 +66,7 @@ def find_program(prog, paths): def callo(args, logger): """Call subprocess on args string.""" logger.info("%s", str(args)) - return subprocess.check_output(args).decode('utf-8', 'replace') + return subprocess.check_output(args).decode("utf-8", "replace") def signal_python(logger, pname, pid): @@ -78,10 +81,14 @@ def signal_python(logger, pname, pid): # On Windows, we set up an event object to wait on a signal. For Cygwin, we register # a signal handler to wait for the signal since it supports POSIX signals. if _IS_WINDOWS: - logger.info("Calling SetEvent to signal python process %s with PID %d", pname, pid) + logger.info( + "Calling SetEvent to signal python process %s with PID %d", pname, pid + ) signal_event_object(logger, pid) else: - logger.info("Sending signal SIGUSR1 to python process %s with PID %d", pname, pid) + logger.info( + "Sending signal SIGUSR1 to python process %s with PID %d", pname, pid + ) signal_process(logger, pid, signal.SIGUSR1) logger.info("Waiting for process to report") @@ -97,7 +104,9 @@ def signal_event_object(logger, pid): try: desired_access = win32event.EVENT_MODIFY_STATE inherit_handle = False - task_timeout_handle = win32event.OpenEvent(desired_access, inherit_handle, event_name) + task_timeout_handle = win32event.OpenEvent( + desired_access, inherit_handle, event_name + ) except win32event.error as err: logger.info("Exception from win32event.OpenEvent with error: %s", err) return @@ -143,7 +152,9 @@ def resume_process(logger, pname, pid): def teardown_processes(logger, processes, dump_pids): """Kill processes with SIGKILL or SIGABRT.""" - logger.info("Starting to kill or abort processes. Logs should be ignored from this point.") + logger.info( + "Starting to kill or abort processes. Logs should be ignored from this point." + ) for pinfo in processes: for pid in pinfo.pidv: try: diff --git a/buildscripts/resmokelib/hang_analyzer/process_list.py b/buildscripts/resmokelib/hang_analyzer/process_list.py index d8ec232a23a..238bdf74f78 100644 --- a/buildscripts/resmokelib/hang_analyzer/process_list.py +++ b/buildscripts/resmokelib/hang_analyzer/process_list.py @@ -37,15 +37,18 @@ def get_processes(process_ids, interesting_processes, process_match, logger): # Canonicalize the process names to lowercase to handle cases where the name of the Python # process is /System/Library/.../Python on OS X and -p python is specified to the hang analyzer. all_processes = [ - Pinfo(name=process_name.lower(), pidv=pid) for (pid, process_name) in all_processes + Pinfo(name=process_name.lower(), pidv=pid) + for (pid, process_name) in all_processes ] if process_ids: running_pids = {pidv for (pname, pidv) in all_processes} missing_pids = set(process_ids) - running_pids if missing_pids: - logger.warning("The following requested process ids are not running %s", - list(missing_pids)) + logger.warning( + "The following requested process ids are not running %s", + list(missing_pids), + ) processes_to_keep = [] for process in all_processes: @@ -60,15 +63,21 @@ def get_processes(process_ids, interesting_processes, process_match, logger): # if we don't have a list of pids, make sure the process matches # the list of interesting processes - if not process_ids and interesting_processes and not _pname_match( - process_match, process.name, interesting_processes): + if ( + not process_ids + and interesting_processes + and not _pname_match(process_match, process.name, interesting_processes) + ): continue processes_to_keep.append(process) process_types = {pname for (pname, _) in processes_to_keep} processes = [ - Pinfo(name=ptype, pidv=[pidv for (pname, pidv) in processes_to_keep if pname == ptype]) + Pinfo( + name=ptype, + pidv=[pidv for (pname, pidv) in processes_to_keep if pname == ptype], + ) for ptype in process_types ] @@ -101,7 +110,9 @@ class _ProcessList(object): :param logger: Where to log output. :return: A list of process names. """ - raise NotImplementedError("dump_process must be implemented in OS-specific subclasses") + raise NotImplementedError( + "dump_process must be implemented in OS-specific subclasses" + ) class _WindowsProcessList(_ProcessList): @@ -132,7 +143,7 @@ class _DarwinProcessList(_ProcessList): @staticmethod def __find_ps(): """Find ps.""" - return find_program('ps', ['/bin']) + return find_program("ps", ["/bin"]) def dump_processes(self, logger): """Get list of [Pid, Process Name].""" @@ -143,7 +154,9 @@ class _DarwinProcessList(_ProcessList): ret = callo([ps, "-axco", "pid,comm"], logger) buff = io.StringIO(ret) - csv_reader = csv.reader(buff, delimiter=' ', quoting=csv.QUOTE_NONE, skipinitialspace=True) + csv_reader = csv.reader( + buff, delimiter=" ", quoting=csv.QUOTE_NONE, skipinitialspace=True + ) return [[int(row[0]), row[1]] for row in csv_reader if row[0] != "PID"] @@ -154,7 +167,7 @@ class _LinuxProcessList(_ProcessList): @staticmethod def __find_ps(): """Find ps.""" - return find_program('ps', ['/bin', '/usr/bin']) + return find_program("ps", ["/bin", "/usr/bin"]) def dump_processes(self, logger): """Get list of [Pid, Process Name].""" @@ -167,15 +180,26 @@ class _LinuxProcessList(_ProcessList): ret = callo([ps, "-eo", "pid,args"], logger) buff = io.StringIO(ret) - csv_reader = csv.reader(buff, delimiter=' ', quoting=csv.QUOTE_NONE, skipinitialspace=True) + csv_reader = csv.reader( + buff, delimiter=" ", quoting=csv.QUOTE_NONE, skipinitialspace=True + ) - return [[int(row[0]), os.path.split(row[1])[1]] for row in csv_reader if row[0] != "PID"] + return [ + [int(row[0]), os.path.split(row[1])[1]] + for row in csv_reader + if row[0] != "PID" + ] def _pname_match(match_type, pname, interesting_processes): """Return True if the pname matches an interesting_processes.""" pname = os.path.splitext(pname)[0] for ip in interesting_processes: - if match_type == 'exact' and pname == ip or match_type == 'contains' and ip in pname: + if ( + match_type == "exact" + and pname == ip + or match_type == "contains" + and ip in pname + ): return True return False diff --git a/buildscripts/resmokelib/logging/buildlogger.py b/buildscripts/resmokelib/logging/buildlogger.py index e3f7d1a2215..db0c72b4389 100644 --- a/buildscripts/resmokelib/logging/buildlogger.py +++ b/buildscripts/resmokelib/logging/buildlogger.py @@ -107,15 +107,22 @@ class _LogsSplitter(object): class _BaseBuildloggerHandler(handlers.BufferedHandler): """Base class of the buildlogger handler for global logs and handler for test logs.""" - def __init__(self, build_config, endpoint, capacity=_SEND_AFTER_LINES, - interval_secs=_SEND_AFTER_SECS): + def __init__( + self, + build_config, + endpoint, + capacity=_SEND_AFTER_LINES, + interval_secs=_SEND_AFTER_SECS, + ): """Initialize the buildlogger handler with the build id and credentials.""" handlers.BufferedHandler.__init__(self, capacity, interval_secs) username = build_config["username"] password = build_config["password"] - self.http_handler = handlers.HTTPHandler(_config.BUILDLOGGER_URL, username, password) + self.http_handler = handlers.HTTPHandler( + _config.BUILDLOGGER_URL, username, password + ) self.endpoint = endpoint self.retry_buffer = [] @@ -174,11 +181,13 @@ class _BaseBuildloggerHandler(handlers.BufferedHandler): new_max_size = response_data["max_size"] if self.max_size and new_max_size >= self.max_size: BUILDLOGGER_FALLBACK.exception( - "Received an HTTP 413 code, but already had max_size set") + "Received an HTTP 413 code, but already had max_size set" + ) return 0 BUILDLOGGER_FALLBACK.warning( "Received an HTTP 413 code, updating the request max_size to %s", - new_max_size) + new_max_size, + ) self.max_size = new_max_size return self._append_logs(log_lines_chunk) BUILDLOGGER_FALLBACK.error("Encountered an HTTP error: %s", err) @@ -207,7 +216,8 @@ class _BaseBuildloggerHandler(handlers.BufferedHandler): # the Evergreen database. BUILDLOGGER_FALLBACK.warning( "Failed to flush all log output (%d messages) to logkeeper.", - len(self.retry_buffer)) + len(self.retry_buffer), + ) # We set a flag to indicate that we failed to flush all log output to logkeeper so # resmoke.py can exit with a special return code. @@ -219,23 +229,33 @@ class _BaseBuildloggerHandler(handlers.BufferedHandler): class BuildloggerTestHandler(_BaseBuildloggerHandler): """Buildlogger handler for the test logs.""" - def __init__(self, build_config, build_id, test_id, capacity=_SEND_AFTER_LINES, - interval_secs=_SEND_AFTER_SECS): + def __init__( + self, + build_config, + build_id, + test_id, + capacity=_SEND_AFTER_LINES, + interval_secs=_SEND_AFTER_SECS, + ): """Initialize the buildlogger handler with the credentials, build id, and test id.""" endpoint = APPEND_TEST_LOGS_ENDPOINT % { "build_id": build_id, "test_id": test_id, } - _BaseBuildloggerHandler.__init__(self, build_config, endpoint, capacity, interval_secs) + _BaseBuildloggerHandler.__init__( + self, build_config, endpoint, capacity, interval_secs + ) @_log_on_error def _finish_test(self, failed=False): """Send a POST request to the APPEND_TEST_LOGS_ENDPOINT with the test status.""" self.post( - self.endpoint, headers={ + self.endpoint, + headers={ "X-Sendlogs-Test-Done": "true", "X-Sendlogs-Test-Failed": "true" if failed else "false", - }) + }, + ) def close(self): """Close the buildlogger handler.""" @@ -249,11 +269,18 @@ class BuildloggerTestHandler(_BaseBuildloggerHandler): class BuildloggerGlobalHandler(_BaseBuildloggerHandler): """Buildlogger handler for the global logs.""" - def __init__(self, build_config, build_id, capacity=_SEND_AFTER_LINES, - interval_secs=_SEND_AFTER_SECS): + def __init__( + self, + build_config, + build_id, + capacity=_SEND_AFTER_LINES, + interval_secs=_SEND_AFTER_SECS, + ): """Initialize the buildlogger handler with the credentials and build id.""" endpoint = APPEND_GLOBAL_LOGS_ENDPOINT % {"build_id": build_id} - _BaseBuildloggerHandler.__init__(self, build_config, endpoint, capacity, interval_secs) + _BaseBuildloggerHandler.__init__( + self, build_config, endpoint, capacity, interval_secs + ) class BuildloggerServer(object): @@ -269,8 +296,12 @@ class BuildloggerServer(object): tmp_globals = {} self.config = {} exec( - compile(open(_BUILDLOGGER_CONFIG, "rb").read(), _BUILDLOGGER_CONFIG, 'exec'), - tmp_globals, self.config) + compile( + open(_BUILDLOGGER_CONFIG, "rb").read(), _BUILDLOGGER_CONFIG, "exec" + ), + tmp_globals, + self.config, + ) # Rename "slavename" to "username" if present. if "slavename" in self.config and "username" not in self.config: @@ -290,35 +321,46 @@ class BuildloggerServer(object): builder = "%s_%s" % (self.config["builder"], suffix) build_num = int(self.config["build_num"]) - handler = handlers.HTTPHandler(url_root=_config.BUILDLOGGER_URL, username=username, - password=password, should_retry=True) + handler = handlers.HTTPHandler( + url_root=_config.BUILDLOGGER_URL, + username=username, + password=password, + should_retry=True, + ) response = handler.post( - CREATE_BUILD_ENDPOINT, data={ + CREATE_BUILD_ENDPOINT, + data={ "builder": builder, "buildnum": build_num, "task_id": _config.EVERGREEN_TASK_ID, "execution": _config.EVERGREEN_EXECUTION, - }) + }, + ) return response["id"] @_log_on_error def new_test_id(self, build_id, test_filename, test_command): """Return a new test id for sending test logs to.""" - handler = handlers.HTTPHandler(url_root=_config.BUILDLOGGER_URL, - username=self.config["username"], - password=self.config["password"], should_retry=True) + handler = handlers.HTTPHandler( + url_root=_config.BUILDLOGGER_URL, + username=self.config["username"], + password=self.config["password"], + should_retry=True, + ) endpoint = CREATE_TEST_ENDPOINT % {"build_id": build_id} response = handler.post( - endpoint, data={ + endpoint, + data={ "test_filename": test_filename, "command": test_command, "phase": self.config.get("build_phase", "unknown"), "task_id": _config.EVERGREEN_TASK_ID, "execution": _config.EVERGREEN_EXECUTION, - }) + }, + ) return response["id"] @@ -341,7 +383,10 @@ class BuildloggerServer(object): def get_test_log_url(build_id, test_id): """Return the test log URL.""" base_url = _config.BUILDLOGGER_URL.rstrip("/") - endpoint = APPEND_TEST_LOGS_ENDPOINT % {"build_id": build_id, "test_id": test_id} + endpoint = APPEND_TEST_LOGS_ENDPOINT % { + "build_id": build_id, + "test_id": test_id, + } return "%s/%s" % (base_url, endpoint.strip("/")) @staticmethod diff --git a/buildscripts/resmokelib/logging/flush.py b/buildscripts/resmokelib/logging/flush.py index ce504b088d0..d62cdf8b001 100644 --- a/buildscripts/resmokelib/logging/flush.py +++ b/buildscripts/resmokelib/logging/flush.py @@ -9,7 +9,7 @@ import threading import time _FLUSH_THREAD_LOCK = threading.Lock() -_FLUSH_THREAD: '_FlushThread' = None +_FLUSH_THREAD: "_FlushThread" = None def start_thread(): diff --git a/buildscripts/resmokelib/logging/formatters.py b/buildscripts/resmokelib/logging/formatters.py index b85eee46a43..351550d6d3c 100644 --- a/buildscripts/resmokelib/logging/formatters.py +++ b/buildscripts/resmokelib/logging/formatters.py @@ -31,4 +31,6 @@ class EvergreenLogFormatter(logging.Formatter): def format(self, record): ts = int(record.created * 1e9) - return "\n".join([f"{ts} {line}" for line in super().format(record).split("\n")]) + return "\n".join( + [f"{ts} {line}" for line in super().format(record).split("\n")] + ) diff --git a/buildscripts/resmokelib/logging/handlers.py b/buildscripts/resmokelib/logging/handlers.py index c8998df32db..09f4a8e6b21 100644 --- a/buildscripts/resmokelib/logging/handlers.py +++ b/buildscripts/resmokelib/logging/handlers.py @@ -71,8 +71,10 @@ class ExceptionExtractor: self.exception_detected = True if self.current_exception_is_truncated: self.current_exception.appendleft( - "[LAST Part of Exception]" if self.truncate == - Truncate.FIRST else "[FIRST Part of Exception]") + "[LAST Part of Exception]" + if self.truncate == Truncate.FIRST + else "[FIRST Part of Exception]" + ) def get_exception(self): """Get the exception as a list of strings if it exists.""" @@ -176,7 +178,10 @@ class BufferedHandler(logging.Handler): # be None after this point. self.__flush_event = flush.flush_after(self, delay=self.interval_secs) - if not self.__flush_scheduled_by_emit and len(self.__emit_buffer) >= self.capacity: + if ( + not self.__flush_scheduled_by_emit + and len(self.__emit_buffer) >= self.capacity + ): # Attempt to flush the buffer early if we haven't already done so. We don't bother # calling flush.cancel() and flush.flush_after() when 'self.__flush_event' is # already scheduled to happen as soon as possible to avoid introducing unnecessary @@ -214,8 +219,10 @@ class BufferedHandler(logging.Handler): def _flush_buffer_with_lock(self, buf, close_called): """Ensure all logging output has been flushed.""" - raise NotImplementedError("_flush_buffer_with_lock must be implemented by BufferedHandler" - " subclasses") + raise NotImplementedError( + "_flush_buffer_with_lock must be implemented by BufferedHandler" + " subclasses" + ) def close(self): """Flush the buffer and tidies up any resources used by this handler.""" @@ -269,11 +276,12 @@ class HTTPHandler(object): retry = urllib3_retry.Retry( backoff_factor=0.1, # Enable backoff starting at 0.1s. allowed_methods=False, # Support all HTTP verbs. - status_forcelist=retry_status) + status_forcelist=retry_status, + ) adapter = requests.adapters.HTTPAdapter(max_retries=retry) - self.session.mount('http://', adapter) - self.session.mount('https://', adapter) + self.session.mount("http://", adapter) + self.session.mount("https://", adapter) self.url_root = url_root @@ -298,7 +306,9 @@ class HTTPHandler(object): with warnings.catch_warnings(): if urllib3_exceptions is not None: try: - warnings.simplefilter("ignore", urllib3_exceptions.InsecurePlatformWarning) + warnings.simplefilter( + "ignore", urllib3_exceptions.InsecurePlatformWarning + ) except AttributeError: # Versions of urllib3 prior to 1.10.3 didn't define InsecurePlatformWarning. # Versions of requests prior to 2.6.0 didn't have a vendored copy of urllib3 @@ -306,15 +316,23 @@ class HTTPHandler(object): pass try: - warnings.simplefilter("ignore", urllib3_exceptions.InsecureRequestWarning) + warnings.simplefilter( + "ignore", urllib3_exceptions.InsecureRequestWarning + ) except AttributeError: # Versions of urllib3 prior to 1.9 didn't define InsecureRequestWarning. # Versions of requests prior to 2.4.0 didn't have a vendored copy of urllib3 # that defined InsecureRequestWarning. pass - response = self.session.post(url, data=data, headers=headers, timeout=timeout_secs, - auth=self.auth_handler, verify=True) + response = self.session.post( + url, + data=data, + headers=headers, + timeout=timeout_secs, + auth=self.auth_handler, + verify=True, + ) response.raise_for_status() diff --git a/buildscripts/resmokelib/logging/loggers.py b/buildscripts/resmokelib/logging/loggers.py index 7b66569cf51..cf26cd1a5e7 100644 --- a/buildscripts/resmokelib/logging/loggers.py +++ b/buildscripts/resmokelib/logging/loggers.py @@ -45,7 +45,9 @@ _FIXTURE_LOGGER_REGISTRY: dict = {} # URL of parsley logs. RAW_TEST_LOGS_URL = "https://evergreen.mongodb.com/rest/v2/tasks/{task_id}/build/TestLogs/job{job_num}%2F{test_id}.log?execution={execution}&print_time=true" RAW_JOBS_LOGS_URL = "https://evergreen.mongodb.com/rest/v2/tasks/{task_id}/build/TestLogs/job{job_num}?execution={execution}&print_time=true" -PARSLEY_JOBS_LOGS_URL = "https://parsley.mongodb.com/test/{task_id}/{execution}/job{job_num}/all" +PARSLEY_JOBS_LOGS_URL = ( + "https://parsley.mongodb.com/test/{task_id}/{execution}/job{job_num}/all" +) def _build_logger_server(): @@ -73,10 +75,13 @@ def _setup_redirects(): redirect_cmds.append("mrlog") redirect_cmds.append(["tee", config.USER_FRIENDLY_OUTPUT]) - redirect_cmds.append([ - "grep", "-Ea", - r"Summary of|Running.*\.\.\.|invariant|fassert|BACKTRACE|Invalid access|Workload\(s\) started|Workload\(s\)|WiredTiger error|AddressSanitizer|threads with tids|failed to load|Completed cmd|Completed stepdown" - ]) + redirect_cmds.append( + [ + "grep", + "-Ea", + r"Summary of|Running.*\.\.\.|invariant|fassert|BACKTRACE|Invalid access|Workload\(s\) started|Workload\(s\)|WiredTiger error|AddressSanitizer|threads with tids|failed to load|Completed cmd|Completed stepdown", + ] + ) for idx, redirect in enumerate(redirect_cmds): # The first redirect reads from stdout. Otherwise read from the previous redirect. @@ -94,7 +99,8 @@ def configure_loggers(): # The 'buildlogger' prefix is not added to the fallback logger since the prefix of the original # logger will be there as part of the logged message. buildlogger.BUILDLOGGER_FALLBACK.addHandler( - _fallback_buildlogger_handler(include_logger_name=False)) + _fallback_buildlogger_handler(include_logger_name=False) + ) global BUILDLOGGER_SERVER # pylint: disable=global-statement BUILDLOGGER_SERVER = _build_logger_server() @@ -157,7 +163,8 @@ def new_job_logger(test_kind, job_num) -> logging.Logger: buildlogger.set_log_output_incomplete() raise errors.LoggerRuntimeConfigError( "Encountered an error configuring buildlogger for job #{:d}: Failed to get a" - " new build_id".format(job_num)) + " new build_id".format(job_num) + ) url = BUILDLOGGER_SERVER.get_build_log_url(build_id) ROOT_EXECUTOR_LOGGER.info("Writing output of job #%d to %s.", job_num, url) @@ -239,7 +246,9 @@ def configure_exception_capture(test_logger): return [js_exception, py_exception] -def new_test_logger(test_shortname, test_basename, command, parent, job_num, test_id, job_logger): +def new_test_logger( + test_shortname, test_basename, command, parent, job_num, test_id, job_logger +): """Create a new test logger that will be a child of the given parent.""" name = "%s:%s" % (parent.name, test_shortname) logger = logging.Logger(name) @@ -259,7 +268,8 @@ def new_test_logger(test_shortname, test_basename, command, parent, job_num, tes buildlogger.set_log_output_incomplete() raise errors.LoggerRuntimeConfigError( "Encountered an error configuring buildlogger for test {}: Failed to get a new" - " test_id".format(test_basename)) + " test_id".format(test_basename) + ) url = BUILDLOGGER_SERVER.get_test_log_url(build_id, test_id) parsley_url = BUILDLOGGER_SERVER.get_parsley_log_url(build_id, test_id) @@ -274,8 +284,11 @@ def new_test_logger(test_shortname, test_basename, command, parent, job_num, tes def new_test_thread_logger(parent, test_kind, thread_id, tenant_id=None): """Create a new test thread logger that will be the child of the given parent.""" - name = "%s:%s:%s" % (test_kind, thread_id, tenant_id) if tenant_id else "%s:%s" % (test_kind, - thread_id) + name = ( + "%s:%s:%s" % (test_kind, thread_id, tenant_id) + if tenant_id + else "%s:%s" % (test_kind, thread_id) + ) logger = logging.Logger(name) logger.parent = parent return logger @@ -296,8 +309,9 @@ def _add_handler(logger, handler_info, formatter): """Add non-buildlogger handlers to a logger based on configuration.""" handler_class = handler_info["class"] if handler_class == "logging.FileHandler": - handler = logging.FileHandler(filename=handler_info["filename"], mode=handler_info.get( - "mode", "w")) + handler = logging.FileHandler( + filename=handler_info["filename"], mode=handler_info.get("mode", "w") + ) elif handler_class == "logging.NullHandler": handler = logging.NullHandler() elif handler_class == "logging.StreamHandler": @@ -319,7 +333,9 @@ def _add_build_logger_handler(logger, job_num, test_id=None): handler_info = _get_buildlogger_handler_info(logger_info) if handler_info is not None: if test_id is not None: - handler = BUILDLOGGER_SERVER.get_test_handler(build_id, test_id, handler_info) + handler = BUILDLOGGER_SERVER.get_test_handler( + build_id, test_id, handler_info + ) else: handler = BUILDLOGGER_SERVER.get_global_handler(build_id, handler_info) handler.setFormatter(_get_formatter(logger_info)) @@ -410,12 +426,17 @@ def _add_evergreen_handler(logger, job_num, test_id=None, test_name=None): break if evergreen_handler_info: - fp = f"{_get_evergreen_log_dirname()}/{get_evergreen_log_name(job_num, test_id)}" + fp = ( + f"{_get_evergreen_log_dirname()}/{get_evergreen_log_name(job_num, test_id)}" + ) os.makedirs(os.path.dirname(fp), exist_ok=True) handler = BufferedFileHandler(fp) handler.setFormatter( - formatters.EvergreenLogFormatter(fmt=logger_info.get("format", _DEFAULT_FORMAT))) + formatters.EvergreenLogFormatter( + fmt=logger_info.get("format", _DEFAULT_FORMAT) + ) + ) logger.addHandler(handler) if test_id: @@ -426,7 +447,9 @@ def _add_evergreen_handler(logger, job_num, test_id=None, test_name=None): execution=config.EVERGREEN_EXECUTION, ) ROOT_EXECUTOR_LOGGER.info("Writing output of %s to %s.", test_id, fp) - ROOT_EXECUTOR_LOGGER.info("Raw logs for %s can be viewed at %s", test_name, raw_url) + ROOT_EXECUTOR_LOGGER.info( + "Raw logs for %s can be viewed at %s", test_name, raw_url + ) else: parsley_url = PARSLEY_JOBS_LOGS_URL.format( task_id=config.EVERGREEN_TASK_ID, @@ -439,9 +462,12 @@ def _add_evergreen_handler(logger, job_num, test_id=None, test_name=None): execution=config.EVERGREEN_EXECUTION, ) ROOT_EXECUTOR_LOGGER.info("Writing output of job #%d to %s.", job_num, fp) - ROOT_EXECUTOR_LOGGER.info("Parsley logs for job #%s can be viewed at %s", job_num, - parsley_url) - ROOT_EXECUTOR_LOGGER.info("Raw logs for job #%s can be viewed at %s", job_num, raw_url) + ROOT_EXECUTOR_LOGGER.info( + "Parsley logs for job #%s can be viewed at %s", job_num, parsley_url + ) + ROOT_EXECUTOR_LOGGER.info( + "Raw logs for job #%s can be viewed at %s", job_num, raw_url + ) def _get_evergreen_log_dirname(): diff --git a/buildscripts/resmokelib/mongo_fuzzer_configs.py b/buildscripts/resmokelib/mongo_fuzzer_configs.py index 004e085862f..a72f3892460 100644 --- a/buildscripts/resmokelib/mongo_fuzzer_configs.py +++ b/buildscripts/resmokelib/mongo_fuzzer_configs.py @@ -12,10 +12,13 @@ def generate_eviction_configs(rng, mode): eviction_trigger = rng.randint(eviction_target + 1, 99) # Fuzz eviction_dirty_target and trigger both as relative and absolute values - target_bytes_min = 50 * 1024 * 1024 # 50MB # 5% of 1GB default cache size on Evergreen + target_bytes_min = ( + 50 * 1024 * 1024 + ) # 50MB # 5% of 1GB default cache size on Evergreen target_bytes_max = 256 * 1024 * 1024 # 256MB # 1GB default cache size on Evergreen eviction_dirty_target = rng.choice( - [rng.randint(5, 50), rng.randint(target_bytes_min, target_bytes_max)]) + [rng.randint(5, 50), rng.randint(target_bytes_min, target_bytes_max)] + ) trigger_max = 75 if eviction_dirty_target <= 50 else target_bytes_max eviction_dirty_trigger = rng.randint(eviction_dirty_target + 1, trigger_max) @@ -26,9 +29,13 @@ def generate_eviction_configs(rng, mode): # values of the corresponding eviction dirty target and trigger. They need to stay less than the # dirty equivalents. The default updates target is 2.5% of the cache, so let's start fuzzing # from 2%. - updates_target_min = 2 if eviction_dirty_target <= 100 else 20 * 1024 * 1024 # 2% of 1GB cache + updates_target_min = ( + 2 if eviction_dirty_target <= 100 else 20 * 1024 * 1024 + ) # 2% of 1GB cache eviction_updates_target = rng.randint(updates_target_min, eviction_dirty_target - 1) - eviction_updates_trigger = rng.randint(eviction_updates_target + 1, eviction_dirty_trigger - 1) + eviction_updates_trigger = rng.randint( + eviction_updates_target + 1, eviction_dirty_trigger - 1 + ) # Fuzz File manager settings close_idle_time_secs = rng.randint(1, 100) @@ -41,33 +48,37 @@ def generate_eviction_configs(rng, mode): # realloc_exact - Finds more memory bugs by allocating the memory for the exact size asked # rollback_error - Forces WiredTiger to return a rollback error every Nth call # slow_checkpoint - Adds internal delays in processing internal leaf pages during a checkpoint - dbg_eviction = rng.choice(['true', 'false']) - dbg_realloc_exact = rng.choice(['true', 'false']) + dbg_eviction = rng.choice(["true", "false"]) + dbg_realloc_exact = rng.choice(["true", "false"]) # Rollback every Nth transaction. The values have been tuned after looking at how many # WiredTiger transactions happen per second for the config-fuzzed jstests. # The setting is trigerring bugs, disabled until they get resolved. # dbg_rollback_error = rng.choice([0, rng.randint(250, 1500)]) dbg_rollback_error = 0 - dbg_slow_checkpoint = 'false' if mode != 'stress' else rng.choice(['true', 'false']) + dbg_slow_checkpoint = "false" if mode != "stress" else rng.choice(["true", "false"]) - return "debug_mode=(eviction={0},realloc_exact={1},rollback_error={2}, slow_checkpoint={3}),"\ - "eviction_checkpoint_target={4},eviction_dirty_target={5},eviction_dirty_trigger={6},"\ - "eviction_target={7},eviction_trigger={8},eviction_updates_target={9},"\ - "eviction_updates_trigger={10},file_manager=(close_handle_minimum={11},"\ - "close_idle_time={12},close_scan_interval={13})".format(dbg_eviction, - dbg_realloc_exact, - dbg_rollback_error, - dbg_slow_checkpoint, - eviction_checkpoint_target, - eviction_dirty_target, - eviction_dirty_trigger, - eviction_target, - eviction_trigger, - eviction_updates_target, - eviction_updates_trigger, - close_handle_minimum, - close_idle_time_secs, - close_scan_interval) + return ( + "debug_mode=(eviction={0},realloc_exact={1},rollback_error={2}, slow_checkpoint={3})," + "eviction_checkpoint_target={4},eviction_dirty_target={5},eviction_dirty_trigger={6}," + "eviction_target={7},eviction_trigger={8},eviction_updates_target={9}," + "eviction_updates_trigger={10},file_manager=(close_handle_minimum={11}," + "close_idle_time={12},close_scan_interval={13})".format( + dbg_eviction, + dbg_realloc_exact, + dbg_rollback_error, + dbg_slow_checkpoint, + eviction_checkpoint_target, + eviction_dirty_target, + eviction_dirty_trigger, + eviction_target, + eviction_trigger, + eviction_updates_target, + eviction_updates_trigger, + close_handle_minimum, + close_idle_time_secs, + close_scan_interval, + ) + ) def generate_table_configs(rng): @@ -80,21 +91,28 @@ def generate_table_configs(rng): memory_page_max_lower_bound = leaf_page_max # Assume WT cache size of 1GB as most MDB tests specify this as the cache size. memory_page_max_upper_bound = round( - (rng.randint(256, 1024) * 1024 * 1024) / 10) # cache_size / 10 - memory_page_max = rng.randint(memory_page_max_lower_bound, memory_page_max_upper_bound) + (rng.randint(256, 1024) * 1024 * 1024) / 10 + ) # cache_size / 10 + memory_page_max = rng.randint( + memory_page_max_lower_bound, memory_page_max_upper_bound + ) split_pct = rng.choice([50, 60, 75, 100]) prefix_compression = rng.choice(["true", "false"]) block_compressor = rng.choice(["none", "snappy", "zlib", "zstd"]) - return "block_compressor={0},internal_page_max={1},leaf_page_max={2},leaf_value_max={3},"\ - "memory_page_max={4},prefix_compression={5},split_pct={6}".format(block_compressor, - internal_page_max, - leaf_page_max, - leaf_value_max, - memory_page_max, - prefix_compression, - split_pct) + return ( + "block_compressor={0},internal_page_max={1},leaf_page_max={2},leaf_value_max={3}," + "memory_page_max={4},prefix_compression={5},split_pct={6}".format( + block_compressor, + internal_page_max, + leaf_page_max, + leaf_value_max, + memory_page_max, + prefix_compression, + split_pct, + ) + ) def generate_flow_control_parameters(rng): @@ -126,8 +144,9 @@ def generate_mongod_parameters(rng, mode): # ret["lockCodeSegmentsInMemory"] = rng.choice([True, False]) if not ret["disableLogicalSessionCacheRefresh"]: ret["logicalSessionRefreshMillis"] = rng.choice([100, 1000, 10000, 100000]) - ret["maxNumberOfTransactionOperationsInSingleOplogEntry"] = rng.randint(1, 10) * rng.choice( - [1, 10, 100]) + ret["maxNumberOfTransactionOperationsInSingleOplogEntry"] = rng.randint( + 1, 10 + ) * rng.choice([1, 10, 100]) ret["minSnapshotHistoryWindowInSeconds"] = rng.choice([300, rng.randint(30, 600)]) ret["mirrorReads"] = {"samplingRate": rng.random()} ret["queryAnalysisWriterMaxMemoryUsageBytes"] = rng.randint(1, 100) * 1024 * 1024 @@ -135,20 +154,25 @@ def generate_mongod_parameters(rng, mode): ret["wiredTigerCursorCacheSize"] = rng.randint(-100, 100) ret["wiredTigerSessionCloseIdleTimeSecs"] = rng.randint(0, 300) ret["storageEngineConcurrencyAdjustmentAlgorithm"] = rng.choices( - ["throughputProbing", "fixedConcurrentTransactions"], weights=[10, 1])[0] + ["throughputProbing", "fixedConcurrentTransactions"], weights=[10, 1] + )[0] ret["storageEngineConcurrencyAdjustmentIntervalMillis"] = rng.randint(10, 1000) ret["throughputProbingStepMultiple"] = rng.uniform(0.1, 0.5) ret["throughputProbingInitialConcurrency"] = rng.randint(4, 128) - ret["throughputProbingMinConcurrency"] = rng.randint(4, - ret["throughputProbingInitialConcurrency"]) - ret["throughputProbingMaxConcurrency"] = rng.randint(ret["throughputProbingInitialConcurrency"], - 128) + ret["throughputProbingMinConcurrency"] = rng.randint( + 4, ret["throughputProbingInitialConcurrency"] + ) + ret["throughputProbingMaxConcurrency"] = rng.randint( + ret["throughputProbingInitialConcurrency"], 128 + ) ret["throughputProbingReadWriteRatio"] = rng.uniform(0, 1) ret["throughputProbingConcurrencyMovingAverageWeight"] = 1 - rng.random() ret["wiredTigerConcurrentWriteTransactions"] = rng.randint(5, 32) ret["wiredTigerConcurrentReadTransactions"] = rng.randint(5, 32) - ret["wiredTigerStressConfig"] = False if mode != 'stress' else rng.choice([True, False]) + ret["wiredTigerStressConfig"] = ( + False if mode != "stress" else rng.choice([True, False]) + ) ret["wiredTigerSizeStorerPeriodicSyncHits"] = rng.randint(1, 100000) ret["wiredTigerSizeStorerPeriodicSyncPeriodMillis"] = rng.randint(1, 60000) @@ -166,12 +190,14 @@ def generate_mongod_parameters(rng, mode): # because the generated mongod parameters are used for every node in the replica set, so the # secondaries in the replica set will not be able to find a valid sync source. ret["initialSyncSourceReadPreference"] = rng.choice( - ["nearest", "primary", "primaryPreferred", "secondaryPreferred"]) + ["nearest", "primary", "primaryPreferred", "secondaryPreferred"] + ) ret["initialSyncMethod"] = rng.choice(["fileCopyBased", "logical"]) # Query parameters - ret["internalQueryExecYieldIterations"] = rng.choices([1, rng.randint(1, 1000)], - weights=[1, 10])[0] + ret["internalQueryExecYieldIterations"] = rng.choices( + [1, rng.randint(1, 1000)], weights=[1, 10] + )[0] ret["internalQueryExecYieldPeriodMS"] = rng.randint(1, 100) # We need a higher timeout to account for test slowness @@ -202,7 +228,10 @@ def fuzz_mongod_set_parameters(mode, seed, user_provided_params): rng = random.Random(seed) ret = {} - params = [generate_flow_control_parameters(rng), generate_mongod_parameters(rng, mode)] + params = [ + generate_flow_control_parameters(rng), + generate_mongod_parameters(rng, mode), + ] for dct in params: for key, value in dct.items(): ret[key] = value @@ -210,8 +239,12 @@ def fuzz_mongod_set_parameters(mode, seed, user_provided_params): for key, value in utils.load_yaml(user_provided_params).items(): ret[key] = value - return utils.dump_yaml(ret), generate_eviction_configs(rng, mode), generate_table_configs(rng), \ - generate_table_configs(rng) + return ( + utils.dump_yaml(ret), + generate_eviction_configs(rng, mode), + generate_table_configs(rng), + generate_table_configs(rng), + ) def fuzz_mongos_set_parameters(mode, seed, user_provided_params): diff --git a/buildscripts/resmokelib/multiversion/__init__.py b/buildscripts/resmokelib/multiversion/__init__.py index 1e3fdfa1581..ced98e7fea4 100644 --- a/buildscripts/resmokelib/multiversion/__init__.py +++ b/buildscripts/resmokelib/multiversion/__init__.py @@ -1,4 +1,5 @@ """Subcommand for multiversion config.""" + import argparse from typing import List, Optional @@ -57,9 +58,14 @@ class MultiversionConfigSubcommand(Subcommand): def determine_multiversion_config() -> MultiversionConfig: """Discover the current multiversion configuration.""" from buildscripts.resmokelib import multiversionconstants + multiversion_service = MultiversionService( - mongo_version=MongoVersion.from_yaml_file(multiversionconstants.MONGO_VERSION_YAML), - mongo_releases=MongoReleases.from_yaml_file(multiversionconstants.RELEASES_YAML), + mongo_version=MongoVersion.from_yaml_file( + multiversionconstants.MONGO_VERSION_YAML + ), + mongo_releases=MongoReleases.from_yaml_file( + multiversionconstants.RELEASES_YAML + ), ) version_constants = multiversion_service.calculate_version_constants() return MultiversionConfig( @@ -81,14 +87,27 @@ class MultiversionPlugin(PluginInterface): :param subparsers: argparse subparsers """ - parser = subparsers.add_parser(MULTIVERSION_SUBCOMMAND, - help="Display configuration for multiversion testing") + parser = subparsers.add_parser( + MULTIVERSION_SUBCOMMAND, + help="Display configuration for multiversion testing", + ) - parser.add_argument("--config-file-output", '-f', action="store", type=str, default=None, - help="File to write the multiversion config to.") + parser.add_argument( + "--config-file-output", + "-f", + action="store", + type=str, + default=None, + help="File to write the multiversion config to.", + ) - def parse(self, subcommand: str, parser: argparse.ArgumentParser, - parsed_args: argparse.Namespace, **kwargs) -> Optional[Subcommand]: + def parse( + self, + subcommand: str, + parser: argparse.ArgumentParser, + parsed_args: argparse.Namespace, + **kwargs, + ) -> Optional[Subcommand]: """ Resolve command-line options to a Subcommand or None. diff --git a/buildscripts/resmokelib/multiversion/multiversion_service.py b/buildscripts/resmokelib/multiversion/multiversion_service.py index 1703b31d782..65e0944a045 100644 --- a/buildscripts/resmokelib/multiversion/multiversion_service.py +++ b/buildscripts/resmokelib/multiversion/multiversion_service.py @@ -1,4 +1,5 @@ """A service for working with multiversion testing.""" + from __future__ import annotations import re @@ -13,7 +14,7 @@ from pydantic import BaseModel, Field # These values must match the include paths for artifacts.tgz in evergreen.yml. MONGO_VERSION_YAML = ".resmoke_mongo_version.yml" RELEASES_YAML = ".resmoke_mongo_release_values.yml" -VERSION_RE = re.compile(r'^[0-9]+\.[0-9]+') +VERSION_RE = re.compile(r"^[0-9]+\.[0-9]+") LOGGER = structlog.getLogger(__name__) @@ -126,7 +127,7 @@ class MongoVersion(BaseModel): :param yaml_file: Path to yaml file. :return: MongoVersion read from file. """ - mongo_version_yml_file = open(yaml_file, 'r') + mongo_version_yml_file = open(yaml_file, "r") return cls(**yaml.safe_load(mongo_version_yml_file)) def get_version(self) -> Version: @@ -134,7 +135,8 @@ class MongoVersion(BaseModel): version_match = VERSION_RE.match(self.mongo_version) if version_match is None: raise ValueError( - f"Could not determine version from mongo version string '{self.mongo_version}'") + f"Could not determine version from mongo version string '{self.mongo_version}'" + ) return Version(version_match.group(0)) @@ -149,11 +151,14 @@ class MongoReleases(BaseModel): LTS. """ - feature_compatibility_versions: List[str] = Field(alias="featureCompatibilityVersions") + feature_compatibility_versions: List[str] = Field( + alias="featureCompatibilityVersions" + ) long_term_support_releases: List[str] = Field(alias="longTermSupportReleases") eol_versions: List[str] = Field(alias="eolVersions") - generate_fcv_lower_bound_override: Optional[str] = Field(None, - alias="generateFCVLowerBoundOverride") + generate_fcv_lower_bound_override: Optional[str] = Field( + None, alias="generateFCVLowerBoundOverride" + ) @classmethod def from_yaml_file(cls, yaml_file: str) -> MongoReleases: @@ -164,16 +169,18 @@ class MongoReleases(BaseModel): :return: MongoReleases read from file. """ - with open(yaml_file, 'r') as mongo_releases_file: + with open(yaml_file, "r") as mongo_releases_file: yaml_contents = mongo_releases_file.read() safe_load_result = yaml.safe_load(yaml_contents) try: return cls(**safe_load_result) except: - LOGGER.info("MongoReleases.from_yaml_file() failed\n" - f"yaml_file = {yaml_file}\n" - f"yaml_contents = {yaml_contents}\n" - f"safe_load_result = {safe_load_result}") + LOGGER.info( + "MongoReleases.from_yaml_file() failed\n" + f"yaml_file = {yaml_file}\n" + f"yaml_contents = {yaml_contents}\n" + f"safe_load_result = {safe_load_result}" + ) raise def get_fcv_versions(self) -> List[Version]: @@ -192,7 +199,9 @@ class MongoReleases(BaseModel): class MultiversionService: """A service for working with multiversion information.""" - def __init__(self, mongo_version: MongoVersion, mongo_releases: MongoReleases) -> None: + def __init__( + self, mongo_version: MongoVersion, mongo_releases: MongoReleases + ) -> None: """ Initialize the service. @@ -216,11 +225,13 @@ class MultiversionService: last_lts = lts[bisect_left(lts, latest) - 1] # All FCVs greater than last LTS, up to latest. - requires_fcv_tag_list = fcvs[bisect_right(fcvs, last_lts):bisect_right(fcvs, latest)] + requires_fcv_tag_list = fcvs[ + bisect_right(fcvs, last_lts) : bisect_right(fcvs, latest) + ] requires_fcv_tag_list_continuous = [latest] # All FCVs less than latest. - fcvs_less_than_latest = fcvs[:bisect_left(fcvs, latest)] + fcvs_less_than_latest = fcvs[: bisect_left(fcvs, latest)] return VersionConstantValues( latest=latest, diff --git a/buildscripts/resmokelib/multiversionconstants.py b/buildscripts/resmokelib/multiversionconstants.py index e1f30e511b3..2d6d4fa625a 100644 --- a/buildscripts/resmokelib/multiversionconstants.py +++ b/buildscripts/resmokelib/multiversionconstants.py @@ -1,4 +1,5 @@ """FCV and Server binary version constants used for multiversion testing.""" + import http import os import shutil @@ -37,10 +38,12 @@ def generate_mongo_version_file(): try: res = check_output("git describe", shell=True, text=True) except CalledProcessError as exp: - raise ChildProcessError("Failed to run git describe to get the latest tag") from exp + raise ChildProcessError( + "Failed to run git describe to get the latest tag" + ) from exp # Write the current MONGO_VERSION to a data file. - with open(MONGO_VERSION_YAML, 'w') as mongo_version_fh: + with open(MONGO_VERSION_YAML, "w") as mongo_version_fh: # E.g. res = 'r5.1.0-alpha-597-g8c345c6693\n' res = res[1:] # Remove the leading "r" character. mongo_version_fh.write("mongo_version: " + res) @@ -55,11 +58,14 @@ def get_releases_file_from_remote(): if response.status_code != http.HTTPStatus.OK: raise RuntimeError( f"Fetching releases.yml file returned unsuccessful status: {response.status_code}, " - f"response body: {response.text}\n") + f"response body: {response.text}\n" + ) file.write(response.content) LOGGER.info(f"Got releases.yml file remotely: {MASTER_RELEASES_REMOTE_FILE}") except Exception as exc: - LOGGER.warning(f"Could not get releases.yml file remotely: {MASTER_RELEASES_REMOTE_FILE}") + LOGGER.warning( + f"Could not get releases.yml file remotely: {MASTER_RELEASES_REMOTE_FILE}" + ) raise exc @@ -69,7 +75,9 @@ def get_releases_file_locally_or_fallback_to_remote(): LOGGER.info(f"Found releases.yml file locally: {RELEASES_LOCAL_FILE}") shutil.copyfile(RELEASES_LOCAL_FILE, RELEASES_YAML) else: - LOGGER.warning(f"Could not find releases.yml file locally: {RELEASES_LOCAL_FILE}") + LOGGER.warning( + f"Could not find releases.yml file locally: {RELEASES_LOCAL_FILE}" + ) get_releases_file_from_remote() @@ -88,7 +96,8 @@ def in_git_root_dir(): return False git_root_dir = os.path.realpath( - check_output("git rev-parse --show-toplevel", shell=True, text=True).strip()) + check_output("git rev-parse --show-toplevel", shell=True, text=True).strip() + ) curr_dir = os.path.realpath(os.getcwd()) return git_root_dir == curr_dir @@ -96,7 +105,9 @@ def in_git_root_dir(): if in_git_root_dir(): generate_mongo_version_file() else: - LOGGER.info("Skipping generating mongo version file since we're not in the root of a git repo") + LOGGER.info( + "Skipping generating mongo version file since we're not in the root of a git repo" + ) # Avoiding regenerating the releases file if this flag is set. Should only be set if there are # multiple processes attempting to set up multiversion concurrently. @@ -104,12 +115,13 @@ if not USE_EXISTING_RELEASES_FILE: generate_releases_file() else: LOGGER.info( - "Skipping generating releases file since the --useExistingReleasesFile flag has been set") + "Skipping generating releases file since the --useExistingReleasesFile flag has been set" + ) def evg_project_str(version): """Return the evergreen project name for the given version.""" - return 'mongodb-mongo-v{}.{}'.format(version.major, version.minor) + return "mongodb-mongo-v{}.{}".format(version.major, version.minor) multiversion_service = MultiversionService( @@ -143,13 +155,17 @@ REQUIRES_FCV_TAG = version_constants.get_fcv_tag_list() REQUIRES_FCV_TAGS_LESS_THAN_LATEST = version_constants.get_fcv_tags_less_than_latest() # Generate evergreen project names for all FCVs less than latest. -EVERGREEN_PROJECTS = ['mongodb-mongo-master'] -EVERGREEN_PROJECTS.extend([evg_project_str(fcv) for fcv in version_constants.fcvs_less_than_latest]) +EVERGREEN_PROJECTS = ["mongodb-mongo-master"] +EVERGREEN_PROJECTS.extend( + [evg_project_str(fcv) for fcv in version_constants.fcvs_less_than_latest] +) -OLD_VERSIONS = [ - LAST_LTS -] if LAST_CONTINUOUS_FCV == LAST_LTS_FCV or LAST_CONTINUOUS_FCV in version_constants.get_eols( -) else [LAST_LTS, LAST_CONTINUOUS] +OLD_VERSIONS = ( + [LAST_LTS] + if LAST_CONTINUOUS_FCV == LAST_LTS_FCV + or LAST_CONTINUOUS_FCV in version_constants.get_eols() + else [LAST_LTS, LAST_CONTINUOUS] +) def log_constants(exec_log): diff --git a/buildscripts/resmokelib/parser.py b/buildscripts/resmokelib/parser.py index 0b755e37fe4..7d270141e6f 100644 --- a/buildscripts/resmokelib/parser.py +++ b/buildscripts/resmokelib/parser.py @@ -31,11 +31,18 @@ def get_parser(usage=None): """Get the resmoke parser.""" parser = argparse.ArgumentParser(usage=usage) subparsers = parser.add_subparsers(dest="command") - parser.add_argument("--configDir", dest="config_dir", metavar="CONFIG_DIR", - help="Directory to search for resmoke configuration files") parser.add_argument( - "--jstestsDir", dest="jstests_dir", metavar="CONFIG_DIR", - help="Directory to search for jstests files existence while suite validation") + "--configDir", + dest="config_dir", + metavar="CONFIG_DIR", + help="Directory to search for resmoke configuration files", + ) + parser.add_argument( + "--jstestsDir", + dest="jstests_dir", + metavar="CONFIG_DIR", + help="Directory to search for jstests files existence while suite validation", + ) # Add sub-commands. for plugin in _PLUGINS: @@ -64,10 +71,12 @@ def parse_command_line(sys_args, usage=None, **kwargs): if subcommand_obj is not None: return subcommand_obj - raise RuntimeError(f"Resmoke configuration has invalid subcommand: {subcommand}. Try '--help'") + raise RuntimeError( + f"Resmoke configuration has invalid subcommand: {subcommand}. Try '--help'" + ) -def set_run_options(argstr=''): +def set_run_options(argstr=""): """Populate the config module variables for the 'run' subcommand with the default options.""" - parser, parsed_args = parse(['run'] + shlex.split(argstr)) + parser, parsed_args = parse(["run"] + shlex.split(argstr)) configure_resmoke.validate_and_update_config(parser, parsed_args) diff --git a/buildscripts/resmokelib/plugin.py b/buildscripts/resmokelib/plugin.py index 75a3ff9f317..5d62306be6e 100644 --- a/buildscripts/resmokelib/plugin.py +++ b/buildscripts/resmokelib/plugin.py @@ -8,7 +8,9 @@ class Subcommand(object): def execute(self): """Execute the subcommand.""" - raise NotImplementedError("execute must be implemented by Subcommand subclasses") + raise NotImplementedError( + "execute must be implemented by Subcommand subclasses" + ) class PluginInterface(abc.ABC): diff --git a/buildscripts/resmokelib/powercycle/__init__.py b/buildscripts/resmokelib/powercycle/__init__.py index f2c027921bc..d75fc6240eb 100644 --- a/buildscripts/resmokelib/powercycle/__init__.py +++ b/buildscripts/resmokelib/powercycle/__init__.py @@ -7,6 +7,7 @@ Client & server side powercycle test script. This script is used in conjunction with certain Evergreen hosts created with the `evergreen host create` command. """ + import argparse from buildscripts.resmokelib.plugin import PluginInterface, Subcommand @@ -48,9 +49,10 @@ class Powercycle(Subcommand): def execute(self): """Execute powercycle test.""" return { - self.RUN: self._exec_powercycle_main, self.HOST_SETUP: self._exec_powercycle_host_setup, + self.RUN: self._exec_powercycle_main, + self.HOST_SETUP: self._exec_powercycle_host_setup, self.SAVE_DIAG: self._exec_powercycle_save_diagnostics, - self.REMOTE_HANG_ANALYZER: self._exec_powercycle_hang_analyzer + self.REMOTE_HANG_ANALYZER: self._exec_powercycle_hang_analyzer, }[self.options.run_option]() def _exec_powercycle_main(self): @@ -62,7 +64,6 @@ class Powercycle(Subcommand): @staticmethod def _exec_powercycle_save_diagnostics(): - # The event logs on Windows are a useful diagnostic to have when determining if something bad # happened to the remote machine after it was repeatedly crashed during powercycle testing. For # example, the Application and System event logs have previously revealed that the mongod.exe @@ -92,24 +93,29 @@ class PowercyclePlugin(PluginInterface): """Add sub-subcommands for powercycle.""" sub_parsers = parent_parser.add_subparsers() - setup_parser = sub_parsers.add_parser("setup-host", - help="Step 1. Set up the host for powercycle") + setup_parser = sub_parsers.add_parser( + "setup-host", help="Step 1. Set up the host for powercycle" + ) setup_parser.set_defaults(run_option=Powercycle.HOST_SETUP) run_parser = sub_parsers.add_parser( - "run", help="Step 2. Run the Powercycle test of your choice;" - "search for 'powercycle invocation' in evg task logs") + "run", + help="Step 2. Run the Powercycle test of your choice;" + "search for 'powercycle invocation' in evg task logs", + ) run_parser.set_defaults(run_option=Powercycle.RUN) save_parser = sub_parsers.add_parser( "save-diagnostics", help="Copy Powercycle diagnostics to local machine; mainly used by Evergreen. For" - "local invocation, consider instead ssh-ing into the Powercycle host directly") + "local invocation, consider instead ssh-ing into the Powercycle host directly", + ) save_parser.set_defaults(run_option=Powercycle.SAVE_DIAG) save_parser = sub_parsers.add_parser( "remote-hang-analyzer", - help="Run the hang analyzer on the remote machine; mainly used by Evergreen") + help="Run the hang analyzer on the remote machine; mainly used by Evergreen", + ) save_parser.set_defaults(run_option=Powercycle.REMOTE_HANG_ANALYZER) # Only need to return run_parser for further processing; others don't need additional args. @@ -118,7 +124,9 @@ class PowercyclePlugin(PluginInterface): def add_subcommand(self, subparsers): """Create and add the parser for the subcommand.""" intermediate_parser = subparsers.add_parser( - SUBCOMMAND, help=__doc__, usage=""" + SUBCOMMAND, + help=__doc__, + usage=""" MongoDB Powercycle Tests. To run a powercycle test locally, use the following steps: 1. Spin up an Evergreen spawnhost or virtual workstation that supports running @@ -142,7 +150,8 @@ MongoDB Powercycle Tests. To run a powercycle test locally, use the following st 5. You're ready to run powercycle! See the help message for individual subcommands for more detail. - """) + """, + ) parser = self._add_powercycle_commands(intermediate_parser) @@ -152,61 +161,97 @@ MongoDB Powercycle Tests. To run a powercycle test locally, use the following st program_options = parser.add_argument_group("Program Options") # Test options - test_options.add_argument("--sshUserHost", dest="ssh_user_host", - help="Remote server ssh user/host, i.e., user@host (REQUIRED)", - required=True) + test_options.add_argument( + "--sshUserHost", + dest="ssh_user_host", + help="Remote server ssh user/host, i.e., user@host (REQUIRED)", + required=True, + ) test_options.add_argument( - "--sshConnection", dest="ssh_connection_options", + "--sshConnection", + dest="ssh_connection_options", help="Remote server ssh additional connection options, i.e., '-i ident.pem'" - " which are added to '{}'".format(powercycle_constants.DEFAULT_SSH_CONNECTION_OPTIONS), - default=None) + " which are added to '{}'".format( + powercycle_constants.DEFAULT_SSH_CONNECTION_OPTIONS + ), + default=None, + ) test_options.add_argument( - "--taskName", dest="task_name", + "--taskName", + dest="task_name", help=f"Powercycle task name. Based on this value additional" f" config values will be used from '{powercycle_config.POWERCYCLE_TASKS_CONFIG}'." - f" [default: '%(default)s']", default="powercycle") + f" [default: '%(default)s']", + default="powercycle", + ) - test_options.add_argument("--sshAccessRetryCount", dest="ssh_access_retry_count", - help=argparse.SUPPRESS, type=int, default=5) + test_options.add_argument( + "--sshAccessRetryCount", + dest="ssh_access_retry_count", + help=argparse.SUPPRESS, + type=int, + default=5, + ) # MongoDB options mongodb_options.add_argument( - "--downloadUrl", dest="tarball_url", - help="URL of tarball to test, if unspecifed latest tarball will be" - " used", default="latest") + "--downloadUrl", + dest="tarball_url", + help="URL of tarball to test, if unspecifed latest tarball will be" " used", + default="latest", + ) # mongod options # The current host used to start and connect to mongod. Not meant to be specified # by the user. - mongod_options.add_argument("--mongodHost", dest="host", help=argparse.SUPPRESS, - default=None) + mongod_options.add_argument( + "--mongodHost", dest="host", help=argparse.SUPPRESS, default=None + ) # The current port used to start and connect to mongod. Not meant to be specified # by the user. - mongod_options.add_argument("--mongodPort", dest="port", help=argparse.SUPPRESS, type=int, - default=None) + mongod_options.add_argument( + "--mongodPort", dest="port", help=argparse.SUPPRESS, type=int, default=None + ) # Program options log_levels = ["debug", "info", "warning", "error"] program_options.add_argument( - "--logLevel", dest="log_level", choices=log_levels, + "--logLevel", + dest="log_level", + choices=log_levels, help="The log level. Accepted values are: {}." - " [default: '%(default)s'].".format(log_levels), default="info") + " [default: '%(default)s'].".format(log_levels), + default="info", + ) program_options.add_argument( - "--logFile", dest="log_file", - help="The destination file for the log output. Defaults to stdout.", default=None) + "--logFile", + dest="log_file", + help="The destination file for the log output. Defaults to stdout.", + default=None, + ) # Remote options, include commands and options sent from client to server under test. # These are 'internal' options, not meant to be directly specifed. # More than one remote operation can be provided and they are specified in the program args. - program_options.add_argument("--remoteOperation", dest="remote_operation", - help=argparse.SUPPRESS, action="store_true", default=False) + program_options.add_argument( + "--remoteOperation", + dest="remote_operation", + help=argparse.SUPPRESS, + action="store_true", + default=False, + ) - program_options.add_argument("--rsyncDest", dest="rsync_dest", nargs=2, - help=argparse.SUPPRESS, default=None) + program_options.add_argument( + "--rsyncDest", + dest="rsync_dest", + nargs=2, + help=argparse.SUPPRESS, + default=None, + ) parser.add_argument("remote_operations", nargs="*", help=argparse.SUPPRESS) diff --git a/buildscripts/resmokelib/powercycle/lib/__init__.py b/buildscripts/resmokelib/powercycle/lib/__init__.py index 043d8d317e0..085d03536b2 100644 --- a/buildscripts/resmokelib/powercycle/lib/__init__.py +++ b/buildscripts/resmokelib/powercycle/lib/__init__.py @@ -1,4 +1,5 @@ """Library functions for powercycle.""" + import getpass import logging import os @@ -24,7 +25,9 @@ class PowercycleCommand(Subcommand): def __init__(self): """Initialize PowercycleCommand.""" self.expansions = yaml.safe_load(open(powercycle_constants.EXPANSIONS_FILE)) - self.ssh_connection_options = f"-i powercycle.pem {powercycle_constants.DEFAULT_SSH_CONNECTION_OPTIONS}" + self.ssh_connection_options = ( + f"-i powercycle.pem {powercycle_constants.DEFAULT_SSH_CONNECTION_OPTIONS}" + ) self.sudo = "" if self.is_windows() else "sudo" # The username on the Windows image that powercycle uses is currently the default user. self.user = "Administrator" if self.is_windows() else getpass.getuser() @@ -44,7 +47,9 @@ class PowercycleCommand(Subcommand): def _call(cmd): cmd = shlex.split(cmd) # Use a common pipe for stdout & stderr for logging. - process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT + ) buff_stdout, _ = process.communicate() buff = buff_stdout.decode("utf-8", "replace") return process.poll(), buff diff --git a/buildscripts/resmokelib/powercycle/lib/named_temp_file.py b/buildscripts/resmokelib/powercycle/lib/named_temp_file.py index cb5e67fbb99..953e0ba9b38 100644 --- a/buildscripts/resmokelib/powercycle/lib/named_temp_file.py +++ b/buildscripts/resmokelib/powercycle/lib/named_temp_file.py @@ -1,4 +1,5 @@ """Wrapper for the NamedTempFile class.""" + import logging import os import shutil @@ -20,8 +21,9 @@ class NamedTempFile(object): LOGGER.debug("Creating temporary directory %s", directory) os.makedirs(directory) cls._DIR_LIST.append(directory) - temp_file = tempfile.NamedTemporaryFile(mode="w+", newline=newline, suffix=suffix, - dir=directory, delete=False) + temp_file = tempfile.NamedTemporaryFile( + mode="w+", newline=newline, suffix=suffix, dir=directory, delete=False + ) cls._FILE_MAP[temp_file.name] = temp_file return temp_file.name @@ -44,7 +46,9 @@ class NamedTempFile(object): try: os.remove(name) except (IOError, OSError) as err: - LOGGER.warning("Unable to delete temporary file %s with error %s", name, err) + LOGGER.warning( + "Unable to delete temporary file %s with error %s", name, err + ) if not os.path.exists(name): del cls._FILE_MAP[name] @@ -60,7 +64,9 @@ class NamedTempFile(object): try: shutil.rmtree(directory) except (IOError, OSError) as err: - LOGGER.warning("Unable to delete temporary directory %s with error %s", directory, err) + LOGGER.warning( + "Unable to delete temporary directory %s with error %s", directory, err + ) if not os.path.exists(directory): cls._DIR_LIST.remove(directory) diff --git a/buildscripts/resmokelib/powercycle/lib/process_control.py b/buildscripts/resmokelib/powercycle/lib/process_control.py index 4df437de749..bb0c92452aa 100644 --- a/buildscripts/resmokelib/powercycle/lib/process_control.py +++ b/buildscripts/resmokelib/powercycle/lib/process_control.py @@ -1,4 +1,5 @@ """Wrapper for the ProcessControl class.""" + import logging import psutil @@ -54,5 +55,7 @@ class ProcessControl(object): try: proc.kill() except psutil.NoSuchProcess: - LOGGER.info("Could not kill process with pid %d, as it no longer exists", - proc.pid) + LOGGER.info( + "Could not kill process with pid %d, as it no longer exists", + proc.pid, + ) diff --git a/buildscripts/resmokelib/powercycle/lib/remote_operations.py b/buildscripts/resmokelib/powercycle/lib/remote_operations.py index c9ab577de7c..a46bdd660f0 100644 --- a/buildscripts/resmokelib/powercycle/lib/remote_operations.py +++ b/buildscripts/resmokelib/powercycle/lib/remote_operations.py @@ -36,25 +36,38 @@ class SSHOperation(object): def posix_path(path): """Return posix path, used on Windows since scp requires posix style paths.""" # If path is already quoted, we need to remove the quotes before calling - path_quote = "\'" if path.startswith("\'") else "" - path_quote = "\"" if path.startswith("\"") else path_quote + path_quote = "'" if path.startswith("'") else "" + path_quote = '"' if path.startswith('"') else path_quote if path_quote: path = path[1:-1] drive, new_path = os.path.splitdrive(path) if drive: - new_path = posixpath.join("/cygdrive", drive.split(":")[0], *re.split("/|\\\\", new_path)) + new_path = posixpath.join( + "/cygdrive", drive.split(":")[0], *re.split("/|\\\\", new_path) + ) return "{quote}{path}{quote}".format(quote=path_quote, path=new_path) class RemoteOperations(object): """Class to support remote operations.""" - def __init__(self, user_host, ssh_connection_options=None, ssh_options=None, scp_options=None, - shell_binary="/bin/bash", use_shell=False, ignore_ret=False, access_retry_count=5): + def __init__( + self, + user_host, + ssh_connection_options=None, + ssh_options=None, + scp_options=None, + shell_binary="/bin/bash", + use_shell=False, + ignore_ret=False, + access_retry_count=5, + ): """Initialize RemoteOperations.""" self.user_host = user_host - self.ssh_connection_options = ssh_connection_options if ssh_connection_options else "" + self.ssh_connection_options = ( + ssh_connection_options if ssh_connection_options else "" + ) self.ssh_options = ssh_options if ssh_options else "" self.scp_options = scp_options if scp_options else "" self.retry_sleep = 10 @@ -71,8 +84,9 @@ class RemoteOperations(object): if not self.use_shell: cmd = shlex.split(cmd) # Use a common pipe for stdout & stderr for logging. - process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, - shell=self.use_shell) + process = subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, shell=self.use_shell + ) buff_stdout, _ = process.communicate() buff = buff_stdout.decode("utf-8", "replace") print("Result of command:") @@ -84,21 +98,27 @@ class RemoteOperations(object): while True: ret, buff = self._call(cmd) # Ignore any connection errors before sshd has fully initialized. - if not ret and not any(ssh_error in buff for ssh_error in _SSH_CONNECTION_ERRORS): + if not ret and not any( + ssh_error in buff for ssh_error in _SSH_CONNECTION_ERRORS + ): return ret, buff attempt_num += 1 if attempt_num > retry_count: print("Exhausted all retry attempts.") break - print("Remote attempt {} unsuccessful, retrying in {} seconds".format( - attempt_num, self.retry_sleep)) + print( + "Remote attempt {} unsuccessful, retrying in {} seconds".format( + attempt_num, self.retry_sleep + ) + ) time.sleep(self.retry_sleep) return ret, buff def _remote_access(self): """Check if a remote session is possible.""" - cmd = "ssh {} {} {} date".format(self.ssh_connection_options, self.ssh_options, - self.user_host) + cmd = "ssh {} {} {} date".format( + self.ssh_connection_options, self.ssh_options, self.user_host + ) return self._call_retries(cmd, self.access_retry_count) def _perform_operation(self, cmd, retry, retry_count): @@ -125,8 +145,14 @@ class RemoteOperations(object): return message.startswith("ssh:") # pylint: disable=inconsistent-return-statements - def operation(self, operation_type, operation_param, operation_dir=None, retry=False, - retry_count=5): + def operation( + self, + operation_type, + operation_param, + operation_dir=None, + retry=False, + retry_count=5, + ): """Execute Main entry for remote operations. Returns (code, output). 'operation_type' supports remote shell and copy operations. @@ -136,7 +162,9 @@ class RemoteOperations(object): if not self.access_established(): code, output = self.access_info() - print(f"Exiting, unable to establish access. Code=${code}, output=${output}") + print( + f"Exiting, unable to establish access. Code=${code}, output=${output}" + ) return # File names with a space must be quoted, since we permit the @@ -154,40 +182,51 @@ class RemoteOperations(object): # See https://stackoverflow.com/questions/8254120/ # how-to-escape-a-single-quote-in-single-quote-string-in-bash operation_param = "{}".format(operation_param.replace("'", r"\'")) - operation_param = "{}".format(operation_param.replace("\"", r"\"")) + operation_param = "{}".format(operation_param.replace('"', r"\"")) dollar = "$" - cmd = "ssh {} {} {} {} -c \"{}'{}'\"".format(self.ssh_connection_options, - self.ssh_options, self.user_host, - self.shell_binary, dollar, operation_param) + cmd = "ssh {} {} {} {} -c \"{}'{}'\"".format( + self.ssh_connection_options, + self.ssh_options, + self.user_host, + self.shell_binary, + dollar, + operation_param, + ) elif operation_type == "copy_to": cmd = "scp -r {} {} ".format(self.ssh_connection_options, self.scp_options) # To support spaces in the filename or directory, we quote them one at a time. for copy_file in operation_param: # Quote file on Posix. - quote = "\"" if not _IS_WINDOWS else "" - cmd += "{quote}{file}{quote} ".format(quote=quote, file=posix_path(copy_file)) + quote = '"' if not _IS_WINDOWS else "" + cmd += "{quote}{file}{quote} ".format( + quote=quote, file=posix_path(copy_file) + ) operation_dir = operation_dir if operation_dir else "" cmd += " {}:{}".format(self.user_host, posix_path(operation_dir)) elif operation_type == "copy_from": operation_dir = operation_dir if operation_dir else "." if not os.path.isdir(operation_dir): - raise ValueError("Local directory '{}' does not exist.".format(operation_dir)) + raise ValueError( + "Local directory '{}' does not exist.".format(operation_dir) + ) # We support multiple files being copied from the remote host # by invoking scp for each file specified. # Note - this is a method which scp does not support directly. for copy_file in operation_param: copy_file = posix_path(copy_file) - cmd = "scp -r {} {} {}:".format(self.ssh_connection_options, self.scp_options, - self.user_host) + cmd = "scp -r {} {} {}:".format( + self.ssh_connection_options, self.scp_options, self.user_host + ) # Quote (on Posix), and escape the file if there are spaces. # Note - we do not support other non-ASCII characters in a file name. - quote = "\"" if not _IS_WINDOWS else "" + quote = '"' if not _IS_WINDOWS else "" if " " in copy_file: - copy_file = re.escape("{quote}{file}{quote}".format( - quote=quote, file=copy_file)) + copy_file = re.escape( + "{quote}{file}{quote}".format(quote=quote, file=copy_file) + ) cmd += "{} {}".format(copy_file, posix_path(operation_dir)) else: @@ -209,15 +248,24 @@ class RemoteOperations(object): def shell(self, operation_param, operation_dir=None): """Provide helper for remote shell operations.""" - return self.operation(operation_type="shell", operation_param=operation_param, - operation_dir=operation_dir) + return self.operation( + operation_type="shell", + operation_param=operation_param, + operation_dir=operation_dir, + ) def copy_to(self, operation_param, operation_dir=None): """Provide helper for remote copy_to operations.""" - return self.operation(operation_type="copy_to", operation_param=operation_param, - operation_dir=operation_dir) + return self.operation( + operation_type="copy_to", + operation_param=operation_param, + operation_dir=operation_dir, + ) def copy_from(self, operation_param, operation_dir=None): """Provide helper for remote copy_from operations.""" - return self.operation(operation_type="copy_from", operation_param=operation_param, - operation_dir=operation_dir) + return self.operation( + operation_type="copy_from", + operation_param=operation_param, + operation_dir=operation_dir, + ) diff --git a/buildscripts/resmokelib/powercycle/lib/services.py b/buildscripts/resmokelib/powercycle/lib/services.py index 9b5feb34143..cf83f94b195 100644 --- a/buildscripts/resmokelib/powercycle/lib/services.py +++ b/buildscripts/resmokelib/powercycle/lib/services.py @@ -1,4 +1,5 @@ """Wrapper for OS Service Wrappers.""" + import importlib import os import sys @@ -62,11 +63,18 @@ class WindowsService(object): def create(self): """Create service, if not installed. Return (code, output) tuple.""" if self.status() in list(self._states.values()): - return 1, "Service '{}' already installed, status: {}".format(self.name, self.status()) + return 1, "Service '{}' already installed, status: {}".format( + self.name, self.status() + ) try: - win32serviceutil.InstallService(pythonClassString="Service.{}".format( - self.name), serviceName=self.name, displayName=self.name, startType=self.start_type, - exeName=self.bin_path, exeArgs=self.bin_options) + win32serviceutil.InstallService( + pythonClassString="Service.{}".format(self.name), + serviceName=self.name, + displayName=self.name, + startType=self.start_type, + exeName=self.bin_path, + exeArgs=self.bin_options, + ) ret = 0 output = "Service '{}' created".format(self.name) except pywintypes.error as err: @@ -80,9 +88,14 @@ class WindowsService(object): if self.status() not in self._states.values(): return 1, "Service update '{}' status: {}".format(self.name, self.status()) try: - win32serviceutil.ChangeServiceConfig(pythonClassString="Service.{}".format( - self.name), serviceName=self.name, displayName=self.name, startType=self.start_type, - exeName=self.bin_path, exeArgs=self.bin_options) + win32serviceutil.ChangeServiceConfig( + pythonClassString="Service.{}".format(self.name), + serviceName=self.name, + displayName=self.name, + startType=self.start_type, + exeName=self.bin_path, + exeArgs=self.bin_options, + ) ret = 0 output = "Service '{}' updated".format(self.name) except pywintypes.error as err: @@ -149,7 +162,9 @@ class WindowsService(object): # (winerror=109) when stopping the "mongod-powercycle-test" service on # Windows Server 2016 and the underlying mongod process has already exited. ret = 0 - output = f"Assuming service '{self.name}' stopped despite error: {output}" + output = ( + f"Assuming service '{self.name}' stopped despite error: {output}" + ) return ret, output @@ -160,7 +175,8 @@ class WindowsService(object): # (scvType, svcState, svcControls, err, svcErr, svcCP, svcWH) # See https://msdn.microsoft.com/en-us/library/windows/desktop/ms685996(v=vs.85).aspx scv_type, svc_state, svc_controls, err, svc_err, svc_cp, svc_wh = ( - win32serviceutil.QueryServiceStatus(serviceName=self.name)) + win32serviceutil.QueryServiceStatus(serviceName=self.name) + ) if svc_state in self._states: return self._states[svc_state] return "unknown" diff --git a/buildscripts/resmokelib/powercycle/powercycle.py b/buildscripts/resmokelib/powercycle/powercycle.py index 868c91142f5..3e51ceb7fed 100755 --- a/buildscripts/resmokelib/powercycle/powercycle.py +++ b/buildscripts/resmokelib/powercycle/powercycle.py @@ -1,4 +1,5 @@ """Powercycle test helper functions.""" + import atexit import collections import copy @@ -151,10 +152,14 @@ def register_signal_handler(handler): # Wait for task time out to dump stacks. ret = win32event.WaitForSingleObject(event_handle, win32event.INFINITE) if ret != win32event.WAIT_OBJECT_0: - LOGGER.error("_handle_set_event WaitForSingleObject failed: %d", ret) + LOGGER.error( + "_handle_set_event WaitForSingleObject failed: %d", ret + ) return except win32event.error as err: - LOGGER.error("Exception from win32event.WaitForSingleObject with error: %s", err) + LOGGER.error( + "Exception from win32event.WaitForSingleObject with error: %s", err + ) else: handler(None, None) @@ -167,8 +172,9 @@ def register_signal_handler(handler): security_attributes = None manual_reset = False initial_state = False - task_timeout_handle = win32event.CreateEvent(security_attributes, manual_reset, - initial_state, event_name) + task_timeout_handle = win32event.CreateEvent( + security_attributes, manual_reset, initial_state, event_name + ) except win32event.error as err: LOGGER.error("Exception from win32event.CreateEvent with error: %s", err) return @@ -179,8 +185,9 @@ def register_signal_handler(handler): # Create thread. event_handler_thread = threading.Thread( target=_handle_set_event, - kwargs={"event_handle": task_timeout_handle, - "handler": handler}, name="windows_event_handler_thread") + kwargs={"event_handle": task_timeout_handle, "handler": handler}, + name="windows_event_handler_thread", + ) event_handler_thread.daemon = True event_handler_thread.start() else: @@ -232,7 +239,9 @@ def kill_process(parent, kill_children=True): LOGGER.debug("Killing process '%s' pid %d", proc.name(), proc.pid) proc.kill() except psutil.NoSuchProcess: - LOGGER.warning("Could not kill process %d, as it no longer exists", proc.pid) + LOGGER.warning( + "Could not kill process %d, as it no longer exists", proc.pid + ) _, alive = psutil.wait_procs(procs, timeout=30, callback=None) if alive: @@ -247,7 +256,9 @@ def kill_processes(procs, kill_children=True): LOGGER.debug("Starting kill of parent process %d", proc.pid) kill_process(proc, kill_children=kill_children) ret = proc.wait() - LOGGER.debug("Finished kill of parent process %d has return code of %d", proc.pid, ret) + LOGGER.debug( + "Finished kill of parent process %d has return code of %d", proc.pid, ret + ) def get_extension(filename): @@ -269,8 +280,11 @@ def abs_path(path): cmd = "cygpath -wa {}".format(path) ret, output = execute_cmd(cmd, use_file=True) if ret: - raise Exception("Command \"{}\" failed with code {} and output message: {}".format( - cmd, ret, output)) + raise Exception( + 'Command "{}" failed with code {} and output message: {}'.format( + cmd, ret, output + ) + ) return output.rstrip().replace("\\", "/") return os.path.abspath(os.path.normpath(path)) @@ -279,7 +293,8 @@ def symlink_dir(source_dir, dest_dir): """Symlink the 'dest_dir' to 'source_dir'.""" if _IS_WINDOWS: win32file.CreateSymbolicLink( # pylint: disable=undefined-variable - dest_dir, source_dir, win32file.SYMBOLIC_LINK_FLAG_DIRECTORY) # pylint: disable=undefined-variable + dest_dir, source_dir, win32file.SYMBOLIC_LINK_FLAG_DIRECTORY + ) # pylint: disable=undefined-variable else: os.symlink(source_dir, dest_dir) @@ -349,7 +364,7 @@ def parse_options(options): options_map[opt_name] = (None, opt_form) else: opt_name = opt[opt_idx:eq_idx] - options_map[opt_name] = (opt[eq_idx + 1:], opt_form) + options_map[opt_name] = (opt[eq_idx + 1 :], opt_form) opt_name = None elif opt_name: options_map[opt_name] = (opt, opt_form) @@ -361,7 +376,6 @@ def download_file(url, file_name, download_retries=5): LOGGER.info("Downloading %s to %s", url, file_name) while download_retries > 0: - with requests.Session() as session: adapter = requests.adapters.HTTPAdapter(max_retries=download_retries) session.mount(url, adapter) @@ -375,7 +389,9 @@ def download_file(url, file_name, download_retries=5): except requests.exceptions.ChunkedEncodingError as err: download_retries -= 1 if download_retries == 0: - raise Exception("Incomplete download for URL {}: {}".format(url, err)) + raise Exception( + "Incomplete download for URL {}: {}".format(url, err) + ) continue # Check if file download was completed. @@ -386,9 +402,12 @@ def download_file(url, file_name, download_retries=5): if url_content_length != file_size: download_retries -= 1 if download_retries == 0: - raise Exception("Downloaded file size ({} bytes) doesn't match content length" - "({} bytes) for URL {}".format(file_size, url_content_length, - url)) + raise Exception( + "Downloaded file size ({} bytes) doesn't match content length" + "({} bytes) for URL {}".format( + file_size, url_content_length, url + ) + ) continue return True @@ -406,12 +425,16 @@ def install_tarball(tarball, root_dir): if ext == ".tgz": with tarfile.open(tarball, "r:gz") as tar_handle: tar_handle.extractall(path=root_dir) - output = "Unzipped {} to {}: {}".format(tarball, root_dir, tar_handle.getnames()) + output = "Unzipped {} to {}: {}".format( + tarball, root_dir, tar_handle.getnames() + ) ret = 0 elif ext == ".zip": with zipfile.ZipFile(tarball, "r") as zip_handle: zip_handle.extractall(root_dir) - output = "Unzipped {} to {}: {}".format(tarball, root_dir, zip_handle.namelist()) + output = "Unzipped {} to {}: {}".format( + tarball, root_dir, zip_handle.namelist() + ) ret = 0 elif ext == ".msi": if not _IS_WINDOWS: @@ -445,8 +468,10 @@ def install_tarball(tarball, root_dir): ret, output = execute_cmd(cmds, use_file=True) shutil.rmtree(tmp_dir) else: - raise Exception("Unsupported file extension to unzip {}," - " supported extensions are {}".format(tarball, extensions)) + raise Exception( + "Unsupported file extension to unzip {}," + " supported extensions are {}".format(tarball, extensions) + ) LOGGER.debug(output) if ret: @@ -474,20 +499,30 @@ def chmod_w_file(chmod_file): # questions/12168110/setting-folder-permissions-in-windows-using-python # pylint: disable=undefined-variable,unused-variable user, domain, sec_type = win32security.LookupAccountName("", "Everyone") - file_sd = win32security.GetFileSecurity(chmod_file, win32security.DACL_SECURITY_INFORMATION) + file_sd = win32security.GetFileSecurity( + chmod_file, win32security.DACL_SECURITY_INFORMATION + ) dacl = file_sd.GetSecurityDescriptorDacl() - dacl.AddAccessAllowedAce(win32security.ACL_REVISION, ntsecuritycon.FILE_GENERIC_WRITE, user) + dacl.AddAccessAllowedAce( + win32security.ACL_REVISION, ntsecuritycon.FILE_GENERIC_WRITE, user + ) file_sd.SetSecurityDescriptorDacl(1, dacl, 0) - win32security.SetFileSecurity(chmod_file, win32security.DACL_SECURITY_INFORMATION, file_sd) + win32security.SetFileSecurity( + chmod_file, win32security.DACL_SECURITY_INFORMATION, file_sd + ) # pylint: enable=undefined-variable,unused-variable else: - os.chmod(chmod_file, os.stat(chmod_file) | stat.S_IWUSR | stat.S_IWGRP | stat.S_IWOTH) + os.chmod( + chmod_file, os.stat(chmod_file) | stat.S_IWUSR | stat.S_IWGRP | stat.S_IWOTH + ) def set_windows_bootstatuspolicy(): """For Windows hosts that are physical, this prevents boot to prompt after failure.""" - LOGGER.info("Setting bootstatuspolicy to ignoreallfailures & boot timeout to 5 seconds") + LOGGER.info( + "Setting bootstatuspolicy to ignoreallfailures & boot timeout to 5 seconds" + ) cmds = """ echo 'Setting bootstatuspolicy to ignoreallfailures & boot timeout to 5 seconds' ; bcdedit /set {default} bootstatuspolicy ignoreallfailures ; @@ -524,10 +559,11 @@ def _do_install_mongod(bin_dir=None, tarball_url="latest", root_dir=None): if _IS_WINDOWS: # MSI default: # https://fastdl.mongodb.org/win32/mongodb-win32-x86_64-2008plus-ssl-latest-signed.msi - tarball_url = ( - "https://fastdl.mongodb.org/win32/mongodb-win32-x86_64-2008plus-ssl-latest.zip") + tarball_url = "https://fastdl.mongodb.org/win32/mongodb-win32-x86_64-2008plus-ssl-latest.zip" elif _IS_LINUX: - tarball_url = "https://fastdl.mongodb.org/linux/mongodb-linux-x86_64-latest.tgz" + tarball_url = ( + "https://fastdl.mongodb.org/linux/mongodb-linux-x86_64-latest.tgz" + ) tarball = os.path.split(urllib.parse.urlsplit(tarball_url).path)[-1] download_file(tarball_url, tarball) @@ -549,20 +585,24 @@ def get_boot_datetime(uptime_string): match = re.search(r"last booted (.*), up", uptime_string) if match: return datetime.datetime( - *list(map(int, list(map(float, re.split("[ :-]", - match.groups()[0])))))) + *list(map(int, list(map(float, re.split("[ :-]", match.groups()[0]))))) + ) return -1 def print_uptime(): """Print the last time the system was booted, and the uptime (in seconds).""" boot_time_epoch = psutil.boot_time() - boot_time = datetime.datetime.fromtimestamp(boot_time_epoch).strftime('%Y-%m-%d %H:%M:%S.%f') + boot_time = datetime.datetime.fromtimestamp(boot_time_epoch).strftime( + "%Y-%m-%d %H:%M:%S.%f" + ) uptime = int(time.time() - boot_time_epoch) LOGGER.info("System was last booted %s, up %d seconds", boot_time, uptime) -def call_remote_operation(local_ops, remote_python, script_name, client_args, operation): +def call_remote_operation( + local_ops, remote_python, script_name, client_args, operation +): """Call the remote operation and return tuple (ret, ouput).""" client_call = f"{remote_python} {script_name} {client_args} {operation}" ret, output = local_ops.shell(client_call) @@ -628,8 +668,9 @@ class MongodControl(object): self._service = PosixService # After mongod has been installed, self.bin_path is defined. if self.bin_path: - self.service = self._service("mongod-powercycle-test", self.bin_path, - self.mongod_options(), db_path) + self.service = self._service( + "mongod-powercycle-test", self.bin_path, self.mongod_options(), db_path + ) def set_mongod_option(self, option, option_value=None, option_form="--"): """Set mongod command line option.""" @@ -659,7 +700,9 @@ class MongodControl(object): if os.path.isdir(root_dir): LOGGER.warning("Root dir %s already exists", root_dir) else: - _do_install_mongod(bin_dir=self.bin_dir, tarball_url=tarball_url, root_dir=root_dir) + _do_install_mongod( + bin_dir=self.bin_dir, tarball_url=tarball_url, root_dir=root_dir + ) self.bin_dir = get_bin_dir(root_dir) if not self.bin_dir: ret, output = execute_cmd("ls -lR '{}'".format(root_dir), use_file=True) @@ -668,8 +711,9 @@ class MongodControl(object): self.bin_path = os.path.join(self.bin_dir, self.process_name) # We need to instantiate the Service when installing, since the bin_path # is only known after install_mongod runs. - self.service = self._service("mongod-powercycle-test", self.bin_path, self.mongod_options(), - db_path=None) + self.service = self._service( + "mongod-powercycle-test", self.bin_path, self.mongod_options(), db_path=None + ) ret, output = self.service.create() return ret, output @@ -724,14 +768,26 @@ class LocalToRemoteOperations(object): Return (return code, output). """ - def __init__(self, user_host, ssh_connection_options=None, ssh_options=None, - shell_binary="/bin/bash", use_shell=False, access_retry_count=5): + def __init__( + self, + user_host, + ssh_connection_options=None, + ssh_options=None, + shell_binary="/bin/bash", + use_shell=False, + access_retry_count=5, + ): """Initialize LocalToRemoteOperations.""" self.remote_op = remote_operations.RemoteOperations( - user_host=user_host, ssh_connection_options=ssh_connection_options, - ssh_options=ssh_options, shell_binary=shell_binary, use_shell=use_shell, - ignore_ret=True, access_retry_count=access_retry_count) + user_host=user_host, + ssh_connection_options=ssh_connection_options, + ssh_options=ssh_options, + shell_binary=shell_binary, + use_shell=use_shell, + ignore_ret=True, + access_retry_count=access_retry_count, + ) def shell(self, cmds, remote_dir=None): """Return tuple (ret, output) from performing remote shell operation.""" @@ -784,10 +840,17 @@ def remote_handler(options, task_config, root_dir): db_path = abs_path(powercycle_constants.DB_PATH) log_path = abs_path(powercycle_constants.LOG_PATH) - mongod = MongodControl(bin_dir=bin_dir, db_path=db_path, log_path=log_path, port=options.port, - options=mongod_options) + mongod = MongodControl( + bin_dir=bin_dir, + db_path=db_path, + log_path=log_path, + port=options.port, + options=mongod_options, + ) - mongo_client_opts = get_mongo_client_args(host=host, port=options.port, task_config=task_config) + mongo_client_opts = get_mongo_client_args( + host=host, port=options.port, task_config=task_config + ) # Perform the sequence of operations specified. If any operation fails then return immediately. for operation in options.remote_operations: @@ -858,9 +921,15 @@ def remote_handler(options, task_config, root_dir): ret, output = mongod.start() LOGGER.info(output) if ret: - LOGGER.error("Failed to start mongod on port %d: %s", options.port, output) + LOGGER.error( + "Failed to start mongod on port %d: %s", options.port, output + ) return ret - LOGGER.info("Started mongod running on port %d pid %s", options.port, mongod.get_pids()) + LOGGER.info( + "Started mongod running on port %d pid %s", + options.port, + mongod.get_pids(), + ) mongo = pymongo.MongoClient(**mongo_client_opts) # Limit retries to a reasonable value for _ in range(100): @@ -892,8 +961,11 @@ def remote_handler(options, task_config, root_dir): def rsync_data(): rsync_dir, new_rsync_dir = options.rsync_dest - ret, output = rsync(powercycle_constants.DB_PATH, rsync_dir, - powercycle_constants.RSYNC_EXCLUDE_FILES) + ret, output = rsync( + powercycle_constants.DB_PATH, + rsync_dir, + powercycle_constants.RSYNC_EXCLUDE_FILES, + ) if output: LOGGER.info(output) # Rename the rsync_dir only if it has a different name than new_rsync_dir. @@ -905,14 +977,19 @@ def remote_handler(options, task_config, root_dir): def seed_docs(): mongo = pymongo.MongoClient(**mongo_client_opts) - return mongo_seed_docs(mongo, powercycle_constants.DB_NAME, - powercycle_constants.COLLECTION_NAME, task_config.seed_doc_num) + return mongo_seed_docs( + mongo, + powercycle_constants.DB_NAME, + powercycle_constants.COLLECTION_NAME, + task_config.seed_doc_num, + ) def set_fcv(): mongo = pymongo.MongoClient(**mongo_client_opts) try: - ret = mongo.admin.command("setFeatureCompatibilityVersion", task_config.fcv, - confirm=True) + ret = mongo.admin.command( + "setFeatureCompatibilityVersion", task_config.fcv, confirm=True + ) ret = 0 if ret["ok"] == 1 else 1 except pymongo.errors.OperationFailure as err: LOGGER.error("%s", err) @@ -939,18 +1016,29 @@ def remote_handler(options, task_config, root_dir): LOGGER.info("Running chkdsk command for %s drive", drive_letter) cmds = f"chkdsk '{drive_letter}'" ret, output = execute_cmd(cmds, use_file=True) - LOGGER.warning("chkdsk command for %s drive exited with code %d:\n%s", - drive_letter, ret, output) + LOGGER.warning( + "chkdsk command for %s drive exited with code %d:\n%s", + drive_letter, + ret, + output, + ) if ret != 0: return ret return ret op_map = { - "noop": noop, "crash_server": crash_server, "kill_mongod": kill_mongod, - "install_mongod": install_mongod, "start_mongod": start_mongod, "stop_mongod": - stop_mongod, "shutdown_mongod": shutdown_mongod, "rsync_data": rsync_data, - "seed_docs": seed_docs, "set_fcv": set_fcv, "check_disk": check_disk + "noop": noop, + "crash_server": crash_server, + "kill_mongod": kill_mongod, + "install_mongod": install_mongod, + "start_mongod": start_mongod, + "stop_mongod": stop_mongod, + "shutdown_mongod": shutdown_mongod, + "rsync_data": rsync_data, + "seed_docs": seed_docs, + "set_fcv": set_fcv, + "check_disk": check_disk, } if operation not in op_map: @@ -998,8 +1086,13 @@ def rsync(src_dir, dest_dir, exclude_files=None): if ret == 0 or "No medium found" not in rsync_output: break - LOGGER.warning("[%d/%d] rsync command failed (code=%d): %s", attempt, max_attempts, ret, - rsync_output) + LOGGER.warning( + "[%d/%d] rsync command failed (code=%d): %s", + attempt, + max_attempts, + ret, + rsync_output, + ) # If the rsync command failed with an "No medium found" error message, then we log some # basic information about the /log mount point. @@ -1053,23 +1146,34 @@ def internal_crash(): return 1, "Crash did not occur" -def crash_server_or_kill_mongod(task_config, crash_canary, local_ops, script_name, client_args): +def crash_server_or_kill_mongod( + task_config, crash_canary, local_ops, script_name, client_args +): """Crash server or kill mongod and optionally write canary doc. Return tuple (ret, output).""" crash_wait_time = powercycle_constants.CRASH_WAIT_TIME + random.randint( - 0, powercycle_constants.CRASH_WAIT_TIME_JITTER) - message_prefix = "Killing mongod" if task_config.crash_method == "kill" else "Crashing server" + 0, powercycle_constants.CRASH_WAIT_TIME_JITTER + ) + message_prefix = ( + "Killing mongod" if task_config.crash_method == "kill" else "Crashing server" + ) LOGGER.info("%s in %d seconds", message_prefix, crash_wait_time) time.sleep(crash_wait_time) if task_config.crash_method in ["internal", "kill"]: - crash_cmd = "crash_server" if task_config.crash_method == "internal" else "kill_mongod" + crash_cmd = ( + "crash_server" if task_config.crash_method == "internal" else "kill_mongod" + ) crash_func = local_ops.shell remote_python = get_remote_python() - crash_args = [f"{remote_python} {script_name} {client_args} --remoteOperation {crash_cmd}"] + crash_args = [ + f"{remote_python} {script_name} {client_args} --remoteOperation {crash_cmd}" + ] else: - message = "Unsupported crash method '{}' provided".format(task_config.crash_method) + message = "Unsupported crash method '{}' provided".format( + task_config.crash_method + ) LOGGER.error(message) return 1, message @@ -1087,7 +1191,9 @@ def wait_for_mongod_shutdown(mongod_control, timeout=2 * ONE_HOUR_SECS): status = mongod_control.status() while status != "stopped": if time.time() - start >= timeout: - LOGGER.error("The mongod process has not stopped, current status is %s", status) + LOGGER.error( + "The mongod process has not stopped, current status is %s", status + ) return 1 LOGGER.info("Waiting for mongod process to stop, current status is %s ", status) time.sleep(3) @@ -1101,9 +1207,14 @@ def wait_for_mongod_shutdown(mongod_control, timeout=2 * ONE_HOUR_SECS): return 0 -def get_mongo_client_args(host=None, port=None, task_config=None, - server_selection_timeout_ms=2 * ONE_HOUR_SECS * 1000, - socket_timeout_ms=2 * ONE_HOUR_SECS * 1000, direct_connection=True): +def get_mongo_client_args( + host=None, + port=None, + task_config=None, + server_selection_timeout_ms=2 * ONE_HOUR_SECS * 1000, + socket_timeout_ms=2 * ONE_HOUR_SECS * 1000, + direct_connection=True, +): """Return keyword arg dict used in PyMongo client.""" # Set the default serverSelectionTimeoutMS & socketTimeoutMS to 10 minutes. mongo_args = { @@ -1128,7 +1239,11 @@ def get_mongo_client_args(host=None, port=None, task_config=None, def mongo_shell(mongo_path, work_dir, host_port, mongo_cmds, retries=5, retry_sleep=5): """Start mongo_path from work_dir, connecting to host_port and executes mongo_cmds.""" cmds = "cd {}; echo {} | {} {}".format( - pipes.quote(work_dir), pipes.quote(mongo_cmds), pipes.quote(mongo_path), host_port) + pipes.quote(work_dir), + pipes.quote(mongo_cmds), + pipes.quote(mongo_path), + host_port, + ) attempt_num = 0 while True: ret, output = execute_cmd(cmds, use_file=True) @@ -1214,11 +1329,18 @@ def mongo_seed_docs(mongo, db_name, coll_name, num_docs): def rand_string(max_length=1024): """Return random string of random length.""" - return ''.join( - random.choice(string.ascii_letters) for _ in range(random.randint(1, max_length))) + return "".join( + random.choice(string.ascii_letters) + for _ in range(random.randint(1, max_length)) + ) - LOGGER.info("Seeding DB '%s' collection '%s' with %d documents, %d already exist", db_name, - coll_name, num_docs, mongo[db_name][coll_name].estimated_document_count()) + LOGGER.info( + "Seeding DB '%s' collection '%s' with %d documents, %d already exist", + db_name, + coll_name, + num_docs, + mongo[db_name][coll_name].estimated_document_count(), + ) random.seed() base_num = 100000 bulk_num = min(num_docs, 10000) @@ -1228,9 +1350,15 @@ def mongo_seed_docs(mongo, db_name, coll_name, num_docs): if num_coll_docs >= num_docs: break mongo[db_name][coll_name].insert_many( - [{"x": random.randint(0, base_num), "doc": rand_string(1024)} for _ in range(bulk_num)]) - LOGGER.info("After seeding there are %d documents in the collection", - mongo[db_name][coll_name].estimated_document_count()) + [ + {"x": random.randint(0, base_num), "doc": rand_string(1024)} + for _ in range(bulk_num) + ] + ) + LOGGER.info( + "After seeding there are %d documents in the collection", + mongo[db_name][coll_name].estimated_document_count(), + ) return 0 @@ -1238,15 +1366,20 @@ def mongo_validate_canary(mongo, db_name, coll_name, doc): """Validate a canary document, return 0 if the document exists.""" if not doc: return 0 - LOGGER.info("Validating canary document using %s.%s.find_one(%s)", db_name, coll_name, doc) + LOGGER.info( + "Validating canary document using %s.%s.find_one(%s)", db_name, coll_name, doc + ) return 0 if mongo[db_name][coll_name].find_one(doc) else 1 def mongo_insert_canary(mongo, db_name, coll_name, doc): """Insert a canary document with 'j' True, return 0 if successful.""" - LOGGER.info("Inserting canary document using %s.%s.insert_one(%s)", db_name, coll_name, doc) + LOGGER.info( + "Inserting canary document using %s.%s.insert_one(%s)", db_name, coll_name, doc + ) coll = mongo[db_name][coll_name].with_options( - write_concern=pymongo.write_concern.WriteConcern(j=True)) + write_concern=pymongo.write_concern.WriteConcern(j=True) + ) res = coll.insert_one(doc) return 0 if res.inserted_id else 1 @@ -1255,7 +1388,12 @@ def new_resmoke_config(config_file, new_config_file, test_data, eval_str=""): """Create 'new_config_file', from 'config_file', with an update from 'test_data'.""" new_config = { "executor": { - "config": {"shell_options": {"eval": eval_str, "global_vars": {"TestData": test_data}}} + "config": { + "shell_options": { + "eval": eval_str, + "global_vars": {"TestData": test_data}, + } + } } } with open(config_file, "r") as yaml_stream: @@ -1265,20 +1403,30 @@ def new_resmoke_config(config_file, new_config_file, test_data, eval_str=""): yaml.safe_dump(config, yaml_stream) -def resmoke_client(work_dir, mongo_path, host_port, js_test, resmoke_suite, repeat_num=1, - no_wait=False, log_file=None): +def resmoke_client( + work_dir, + mongo_path, + host_port, + js_test, + resmoke_suite, + repeat_num=1, + no_wait=False, + log_file=None, +): """Start resmoke client from work_dir, connecting to host_port and executes js_test.""" log_output = f">> {log_file} 2>&1" if log_file else "" - cmds = (f"cd {pipes.quote(work_dir)};" - f" python {powercycle_constants.RESMOKE_PATH}" - f" run" - f" --mongo {pipes.quote(mongo_path)}" - f" --suites {pipes.quote(resmoke_suite)}" - f" --shellConnString mongodb://{host_port}" - f" --continueOnFailure" - f" --repeat {repeat_num}" - f" {pipes.quote(js_test)}" - f" {log_output}") + cmds = ( + f"cd {pipes.quote(work_dir)};" + f" python {powercycle_constants.RESMOKE_PATH}" + f" run" + f" --mongo {pipes.quote(mongo_path)}" + f" --suites {pipes.quote(resmoke_suite)}" + f" --shellConnString mongodb://{host_port}" + f" --continueOnFailure" + f" --repeat {repeat_num}" + f" {pipes.quote(js_test)}" + f" {log_output}" + ) ret, output = None, None if no_wait: Processes.create(cmds, use_file=True) @@ -1291,7 +1439,9 @@ def get_remote_python(): """Return remote python.""" python_bin_dir = "Scripts" if _IS_WINDOWS else "bin" - remote_python = f". {powercycle_constants.VIRTUALENV_DIR}/{python_bin_dir}/activate; python -u" + remote_python = ( + f". {powercycle_constants.VIRTUALENV_DIR}/{python_bin_dir}/activate; python -u" + ) return remote_python @@ -1310,8 +1460,11 @@ def main(parser_actions, options): atexit.register(exit_handler) register_signal_handler(dump_stacks_and_exit) - logging.basicConfig(format="%(asctime)s %(levelname)s %(message)s", level=logging.ERROR, - filename=options.log_file) + logging.basicConfig( + format="%(asctime)s %(levelname)s %(message)s", + level=logging.ERROR, + filename=options.log_file, + ) logging.getLogger(__name__).setLevel(options.log_level.upper()) logging.Formatter.converter = time.gmtime @@ -1324,7 +1477,9 @@ def main(parser_actions, options): # Initialize the mongod options # Note - We use posixpath for Windows client to Linux server scenarios. - root_dir = f"{powercycle_constants.REMOTE_DIR}/mongodb-powercycle-test-{int(time.time())}" + root_dir = ( + f"{powercycle_constants.REMOTE_DIR}/mongodb-powercycle-test-{int(time.time())}" + ) set_fcv_cmd = "set_fcv" if task_config.fcv is not None else "" # Error out earlier if these options are not properly specified @@ -1340,11 +1495,17 @@ def main(parser_actions, options): EXIT_YML_FILE = powercycle_constants.POWERCYCLE_EXIT_FILE REPORT_JSON_FILE = powercycle_constants.REPORT_JSON_FILE REPORT_JSON = { - "failures": - 0, "results": [{ - "status": "fail", "test_file": task_name, "exit_code": 1, "elapsed": 0, - "start": int(time.time()), "end": int(time.time()) - }] + "failures": 0, + "results": [ + { + "status": "fail", + "test_file": task_name, + "exit_code": 1, + "elapsed": 0, + "start": int(time.time()), + "end": int(time.time()), + } + ], } LOGGER.debug("Creating report JSON %s", REPORT_JSON) @@ -1374,15 +1535,18 @@ def main(parser_actions, options): backup_path_after = f"{backup_path_after}-1" # Setup the mongo client, mongo_path is required if there are local clients. - mongo_executable = shutil.which(cmd="dist-test/bin/mongo", - path=os.getcwd() + os.pathsep + os.environ["PATH"]) + mongo_executable = shutil.which( + cmd="dist-test/bin/mongo", path=os.getcwd() + os.pathsep + os.environ["PATH"] + ) # Note: No check for `if mongo_executable is None` mongo_path = os.path.abspath(os.path.normpath(mongo_executable)) # Setup the CRUD & FSM clients. if not os.path.isfile(powercycle_constants.CONFIG_CRUD_CLIENT): - LOGGER.error("config crud client %s does not exist", - powercycle_constants.CONFIG_CRUD_CLIENT) + LOGGER.error( + "config crud client %s does not exist", + powercycle_constants.CONFIG_CRUD_CLIENT, + ) local_exit(1) with_external_server = powercycle_constants.CONFIG_CRUD_CLIENT fsm_client = powercycle_constants.FSM_CLIENT @@ -1422,7 +1586,8 @@ def main(parser_actions, options): # user-specified --sshConnection options. ssh_connection_options = ( f"{options.ssh_connection_options if options.ssh_connection_options else ''}" - f" {powercycle_constants.DEFAULT_SSH_CONNECTION_OPTIONS}") + f" {powercycle_constants.DEFAULT_SSH_CONNECTION_OPTIONS}" + ) # For remote operations requiring sudo, force pseudo-tty allocation, # see https://stackoverflow.com/questions/10310299/proper-way-to-sudo-over-ssh. # Note - the ssh option RequestTTY was added in OpenSSH 5.9, so we use '-tt'. @@ -1430,8 +1595,12 @@ def main(parser_actions, options): # Instantiate the local handler object. local_ops = LocalToRemoteOperations( - user_host=ssh_user_host, ssh_connection_options=ssh_connection_options, - ssh_options=ssh_options, use_shell=True, access_retry_count=options.ssh_access_retry_count) + user_host=ssh_user_host, + ssh_connection_options=ssh_connection_options, + ssh_options=ssh_options, + use_shell=True, + access_retry_count=options.ssh_access_retry_count, + ) verify_remote_access(local_ops) # Pass client_args to the remote script invocation. @@ -1453,15 +1622,22 @@ def main(parser_actions, options): option_value = " ".join(map(str, option_value)) client_args = f"{client_args} {action.option_strings[-1]} {option_value}" - script_name = f"{powercycle_constants.REMOTE_DIR}/{powercycle_constants.RESMOKE_PATH}" + script_name = ( + f"{powercycle_constants.REMOTE_DIR}/{powercycle_constants.RESMOKE_PATH}" + ) script_name = abs_path(script_name) LOGGER.info("%s %s", script_name, client_args) remote_python = get_remote_python() # Remote install of MongoDB. - ret, output = call_remote_operation(local_ops, remote_python, script_name, client_args, - "--remoteOperation install_mongod") + ret, output = call_remote_operation( + local_ops, + remote_python, + script_name, + client_args, + "--remoteOperation install_mongod", + ) LOGGER.info("****install_mongod: %d %s****", ret, output) if ret: local_exit(ret) @@ -1487,7 +1663,9 @@ def main(parser_actions, options): # ========= while True: loop_num += 1 - LOGGER.info("****Starting test loop %d test time %d seconds****", loop_num, test_time) + LOGGER.info( + "****Starting test loop %d test time %d seconds****", loop_num, test_time + ) temp_client_files = [] @@ -1504,16 +1682,19 @@ def main(parser_actions, options): # Optionally, rsync the pre-recovery database. # Start monogd on the secret port. # Optionally validate collections, validate the canary and seed the collection. - remote_operation = (f"--remoteOperation" - f" {rsync_opt}" - f" --mongodHost {mongod_host}" - f" --mongodPort {secret_port}" - f" {rsync_cmd}" - f" start_mongod" - f" {set_fcv_cmd if loop_num == 1 else ''}" - f" {seed_docs if loop_num == 1 else ''}") - ret, output = call_remote_operation(local_ops, remote_python, script_name, client_args, - remote_operation) + remote_operation = ( + f"--remoteOperation" + f" {rsync_opt}" + f" --mongodHost {mongod_host}" + f" --mongodPort {secret_port}" + f" {rsync_cmd}" + f" start_mongod" + f" {set_fcv_cmd if loop_num == 1 else ''}" + f" {seed_docs if loop_num == 1 else ''}" + ) + ret, output = call_remote_operation( + local_ops, remote_python, script_name, client_args, remote_operation + ) rsync_text = "rsync_data beforerecovery & " LOGGER.info("****%sstart mongod: %d %s****", rsync_text, ret, output) if ret: @@ -1521,11 +1702,20 @@ def main(parser_actions, options): # Optionally validate canary document locally. if validate_canary_local: - mongo = pymongo.MongoClient(**get_mongo_client_args( - host=mongod_host, port=secret_port, server_selection_timeout_ms=one_hour_ms, - socket_timeout_ms=one_hour_ms)) - ret = mongo_validate_canary(mongo, powercycle_constants.DB_NAME, - powercycle_constants.COLLECTION_NAME, canary_doc) + mongo = pymongo.MongoClient( + **get_mongo_client_args( + host=mongod_host, + port=secret_port, + server_selection_timeout_ms=one_hour_ms, + socket_timeout_ms=one_hour_ms, + ) + ) + ret = mongo_validate_canary( + mongo, + powercycle_constants.DB_NAME, + powercycle_constants.COLLECTION_NAME, + canary_doc, + ) LOGGER.info("Local canary validation: %d", ret) if ret: local_exit(ret) @@ -1535,15 +1725,20 @@ def main(parser_actions, options): new_config_file = NamedTempFile.create(suffix=".yml", directory="tmp") temp_client_files.append(new_config_file) validation_test_data = { - "skipValidationOnNamespaceNotFound": True, "allowUncleanShutdowns": True + "skipValidationOnNamespaceNotFound": True, + "allowUncleanShutdowns": True, } new_resmoke_config(with_external_server, new_config_file, validation_test_data) - ret, output = resmoke_client(mongo_repo_root_dir, mongo_path, host_port, - "jstests/hooks/run_validate_collections.js", new_config_file) + ret, output = resmoke_client( + mongo_repo_root_dir, + mongo_path, + host_port, + "jstests/hooks/run_validate_collections.js", + new_config_file, + ) LOGGER.info("Local collection validation: %d %s", ret, output) if ret: - network_error = ( - f"network error while attempting to run command 'isMaster' on host '{host_port}'") + network_error = f"network error while attempting to run command 'isMaster' on host '{host_port}'" # Mark this error as ssh failure, since it happens during the first test loop before # the first server crash and likely related to port forwarding not working, which # uses ssh tunnel command. @@ -1553,8 +1748,9 @@ def main(parser_actions, options): # Shutdown mongod on secret port. remote_op = f"--remoteOperation --mongodPort {secret_port} shutdown_mongod" - ret, output = call_remote_operation(local_ops, remote_python, script_name, client_args, - remote_op) + ret, output = call_remote_operation( + local_ops, remote_python, script_name, client_args, remote_op + ) LOGGER.info("****shutdown_mongod: %d %s****", ret, output) if ret: local_exit(ret) @@ -1567,14 +1763,17 @@ def main(parser_actions, options): # Optionally, rsync the post-recovery database. # Start monogd on the standard port. - remote_op = (f"--remoteOperation" - f" {rsync_opt}" - f" --mongodHost {mongod_host}" - f" --mongodPort {standard_port}" - f" {rsync_cmd}" - f" start_mongod") - ret, output = call_remote_operation(local_ops, remote_python, script_name, client_args, - remote_op) + remote_op = ( + f"--remoteOperation" + f" {rsync_opt}" + f" --mongodHost {mongod_host}" + f" --mongodPort {standard_port}" + f" {rsync_cmd}" + f" start_mongod" + ) + ret, output = call_remote_operation( + local_ops, remote_python, script_name, client_args, remote_op + ) rsync_text = "rsync_data afterrecovery & " LOGGER.info("****%s start mongod: %d %s****", rsync_text, ret, output) if ret: @@ -1586,12 +1785,22 @@ def main(parser_actions, options): host_port = f"{mongod_host}:{standard_port}" for i in range(num_crud_clients): crud_config_file = NamedTempFile.create(suffix=".yml", directory="tmp") - crud_test_data["collectionName"] = f"{powercycle_constants.COLLECTION_NAME}-{i}" - new_resmoke_config(with_external_server, crud_config_file, crud_test_data, eval_str) - _, _ = resmoke_client(work_dir=mongo_repo_root_dir, mongo_path=mongo_path, - host_port=host_port, js_test=powercycle_constants.CRUD_CLIENT, - resmoke_suite=crud_config_file, repeat_num=100, no_wait=True, - log_file=f"crud_{i}.log") + crud_test_data["collectionName"] = ( + f"{powercycle_constants.COLLECTION_NAME}-{i}" + ) + new_resmoke_config( + with_external_server, crud_config_file, crud_test_data, eval_str + ) + _, _ = resmoke_client( + work_dir=mongo_repo_root_dir, + mongo_path=mongo_path, + host_port=host_port, + js_test=powercycle_constants.CRUD_CLIENT, + resmoke_suite=crud_config_file, + repeat_num=100, + no_wait=True, + log_file=f"crud_{i}.log", + ) LOGGER.info("****Started %d CRUD client(s)****", num_crud_clients) @@ -1601,11 +1810,19 @@ def main(parser_actions, options): fsm_test_data["dbNamePrefix"] = f"fsm-{i}" # Do collection validation only for the first FSM client. fsm_test_data["validateCollections"] = bool(i == 0) - new_resmoke_config(with_external_server, fsm_config_file, fsm_test_data, eval_str) - _, _ = resmoke_client(work_dir=mongo_repo_root_dir, mongo_path=mongo_path, - host_port=host_port, js_test=fsm_client, - resmoke_suite=fsm_config_file, repeat_num=100, no_wait=True, - log_file=f"fsm_{i}.log") + new_resmoke_config( + with_external_server, fsm_config_file, fsm_test_data, eval_str + ) + _, _ = resmoke_client( + work_dir=mongo_repo_root_dir, + mongo_path=mongo_path, + host_port=host_port, + js_test=fsm_client, + resmoke_suite=fsm_config_file, + repeat_num=100, + no_wait=True, + log_file=f"fsm_{i}.log", + ) LOGGER.info("****Started %d FSM client(s)****", num_fsm_clients) @@ -1613,15 +1830,24 @@ def main(parser_actions, options): crash_canary = {} canary_doc = {"x": time.time()} orig_canary_doc = copy.deepcopy(canary_doc) - mongo = pymongo.MongoClient(**get_mongo_client_args(host=mongod_host, port=standard_port, - server_selection_timeout_ms=one_hour_ms, - socket_timeout_ms=one_hour_ms)) + mongo = pymongo.MongoClient( + **get_mongo_client_args( + host=mongod_host, + port=standard_port, + server_selection_timeout_ms=one_hour_ms, + socket_timeout_ms=one_hour_ms, + ) + ) crash_canary["function"] = mongo_insert_canary crash_canary["args"] = [ - mongo, powercycle_constants.DB_NAME, powercycle_constants.COLLECTION_NAME, canary_doc + mongo, + powercycle_constants.DB_NAME, + powercycle_constants.COLLECTION_NAME, + canary_doc, ] - ret, output = crash_server_or_kill_mongod(task_config, crash_canary, local_ops, script_name, - client_args) + ret, output = crash_server_or_kill_mongod( + task_config, crash_canary, local_ops, script_name, client_args + ) LOGGER.info("Crash server or Kill mongod: %d %s****", ret, output) @@ -1644,31 +1870,49 @@ def main(parser_actions, options): NamedTempFile.delete(temp_file) # Reestablish remote access after crash. - local_ops = LocalToRemoteOperations(user_host=ssh_user_host, - ssh_connection_options=ssh_connection_options, - ssh_options=ssh_options, use_shell=True, - access_retry_count=options.ssh_access_retry_count) + local_ops = LocalToRemoteOperations( + user_host=ssh_user_host, + ssh_connection_options=ssh_connection_options, + ssh_options=ssh_options, + use_shell=True, + access_retry_count=options.ssh_access_retry_count, + ) verify_remote_access(local_ops) - ret, output = call_remote_operation(local_ops, remote_python, script_name, client_args, - "--remoteOperation noop") + ret, output = call_remote_operation( + local_ops, remote_python, script_name, client_args, "--remoteOperation noop" + ) boot_time_after_crash = get_boot_datetime(output) if boot_time_after_crash == -1 or boot_time_after_recovery == -1: LOGGER.warning( "Cannot compare boot time after recovery: %s with boot time after crash: %s", - boot_time_after_recovery, boot_time_after_crash) - elif task_config.crash_method != "kill" and boot_time_after_crash <= boot_time_after_recovery: - raise Exception(f"System boot time after crash ({boot_time_after_crash}) is not newer" - f" than boot time before crash ({boot_time_after_recovery})") + boot_time_after_recovery, + boot_time_after_crash, + ) + elif ( + task_config.crash_method != "kill" + and boot_time_after_crash <= boot_time_after_recovery + ): + raise Exception( + f"System boot time after crash ({boot_time_after_crash}) is not newer" + f" than boot time before crash ({boot_time_after_recovery})" + ) canary_doc = copy.deepcopy(orig_canary_doc) test_time = int(time.time()) - start_time - LOGGER.info("****Completed test loop %d test time %d seconds****", loop_num, test_time) + LOGGER.info( + "****Completed test loop %d test time %d seconds****", loop_num, test_time + ) if loop_num == test_loops: break - ret, output = call_remote_operation(local_ops, remote_python, script_name, client_args, - "--remoteOperation check_disk") + ret, output = call_remote_operation( + local_ops, + remote_python, + script_name, + client_args, + "--remoteOperation check_disk", + ) if ret != 0: LOGGER.error("****check_disk: %d %s****", ret, output) diff --git a/buildscripts/resmokelib/powercycle/powercycle_config.py b/buildscripts/resmokelib/powercycle/powercycle_config.py index 91592431cc1..093fb2df427 100644 --- a/buildscripts/resmokelib/powercycle/powercycle_config.py +++ b/buildscripts/resmokelib/powercycle/powercycle_config.py @@ -1,4 +1,5 @@ """Powercycle tasks config.""" + import yaml from buildscripts.resmokelib.powercycle import powercycle, powercycle_constants @@ -13,17 +14,24 @@ class PowercycleTaskConfig: """Initialize.""" self.name = task_yaml.get("name", "") - self.crash_method = task_yaml.get("crash_method", powercycle_constants.DEFAULT_CRASH_METHOD) - self.test_loops = task_yaml.get("test_loops", powercycle_constants.DEFAULT_TEST_LOOPS) - self.seed_doc_num = task_yaml.get("seed_doc_num", powercycle_constants.DEFAULT_SEED_DOC_NUM) + self.crash_method = task_yaml.get( + "crash_method", powercycle_constants.DEFAULT_CRASH_METHOD + ) + self.test_loops = task_yaml.get( + "test_loops", powercycle_constants.DEFAULT_TEST_LOOPS + ) + self.seed_doc_num = task_yaml.get( + "seed_doc_num", powercycle_constants.DEFAULT_SEED_DOC_NUM + ) self.write_concern = task_yaml.get("write_concern", "{}") self.read_concern_level = task_yaml.get("read_concern_level", None) self.fcv = task_yaml.get("fcv", None) self.repl_set = task_yaml.get("repl_set", None) - self.mongod_options = task_yaml.get("mongod_options", - powercycle_constants.DEFAULT_MONGOD_OPTIONS) + self.mongod_options = task_yaml.get( + "mongod_options", powercycle_constants.DEFAULT_MONGOD_OPTIONS + ) def __str__(self): """Return as dict.""" @@ -36,7 +44,8 @@ def get_task_config(task_name, is_remote): if is_remote: config_location = powercycle.abs_path( - f"{powercycle_constants.REMOTE_DIR}/{POWERCYCLE_TASKS_CONFIG}") + f"{powercycle_constants.REMOTE_DIR}/{POWERCYCLE_TASKS_CONFIG}" + ) else: config_location = powercycle.abs_path(POWERCYCLE_TASKS_CONFIG) @@ -48,6 +57,8 @@ def get_task_config(task_name, is_remote): if single_task_yaml["name"] == task_name: return PowercycleTaskConfig(single_task_yaml) - raise Exception(f"Task with name '{task_name}' is not found" - f" in powercycle tasks configuration file '{POWERCYCLE_TASKS_CONFIG}'." - f" Please add a task there with the appropriate name.") + raise Exception( + f"Task with name '{task_name}' is not found" + f" in powercycle tasks configuration file '{POWERCYCLE_TASKS_CONFIG}'." + f" Please add a task there with the appropriate name." + ) diff --git a/buildscripts/resmokelib/powercycle/powercycle_constants.py b/buildscripts/resmokelib/powercycle/powercycle_constants.py index ddeeaab3f04..731a9d8aea1 100644 --- a/buildscripts/resmokelib/powercycle/powercycle_constants.py +++ b/buildscripts/resmokelib/powercycle/powercycle_constants.py @@ -2,7 +2,7 @@ import os -if 'CI' in os.environ: +if "CI" in os.environ: # in CI, the expansions file is located in the ${workdir}, one dir up # from src, the checkout directory EXPANSIONS_FILE = "../expansions.yml" @@ -12,14 +12,16 @@ else: # For ssh disable the options GSSAPIAuthentication, CheckHostIP, StrictHostKeyChecking # & UserKnownHostsFile, since these are local connections from one AWS instance to another. -DEFAULT_SSH_CONNECTION_OPTIONS = ("-o ServerAliveCountMax=10" - " -o ServerAliveInterval=6" - " -o StrictHostKeyChecking=no" - " -o ConnectTimeout=30" - " -o ConnectionAttempts=3" - " -o UserKnownHostsFile=/dev/null" - " -o GSSAPIAuthentication=no" - " -o CheckHostIP=no") +DEFAULT_SSH_CONNECTION_OPTIONS = ( + "-o ServerAliveCountMax=10" + " -o ServerAliveInterval=6" + " -o StrictHostKeyChecking=no" + " -o ConnectTimeout=30" + " -o ConnectionAttempts=3" + " -o UserKnownHostsFile=/dev/null" + " -o GSSAPIAuthentication=no" + " -o CheckHostIP=no" +) MONITOR_PROC_FILE = "proc.json" MONITOR_SYSTEM_FILE = "system.json" @@ -54,7 +56,9 @@ CRUD_CLIENT = "jstests/hooks/crud_client.js" CONFIG_CRUD_CLIENT = "buildscripts/resmokeconfig/suites/with_external_server.yml" NUM_FSM_CLIENTS = 20 FSM_CLIENT = "jstests/libs/fsm_serial_client.js" -SET_READ_AND_WRITE_CONCERN = "jstests/libs/override_methods/set_read_and_write_concerns.js" +SET_READ_AND_WRITE_CONCERN = ( + "jstests/libs/override_methods/set_read_and_write_concerns.js" +) REPORT_JSON_FILE = "report.json" POWERCYCLE_EXIT_FILE = "powercycle_exit.yml" @@ -62,7 +66,9 @@ POWERCYCLE_EXIT_FILE = "powercycle_exit.yml" DEFAULT_CRASH_METHOD = "internal" DEFAULT_TEST_LOOPS = 15 DEFAULT_SEED_DOC_NUM = 10_000 -DEFAULT_MONGOD_OPTIONS = ("--setParameter enableTestCommands=1" - " --setParameter logComponentVerbosity='{storage:{recovery:2}}'" - " --storageEngine wiredTiger" - " --wiredTigerEngineConfigString 'debug_mode=[table_logging=true]'") +DEFAULT_MONGOD_OPTIONS = ( + "--setParameter enableTestCommands=1" + " --setParameter logComponentVerbosity='{storage:{recovery:2}}'" + " --storageEngine wiredTiger" + " --wiredTigerEngineConfigString 'debug_mode=[table_logging=true]'" +) diff --git a/buildscripts/resmokelib/powercycle/remote_hang_analyzer/__init__.py b/buildscripts/resmokelib/powercycle/remote_hang_analyzer/__init__.py index d0ab930ff92..1693b9c8f70 100644 --- a/buildscripts/resmokelib/powercycle/remote_hang_analyzer/__init__.py +++ b/buildscripts/resmokelib/powercycle/remote_hang_analyzer/__init__.py @@ -17,11 +17,17 @@ class RunHangAnalyzerOnRemoteInstance(PowercycleCommand): """:return: None.""" if "private_ip_address" not in self.expansions: return - hang_analyzer_processes = "dbtest,java,mongo,mongod,mongos,python,_test" if "hang_analyzer_processes" not in self.expansions else self.expansions[ - "hang_analyzer_processes"] + hang_analyzer_processes = ( + "dbtest,java,mongo,mongod,mongos,python,_test" + if "hang_analyzer_processes" not in self.expansions + else self.expansions["hang_analyzer_processes"] + ) hang_analyzer_option = f"-o file -o stdout -p {hang_analyzer_processes}" - hang_analyzer_dump_core = True if "hang_analyzer_dump_core" not in self.expansions else self.expansions[ - "hang_analyzer_dump_core"] + hang_analyzer_dump_core = ( + True + if "hang_analyzer_dump_core" not in self.expansions + else self.expansions["hang_analyzer_dump_core"] + ) if hang_analyzer_dump_core: hang_analyzer_option = f"-c {hang_analyzer_option}" diff --git a/buildscripts/resmokelib/powercycle/save_diagnostics/__init__.py b/buildscripts/resmokelib/powercycle/save_diagnostics/__init__.py index cdba94caa37..9cac5e4d2bd 100644 --- a/buildscripts/resmokelib/powercycle/save_diagnostics/__init__.py +++ b/buildscripts/resmokelib/powercycle/save_diagnostics/__init__.py @@ -81,7 +81,7 @@ class GatherRemoteMongoCoredumps(PowercycleCommand): remote_dir = powercycle_constants.REMOTE_DIR # Find all core files and move to $remote_dir cmds = "core_files=$(/usr/bin/find -H . \\( -name '*.core' -o -name '*.mdmp' \\) 2> /dev/null)" - cmds = f"{cmds}; if [ -z \"$core_files\" ]; then exit 0; fi" + cmds = f'{cmds}; if [ -z "$core_files" ]; then exit 0; fi' cmds = f"{cmds}; echo Found remote core files $core_files, moving to $(pwd)" cmds = f"{cmds}; for core_file in $core_files" cmds = f"{cmds}; do base_name=$(echo $core_file | sed 's/.*///')" @@ -108,7 +108,9 @@ class CopyRemoteMongoCoredumps(PowercycleCommand): remote_dir = powercycle_constants.REMOTE_DIR # Core file may not exist so we ignore the return code. - self.remote_op.operation(SSHOperation.SHELL, f"{remote_dir}/*.{core_suffix}", None, True) + self.remote_op.operation( + SSHOperation.SHELL, f"{remote_dir}/*.{core_suffix}", None, True + ) class CopyEC2MonitorFiles(PowercycleCommand): @@ -122,4 +124,4 @@ class CopyEC2MonitorFiles(PowercycleCommand): cmd = f"{tar_cmd} czf ec2_monitor_files.tgz {powercycle_constants.EC2_MONITOR_FILES}" self.remote_op.operation(SSHOperation.SHELL, cmd, None) - self.remote_op.operation(SSHOperation.COPY_FROM, 'ec2_monitor_files.tgz', None) + self.remote_op.operation(SSHOperation.COPY_FROM, "ec2_monitor_files.tgz", None) diff --git a/buildscripts/resmokelib/powercycle/setup/__init__.py b/buildscripts/resmokelib/powercycle/setup/__init__.py index ef2c0308926..6e10e219090 100644 --- a/buildscripts/resmokelib/powercycle/setup/__init__.py +++ b/buildscripts/resmokelib/powercycle/setup/__init__.py @@ -16,7 +16,9 @@ class SetUpEC2Instance(PowercycleCommand): """:return: None.""" default_retry_count = 2 - retry_count = int(self.expansions.get("set_up_retry_count", default_retry_count)) + retry_count = int( + self.expansions.get("set_up_retry_count", default_retry_count) + ) # First operation - # Create remote_dir. @@ -34,35 +36,49 @@ class SetUpEC2Instance(PowercycleCommand): cmds = f"{self.sudo} mkdir -p {remote_dir}; {self.sudo} chown -R {user_group} {remote_dir}; {set_permission_stmt} {remote_dir}; ls -ld {remote_dir}" cmds = f"{cmds}; {self.sudo} mkdir -p {db_path}; {self.sudo} chown -R {user_group} {db_path}; {set_permission_stmt} {db_path}; ls -ld {db_path}" - self.remote_op.operation(SSHOperation.SHELL, cmds, retry=True, retry_count=retry_count) + self.remote_op.operation( + SSHOperation.SHELL, cmds, retry=True, retry_count=retry_count + ) # Second operation - # Copy buildscripts and mongoDB executables to the remote host. - files = ["etc", "buildscripts", "dist-test/bin", "poetry.lock", "pyproject.toml"] + files = [ + "etc", + "buildscripts", + "dist-test/bin", + "poetry.lock", + "pyproject.toml", + ] shared_libs = "dist-test/lib" if os.path.isdir(shared_libs): files.append(shared_libs) - self.remote_op.operation(SSHOperation.COPY_TO, files, remote_dir, retry=True, - retry_count=retry_count) + self.remote_op.operation( + SSHOperation.COPY_TO, files, remote_dir, retry=True, retry_count=retry_count + ) # Third operation - # Set up virtualenv on remote. venv = powercycle_constants.VIRTUALENV_DIR - python = "/opt/mongodbtoolchain/v4/bin/python3" if "python" not in self.expansions else self.expansions[ - "python"] + python = ( + "/opt/mongodbtoolchain/v4/bin/python3" + if "python" not in self.expansions + else self.expansions["python"] + ) cmds = f"python_loc=$(which {python})" cmds = f"{cmds}; remote_dir={remote_dir}" - cmds = f"{cmds}; if [ \"Windows_NT\" = \"$OS\" ]; then python_loc=$(cygpath -w $python_loc); remote_dir=$(cygpath -w $remote_dir); fi" + cmds = f'{cmds}; if [ "Windows_NT" = "$OS" ]; then python_loc=$(cygpath -w $python_loc); remote_dir=$(cygpath -w $remote_dir); fi' cmds = f"{cmds}; $python_loc -m venv --system-site-packages {venv}" cmds = f"{cmds}; activate=$(find {venv} -name 'activate')" cmds = f"{cmds}; . $activate" cmds = f"{cmds}; python3 -m pip install 'poetry==2.0.0'" cmds = f"{cmds}; pushd $remote_dir && python3 -m poetry install --no-root --sync && popd" - self.remote_op.operation(SSHOperation.SHELL, cmds, retry=True, retry_count=retry_count) + self.remote_op.operation( + SSHOperation.SHELL, cmds, retry=True, retry_count=retry_count + ) # Operation below that enables core dumps is commented out since it causes failures on Ubuntu 18.04. # It might be a race condition, so `nohup reboot` command is likely a culprit here. @@ -113,7 +129,9 @@ class SetUpEC2Instance(PowercycleCommand): monitor_proc_file = powercycle_constants.MONITOR_PROC_FILE if self.is_windows(): # Since curator runs as SYSTEM user, ensure the output files can be accessed. - cmds = f"{cmds}; touch {monitor_system_file}; chmod 777 {monitor_system_file}" + cmds = ( + f"{cmds}; touch {monitor_system_file}; chmod 777 {monitor_system_file}" + ) cmds = f"{cmds}; cygrunsrv --install curator_sys --path curator --chdir $HOME --args 'stat system --file {monitor_system_file}'" cmds = f"{cmds}; touch {monitor_proc_file}; chmod 777 {monitor_proc_file}" cmds = f"{cmds}; cygrunsrv --install curator_proc --path curator --chdir $HOME --args 'stat process-all --file {monitor_proc_file}'" @@ -121,14 +139,16 @@ class SetUpEC2Instance(PowercycleCommand): cmds = f"{cmds}; cygrunsrv --start curator_proc" else: cmds = f"{cmds}; touch {monitor_system_file} {monitor_proc_file}" - cmds = f"{cmds}; cmd=\"@reboot cd $HOME && {self.sudo} ./curator stat system >> {monitor_system_file}\"" - cmds = f"{cmds}; (crontab -l ; echo \"$cmd\") | crontab -" - cmds = f"{cmds}; cmd=\"@reboot cd $HOME && $sudo ./curator stat process-all >> {monitor_proc_file}\"" - cmds = f"{cmds}; (crontab -l ; echo \"$cmd\") | crontab -" + cmds = f'{cmds}; cmd="@reboot cd $HOME && {self.sudo} ./curator stat system >> {monitor_system_file}"' + cmds = f'{cmds}; (crontab -l ; echo "$cmd") | crontab -' + cmds = f'{cmds}; cmd="@reboot cd $HOME && $sudo ./curator stat process-all >> {monitor_proc_file}"' + cmds = f'{cmds}; (crontab -l ; echo "$cmd") | crontab -' cmds = f"{cmds}; crontab -l" cmds = f"{cmds}; {{ {self.sudo} $HOME/curator stat system --file {monitor_system_file} > /dev/null 2>&1 & {self.sudo} $HOME/curator stat process-all --file {monitor_proc_file} > /dev/null 2>&1 & }} & disown" - self.remote_op.operation(SSHOperation.SHELL, cmds, retry=True, retry_count=retry_count) + self.remote_op.operation( + SSHOperation.SHELL, cmds, retry=True, retry_count=retry_count + ) # Seventh operation - # Install NotMyFault, used to crash Windows. @@ -140,4 +160,6 @@ class SetUpEC2Instance(PowercycleCommand): cmds = f"curl -s -o {windows_crash_zip} {windows_crash_dl}" cmds = f"{cmds}; unzip -q {windows_crash_zip} -d {windows_crash_dir}" cmds = f"{cmds}; chmod +x {windows_crash_dir}/*.exe" - self.remote_op.operation(SSHOperation.SHELL, cmds, retry=True, retry_count=retry_count) + self.remote_op.operation( + SSHOperation.SHELL, cmds, retry=True, retry_count=retry_count + ) diff --git a/buildscripts/resmokelib/run/__init__.py b/buildscripts/resmokelib/run/__init__.py index b311a511b68..749f364b2c8 100644 --- a/buildscripts/resmokelib/run/__init__.py +++ b/buildscripts/resmokelib/run/__init__.py @@ -88,8 +88,9 @@ class TestRunner(Subcommand): flush_success = logging.flush.stop_thread() if not flush_success: self._resmoke_logger.error( - 'Failed to flush all logs within a reasonable amount of time, ' - 'treating logs as incomplete') + "Failed to flush all logs within a reasonable amount of time, " + "treating logs as incomplete" + ) if not flush_success or logging.buildlogger.is_log_output_incomplete(): self._exit_on_incomplete_logging() @@ -101,12 +102,16 @@ class TestRunner(Subcommand): # or cause a JIRA ticket to be created. self._resmoke_logger.info( "We failed to flush all log output to logkeeper but all tests passed, so" - " ignoring.") + " ignoring." + ) else: exit_code = errors.LoggerRuntimeConfigError.EXIT_CODE self._resmoke_logger.info( "Exiting with code %d rather than requested code %d because we failed to flush all" - " log output to logkeeper.", exit_code, self._exit_code) + " log output to logkeeper.", + exit_code, + self._exit_code, + ) self._exit_code = exit_code # Force exit the process without cleaning up or calling the finally block @@ -143,7 +148,9 @@ class TestRunner(Subcommand): def list_suites(self): """List the suites that are available to execute.""" suite_names = suitesconfig.get_named_suites() - self._resmoke_logger.info("Suites available to execute:\n%s", "\n".join(suite_names)) + self._resmoke_logger.info( + "Suites available to execute:\n%s", "\n".join(suite_names) + ) def find_suites(self): """List the suites that run the specified tests.""" @@ -151,8 +158,9 @@ class TestRunner(Subcommand): suites_by_test = self._find_suites_by_test(suites) for test in sorted(suites_by_test): suite_names = suites_by_test[test] - self._resmoke_logger.info("%s will be run by the following suite(s): %s", test, - suite_names) + self._resmoke_logger.info( + "%s will be run by the following suite(s): %s", test, suite_names + ) def list_tags(self): """ @@ -173,8 +181,9 @@ class TestRunner(Subcommand): for single_tag_block in splitted_tags_block: tag_name, doc = list_tags.get_tag_doc(single_tag_block) - if tag_name and (tag_name not in tag_docs - or len(doc) > len(tag_docs[tag_name])): + if tag_name and ( + tag_name not in tag_docs or len(doc) > len(tag_docs[tag_name]) + ): tag_docs[tag_name] = doc if suite_name in config.SUITE_FILES: # pylint: disable=unsupported-membership-test @@ -183,9 +192,13 @@ class TestRunner(Subcommand): if config.SUITE_FILES == [config.DEFAULTS["suite_files"]]: out_tag_docs = tag_docs else: - out_tag_docs = {tag: doc for tag, doc in tag_docs.items() if tag in out_tag_names} + out_tag_docs = { + tag: doc for tag, doc in tag_docs.items() if tag in out_tag_names + } - self._resmoke_logger.info("Found tags in suites:%s", list_tags.make_output(out_tag_docs)) + self._resmoke_logger.info( + "Found tags in suites:%s", list_tags.make_output(out_tag_docs) + ) def generate_multiversion_exclude_tags(self): """Generate multiversion exclude tags file.""" @@ -214,9 +227,15 @@ class TestRunner(Subcommand): suites = self._get_suites() for suite in suites: self._shuffle_tests(suite) - sb = ["Tests that would be run in suite {}".format(suite.get_display_name())] + sb = [ + "Tests that would be run in suite {}".format(suite.get_display_name()) + ] sb.extend(suite.tests or ["(no tests)"]) - sb.append("Tests that would be excluded from suite {}".format(suite.get_display_name())) + sb.append( + "Tests that would be excluded from suite {}".format( + suite.get_display_name() + ) + ) sb.extend(suite.excluded or ["(no tests)"]) self._exec_logger.info("\n".join(sb)) @@ -226,24 +245,33 @@ class TestRunner(Subcommand): if config.REQUIRES_WORKLOAD_CONTAINER_SETUP: self._setup_workload_container() - self._resmoke_logger.info("verbatim resmoke.py invocation: %s", - " ".join([shlex.quote(arg) for arg in sys.argv])) + self._resmoke_logger.info( + "verbatim resmoke.py invocation: %s", + " ".join([shlex.quote(arg) for arg in sys.argv]), + ) self._check_for_mongo_processes() if config.EVERGREEN_TASK_DOC: - self._resmoke_logger.info("Evergreen task documentation:\n%s", - config.EVERGREEN_TASK_DOC) + self._resmoke_logger.info( + "Evergreen task documentation:\n%s", config.EVERGREEN_TASK_DOC + ) elif config.EVERGREEN_TASK_NAME: - self._resmoke_logger.info("Evergreen task documentation is absent for this task.") - task_name = utils.get_task_name_without_suffix(config.EVERGREEN_TASK_NAME, - config.EVERGREEN_VARIANT_NAME) + self._resmoke_logger.info( + "Evergreen task documentation is absent for this task." + ) + task_name = utils.get_task_name_without_suffix( + config.EVERGREEN_TASK_NAME, config.EVERGREEN_VARIANT_NAME + ) self._resmoke_logger.info( "If you are familiar with the functionality of %s task, " - "please consider adding documentation for it in %s", task_name, - os.path.join(config.CONFIG_DIR, "evg_task_doc", "evg_task_doc.yml")) + "please consider adding documentation for it in %s", + task_name, + os.path.join(config.CONFIG_DIR, "evg_task_doc", "evg_task_doc.yml"), + ) self._log_local_resmoke_invocation() from buildscripts.resmokelib import multiversionconstants + multiversionconstants.log_constants(self._resmoke_logger) suites = None @@ -254,7 +282,9 @@ class TestRunner(Subcommand): for suite in suites: self._interrupted = self._run_suite(suite) - if self._interrupted or (suite.options.fail_fast and suite.return_code != 0): + if self._interrupted or ( + suite.options.fail_fast and suite.return_code != 0 + ): self._log_resmoke_summary(suites) self.exit(suite.return_code) @@ -283,15 +313,24 @@ class TestRunner(Subcommand): # Currently, you can only run one suite at a time from within a workload container suite = self._get_suites()[0] - if "jstestfuzz/out/*.js" in suite.get_selector_config().get("roots", []) and not any( - filename.endswith(".js") for filename in os.listdir(jstestfuzz_tests_dir)): - subprocess.run([ - "./src/scripts/npm_run.sh", - "jstestfuzz", - "--", - "--jsTestsDir", - jstests_dir, - ], cwd=jstestfuzz_repo_dir, stdout=sys.stdout, stderr=sys.stderr, check=True) + if "jstestfuzz/out/*.js" in suite.get_selector_config().get( + "roots", [] + ) and not any( + filename.endswith(".js") for filename in os.listdir(jstestfuzz_tests_dir) + ): + subprocess.run( + [ + "./src/scripts/npm_run.sh", + "jstestfuzz", + "--", + "--jsTestsDir", + jstests_dir, + ], + cwd=jstestfuzz_repo_dir, + stdout=sys.stdout, + stderr=sys.stderr, + check=True, + ) def _run_suite(self, suite: Suite): """Run a test suite.""" @@ -307,7 +346,9 @@ class TestRunner(Subcommand): # Do not log local args if this is not being ran in evergreen if not config.EVERGREEN_TASK_ID: - print("Skipping local invocation because evergreen task id was not provided.") + print( + "Skipping local invocation because evergreen task id was not provided." + ) return evg_conf = parse_evergreen_file(config.EVERGREEN_PROJECT_CONFIG_PATH) @@ -323,9 +364,11 @@ class TestRunner(Subcommand): # The suite names should be in the evergreen functions in this case if task is None: for current_task in evg_conf.tasks: - func = current_task.find_func_command("run tests") \ - or current_task.find_func_command("generate resmoke tasks") \ + func = ( + current_task.find_func_command("run tests") + or current_task.find_func_command("generate resmoke tasks") or current_task.find_func_command("run benchmark tests") + ) if func and get_dict_value(func, ["vars", "suite"]) == suite_name: task = current_task break @@ -343,51 +386,66 @@ class TestRunner(Subcommand): break if task is None: - raise RuntimeError(f"Error: Could not find evergreen task definition for {suite_name}") + raise RuntimeError( + f"Error: Could not find evergreen task definition for {suite_name}" + ) is_multiversion = "multiversion" in task.tags generate_func = task.find_func_command("generate resmoke tasks") is_jstestfuzz = False if generate_func: - is_jstestfuzz = get_dict_value(generate_func, ["vars", "is_jstestfuzz"]) == "true" + is_jstestfuzz = ( + get_dict_value(generate_func, ["vars", "is_jstestfuzz"]) == "true" + ) local_args = to_local_args() local_args = strip_fuzz_config_params(local_args) local_resmoke_invocation = ( - f"{os.path.join('buildscripts', 'resmoke.py')} {' '.join(local_args)}") + f"{os.path.join('buildscripts', 'resmoke.py')} {' '.join(local_args)}" + ) using_config_fuzzer = False if config.FUZZ_MONGOD_CONFIGS: using_config_fuzzer = True - local_resmoke_invocation += f" --fuzzMongodConfigs={config.FUZZ_MONGOD_CONFIGS}" + local_resmoke_invocation += ( + f" --fuzzMongodConfigs={config.FUZZ_MONGOD_CONFIGS}" + ) - self._resmoke_logger.info("Fuzzed mongodSetParameters:\n%s", - config.MONGOD_SET_PARAMETERS) - self._resmoke_logger.info("Fuzzed wiredTigerConnectionString: %s", - config.WT_ENGINE_CONFIG) + self._resmoke_logger.info( + "Fuzzed mongodSetParameters:\n%s", config.MONGOD_SET_PARAMETERS + ) + self._resmoke_logger.info( + "Fuzzed wiredTigerConnectionString: %s", config.WT_ENGINE_CONFIG + ) if config.FUZZ_MONGOS_CONFIGS: using_config_fuzzer = True - local_resmoke_invocation += f" --fuzzMongosConfigs={config.FUZZ_MONGOS_CONFIGS}" + local_resmoke_invocation += ( + f" --fuzzMongosConfigs={config.FUZZ_MONGOS_CONFIGS}" + ) - self._resmoke_logger.info("Fuzzed mongosSetParameters:\n%s", - config.MONGOS_SET_PARAMETERS) + self._resmoke_logger.info( + "Fuzzed mongosSetParameters:\n%s", config.MONGOS_SET_PARAMETERS + ) if using_config_fuzzer: - local_resmoke_invocation += f" --configFuzzSeed={str(config.CONFIG_FUZZ_SEED)}" + local_resmoke_invocation += ( + f" --configFuzzSeed={str(config.CONFIG_FUZZ_SEED)}" + ) if multiversion_bin_version: default_tag_file = config.DEFAULTS["exclude_tags_file_path"] local_resmoke_invocation += f" --tagFile={default_tag_file}" - resmoke_env_options = '' - if os.path.exists('resmoke_env_options.txt'): - with open('resmoke_env_options.txt') as fin: + resmoke_env_options = "" + if os.path.exists("resmoke_env_options.txt"): + with open("resmoke_env_options.txt") as fin: resmoke_env_options = fin.read().strip() local_resmoke_invocation = f"{resmoke_env_options} {local_resmoke_invocation}" - self._resmoke_logger.info("resmoke.py invocation for local usage: %s", - local_resmoke_invocation) + self._resmoke_logger.info( + "resmoke.py invocation for local usage: %s", local_resmoke_invocation + ) lines = [] @@ -447,15 +505,15 @@ class TestRunner(Subcommand): def _check_for_mongo_processes(self): """Check for existing mongo processes as they could interfere with running the tests.""" - if config.AUTO_KILL == 'off' or config.SHELL_CONN_STRING is not None: + if config.AUTO_KILL == "off" or config.SHELL_CONN_STRING is not None: return rogue_procs = [] # Iterate over all running process for proc in psutil.process_iter(): try: - parent_resmoke_pid = proc.environ().get('RESMOKE_PARENT_PROCESS') - parent_resmoke_ctime = proc.environ().get('RESMOKE_PARENT_CTIME') + parent_resmoke_pid = proc.environ().get("RESMOKE_PARENT_PROCESS") + parent_resmoke_ctime = proc.environ().get("RESMOKE_PARENT_CTIME") if not parent_resmoke_pid: continue if platform.system() == "Darwin": @@ -469,8 +527,10 @@ class TestRunner(Subcommand): # where there are no rogue resmoke processes. cmd = ["ps", "-E", str(proc.pid)] ps_proc = subprocess.run(cmd, capture_output=True) - if (f"RESMOKE_PARENT_PROCESS={parent_resmoke_pid}" not in - ps_proc.stdout.decode()): + if ( + f"RESMOKE_PARENT_PROCESS={parent_resmoke_pid}" + not in ps_proc.stdout.decode() + ): continue if psutil.pid_exists(int(parent_resmoke_pid)): # Double check `parent_resmoke_pid` is really a rooting resmoke process. Having @@ -489,11 +549,13 @@ class TestRunner(Subcommand): if rogue_procs: msg = "detected existing mongo processes. Please clean up these processes as they may affect tests:" - if config.AUTO_KILL == 'on': - msg += textwrap.dedent("""\ + if config.AUTO_KILL == "on": + msg += textwrap.dedent( + """\ Congratulations, you have selected auto kill mode: - HASTA LA VISTA MONGO""" + r""" + HASTA LA VISTA MONGO""" + + r""" ______ <((((((\\\ / . }\ @@ -509,27 +571,33 @@ class TestRunner(Subcommand): \ '\ / \ | | _/ / \ \ \ | | / / \ \ \ / - """) + """ + ) print(f"WARNING: {msg}") else: self._resmoke_logger.error("ERROR: %s", msg) for proc in rogue_procs: - if config.AUTO_KILL == 'on': + if config.AUTO_KILL == "on": proc_msg = f" Target acquired: pid: {str(proc.pid).ljust(5)} name: {proc.exe()}" try: proc.kill() - except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess) as exc: + except ( + psutil.NoSuchProcess, + psutil.AccessDenied, + psutil.ZombieProcess, + ) as exc: proc_msg += f" - target escaped: {type(exc).__name__ }" else: proc_msg += " - target destroyed\n" print(proc_msg) else: - self._resmoke_logger.error(" pid: %s name: %s", - str(proc.pid).ljust(5), proc.exe()) + self._resmoke_logger.error( + " pid: %s name: %s", str(proc.pid).ljust(5), proc.exe() + ) - if config.AUTO_KILL == 'on': + if config.AUTO_KILL == "on": print("I'll be back...\n") else: raise errors.ResmokeError( @@ -537,7 +605,8 @@ class TestRunner(Subcommand): Failing because existing mongo processes detected. You can use --autoKillResmokeMongo=on to automatically kill the processes, or --autoKillResmokeMongo=off to ignore them. - """)) + """) + ) def _log_resmoke_summary(self, suites): """Log a summary of the resmoke run.""" @@ -548,8 +617,11 @@ class TestRunner(Subcommand): def _log_suite_summary(self, suite: Suite): """Log a summary of the suite run.""" self._resmoke_logger.info("=" * 80) - self._resmoke_logger.info("Summary of %s suite: %s", suite.get_display_name(), - self._get_suite_summary(suite)) + self._resmoke_logger.info( + "Summary of %s suite: %s", + suite.get_display_name(), + self._get_suite_summary(suite), + ) @TRACER.start_as_current_span("run.__init__._execute_suite") def _execute_suite(self, suite: Suite) -> bool: @@ -566,15 +638,21 @@ class TestRunner(Subcommand): self._exec_logger.info("Skipping %s, no tests to run", suite.test_kind) suite.return_code = 0 execute_suite_span.set_status(StatusCode.OK) - execute_suite_span.set_attributes({ - Suite.METRIC_NAMES.RETURN_CODE: suite.return_code, - Suite.METRIC_NAMES.RETURN_STATUS: "skipped", - }) + execute_suite_span.set_attributes( + { + Suite.METRIC_NAMES.RETURN_CODE: suite.return_code, + Suite.METRIC_NAMES.RETURN_STATUS: "skipped", + } + ) return False executor_config = suite.get_executor_config() try: executor = testing.executor.TestSuiteExecutor( - self._exec_logger, suite, archive_instance=self._archive, **executor_config) + self._exec_logger, + suite, + archive_instance=self._archive, + **executor_config, + ) # If this is a "docker compose build", we just build the docker compose images for # this resmoke configuration and exit. if config.DOCKER_COMPOSE_BUILD_IMAGES: @@ -583,44 +661,62 @@ class TestRunner(Subcommand): else: executor.run() except (errors.UserInterrupt, errors.LoggerRuntimeConfigError) as err: - self._exec_logger.error("Encountered an error when running %ss of suite %s: %s", - suite.test_kind, suite.get_display_name(), err) + self._exec_logger.error( + "Encountered an error when running %ss of suite %s: %s", + suite.test_kind, + suite.get_display_name(), + err, + ) suite.return_code = err.EXIT_CODE - return_status = "user_interrupt" if isinstance( - err, errors.UserInterrupt) else "logger_runtime_config" + return_status = ( + "user_interrupt" + if isinstance(err, errors.UserInterrupt) + else "logger_runtime_config" + ) execute_suite_span.set_status(StatusCode.ERROR, description=return_status) - execute_suite_span.set_attributes({ - Suite.METRIC_NAMES.RETURN_CODE: suite.return_code, - Suite.METRIC_NAMES.RETURN_STATUS: return_status, - }) + execute_suite_span.set_attributes( + { + Suite.METRIC_NAMES.RETURN_CODE: suite.return_code, + Suite.METRIC_NAMES.RETURN_STATUS: return_status, + } + ) return True except OSError as err: self._exec_logger.error("Encountered an OSError: %s", err) suite.return_code = 74 # Exit code for OSError on POSIX systems. return_status = "os_error" execute_suite_span.set_status(StatusCode.ERROR, description=return_status) - execute_suite_span.set_attributes({ - Suite.METRIC_NAMES.RETURN_CODE: suite.return_code, - Suite.METRIC_NAMES.RETURN_STATUS: return_status, - Suite.METRIC_NAMES.ERRORNO: err.errno - }) + execute_suite_span.set_attributes( + { + Suite.METRIC_NAMES.RETURN_CODE: suite.return_code, + Suite.METRIC_NAMES.RETURN_STATUS: return_status, + Suite.METRIC_NAMES.ERRORNO: err.errno, + } + ) return True except: # pylint: disable=bare-except - self._exec_logger.exception("Encountered an error when running %ss of suite %s.", - suite.test_kind, suite.get_display_name()) + self._exec_logger.exception( + "Encountered an error when running %ss of suite %s.", + suite.test_kind, + suite.get_display_name(), + ) suite.return_code = 2 return_status = "unknown_error" execute_suite_span.set_status(StatusCode.ERROR, description=return_status) - execute_suite_span.set_attributes({ - Suite.METRIC_NAMES.RETURN_CODE: suite.return_code, - Suite.METRIC_NAMES.RETURN_STATUS: return_status, - }) + execute_suite_span.set_attributes( + { + Suite.METRIC_NAMES.RETURN_CODE: suite.return_code, + Suite.METRIC_NAMES.RETURN_STATUS: return_status, + } + ) return False execute_suite_span.set_status(StatusCode.OK) - execute_suite_span.set_attributes({ - Suite.METRIC_NAMES.RETURN_CODE: suite.return_code, - Suite.METRIC_NAMES.RETURN_STATUS: "success", - }) + execute_suite_span.set_attributes( + { + Suite.METRIC_NAMES.RETURN_CODE: suite.return_code, + Suite.METRIC_NAMES.RETURN_STATUS: "success", + } + ) return False def _shuffle_tests(self, suite: Suite): @@ -628,8 +724,12 @@ class TestRunner(Subcommand): random.seed(config.RANDOM_SEED) if not config.SHUFFLE: return - self._exec_logger.info("Shuffling order of tests for %ss in suite %s. The seed is %d.", - suite.test_kind, suite.get_display_name(), config.RANDOM_SEED) + self._exec_logger.info( + "Shuffling order of tests for %ss in suite %s. The seed is %d.", + suite.test_kind, + suite.get_display_name(), + config.RANDOM_SEED, + ) random.shuffle(suite.tests) # pylint: disable=inconsistent-return-statements @@ -638,7 +738,9 @@ class TestRunner(Subcommand): try: return suitesconfig.get_suites(config.SUITE_FILES, config.TEST_FILES) except errors.SuiteNotFound as err: - self._resmoke_logger.error("Failed to parse YAML suite definition: %s", str(err)) + self._resmoke_logger.error( + "Failed to parse YAML suite definition: %s", str(err) + ) self.list_suites() self.exit(1) except errors.InvalidMatrixSuiteError as err: @@ -647,17 +749,21 @@ class TestRunner(Subcommand): except errors.TestExcludedFromSuiteError as err: self._resmoke_logger.error( "Cannot run excluded test in suite config. Use '--force-excluded-tests' to override: %s", - str(err)) + str(err), + ) self.exit(1) return [] def _log_suite_config(self, suite: Suite): sb = [ "YAML configuration of suite {}".format(suite.get_display_name()), - utils.dump_yaml({"test_kind": suite.get_test_kind_config()}), "", - utils.dump_yaml({"selector": suite.get_selector_config()}), "", - utils.dump_yaml({"executor": suite.get_executor_config()}), "", - utils.dump_yaml({"logging": config.LOGGING_CONFIG}) + utils.dump_yaml({"test_kind": suite.get_test_kind_config()}), + "", + utils.dump_yaml({"selector": suite.get_selector_config()}), + "", + utils.dump_yaml({"executor": suite.get_executor_config()}), + "", + utils.dump_yaml({"logging": config.LOGGING_CONFIG}), ] self._resmoke_logger.info("\n".join(sb)) @@ -676,8 +782,11 @@ class TestRunner(Subcommand): """Set up the archival feature if enabled in the cli options.""" if config.ARCHIVE_FILE: self._archive = utils.archival.Archival( - archival_json_file=config.ARCHIVE_FILE, limit_size_mb=config.ARCHIVE_LIMIT_MB, - limit_files=config.ARCHIVE_LIMIT_TESTS, logger=self._exec_logger) + archival_json_file=config.ARCHIVE_FILE, + limit_size_mb=config.ARCHIVE_LIMIT_MB, + limit_files=config.ARCHIVE_LIMIT_TESTS, + logger=self._exec_logger, + ) def _exit_archival(self): """Finish up archival tasks before exit if enabled in the cli options.""" @@ -691,7 +800,9 @@ class TestRunner(Subcommand): sys.exit(exit_code) -_TagInfo = collections.namedtuple("_TagInfo", ["tag_name", "evergreen_aware", "suite_options"]) +_TagInfo = collections.namedtuple( + "_TagInfo", ["tag_name", "evergreen_aware", "suite_options"] +) class TestRunnerEvg(TestRunner): @@ -705,7 +816,9 @@ class TestRunnerEvg(TestRunner): tag_name="resource_intensive", evergreen_aware=False, suite_options=config.SuiteOptions.ALL_INHERITED._replace( # type: ignore - num_jobs=1)) + num_jobs=1 + ), + ) @staticmethod def _make_evergreen_aware_tags(tag_name): @@ -723,12 +836,18 @@ class TestRunnerEvg(TestRunner): tags_format.append("{tag_name}|{task_name}|{variant_name}") if config.EVERGREEN_DISTRO_ID is not None: - tags_format.append("{tag_name}|{task_name}|{variant_name}|{distro_id}") + tags_format.append( + "{tag_name}|{task_name}|{variant_name}|{distro_id}" + ) return [ - tag.format(tag_name=tag_name, task_name=config.EVERGREEN_TASK_NAME, - variant_name=config.EVERGREEN_VARIANT_NAME, - distro_id=config.EVERGREEN_DISTRO_ID) for tag in tags_format + tag.format( + tag_name=tag_name, + task_name=config.EVERGREEN_TASK_NAME, + variant_name=config.EVERGREEN_VARIANT_NAME, + distro_id=config.EVERGREEN_DISTRO_ID, + ) + for tag in tags_format ] @classmethod @@ -741,8 +860,12 @@ class TestRunnerEvg(TestRunner): combinations = [] - combinations.append(("resource intensive", [(cls.RESOURCE_INTENSIVE_TAG, True)])) - combinations.append(("not resource intensive", [(cls.RESOURCE_INTENSIVE_TAG, False)])) + combinations.append( + ("resource intensive", [(cls.RESOURCE_INTENSIVE_TAG, True)]) + ) + combinations.append( + ("not resource intensive", [(cls.RESOURCE_INTENSIVE_TAG, False)]) + ) return combinations @@ -763,10 +886,10 @@ class TestRunnerEvg(TestRunner): suites.append(suite) continue - for (tag_desc, tag_combo) in self._make_tag_combinations(): + for tag_desc, tag_combo in self._make_tag_combinations(): suite_options_list = [] - for (tag_info, enabled) in tag_combo: + for tag_info, enabled in tag_combo: if tag_info.evergreen_aware: tags = self._make_evergreen_aware_tags(tag_info.tag_name) include_tags = {"$anyOf": tags} @@ -774,10 +897,13 @@ class TestRunnerEvg(TestRunner): include_tags = tag_info.tag_name if enabled: - suite_options = tag_info.suite_options._replace(include_tags=include_tags) + suite_options = tag_info.suite_options._replace( + include_tags=include_tags + ) else: suite_options = config.SuiteOptions.ALL_INHERITED._replace( - include_tags={"$not": include_tags}) + include_tags={"$not": include_tags} + ) suite_options_list.append(suite_options) @@ -814,8 +940,14 @@ class RunPlugin(PluginInterface): :param kwargs: additional args :return: None or a Subcommand """ - if subcommand in ('find-suites', 'list-suites', 'list-tags', 'run', - 'generate-multiversion-exclude-tags', 'generate-matrix-suites'): + if subcommand in ( + "find-suites", + "list-suites", + "list-tags", + "run", + "generate-multiversion-exclude-tags", + "generate-matrix-suites", + ): configure_resmoke.validate_and_update_config(parser, parsed_args) if config.EVERGREEN_TASK_ID is not None: return TestRunnerEvg(subcommand, **kwargs) @@ -828,61 +960,113 @@ class RunPlugin(PluginInterface): """Create and add the parser for the Run subcommand.""" parser = subparsers.add_parser("run", help="Runs the specified tests.") - parser.set_defaults(dry_run="off", shuffle="auto", stagger_jobs="off", - majority_read_concern="on") - - parser.add_argument("test_files", metavar="TEST_FILES", nargs="*", - help="Explicit test files to run") + parser.set_defaults( + dry_run="off", + shuffle="auto", + stagger_jobs="off", + majority_read_concern="on", + ) parser.add_argument( - "--suites", dest="suite_files", metavar="SUITE1,SUITE2", - help=("Comma separated list of YAML files that each specify the configuration" - " of a suite. If the file is located in the resmokeconfig/suites/" - " directory, then the basename without the .yml extension can be" - " specified, e.g. 'core'. If a list of files is passed in as" - " positional arguments, they will be run using the suites'" - " configurations.")) + "test_files", + metavar="TEST_FILES", + nargs="*", + help="Explicit test files to run", + ) parser.add_argument( - "--autoKillResmokeMongo", dest="auto_kill", choices=['on', 'error', - 'off'], default='on', - help=("When resmoke starts up, existing mongo processes created from resmoke " - " could cause issues when running tests. This option causes resmoke to kill" - " the existing processes and continue running the test, or if 'error' option" - " is used, prints the offending processes and fails the test.")) - - parser.add_argument("--installDir", dest="install_dir", metavar="INSTALL_DIR", - help="Directory to search for MongoDB binaries") + "--suites", + dest="suite_files", + metavar="SUITE1,SUITE2", + help=( + "Comma separated list of YAML files that each specify the configuration" + " of a suite. If the file is located in the resmokeconfig/suites/" + " directory, then the basename without the .yml extension can be" + " specified, e.g. 'core'. If a list of files is passed in as" + " positional arguments, they will be run using the suites'" + " configurations." + ), + ) parser.add_argument( - "--alwaysUseLogFiles", dest="always_use_log_files", action="store_true", - help=("Logs server output to a file located in the db path and prevents the" - " cleaning of dbpaths after testing. Note that conflicting options" - " passed in from test files may cause an error.")) + "--autoKillResmokeMongo", + dest="auto_kill", + choices=["on", "error", "off"], + default="on", + help=( + "When resmoke starts up, existing mongo processes created from resmoke " + " could cause issues when running tests. This option causes resmoke to kill" + " the existing processes and continue running the test, or if 'error' option" + " is used, prints the offending processes and fails the test." + ), + ) parser.add_argument( - "--basePort", dest="base_port", metavar="PORT", - help=("The starting port number to use for mongod and mongos processes" - " spawned by resmoke.py or the tests themselves. Each fixture and Job" - " allocates a contiguous range of ports.")) - - parser.add_argument("--continueOnFailure", action="store_true", dest="continue_on_failure", - help="Executes all tests in all suites, even if some of them fail.") - - parser.add_argument("--dbtest", dest="dbtest_executable", metavar="PATH", - help="The path to the dbtest executable for resmoke to use.") + "--installDir", + dest="install_dir", + metavar="INSTALL_DIR", + help="Directory to search for MongoDB binaries", + ) parser.add_argument( - "--excludeWithAnyTags", action="append", dest="exclude_with_any_tags", + "--alwaysUseLogFiles", + dest="always_use_log_files", + action="store_true", + help=( + "Logs server output to a file located in the db path and prevents the" + " cleaning of dbpaths after testing. Note that conflicting options" + " passed in from test files may cause an error." + ), + ) + + parser.add_argument( + "--basePort", + dest="base_port", + metavar="PORT", + help=( + "The starting port number to use for mongod and mongos processes" + " spawned by resmoke.py or the tests themselves. Each fixture and Job" + " allocates a contiguous range of ports." + ), + ) + + parser.add_argument( + "--continueOnFailure", + action="store_true", + dest="continue_on_failure", + help="Executes all tests in all suites, even if some of them fail.", + ) + + parser.add_argument( + "--dbtest", + dest="dbtest_executable", + metavar="PATH", + help="The path to the dbtest executable for resmoke to use.", + ) + + parser.add_argument( + "--excludeWithAnyTags", + action="append", + dest="exclude_with_any_tags", metavar="TAG1,TAG2", - help=("Comma separated list of tags. Any jstest that contains any of the" - " specified tags will be excluded from any suites that are run." - " The tag '{}' is implicitly part of this list.".format(config.EXCLUDED_TAG))) + help=( + "Comma separated list of tags. Any jstest that contains any of the" + " specified tags will be excluded from any suites that are run." + " The tag '{}' is implicitly part of this list.".format( + config.EXCLUDED_TAG + ) + ), + ) parser.add_argument( - "--force-excluded-tests", dest="force_excluded_tests", action="store_true", - help=("Allows running tests in a suite config's excluded test roots" - " when passed as positional arg(s).")) + "--force-excluded-tests", + dest="force_excluded_tests", + action="store_true", + help=( + "Allows running tests in a suite config's excluded test roots" + " when passed as positional arg(s)." + ), + ) parser.add_argument( "--skipSymbolization", @@ -891,95 +1075,170 @@ class RunPlugin(PluginInterface): help="Skips symbolizing stacktraces generated by tests.", ) - parser.add_argument("--genny", dest="genny_executable", metavar="PATH", - help="The path to the genny executable for resmoke to use.") - parser.add_argument( - "--includeWithAnyTags", action="append", dest="include_with_any_tags", - metavar="TAG1,TAG2", - help=("Comma separated list of tags. For the jstest portion of the suite(s)," - " only tests which have at least one of the specified tags will be" - " run.")) - - parser.add_argument( - "--dockerComposeBuildImages", dest="docker_compose_build_images", - metavar="IMAGE1,IMAGE2,IMAGE3", help= - ("Comma separated list of base images to build for running resmoke against an External System Under Test:" - " (1) `workload`: Your mongo repo with a python development environment setup." - " (2) `mongo-binaries`: The `mongo`, `mongod`, `mongos` binaries to run tests with." - " (3) `config`: The target suite's `docker-compose.yml` file, startup scripts & configuration." - " All three images are needed to successfully setup an External System Under Test." - " This will not run any tests. It will just build the images and generate" - " the `docker-compose.yml` configuration to set up the External System Under Test for the desired suite." - )) - - parser.add_argument( - "--dockerComposeBuildEnv", dest="docker_compose_build_env", - choices=["local", "evergreen"], default="local", help= - ("Set the environment where this `--dockerComposeBuildImages` is happening -- defaults to: `local`." - )) - - parser.add_argument( - "--dockerComposeTag", dest="docker_compose_tag", metavar="TAG", default="development", - help=("The `tag` name to use for images built during a `--dockerComposeBuildImages`.")) - - parser.add_argument( - "--externalSUT", dest="external_sut", action="store_true", default=False, help= - ("This option should only be used when running resmoke against an External System Under Test." - " The External System Under Test should be setup via the command generated after" - " running: `buildscripts/resmoke.py run --suite [suite_name] ... --dockerComposeBuildImages" - " config,workload,mongo-binaries`.")) - - parser.add_argument( - "--sanityCheck", action="store_true", dest="sanity_check", help= - "Truncate the test queue to 1 item, just in order to verify the suite is properly set up." + "--genny", + dest="genny_executable", + metavar="PATH", + help="The path to the genny executable for resmoke to use.", ) parser.add_argument( - "--includeWithAllTags", action="append", dest="include_with_all_tags", + "--includeWithAnyTags", + action="append", + dest="include_with_any_tags", metavar="TAG1,TAG2", - help=("Comma separated list of tags. For the jstest portion of the suite(s)," - "tests that have all of the specified tags will be run.")) - - parser.add_argument("-n", action="store_const", const="tests", dest="dry_run", - help="Outputs the tests that would be run.") + help=( + "Comma separated list of tags. For the jstest portion of the suite(s)," + " only tests which have at least one of the specified tags will be" + " run." + ), + ) parser.add_argument( - "--recordWith", dest="undo_recorder_path", metavar="PATH", + "--dockerComposeBuildImages", + dest="docker_compose_build_images", + metavar="IMAGE1,IMAGE2,IMAGE3", + help=( + "Comma separated list of base images to build for running resmoke against an External System Under Test:" + " (1) `workload`: Your mongo repo with a python development environment setup." + " (2) `mongo-binaries`: The `mongo`, `mongod`, `mongos` binaries to run tests with." + " (3) `config`: The target suite's `docker-compose.yml` file, startup scripts & configuration." + " All three images are needed to successfully setup an External System Under Test." + " This will not run any tests. It will just build the images and generate" + " the `docker-compose.yml` configuration to set up the External System Under Test for the desired suite." + ), + ) + + parser.add_argument( + "--dockerComposeBuildEnv", + dest="docker_compose_build_env", + choices=["local", "evergreen"], + default="local", + help=( + "Set the environment where this `--dockerComposeBuildImages` is happening -- defaults to: `local`." + ), + ) + + parser.add_argument( + "--dockerComposeTag", + dest="docker_compose_tag", + metavar="TAG", + default="development", + help=( + "The `tag` name to use for images built during a `--dockerComposeBuildImages`." + ), + ) + + parser.add_argument( + "--externalSUT", + dest="external_sut", + action="store_true", + default=False, + help=( + "This option should only be used when running resmoke against an External System Under Test." + " The External System Under Test should be setup via the command generated after" + " running: `buildscripts/resmoke.py run --suite [suite_name] ... --dockerComposeBuildImages" + " config,workload,mongo-binaries`." + ), + ) + + parser.add_argument( + "--sanityCheck", + action="store_true", + dest="sanity_check", + help="Truncate the test queue to 1 item, just in order to verify the suite is properly set up.", + ) + + parser.add_argument( + "--includeWithAllTags", + action="append", + dest="include_with_all_tags", + metavar="TAG1,TAG2", + help=( + "Comma separated list of tags. For the jstest portion of the suite(s)," + "tests that have all of the specified tags will be run." + ), + ) + + parser.add_argument( + "-n", + action="store_const", + const="tests", + dest="dry_run", + help="Outputs the tests that would be run.", + ) + + parser.add_argument( + "--recordWith", + dest="undo_recorder_path", + metavar="PATH", help="Record execution of mongo, mongod and mongos processes;" - "specify the path to UndoDB's 'live-record' binary") + "specify the path to UndoDB's 'live-record' binary", + ) # TODO: add support for --dryRun=commands parser.add_argument( - "--dryRun", action="store", dest="dry_run", choices=("off", "tests"), metavar="MODE", - help=("Instead of running the tests, outputs the tests that would be run" - " (if MODE=tests). Defaults to MODE=%(default)s.")) + "--dryRun", + action="store", + dest="dry_run", + choices=("off", "tests"), + metavar="MODE", + help=( + "Instead of running the tests, outputs the tests that would be run" + " (if MODE=tests). Defaults to MODE=%(default)s." + ), + ) parser.add_argument( - "-j", "--jobs", type=int, dest="jobs", metavar="JOBS", - help=("The number of Job instances to use. Each instance will receive its" - " own MongoDB deployment to dispatch tests to.")) + "-j", + "--jobs", + type=int, + dest="jobs", + metavar="JOBS", + help=( + "The number of Job instances to use. Each instance will receive its" + " own MongoDB deployment to dispatch tests to." + ), + ) parser.set_defaults(logger_file="console") parser.add_argument( - "--shellSeed", action="store", dest="shell_seed", default=None, - help=("Sets the seed for replset and sharding fixtures to use. " - "This only works when only one test is input into resmoke.")) + "--shellSeed", + action="store", + dest="shell_seed", + default=None, + help=( + "Sets the seed for replset and sharding fixtures to use. " + "This only works when only one test is input into resmoke." + ), + ) parser.add_argument( - "--mongocryptdSetParameters", dest="mongocryptd_set_parameters", action="append", + "--mongocryptdSetParameters", + dest="mongocryptd_set_parameters", + action="append", metavar="{key1: value1, key2: value2, ..., keyN: valueN}", - help=("Passes one or more --setParameter options to all mongocryptd processes" - " started by resmoke.py. The argument is specified as bracketed YAML -" - " i.e. JSON with support for single quoted and unquoted keys.")) - - parser.add_argument("--numClientsPerFixture", type=int, dest="num_clients_per_fixture", - help="Number of clients running tests per fixture.") + help=( + "Passes one or more --setParameter options to all mongocryptd processes" + " started by resmoke.py. The argument is specified as bracketed YAML -" + " i.e. JSON with support for single quoted and unquoted keys." + ), + ) parser.add_argument( - "--useTenantClient", default=False, dest="use_tenant_client", action="store_true", help= - "Use tenant client. If set, each client will be constructed with a generated tenant id." + "--numClientsPerFixture", + type=int, + dest="num_clients_per_fixture", + help="Number of clients running tests per fixture.", + ) + + parser.add_argument( + "--useTenantClient", + default=False, + dest="use_tenant_client", + action="store_true", + help="Use tenant client. If set, each client will be constructed with a generated tenant id.", ) parser.add_argument( @@ -1001,189 +1260,337 @@ class RunPlugin(PluginInterface): help="Overrides the default fixture and connects with a mongodb:// connection" " string to an existing MongoDB cluster instead. This is useful for" " connecting to a MongoDB deployment started outside of resmoke.py including" - " one running in a debugger.") + " one running in a debugger.", + ) parser.add_argument( - "--shellPort", dest="shell_port", metavar="PORT", + "--shellPort", + dest="shell_port", + metavar="PORT", help="Convenience form of --shellConnString for connecting to an" " existing MongoDB cluster with the URL mongodb://localhost:[PORT]." - " This is useful for connecting to a server running in a debugger.") - - parser.add_argument("--shellTls", dest="shell_tls_enabled", action="store_true", - help="Whether to use TLS when connecting.") - - parser.add_argument("--shellTlsCertificateKeyFile", dest="shell_tls_certificate_key_file", - metavar="SHELL_TLS_CERTIFICATE_KEY_FILE", - help="The TLS certificate to use when connecting.") - - parser.add_argument("--repeat", "--repeatSuites", type=int, dest="repeat_suites", - metavar="N", - help="Repeats the given suite(s) N times, or until one fails.") + " This is useful for connecting to a server running in a debugger.", + ) parser.add_argument( - "--repeatTests", type=int, dest="repeat_tests", metavar="N", + "--shellTls", + dest="shell_tls_enabled", + action="store_true", + help="Whether to use TLS when connecting.", + ) + + parser.add_argument( + "--shellTlsCertificateKeyFile", + dest="shell_tls_certificate_key_file", + metavar="SHELL_TLS_CERTIFICATE_KEY_FILE", + help="The TLS certificate to use when connecting.", + ) + + parser.add_argument( + "--repeat", + "--repeatSuites", + type=int, + dest="repeat_suites", + metavar="N", + help="Repeats the given suite(s) N times, or until one fails.", + ) + + parser.add_argument( + "--repeatTests", + type=int, + dest="repeat_tests", + metavar="N", help="Repeats the tests inside each suite N times. This applies to tests" " defined in the suite configuration as well as tests defined on the command" - " line.") + " line.", + ) parser.add_argument( - "--repeatTestsMax", type=int, dest="repeat_tests_max", metavar="N", + "--repeatTestsMax", + type=int, + dest="repeat_tests_max", + metavar="N", help="Repeats the tests inside each suite no more than N time when" " --repeatTestsSecs is specified. This applies to tests defined in the suite" - " configuration as well as tests defined on the command line.") + " configuration as well as tests defined on the command line.", + ) parser.add_argument( - "--repeatTestsMin", type=int, dest="repeat_tests_min", metavar="N", + "--repeatTestsMin", + type=int, + dest="repeat_tests_min", + metavar="N", help="Repeats the tests inside each suite at least N times when" " --repeatTestsSecs is specified. This applies to tests defined in the suite" - " configuration as well as tests defined on the command line.") + " configuration as well as tests defined on the command line.", + ) parser.add_argument( - "--repeatTestsSecs", type=float, dest="repeat_tests_secs", metavar="SECONDS", + "--repeatTestsSecs", + type=float, + dest="repeat_tests_secs", + metavar="SECONDS", help="Repeats the tests inside each suite this amount of time. Note that" " this option is mutually exclusive with --repeatTests. This applies to" " tests defined in the suite configuration as well as tests defined on the" - " command line.") + " command line.", + ) parser.add_argument( - "--seed", type=int, dest="seed", metavar="SEED", - help=("Seed for the random number generator. Useful in combination with the" - " --shuffle option for producing a consistent test execution order.")) - - parser.add_argument("--mongo", dest="mongo_executable", metavar="PATH", - help="The path to the mongo shell executable for resmoke.py to use.") + "--seed", + type=int, + dest="seed", + metavar="SEED", + help=( + "Seed for the random number generator. Useful in combination with the" + " --shuffle option for producing a consistent test execution order." + ), + ) parser.add_argument( - "--shuffle", action="store_const", const="on", dest="shuffle", - help=("Randomizes the order in which tests are executed. This is equivalent" - " to specifying --shuffleMode=on.")) + "--mongo", + dest="mongo_executable", + metavar="PATH", + help="The path to the mongo shell executable for resmoke.py to use.", + ) parser.add_argument( - "--shuffleMode", action="store", dest="shuffle", choices=("on", "off", "auto"), + "--shuffle", + action="store_const", + const="on", + dest="shuffle", + help=( + "Randomizes the order in which tests are executed. This is equivalent" + " to specifying --shuffleMode=on." + ), + ) + + parser.add_argument( + "--shuffleMode", + action="store", + dest="shuffle", + choices=("on", "off", "auto"), metavar="ON|OFF|AUTO", - help=("Controls whether to randomize the order in which tests are executed." - " Defaults to auto when not supplied. auto enables randomization in" - " all cases except when the number of jobs requested is 1.")) + help=( + "Controls whether to randomize the order in which tests are executed." + " Defaults to auto when not supplied. auto enables randomization in" + " all cases except when the number of jobs requested is 1." + ), + ) parser.add_argument( - "--executor", dest="executor_file", + "--executor", + dest="executor_file", help="OBSOLETE: Superceded by --suites; specify --suites=SUITE path/to/test" - " to run a particular test under a particular suite configuration.") - - parser.add_argument( - "--linearChain", action="store", dest="linear_chain", choices=("on", "off"), - metavar="ON|OFF", help="Enable or disable linear chaining for tests using " - "ReplicaSetFixture.") - - parser.add_argument( - "--backupOnRestartDir", action="store", type=str, dest="backup_on_restart_dir", - metavar="DIRECTORY", help= - "Every time a mongod restarts on existing data files, the data files will be backed up underneath the input directory." + " to run a particular test under a particular suite configuration.", ) parser.add_argument( - "--replayFile", action="store", type=str, dest="replay_file", metavar="FILE", help= - "Run the tests listed in the input file. This is an alternative to passing test files as positional arguments on the command line. Each line in the file must be a path to a test file relative to the current working directory. A short-hand for `resmoke run --replay_file foo` is `resmoke run @foo`." + "--linearChain", + action="store", + dest="linear_chain", + choices=("on", "off"), + metavar="ON|OFF", + help="Enable or disable linear chaining for tests using " + "ReplicaSetFixture.", ) parser.add_argument( - "--mrlog", action="store_const", const="mrlog", dest="mrlog", help= - "Pipe output through the `mrlog` binary for converting logv2 logs to human readable logs." + "--backupOnRestartDir", + action="store", + type=str, + dest="backup_on_restart_dir", + metavar="DIRECTORY", + help="Every time a mongod restarts on existing data files, the data files will be backed up underneath the input directory.", ) parser.add_argument( - "--userFriendlyOutput", action="store", type=str, dest="user_friendly_output", - metavar="FILE", help= - "Have resmoke redirect all output to FILE. Additionally, stdout will contain lines that typically indicate that the test is making progress, or an error has happened. If `mrlog` is in the path it will be used. `tee` and `egrep` must be in the path." + "--replayFile", + action="store", + type=str, + dest="replay_file", + metavar="FILE", + help="Run the tests listed in the input file. This is an alternative to passing test files as positional arguments on the command line. Each line in the file must be a path to a test file relative to the current working directory. A short-hand for `resmoke run --replay_file foo` is `resmoke run @foo`.", ) parser.add_argument( - "--runAllFeatureFlagTests", dest="run_all_feature_flag_tests", action="store_true", - help= - "Run MongoDB servers with all feature flags enabled and only run tests tags with these feature flags" + "--mrlog", + action="store_const", + const="mrlog", + dest="mrlog", + help="Pipe output through the `mrlog` binary for converting logv2 logs to human readable logs.", ) parser.add_argument( - "--runNoFeatureFlagTests", dest="run_no_feature_flag_tests", action="store_true", - help=("Do not run any tests tagged with enabled feature flags." - " This argument has precedence over --runAllFeatureFlagTests" - "; used for multiversion suites")) + "--userFriendlyOutput", + action="store", + type=str, + dest="user_friendly_output", + metavar="FILE", + help="Have resmoke redirect all output to FILE. Additionally, stdout will contain lines that typically indicate that the test is making progress, or an error has happened. If `mrlog` is in the path it will be used. `tee` and `egrep` must be in the path.", + ) - parser.add_argument("--additionalFeatureFlags", dest="additional_feature_flags", - action="append", metavar="featureFlag1, featureFlag2, ...", - help="Additional feature flags") + parser.add_argument( + "--runAllFeatureFlagTests", + dest="run_all_feature_flag_tests", + action="store_true", + help="Run MongoDB servers with all feature flags enabled and only run tests tags with these feature flags", + ) - parser.add_argument("--additionalFeatureFlagsFile", dest="additional_feature_flags_file", - action="store", metavar="FILE", - help="The path to a file with feature flags, delimited by newlines.") + parser.add_argument( + "--runNoFeatureFlagTests", + dest="run_no_feature_flag_tests", + action="store_true", + help=( + "Do not run any tests tagged with enabled feature flags." + " This argument has precedence over --runAllFeatureFlagTests" + "; used for multiversion suites" + ), + ) - parser.add_argument("--maxTestQueueSize", type=int, dest="max_test_queue_size", - help=argparse.SUPPRESS) + parser.add_argument( + "--additionalFeatureFlags", + dest="additional_feature_flags", + action="append", + metavar="featureFlag1, featureFlag2, ...", + help="Additional feature flags", + ) - parser.add_argument("--tagFile", action="append", dest="tag_files", metavar="TAG_FILES", - help="One or more YAML files that associate tests and tags.") + parser.add_argument( + "--additionalFeatureFlagsFile", + dest="additional_feature_flags_file", + action="store", + metavar="FILE", + help="The path to a file with feature flags, delimited by newlines.", + ) + + parser.add_argument( + "--maxTestQueueSize", + type=int, + dest="max_test_queue_size", + help=argparse.SUPPRESS, + ) + + parser.add_argument( + "--tagFile", + action="append", + dest="tag_files", + metavar="TAG_FILES", + help="One or more YAML files that associate tests and tags.", + ) configure_resmoke.add_otel_args(parser) mongodb_server_options = parser.add_argument_group( title=_MONGODB_SERVER_OPTIONS_TITLE, - description=("Options related to starting a MongoDB cluster that are forwarded from" - " resmoke.py to the fixture.")) + description=( + "Options related to starting a MongoDB cluster that are forwarded from" + " resmoke.py to the fixture." + ), + ) mongodb_server_options.add_argument( - "--mongod", dest="mongod_executable", metavar="PATH", - help="The path to the mongod executable for resmoke.py to use.") + "--mongod", + dest="mongod_executable", + metavar="PATH", + help="The path to the mongod executable for resmoke.py to use.", + ) mongodb_server_options.add_argument( - "--mongos", dest="mongos_executable", metavar="PATH", - help="The path to the mongos executable for resmoke.py to use.") + "--mongos", + dest="mongos_executable", + metavar="PATH", + help="The path to the mongos executable for resmoke.py to use.", + ) mongodb_server_options.add_argument( - "--mongodSetParameters", dest="mongod_set_parameters", action="append", + "--mongodSetParameters", + dest="mongod_set_parameters", + action="append", metavar="{key1: value1, key2: value2, ..., keyN: valueN}", - help=("Passes one or more --setParameter options to all mongod processes" - " started by resmoke.py. The argument is specified as bracketed YAML -" - " i.e. JSON with support for single quoted and unquoted keys.")) + help=( + "Passes one or more --setParameter options to all mongod processes" + " started by resmoke.py. The argument is specified as bracketed YAML -" + " i.e. JSON with support for single quoted and unquoted keys." + ), + ) mongodb_server_options.add_argument( - "--mongosSetParameters", dest="mongos_set_parameters", action="append", + "--mongosSetParameters", + dest="mongos_set_parameters", + action="append", metavar="{key1: value1, key2: value2, ..., keyN: valueN}", - help=("Passes one or more --setParameter options to all mongos processes" - " started by resmoke.py. The argument is specified as bracketed YAML -" - " i.e. JSON with support for single quoted and unquoted keys.")) + help=( + "Passes one or more --setParameter options to all mongos processes" + " started by resmoke.py. The argument is specified as bracketed YAML -" + " i.e. JSON with support for single quoted and unquoted keys." + ), + ) mongodb_server_options.add_argument( - "--dbpathPrefix", dest="dbpath_prefix", metavar="PATH", - help=("The directory which will contain the dbpaths of any mongod's started" - " by resmoke.py or the tests themselves.")) + "--dbpathPrefix", + dest="dbpath_prefix", + metavar="PATH", + help=( + "The directory which will contain the dbpaths of any mongod's started" + " by resmoke.py or the tests themselves." + ), + ) mongodb_server_options.add_argument( - "--majorityReadConcern", action="store", dest="majority_read_concern", choices=("on", - "off"), - metavar="ON|OFF", help=("Enable or disable majority read concern support." - " Defaults to %(default)s.")) + "--majorityReadConcern", + action="store", + dest="majority_read_concern", + choices=("on", "off"), + metavar="ON|OFF", + help=( + "Enable or disable majority read concern support." + " Defaults to %(default)s." + ), + ) mongodb_server_options.add_argument( - "--enableEnterpriseTests", action="store", dest="enable_enterprise_tests", default="on", - choices=("on", "off"), metavar="ON|OFF", - help=("Enable or disable enterprise tests. Defaults to 'on'.")) - - mongodb_server_options.add_argument("--flowControl", action="store", dest="flow_control", - choices=("on", "off"), metavar="ON|OFF", - help=("Enable or disable flow control.")) - - mongodb_server_options.add_argument("--flowControlTicketOverride", type=int, action="store", - dest="flow_control_tickets", metavar="TICKET_OVERRIDE", - help=("Number of tickets available for flow control.")) - - mongodb_server_options.add_argument("--storageEngine", dest="storage_engine", - metavar="ENGINE", - help="The storage engine used by dbtests and jstests.") + "--enableEnterpriseTests", + action="store", + dest="enable_enterprise_tests", + default="on", + choices=("on", "off"), + metavar="ON|OFF", + help=("Enable or disable enterprise tests. Defaults to 'on'."), + ) mongodb_server_options.add_argument( - "--storageEngineCacheSizeGB", dest="storage_engine_cache_size_gb", metavar="CONFIG", + "--flowControl", + action="store", + dest="flow_control", + choices=("on", "off"), + metavar="ON|OFF", + help=("Enable or disable flow control."), + ) + + mongodb_server_options.add_argument( + "--flowControlTicketOverride", + type=int, + action="store", + dest="flow_control_tickets", + metavar="TICKET_OVERRIDE", + help=("Number of tickets available for flow control."), + ) + + mongodb_server_options.add_argument( + "--storageEngine", + dest="storage_engine", + metavar="ENGINE", + help="The storage engine used by dbtests and jstests.", + ) + + mongodb_server_options.add_argument( + "--storageEngineCacheSizeGB", + dest="storage_engine_cache_size_gb", + metavar="CONFIG", help="Sets the storage engine cache size configuration" - " setting for all mongod's.") + " setting for all mongod's.", + ) mongodb_server_options.add_argument( "--storageEngineCacheSizePct", @@ -1194,80 +1601,131 @@ class RunPlugin(PluginInterface): ) mongodb_server_options.add_argument( - "--tlsMode", dest="tls_mode", metavar="TLS_MODE", help="Indicates what TLS mode mongod " + "--tlsMode", + dest="tls_mode", + metavar="TLS_MODE", + help="Indicates what TLS mode mongod " "and mongos servers should be started with. See also: https://www.mongodb.com" - "/docs/manual/reference/configuration-options/#mongodb-setting-net.tls.mode") + "/docs/manual/reference/configuration-options/#mongodb-setting-net.tls.mode", + ) mongodb_server_options.add_argument( - "--tlsCAFile", dest="tls_ca_file", metavar="TLS_CA_FILE", - help="Path to the CA certificate file to be used by all clients and servers.") + "--tlsCAFile", + dest="tls_ca_file", + metavar="TLS_CA_FILE", + help="Path to the CA certificate file to be used by all clients and servers.", + ) mongodb_server_options.add_argument( - "--mongodTlsCertificateKeyFile", dest="mongod_tls_certificate_key_file", + "--mongodTlsCertificateKeyFile", + dest="mongod_tls_certificate_key_file", metavar="MONGOD_TLS_CERTIFICATE_KEY_FILE", - help="Path to the TLS certificate to be used by all mongods.") + help="Path to the TLS certificate to be used by all mongods.", + ) mongodb_server_options.add_argument( - "--mongosTlsCertificateKeyFile", dest="mongos_tls_certificate_key_file", + "--mongosTlsCertificateKeyFile", + dest="mongos_tls_certificate_key_file", metavar="MONGOS_TLS_CERTIFICATE_KEY_FILE", - help="Path to the TLS certificate to be used by all mongoses.") + help="Path to the TLS certificate to be used by all mongoses.", + ) mongodb_server_options.add_argument( - "--numReplSetNodes", type=int, dest="num_replset_nodes", metavar="N", + "--numReplSetNodes", + type=int, + dest="num_replset_nodes", + metavar="N", help="The number of nodes to initialize per ReplicaSetFixture. This is also " "used to indicate the number of replica set members per shard in a " - "ShardedClusterFixture.") + "ShardedClusterFixture.", + ) mongodb_server_options.add_argument( - "--numShards", type=int, dest="num_shards", metavar="N", - help="The number of shards to use in a ShardedClusterFixture.") + "--numShards", + type=int, + dest="num_shards", + metavar="N", + help="The number of shards to use in a ShardedClusterFixture.", + ) mongodb_server_options.add_argument( - "--wiredTigerCollectionConfigString", dest="wt_coll_config", metavar="CONFIG", - help="Sets the WiredTiger collection configuration setting for all mongod's.") + "--wiredTigerCollectionConfigString", + dest="wt_coll_config", + metavar="CONFIG", + help="Sets the WiredTiger collection configuration setting for all mongod's.", + ) mongodb_server_options.add_argument( - "--wiredTigerEngineConfigString", dest="wt_engine_config", metavar="CONFIG", - help="Sets the WiredTiger engine configuration setting for all mongod's.") + "--wiredTigerEngineConfigString", + dest="wt_engine_config", + metavar="CONFIG", + help="Sets the WiredTiger engine configuration setting for all mongod's.", + ) mongodb_server_options.add_argument( - "--wiredTigerIndexConfigString", dest="wt_index_config", metavar="CONFIG", - help="Sets the WiredTiger index configuration setting for all mongod's.") + "--wiredTigerIndexConfigString", + dest="wt_index_config", + metavar="CONFIG", + help="Sets the WiredTiger index configuration setting for all mongod's.", + ) mongodb_server_options.add_argument( - "--fuzzMongodConfigs", dest="fuzz_mongod_configs", + "--fuzzMongodConfigs", + dest="fuzz_mongod_configs", help="Randomly chooses mongod parameters that were not specified. Use 'stress' to fuzz " "all configs including stressful storage configurations that may significantly " "slow down the server. Use 'normal' to only fuzz non-stressful configurations. ", - metavar="MODE", choices=('normal', 'stress')) + metavar="MODE", + choices=("normal", "stress"), + ) mongodb_server_options.add_argument( - "--fuzzMongosConfigs", dest="fuzz_mongos_configs", - help="Randomly chooses mongos parameters that were not specified", metavar="MODE", - choices=('normal', )) + "--fuzzMongosConfigs", + dest="fuzz_mongos_configs", + help="Randomly chooses mongos parameters that were not specified", + metavar="MODE", + choices=("normal",), + ) mongodb_server_options.add_argument( - "--configFuzzSeed", dest="config_fuzz_seed", metavar="PATH", - help="Sets the seed used by mongod and mongos config fuzzers") + "--configFuzzSeed", + dest="config_fuzz_seed", + metavar="PATH", + help="Sets the seed used by mongod and mongos config fuzzers", + ) mongodb_server_options.add_argument( - "--configShard", dest="config_shard", metavar="CONFIG", - help="If set, specifies which node is the config shard. Can also be set to 'any'.") + "--configShard", + dest="config_shard", + metavar="CONFIG", + help="If set, specifies which node is the config shard. Can also be set to 'any'.", + ) mongodb_server_options.add_argument( - "--embeddedRouter", dest="embedded_router", metavar="CONFIG", - help="If set, uses embedded routers instead of dedicated mongos.") + "--embeddedRouter", + dest="embedded_router", + metavar="CONFIG", + help="If set, uses embedded routers instead of dedicated mongos.", + ) internal_options = parser.add_argument_group( title=_INTERNAL_OPTIONS_TITLE, - description=("Internal options for advanced users and resmoke developers." - " These are not meant to be invoked when running resmoke locally.")) + description=( + "Internal options for advanced users and resmoke developers." + " These are not meant to be invoked when running resmoke locally." + ), + ) internal_options.add_argument( - "--log", dest="logger_file", metavar="LOGGER", - help=("A YAML file that specifies the logging configuration. If the file is" - " located in the resmokeconfig/suites/ directory, then the basename" - " without the .yml extension can be specified, e.g. 'console'.")) + "--log", + dest="logger_file", + metavar="LOGGER", + help=( + "A YAML file that specifies the logging configuration. If the file is" + " located in the resmokeconfig/suites/ directory, then the basename" + " without the .yml extension can be specified, e.g. 'console'." + ), + ) # Used for testing resmoke. # @@ -1284,206 +1742,340 @@ class RunPlugin(PluginInterface): # `test_analysis`: # When specified, the hang-analyzer writes out the pids it will analyze without # actually running analysis, which can be time and resource intensive. - internal_options.add_argument("--internalParam", action="append", dest="internal_params", - help=argparse.SUPPRESS) - - internal_options.add_argument("--cedarReportFile", dest="cedar_report_file", - metavar="CEDAR_REPORT", - help="Writes a JSON file with performance test results.") + internal_options.add_argument( + "--internalParam", + action="append", + dest="internal_params", + help=argparse.SUPPRESS, + ) internal_options.add_argument( - "--reportFile", dest="report_file", metavar="REPORT", - help="Writes a JSON file with test status and timing information.") + "--cedarReportFile", + dest="cedar_report_file", + metavar="CEDAR_REPORT", + help="Writes a JSON file with performance test results.", + ) internal_options.add_argument( - "--staggerJobs", action="store", dest="stagger_jobs", choices=("on", "off"), - metavar="ON|OFF", help=("Enables or disables the stagger of launching resmoke jobs." - " Defaults to %(default)s.")) + "--reportFile", + dest="report_file", + metavar="REPORT", + help="Writes a JSON file with test status and timing information.", + ) internal_options.add_argument( - "--exportMongodConfig", dest="export_mongod_config", choices=("off", "regular", - "detailed"), - help=("Exports a yaml containing the history of each mongod config option to" - " {nodeName}_config.yml." - " Defaults to 'off'. A 'detailed' export will include locations of accesses.")) + "--staggerJobs", + action="store", + dest="stagger_jobs", + choices=("on", "off"), + metavar="ON|OFF", + help=( + "Enables or disables the stagger of launching resmoke jobs." + " Defaults to %(default)s." + ), + ) + + internal_options.add_argument( + "--exportMongodConfig", + dest="export_mongod_config", + choices=("off", "regular", "detailed"), + help=( + "Exports a yaml containing the history of each mongod config option to" + " {nodeName}_config.yml." + " Defaults to 'off'. A 'detailed' export will include locations of accesses." + ), + ) evergreen_options = parser.add_argument_group( - title=_EVERGREEN_ARGUMENT_TITLE, description=( + title=_EVERGREEN_ARGUMENT_TITLE, + description=( "Options used to propagate information about the Evergreen task running this" - " script.")) - - evergreen_options.add_argument("--evergreenURL", dest="evergreen_url", - metavar="EVERGREEN_URL", - help=("The URL of the Evergreen service.")) + " script." + ), + ) evergreen_options.add_argument( - "--archiveLimitMb", type=int, dest="archive_limit_mb", metavar="ARCHIVE_LIMIT_MB", - help=("Sets the limit (in MB) for archived files to S3. A value of 0" - " indicates there is no limit.")) + "--evergreenURL", + dest="evergreen_url", + metavar="EVERGREEN_URL", + help=("The URL of the Evergreen service."), + ) evergreen_options.add_argument( - "--archiveLimitTests", type=int, dest="archive_limit_tests", + "--archiveLimitMb", + type=int, + dest="archive_limit_mb", + metavar="ARCHIVE_LIMIT_MB", + help=( + "Sets the limit (in MB) for archived files to S3. A value of 0" + " indicates there is no limit." + ), + ) + + evergreen_options.add_argument( + "--archiveLimitTests", + type=int, + dest="archive_limit_tests", metavar="ARCHIVE_LIMIT_TESTS", - help=("Sets the maximum number of tests to archive to S3. A value" - " of 0 indicates there is no limit.")) - - evergreen_options.add_argument("--buildId", dest="build_id", metavar="BUILD_ID", - help="Sets the build ID of the task.") - - evergreen_options.add_argument("--buildloggerUrl", action="store", dest="buildlogger_url", - metavar="URL", - help="The root url of the buildlogger server.") + help=( + "Sets the maximum number of tests to archive to S3. A value" + " of 0 indicates there is no limit." + ), + ) evergreen_options.add_argument( - "--distroId", dest="distro_id", metavar="DISTRO_ID", - help=("Sets the identifier for the Evergreen distro running the" - " tests.")) + "--buildId", + dest="build_id", + metavar="BUILD_ID", + help="Sets the build ID of the task.", + ) evergreen_options.add_argument( - "--executionNumber", type=int, dest="execution_number", metavar="EXECUTION_NUMBER", - help=("Sets the number for the Evergreen execution running the" - " tests.")) + "--buildloggerUrl", + action="store", + dest="buildlogger_url", + metavar="URL", + help="The root url of the buildlogger server.", + ) evergreen_options.add_argument( - "--gitRevision", dest="git_revision", metavar="GIT_REVISION", - help=("Sets the git revision for the Evergreen task running the" - " tests.")) + "--distroId", + dest="distro_id", + metavar="DISTRO_ID", + help=("Sets the identifier for the Evergreen distro running the" " tests."), + ) + + evergreen_options.add_argument( + "--executionNumber", + type=int, + dest="execution_number", + metavar="EXECUTION_NUMBER", + help=("Sets the number for the Evergreen execution running the" " tests."), + ) + + evergreen_options.add_argument( + "--gitRevision", + dest="git_revision", + metavar="GIT_REVISION", + help=("Sets the git revision for the Evergreen task running the" " tests."), + ) # We intentionally avoid adding a new command line option that starts with --suite so it doesn't # become ambiguous with the --suites option and break how engineers run resmoke.py locally. evergreen_options.add_argument( - "--originSuite", dest="origin_suite", metavar="SUITE", - help=("Indicates the name of the test suite prior to the" - " evergreen_generate_resmoke_tasks.py script splitting it" - " up.")) + "--originSuite", + dest="origin_suite", + metavar="SUITE", + help=( + "Indicates the name of the test suite prior to the" + " evergreen_generate_resmoke_tasks.py script splitting it" + " up." + ), + ) evergreen_options.add_argument( - "--patchBuild", action="store_true", dest="patch_build", - help=("Indicates that the Evergreen task running the tests is a" - " patch build.")) + "--patchBuild", + action="store_true", + dest="patch_build", + help=( + "Indicates that the Evergreen task running the tests is a" + " patch build." + ), + ) evergreen_options.add_argument( - "--projectName", dest="project_name", metavar="PROJECT_NAME", - help=("Sets the name of the Evergreen project running the tests.")) - - evergreen_options.add_argument("--revisionOrderId", dest="revision_order_id", - metavar="REVISION_ORDER_ID", - help="Sets the chronological order number of this commit.") + "--projectName", + dest="project_name", + metavar="PROJECT_NAME", + help=("Sets the name of the Evergreen project running the tests."), + ) evergreen_options.add_argument( - "--taskName", dest="task_name", metavar="TASK_NAME", - help="Sets the name of the Evergreen task running the tests.") - - evergreen_options.add_argument("--taskId", dest="task_id", metavar="TASK_ID", - help="Sets the Id of the Evergreen task running the tests.") + "--revisionOrderId", + dest="revision_order_id", + metavar="REVISION_ORDER_ID", + help="Sets the chronological order number of this commit.", + ) evergreen_options.add_argument( - "--variantName", dest="variant_name", metavar="VARIANT_NAME", - help=("Sets the name of the Evergreen build variant running the" - " tests.")) - - evergreen_options.add_argument("--versionId", dest="version_id", metavar="VERSION_ID", - help="Sets the version ID of the task.") - - evergreen_options.add_argument("--taskWorkDir", dest="work_dir", metavar="TASK_WORK_DIR", - help="Sets the working directory of the task.") + "--taskName", + dest="task_name", + metavar="TASK_NAME", + help="Sets the name of the Evergreen task running the tests.", + ) evergreen_options.add_argument( - "--projectConfigPath", dest="evg_project_config_path", - help="Sets the path to evergreen project configuration yaml.") + "--taskId", + dest="task_id", + metavar="TASK_ID", + help="Sets the Id of the Evergreen task running the tests.", + ) + + evergreen_options.add_argument( + "--variantName", + dest="variant_name", + metavar="VARIANT_NAME", + help=("Sets the name of the Evergreen build variant running the" " tests."), + ) + + evergreen_options.add_argument( + "--versionId", + dest="version_id", + metavar="VERSION_ID", + help="Sets the version ID of the task.", + ) + + evergreen_options.add_argument( + "--taskWorkDir", + dest="work_dir", + metavar="TASK_WORK_DIR", + help="Sets the working directory of the task.", + ) + + evergreen_options.add_argument( + "--projectConfigPath", + dest="evg_project_config_path", + help="Sets the path to evergreen project configuration yaml.", + ) benchmark_options = parser.add_argument_group( title=_BENCHMARK_ARGUMENT_TITLE, - description="Options for running Benchmark/Benchrun tests") + description="Options for running Benchmark/Benchrun tests", + ) - benchmark_options.add_argument("--benchmarkFilter", type=str, dest="benchmark_filter", - metavar="BENCHMARK_FILTER", - help="Regex to filter Google benchmark tests to run.") + benchmark_options.add_argument( + "--benchmarkFilter", + type=str, + dest="benchmark_filter", + metavar="BENCHMARK_FILTER", + help="Regex to filter Google benchmark tests to run.", + ) benchmark_options.add_argument( "--benchmarkListTests", dest="benchmark_list_tests", action="store_true", # metavar="BENCHMARK_LIST_TESTS", - help=("Lists all Google benchmark test configurations in each" - " test file.")) + help=( + "Lists all Google benchmark test configurations in each" " test file." + ), + ) benchmark_min_time_help = ( "Minimum time to run each benchmark/benchrun test for. Use this option instead of " - "--benchmarkRepetitions to make a test run for a longer or shorter duration.") - benchmark_options.add_argument("--benchmarkMinTimeSecs", type=int, - dest="benchmark_min_time_secs", metavar="BENCHMARK_MIN_TIME", - help=benchmark_min_time_help) + "--benchmarkRepetitions to make a test run for a longer or shorter duration." + ) + benchmark_options.add_argument( + "--benchmarkMinTimeSecs", + type=int, + dest="benchmark_min_time_secs", + metavar="BENCHMARK_MIN_TIME", + help=benchmark_min_time_help, + ) benchmark_repetitions_help = ( "Set --benchmarkRepetitions=1 if you'd like to run the benchmark/benchrun tests only once." " By default, each test is run multiple times to provide statistics on the variance" " between runs; use --benchmarkMinTimeSecs if you'd like to run a test for a longer or" - " shorter duration.") + " shorter duration." + ) benchmark_options.add_argument( - "--benchmarkRepetitions", type=int, dest="benchmark_repetitions", - metavar="BENCHMARK_REPETITIONS", help=benchmark_repetitions_help) + "--benchmarkRepetitions", + type=int, + dest="benchmark_repetitions", + metavar="BENCHMARK_REPETITIONS", + help=benchmark_repetitions_help, + ) @classmethod def _add_list_suites(cls, subparsers: argparse._SubParsersAction): """Create and add the parser for the list-suites subcommand.""" - parser = subparsers.add_parser("list-suites", - help="Lists the names of the suites available to execute.") + parser = subparsers.add_parser( + "list-suites", help="Lists the names of the suites available to execute." + ) parser.add_argument( - "--log", dest="logger_file", metavar="LOGGER", - help=("A YAML file that specifies the logging configuration. If the file is" - " located in the resmokeconfig/suites/ directory, then the basename" - " without the .yml extension can be specified, e.g. 'console'.")) + "--log", + dest="logger_file", + metavar="LOGGER", + help=( + "A YAML file that specifies the logging configuration. If the file is" + " located in the resmokeconfig/suites/ directory, then the basename" + " without the .yml extension can be specified, e.g. 'console'." + ), + ) parser.set_defaults(logger_file="console") @classmethod def _add_generate(cls, subparsers: argparse._SubParsersAction): """Create and add the parser for the generate subcommand.""" - subparsers.add_parser("generate-matrix-suites", - help="Generate matrix suite config files from the mapping files.") + subparsers.add_parser( + "generate-matrix-suites", + help="Generate matrix suite config files from the mapping files.", + ) @classmethod def _add_find_suites(cls, subparsers: argparse._SubParsersAction): """Create and add the parser for the find-suites subcommand.""" parser = subparsers.add_parser( "find-suites", - help="Lists the names of the suites that will execute the specified tests.") + help="Lists the names of the suites that will execute the specified tests.", + ) # find-suites shares a lot of code with 'run' (for now), and this option needs be specified, # though it is not used. parser.set_defaults(logger_file="console") - parser.add_argument("test_files", metavar="TEST_FILES", nargs="*", - help="Explicit test files to run") + parser.add_argument( + "test_files", + metavar="TEST_FILES", + nargs="*", + help="Explicit test files to run", + ) @classmethod def _add_list_tags(cls, subparsers: argparse._SubParsersAction): """Create and add the parser for the list-tags subcommand.""" parser = subparsers.add_parser( - "list-tags", help="Lists the tags and their documentation available in the suites.") + "list-tags", + help="Lists the tags and their documentation available in the suites.", + ) parser.set_defaults(logger_file="console") parser.add_argument( - "--suites", dest="suite_files", metavar="SUITE1,SUITE2", - help=("Comma separated list of suite names to get tags from." - " All suites are used if unspecified.")) + "--suites", + dest="suite_files", + metavar="SUITE1,SUITE2", + help=( + "Comma separated list of suite names to get tags from." + " All suites are used if unspecified." + ), + ) @classmethod - def _add_generate_multiversion_exclude_tags(cls, subparser: argparse._SubParsersAction): + def _add_generate_multiversion_exclude_tags( + cls, subparser: argparse._SubParsersAction + ): """Create and add the parser for the generate-multiversion-exclude-tags subcommand.""" parser = subparser.add_parser( "generate-multiversion-exclude-tags", help="Create a tag file associating multiversion tests to tags for exclusion." " Compares the BACKPORTS_REQUIRED_FILE on the current branch with the same file on the" - " last-lts and/or last-continuous branch to determine which tests should be denylisted." + " last-lts and/or last-continuous branch to determine which tests should be denylisted.", ) parser.set_defaults(logger_file="console") parser.add_argument( - "--oldBinVersion", type=str, dest="old_bin_version", + "--oldBinVersion", + type=str, + dest="old_bin_version", choices=config.MultiversionOptions.all_options(), - help="Choose the multiversion binary version as last-lts or last-continuous.") - parser.add_argument("--excludeTagsFilePath", type=str, dest="exclude_tags_file_path", - help="Where to output the generated tags.") + help="Choose the multiversion binary version as last-lts or last-continuous.", + ) + parser.add_argument( + "--excludeTagsFilePath", + type=str, + dest="exclude_tags_file_path", + help="Where to output the generated tags.", + ) def to_local_args(input_args: Optional[List[str]] = None): @@ -1499,7 +2091,7 @@ def to_local_args(input_args: Optional[List[str]] = None): (parser, parsed_args) = main_parser.parse(input_args) - if parsed_args.command != 'run': + if parsed_args.command != "run": raise TypeError( f"to_local_args can only be called for the 'run' subcommand. Instead was called on '{parsed_args.command}'" ) @@ -1513,7 +2105,8 @@ def to_local_args(input_args: Optional[List[str]] = None): # The top-level parser has one subparser that contains all subcommand parsers. command_subparser = [ - action for action in parser._actions # pylint: disable=protected-access + action + for action in parser._actions # pylint: disable=protected-access if action.dest == "command" ][0] @@ -1559,12 +2152,14 @@ def to_local_args(input_args: Optional[List[str]] = None): continue # Skip any evergreen centric args. elif group.title in [ - _INTERNAL_OPTIONS_TITLE, _EVERGREEN_ARGUMENT_TITLE, _CEDAR_ARGUMENT_TITLE + _INTERNAL_OPTIONS_TITLE, + _EVERGREEN_ARGUMENT_TITLE, + _CEDAR_ARGUMENT_TITLE, ]: continue elif arg_dest in skipped_args: continue - elif group.title == 'positional arguments': + elif group.title == "positional arguments": positional_args.extend(arg_value) # Keep all remaining args. else: @@ -1591,8 +2186,12 @@ def to_local_args(input_args: Optional[List[str]] = None): else: other_local_args.append(arg) - return ["run"] + [arg for arg in (suites_arg, storage_engine_arg) if arg is not None - ] + other_local_args + positional_args + return ( + ["run"] + + [arg for arg in (suites_arg, storage_engine_arg) if arg is not None] + + other_local_args + + positional_args + ) def strip_fuzz_config_params(input_args: List[str]): @@ -1600,7 +2199,9 @@ def strip_fuzz_config_params(input_args: List[str]): ret = [] for arg in input_args: - if not arg.startswith(("--fuzzMongodConfigs", "--fuzzMongosConfigs", "--configFuzzSeed")): + if not arg.startswith( + ("--fuzzMongodConfigs", "--fuzzMongosConfigs", "--configFuzzSeed") + ): ret.append(arg) return ret diff --git a/buildscripts/resmokelib/run/generate_multiversion_exclude_tags.py b/buildscripts/resmokelib/run/generate_multiversion_exclude_tags.py index 447c153002a..20a3440316f 100755 --- a/buildscripts/resmokelib/run/generate_multiversion_exclude_tags.py +++ b/buildscripts/resmokelib/run/generate_multiversion_exclude_tags.py @@ -1,4 +1,5 @@ """Generate multiversion exclude tags file.""" + import logging import os import re @@ -31,23 +32,27 @@ def get_backports_required_hash_for_shell_version(mongo_shell_path: str | None = if is_windows(): mongo_shell = mongo_shell_path + ".exe" - shell_version = check_output(f"{mongo_shell} --version", shell=True, - env=env_vars).decode('utf-8') + shell_version = check_output( + f"{mongo_shell} --version", shell=True, env=env_vars + ).decode("utf-8") for line in shell_version.splitlines(): if "gitVersion" in line: - version_line = line.split(':')[1] + version_line = line.split(":")[1] # We identify the commit hash as the string enclosed by double quotation marks. result = re.search(r'"(.*?)"', version_line) if result: commit_hash = result.group().strip('"') if not commit_hash.isalnum(): - raise ValueError(f"Error parsing commit hash. Expected an " - f"alpha-numeric string but got: {commit_hash}") + raise ValueError( + f"Error parsing commit hash. Expected an " + f"alpha-numeric string but got: {commit_hash}" + ) return commit_hash else: break raise ValueError( - f"Could not find a valid commit hash from the {mongo_shell_path} mongo binary.") + f"Could not find a valid commit hash from the {mongo_shell_path} mongo binary." + ) def get_git_file_content(commit_hash: str) -> str: @@ -63,8 +68,12 @@ def get_git_file_content(commit_hash: str) -> str: try: # If the git show command failed once, we attempt to shallow fetch the commit # to ensure we have the commit's contents then try again. - _ = subprocess.run(git_fetch_command, capture_output=True, text=True, check=True) - result = subprocess.run(git_command, capture_output=True, text=True, check=True) + _ = subprocess.run( + git_fetch_command, capture_output=True, text=True, check=True + ) + result = subprocess.run( + git_command, capture_output=True, text=True, check=True + ) except subprocess.CalledProcessError as err: raise RuntimeError( f"Failed to retrieve file content using command: {' '.join(git_command)}. Error: {err.stderr}" @@ -87,7 +96,9 @@ def get_old_yaml(commit_hash: str): return backports_required_old -def generate_exclude_yaml(old_bin_version: str, output: str, logger: logging.Logger) -> None: +def generate_exclude_yaml( + old_bin_version: str, output: str, logger: logging.Logger +) -> None: """ Create a tag file associating multiversion tests to tags for exclusion. @@ -100,7 +111,9 @@ def generate_exclude_yaml(old_bin_version: str, output: str, logger: logging.Log if not os.path.isdir(location): os.makedirs(location) - backports_required_latest = read_yaml_file(os.path.join(ETC_DIR, BACKPORTS_REQUIRED_FILE)) + backports_required_latest = read_yaml_file( + os.path.join(ETC_DIR, BACKPORTS_REQUIRED_FILE) + ) # Get the state of the backports_required_for_multiversion_tests.yml file for the old # binary we are running tests against. We do this by using the commit hash from the old @@ -113,36 +126,46 @@ def generate_exclude_yaml(old_bin_version: str, output: str, logger: logging.Log }[old_bin_version] old_version_commit_hash = get_backports_required_hash_for_shell_version( - mongo_shell_path=shell_version) + mongo_shell_path=shell_version + ) # Get the yaml contents from the old commit. - logger.info(f"Downloading file from commit hash of old branch {old_version_commit_hash}") + logger.info( + f"Downloading file from commit hash of old branch {old_version_commit_hash}" + ) backports_required_old = get_old_yaml(old_version_commit_hash) def diff(list1, list2): return [elem for elem in (list1 or []) if elem not in (list2 or [])] def get_suite_exclusions(version_key): - _suites_latest = backports_required_latest[version_key]["suites"] or {} # Check if the changed syntax for etc/backports_required_for_multiversion_tests.yml has been # backported. # This variable and all branches where it's not set can be deleted after backporting the change. change_backported = version_key in backports_required_old.keys() if change_backported: - _always_exclude = diff(backports_required_latest[version_key]["all"], - backports_required_old[version_key]["all"]) + _always_exclude = diff( + backports_required_latest[version_key]["all"], + backports_required_old[version_key]["all"], + ) _suites_old: defaultdict = defaultdict( - list, backports_required_old[version_key]["suites"] or {}) + list, backports_required_old[version_key]["suites"] or {} + ) else: - _always_exclude = diff(backports_required_latest[version_key]["all"], - backports_required_old["all"]) - _suites_old: defaultdict = defaultdict(list, backports_required_old["suites"] or {}) + _always_exclude = diff( + backports_required_latest[version_key]["all"], + backports_required_old["all"], + ) + _suites_old: defaultdict = defaultdict( + list, backports_required_old["suites"] or {} + ) return _suites_latest, _suites_old, _always_exclude suites_latest, suites_old, always_exclude = get_suite_exclusions( - old_bin_version.replace("_", "-")) + old_bin_version.replace("_", "-") + ) tags = _tags.TagsConfig() @@ -159,5 +182,7 @@ def generate_exclude_yaml(old_bin_version: str, output: str, logger: logging.Log tags.add_tag("js_test", test, f"{suite}_{BACKPORT_REQUIRED_TAG}") logger.info(f"Writing exclude tags to {output}.") - tags.write_file(filename=output, - preamble="Tag file that specifies exclusions from multiversion suites.") + tags.write_file( + filename=output, + preamble="Tag file that specifies exclusions from multiversion suites.", + ) diff --git a/buildscripts/resmokelib/run/list_tags.py b/buildscripts/resmokelib/run/list_tags.py index 6119875f816..d9d623c9300 100644 --- a/buildscripts/resmokelib/run/list_tags.py +++ b/buildscripts/resmokelib/run/list_tags.py @@ -61,6 +61,6 @@ def make_output(tag_docs): output = "" for tag, doc in sorted(tag_docs.items()): newline = "\n" - wrapped_doc = textwrap.indent(doc, '\t') + wrapped_doc = textwrap.indent(doc, "\t") output = f"{output}{newline}{tag}:{newline}{wrapped_doc}" return output diff --git a/buildscripts/resmokelib/run/runtime_recorder.py b/buildscripts/resmokelib/run/runtime_recorder.py index 67fe7453756..71a7d934620 100644 --- a/buildscripts/resmokelib/run/runtime_recorder.py +++ b/buildscripts/resmokelib/run/runtime_recorder.py @@ -22,6 +22,8 @@ def compare_start_time(cur_time_secs): cur_timefile = utils.load_yaml_file(_START_TIME_FILE) start_time_secs = cur_timefile["start_time"] except (FileNotFoundError, KeyError) as erros: - raise FileNotFoundError("resmoke.py did not successfully record its start time") from erros + raise FileNotFoundError( + "resmoke.py did not successfully record its start time" + ) from erros return cur_time_secs - start_time_secs diff --git a/buildscripts/resmokelib/selector.py b/buildscripts/resmokelib/selector.py index a99289c269f..eabe3dfff82 100644 --- a/buildscripts/resmokelib/selector.py +++ b/buildscripts/resmokelib/selector.py @@ -25,10 +25,12 @@ ENTERPRISE_TEST_DIR = os.path.normpath("src/mongo/db/modules/enterprise/jstests" _DO_NOT_MATCH_ANY_EXISTING_TEST_FILES_MESSAGE = ( "Pattern(s) and/or filename(s) in `{config_section}`" - " do not match any existing test files: {not_matching_paths}") + " do not match any existing test files: {not_matching_paths}" +) _DO_NOT_MATCH_ANY_TEST_FILES_FROM_ROOTS_MESSAGE = ( "Pattern(s) and/or filename(s) in `{config_section}`" - " do not match any test files from `roots`: {not_matching_paths}") + " do not match any test files from `roots`: {not_matching_paths}" +) class TestFileExplorer(object): @@ -99,9 +101,12 @@ class TestFileExplorer(object): returncode, stdout, stderr = self._run_program(dbtest_binary, ["--list"]) if returncode != 0: - raise errors.ResmokeError("Getting list of dbtest suites failed" - ", dbtest_binary=`{}`: stdout=`{}`, stderr=`{}`".format( - dbtest_binary, stdout, stderr)) + raise errors.ResmokeError( + "Getting list of dbtest suites failed" + ", dbtest_binary=`{}`: stdout=`{}`, stderr=`{}`".format( + dbtest_binary, stdout, stderr + ) + ) return stdout.splitlines() @staticmethod @@ -116,7 +121,9 @@ class TestFileExplorer(object): """ command = [binary] command.extend(args) - program = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + program = subprocess.Popen( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) stdout, stderr = program.communicate() return program.returncode, stdout.decode("utf-8"), stderr.decode("utf-8") @@ -150,7 +157,9 @@ class TestFileExplorer(object): # TODO SERVER-77265 always validate tag file input when mongo-task-generator # no longer passes in invalid tag files if not config.EVERGREEN_TASK_ID: - raise errors.TagFileDoesNotExistError(f"A tag file was not found at {tag_file}") + raise errors.TagFileDoesNotExistError( + f"A tag file was not found at {tag_file}" + ) return tagged_tests @@ -179,8 +188,12 @@ class _TestList(object): 'include_files()' or 'exclude_files()' will raise an TypeError. """ - def __init__(self, test_file_explorer: TestFileExplorer, roots: List[str], - tests_are_files: bool = True) -> None: + def __init__( + self, + test_file_explorer: TestFileExplorer, + roots: List[str], + tests_are_files: bool = True, + ) -> None: """Initialize the _TestList with a TestFileExplorer component and a list of root tests.""" self._test_file_explorer = test_file_explorer self._tests_are_files = tests_are_files @@ -199,7 +212,8 @@ class _TestList(object): len_after = len(evaluated) if len_after == len_before and path.startswith( - self._test_file_explorer.get_jstests_dir()): + self._test_file_explorer.get_jstests_dir() + ): unrecognized.append(path) elif self._test_file_explorer.isfile(path): @@ -215,7 +229,9 @@ class _TestList(object): if len(paths.unrecognized) > 0: raise errors.SuiteSelectorConfigurationError( _DO_NOT_MATCH_ANY_EXISTING_TEST_FILES_MESSAGE.format( - config_section="roots", not_matching_paths=paths.unrecognized)) + config_section="roots", not_matching_paths=paths.unrecognized + ) + ) return paths.evaluated def include_files(self, include_files: List[str]) -> None: @@ -231,16 +247,24 @@ class _TestList(object): if len(paths.unrecognized) > 0: raise errors.SuiteSelectorConfigurationError( _DO_NOT_MATCH_ANY_EXISTING_TEST_FILES_MESSAGE.format( - config_section="include_files", not_matching_paths=paths.unrecognized)) + config_section="include_files", + not_matching_paths=paths.unrecognized, + ) + ) paths_missing_from_roots = [ - path for path in paths.evaluated if - path.startswith(self._test_file_explorer.get_jstests_dir()) and path not in self._roots + path + for path in paths.evaluated + if path.startswith(self._test_file_explorer.get_jstests_dir()) + and path not in self._roots ] if len(paths_missing_from_roots) > 0: raise errors.SuiteSelectorConfigurationError( _DO_NOT_MATCH_ANY_TEST_FILES_FROM_ROOTS_MESSAGE.format( - config_section="include_files", not_matching_paths=paths_missing_from_roots)) + config_section="include_files", + not_matching_paths=paths_missing_from_roots, + ) + ) self._filtered = set(paths.evaluated) @@ -257,16 +281,24 @@ class _TestList(object): if len(paths.unrecognized) > 0: raise errors.SuiteSelectorConfigurationError( _DO_NOT_MATCH_ANY_EXISTING_TEST_FILES_MESSAGE.format( - config_section="exclude_files", not_matching_paths=paths.unrecognized)) + config_section="exclude_files", + not_matching_paths=paths.unrecognized, + ) + ) paths_missing_from_roots = [ - path for path in paths.evaluated if - path.startswith(self._test_file_explorer.get_jstests_dir()) and path not in self._roots + path + for path in paths.evaluated + if path.startswith(self._test_file_explorer.get_jstests_dir()) + and path not in self._roots ] if len(paths_missing_from_roots) > 0: raise errors.SuiteSelectorConfigurationError( _DO_NOT_MATCH_ANY_TEST_FILES_FROM_ROOTS_MESSAGE.format( - config_section="exclude_files", not_matching_paths=paths_missing_from_roots)) + config_section="exclude_files", + not_matching_paths=paths_missing_from_roots, + ) + ) for path in paths.evaluated: self._filtered.discard(path) @@ -275,7 +307,8 @@ class _TestList(object): """Exclude tests that start with the enterprise module directory from the test list.""" self._filtered = { test - for test in self._filtered if not os.path.normpath(test).startswith(ENTERPRISE_TEST_DIR) + for test in self._filtered + if not os.path.normpath(test).startswith(ENTERPRISE_TEST_DIR) } def match_tag_expression(self, tag_expression, get_tags): @@ -287,7 +320,9 @@ class _TestList(object): get_tags: a callable object that takes a test and returns the corresponding list of tags. """ - self._filtered = {test for test in self._filtered if tag_expression(get_tags(test))} + self._filtered = { + test for test in self._filtered if tag_expression(get_tags(test)) + } def include_any_pattern(self, patterns): """Filter the test list to only include tests that match any provided glob patterns.""" @@ -402,9 +437,18 @@ def _make_expression_list(configs): class _SelectorConfig(object): """Base object to represent the configuration for test selection.""" - def __init__(self, root=None, roots=None, include_files=None, exclude_files=None, - include_tags=None, exclude_tags=None, include_with_any_tags=None, - exclude_with_any_tags=None, tag_file=None): + def __init__( + self, + root=None, + roots=None, + include_files=None, + exclude_files=None, + include_tags=None, + exclude_tags=None, + include_with_any_tags=None, + exclude_with_any_tags=None, + tag_file=None, + ): """Initialize the _SelectorConfig from the configuration elements. Args: @@ -424,24 +468,32 @@ class _SelectorConfig(object): if root and roots: raise ValueError("root and roots cannot be specified at the same time") if include_tags and exclude_tags: - raise ValueError("include_tags and exclude_tags cannot be specified at the same time") + raise ValueError( + "include_tags and exclude_tags cannot be specified at the same time" + ) self.root = root self.roots = roots self.tag_file = tag_file self.include_files = utils.default_if_none(include_files, []) self.exclude_files = utils.default_if_none(exclude_files, []) - include_with_any_tags = self.__merge_lists(include_with_any_tags, - config.INCLUDE_WITH_ANY_TAGS) - exclude_with_any_tags = self.__merge_lists(exclude_with_any_tags, - config.EXCLUDE_WITH_ANY_TAGS) + include_with_any_tags = self.__merge_lists( + include_with_any_tags, config.INCLUDE_WITH_ANY_TAGS + ) + exclude_with_any_tags = self.__merge_lists( + exclude_with_any_tags, config.EXCLUDE_WITH_ANY_TAGS + ) # This is functionally similar to `include_tags` but contains a list of tags rather # than an expression. include_with_all_tags = config.INCLUDE_TAGS self.tags_expression = self.__make_tags_expression( - include_tags, exclude_tags, include_with_any_tags, exclude_with_any_tags, - include_with_all_tags) + include_tags, + exclude_tags, + include_with_any_tags, + exclude_with_any_tags, + include_with_all_tags, + ) @staticmethod def __merge_lists(list_a, list_b): @@ -454,13 +506,20 @@ class _SelectorConfig(object): return None @staticmethod - def __make_tags_expression(include_tags, exclude_tags, include_with_any_tags, - exclude_with_any_tags, include_with_all_tags): + def __make_tags_expression( + include_tags, + exclude_tags, + include_with_any_tags, + exclude_with_any_tags, + include_with_all_tags, + ): expressions = [] if include_tags: expressions.append(make_expression(include_tags)) if include_with_all_tags: - include_with_all_tags_expr = make_expression({"$allOf": include_with_all_tags}) + include_with_all_tags_expr = make_expression( + {"$allOf": include_with_all_tags} + ) expressions.append(include_with_all_tags_expr) elif exclude_tags: expressions.append(_NotExpression(make_expression(exclude_tags))) @@ -468,7 +527,9 @@ class _SelectorConfig(object): include_with_any_expr = make_expression({"$anyOf": include_with_any_tags}) expressions.append(include_with_any_expr) if exclude_with_any_tags: - exclude_with_any_expr = make_expression({"$not": {"$anyOf": exclude_with_any_tags}}) + exclude_with_any_expr = make_expression( + {"$not": {"$anyOf": exclude_with_any_tags}} + ) expressions.append(exclude_with_any_expr) if expressions: @@ -513,7 +574,9 @@ class _Selector(object): test_list.exclude_files(selector_config.exclude_files) # 4. Apply the tag filters. if selector_config.tags_expression: - test_list.match_tag_expression(selector_config.tags_expression, self.get_tags) + test_list.match_tag_expression( + selector_config.tags_expression, self.get_tags + ) # 5. Apply the include files last with force=True to take precedence over the tags. if self._tests_are_files and selector_config.include_files: test_list.include_files(selector_config.include_files) @@ -544,14 +607,28 @@ class _Selector(object): class _JSTestSelectorConfig(_SelectorConfig): """_SelectorConfig subclass for JavaScript tests.""" - def __init__(self, roots=None, include_files=None, exclude_files=None, - include_with_any_tags=None, exclude_with_any_tags=None, include_tags=None, - exclude_tags=None, tag_file=None): + def __init__( + self, + roots=None, + include_files=None, + exclude_files=None, + include_with_any_tags=None, + exclude_with_any_tags=None, + include_tags=None, + exclude_tags=None, + tag_file=None, + ): _SelectorConfig.__init__( - self, roots=roots, include_files=include_files, exclude_files=exclude_files, + self, + roots=roots, + include_files=include_files, + exclude_files=exclude_files, include_with_any_tags=include_with_any_tags, - exclude_with_any_tags=exclude_with_any_tags, include_tags=include_tags, - exclude_tags=exclude_tags, tag_file=tag_file) + exclude_with_any_tags=exclude_with_any_tags, + include_tags=include_tags, + exclude_tags=exclude_tags, + tag_file=tag_file, + ) class _JSTestSelector(_Selector): @@ -559,11 +636,14 @@ class _JSTestSelector(_Selector): def __init__(self, test_file_explorer): _Selector.__init__(self, test_file_explorer) - self._tags = self._test_file_explorer.parse_tag_files("js_test", config.TAG_FILES) + self._tags = self._test_file_explorer.parse_tag_files( + "js_test", config.TAG_FILES + ) def select(self, selector_config): - self._tags = self._test_file_explorer.parse_tag_files("js_test", [selector_config.tag_file], - self._tags) + self._tags = self._test_file_explorer.parse_tag_files( + "js_test", [selector_config.tag_file], self._tags + ) return _Selector.select(self, selector_config) def get_tags(self, test_file): @@ -632,7 +712,7 @@ class _MultiJSTestSelector(_JSTestSelector): random.shuffle(recycled_tests) corpus = corpus[start:] + recycled_tests start = 0 - grouped_tests.append(corpus[start:start + group_size]) + grouped_tests.append(corpus[start : start + group_size]) start += group_size return grouped_tests, excluded @@ -645,17 +725,30 @@ class _MultiJSTestSelector(_JSTestSelector): class _CppTestSelectorConfig(_SelectorConfig): """_SelectorConfig subclass for cpp_integration_test and cpp_unit_test tests.""" - def __init__(self, root=config.DEFAULT_INTEGRATION_TEST_LIST, roots=None, include_files=None, - exclude_files=None): + def __init__( + self, + root=config.DEFAULT_INTEGRATION_TEST_LIST, + roots=None, + include_files=None, + exclude_files=None, + ): """Initialize _CppTestSelectorConfig.""" if roots: # The 'roots' argument is only present when tests are specified on the command line # and in that case they take precedence over the tests in the root file. - _SelectorConfig.__init__(self, roots=roots, include_files=include_files, - exclude_files=exclude_files) + _SelectorConfig.__init__( + self, + roots=roots, + include_files=include_files, + exclude_files=exclude_files, + ) else: - _SelectorConfig.__init__(self, root=root, include_files=include_files, - exclude_files=exclude_files) + _SelectorConfig.__init__( + self, + root=root, + include_files=include_files, + exclude_files=exclude_files, + ) class _CppTestSelector(_Selector): @@ -678,17 +771,30 @@ class _CppTestSelector(_Selector): class _PrettyPrinterTestSelectorConfig(_SelectorConfig): """_SelectorConfig subclass for pretty-printer-tests.""" - def __init__(self, root=config.DEFAULT_INTEGRATION_TEST_LIST, roots=None, include_files=None, - exclude_files=None): + def __init__( + self, + root=config.DEFAULT_INTEGRATION_TEST_LIST, + roots=None, + include_files=None, + exclude_files=None, + ): """Initialize _PrettyPrinterTestSelectorConfig.""" if roots: # The 'roots' argument is only present when tests are specified on the command line # and in that case they take precedence over the tests in the root file. - _SelectorConfig.__init__(self, roots=roots, include_files=include_files, - exclude_files=exclude_files) + _SelectorConfig.__init__( + self, + roots=roots, + include_files=include_files, + exclude_files=exclude_files, + ) else: - _SelectorConfig.__init__(self, root=root, include_files=include_files, - exclude_files=exclude_files) + _SelectorConfig.__init__( + self, + root=root, + include_files=include_files, + exclude_files=exclude_files, + ) class _PrettyPrinterTestSelector(_Selector): @@ -766,8 +872,9 @@ class _FileBasedSelectorConfig(_SelectorConfig): def __init__(self, roots, include_files=None, exclude_files=None): """Initialize _FileBasedSelectorConfig.""" - _SelectorConfig.__init__(self, roots=roots, include_files=include_files, - exclude_files=exclude_files) + _SelectorConfig.__init__( + self, roots=roots, include_files=include_files, exclude_files=exclude_files + ) class _SleepTestCaseSelectorConfig(_SelectorConfig): @@ -790,8 +897,9 @@ class _PyTestCaseSelectorConfig(_SelectorConfig): """_SelectorConfig subclass for py_test tests.""" def __init__(self, roots, include_files=None, exclude_files=None): - _SelectorConfig.__init__(self, roots=roots, include_files=include_files, - exclude_files=exclude_files) + _SelectorConfig.__init__( + self, roots=roots, include_files=include_files, exclude_files=exclude_files + ) class _GennylibTestCaseSelectorConfig(_SelectorConfig): @@ -819,7 +927,10 @@ _DEFAULT_TEST_FILE_EXPLORER = TestFileExplorer() _SELECTOR_REGISTRY = { "cpp_integration_test": (_CppTestSelectorConfig, _CppTestSelector), "cpp_unit_test": (_CppTestSelectorConfig, _CppTestSelector), - "pretty_printer_test": (_PrettyPrinterTestSelectorConfig, _PrettyPrinterTestSelector), + "pretty_printer_test": ( + _PrettyPrinterTestSelectorConfig, + _PrettyPrinterTestSelector, + ), "benchmark_test": (_CppTestSelectorConfig, _CppTestSelector), "sdam_json_test": (_FileBasedSelectorConfig, _Selector), "server_selection_json_test": (_FileBasedSelectorConfig, _Selector), @@ -843,7 +954,9 @@ _SELECTOR_REGISTRY = { } -def filter_tests(test_kind, selector_config, test_file_explorer=_DEFAULT_TEST_FILE_EXPLORER): +def filter_tests( + test_kind, selector_config, test_file_explorer=_DEFAULT_TEST_FILE_EXPLORER +): """Filter the tests according to a specified configuration. Args: diff --git a/buildscripts/resmokelib/setup_multiversion/config.py b/buildscripts/resmokelib/setup_multiversion/config.py index e81fd445f58..77bcf80caae 100644 --- a/buildscripts/resmokelib/setup_multiversion/config.py +++ b/buildscripts/resmokelib/setup_multiversion/config.py @@ -1,7 +1,10 @@ """Setup multiversion config.""" + from typing import List -SETUP_MULTIVERSION_CONFIG = "buildscripts/resmokeconfig/setup_multiversion/setup_multiversion_config.yml" +SETUP_MULTIVERSION_CONFIG = ( + "buildscripts/resmokeconfig/setup_multiversion/setup_multiversion_config.yml" +) # Records the paths of installed multiversion binaries on Windows. WINDOWS_BIN_PATHS_FILE = "windows_binary_paths.txt" diff --git a/buildscripts/resmokelib/setup_multiversion/download.py b/buildscripts/resmokelib/setup_multiversion/download.py index 322d96d2929..d5b71af2418 100644 --- a/buildscripts/resmokelib/setup_multiversion/download.py +++ b/buildscripts/resmokelib/setup_multiversion/download.py @@ -1,4 +1,5 @@ """Helper functions to download.""" + import contextlib import errno import glob @@ -74,7 +75,7 @@ def download_from_s3(url): raise DownloadError("Download URL not found") LOGGER.info("Downloading.", url=url) - filename = os.path.join(mkdtemp_in_build_dir(), url.split('/')[-1].split('?')[0]) + filename = os.path.join(mkdtemp_in_build_dir(), url.split("/")[-1].split("?")[0]) arch = platform.uname().machine.lower() @@ -119,7 +120,9 @@ def _rsync_move_dir(source_dir, dest_dir): def extract_archive(archive_file, install_dir): """Uncompress file and return root of extracted directory.""" - LOGGER.info("Extracting archive data.", archive=archive_file, install_dir=install_dir) + LOGGER.info( + "Extracting archive data.", archive=archive_file, install_dir=install_dir + ) temp_dir = mkdtemp_in_build_dir() archive_name = os.path.basename(archive_file) _, file_suffix = os.path.splitext(archive_name) @@ -204,6 +207,7 @@ def symlink_version(suffix, installed_dir, link_dir=None): def symlink_ms(source, symlink_name): """Provide symlink for Windows.""" import ctypes + csl = ctypes.windll.kernel32.CreateSymbolicLinkW csl.argtypes = (ctypes.c_wchar_p, ctypes.c_wchar_p, ctypes.c_uint32) csl.restype = ctypes.c_ubyte @@ -213,7 +217,11 @@ def symlink_version(suffix, installed_dir, link_dir=None): link_method = symlink_ms link_method(executable, executable_link) - LOGGER.debug("Symlink created.", executable=executable, executable_link=executable_link) + LOGGER.debug( + "Symlink created.", + executable=executable, + executable_link=executable_link, + ) except OSError as exc: if exc.errno == errno.EEXIST: @@ -221,5 +229,7 @@ def symlink_version(suffix, installed_dir, link_dir=None): else: raise - LOGGER.info("Symlinks for all executables are created in the directory.", link_dir=link_dir) + LOGGER.info( + "Symlinks for all executables are created in the directory.", link_dir=link_dir + ) return link_dir diff --git a/buildscripts/resmokelib/setup_multiversion/github_conn.py b/buildscripts/resmokelib/setup_multiversion/github_conn.py index 90a3ffe4f97..e538e20be53 100644 --- a/buildscripts/resmokelib/setup_multiversion/github_conn.py +++ b/buildscripts/resmokelib/setup_multiversion/github_conn.py @@ -1,4 +1,5 @@ """Helper functions to interact with github.""" + from github import Github, GithubException @@ -30,4 +31,6 @@ def get_git_tag_and_commit(github_oauth_token, version): return None, git_commit.sha except GithubException as gh_exception: - raise GithubConnError(f"Commit hash for {version} not found. Error: {str(gh_exception)}") + raise GithubConnError( + f"Commit hash for {version} not found. Error: {str(gh_exception)}" + ) diff --git a/buildscripts/resmokelib/setup_multiversion/setup_multiversion.py b/buildscripts/resmokelib/setup_multiversion/setup_multiversion.py index 276ada10711..31fc24ac0e7 100644 --- a/buildscripts/resmokelib/setup_multiversion/setup_multiversion.py +++ b/buildscripts/resmokelib/setup_multiversion/setup_multiversion.py @@ -5,6 +5,7 @@ to include its version) into an install directory and symlinks the binaries with versions to another directory. This script supports community and enterprise builds. """ + import argparse import logging import os @@ -32,19 +33,21 @@ def infer_platform(edition=None, version=None): """Infer platform for popular OS.""" syst = platform.system() pltf = None - if syst == 'Darwin': - pltf = 'osx' - elif syst == 'Windows': - pltf = 'windows' - if edition == 'base' and version == "4.2": - pltf += '_x86_64-2012plus' - elif syst == 'Linux': + if syst == "Darwin": + pltf = "osx" + elif syst == "Windows": + pltf = "windows" + if edition == "base" and version == "4.2": + pltf += "_x86_64-2012plus" + elif syst == "Linux": id_name = distro.id() - if id_name in ('ubuntu', 'rhel'): + if id_name in ("ubuntu", "rhel"): pltf = id_name + distro.major_version() + distro.minor_version() if pltf is None: - raise ValueError("Platform cannot be inferred. Please specify platform explicitly with -p. " - f"Available platforms can be found in {config.SETUP_MULTIVERSION_CONFIG}.") + raise ValueError( + "Platform cannot be inferred. Please specify platform explicitly with -p. " + f"Available platforms can be found in {config.SETUP_MULTIVERSION_CONFIG}." + ) else: return pltf @@ -52,10 +55,15 @@ def infer_platform(edition=None, version=None): def get_merge_base_commit(version: str, logger: logging.Logger) -> Optional[str]: """Get merge-base commit hash between origin/master and version.""" cmd = ["git", "merge-base", "origin/master", f"origin/v{version}"] - result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False) + result = subprocess.run( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, check=False + ) if result.returncode: - logger.warning("Git merge-base command failed. Falling back to latest master", cmd=cmd, - error=result.stderr.decode("utf-8").strip()) + logger.warning( + "Git merge-base command failed. Falling back to latest master", + cmd=cmd, + error=result.stderr.decode("utf-8").strip(), + ) return None commit_hash = result.stdout.decode("utf-8").strip() logger.info("Found merge-base commit.", cmd=cmd, commit=commit_hash) @@ -73,24 +81,24 @@ class SetupMultiversion(Subcommand): """Main class for the setup multiversion subcommand.""" def __init__( - self, - download_options, - install_dir="", - link_dir="", - mv_platform=None, - edition=None, - architecture=None, - use_latest=None, - versions=None, - variant=None, - install_last_lts=None, - install_last_continuous=None, - evergreen_config=None, - github_oauth_token=None, - debug=None, - ignore_failed_push=False, - evg_versions_file=None, - logger: Optional[logging.Logger] = None, + self, + download_options, + install_dir="", + link_dir="", + mv_platform=None, + edition=None, + architecture=None, + use_latest=None, + versions=None, + variant=None, + install_last_lts=None, + install_last_continuous=None, + evergreen_config=None, + github_oauth_token=None, + debug=None, + ignore_failed_push=False, + evg_versions_file=None, + logger: Optional[logging.Logger] = None, ): """Initialize.""" @@ -115,8 +123,9 @@ class SetupMultiversion(Subcommand): self.evg_api = evergreen_conn.get_evergreen_api(evergreen_config) # In evergreen github oauth token is stored as `token ******`, so we remove the leading part - self.github_oauth_token = github_oauth_token.replace("token ", - "") if github_oauth_token else None + self.github_oauth_token = ( + github_oauth_token.replace("token ", "") if github_oauth_token else None + ) with open(config.SETUP_MULTIVERSION_CONFIG) as file_handle: raw_yaml = yaml.safe_load(file_handle) self.config = config.SetupMultiversionConfig(raw_yaml) @@ -145,7 +154,10 @@ class SetupMultiversion(Subcommand): logger = logging.Logger("SetupMultiversion", level=log_level) handler = logging.StreamHandler(sys.stdout) handler.setFormatter( - logging.Formatter(fmt="[%(asctime)s - %(name)s - %(levelname)s] %(message)s")) + logging.Formatter( + fmt="[%(asctime)s - %(name)s - %(levelname)s] %(message)s" + ) + ) logger.addHandler(handler) return logger @@ -163,29 +175,39 @@ class SetupMultiversion(Subcommand): # Use the Evergreen project ID as fallback. return re.search(r"(\d+\.\d+$)", evg_project_id).group(0) - def _get_release_versions(self, install_last_lts: Optional[bool], - install_last_continuous: Optional[bool]) -> List[str]: + def _get_release_versions( + self, install_last_lts: Optional[bool], install_last_continuous: Optional[bool] + ) -> List[str]: """Return last-LTS and/or last-continuous versions.""" out = [] if not os.path.isfile( - os.path.join(os.getcwd(), "buildscripts", "resmokelib", - "multiversionconstants.py")): - self.logger.error("This command should be run from the root of the mongo repo.") + os.path.join( + os.getcwd(), "buildscripts", "resmokelib", "multiversionconstants.py" + ) + ): + self.logger.error( + "This command should be run from the root of the mongo repo." + ) self.logger.error( "If you're running it from the root of the mongo repo and still seeing" - " this error, please reach out in #server-testing slack channel.") + " this error, please reach out in #server-testing slack channel." + ) exit(1) try: from buildscripts.resmokelib import multiversionconstants except ImportError: - self.logger.error("Could not import `buildscripts.resmokelib.multiversionconstants`.") + self.logger.error( + "Could not import `buildscripts.resmokelib.multiversionconstants`." + ) self.logger.error( "If you're passing `--installLastLTS` and/or `--installLastContinuous`" " flags, this module is required to automatically calculate last-LTS" - " and/or last-continuous versions.") + " and/or last-continuous versions." + ) self.logger.error( "Try omitting these flags if you don't need the automatic calculation." - " Otherwise please reach out in #server-testing slack channel.") + " Otherwise please reach out in #server-testing slack channel." + ) exit(1) else: releases = { @@ -200,8 +222,10 @@ class SetupMultiversion(Subcommand): """Execute setup multiversion mongodb.""" if self.install_last_lts or self.install_last_continuous: self.versions.extend( - self._get_release_versions(self, self.install_last_lts, - self.install_last_continuous)) + self._get_release_versions( + self, self.install_last_lts, self.install_last_continuous + ) + ) self.versions = list(set(self.versions)) downloaded_versions = [] @@ -211,24 +235,33 @@ class SetupMultiversion(Subcommand): self.logger.info("Fetching download URL from Evergreen.") try: - self.platform = infer_platform(self.edition, - version) if self.inferred_platform else self.platform + self.platform = ( + infer_platform(self.edition, version) + if self.inferred_platform + else self.platform + ) urls_info = EvgURLInfo() if self.use_latest: urls_info = self.get_latest_urls(version) if self.use_latest and not urls_info.urls: - self.logger.warning("Latest URL is not available, falling back" - " to getting the URL from 'mongodb-mongo-master'" - " project preceding the merge-base commit.") + self.logger.warning( + "Latest URL is not available, falling back" + " to getting the URL from 'mongodb-mongo-master'" + " project preceding the merge-base commit." + ) merge_base_revision = get_merge_base_commit(version, self.logger) urls_info = self.get_latest_urls("master", merge_base_revision) if not urls_info.urls: - self.logger.warning("Latest URL is not available or not requested," - " falling back to getting the URL for a specific" - " version.") + self.logger.warning( + "Latest URL is not available or not requested," + " falling back to getting the URL for a specific" + " version." + ) urls_info = self.get_urls(version, self.variant) if not urls_info: - self.logger.error("URL is not available for the version. version=%s", version) + self.logger.error( + "URL is not available for the version. version=%s", version + ) exit(1) urls = urls_info.urls @@ -240,8 +273,11 @@ class SetupMultiversion(Subcommand): install_dir = os.path.join(self.install_dir, version) self.download_and_extract_from_urls(self, urls, bin_suffix, install_dir) - except (github_conn.GithubConnError, evergreen_conn.EvergreenConnError, - download.DownloadError) as ex: + except ( + github_conn.GithubConnError, + evergreen_conn.EvergreenConnError, + download.DownloadError, + ) as ex: self.logger.error(ex) exit(1) @@ -253,42 +289,66 @@ class SetupMultiversion(Subcommand): self._write_windows_install_paths(self, self._windows_bin_install_dirs) if self.evg_versions_file: - self._write_evg_versions_file(self, self.evg_versions_file, downloaded_versions) + self._write_evg_versions_file( + self, self.evg_versions_file, downloaded_versions + ) - def download_and_extract_from_urls(self, urls, bin_suffix, install_dir, skip_symlinks=False): + def download_and_extract_from_urls( + self, urls, bin_suffix, install_dir, skip_symlinks=False + ): """Download and extract values indicated in `urls`.""" artifacts_url = urls.get("Artifacts", "") if self.download_artifacts else None binaries_url = urls.get("Binaries", "") if self.download_binaries else None - python_venv_url = urls.get("Python venv (see included README.txt)", "") or urls.get( - "Python venv (see included venv_readme.txt)", "") if self.download_python_venv else None + python_venv_url = ( + urls.get("Python venv (see included README.txt)", "") + or urls.get("Python venv (see included venv_readme.txt)", "") + if self.download_python_venv + else None + ) download_symbols_url = None if self.download_symbols: for name in [ - " mongo-debugsymbols.tgz", " mongo-debugsymbols.zip", "mongo-debugsymbols.tgz", - "mongo-debugsymbols.zip" + " mongo-debugsymbols.tgz", + " mongo-debugsymbols.zip", + "mongo-debugsymbols.tgz", + "mongo-debugsymbols.zip", ]: download_symbols_url = urls.get(name, None) if download_symbols_url: break if self.download_symbols and not download_symbols_url: - raise download.DownloadError("Symbols download requested but not URL available") + raise download.DownloadError( + "Symbols download requested but not URL available" + ) if self.download_artifacts and not artifacts_url: raise download.DownloadError( - "Evergreen artifacts download requested but not URL available") + "Evergreen artifacts download requested but not URL available" + ) if self.download_binaries and not binaries_url: - raise download.DownloadError("Binaries download requested but not URL available") + raise download.DownloadError( + "Binaries download requested but not URL available" + ) if self.download_python_venv and not python_venv_url: - raise download.DownloadError("Python venv download requested but not URL available") + raise download.DownloadError( + "Python venv download requested but not URL available" + ) - self.setup_mongodb(artifacts_url, binaries_url, download_symbols_url, python_venv_url, - install_dir, bin_suffix, link_dir=self.link_dir, - install_dir_list=self._windows_bin_install_dirs, - skip_symlinks=skip_symlinks) + self.setup_mongodb( + artifacts_url, + binaries_url, + download_symbols_url, + python_venv_url, + install_dir, + bin_suffix, + link_dir=self.link_dir, + install_dir_list=self._windows_bin_install_dirs, + skip_symlinks=skip_symlinks, + ) def _write_windows_install_paths(self, paths): with open(config.WINDOWS_BIN_PATHS_FILE, "a") as out: @@ -296,18 +356,23 @@ class SetupMultiversion(Subcommand): out.write(os.pathsep) out.write(os.pathsep.join(paths)) - self.logger.info("Finished writing binary paths on Windows to %s", - config.WINDOWS_BIN_PATHS_FILE) + self.logger.info( + "Finished writing binary paths on Windows to %s", + config.WINDOWS_BIN_PATHS_FILE, + ) def _write_evg_versions_file(self, file_name: str, versions: List[str]): with open(file_name, "a") as out: out.write("\n".join(versions)) - self.logger.info("Finished writing downloaded Evergreen versions to %s", - os.path.abspath(file_name)) + self.logger.info( + "Finished writing downloaded Evergreen versions to %s", + os.path.abspath(file_name), + ) - def get_latest_urls(self, version: str, - start_from_revision: Optional[str] = None) -> EvgURLInfo: + def get_latest_urls( + self, version: str, start_from_revision: Optional[str] = None + ) -> EvgURLInfo: """Return latest urls.""" urls = {} actual_version_id = None @@ -334,8 +399,13 @@ class SetupMultiversion(Subcommand): for evg_version in chain(iter([evg_version]), evg_versions): # Skip all versions until we get the revision we should start looking from - if found_start_revision is False and evg_version.revision != start_from_revision: - self.logger.warning("Skipping evergreen version. evg_version=%s", evg_version) + if ( + found_start_revision is False + and evg_version.revision != start_from_revision + ): + self.logger.warning( + "Skipping evergreen version. evg_version=%s", evg_version + ) continue else: found_start_revision = True @@ -345,8 +415,11 @@ class SetupMultiversion(Subcommand): continue curr_urls = evergreen_conn.get_compile_artifact_urls( - self.evg_api, evg_version, buildvariant_name, - ignore_failed_push=self.ignore_failed_push) + self.evg_api, + evg_version, + buildvariant_name, + ignore_failed_push=self.ignore_failed_push, + ) if "Binaries" in curr_urls: urls = curr_urls actual_version_id = evg_version.version_id @@ -354,22 +427,30 @@ class SetupMultiversion(Subcommand): return EvgURLInfo(urls=urls, evg_version_id=actual_version_id) - def get_urls(self, version: str, buildvariant_name: Optional[str] = None) -> EvgURLInfo: + def get_urls( + self, version: str, buildvariant_name: Optional[str] = None + ) -> EvgURLInfo: """Return multiversion urls for a given version (as binary version or commit hash or evergreen_version_id).""" evg_version = evergreen_conn.get_evergreen_version(self.evg_api, version) if evg_version is None: - git_tag, commit_hash = github_conn.get_git_tag_and_commit(self.github_oauth_token, - version) - self.logger.info("Found git attributes. git_tag=%s, commit_hash=%s", git_tag, - commit_hash) - evg_version = evergreen_conn.get_evergreen_version(self.evg_api, commit_hash) + git_tag, commit_hash = github_conn.get_git_tag_and_commit( + self.github_oauth_token, version + ) + self.logger.info( + "Found git attributes. git_tag=%s, commit_hash=%s", git_tag, commit_hash + ) + evg_version = evergreen_conn.get_evergreen_version( + self.evg_api, commit_hash + ) if evg_version is None: return EvgURLInfo() if not buildvariant_name: evg_project = evg_version.project_identifier - self.logger.debug("Found evergreen project. evergreen_project=%s", evg_project) + self.logger.debug( + "Found evergreen project. evergreen_project=%s", evg_project + ) try: major_minor_version = re.findall(r"\d+\.\d+", evg_project)[-1] @@ -377,21 +458,37 @@ class SetupMultiversion(Subcommand): major_minor_version = "master" buildvariant_name = self.get_buildvariant_name(major_minor_version) - self.logger.debug("Found buildvariant. buildvariant_name=%s", buildvariant_name) + self.logger.debug( + "Found buildvariant. buildvariant_name=%s", buildvariant_name + ) if buildvariant_name not in evg_version.build_variants_map: raise ValueError( f"Buildvariant {buildvariant_name} not found in evergreen. " - f"Available buildvariants can be found in {config.SETUP_MULTIVERSION_CONFIG}.") + f"Available buildvariants can be found in {config.SETUP_MULTIVERSION_CONFIG}." + ) - urls = evergreen_conn.get_compile_artifact_urls(self.evg_api, evg_version, - buildvariant_name, - ignore_failed_push=self.ignore_failed_push) + urls = evergreen_conn.get_compile_artifact_urls( + self.evg_api, + evg_version, + buildvariant_name, + ignore_failed_push=self.ignore_failed_push, + ) return EvgURLInfo(urls=urls, evg_version_id=evg_version.version_id) - def setup_mongodb(self, artifacts_url, binaries_url, symbols_url, python_venv_url, install_dir, - bin_suffix=None, link_dir=None, install_dir_list=None, skip_symlinks=False): + def setup_mongodb( + self, + artifacts_url, + binaries_url, + symbols_url, + python_venv_url, + install_dir, + bin_suffix=None, + link_dir=None, + install_dir_list=None, + skip_symlinks=False, + ): """Download, extract and symlink.""" for url in [artifacts_url, binaries_url, symbols_url, python_venv_url]: @@ -400,7 +497,9 @@ class SetupMultiversion(Subcommand): def try_download(download_url): self.logger.info("Downloading '%s'", download_url) tarball = download.download_from_s3(download_url) - self.logger.info("Extracting '%s' in '%s' folder", tarball, install_dir) + self.logger.info( + "Extracting '%s' in '%s' folder", tarball, install_dir + ) download.extract_archive(tarball, install_dir) self.logger.info("Removing tarball '%s'", tarball) os.remove(tarball) @@ -409,7 +508,9 @@ class SetupMultiversion(Subcommand): try_download(url) except Exception as err: # pylint: disable=broad-except self.logger.warning( - "Setting up tarball failed with error, retrying once... error=%s", err) + "Setting up tarball failed with error, retrying once... error=%s", + err, + ) time.sleep(1) try_download(url) @@ -422,7 +523,8 @@ class SetupMultiversion(Subcommand): else: self.logger.info( "Linking to install_dir on Windows; executable have to live in different working" - " directories to avoid DLLs for different versions clobbering each other") + " directories to avoid DLLs for different versions clobbering each other" + ) link_dir = download.symlink_version(bin_suffix, install_dir, None) install_dir_list.append(link_dir) @@ -436,8 +538,12 @@ class SetupMultiversion(Subcommand): return self.variant return evergreen_conn.get_buildvariant_name( - config=self.config, edition=self.edition, platform=self.platform, - architecture=self.architecture, major_minor_version=major_minor_version) + config=self.config, + edition=self.edition, + platform=self.platform, + architecture=self.architecture, + major_minor_version=major_minor_version, + ) class _DownloadOptions(object): @@ -459,96 +565,196 @@ class SetupMultiversionPlugin(PluginInterface): # Shorthand for brevity. args = parsed_args - download_options = _DownloadOptions(db=args.download_binaries, ds=args.download_symbols, - da=args.download_artifacts, - dv=args.download_python_venv) + download_options = _DownloadOptions( + db=args.download_binaries, + ds=args.download_symbols, + da=args.download_artifacts, + dv=args.download_python_venv, + ) if args.use_existing_releases_file: multiversionsetupconstants.USE_EXISTING_RELEASES_FILE = True return SetupMultiversion( - install_dir=args.install_dir, link_dir=args.link_dir, mv_platform=args.platform, - edition=args.edition, architecture=args.architecture, use_latest=args.use_latest, - versions=args.versions, install_last_lts=args.install_last_lts, variant=args.variant, - install_last_continuous=args.install_last_continuous, download_options=download_options, - evergreen_config=args.evergreen_config, github_oauth_token=args.github_oauth_token, - ignore_failed_push=(not args.require_push), evg_versions_file=args.evg_versions_file, - debug=args.debug, logger=SetupMultiversion.setup_logger(parsed_args.debug)) + install_dir=args.install_dir, + link_dir=args.link_dir, + mv_platform=args.platform, + edition=args.edition, + architecture=args.architecture, + use_latest=args.use_latest, + versions=args.versions, + install_last_lts=args.install_last_lts, + variant=args.variant, + install_last_continuous=args.install_last_continuous, + download_options=download_options, + evergreen_config=args.evergreen_config, + github_oauth_token=args.github_oauth_token, + ignore_failed_push=(not args.require_push), + evg_versions_file=args.evg_versions_file, + debug=args.debug, + logger=SetupMultiversion.setup_logger(parsed_args.debug), + ) @classmethod def _add_args_to_parser(cls, parser): - parser.add_argument("-i", "--installDir", dest="install_dir", required=True, - help="Directory to install the download archive. [REQUIRED]") parser.add_argument( - "-l", "--linkDir", dest="link_dir", required=True, + "-i", + "--installDir", + dest="install_dir", + required=True, + help="Directory to install the download archive. [REQUIRED]", + ) + parser.add_argument( + "-l", + "--linkDir", + dest="link_dir", + required=True, help="Directory to contain links to all binaries for each version " - "in the install directory. [REQUIRED]") + "in the install directory. [REQUIRED]", + ) editions = ("base", "enterprise", "targeted") - parser.add_argument("-e", "--edition", dest="edition", choices=editions, - default="enterprise", - help="Edition of the build to download, [default: %(default)s].") parser.add_argument( - "-p", "--platform", dest="platform", help="Platform to download. " - f"Available platforms can be found in {config.SETUP_MULTIVERSION_CONFIG}.") + "-e", + "--edition", + dest="edition", + choices=editions, + default="enterprise", + help="Edition of the build to download, [default: %(default)s].", + ) parser.add_argument( - "-a", "--architecture", dest="architecture", default="x86_64", + "-p", + "--platform", + dest="platform", + help="Platform to download. " + f"Available platforms can be found in {config.SETUP_MULTIVERSION_CONFIG}.", + ) + parser.add_argument( + "-a", + "--architecture", + dest="architecture", + default="x86_64", help="Architecture to download, [default: %(default)s]. Examples include: " - "'arm64', 'ppc64le', 's390x' and 'x86_64'.") + "'arm64', 'ppc64le', 's390x' and 'x86_64'.", + ) parser.add_argument( - "-v", "--variant", dest="variant", default=None, help="Specify a variant to use, " - "which supersedes the --platform, --edition and --architecture options.") + "-v", + "--variant", + dest="variant", + default=None, + help="Specify a variant to use, " + "which supersedes the --platform, --edition and --architecture options.", + ) parser.add_argument( - "-u", "--useLatest", dest="use_latest", action="store_true", + "-u", + "--useLatest", + dest="use_latest", + action="store_true", help="If specified, the latest version from Evergreen will be downloaded, if it exists, " "for the version specified. For example, if specifying version 4.4 for download, the latest " "version from `mongodb-mongo-v4.4` Evergreen project will be downloaded. Otherwise the latest " - "by git tag version will be downloaded.") + "by git tag version will be downloaded.", + ) parser.add_argument( - "versions", nargs="*", + "versions", + nargs="*", help="Accepts binary versions, full git commit hashes, evergreen version ids. " "Binary version examples: 4.0, 4.0.1, 4.0.0-rc0. If 'rc' is included in the version name, " "we'll use the exact rc, otherwise we'll pull the highest non-rc version compatible with the " - "version specified.") - parser.add_argument("--installLastLTS", dest="install_last_lts", action="store_true", - help="If specified, the last LTS version will be installed") - parser.add_argument("--installLastContinuous", dest="install_last_continuous", - action="store_true", - help="If specified, the last continuous version will be installed") + "version specified.", + ) + parser.add_argument( + "--installLastLTS", + dest="install_last_lts", + action="store_true", + help="If specified, the last LTS version will be installed", + ) + parser.add_argument( + "--installLastContinuous", + dest="install_last_continuous", + action="store_true", + help="If specified, the last continuous version will be installed", + ) - parser.add_argument("-db", "--downloadBinaries", dest="download_binaries", - action="store_true", default=True, - help="whether to download binaries, [default: %(default)s].") - parser.add_argument("-ds", "--downloadSymbols", dest="download_symbols", - action="store_true", default=False, - help="whether to download debug symbols, [default: %(default)s].") - parser.add_argument("-da", "--downloadArtifacts", dest="download_artifacts", - action="store_true", default=False, - help="whether to download artifacts, [default: %(default)s].") - parser.add_argument("-dv", "--downloadPythonVenv", dest="download_python_venv", - action="store_true", default=False, - help="whether to download python venv, [default: %(default)s].") parser.add_argument( - "-ec", "--evergreenConfig", dest="evergreen_config", + "-db", + "--downloadBinaries", + dest="download_binaries", + action="store_true", + default=True, + help="whether to download binaries, [default: %(default)s].", + ) + parser.add_argument( + "-ds", + "--downloadSymbols", + dest="download_symbols", + action="store_true", + default=False, + help="whether to download debug symbols, [default: %(default)s].", + ) + parser.add_argument( + "-da", + "--downloadArtifacts", + dest="download_artifacts", + action="store_true", + default=False, + help="whether to download artifacts, [default: %(default)s].", + ) + parser.add_argument( + "-dv", + "--downloadPythonVenv", + dest="download_python_venv", + action="store_true", + default=False, + help="whether to download python venv, [default: %(default)s].", + ) + parser.add_argument( + "-ec", + "--evergreenConfig", + dest="evergreen_config", help="Location of evergreen configuration file. If not specified it will look " - f"for it in the following locations: {evergreen_conn.EVERGREEN_CONFIG_LOCATIONS}") + f"for it in the following locations: {evergreen_conn.EVERGREEN_CONFIG_LOCATIONS}", + ) parser.add_argument( - "-gt", "--githubOauthToken", dest="github_oauth_token", + "-gt", + "--githubOauthToken", + dest="github_oauth_token", help="Set the token to increase your rate limit. In most cases it works without auth. " "Otherwise you can pass OAuth token to increase the github API rate limit. See " - "https://developer.github.com/v3/#rate-limiting") - parser.add_argument("-d", "--debug", dest="debug", action="store_true", default=False, - help="Set DEBUG logging level.") + "https://developer.github.com/v3/#rate-limiting", + ) parser.add_argument( - "-rp", "--require-push", dest="require_push", action="store_true", default=False, - help="Require the push task to be successful for assets to be downloaded") + "-d", + "--debug", + dest="debug", + action="store_true", + default=False, + help="Set DEBUG logging level.", + ) + parser.add_argument( + "-rp", + "--require-push", + dest="require_push", + action="store_true", + default=False, + help="Require the push task to be successful for assets to be downloaded", + ) # Hidden flag that determines if we should generate a new releases yaml file. This flag # should be set to True if we are invoking setup_multiversion multiple times in parallel, # to prevent multiple processes from modifying the releases yaml file simultaneously. - parser.add_argument("--useExistingReleasesFile", dest="use_existing_releases_file", - action="store_true", default=False, help=argparse.SUPPRESS) + parser.add_argument( + "--useExistingReleasesFile", + dest="use_existing_releases_file", + action="store_true", + default=False, + help=argparse.SUPPRESS, + ) # Hidden flag to write out the Evergreen versions of the downloaded binaries. - parser.add_argument("--evgVersionsFile", dest="evg_versions_file", default=None, - help=argparse.SUPPRESS) + parser.add_argument( + "--evgVersionsFile", + dest="evg_versions_file", + default=None, + help=argparse.SUPPRESS, + ) def add_subcommand(self, subparsers): """Create and add the parser for the subcommand.""" diff --git a/buildscripts/resmokelib/sighandler.py b/buildscripts/resmokelib/sighandler.py index abb56306d9c..1e7bf6d02fa 100644 --- a/buildscripts/resmokelib/sighandler.py +++ b/buildscripts/resmokelib/sighandler.py @@ -13,7 +13,7 @@ import psutil from buildscripts.resmokelib import config, parser, reportfile, testing from buildscripts.resmokelib.flags import HANG_ANALYZER_CALLED -_IS_WINDOWS = (sys.platform == "win32") +_IS_WINDOWS = sys.platform == "win32" if _IS_WINDOWS: import win32api import win32event @@ -45,10 +45,14 @@ def register(logger, suites, start_time): # Wait for task time out to dump stacks. ret = win32event.WaitForSingleObject(event_handle, win32event.INFINITE) if ret != win32event.WAIT_OBJECT_0: - logger.error("_handle_set_event WaitForSingleObject failed: %d" % ret) + logger.error( + "_handle_set_event WaitForSingleObject failed: %d" % ret + ) return except win32event.error as err: - logger.error("Exception from win32event.WaitForSingleObject with error: %s" % err) + logger.error( + "Exception from win32event.WaitForSingleObject with error: %s" % err + ) else: HANG_ANALYZER_CALLED.set() header_msg = "Dumping stacks due to signal from win32event.SetEvent" @@ -62,7 +66,7 @@ def register(logger, suites, start_time): testing.suite.Suite.log_summaries(logger, suites, time.time() - start_time) - if 'is_inner_level' not in config.INTERNAL_PARAMS: + if "is_inner_level" not in config.INTERNAL_PARAMS: # Gather and analyze pids of all subprocesses. # Do nothing for child resmoke process started by another resmoke process # (e.g. backup_restore.js) The child processes of the child resmoke will be @@ -81,8 +85,9 @@ def register(logger, suites, start_time): security_attributes = None manual_reset = False initial_state = False - task_timeout_handle = win32event.CreateEvent(security_attributes, manual_reset, - initial_state, event_name) + task_timeout_handle = win32event.CreateEvent( + security_attributes, manual_reset, initial_state, event_name + ) except win32event.error as err: logger.error("Exception from win32event.CreateEvent with error: %s" % err) return @@ -91,9 +96,11 @@ def register(logger, suites, start_time): atexit.register(win32api.CloseHandle, task_timeout_handle) # Create thread. - event_handler_thread = threading.Thread(target=_handle_set_event, - kwargs={"event_handle": task_timeout_handle}, - name="windows_event_handler_thread") + event_handler_thread = threading.Thread( + target=_handle_set_event, + kwargs={"event_handle": task_timeout_handle}, + name="windows_event_handler_thread", + ) event_handler_thread.daemon = True event_handler_thread.start() else: @@ -126,7 +133,7 @@ def _get_pids(): for child in parent.children(recursive=True): # Don't signal python threads. They have already been signalled in the evergreen timeout # section. - if 'python' not in child.name().lower(): + if "python" not in child.name().lower(): pids.append(child.pid) return pids @@ -136,8 +143,10 @@ def _analyze_pids(logger, pids): """Analyze the PIDs spawned by the current resmoke process.""" # If 'test_analysis' is specified, we will just write the pids out to a file and kill them # Instead of running analysis. This option will only be specified in resmoke selftests. - if 'test_analysis' in config.INTERNAL_PARAMS: - with open(os.path.join(config.DBPATH_PREFIX, "test_analysis.txt"), "w") as analysis_file: + if "test_analysis" in config.INTERNAL_PARAMS: + with open( + os.path.join(config.DBPATH_PREFIX, "test_analysis.txt"), "w" + ) as analysis_file: analysis_file.write("\n".join([str(pid) for pid in pids])) for pid in pids: try: @@ -153,8 +162,15 @@ def _analyze_pids(logger, pids): # See hang-analyzer argument options here: # https://github.com/10gen/mongo/blob/8636ede10bd70b32ff4b6cd115132ab0f22b89c7/buildscripts/resmokelib/hang_analyzer/hang_analyzer.py#L245 hang_analyzer_args = [ - 'hang-analyzer', '-c', '-o', 'file', '-o', 'stdout', '-k', '-d', - ','.join([str(p) for p in pids]) + "hang-analyzer", + "-c", + "-o", + "file", + "-o", + "stdout", + "-k", + "-d", + ",".join([str(p) for p in pids]), ] _hang_analyzer = parser.parse_command_line(hang_analyzer_args, logger=logger) @@ -175,7 +191,9 @@ def _analyze_pids(logger, pids): logger.warning( "Resmoke invoked hang analyzer thread did not finish, but will continue running in the background. The thread may be disruputed and may show extraneous output." ) - logger.warning("Cleaning up resmoke child processes so that resmoke can fail gracefully.") + logger.warning( + "Cleaning up resmoke child processes so that resmoke can fail gracefully." + ) _hang_analyzer.kill_rogue_processes() else: diff --git a/buildscripts/resmokelib/suitesconfig.py b/buildscripts/resmokelib/suitesconfig.py index 13a3dc2d959..cd0e5a309fc 100644 --- a/buildscripts/resmokelib/suitesconfig.py +++ b/buildscripts/resmokelib/suitesconfig.py @@ -1,4 +1,5 @@ """Module for retrieving the configuration of resmoke.py test suites.""" + import collections import copy import os @@ -39,7 +40,8 @@ def get_named_suites() -> List[SuiteName]: dbtest = {"dbtest"} explicit_suite_names = [ - name for name in ExplicitSuiteConfig.get_named_suites() + name + for name in ExplicitSuiteConfig.get_named_suites() if (name not in executor_only and name not in dbtest) ] composed_suite_names = MatrixSuiteConfig.get_named_suites() @@ -50,7 +52,9 @@ def get_named_suites() -> List[SuiteName]: def get_suite_files() -> Dict[str, str]: """Get the physical files defining these suites for parsing comments.""" - return merge_dicts(ExplicitSuiteConfig.get_suite_files(), MatrixSuiteConfig.get_suite_files()) + return merge_dicts( + ExplicitSuiteConfig.get_suite_files(), MatrixSuiteConfig.get_suite_files() + ) def create_test_membership_map(fail_on_missing_selector=False, test_kind=None): @@ -105,7 +109,11 @@ def get_suites(suite_names_or_paths, test_files) -> List[_suite.Suite]: # specified. If an option is specified, then sort the tests for consistent execution order. _config.ORDER_TESTS_BY_NAME = any( tag_filter is not None - for tag_filter in (_config.EXCLUDE_WITH_ANY_TAGS, _config.INCLUDE_WITH_ANY_TAGS)) + for tag_filter in ( + _config.EXCLUDE_WITH_ANY_TAGS, + _config.INCLUDE_WITH_ANY_TAGS, + ) + ) # Build configuration for list of files to run. suite_roots = _make_suite_roots(test_files) @@ -121,11 +129,13 @@ def get_suites(suite_names_or_paths, test_files) -> List[_suite.Suite]: for test in override_suite.tests: if test in suite.excluded: if _config.FORCE_EXCLUDED_TESTS: - loggers.ROOT_EXECUTOR_LOGGER.warning("Will forcibly run excluded test: %s", - test) + loggers.ROOT_EXECUTOR_LOGGER.warning( + "Will forcibly run excluded test: %s", test + ) else: raise errors.TestExcludedFromSuiteError( - f"'{test}' excluded in '{suite.get_name()}'") + f"'{test}' excluded in '{suite.get_name()}'" + ) suite = override_suite suites.append(suite) return suites @@ -180,7 +190,9 @@ class ExplicitSuiteConfig(SuiteConfigInterface): if os.path.isfile(suite_name): suite_path = suite_name else: - raise ValueError("Expected a suite YAML config, but got '%s'" % suite_name) + raise ValueError( + "Expected a suite YAML config, but got '%s'" % suite_name + ) else: # Not an explicit suite, return None. return None @@ -226,8 +238,9 @@ class MatrixSuiteConfig(SuiteConfigInterface): """Get all YAML files in the given directory.""" return { short_name: load_yaml_file(path) - for short_name, path in cls.__get_suite_files_in_dir(os.path.abspath( - target_dir)).items() + for short_name, path in cls.__get_suite_files_in_dir( + os.path.abspath(target_dir) + ).items() } @staticmethod @@ -246,8 +259,8 @@ class MatrixSuiteConfig(SuiteConfigInterface): generated_path = cls.get_generated_suite_path(suite_name) if not os.path.exists(generated_path): raise errors.InvalidMatrixSuiteError( - f"No generated suite file was found for {suite_name}" + - "To (re)generate the matrix suite files use `python3 buildscripts/resmoke.py generate-matrix-suites`" + f"No generated suite file was found for {suite_name}" + + "To (re)generate the matrix suite files use `python3 buildscripts/resmoke.py generate-matrix-suites`" ) new_text = cls.generate_matrix_suite_text(suite_name) @@ -260,8 +273,7 @@ class MatrixSuiteConfig(SuiteConfigInterface): loggers.ROOT_EXECUTOR_LOGGER.error(new_text) raise errors.InvalidMatrixSuiteError( f"The generated file found on disk did not match the mapping file for {suite_name}. " - + - "To (re)generate the matrix suite files use `python3 buildscripts/resmoke.py generate-matrix-suites`" + + "To (re)generate the matrix suite files use `python3 buildscripts/resmoke.py generate-matrix-suites`" ) return config @@ -291,14 +303,16 @@ class MatrixSuiteConfig(SuiteConfigInterface): base_suite = ExplicitSuiteConfig.get_config_obj_no_verify(base_suite_name) if base_suite is None: - raise ValueError(f"Unknown base suite {base_suite_name} for matrix suite {suite_name}") + raise ValueError( + f"Unknown base suite {base_suite_name} for matrix suite {suite_name}" + ) res = copy.deepcopy(base_suite) - res['matrix_suite'] = True + res["matrix_suite"] = True overrides = copy.deepcopy(overrides) if description: - res['description'] = description + res["description"] = description if override_names: for override_name in override_names: @@ -310,7 +324,9 @@ class MatrixSuiteConfig(SuiteConfigInterface): for key in excludes_dict: if key not in ["exclude_with_any_tags", "exclude_files"]: - raise ValueError(f"{excludes_name} is not supported in the 'excludes' tag") + raise ValueError( + f"{excludes_name} is not supported in the 'excludes' tag" + ) value = excludes_dict[key] if not isinstance(value, list): @@ -360,11 +376,14 @@ class MatrixSuiteConfig(SuiteConfigInterface): for filename, override_config_file in overrides_files.items(): for override_config in override_config_file: if "name" in override_config and "value" in override_config: - cls._all_overrides[ - f"{filename}.{override_config['name']}"] = override_config["value"] + cls._all_overrides[f"{filename}.{override_config['name']}"] = ( + override_config["value"] + ) else: - raise ValueError("Invalid override configuration, missing required keys. ", - override_config) + raise ValueError( + "Invalid override configuration, missing required keys. ", + override_config, + ) return cls._all_overrides @classmethod @@ -394,8 +413,10 @@ class MatrixSuiteConfig(SuiteConfigInterface): if "base_suite" in suite_config: cls._all_mappings[suite_name] = suite_config else: - raise ValueError("Invalid suite configuration, missing required keys. ", - suite_config) + raise ValueError( + "Invalid suite configuration, missing required keys. ", + suite_config, + ) return cls._all_mappings @classmethod @@ -456,7 +477,7 @@ class MatrixSuiteConfig(SuiteConfigInterface): def generate_matrix_suite_file(cls, suite_name): text = cls.generate_matrix_suite_text(suite_name) path = cls.get_generated_suite_path(suite_name) - with open(path, 'w+') as file: + with open(path, "w+") as file: file.write(text) print(f"Generated matrix suite file {path}") @@ -481,7 +502,8 @@ class SuiteFinder(object): if explicit_suite and matrix_suite: raise errors.DuplicateSuiteDefinition( - "Multiple definitions for suite '%s'" % suite_path) + "Multiple definitions for suite '%s'" % suite_path + ) suite = matrix_suite or explicit_suite diff --git a/buildscripts/resmokelib/symbolizer/__init__.py b/buildscripts/resmokelib/symbolizer/__init__.py index 722a6f1b1c9..166e13f861a 100644 --- a/buildscripts/resmokelib/symbolizer/__init__.py +++ b/buildscripts/resmokelib/symbolizer/__init__.py @@ -1,4 +1,5 @@ """Wrapper around mongosym to download everything required.""" + import logging import os import shutil @@ -33,12 +34,12 @@ class Symbolizer(Subcommand): """Interact with Symbolizer.""" def __init__( - self, - task_id, - download_symbols_only, - bin_name=None, - all_args=None, - logger: Optional[logging.Logger] = None, + self, + task_id, + download_symbols_only, + bin_name=None, + all_args=None, + logger: Optional[logging.Logger] = None, ): """Constructor.""" @@ -49,7 +50,9 @@ class Symbolizer(Subcommand): self.logger = logger or self.setup_logger() - self.evg_api: evergreen_conn.RetryingEvergreenApi = evergreen_conn.get_evergreen_api() + self.evg_api: evergreen_conn.RetryingEvergreenApi = ( + evergreen_conn.get_evergreen_api() + ) self.multiversion_setup = self._get_multiversion_setup() self.task_info = self.evg_api.task_by_id(task_id) @@ -81,7 +84,10 @@ class Symbolizer(Subcommand): logger = logging.Logger("symbolizer", level=log_level) handler = logging.StreamHandler(sys.stdout) handler.setFormatter( - logging.Formatter(fmt="[%(asctime)s - %(name)s - %(levelname)s] %(message)s")) + logging.Formatter( + fmt="[%(asctime)s - %(name)s - %(levelname)s] %(message)s" + ) + ) logger.addHandler(handler) return logger @@ -90,21 +96,26 @@ class Symbolizer(Subcommand): if not task_id: raise ValueError( "A valid Evergreen Task ID is required. You can get it by double clicking the" - " Evergreen URL after `/task/` on any task page") + " Evergreen URL after `/task/` on any task page" + ) if not download_symbols_only: if not bin_name: raise ValueError( - "A binary base name is required. This is usually `mongod` or `mongos`") + "A binary base name is required. This is usually `mongod` or `mongos`" + ) if not os.path.isfile(DEFAULT_SYMBOLIZER_LOCATION): raise ValueError( "llvm-symbolizer in MongoDB toolchain not found. Please run this on a " - "virtual workstation or install the toolchain manually") + "virtual workstation or install the toolchain manually" + ) if not os.access("/data/mci", os.W_OK): - raise ValueError("Please ensure you have write access to /data/mci. " - "E.g. with `sudo mkdir -p /data/mci; sudo chown $USER /data/mci`") + raise ValueError( + "Please ensure you have write access to /data/mci. " + "E.g. with `sudo mkdir -p /data/mci; sudo chown $USER /data/mci`" + ) def _get_multiversion_setup(self): if self.download_symbols_only: @@ -112,7 +123,9 @@ class Symbolizer(Subcommand): else: download_options = _DownloadOptions(db=True, ds=True, da=False, dv=False) return SetupMultiversion( - download_options=download_options, ignore_failed_push=True, logger=self.logger + download_options=download_options, + ignore_failed_push=True, + logger=self.logger, ) def _get_compile_artifacts(self): @@ -120,12 +133,14 @@ class Symbolizer(Subcommand): version_id = self.task_info.version_id buildvariant_name = self.task_info.build_variant - urlinfo = self.multiversion_setup.get_urls(version=version_id, - buildvariant_name=buildvariant_name) + urlinfo = self.multiversion_setup.get_urls( + version=version_id, buildvariant_name=buildvariant_name + ) self.logger.info("Found urls to download and extract %s", urlinfo.urls) - self.multiversion_setup.download_and_extract_from_urls(urlinfo.urls, bin_suffix=None, - install_dir=self.dest_dir) + self.multiversion_setup.download_and_extract_from_urls( + urlinfo.urls, bin_suffix=None, install_dir=self.dest_dir + ) def _patch_diff_by_id(self): version_id = self.task_info.version_id @@ -138,9 +153,13 @@ class Symbolizer(Subcommand): for module_name, diff in module_diffs.items(): # TODO: enterprise. if "mongodb-mongo-" in module_name: - with open(os.path.join(self.dest_dir, "patch.diff"), 'w') as git_diff_file: + with open( + os.path.join(self.dest_dir, "patch.diff"), "w" + ) as git_diff_file: git_diff_file.write(diff) - subprocess.run(["git", "apply", "patch.diff"], cwd=self.dest_dir, check=True) + subprocess.run( + ["git", "apply", "patch.diff"], cwd=self.dest_dir, check=True + ) def _get_source(self): revision = self.task_info.revision @@ -149,8 +168,11 @@ class Symbolizer(Subcommand): try: cache_dir = mkdtemp_in_build_dir() - subprocess.run(["curl", "-L", "-o", "source.zip", source_url], cwd=cache_dir, - check=True) + subprocess.run( + ["curl", "-L", "-o", "source.zip", source_url], + cwd=cache_dir, + check=True, + ) subprocess.run(["unzip", "-q", "source.zip"], cwd=cache_dir, check=True) subprocess.run(["rm", "source.zip"], cwd=cache_dir, check=True) @@ -171,13 +193,15 @@ class Symbolizer(Subcommand): try: if os.path.isdir(self.dest_dir): self.logger.info( - "directory for build already exists, skipping fetching source and symbols") + "directory for build already exists, skipping fetching source and symbols" + ) return self.logger.info("Getting source from GitHub...") self._get_source() self.logger.info( - "Downloading debug symbols and binaries, this may take a few minutes...") + "Downloading debug symbols and binaries, this may take a few minutes..." + ) self._get_compile_artifacts() self.logger.info("Applying patch diff (if any)...") self._patch_diff_by_id() @@ -194,18 +218,23 @@ class Symbolizer(Subcommand): def _parse_mongosymb_args(self): symbolizer_path = self.mongosym_args.symbolizer_path if symbolizer_path: - raise ValueError("Must use the default symbolizer from the toolchain," - f"not {symbolizer_path}") + raise ValueError( + "Must use the default symbolizer from the toolchain," + f"not {symbolizer_path}" + ) self.mongosym_args.symbolizer_path = DEFAULT_SYMBOLIZER_LOCATION sym_search_path = self.mongosym_args.path_to_executable if sym_search_path: - raise ValueError(f"Must not specify path_to_executable, the original path that " - f"generated the symbols will be used: {sym_search_path}") + raise ValueError( + f"Must not specify path_to_executable, the original path that " + f"generated the symbols will be used: {sym_search_path}" + ) # TODO: support non-hygienic builds. self.mongosym_args.path_to_executable = build_hygienic_bin_path( - parent=self.dest_dir, child=self.bin_name) + parent=self.dest_dir, child=self.bin_name + ) self.mongosym_args.src_dir_to_move = self.dest_dir @@ -237,22 +266,44 @@ class SymbolizerPlugin(PluginInterface): """ parser = subparsers.add_parser(_COMMAND, help=_HELP) parser.add_argument( - "--task-id", '-t', action="store", type=str, required=True, - help="Fetch corresponding binaries and/or symbols given an Evergreen task ID") + "--task-id", + "-t", + action="store", + type=str, + required=True, + help="Fetch corresponding binaries and/or symbols given an Evergreen task ID", + ) parser.add_argument( - "--binary-name", "-b", action="store", type=str, default="mongod", - help="Base name of the binary that generated the stacktrace; e.g. `mongod` or `mongos`") + "--binary-name", + "-b", + action="store", + type=str, + default="mongod", + help="Base name of the binary that generated the stacktrace; e.g. `mongod` or `mongos`", + ) - parser.add_argument("--download-symbols-only", "-s", action="store_true", default=False, - help="Just download the debug symbol tarball for the given task-id") + parser.add_argument( + "--download-symbols-only", + "-s", + action="store_true", + default=False, + help="Just download the debug symbol tarball for the given task-id", + ) - parser.add_argument("--debug", "-d", dest="debug", action="store_true", default=False, - help="Set DEBUG logging level.") + parser.add_argument( + "--debug", + "-d", + dest="debug", + action="store_true", + default=False, + help="Set DEBUG logging level.", + ) group = parser.add_argument_group( "Verbatim mongosymb.py options for advanced usages", - description="Compatibility not guaranteed, use at your own risk") + description="Compatibility not guaranteed, use at your own risk", + ) mongosymb.make_argument_parser(group) def parse(self, subcommand, parser, parsed_args, **kwargs): diff --git a/buildscripts/resmokelib/testing/docker_cluster_image_builder.py b/buildscripts/resmokelib/testing/docker_cluster_image_builder.py index f137e21df23..0e36f678e67 100644 --- a/buildscripts/resmokelib/testing/docker_cluster_image_builder.py +++ b/buildscripts/resmokelib/testing/docker_cluster_image_builder.py @@ -55,11 +55,17 @@ class DockerComposeImageBuilder: self.WORKLOAD_BUILD_CONTEXT = "buildscripts/antithesis/base_images/workload" self.WORKLOAD_DOCKERFILE = f"{self.WORKLOAD_BUILD_CONTEXT}/Dockerfile" - self.MONGO_BINARIES_BUILD_CONTEXT = "buildscripts/antithesis/base_images/mongo_binaries" - self.MONGO_BINARIES_DOCKERFILE = f"{self.MONGO_BINARIES_BUILD_CONTEXT}/Dockerfile" + self.MONGO_BINARIES_BUILD_CONTEXT = ( + "buildscripts/antithesis/base_images/mongo_binaries" + ) + self.MONGO_BINARIES_DOCKERFILE = ( + f"{self.MONGO_BINARIES_BUILD_CONTEXT}/Dockerfile" + ) # Artifact constants - self.DIST_TEST_DIR = "dist-test" if self.in_evergreen else "antithesis-dist-test" + self.DIST_TEST_DIR = ( + "dist-test" if self.in_evergreen else "antithesis-dist-test" + ) self.MONGODB_BINARIES_DIR = os.path.join(self.DIST_TEST_DIR, "bin") self.MONGODB_LIBRARIES_DIR = os.path.join(self.DIST_TEST_DIR, "lib") self.TSAN_SUPPRESSIONS_SOURCE = "etc/tsan.suppressions" @@ -72,7 +78,9 @@ class DockerComposeImageBuilder: # MongoDB Enterprise Modules constants self.MODULES_RELATIVE_PATH = "src/mongo/db/modules" - self.MONGO_ENTERPRISE_MODULES_RELATIVE_PATH = f"{self.MODULES_RELATIVE_PATH}/enterprise" + self.MONGO_ENTERPRISE_MODULES_RELATIVE_PATH = ( + f"{self.MODULES_RELATIVE_PATH}/enterprise" + ) # Port suffix ranging from 1-24 is subject to fault injection while ports 130+ are safe. self.next_available_fault_enabled_ip = 2 @@ -105,7 +113,9 @@ class DockerComposeImageBuilder: # (2) it should add the `--externalSUT` flag command = sys.argv rm_index = command.index("--dockerComposeBuildImages") - return ' '.join(command[0:rm_index] + command[rm_index + 2:] + ["--externalSUT"]) + return " ".join( + command[0:rm_index] + command[rm_index + 2 :] + ["--externalSUT"] + ) def _add_docker_compose_configuration_to_build_context(self, build_context) -> None: """ @@ -129,53 +139,73 @@ class DockerComposeImageBuilder: ip_suffix = self.next_available_fault_disabled_ip self.next_available_fault_disabled_ip += 1 return { - "container_name": name, "hostname": name, + "container_name": name, + "hostname": name, "image": f'{"workload" if name == "workload" else "mongo-binaries"}:{self.tag}', "volumes": [ f"./logs/{name}:/var/log/mongodb/", "./scripts:/scripts/", f"./data/{name}:/data/db", - ], "command": f"/bin/bash /scripts/{name}.sh", "networks": { + ], + "command": f"/bin/bash /scripts/{name}.sh", + "networks": { "antithesis-net": {"ipv4_address": f"10.20.20.{ip_suffix}"} - }, "depends_on": depends_on + }, + "depends_on": depends_on, } docker_compose_yml = { - "version": "3.0", "services": { - "workload": - create_docker_compose_service( - "workload", fault_injection=False, depends_on=[ - process.logger.external_sut_hostname - for process in self.suite_fixture.all_processes() - ]) - }, "networks": { + "version": "3.0", + "services": { + "workload": create_docker_compose_service( + "workload", + fault_injection=False, + depends_on=[ + process.logger.external_sut_hostname + for process in self.suite_fixture.all_processes() + ], + ) + }, + "networks": { "antithesis-net": { - "driver": "bridge", "ipam": {"config": [{"subnet": "10.20.20.0/24"}]} + "driver": "bridge", + "ipam": {"config": [{"subnet": "10.20.20.0/24"}]}, } - } + }, } print("Writing workload init script...") - with open(os.path.join(build_context, "scripts", "workload.sh"), "w") as workload_init: + with open( + os.path.join(build_context, "scripts", "workload.sh"), "w" + ) as workload_init: workload_init.write("tail -f /dev/null\n") print("Writing resmoke run script for convenience...") - with open(os.path.join(build_context, "scripts", "run_resmoke.sh"), "w") as run_resmoke: + with open( + os.path.join(build_context, "scripts", "run_resmoke.sh"), "w" + ) as run_resmoke: run_resmoke.write(f'{self.get_resmoke_run_command()} "$@"\n') print("Writing mongo{d,s} init scripts...") for process in self.suite_fixture.all_processes(): # Add the `Process` as a service in the docker-compose.yml service_name = process.logger.external_sut_hostname - docker_compose_yml["services"][service_name] = create_docker_compose_service( - service_name, fault_injection=True, depends_on=[]) + docker_compose_yml["services"][service_name] = ( + create_docker_compose_service( + service_name, fault_injection=True, depends_on=[] + ) + ) # Write the `Process` args as an init script - with open(os.path.join(build_context, "scripts", f"{service_name}.sh"), "w") as file: - file.write(" ".join(map(shlex.quote, process.args)) + '\n') + with open( + os.path.join(build_context, "scripts", f"{service_name}.sh"), "w" + ) as file: + file.write(" ".join(map(shlex.quote, process.args)) + "\n") print("Writing `docker-compose.yml`...") - with open(os.path.join(build_context, "docker-compose.yml"), "w") as docker_compose: - docker_compose.write(yaml.dump(docker_compose_yml) + '\n') + with open( + os.path.join(build_context, "docker-compose.yml"), "w" + ) as docker_compose: + docker_compose.write(yaml.dump(docker_compose_yml) + "\n") print("Writing Dockerfile...") with open(os.path.join(build_context, "Dockerfile"), "w") as dockerfile: @@ -212,7 +242,7 @@ class DockerComposeImageBuilder: def _get_docker_build_san_args(self): args = [] for key, value in self.san_options.items(): - args += ["--build-arg", f'{key}={value}'] + args += ["--build-arg", f"{key}={value}"] return args def _docker_build(self, image_name, dockerfile, build_context): @@ -229,23 +259,35 @@ class DockerComposeImageBuilder: :return: None """ # Build out the directory structure and write the startup scripts for the config image - print(f"Preparing antithesis config image build context for `{self.suite_name}`...") + print( + f"Preparing antithesis config image build context for `{self.suite_name}`..." + ) self._initialize_docker_compose_build_context(self.DOCKER_COMPOSE_BUILD_CONTEXT) - self._add_docker_compose_configuration_to_build_context(self.DOCKER_COMPOSE_BUILD_CONTEXT) + self._add_docker_compose_configuration_to_build_context( + self.DOCKER_COMPOSE_BUILD_CONTEXT + ) # Our official builds happen in Evergreen. Assert debug symbols are on system. # If this is running locally, this is for development purposes only and debug symbols are not required. if self.in_evergreen: - assert os.path.exists(self.MONGODB_DEBUGSYMBOLS - ), f"No debug symbols available at: {self.MONGODB_DEBUGSYMBOLS}" + assert os.path.exists( + self.MONGODB_DEBUGSYMBOLS + ), f"No debug symbols available at: {self.MONGODB_DEBUGSYMBOLS}" print("Running in Evergreen -- copying debug symbols to build context...") - shutil.copy(self.MONGODB_DEBUGSYMBOLS, - os.path.join(self.DOCKER_COMPOSE_BUILD_CONTEXT, "debug")) + shutil.copy( + self.MONGODB_DEBUGSYMBOLS, + os.path.join(self.DOCKER_COMPOSE_BUILD_CONTEXT, "debug"), + ) - print(f"Done setting up antithesis config image build context for `{self.suite_name}...") + print( + f"Done setting up antithesis config image build context for `{self.suite_name}..." + ) print("Building antithesis config image...") - self._docker_build(self.suite_name, f"{self.DOCKER_COMPOSE_BUILD_CONTEXT}/Dockerfile", - self.DOCKER_COMPOSE_BUILD_CONTEXT) + self._docker_build( + self.suite_name, + f"{self.DOCKER_COMPOSE_BUILD_CONTEXT}/Dockerfile", + self.DOCKER_COMPOSE_BUILD_CONTEXT, + ) print("Done building antithesis config image.") def build_workload_image(self): @@ -272,7 +314,9 @@ class DockerComposeImageBuilder: # Build docker image print("Building workload image...") - self._docker_build("workload", self.WORKLOAD_DOCKERFILE, self.WORKLOAD_BUILD_CONTEXT) + self._docker_build( + "workload", self.WORKLOAD_DOCKERFILE, self.WORKLOAD_BUILD_CONTEXT + ) print("Done building workload image.") def build_mongo_binaries_image(self): @@ -292,8 +336,11 @@ class DockerComposeImageBuilder: # Build docker image print("Building mongo binaries image...") - self._docker_build("mongo-binaries", self.MONGO_BINARIES_DOCKERFILE, - self.MONGO_BINARIES_BUILD_CONTEXT) + self._docker_build( + "mongo-binaries", + self.MONGO_BINARIES_DOCKERFILE, + self.MONGO_BINARIES_BUILD_CONTEXT, + ) print("Done building mongo binaries image.") def _fetch_mongodb_binaries(self): @@ -322,41 +369,56 @@ class DockerComposeImageBuilder: More info on `db-contrib-tool` available here: `https://github.com/10gen/db-contrib-tool` """ - assert subprocess.run(["which", - "db-contrib-tool"]).returncode == 0, db_contrib_tool_error + assert ( + subprocess.run(["which", "db-contrib-tool"]).returncode == 0 + ), db_contrib_tool_error # Use `db-contrib-tool` to get MongoDB binaries for this image print("Running Locally - Fetching All MongoDB binaries for image build...") - subprocess.run([ - "db-contrib-tool", - "setup-repro-env", - "--variant", - "ubuntu2204", - "--linkDir", - mongodb_binaries_destination, - "--installLastContinuous", - "--installLastLTS", - "master", - ], stdout=sys.stdout, stderr=sys.stderr, check=True) + subprocess.run( + [ + "db-contrib-tool", + "setup-repro-env", + "--variant", + "ubuntu2204", + "--linkDir", + mongodb_binaries_destination, + "--installLastContinuous", + "--installLastLTS", + "master", + ], + stdout=sys.stdout, + stderr=sys.stderr, + check=True, + ) elif self.in_evergreen: print( "Running in Evergreen - Fetching `last-continuous` and `last-lts` MongoDB binaries for image build..." ) - subprocess.run([ - "db-contrib-tool", - "setup-repro-env", - "--variant", - "ubuntu2204", - "--linkDir", - mongodb_binaries_destination, - "--installLastContinuous", - "--installLastLTS", - "--evergreenConfig", - "./.evergreen.yml", - ], stdout=sys.stdout, stderr=sys.stderr, check=True) + subprocess.run( + [ + "db-contrib-tool", + "setup-repro-env", + "--variant", + "ubuntu2204", + "--linkDir", + mongodb_binaries_destination, + "--installLastContinuous", + "--installLastLTS", + "--evergreenConfig", + "./.evergreen.yml", + ], + stdout=sys.stdout, + stderr=sys.stderr, + check=True, + ) # Verify the binaries were downloaded successfully - for required_binary in [self.MONGO_BINARY, self.MONGOD_BINARY, self.MONGOS_BINARY]: + for required_binary in [ + self.MONGO_BINARY, + self.MONGOD_BINARY, + self.MONGOS_BINARY, + ]: assert os.path.exists( required_binary ), f"Could not find Ubuntu 22.04 MongoDB binary at: {required_binary}" @@ -364,12 +426,14 @@ class DockerComposeImageBuilder: # Our official builds happen in Evergreen. # We want to ensure the binaries are linked with `libvoidstar.so` during image build. if self.in_evergreen: - assert "libvoidstar" in subprocess.run( - ["ldd", required_binary], - check=True, - capture_output=True, - ).stdout.decode( - "utf-8"), f"MongoDB binary is not linked to `libvoidstar.so`: {required_binary}" + assert ( + "libvoidstar" + in subprocess.run( + ["ldd", required_binary], + check=True, + capture_output=True, + ).stdout.decode("utf-8") + ), f"MongoDB binary is not linked to `libvoidstar.so`: {required_binary}" def _copy_mongo_binary_to_build_context(self, dir_path): """ @@ -403,8 +467,11 @@ class DockerComposeImageBuilder: # Instead, we rely on Evergreen's `git.get_project` which will correctly clone the repo and apply changes from the `patch`. if self.in_evergreen: assert os.path.exists( - mongo_repo_destination), f"No `mongo` repo available at: {mongo_repo_destination}" - print("Running in Evergreen -- no need to clone `mongo` repo since it already exists.") + mongo_repo_destination + ), f"No `mongo` repo available at: {mongo_repo_destination}" + print( + "Running in Evergreen -- no need to clone `mongo` repo since it already exists." + ) return # Clean up any old artifacts in the build context. @@ -438,7 +505,9 @@ class DockerComposeImageBuilder: print(f"\n\tFound existing QA repo at: {qa_repo_destination}\n") else: print("Cloning QA repo to build context...") - self._clone_repo("10gen", "QA", qa_repo_destination, get_expansion("github_token_qa")) + self._clone_repo( + "10gen", "QA", qa_repo_destination, get_expansion("github_token_qa") + ) print("Done cloning QA repo to build context.") def _clone_jstestfuzz_to_build_context(self, dir_path): @@ -451,7 +520,9 @@ class DockerComposeImageBuilder: # Clone jstestfuzz repo if it does not already exist if os.path.exists(jstestfuzz_repo_destination): - print(f"\n\tFound existing jstestfuzz repo at: {jstestfuzz_repo_destination}\n") + print( + f"\n\tFound existing jstestfuzz repo at: {jstestfuzz_repo_destination}\n" + ) else: print("Cloning jstestfuzz repo to build context...") self._clone_repo( @@ -523,7 +594,8 @@ class DockerComposeImageBuilder: # Our official builds happen in Evergreen. Assert a "real" `libvoidstar.so` is on system. if self.in_evergreen: assert os.path.exists( - self.LIBVOIDSTAR_PATH), f"No `libvoidstar.so` available at: {self.LIBVOIDSTAR_PATH}" + self.LIBVOIDSTAR_PATH + ), f"No `libvoidstar.so` available at: {self.LIBVOIDSTAR_PATH}" print("Running in Evergreen -- using system `libvoidstar.so`.") shutil.copy(self.LIBVOIDSTAR_PATH, libvoidstar_destination) print("Done copying `libvoidstar.so` from system to build context") @@ -548,7 +620,9 @@ class DockerComposeImageBuilder: url = f"https://x-access-token:{token}@github.com/{owner}/{repo}.git" else: print(f"No token found for {owner}/{repo} git repo, using ssh clone") - assert not self.in_evergreen, "SSH cloning should only be done when not in evergreen" + assert ( + not self.in_evergreen + ), "SSH cloning should only be done when not in evergreen" url = f"git@github.com:{owner}/{repo}.git" git.Repo.clone_from(url, destination) diff --git a/buildscripts/resmokelib/testing/executor.py b/buildscripts/resmokelib/testing/executor.py index e1d4dd4ff1b..59fed937bf8 100644 --- a/buildscripts/resmokelib/testing/executor.py +++ b/buildscripts/resmokelib/testing/executor.py @@ -35,9 +35,16 @@ class TestSuiteExecutor(object): _TIMEOUT = 24 * 60 * 60 # =1 day (a long time to have tests run) - def __init__(self, exec_logger: Logger, suite: Suite, config=None, - fixture: Optional[Fixture] = None, hooks=None, archive_instance=None, - archive=None): + def __init__( + self, + exec_logger: Logger, + suite: Suite, + config=None, + fixture: Optional[Fixture] = None, + hooks=None, + archive_instance=None, + archive=None, + ): """Initialize the TestSuiteExecutor with the test suite to run.""" self.logger = exec_logger @@ -46,7 +53,7 @@ class TestSuiteExecutor(object): # specified in the YAML configuration to be the external fixture. self.fixture_config = { "class": fixtures.EXTERNAL_FIXTURE_CLASS, - "shell_conn_string": _config.SHELL_CONN_STRING + "shell_conn_string": _config.SHELL_CONN_STRING, } else: self.fixture_config = fixture @@ -56,8 +63,9 @@ class TestSuiteExecutor(object): self.archival = None if archive_instance: - self.archival = archival.HookTestArchival(suite, self.hooks_config, archive_instance, - archive) + self.archival = archival.HookTestArchival( + suite, self.hooks_config, archive_instance, archive + ) self._suite = suite self.test_queue_logger = logging.loggers.new_testqueue_logger(suite.test_kind) @@ -105,8 +113,9 @@ class TestSuiteExecutor(object): # We use the 'hook_failure_flag' to distinguish hook failures from other failures, # so that we can return a separate return code when a hook has failed. hook_failure_flag = threading.Event() - (report, interrupted) = self._run_tests(test_queue, setup_flag, teardown_flag, - hook_failure_flag) + (report, interrupted) = self._run_tests( + test_queue, setup_flag, teardown_flag, hook_failure_flag + ) self._suite.record_test_end(report) @@ -144,7 +153,9 @@ class TestSuiteExecutor(object): if test_results_num < test_queue.num_tests: raise errors.ResmokeError( "{} reported tests is less than {} expected tests".format( - test_results_num, test_queue.num_tests)) + test_results_num, test_queue.num_tests + ) + ) # Clear the report so it can be reused for the next execution. for job in self._jobs: @@ -158,11 +169,11 @@ class TestSuiteExecutor(object): self._suite.return_code = return_code def _run_tests( - self, - test_queue: 'TestQueue[Union[QueueElemRepeatTime, QueueElem]]', - setup_flag: Optional[threading.Event], - teardown_flag: Optional[threading.Event], - hook_failure_flag: Optional[threading.Event], + self, + test_queue: "TestQueue[Union[QueueElemRepeatTime, QueueElem]]", + setup_flag: Optional[threading.Event], + teardown_flag: Optional[threading.Event], + hook_failure_flag: Optional[threading.Event], ): """Start a thread for each Job instance and block until all of the tests are run. @@ -178,12 +189,15 @@ class TestSuiteExecutor(object): # Run each Job instance in its own thread. for job in self._jobs: thr = threading.Thread( - target=job, args=(test_queue, interrupt_flag), kwargs=dict( + target=job, + args=(test_queue, interrupt_flag), + kwargs=dict( parent_context=context.get_current(), setup_flag=setup_flag, teardown_flag=teardown_flag, hook_failure_flag=hook_failure_flag, - )) + ), + ) # Do not wait for tests to finish executing if interrupted by the user. thr.daemon = True thr.start() @@ -234,8 +248,11 @@ class TestSuiteExecutor(object): success = True for job in self._jobs: if not job.manager.teardown_fixture(self.logger): - self.logger.warning("Teardown of %s of job %s was not successful", job.fixture, - job.job_num) + self.logger.warning( + "Teardown of %s of job %s was not successful", + job.fixture, + job.job_num, + ) success = False return success @@ -250,7 +267,9 @@ class TestSuiteExecutor(object): fixture_logger = logging.loggers.new_fixture_logger(fixture_class, job_num) - return fixtures.make_fixture(fixture_class, fixture_logger, job_num, **fixture_config) + return fixtures.make_fixture( + fixture_class, fixture_logger, job_num, **fixture_config + ) def _make_hooks(self, fixture, job_num) -> List[Hook]: """Create the hooks for the job's fixture.""" @@ -281,8 +300,16 @@ class TestSuiteExecutor(object): report = _report.TestReport(job_logger, self._suite.options, job_num) - return _job.Job(job_num, job_logger, fixture, hooks, report, self.archival, - self._suite.options, self.test_queue_logger) + return _job.Job( + job_num, + job_logger, + fixture, + hooks, + report, + self.archival, + self._suite.options, + self.test_queue_logger, + ) def _create_queue_elem_for_test_name(self, test_name): """ @@ -291,11 +318,12 @@ class TestSuiteExecutor(object): :param test_name: Name of test to be queued. :return: queue_elem representing the test_name to be run. """ - test_case = testcases.make_test_case(self._suite.test_kind, self.test_queue_logger, - test_name, **self.test_config) + test_case = testcases.make_test_case( + self._suite.test_kind, self.test_queue_logger, test_name, **self.test_config + ) return queue_elem_factory(test_case, self.test_config, self._suite.options) - def _make_test_queue(self) -> 'TestQueue[Union[QueueElemRepeatTime, QueueElem]]': + def _make_test_queue(self) -> "TestQueue[Union[QueueElemRepeatTime, QueueElem]]": """ Create a queue of test cases to run. @@ -324,11 +352,13 @@ class TestSuiteExecutor(object): def _log_timeout_warning(self, seconds): """Log a message if any thread fails to terminate after `seconds`.""" self.logger.warning( - '*** Still waiting for processes to terminate after %s seconds. Try using ctrl-\\ ' - 'to send a SIGQUIT on Linux or ctrl-c again on Windows ***', seconds) + "*** Still waiting for processes to terminate after %s seconds. Try using ctrl-\\ " + "to send a SIGQUIT on Linux or ctrl-c again on Windows ***", + seconds, + ) -T = TypeVar('T') +T = TypeVar("T") class TestQueue(_queue.Queue, Generic[T]): @@ -341,13 +371,18 @@ class TestQueue(_queue.Queue, Generic[T]): def __init__(self): """Initialize test queue.""" self.num_tests = 0 - self.max_test_queue_size = utils.default_if_none(_config.MAX_TEST_QUEUE_SIZE, -1) + self.max_test_queue_size = utils.default_if_none( + _config.MAX_TEST_QUEUE_SIZE, -1 + ) super().__init__() def add_test_cases(self, test_cases: List[QueueElem]) -> None: """Add test cases to the queue.""" for test_case in test_cases: - if self.max_test_queue_size < 0 or self.num_tests < self.max_test_queue_size: + if ( + self.max_test_queue_size < 0 + or self.num_tests < self.max_test_queue_size + ): self.put(test_case) self.num_tests += 1 else: diff --git a/buildscripts/resmokelib/testing/fixtures/_builder.py b/buildscripts/resmokelib/testing/fixtures/_builder.py index 7d5a55f27f2..e1edebf47ba 100644 --- a/buildscripts/resmokelib/testing/fixtures/_builder.py +++ b/buildscripts/resmokelib/testing/fixtures/_builder.py @@ -1,4 +1,5 @@ """Utilities for constructing fixtures that may span multiple versions.""" + import logging import threading from abc import ABC, abstractmethod @@ -29,7 +30,9 @@ RETRIEVE_LOCK = threading.Lock() _BUILDERS = {} # type: ignore -def make_fixture(class_name, logger, job_num, *args, enable_feature_flags=True, **kwargs): +def make_fixture( + class_name, logger, job_num, *args, enable_feature_flags=True, **kwargs +): """Provide factory function for creating Fixture instances.""" fixturelib = FixtureLib() @@ -44,8 +47,14 @@ def make_fixture(class_name, logger, job_num, *args, enable_feature_flags=True, # Special case MongoDFixture or _MongosFixture for now since we only add one option. # If there's more logic, we should add a builder class for them. if class_name in ["MongoDFixture", "_MongoSFixture"] and enable_feature_flags: - return _FIXTURES[class_name](logger, job_num, fixturelib, *args, - add_feature_flags=bool(config.ENABLED_FEATURE_FLAGS), **kwargs) + return _FIXTURES[class_name]( + logger, + job_num, + fixturelib, + *args, + add_feature_flags=bool(config.ENABLED_FEATURE_FLAGS), + **kwargs, + ) return _FIXTURES[class_name](logger, job_num, fixturelib, *args, **kwargs) @@ -70,7 +79,9 @@ def make_dummy_fixture(suite_name): return make_fixture(fixture_class, fixture_logger, job_num=0, **fixture_config) -class FixtureBuilder(ABC, metaclass=registry.make_registry_metaclass(_BUILDERS, type(ABC))): # pylint: disable=invalid-metaclass +class FixtureBuilder( + ABC, metaclass=registry.make_registry_metaclass(_BUILDERS, type(ABC)) +): # pylint: disable=invalid-metaclass """ ABC for fixture builders. @@ -82,7 +93,9 @@ class FixtureBuilder(ABC, metaclass=registry.make_registry_metaclass(_BUILDERS, REGISTERED_NAME = "Builder" @abstractmethod - def build_fixture(self, logger, job_num, fixturelib, *args, existing_nodes=None, **kwargs): + def build_fixture( + self, logger, job_num, fixturelib, *args, existing_nodes=None, **kwargs + ): """Abstract method to build a fixture.""" return @@ -90,8 +103,8 @@ class FixtureBuilder(ABC, metaclass=registry.make_registry_metaclass(_BUILDERS, class BinVersionEnum(object): """Enumeration version types.""" - OLD = 'old' - NEW = 'new' + OLD = "old" + NEW = "new" class FixtureContainer(object): @@ -103,7 +116,10 @@ class FixtureContainer(object): """Initialize FixtureContainer.""" if old_fixture is not None: - self._fixtures = {BinVersionEnum.NEW: new_fixture, BinVersionEnum.OLD: old_fixture} + self._fixtures = { + BinVersionEnum.NEW: new_fixture, + BinVersionEnum.OLD: old_fixture, + } self.cur_version_cls = self._fixtures[cur_version] else: # No need to support dictionary of fixture classes if only a single version of @@ -140,7 +156,8 @@ class FixtureContainer(object): def _extract_multiversion_options( - kwargs: Dict[str, Any]) -> Tuple[Optional[List[str]], Optional[str]]: + kwargs: Dict[str, Any], +) -> Tuple[Optional[List[str]], Optional[str]]: """Pop multiversion options from kwargs dict and return them. :param kwargs: fixture kwargs @@ -168,9 +185,15 @@ class ReplSetBuilder(FixtureBuilder): REGISTERED_NAME = "ReplicaSetFixture" LATEST_MONGOD_CLASS = "MongoDFixture" - def build_fixture(self, logger: logging.Logger, job_num: int, fixturelib: Type[FixtureLib], - *args, existing_nodes: Optional[List[MongoDFixture]] = None, - **kwargs) -> ReplicaSetFixture: + def build_fixture( + self, + logger: logging.Logger, + job_num: int, + fixturelib: Type[FixtureLib], + *args, + existing_nodes: Optional[List[MongoDFixture]] = None, + **kwargs, + ) -> ReplicaSetFixture: """Build a replica set. :param logger: fixture logger @@ -184,10 +207,13 @@ class ReplSetBuilder(FixtureBuilder): self._mutate_kwargs(kwargs) mixed_bin_versions, old_bin_version = _extract_multiversion_options(kwargs) self._validate_multiversion_options(kwargs, mixed_bin_versions) - mongod_class, mongod_executables, mongod_binary_versions = self._get_mongod_assets( - kwargs, mixed_bin_versions, old_bin_version) + mongod_class, mongod_executables, mongod_binary_versions = ( + self._get_mongod_assets(kwargs, mixed_bin_versions, old_bin_version) + ) - replset = _FIXTURES[self.REGISTERED_NAME](logger, job_num, fixturelib, *args, **kwargs) + replset = _FIXTURES[self.REGISTERED_NAME]( + logger, job_num, fixturelib, *args, **kwargs + ) is_multiversion = mixed_bin_versions is not None fcv = self._get_fcv(is_multiversion, old_bin_version) @@ -202,17 +228,29 @@ class ReplSetBuilder(FixtureBuilder): return replset for node_index in range(replset.num_nodes): - node = self._new_mongod(replset, node_index, mongod_executables, mongod_class, - mongod_binary_versions[node_index], is_multiversion, - launch_mongot) + node = self._new_mongod( + replset, + node_index, + mongod_executables, + mongod_class, + mongod_binary_versions[node_index], + is_multiversion, + launch_mongot, + ) replset.install_mongod(node) if replset.start_initial_sync_node: if not replset.initial_sync_node: replset.initial_sync_node_idx = replset.num_nodes replset.initial_sync_node = self._new_mongod( - replset, replset.initial_sync_node_idx, mongod_executables, mongod_class, - BinVersionEnum.NEW, is_multiversion, launch_mongot) + replset, + replset.initial_sync_node_idx, + mongod_executables, + mongod_class, + BinVersionEnum.NEW, + is_multiversion, + launch_mongot, + ) return replset @@ -227,13 +265,16 @@ class ReplSetBuilder(FixtureBuilder): kwargs["num_nodes"] = num_nodes mongod_executable = default_if_none( - kwargs.get("mongod_executable"), config.MONGOD_EXECUTABLE, - config.DEFAULT_MONGOD_EXECUTABLE) + kwargs.get("mongod_executable"), + config.MONGOD_EXECUTABLE, + config.DEFAULT_MONGOD_EXECUTABLE, + ) kwargs["mongod_executable"] = mongod_executable @staticmethod - def _validate_multiversion_options(kwargs: Dict[str, Any], - mixed_bin_versions: Optional[List[str]]) -> None: + def _validate_multiversion_options( + kwargs: Dict[str, Any], mixed_bin_versions: Optional[List[str]] + ) -> None: """Error out if the number of binary versions does not match the number of nodes in replica set. :param kwargs: sharded cluster fixture kwargs @@ -242,18 +283,25 @@ class ReplSetBuilder(FixtureBuilder): if mixed_bin_versions is not None: num_versions = len(mixed_bin_versions) replset_config_options = kwargs.get("replset_config_options", {}) - is_config_svr = "configsvr" in replset_config_options and replset_config_options[ - "configsvr"] + is_config_svr = ( + "configsvr" in replset_config_options + and replset_config_options["configsvr"] + ) if num_versions != kwargs["num_nodes"] and not is_config_svr: - msg = ("The number of binary versions specified: {} do not match the number of" - " nodes in the replica set: {}.").format(num_versions, kwargs["num_nodes"]) + msg = ( + "The number of binary versions specified: {} do not match the number of" + " nodes in the replica set: {}." + ).format(num_versions, kwargs["num_nodes"]) raise errors.ServerFailure(msg) @classmethod def _get_mongod_assets( - cls, kwargs: Dict[str, Any], mixed_bin_versions: Optional[List[str]], - old_bin_version: Optional[str]) -> Tuple[Dict[str, str], Dict[str, str], List[str]]: + cls, + kwargs: Dict[str, Any], + mixed_bin_versions: Optional[List[str]], + old_bin_version: Optional[str], + ) -> Tuple[Dict[str, str], Dict[str, str], List[str]]: """Make dicts with mongod new/old class and executable names and binary versions. :param kwargs: sharded cluster fixture kwargs @@ -272,10 +320,8 @@ class ReplSetBuilder(FixtureBuilder): from buildscripts.resmokelib import multiversionconstants old_mongod_version = { - config.MultiversionOptions.LAST_LTS: - multiversionconstants.LAST_LTS_MONGOD_BINARY, - config.MultiversionOptions.LAST_CONTINUOUS: - multiversionconstants.LAST_CONTINUOUS_MONGOD_BINARY, + config.MultiversionOptions.LAST_LTS: multiversionconstants.LAST_LTS_MONGOD_BINARY, + config.MultiversionOptions.LAST_CONTINUOUS: multiversionconstants.LAST_CONTINUOUS_MONGOD_BINARY, }[old_bin_version] executables[BinVersionEnum.OLD] = old_mongod_version @@ -296,18 +342,22 @@ class ReplSetBuilder(FixtureBuilder): fcv = multiversionconstants.LATEST_FCV if is_multiversion: fcv = { - config.MultiversionOptions.LAST_LTS: - multiversionconstants.LAST_LTS_FCV, - config.MultiversionOptions.LAST_CONTINUOUS: - multiversionconstants.LAST_CONTINUOUS_FCV, + config.MultiversionOptions.LAST_LTS: multiversionconstants.LAST_LTS_FCV, + config.MultiversionOptions.LAST_CONTINUOUS: multiversionconstants.LAST_CONTINUOUS_FCV, }[old_bin_version] return fcv @staticmethod - def _new_mongod(replset: ReplicaSetFixture, replset_node_index: int, - executables: Dict[str, str], _class: str, cur_version: str, - is_multiversion: bool, launch_mongot: bool) -> FixtureContainer: + def _new_mongod( + replset: ReplicaSetFixture, + replset_node_index: int, + executables: Dict[str, str], + _class: str, + cur_version: str, + is_multiversion: bool, + launch_mongot: bool, + ) -> FixtureContainer: """Make a fixture container with configured mongod fixture(s) in it. In non-multiversion mode only a new mongod fixture will be in the fixture container. @@ -330,9 +380,14 @@ class ReplSetBuilder(FixtureBuilder): if is_multiversion: # We do not run old versions with feature flags enabled old_fixture = make_fixture( - _class, mongod_logger, replset.job_num, enable_feature_flags=False, - mongod_executable=executables[BinVersionEnum.OLD], mongod_options=mongod_options, - preserve_dbpath=replset.preserve_dbpath) + _class, + mongod_logger, + replset.job_num, + enable_feature_flags=False, + mongod_executable=executables[BinVersionEnum.OLD], + mongod_options=mongod_options, + preserve_dbpath=replset.preserve_dbpath, + ) # Assign the same port for old and new fixtures so upgrade/downgrade can be done without # changing the replicaset config. @@ -344,11 +399,16 @@ class ReplSetBuilder(FixtureBuilder): if is_multiversion: new_fixture_mongod_options["upgradeBackCompat"] = "" - new_fixture = make_fixture(_class, mongod_logger, replset.job_num, - mongod_executable=executables[BinVersionEnum.NEW], - mongod_options=new_fixture_mongod_options, - preserve_dbpath=replset.preserve_dbpath, port=new_fixture_port, - launch_mongot=launch_mongot) + new_fixture = make_fixture( + _class, + mongod_logger, + replset.job_num, + mongod_executable=executables[BinVersionEnum.NEW], + mongod_options=new_fixture_mongod_options, + preserve_dbpath=replset.preserve_dbpath, + port=new_fixture_port, + launch_mongot=launch_mongot, + ) return FixtureContainer(new_fixture, old_fixture, cur_version) @@ -359,7 +419,7 @@ def get_package_name(dir_path: str) -> str: :param dir_path: relative directory path :return: python package name """ - return dir_path.replace('/', '.').replace("\\", ".") + return dir_path.replace("/", ".").replace("\\", ".") class ShardedClusterBuilder(FixtureBuilder): @@ -368,8 +428,14 @@ class ShardedClusterBuilder(FixtureBuilder): REGISTERED_NAME = "ShardedClusterFixture" LATEST_MONGOS_CLASS = "_MongoSFixture" - def build_fixture(self, logger: logging.Logger, job_num: int, fixturelib: Type[FixtureLib], - *args, **kwargs) -> ShardedClusterFixture: + def build_fixture( + self, + logger: logging.Logger, + job_num: int, + fixturelib: Type[FixtureLib], + *args, + **kwargs, + ) -> ShardedClusterFixture: """Build a sharded cluster. :param logger: fixture logger @@ -382,13 +448,17 @@ class ShardedClusterBuilder(FixtureBuilder): is_multiversion = mixed_bin_versions is not None is_config_shard = kwargs["config_shard"] is not None self._validate_multiversion_options(kwargs, mixed_bin_versions) - self._validate_embedded_router_mode_options(kwargs, is_config_shard, is_multiversion) + self._validate_embedded_router_mode_options( + kwargs, is_config_shard, is_multiversion + ) - mongos_class, mongos_executables = self._get_mongos_assets(kwargs, mixed_bin_versions, - old_bin_version) + mongos_class, mongos_executables = self._get_mongos_assets( + kwargs, mixed_bin_versions, old_bin_version + ) - sharded_cluster = _FIXTURES[self.REGISTERED_NAME](logger, job_num, fixturelib, *args, - **kwargs) + sharded_cluster = _FIXTURES[self.REGISTERED_NAME]( + logger, job_num, fixturelib, *args, **kwargs + ) config_shard = kwargs["config_shard"] config_svr = None @@ -397,11 +467,18 @@ class ShardedClusterBuilder(FixtureBuilder): # currently hold collection data, a mongot enabled shared cluster doesn't couple/launch # the config server with an accompanying mongot if config_shard is None: - config_svr = self._new_configsvr(sharded_cluster, is_multiversion, old_bin_version) + config_svr = self._new_configsvr( + sharded_cluster, is_multiversion, old_bin_version + ) else: - config_svr = self._new_rs_shard(sharded_cluster, mixed_bin_versions, old_bin_version, - config_shard, kwargs["num_rs_nodes_per_shard"], - launch_mongot=False) + config_svr = self._new_rs_shard( + sharded_cluster, + mixed_bin_versions, + old_bin_version, + config_shard, + kwargs["num_rs_nodes_per_shard"], + launch_mongot=False, + ) sharded_cluster.install_configsvr(config_svr) # Persist a list of all nodes from the cluster with a boolean that indicates if that node @@ -411,9 +488,14 @@ class ShardedClusterBuilder(FixtureBuilder): launch_mongot = kwargs.get("launch_mongot") for rs_shard_index in range(kwargs["num_shards"]): if rs_shard_index != config_shard: - rs_shard = self._new_rs_shard(sharded_cluster, mixed_bin_versions, old_bin_version, - rs_shard_index, kwargs["num_rs_nodes_per_shard"], - launch_mongot) + rs_shard = self._new_rs_shard( + sharded_cluster, + mixed_bin_versions, + old_bin_version, + rs_shard_index, + kwargs["num_rs_nodes_per_shard"], + launch_mongot, + ) sharded_cluster.install_rs_shard(rs_shard) # Extend the list of nodes to be sure configsvr nodes are placed at first places. nodes.extend([(node, False) for node in rs_shard._all_mongo_d_s_t()]) @@ -424,13 +506,20 @@ class ShardedClusterBuilder(FixtureBuilder): def install_router(): if not kwargs.get("embedded_router", None): - mongos = self._new_mongos(sharded_cluster, mongos_executables, mongos_class, - mongos_index, num_routers, is_multiversion) + mongos = self._new_mongos( + sharded_cluster, + mongos_executables, + mongos_class, + mongos_index, + num_routers, + is_multiversion, + ) sharded_cluster.install_mongos(mongos) else: node = nodes.pop(0) - router_view = self._new_router_view(sharded_cluster, mongos_index, num_routers, - node[0], node[1]) + router_view = self._new_router_view( + sharded_cluster, mongos_index, num_routers, node[0], node[1] + ) sharded_cluster.install_mongos(router_view) for mongos_index in range(num_routers): @@ -449,29 +538,42 @@ class ShardedClusterBuilder(FixtureBuilder): kwargs["num_shards"] = num_shards num_rs_nodes_per_shard = kwargs.pop("num_rs_nodes_per_shard", 1) - num_rs_nodes_per_shard = num_rs_nodes_per_shard if not config.NUM_REPLSET_NODES else config.NUM_REPLSET_NODES + num_rs_nodes_per_shard = ( + num_rs_nodes_per_shard + if not config.NUM_REPLSET_NODES + else config.NUM_REPLSET_NODES + ) kwargs["num_rs_nodes_per_shard"] = num_rs_nodes_per_shard num_mongos = kwargs.pop("num_mongos", 1) kwargs["num_mongos"] = num_mongos mongos_executable = default_if_none( - kwargs.get("mongos_executable"), config.MONGOS_EXECUTABLE, - config.DEFAULT_MONGOS_EXECUTABLE) + kwargs.get("mongos_executable"), + config.MONGOS_EXECUTABLE, + config.DEFAULT_MONGOS_EXECUTABLE, + ) kwargs["mongos_executable"] = mongos_executable config_shard = pick_catalog_shard_node( - kwargs.pop("config_shard", config.CONFIG_SHARD), num_shards) + kwargs.pop("config_shard", config.CONFIG_SHARD), num_shards + ) # Currently the auto_boostrap_procedure requires us to have a config_shard - if "use_auto_bootstrap_procedure" in kwargs and kwargs[ - "use_auto_bootstrap_procedure"] and not config_shard: + if ( + "use_auto_bootstrap_procedure" in kwargs + and kwargs["use_auto_bootstrap_procedure"] + and not config_shard + ): config_shard = 0 - kwargs["embedded_router"] = kwargs.pop("embedded_router", config.EMBEDDED_ROUTER) + kwargs["embedded_router"] = kwargs.pop( + "embedded_router", config.EMBEDDED_ROUTER + ) kwargs["config_shard"] = config_shard @staticmethod - def _validate_multiversion_options(kwargs: Dict[str, Any], - mixed_bin_versions: Optional[List[str]]) -> None: + def _validate_multiversion_options( + kwargs: Dict[str, Any], mixed_bin_versions: Optional[List[str]] + ) -> None: """Error out if the number of binary versions does not match the number of nodes in sharded cluster. :param kwargs: sharded cluster fixture kwargs @@ -482,13 +584,16 @@ class ShardedClusterBuilder(FixtureBuilder): num_mongods = kwargs["num_shards"] * kwargs["num_rs_nodes_per_shard"] if len_versions != num_mongods: - msg = ("The number of binary versions specified: {} do not match the number of" - " nodes in the sharded cluster: {}.").format(len_versions, num_mongods) + msg = ( + "The number of binary versions specified: {} do not match the number of" + " nodes in the sharded cluster: {}." + ).format(len_versions, num_mongods) raise errors.ServerFailure(msg) @staticmethod - def _validate_embedded_router_mode_options(kwargs: Dict[str, Any], is_config_shard: bool, - is_multiversion: bool) -> None: + def _validate_embedded_router_mode_options( + kwargs: Dict[str, Any], is_config_shard: bool, is_multiversion: bool + ) -> None: """Raise an exception if the configuration for the sharded cluster can't support embedded_router_mode. :param kwargs: sharded cluster fixture kwargs. @@ -504,7 +609,10 @@ class ShardedClusterBuilder(FixtureBuilder): # Add the configsvr as a mongos if it is not already counted as a config shard. if not is_config_shard: num_configsvr_nodes = 1 - if "configsvr_options" in kwargs and "num_nodes" in kwargs["configsvr_options"]: + if ( + "configsvr_options" in kwargs + and "num_nodes" in kwargs["configsvr_options"] + ): num_configsvr_nodes = kwargs["configsvr_options"]["num_nodes"] num_routers += num_configsvr_nodes @@ -513,11 +621,17 @@ class ShardedClusterBuilder(FixtureBuilder): "When running in embedded router mode, num_mongos must be <= the total number of shardsvrs in the cluster." ) if is_multiversion: - raise ValueError("Embedded router mode does not support multiversion testing.") + raise ValueError( + "Embedded router mode does not support multiversion testing." + ) @classmethod - def _get_mongos_assets(cls, kwargs: Dict[str, Any], mixed_bin_versions: Optional[List[str]], - old_bin_version: Optional[str]) -> Tuple[Dict[str, str], Dict[str, str]]: + def _get_mongos_assets( + cls, + kwargs: Dict[str, Any], + mixed_bin_versions: Optional[List[str]], + old_bin_version: Optional[str], + ) -> Tuple[Dict[str, str], Dict[str, str]]: """Make dicts with mongos new/old class and executable names. :param kwargs: sharded cluster fixture kwargs @@ -533,18 +647,19 @@ class ShardedClusterBuilder(FixtureBuilder): from buildscripts.resmokelib import multiversionconstants old_mongos_version = { - config.MultiversionOptions.LAST_LTS: - multiversionconstants.LAST_LTS_MONGOS_BINARY, - config.MultiversionOptions.LAST_CONTINUOUS: - multiversionconstants.LAST_CONTINUOUS_MONGOS_BINARY, + config.MultiversionOptions.LAST_LTS: multiversionconstants.LAST_LTS_MONGOS_BINARY, + config.MultiversionOptions.LAST_CONTINUOUS: multiversionconstants.LAST_CONTINUOUS_MONGOS_BINARY, }[old_bin_version] executables[BinVersionEnum.OLD] = old_mongos_version return _class, executables @staticmethod - def _new_configsvr(sharded_cluster: ShardedClusterFixture, is_multiversion: bool, - old_bin_version: Optional[str]) -> ReplicaSetFixture: + def _new_configsvr( + sharded_cluster: ShardedClusterFixture, + is_multiversion: bool, + old_bin_version: Optional[str], + ) -> ReplicaSetFixture: """Return a replica set fixture configured as the config server. :param sharded_cluster: sharded cluster fixture we are configuring config server for @@ -562,15 +677,24 @@ class ShardedClusterBuilder(FixtureBuilder): # server nodes will always be fully upgraded before shard nodes. mixed_bin_versions = [BinVersionEnum.NEW] * 2 - return make_fixture("ReplicaSetFixture", configsvr_logger, sharded_cluster.job_num, - mixed_bin_versions=mixed_bin_versions, old_bin_version=old_bin_version, - **configsvr_kwargs) + return make_fixture( + "ReplicaSetFixture", + configsvr_logger, + sharded_cluster.job_num, + mixed_bin_versions=mixed_bin_versions, + old_bin_version=old_bin_version, + **configsvr_kwargs, + ) @staticmethod - def _new_rs_shard(sharded_cluster: ShardedClusterFixture, - mixed_bin_versions: Optional[List[str]], old_bin_version: Optional[str], - rs_shard_index: int, num_rs_nodes_per_shard: int, - launch_mongot: bool) -> ReplicaSetFixture: + def _new_rs_shard( + sharded_cluster: ShardedClusterFixture, + mixed_bin_versions: Optional[List[str]], + old_bin_version: Optional[str], + rs_shard_index: int, + num_rs_nodes_per_shard: int, + launch_mongot: bool, + ) -> ReplicaSetFixture: """Return a replica set fixture configured as a shard in a sharded cluster. :param sharded_cluster: sharded cluster fixture we are configuring config server for @@ -587,17 +711,29 @@ class ShardedClusterBuilder(FixtureBuilder): if mixed_bin_versions is not None: start_index = rs_shard_index * num_rs_nodes_per_shard - mixed_bin_versions = mixed_bin_versions[start_index:start_index + - num_rs_nodes_per_shard] + mixed_bin_versions = mixed_bin_versions[ + start_index : start_index + num_rs_nodes_per_shard + ] - return make_fixture("ReplicaSetFixture", rs_shard_logger, sharded_cluster.job_num, - num_nodes=num_rs_nodes_per_shard, mixed_bin_versions=mixed_bin_versions, - old_bin_version=old_bin_version, **rs_shard_kwargs) + return make_fixture( + "ReplicaSetFixture", + rs_shard_logger, + sharded_cluster.job_num, + num_nodes=num_rs_nodes_per_shard, + mixed_bin_versions=mixed_bin_versions, + old_bin_version=old_bin_version, + **rs_shard_kwargs, + ) @staticmethod - def _new_mongos(sharded_cluster: ShardedClusterFixture, executables: Dict[str, str], - _class: str, mongos_index: int, total: int, - is_multiversion: bool) -> FixtureContainer: + def _new_mongos( + sharded_cluster: ShardedClusterFixture, + executables: Dict[str, str], + _class: str, + mongos_index: int, + total: int, + is_multiversion: bool, + ) -> FixtureContainer: """Make a fixture container with configured mongos fixture(s) in it. In non-multiversion mode only a new mongos fixture will be in the fixture container. @@ -620,8 +756,13 @@ class ShardedClusterBuilder(FixtureBuilder): if is_multiversion: # We do not run old versions with feature flags enabled old_fixture = make_fixture( - _class, mongos_logger, sharded_cluster.job_num, enable_feature_flags=False, - mongos_executable=executables[BinVersionEnum.OLD], **mongos_kwargs) + _class, + mongos_logger, + sharded_cluster.job_num, + enable_feature_flags=False, + mongos_executable=executables[BinVersionEnum.OLD], + **mongos_kwargs, + ) # We can't restart mongos since explicit ports are not supported. new_fixture_mongos_kwargs = sharded_cluster.get_mongos_kwargs() @@ -630,23 +771,40 @@ class ShardedClusterBuilder(FixtureBuilder): if is_multiversion: new_fixture_mongos_kwargs["mongos_options"]["upgradeBackCompat"] = "" - new_fixture = make_fixture(_class, mongos_logger, sharded_cluster.job_num, - mongos_executable=executables[BinVersionEnum.NEW], - **new_fixture_mongos_kwargs) + new_fixture = make_fixture( + _class, + mongos_logger, + sharded_cluster.job_num, + mongos_executable=executables[BinVersionEnum.NEW], + **new_fixture_mongos_kwargs, + ) # Always spin up an old mongos if in multiversion mode given mongos is the last thing in the update path. - return FixtureContainer(new_fixture, old_fixture, - BinVersionEnum.OLD if is_multiversion else BinVersionEnum.NEW) + return FixtureContainer( + new_fixture, + old_fixture, + BinVersionEnum.OLD if is_multiversion else BinVersionEnum.NEW, + ) @staticmethod - def _new_router_view(sharded_cluster: ShardedClusterFixture, mongos_index: int, total: int, - mongod: MongoDFixture, is_configsvr: bool) -> _RouterView: + def _new_router_view( + sharded_cluster: ShardedClusterFixture, + mongos_index: int, + total: int, + mongod: MongoDFixture, + is_configsvr: bool, + ) -> _RouterView: """Make a fixture that allows ShardedClusterFixture to treat a shardsvr as a router.""" router_logger = sharded_cluster.get_mongos_logger(mongos_index, total) router_kwargs = {} router_kwargs["mongod"] = mongod - fix = make_fixture("_RouterView", router_logger, sharded_cluster.job_num, is_configsvr, - **router_kwargs) + fix = make_fixture( + "_RouterView", + router_logger, + sharded_cluster.job_num, + is_configsvr, + **router_kwargs, + ) return fix diff --git a/buildscripts/resmokelib/testing/fixtures/bulk_write.py b/buildscripts/resmokelib/testing/fixtures/bulk_write.py index a8fde22ef53..8fb130cacb2 100644 --- a/buildscripts/resmokelib/testing/fixtures/bulk_write.py +++ b/buildscripts/resmokelib/testing/fixtures/bulk_write.py @@ -8,40 +8,59 @@ from buildscripts.resmokelib.testing.fixtures import interface class BulkWriteFixture(interface.MultiClusterFixture): """Fixture which provides JSTests with a set of clusters to run tests against.""" - def __init__(self, logger, job_num, fixturelib, cluster_options, dbpath_prefix=None, - preserve_dbpath=False, requires_auth=False): + def __init__( + self, + logger, + job_num, + fixturelib, + cluster_options, + dbpath_prefix=None, + preserve_dbpath=False, + requires_auth=False, + ): """Initialize BulkWriteFixture with different options.""" - interface.MultiClusterFixture.__init__(self, logger, job_num, fixturelib, - dbpath_prefix=dbpath_prefix) + interface.MultiClusterFixture.__init__( + self, logger, job_num, fixturelib, dbpath_prefix=dbpath_prefix + ) self.setup_complete = False self.clusters = [] # cluster_options will be used for the bulkWrite cluster. - cluster_options["settings"] = self.fixturelib.default_if_none(cluster_options["settings"], - {}) - if "preserve_dbpath" not in cluster_options["settings"]\ - or cluster_options["settings"]["preserve_dbpath"] is None: + cluster_options["settings"] = self.fixturelib.default_if_none( + cluster_options["settings"], {} + ) + if ( + "preserve_dbpath" not in cluster_options["settings"] + or cluster_options["settings"]["preserve_dbpath"] is None + ): cluster_options["settings"]["preserve_dbpath"] = preserve_dbpath # The "dbpath_prefix" needs to be under "settings" for replicasets # but also under "mongod_options" for sharded clusters. - cluster_options["settings"]["dbpath_prefix"] = os.path.join(self._dbpath_prefix, - "bulkWriteCluster") + cluster_options["settings"]["dbpath_prefix"] = os.path.join( + self._dbpath_prefix, "bulkWriteCluster" + ) if cluster_options["class"] == "ReplicaSetFixture": cluster_options["settings"]["replicaset_logging_prefix"] = "bw" cluster_options["settings"]["dbpath_prefix"] = os.path.join( - self._dbpath_prefix, "bulkWriteCluster") + self._dbpath_prefix, "bulkWriteCluster" + ) elif cluster_options["class"] == "ShardedClusterFixture": cluster_options["settings"]["cluster_logging_prefix"] = "bw" else: raise ValueError(f"Illegal fixture class: {cluster_options['class']}") self.clusters.append( - self.fixturelib.make_fixture(cluster_options["class"], self.logger, self.job_num, - **cluster_options["settings"])) + self.fixturelib.make_fixture( + cluster_options["class"], + self.logger, + self.job_num, + **cluster_options["settings"], + ) + ) # The cluster where normal writes will be executed has set options. normal_cluster_options = {} @@ -49,17 +68,25 @@ class BulkWriteFixture(interface.MultiClusterFixture): normal_cluster_options["settings"]["mongod_options"] = {} normal_cluster_options["settings"]["mongod_options"]["set_parameters"] = {} normal_cluster_options["settings"]["mongod_options"]["set_parameters"][ - "enableTestCommands"] = 1 + "enableTestCommands" + ] = 1 normal_cluster_options["settings"]["num_nodes"] = 1 normal_cluster_options["settings"]["use_replica_set_connection_string"] = True normal_cluster_options["settings"]["replicaset_logging_prefix"] = "nc" normal_cluster_options["settings"]["dbpath_prefix"] = os.path.join( - self._dbpath_prefix, "normalCluster") + self._dbpath_prefix, "normalCluster" + ) self.clusters.append( - self.fixturelib.make_fixture("ReplicaSetFixture", self.logger, self.job_num, - **normal_cluster_options["settings"], replset_name="rs1")) + self.fixturelib.make_fixture( + "ReplicaSetFixture", + self.logger, + self.job_num, + **normal_cluster_options["settings"], + replset_name="rs1", + ) + ) def pids(self): """Return: pids owned by this fixture if any.""" @@ -67,7 +94,9 @@ class BulkWriteFixture(interface.MultiClusterFixture): for cluster in self.clusters: out.extend(cluster.pids()) if not out: - self.logger.debug('No clusters when gathering multi replicaset fixture pids.') + self.logger.debug( + "No clusters when gathering multi replicaset fixture pids." + ) return out def setup(self): @@ -88,7 +117,9 @@ class BulkWriteFixture(interface.MultiClusterFixture): running_at_start = self.is_running() if not running_at_start: - self.logger.warning("All clusters were expected to be running, but weren't.") + self.logger.warning( + "All clusters were expected to be running, but weren't." + ) teardown_handler = interface.FixtureTeardownHandler(self.logger) @@ -109,13 +140,17 @@ class BulkWriteFixture(interface.MultiClusterFixture): def get_internal_connection_string(self): """Return the internal connection string to the replica set that currently starts out owning the data.""" if not self.setup_complete: - raise ValueError("Must call setup() before calling get_internal_connection_string()") + raise ValueError( + "Must call setup() before calling get_internal_connection_string()" + ) return self.clusters[0].get_internal_connection_string() def get_driver_connection_url(self): """Return the driver connection URL to the replica set that currently starts out owning the data.""" if not self.setup_complete: - raise ValueError("Must call setup() before calling get_driver_connection_url") + raise ValueError( + "Must call setup() before calling get_driver_connection_url" + ) return self.clusters[0].get_driver_connection_url() def get_node_info(self): diff --git a/buildscripts/resmokelib/testing/fixtures/external.py b/buildscripts/resmokelib/testing/fixtures/external.py index e15776b667e..7905f7fa168 100644 --- a/buildscripts/resmokelib/testing/fixtures/external.py +++ b/buildscripts/resmokelib/testing/fixtures/external.py @@ -15,8 +15,10 @@ class ExternalFixture(interface.Fixture): interface.Fixture.__init__(self, logger, job_num, fixturelib) if shell_conn_string is None: - raise ValueError("The ExternalFixture must be specified with the resmoke option" - " --shellConnString or --shellPort") + raise ValueError( + "The ExternalFixture must be specified with the resmoke option" + " --shellConnString or --shellPort" + ) self.shell_conn_string = shell_conn_string @@ -25,7 +27,9 @@ class ExternalFixture(interface.Fixture): # Reconfiguring the external fixture isn't supported so there's no reason to attempt to # parse the mongodb:// connection string the user specified via the command line into the # internal format used by the server. - raise NotImplementedError("ExternalFixture can only be used with a MongoDB connection URI") + raise NotImplementedError( + "ExternalFixture can only be used with a MongoDB connection URI" + ) def get_driver_connection_url(self): """Return the driver connection URL.""" diff --git a/buildscripts/resmokelib/testing/fixtures/fixturelib.py b/buildscripts/resmokelib/testing/fixtures/fixturelib.py index 6ca81e331c4..10f97dcf449 100644 --- a/buildscripts/resmokelib/testing/fixtures/fixturelib.py +++ b/buildscripts/resmokelib/testing/fixtures/fixturelib.py @@ -1,4 +1,5 @@ """Facade wrapping the resmokelib dependencies used by fixtures.""" + from logging import Handler, Logger from typing import Dict @@ -30,7 +31,9 @@ class FixtureLib: def new_fixture_node_logger(self, fixture_class, job_num, node_name): """Create a logger for a particular element in a multi-process fixture.""" - return logging.loggers.new_fixture_node_logger(fixture_class, job_num, node_name) + return logging.loggers.new_fixture_node_logger( + fixture_class, job_num, node_name + ) ############ # Programs # @@ -40,7 +43,9 @@ class FixtureLib: """Build fixtures by calling builder API.""" return _builder.make_fixture(class_name, logger, job_num, *args, **kwargs) - def mongod_program(self, logger, job_num, executable, process_kwargs, mongod_options): + def mongod_program( + self, logger, job_num, executable, process_kwargs, mongod_options + ): """ Return a Process instance that starts mongod arguments constructed from 'mongod_options'. @@ -49,20 +54,25 @@ class FixtureLib: @param process_kwargs - A dict of key-value pairs to pass to the process. @param mongod_options - A HistoryDict describing the various options to pass to the mongod. """ - return core.programs.mongod_program(logger, job_num, executable, process_kwargs, - mongod_options) + return core.programs.mongod_program( + logger, job_num, executable, process_kwargs, mongod_options + ) - def mongos_program(self, logger, job_num, executable=None, process_kwargs=None, - mongos_options=None): + def mongos_program( + self, logger, job_num, executable=None, process_kwargs=None, mongos_options=None + ): """Return a Process instance that starts a mongos with arguments constructed from 'kwargs'.""" - return core.programs.mongos_program(logger, job_num, executable, process_kwargs, - mongos_options) + return core.programs.mongos_program( + logger, job_num, executable, process_kwargs, mongos_options + ) - def mongot_program(self, logger, job_num, executable=None, process_kwargs=None, - mongot_options=None): + def mongot_program( + self, logger, job_num, executable=None, process_kwargs=None, mongot_options=None + ): """Return a Process instance that starts a mongot with arguments constructed from 'kwargs'.""" - return core.programs.mongot_program(logger, job_num, executable, process_kwargs, - mongot_options) + return core.programs.mongot_program( + logger, job_num, executable, process_kwargs, mongot_options + ) def generic_program(self, logger, args, process_kwargs=None, **kwargs): """Return a Process instance that starts an arbitrary executable. @@ -106,7 +116,9 @@ class FixtureLib: original_set_parameters = original.get(self.SET_PARAMETERS_KEY, {}) override_set_parameters = override.get(self.SET_PARAMETERS_KEY, {}) - merged_set_parameters = merge_dicts(original_set_parameters, override_set_parameters) + merged_set_parameters = merge_dicts( + original_set_parameters, override_set_parameters + ) original.update(override) original[self.SET_PARAMETERS_KEY] = merged_set_parameters diff --git a/buildscripts/resmokelib/testing/fixtures/interface.py b/buildscripts/resmokelib/testing/fixtures/interface.py index 691d052f5e5..0a7b360dccf 100644 --- a/buildscripts/resmokelib/testing/fixtures/interface.py +++ b/buildscripts/resmokelib/testing/fixtures/interface.py @@ -54,7 +54,9 @@ class APIVersion(object, metaclass=registry.make_registry_metaclass(_VERSIONS)): return int(version.split(".")[1]) expected = cls.FIXTURE_API_VERSION - return to_major(expected) == to_major(actual) and to_minor(expected) <= to_minor(actual) + return to_major(expected) == to_major(actual) and to_minor( + expected + ) <= to_minor(actual) _FIXTURES = {} # type: ignore @@ -91,8 +93,13 @@ class Fixture(object, metaclass=registry.make_registry_metaclass(_FIXTURES)): # AWAIT_READY_TIMEOUT_SECS = 300 - def __init__(self, logger: Logger, job_num: int, fixturelib: 'FixtureLib', - dbpath_prefix: str = None): + def __init__( + self, + logger: Logger, + job_num: int, + fixturelib: "FixtureLib", + dbpath_prefix: str = None, + ): """Initialize the fixture with a logger instance.""" self.fixturelib = fixturelib @@ -108,14 +115,19 @@ class Fixture(object, metaclass=registry.make_registry_metaclass(_FIXTURES)): # self.logger = logger self.job_num = job_num - dbpath_prefix = self.fixturelib.default_if_none(self.config.DBPATH_PREFIX, dbpath_prefix) - dbpath_prefix = self.fixturelib.default_if_none(dbpath_prefix, - self.config.DEFAULT_DBPATH_PREFIX) + dbpath_prefix = self.fixturelib.default_if_none( + self.config.DBPATH_PREFIX, dbpath_prefix + ) + dbpath_prefix = self.fixturelib.default_if_none( + dbpath_prefix, self.config.DEFAULT_DBPATH_PREFIX + ) self._dbpath_prefix = os.path.join(dbpath_prefix, "job{}".format(self.job_num)) def pids(self): """Return any pids owned by this fixture.""" - raise NotImplementedError("pids must be implemented by Fixture subclasses %s" % self) + raise NotImplementedError( + "pids must be implemented by Fixture subclasses %s" % self + ) def setup(self): """Create the fixture.""" @@ -181,12 +193,13 @@ class Fixture(object, metaclass=registry.make_registry_metaclass(_FIXTURES)): # expected by the mongo::ConnectionString class. """ raise NotImplementedError( - "get_internal_connection_string must be implemented by Fixture subclasses") + "get_internal_connection_string must be implemented by Fixture subclasses" + ) def get_shell_connection_url(self): """Return the connection string to be used by the mongo shell process executing a jstest. - Defaults to returning the driver connection url, but can be overriden to provide + Defaults to returning the driver connection url, but can be overriden to provide shell-specific options (such as using a gRPC port). https://docs.mongodb.com/manual/reference/connection-string/ """ @@ -198,10 +211,15 @@ class Fixture(object, metaclass=registry.make_registry_metaclass(_FIXTURES)): # https://docs.mongodb.com/manual/reference/connection-string/ """ raise NotImplementedError( - "get_driver_connection_url must be implemented by Fixture subclasses") + "get_driver_connection_url must be implemented by Fixture subclasses" + ) - def mongo_client(self, read_preference=pymongo.ReadPreference.PRIMARY, timeout_millis=30000, - **kwargs): + def mongo_client( + self, + read_preference=pymongo.ReadPreference.PRIMARY, + timeout_millis=30000, + **kwargs, + ): """Return a pymongo.MongoClient connecting to this fixture with specified 'read_preference'. The PyMongo driver will wait up to 'timeout_millis' milliseconds @@ -223,10 +241,15 @@ class Fixture(object, metaclass=registry.make_registry_metaclass(_FIXTURES)): # if self.config.TLS_CA_FILE: kwargs["tlsCAFile"] = self.config.TLS_CA_FILE if self.config.SHELL_TLS_CERTIFICATE_KEY_FILE: - kwargs["tlsCertificateKeyFile"] = self.config.SHELL_TLS_CERTIFICATE_KEY_FILE + kwargs["tlsCertificateKeyFile"] = ( + self.config.SHELL_TLS_CERTIFICATE_KEY_FILE + ) - return pymongo.MongoClient(host=self.get_driver_connection_url(), - read_preference=read_preference, **kwargs) + return pymongo.MongoClient( + host=self.get_driver_connection_url(), + read_preference=read_preference, + **kwargs, + ) def __str__(self): return "%s (Job #%d)" % (self.__class__.__name__, self.job_num) @@ -259,7 +282,7 @@ class _DockerComposeInterface: "_all_mongo_d_s_t_instances must be implemented by Fixture subclasses that support `docker-compose.yml` generation." ) - def all_processes(self) -> List['Process']: + def all_processes(self) -> List["Process"]: """ Return a list of all `mongo{d,s,t}` `Process` instances in the fixture. @@ -267,7 +290,8 @@ class _DockerComposeInterface: """ if not self.config.DOCKER_COMPOSE_BUILD_IMAGES: raise DockerComposeException( - "This method is reserved for `--dockerComposeBuildImages` only.") + "This method is reserved for `--dockerComposeBuildImages` only." + ) processes = [] @@ -301,7 +325,8 @@ class MultiClusterFixture(Fixture): def get_independent_clusters(self): """Return a list of the independent clusters (fixtures) that participate in this fixture.""" raise NotImplementedError( - "get_independent_clusters must be implemented by MultiClusterFixture subclasses") + "get_independent_clusters must be implemented by MultiClusterFixture subclasses" + ) class ReplFixture(Fixture): @@ -314,11 +339,15 @@ class ReplFixture(Fixture): def get_primary(self): """Return the primary of a replica set.""" - raise NotImplementedError("get_primary must be implemented by ReplFixture subclasses") + raise NotImplementedError( + "get_primary must be implemented by ReplFixture subclasses" + ) def get_secondaries(self): """Return a list containing the secondaries of a replica set.""" - raise NotImplementedError("get_secondaries must be implemented by ReplFixture subclasses") + raise NotImplementedError( + "get_secondaries must be implemented by ReplFixture subclasses" + ) def retry_until_wtimeout(self, insert_fn): """Retry until wtimeout reached. @@ -343,7 +372,9 @@ class ReplFixture(Fixture): remaining = deadline - time.time() if remaining <= 0.0: message = "Failed to connect to {} within {} minutes".format( - self.get_driver_connection_url(), ReplFixture.AWAIT_REPL_TIMEOUT_MINS) + self.get_driver_connection_url(), + ReplFixture.AWAIT_REPL_TIMEOUT_MINS, + ) self.logger.error(message) raise self.fixturelib.ServerFailure(message) except pymongo.errors.WTimeoutError: @@ -352,7 +383,8 @@ class ReplFixture(Fixture): raise self.fixturelib.ServerFailure(message) except pymongo.errors.PyMongoError as err: message = "Write operation on {} failed: {}".format( - self.get_driver_connection_url(), err) + self.get_driver_connection_url(), err + ) raise self.fixturelib.ServerFailure(message) @@ -447,12 +479,12 @@ def create_fixture_table(fixture): columns[key] = [] for node in info: attr = getattr(node, key) - str_value = str(attr) if attr is not None else '-' + str_value = str(attr) if attr is not None else "-" columns[key].append(str_value) longest[key] = max(longest[key], len(str_value)) # Filter out columns where no row has a value - columns = {k: v for k, v in columns.items() if not all(x == '-' for x in v)} + columns = {k: v for k, v in columns.items() if not all(x == "-" for x in v)} def horizontal_separator(): row = "" @@ -486,16 +518,21 @@ def create_fixture_table(fixture): return "Fixture status:\n" + table -def build_client(node, auth_options=None, read_preference=pymongo.ReadPreference.PRIMARY): +def build_client( + node, auth_options=None, read_preference=pymongo.ReadPreference.PRIMARY +): """Authenticate client for the 'authenticationDatabase' and return the client.""" if auth_options is not None: return node.mongo_client( - username=auth_options["username"], password=auth_options["password"], + username=auth_options["username"], + password=auth_options["password"], authSource=auth_options["authenticationDatabase"], - authMechanism=auth_options["authenticationMechanism"], read_preference=read_preference) + authMechanism=auth_options["authenticationMechanism"], + read_preference=read_preference, + ) else: return node.mongo_client(read_preference=read_preference) # Represents a row in a node info table. -NodeInfo = namedtuple('NodeInfo', ['full_name', 'name', 'port', 'pid', 'router_port']) +NodeInfo = namedtuple("NodeInfo", ["full_name", "name", "port", "pid", "router_port"]) diff --git a/buildscripts/resmokelib/testing/fixtures/mongot.py b/buildscripts/resmokelib/testing/fixtures/mongot.py index b132c6dac24..c9de64b00dc 100644 --- a/buildscripts/resmokelib/testing/fixtures/mongot.py +++ b/buildscripts/resmokelib/testing/fixtures/mongot.py @@ -2,10 +2,10 @@ Mongot is a MongoDB-specific process written as a wrapper around Lucene. Using Lucene, mongot indexes MongoDB databases to provide our customers with full text search capabilities. -Customers have the option of running mongot on Atlas or locally using a special "local-dev" binary of mongot. The local-dev binary allows mongot and mongod to speak directly on the localhost, rather than via proprietary network proxies configured by the Atlas Data Plane. +Customers have the option of running mongot on Atlas or locally using a special "local-dev" binary of mongot. The local-dev binary allows mongot and mongod to speak directly on the localhost, rather than via proprietary network proxies configured by the Atlas Data Plane. A resmoke suite's yml definition can enable launching mongot(s) enabled via the launch_mongot option on the ReplicaSetFixture and providing a keyfile. If enabled, the ReplicaSetFixture launches a local-dev version of mongot per mongod node. The mongot replicates directly from the co-located -mongod via a $changeStream. +mongod via a $changeStream. """ import shutil @@ -20,12 +20,17 @@ from buildscripts.resmokelib.testing.fixtures import interface class MongoTFixture(interface.Fixture, interface._DockerComposeInterface): """Fixture which provides JSTests with a mongot to run alongside a mongod.""" - def __init__(self, logger, job_num, fixturelib, dbpath_prefix=None, mongot_options=None): + def __init__( + self, logger, job_num, fixturelib, dbpath_prefix=None, mongot_options=None + ): interface.Fixture.__init__(self, logger, job_num, fixturelib) self.mongot_options = self.fixturelib.make_historic( - self.fixturelib.default_if_none(mongot_options, {})) + self.fixturelib.default_if_none(mongot_options, {}) + ) # Default to command line options if the YAML configuration is not passed in. - self.mongot_executable = self.fixturelib.default_if_none(self.config.MONGOT_EXECUTABLE) + self.mongot_executable = self.fixturelib.default_if_none( + self.config.MONGOT_EXECUTABLE + ) self.port = self.mongot_options["port"] # Each mongot requires its own unique config journal to persist index definitions, replication status, etc to disk. # If dir passed to --data-dir option doesn't exist, mongot will create it @@ -38,9 +43,12 @@ class MongoTFixture(interface.Fixture, interface._DockerComposeInterface): launcher = MongotLauncher(self.fixturelib) # Second return val is the port, which we ignore because we explicitly generated the port number in MongoDFixture # initialization and save to MongotFixture in above initialization function. - mongot, _ = launcher.launch_mongot_program(self.logger, self.job_num, - executable=self.mongot_executable, - mongot_options=self.mongot_options) + mongot, _ = launcher.launch_mongot_program( + self.logger, + self.job_num, + executable=self.mongot_executable, + mongot_options=self.mongot_options, + ) try: msg = f"Starting mongot on port { self.port } ...\n{ mongot.as_command() }" @@ -63,11 +71,10 @@ class MongoTFixture(interface.Fixture, interface._DockerComposeInterface): """:return: pids owned by this fixture if any.""" out = [x.pid for x in [self.mongot] if x is not None] if not out: - self.logger.debug('Mongot not running when gathering mongot fixture pid.') + self.logger.debug("Mongot not running when gathering mongot fixture pid.") return out def _do_teardown(self, mode=None): - if self.config.NOOP_MONGO_D_S_PROCESSES: self.logger.info( "This is running against an External System Under Test setup with `docker-compose.yml` -- skipping teardown." @@ -81,14 +88,19 @@ class MongoTFixture(interface.Fixture, interface._DockerComposeInterface): if mode == interface.TeardownMode.ABORT: self.logger.info( "Attempting to send SIGABRT from resmoke to mongot on port %d with pid %d...", - self.port, self.mongot.pid) + self.port, + self.mongot.pid, + ) else: - self.logger.info("Stopping mongot on port %d with pid %d...", self.port, - self.mongot.pid) + self.logger.info( + "Stopping mongot on port %d with pid %d...", self.port, self.mongot.pid + ) if not self.is_running(): exit_code = self.mongot.poll() - msg = ("mongot on port {:d} was expected to be running, but wasn't. " - "Process exited with code {:d}.").format(self.port, exit_code) + msg = ( + "mongot on port {:d} was expected to be running, but wasn't. " + "Process exited with code {:d}." + ).format(self.port, exit_code) self.logger.warning(msg) raise self.fixturelib.ServerFailure(msg) @@ -97,13 +109,19 @@ class MongoTFixture(interface.Fixture, interface._DockerComposeInterface): # Java applications return exit code of 143 when they shut down upon receiving and obeying a SIGTERM signal, which is the desired/default mode. if exit_code == 143 or (mode is not None and exit_code == -(mode.value)): - self.logger.info("Successfully stopped the mongot on port {:d}.".format(self.port)) + self.logger.info( + "Successfully stopped the mongot on port {:d}.".format(self.port) + ) else: - self.logger.warning("Stopped the mongot on port {:d}. " - "Process exited with code {:d}.".format(self.port, exit_code)) + self.logger.warning( + "Stopped the mongot on port {:d}. " + "Process exited with code {:d}.".format(self.port, exit_code) + ) raise self.fixturelib.ServerFailure( "mongot on port {:d} with pid {:d} exited with code {:d}".format( - self.port, self.mongot.pid, exit_code)) + self.port, self.mongot.pid, exit_code + ) + ) # It is necessary for correctness purposes to delete the config journals during fixture teardown # (instead of in a hook) to ensure that there are no zombie index entries left from a previous # test that exited abruptly due to a failure. @@ -111,7 +129,9 @@ class MongoTFixture(interface.Fixture, interface._DockerComposeInterface): try: shutil.rmtree(self.data_dir) except OSError as error: - self.logger.error("Hit OS error trying to delete mongot config journal: %s", error) + self.logger.error( + "Hit OS error trying to delete mongot config journal: %s", error + ) pass self.logger.info("Finished deleting mongot data files in fixture teardown") @@ -129,8 +149,13 @@ class MongoTFixture(interface.Fixture, interface._DockerComposeInterface): self.logger.warning("The mongot fixture has not been set up yet.") return [] - info = interface.NodeInfo(full_name=self.logger.full_name, name=self.logger.name, - port=self.port, pid=self.mongot.pid, router_port=self.router_port) + info = interface.NodeInfo( + full_name=self.logger.full_name, + name=self.logger.name, + port=self.port, + pid=self.mongot.pid, + router_port=self.router_port, + ) return [info] def get_internal_connection_string(self): @@ -139,7 +164,11 @@ class MongoTFixture(interface.Fixture, interface._DockerComposeInterface): def get_driver_connection_url(self): """Return the driver connection URL.""" - return "mongodb://" + self.get_internal_connection_string() + "/?directConnection=true" + return ( + "mongodb://" + + self.get_internal_connection_string() + + "/?directConnection=true" + ) def await_ready(self): """Block until the fixture can be used for testing.""" @@ -154,7 +183,8 @@ class MongoTFixture(interface.Fixture, interface._DockerComposeInterface): if exit_code is not None: raise self.fixturelib.ServerFailure( "Could not connect to mongot on port {}, process ended" - " unexpectedly with code {}.".format(self.port, exit_code)) + " unexpectedly with code {}.".format(self.port, exit_code) + ) try: # By connecting to the host and port that mongot is listening from, @@ -167,7 +197,9 @@ class MongoTFixture(interface.Fixture, interface._DockerComposeInterface): if remaining <= 0.0: raise self.fixturelib.ServerFailure( "Failed to connect to mongot on port {} after {} seconds".format( - self.port, MongoTFixture.AWAIT_READY_TIMEOUT_SECS)) + self.port, MongoTFixture.AWAIT_READY_TIMEOUT_SECS + ) + ) self.logger.info("Waiting to connect to mongot on port %d.", self.port) time.sleep(0.1) # Wait a little bit before trying again. @@ -183,8 +215,9 @@ class MongotLauncher(object): self.fixturelib = fixturelib self.config = fixturelib.get_config() - def launch_mongot_program(self, logger, job_num, executable=None, process_kwargs=None, - mongot_options=None): + def launch_mongot_program( + self, logger, job_num, executable=None, process_kwargs=None, mongot_options=None + ): """ Return a Process instance that starts a mongot with arguments constructed from 'mongot_options'. @@ -192,16 +225,18 @@ class MongotLauncher(object): @param executable - The mongot executable to run. @param process_kwargs - A dict of key-value pairs to pass to the process. @param mongot_options - A HistoryDict describing the various options to pass to the mongot. - - Currently, this will launch a mongot with --port, --mongodHostAndPort, and --keyFile commandline - options. To support launching mongot with more startup options, those new options would need to - be added to mongot_options in MongoTFixture initialization or, if mongod needs to share/know the + + Currently, this will launch a mongot with --port, --mongodHostAndPort, and --keyFile commandline + options. To support launching mongot with more startup options, those new options would need to + be added to mongot_options in MongoTFixture initialization or, if mongod needs to share/know the mongot startup option (like in the case of keyFile), in MongoDFixture::setup_mongot(). """ - executable = self.fixturelib.default_if_none(executable, - self.config.DEFAULT_MONGOD_EXECUTABLE) + executable = self.fixturelib.default_if_none( + executable, self.config.DEFAULT_MONGOD_EXECUTABLE + ) mongot_options = self.fixturelib.default_if_none(mongot_options, {}).copy() - return self.fixturelib.mongot_program(logger, job_num, executable, process_kwargs, - mongot_options) + return self.fixturelib.mongot_program( + logger, job_num, executable, process_kwargs, mongot_options + ) diff --git a/buildscripts/resmokelib/testing/fixtures/multi_replica_set.py b/buildscripts/resmokelib/testing/fixtures/multi_replica_set.py index 92f5f88e471..bcf6887c366 100644 --- a/buildscripts/resmokelib/testing/fixtures/multi_replica_set.py +++ b/buildscripts/resmokelib/testing/fixtures/multi_replica_set.py @@ -15,28 +15,50 @@ class MultiReplicaSetFixture(interface.MultiClusterFixture): CONNECTION_STRING_DB_NAME = "config" CONNECTION_STRING_COLL_NAME = "multiReplicaSetFixture" - def __init__(self, logger, job_num, fixturelib, dbpath_prefix=None, num_replica_sets=2, - num_nodes_per_replica_set=2, common_mongod_options=None, per_mongod_options=None, - per_replica_set_options=None, persist_connection_strings=False, - **common_replica_set_options): + def __init__( + self, + logger, + job_num, + fixturelib, + dbpath_prefix=None, + num_replica_sets=2, + num_nodes_per_replica_set=2, + common_mongod_options=None, + per_mongod_options=None, + per_replica_set_options=None, + persist_connection_strings=False, + **common_replica_set_options, + ): """Initialize MultiReplicaSetFixture with different options for the replica set processes.""" - interface.MultiClusterFixture.__init__(self, logger, job_num, fixturelib, dbpath_prefix) + interface.MultiClusterFixture.__init__( + self, logger, job_num, fixturelib, dbpath_prefix + ) - self.num_replica_sets = num_replica_sets if num_replica_sets else self.config.NUM_REPLSETS + self.num_replica_sets = ( + num_replica_sets if num_replica_sets else self.config.NUM_REPLSETS + ) if self.num_replica_sets < 2: raise ValueError("num_replica_sets must be greater or equal to 2") self.num_nodes_per_replica_set = num_nodes_per_replica_set - self.common_mongod_options = self.fixturelib.default_if_none(common_mongod_options, {}) - self.per_mongod_options = self.fixturelib.default_if_none(per_mongod_options, []) + self.common_mongod_options = self.fixturelib.default_if_none( + common_mongod_options, {} + ) + self.per_mongod_options = self.fixturelib.default_if_none( + per_mongod_options, [] + ) self.common_replica_set_options = common_replica_set_options - self.per_replica_set_options = self.fixturelib.default_if_none(per_replica_set_options, []) + self.per_replica_set_options = self.fixturelib.default_if_none( + per_replica_set_options, [] + ) self.persist_connection_strings = persist_connection_strings self.auth_options = self.common_replica_set_options.get("auth_options", None) # Store this since it is needed by the ContinuousStepdown hook. - self.all_nodes_electable = self.common_replica_set_options.get("all_nodes_electable", False) + self.all_nodes_electable = self.common_replica_set_options.get( + "all_nodes_electable", False + ) self.replica_sets = [] if not self.replica_sets: @@ -52,9 +74,16 @@ class MultiReplicaSetFixture(interface.MultiClusterFixture): self.replica_sets.append( self.fixturelib.make_fixture( - "ReplicaSetFixture", self.logger, self.job_num, replset_name=rs_name, - replicaset_logging_prefix=rs_name, num_nodes=self.num_nodes_per_replica_set, - mongod_options=mongod_options, **replica_set_options)) + "ReplicaSetFixture", + self.logger, + self.job_num, + replset_name=rs_name, + replicaset_logging_prefix=rs_name, + num_nodes=self.num_nodes_per_replica_set, + mongod_options=mongod_options, + **replica_set_options, + ) + ) def pids(self): """:return: pids owned by this fixture if any.""" @@ -62,7 +91,9 @@ class MultiReplicaSetFixture(interface.MultiClusterFixture): for replica_set in self.replica_sets: out.extend(replica_set.pids()) if not out: - self.logger.debug('No replica sets when gathering multi replicaset fixture pids.') + self.logger.debug( + "No replica sets when gathering multi replicaset fixture pids." + ) return out def setup(self): @@ -76,12 +107,16 @@ class MultiReplicaSetFixture(interface.MultiClusterFixture): for replica_set in self.replica_sets: replica_set.await_ready() if self.persist_connection_strings: - docs = [{"_id": i, "connectionString": replica_set.get_driver_connection_url()} - for (i, replica_set) in enumerate(self.replica_sets)] - primary_client = interface.build_client(self.replica_sets[0].get_primary(), - self.auth_options) + docs = [ + {"_id": i, "connectionString": replica_set.get_driver_connection_url()} + for (i, replica_set) in enumerate(self.replica_sets) + ] + primary_client = interface.build_client( + self.replica_sets[0].get_primary(), self.auth_options + ) primary_coll = primary_client[self.CONNECTION_STRING_DB_NAME][ - self.CONNECTION_STRING_COLL_NAME] + self.CONNECTION_STRING_COLL_NAME + ] primary_coll.insert_many(docs) def _do_teardown(self, mode=None): @@ -90,7 +125,9 @@ class MultiReplicaSetFixture(interface.MultiClusterFixture): running_at_start = self.is_running() if not running_at_start: - self.logger.warning("All replica sets were expected to be running, but weren't.") + self.logger.warning( + "All replica sets were expected to be running, but weren't." + ) teardown_handler = interface.FixtureTeardownHandler(self.logger) @@ -126,13 +163,17 @@ class MultiReplicaSetFixture(interface.MultiClusterFixture): def get_internal_connection_string(self): """Return the internal connection string to the replica set that tests should connect to.""" if not self.replica_sets: - raise ValueError("Must call setup() before calling get_internal_connection_string()") + raise ValueError( + "Must call setup() before calling get_internal_connection_string()" + ) return self.replica_sets[0].get_internal_connection_string() def get_driver_connection_url(self): """Return the driver connection URL to the replica set that tests should connect to.""" if not self.replica_sets: - raise ValueError("Must call setup() before calling get_driver_connection_url") + raise ValueError( + "Must call setup() before calling get_driver_connection_url" + ) return self.replica_sets[0].get_driver_connection_url() def get_node_info(self): diff --git a/buildscripts/resmokelib/testing/fixtures/multi_sharded_cluster.py b/buildscripts/resmokelib/testing/fixtures/multi_sharded_cluster.py index 85731ead14c..00c68d357b0 100644 --- a/buildscripts/resmokelib/testing/fixtures/multi_sharded_cluster.py +++ b/buildscripts/resmokelib/testing/fixtures/multi_sharded_cluster.py @@ -17,23 +17,39 @@ class MultiShardedClusterFixture(interface.MultiClusterFixture): CONNECTION_STRING_DB_NAME = "config" CONNECTION_STRING_COLL_NAME = "multiShardedClusterFixture" - def __init__(self, logger, job_num, fixturelib, dbpath_prefix=None, num_sharded_clusters=2, - common_mongod_options=None, per_mongod_options=None, - per_sharded_cluster_options=None, persist_connection_strings=False, - **common_sharded_cluster_options): + def __init__( + self, + logger, + job_num, + fixturelib, + dbpath_prefix=None, + num_sharded_clusters=2, + common_mongod_options=None, + per_mongod_options=None, + per_sharded_cluster_options=None, + persist_connection_strings=False, + **common_sharded_cluster_options, + ): """Initialize MultiShardedClusterFixture with different options for the sharded cluster processes.""" - interface.MultiClusterFixture.__init__(self, logger, job_num, fixturelib, dbpath_prefix) + interface.MultiClusterFixture.__init__( + self, logger, job_num, fixturelib, dbpath_prefix + ) if num_sharded_clusters < 2: raise ValueError("num_sharded_clusters must be greater or equal to 2") self.num_sharded_clusters = num_sharded_clusters - self.common_mongod_options = self.fixturelib.default_if_none(common_mongod_options, {}) - self.per_mongod_options = self.fixturelib.default_if_none(per_mongod_options, []) + self.common_mongod_options = self.fixturelib.default_if_none( + common_mongod_options, {} + ) + self.per_mongod_options = self.fixturelib.default_if_none( + per_mongod_options, [] + ) self.common_sharded_cluster_options = common_sharded_cluster_options self.per_sharded_cluster_options = self.fixturelib.default_if_none( - per_sharded_cluster_options, []) + per_sharded_cluster_options, [] + ) self.persist_connection_strings = persist_connection_strings self.sharded_clusters = [] @@ -50,9 +66,15 @@ class MultiShardedClusterFixture(interface.MultiClusterFixture): self.sharded_clusters.append( self.fixturelib.make_fixture( - "ShardedClusterFixture", self.logger, self.job_num, - dbpath_prefix=dbpath_prefix, cluster_logging_prefix=cluster_name, - mongod_options=mongod_options, **sharded_cluster_options)) + "ShardedClusterFixture", + self.logger, + self.job_num, + dbpath_prefix=dbpath_prefix, + cluster_logging_prefix=cluster_name, + mongod_options=mongod_options, + **sharded_cluster_options, + ) + ) def pids(self): """:return: pids owned by this fixture if any.""" @@ -61,7 +83,8 @@ class MultiShardedClusterFixture(interface.MultiClusterFixture): out.extend(sharded_cluster.pids()) if not out: self.logger.debug( - 'No sharded clusters when gathering multi sharded cluster fixture pids.') + "No sharded clusters when gathering multi sharded cluster fixture pids." + ) return out def setup(self): @@ -75,10 +98,19 @@ class MultiShardedClusterFixture(interface.MultiClusterFixture): for sharded_cluster in self.sharded_clusters: sharded_cluster.await_ready() if self.persist_connection_strings: - docs = [{"_id": i, "connectionString": sharded_cluster.get_driver_connection_url()} - for (i, sharded_cluster) in enumerate(self.sharded_clusters)] - client = pymongo.MongoClient(self.sharded_clusters[0].get_driver_connection_url()) - coll = client[self.CONNECTION_STRING_DB_NAME][self.CONNECTION_STRING_COLL_NAME] + docs = [ + { + "_id": i, + "connectionString": sharded_cluster.get_driver_connection_url(), + } + for (i, sharded_cluster) in enumerate(self.sharded_clusters) + ] + client = pymongo.MongoClient( + self.sharded_clusters[0].get_driver_connection_url() + ) + coll = client[self.CONNECTION_STRING_DB_NAME][ + self.CONNECTION_STRING_COLL_NAME + ] coll.insert_many(docs) def feature_flag_present_and_enabled(self, feature_flag_name): @@ -94,7 +126,9 @@ class MultiShardedClusterFixture(interface.MultiClusterFixture): running_at_start = self.is_running() if not running_at_start: - self.logger.warning("All sharded clusters were expected to be running, but weren't.") + self.logger.warning( + "All sharded clusters were expected to be running, but weren't." + ) teardown_handler = interface.FixtureTeardownHandler(self.logger) @@ -109,7 +143,9 @@ class MultiShardedClusterFixture(interface.MultiClusterFixture): def is_running(self): """Return true if all sharded clusters are still operating.""" - return all(sharded_cluster.is_running() for sharded_cluster in self.sharded_clusters) + return all( + sharded_cluster.is_running() for sharded_cluster in self.sharded_clusters + ) def get_num_sharded_clusters(self): """Return the number of sharded clusters.""" @@ -130,13 +166,17 @@ class MultiShardedClusterFixture(interface.MultiClusterFixture): def get_internal_connection_string(self): """Return the internal connection string to the sharded cluster that tests should connect to.""" if not self.sharded_clusters: - raise ValueError("Must call setup() before calling get_internal_connection_string()") + raise ValueError( + "Must call setup() before calling get_internal_connection_string()" + ) return self.sharded_clusters[0].get_internal_connection_string() def get_driver_connection_url(self): """Return the driver connection URL to the sharded cluster that tests should connect to.""" if not self.sharded_clusters: - raise ValueError("Must call setup() before calling get_driver_connection_url") + raise ValueError( + "Must call setup() before calling get_driver_connection_url" + ) return self.sharded_clusters[0].get_driver_connection_url() def get_node_info(self): diff --git a/buildscripts/resmokelib/testing/fixtures/replicaset.py b/buildscripts/resmokelib/testing/fixtures/replicaset.py index 660395eb2ff..70817658dbb 100644 --- a/buildscripts/resmokelib/testing/fixtures/replicaset.py +++ b/buildscripts/resmokelib/testing/fixtures/replicaset.py @@ -43,30 +43,55 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface AWAIT_SHARDING_INITIALIZATION_TIMEOUT_SECS = 60 def __init__( - self, logger, job_num, fixturelib, mongod_executable=None, mongod_options=None, - dbpath_prefix=None, preserve_dbpath=False, num_nodes=2, start_initial_sync_node=False, - electable_initial_sync_node=False, write_concern_majority_journal_default=None, - auth_options=None, replset_config_options=None, voting_secondaries=True, - all_nodes_electable=False, use_replica_set_connection_string=None, linear_chain=False, - default_read_concern=None, default_write_concern=None, shard_logging_prefix=None, - replicaset_logging_prefix=None, replset_name=None, use_auto_bootstrap_procedure=None, - initial_sync_uninitialized_fcv=False, hide_initial_sync_node_from_conn_string=False, - launch_mongot=False, initial_sync_uninitialized_fcv_in_shard_svr=False): + self, + logger, + job_num, + fixturelib, + mongod_executable=None, + mongod_options=None, + dbpath_prefix=None, + preserve_dbpath=False, + num_nodes=2, + start_initial_sync_node=False, + electable_initial_sync_node=False, + write_concern_majority_journal_default=None, + auth_options=None, + replset_config_options=None, + voting_secondaries=True, + all_nodes_electable=False, + use_replica_set_connection_string=None, + linear_chain=False, + default_read_concern=None, + default_write_concern=None, + shard_logging_prefix=None, + replicaset_logging_prefix=None, + replset_name=None, + use_auto_bootstrap_procedure=None, + initial_sync_uninitialized_fcv=False, + hide_initial_sync_node_from_conn_string=False, + launch_mongot=False, + initial_sync_uninitialized_fcv_in_shard_svr=False, + ): """Initialize ReplicaSetFixture.""" - interface.ReplFixture.__init__(self, logger, job_num, fixturelib, - dbpath_prefix=dbpath_prefix) + interface.ReplFixture.__init__( + self, logger, job_num, fixturelib, dbpath_prefix=dbpath_prefix + ) self.mongod_executable = mongod_executable self.mongod_options = self.fixturelib.make_historic( - self.fixturelib.default_if_none(mongod_options, {})) + self.fixturelib.default_if_none(mongod_options, {}) + ) self.preserve_dbpath = preserve_dbpath self.start_initial_sync_node = start_initial_sync_node self.electable_initial_sync_node = electable_initial_sync_node - self.write_concern_majority_journal_default = write_concern_majority_journal_default + self.write_concern_majority_journal_default = ( + write_concern_majority_journal_default + ) self.auth_options = auth_options self.replset_config_options = self.fixturelib.make_historic( - self.fixturelib.default_if_none(replset_config_options, {})) + self.fixturelib.default_if_none(replset_config_options, {}) + ) self.voting_secondaries = voting_secondaries self.all_nodes_electable = all_nodes_electable self.use_replica_set_connection_string = use_replica_set_connection_string @@ -77,16 +102,21 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface self.num_nodes = num_nodes self.replset_name = replset_name self.initial_sync_uninitialized_fcv = initial_sync_uninitialized_fcv - self.hide_initial_sync_node_from_conn_string = hide_initial_sync_node_from_conn_string - self.initial_sync_uninitialized_fcv_in_shard_svr = initial_sync_uninitialized_fcv_in_shard_svr + self.hide_initial_sync_node_from_conn_string = ( + hide_initial_sync_node_from_conn_string + ) + self.initial_sync_uninitialized_fcv_in_shard_svr = ( + initial_sync_uninitialized_fcv_in_shard_svr + ) # Used by the enhanced multiversion system to signify multiversion mode. # None implies no multiversion run. self.fcv = None # Used by suites that run search integration tests. self.launch_mongot = launch_mongot # Use the values given from the command line if they exist for linear_chain and num_nodes. - linear_chain_option = self.fixturelib.default_if_none(self.config.LINEAR_CHAIN, - linear_chain) + linear_chain_option = self.fixturelib.default_if_none( + self.config.LINEAR_CHAIN, linear_chain + ) self.linear_chain = linear_chain_option if linear_chain_option else linear_chain self.repl_set_config = {} @@ -96,12 +126,14 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface self.use_replica_set_connection_string = self.all_nodes_electable if self.default_write_concern is True: - self.default_write_concern = self.fixturelib.make_historic({ - "w": "majority", - # Use a "signature" value that won't typically match a value assigned in normal use. - # This way the wtimeout set by this override is distinguishable in the server logs. - "wtimeout": 5 * 60 * 1000 + 321, # 300321ms - }) + self.default_write_concern = self.fixturelib.make_historic( + { + "w": "majority", + # Use a "signature" value that won't typically match a value assigned in normal use. + # This way the wtimeout set by this override is distinguishable in the server logs. + "wtimeout": 5 * 60 * 1000 + 321, # 300321ms + } + ) # Set the default oplogSize to 511MB. self.mongod_options.setdefault("oplogSize", 511) @@ -112,13 +144,17 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface if "dbpath" in self.mongod_options: self._dbpath_prefix = self.mongod_options.pop("dbpath") else: - self._dbpath_prefix = os.path.join(self._dbpath_prefix, self.config.FIXTURE_SUBDIR) + self._dbpath_prefix = os.path.join( + self._dbpath_prefix, self.config.FIXTURE_SUBDIR + ) self.nodes = [] if "serverless" not in self.mongod_options: if not self.replset_name: self.replset_name = "rs" - self.replset_name = self.mongod_options.setdefault("replSet", self.replset_name) + self.replset_name = self.mongod_options.setdefault( + "replSet", self.replset_name + ) self.initial_sync_node = None self.initial_sync_node_idx = -1 self.use_auto_bootstrap_procedure = use_auto_bootstrap_procedure @@ -138,8 +174,7 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface # get the auto generated replSet name and update the replSet name of the other mongods with it. self.nodes[0].setup() self.nodes[0].await_ready() - self._await_primary( - ) # Wait for writeable primary (this indicates replSet auto-intiiate finished). + self._await_primary() # Wait for writeable primary (this indicates replSet auto-intiiate finished). client = interface.build_client(self.nodes[0], self.auth_options) res = client.admin.command("hello") @@ -176,7 +211,7 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface # Initiate the replica set. members = [] - for (i, node) in enumerate(self.nodes): + for i, node in enumerate(self.nodes): member_info = {"_id": i, "host": node.get_internal_connection_string()} if i > 0: if not self.all_nodes_electable: @@ -190,15 +225,20 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface # FCV state, don't add it to the replica set yet. It will be added to the set later after # the shard is added to the cluster, so that if the shard is a config shard, it can transition # to being a config shard properly. - if self.initial_sync_node and not self.initial_sync_uninitialized_fcv_in_shard_svr: + if ( + self.initial_sync_node + and not self.initial_sync_uninitialized_fcv_in_shard_svr + ): initial_sync_config = self._create_initial_sync_config() members.append(initial_sync_config) repl_config = {"_id": self.replset_name, "protocolVersion": 1} client = interface.build_client(self.nodes[0], self.auth_options) - if client.local.system.replset.count_documents( - filter={}) and not self.use_auto_bootstrap_procedure: + if ( + client.local.system.replset.count_documents(filter={}) + and not self.use_auto_bootstrap_procedure + ): # Skip initializing the replset if there is an existing configuration. # Auto-bootstrapping will automatically create a configuration document but we do not # want to skip reconfiguring the replset (which adds the other nodes @@ -207,15 +247,17 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface return if self.write_concern_majority_journal_default is not None: - repl_config[ - "writeConcernMajorityJournalDefault"] = self.write_concern_majority_journal_default + repl_config["writeConcernMajorityJournalDefault"] = ( + self.write_concern_majority_journal_default + ) else: server_status = client.admin.command({"serverStatus": 1}) if not server_status["storageEngine"]["persistent"]: repl_config["writeConcernMajorityJournalDefault"] = False - if (self.replset_config_options.get("configsvr", False) - or (self.use_auto_bootstrap_procedure and "shardsvr" not in self.mongod_options)): + if self.replset_config_options.get("configsvr", False) or ( + self.use_auto_bootstrap_procedure and "shardsvr" not in self.mongod_options + ): repl_config["configsvr"] = True if self.replset_config_options.get("settings"): replset_settings = self.replset_config_options["settings"] @@ -252,10 +294,12 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface # nodes are subsequently added to the set, since such nodes cannot set their FCV to # "latest". Therefore, we make sure the primary is "last-lts" FCV before adding in # nodes of different binary versions to the replica set. - client.admin.command({ - "setFeatureCompatibilityVersion": self.fcv, - "fromConfigServer": True, - }) + client.admin.command( + { + "setFeatureCompatibilityVersion": self.fcv, + "fromConfigServer": True, + } + ) if self.nodes[1:]: # Wait to connect to each of the secondaries before running the replSetReconfig @@ -296,8 +340,8 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface def add_initial_sync_node_to_replica_set(self): """Adds the initial sync node to the replica set. - - This is used so that we can add in the initial sync node after the setup() function at a + + This is used so that we can add in the initial sync node after the setup() function at a later time. """ @@ -328,11 +372,15 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface for node in self.nodes: pids.extend(node.pids()) if not pids: - self.logger.debug('No members running when gathering replicaset fixture pids.') + self.logger.debug( + "No members running when gathering replicaset fixture pids." + ) return pids def _add_node_to_repl_set(self, client, repl_config, member_index, members): - self.logger.info("Adding in node %d: %s", member_index, members[member_index - 1]) + self.logger.info( + "Adding in node %d: %s", member_index, members[member_index - 1] + ) repl_config["members"] = members[:member_index] self._reconfig_repl_set(client, repl_config) @@ -341,14 +389,18 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface try: # 'newlyAdded' removal reconfigs could bump the version. # Get the current version to be safe. - curr_version = client.admin.command({"replSetGetConfig": 1})['config']['version'] + curr_version = client.admin.command({"replSetGetConfig": 1})["config"][ + "version" + ] repl_config["version"] = curr_version + 1 self.logger.info("Issuing replSetReconfig command: %s", repl_config) - client.admin.command({ - "replSetReconfig": repl_config, - "maxTimeMS": self.AWAIT_REPL_TIMEOUT_MINS * 60 * 1000 - }) + client.admin.command( + { + "replSetReconfig": repl_config, + "maxTimeMS": self.AWAIT_REPL_TIMEOUT_MINS * 60 * 1000, + } + ) self.repl_set_config = repl_config break except pymongo.errors.OperationFailure as err: @@ -357,19 +409,23 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface # indefinitely. # pylint: disable=too-many-boolean-expressions if err.code not in [ - ReplicaSetFixture._NEW_REPLICA_SET_CONFIGURATION_INCOMPATIBLE, - ReplicaSetFixture._CURRENT_CONFIG_NOT_COMMITTED_YET, - ReplicaSetFixture._CONFIGURATION_IN_PROGRESS, - ReplicaSetFixture._NODE_NOT_FOUND, - ReplicaSetFixture._INTERRUPTED_DUE_TO_REPL_STATE_CHANGE, - ReplicaSetFixture._INTERRUPTED_DUE_TO_STORAGE_CHANGE + ReplicaSetFixture._NEW_REPLICA_SET_CONFIGURATION_INCOMPATIBLE, + ReplicaSetFixture._CURRENT_CONFIG_NOT_COMMITTED_YET, + ReplicaSetFixture._CONFIGURATION_IN_PROGRESS, + ReplicaSetFixture._NODE_NOT_FOUND, + ReplicaSetFixture._INTERRUPTED_DUE_TO_REPL_STATE_CHANGE, + ReplicaSetFixture._INTERRUPTED_DUE_TO_STORAGE_CHANGE, ]: - msg = ("Operation failure while setting up the " - "replica set fixture: {}").format(err) + msg = ( + "Operation failure while setting up the " + "replica set fixture: {}" + ).format(err) self.logger.error(msg) raise self.fixturelib.ServerFailure(msg) - msg = ("Retrying failed attempt to add new node to fixture: {}").format(err) + msg = ("Retrying failed attempt to add new node to fixture: {}").format( + err + ) self.logger.error(msg) time.sleep(0.1) # Wait a little bit before trying again. @@ -386,13 +442,18 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface except pymongo.errors.OperationFailure as err: # Retry on NodeNotFound errors from the "replSetInitiate" command. if err.code != ReplicaSetFixture._NODE_NOT_FOUND: - msg = ("Operation failure while configuring the " - "replica set fixture: {}").format(err) + msg = ( + "Operation failure while configuring the " + "replica set fixture: {}" + ).format(err) self.logger.error(msg) raise self.fixturelib.ServerFailure(msg) - msg = "replSetInitiate failed attempt {0} of {1} with error: {2}".format( - attempt, num_initiate_attempts, err) + msg = ( + "replSetInitiate failed attempt {0} of {1} with error: {2}".format( + attempt, num_initiate_attempts, err + ) + ) self.logger.error(msg) if attempt == num_initiate_attempts: msg = "Exceeded number of retries while configuring the replica set fixture" @@ -422,8 +483,9 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface return len(up_to_date_nodes) == len(self.nodes) - self._await_cmd_all_nodes(check_rcmaj_optime, "waiting for last committed optime", - timeout_secs) + self._await_cmd_all_nodes( + check_rcmaj_optime, "waiting for last committed optime", timeout_secs + ) def await_ready(self): """Wait for replica set to be ready.""" @@ -442,7 +504,9 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface primary = self.nodes[0] client = primary.mongo_client() while True: - self.logger.info("Waiting for primary on port %d to be elected.", primary.port) + self.logger.info( + "Waiting for primary on port %d to be elected.", primary.port + ) is_master = client.admin.command("isMaster")["ismaster"] if is_master: break @@ -458,10 +522,14 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface secondaries.append(self.initial_sync_node) for secondary in secondaries: - client = secondary.mongo_client(read_preference=pymongo.ReadPreference.SECONDARY) + client = secondary.mongo_client( + read_preference=pymongo.ReadPreference.SECONDARY + ) while True: - self.logger.info("Waiting for secondary on port %d to become available.", - secondary.port) + self.logger.info( + "Waiting for secondary on port %d to become available.", + secondary.port, + ) try: is_secondary = client.admin.command("isMaster")["secondary"] if is_secondary: @@ -497,14 +565,20 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface # propagate to all members and trigger a stable checkpoint on all persisted storage engines # nodes. admin = primary_client.get_database( - "admin", write_concern=pymongo.write_concern.WriteConcern(w="majority")) + "admin", write_concern=pymongo.write_concern.WriteConcern(w="majority") + ) admin.command("appendOplogNote", data={"await_stable_recovery_timestamp": 1}) for node in self.nodes: - self.logger.info("Waiting for node on port %d to have a stable recovery timestamp.", - node.port) - client = interface.build_client(node, self.auth_options, - read_preference=pymongo.ReadPreference.SECONDARY) + self.logger.info( + "Waiting for node on port %d to have a stable recovery timestamp.", + node.port, + ) + client = interface.build_client( + node, + self.auth_options, + read_preference=pymongo.ReadPreference.SECONDARY, + ) client_admin = client["admin"] @@ -513,7 +587,9 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface # The `lastStableRecoveryTimestamp` field contains a stable timestamp guaranteed to # exist on storage engine recovery to a stable timestamp. - last_stable_recovery_timestamp = status.get("lastStableRecoveryTimestamp", None) + last_stable_recovery_timestamp = status.get( + "lastStableRecoveryTimestamp", None + ) # A missing `lastStableRecoveryTimestamp` field indicates that the storage # engine does not support "recover to a stable timestamp". @@ -525,7 +601,9 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface if last_stable_recovery_timestamp.time: self.logger.info( "Node on port %d now has a stable timestamp for recovery. Time: %s", - node.port, last_stable_recovery_timestamp) + node.port, + last_stable_recovery_timestamp, + ) break time.sleep(0.1) # Wait a little bit before trying again. @@ -537,16 +615,25 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface """ get_config_res = client.admin.command( - {"replSetGetConfig": 1, "commitmentStatus": True, "$_internalIncludeNewlyAdded": True}) + { + "replSetGetConfig": 1, + "commitmentStatus": True, + "$_internalIncludeNewlyAdded": True, + } + ) for member in get_config_res["config"]["members"]: if "newlyAdded" in member: self.logger.info( - "Waiting longer for 'newlyAdded' removals, " + - "member %d is still 'newlyAdded'", member["_id"]) + "Waiting longer for 'newlyAdded' removals, " + + "member %d is still 'newlyAdded'", + member["_id"], + ) return True if not get_config_res["commitmentStatus"]: - self.logger.info("Waiting longer for 'newlyAdded' removals, " + - "config is not yet committed") + self.logger.info( + "Waiting longer for 'newlyAdded' removals, " + + "config is not yet committed" + ) return True return False @@ -581,46 +668,59 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface # TODO: Remove this in SERVER-80010. def _await_auto_bootstrapped_config_shard(self): connection_string = self.get_driver_connection_url() - self.logger.info("Waiting for %s to auto-bootstrap as a config shard...", connection_string) + self.logger.info( + "Waiting for %s to auto-bootstrap as a config shard...", connection_string + ) - deadline = time.time() + ReplicaSetFixture.AWAIT_SHARDING_INITIALIZATION_TIMEOUT_SECS + deadline = ( + time.time() + ReplicaSetFixture.AWAIT_SHARDING_INITIALIZATION_TIMEOUT_SECS + ) timeout_occurred = lambda: deadline - time.time() <= 0.0 while True: client = interface.build_client(self.get_primary(), self.auth_options) config_shard_count = client.get_database("config").command( - {"count": "shards", "query": {"_id": "config"}}) + {"count": "shards", "query": {"_id": "config"}} + ) - if config_shard_count['n'] == 1: + if config_shard_count["n"] == 1: break if timeout_occurred(): port = self.get_primary().port raise self.fixturelib.ServerFailure( - "mongod on port: {} failed waiting for auto-bootstrapped config shard success after {} seconds" - .format(port, interface.Fixture.AWAIT_READY_TIMEOUT_SECS)) + "mongod on port: {} failed waiting for auto-bootstrapped config shard success after {} seconds".format( + port, interface.Fixture.AWAIT_READY_TIMEOUT_SECS + ) + ) time.sleep(0.1) - self.logger.info("%s successfully auto-bootstrapped as a config shard...", - connection_string) + self.logger.info( + "%s successfully auto-bootstrapped as a config shard...", connection_string + ) def _check_initial_sync_node_has_uninitialized_fcv(self, initial_sync_node): sync_node_conn = initial_sync_node.mongo_client() self.logger.info("Checking that initial sync node has uninitialized fcv") try: fcv = sync_node_conn.admin.command( - {'getParameter': 1, 'featureCompatibilityVersion': 1}) + {"getParameter": 1, "featureCompatibilityVersion": 1} + ) - msg = "Initial sync node should have an uninitialized FCV, but got fcv: " + str(fcv) + msg = ( + "Initial sync node should have an uninitialized FCV, but got fcv: " + + str(fcv) + ) raise self.fixturelib.ServerFailure(msg) except pymongo.errors.OperationFailure as err: - if err.code == 258: #codeName == 'UnknownFeatureCompatibilityVersion' + if err.code == 258: # codeName == 'UnknownFeatureCompatibilityVersion' return raise def _pause_initial_sync_at_uninitialized_fcv(self, initial_sync_node): failpointOnCmd = { - 'configureFailPoint': 'initialSyncHangAfterResettingFCV', 'mode': 'alwaysOn' + "configureFailPoint": "initialSyncHangAfterResettingFCV", + "mode": "alwaysOn", } sync_node_conn = initial_sync_node.mongo_client() self.logger.info("Pausing initial sync at failpoint") @@ -629,47 +729,68 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface def _unpause_and_finish_initial_sync(self, initial_sync_node): failpoint_off_cmd = { - 'configureFailPoint': 'initialSyncHangAfterResettingFCV', 'mode': 'off' + "configureFailPoint": "initialSyncHangAfterResettingFCV", + "mode": "off", } self.logger.info("Unpausing initial sync") sync_node_conn = initial_sync_node.mongo_client() sync_node_conn.admin.command(failpoint_off_cmd) wait_for_initial_sync_finish_cmd = bson.SON( - [("replSetTest", 1), ("waitForMemberState", 2), - ("timeoutMillis", interface.ReplFixture.AWAIT_REPL_TIMEOUT_FOREVER_MINS * 60 * 1000)]) + [ + ("replSetTest", 1), + ("waitForMemberState", 2), + ( + "timeoutMillis", + interface.ReplFixture.AWAIT_REPL_TIMEOUT_FOREVER_MINS * 60 * 1000, + ), + ] + ) while True: try: self.logger.info("Waiting for initial sync to finish") sync_node_conn.admin.command(wait_for_initial_sync_finish_cmd) break except pymongo.errors.OperationFailure as err: - if err.code not in (self.INTERRUPTED_DUE_TO_REPL_STATE_CHANGE, - self.INTERRUPTED_DUE_TO_STORAGE_CHANGE): + if err.code not in ( + self.INTERRUPTED_DUE_TO_REPL_STATE_CHANGE, + self.INTERRUPTED_DUE_TO_STORAGE_CHANGE, + ): raise - msg = ("Interrupted while waiting for node to reach secondary state, retrying: {}" - ).format(err) + msg = ( + "Interrupted while waiting for node to reach secondary state, retrying: {}" + ).format(err) self.logger.error(msg) def _do_teardown(self, mode=None): - self.logger.info("Stopping all members of the replica set '%s'...", self.replset_name) + self.logger.info( + "Stopping all members of the replica set '%s'...", self.replset_name + ) running_at_start = self.is_running() if not running_at_start: - self.logger.info("All members of the replica set were expected to be running, " - "but weren't.") + self.logger.info( + "All members of the replica set were expected to be running, " + "but weren't." + ) teardown_handler = interface.FixtureTeardownHandler(self.logger) if self.initial_sync_node: if self.initial_sync_uninitialized_fcv: - self._check_initial_sync_node_has_uninitialized_fcv(self.initial_sync_node) + self._check_initial_sync_node_has_uninitialized_fcv( + self.initial_sync_node + ) self._unpause_and_finish_initial_sync(self.initial_sync_node) - teardown_handler.teardown(self.initial_sync_node, "initial sync node", mode=mode) + teardown_handler.teardown( + self.initial_sync_node, "initial sync node", mode=mode + ) # Terminate the secondaries first to reduce noise in the logs. for node in reversed(self.nodes): - teardown_handler.teardown(node, "replica set member on port %d" % node.port, mode=mode) + teardown_handler.teardown( + node, "replica set member on port %d" % node.port, mode=mode + ) if teardown_handler.was_successful(): self.logger.info("Successfully stopped all members of the replica set.") @@ -697,12 +818,17 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface """Return if `node` is master.""" is_master = client.admin.command("isMaster")["ismaster"] if is_master: - self.logger.info("The node on port %d is primary of replica set '%s'", node.port, - self.replset_name) + self.logger.info( + "The node on port %d is primary of replica set '%s'", + node.port, + self.replset_name, + ) return True return False - return self._await_cmd_all_nodes(is_primary, "waiting for a primary", timeout_secs) + return self._await_cmd_all_nodes( + is_primary, "waiting for a primary", timeout_secs + ) def _await_cmd_all_nodes(self, fn, msg, timeout_secs=None): """Run `fn` on all nodes until it returns a truthy value. @@ -726,13 +852,17 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface for node in all_nodes: now = time.time() if (now - start) >= timeout_secs: - msg = "Timed out while {} for replica set '{}'.".format(msg, self.replset_name) + msg = "Timed out while {} for replica set '{}'.".format( + msg, self.replset_name + ) self.logger.error(msg) raise self.fixturelib.ServerFailure(msg) try: if node.port not in clients: - clients[node.port] = interface.build_client(node, self.auth_options) + clients[node.port] = interface.build_client( + node, self.auth_options + ) if fn(clients[node.port], node): return node @@ -748,9 +878,10 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface # Check that the fixture is still running before stepping down or killing the primary. # This ensures we still detect some cases in which the fixture has already crashed. if not self.is_running(): - raise self.fixturelib.ServerFailure("ReplicaSetFixture {} expected to be running in" - " ContinuousStepdown, but wasn't.".format( - self.replset_name)) + raise self.fixturelib.ServerFailure( + "ReplicaSetFixture {} expected to be running in" + " ContinuousStepdown, but wasn't.".format(self.replset_name) + ) # If we're running with background reconfigs, it's possible to be in a scenario # where we kill a necessary voting node (i.e. in a 5 node repl set), only 2 are @@ -770,13 +901,21 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface should_kill = kill and random.choice([True, False]) action = "Killing" if should_kill else "Terminating" - self.logger.info("%s the primary on port %d of replica set '%s'.", action, primary.port, - self.replset_name) + self.logger.info( + "%s the primary on port %d of replica set '%s'.", + action, + primary.port, + self.replset_name, + ) # We send the mongod process the signal to exit but don't immediately wait for it to # exit because clean shutdown may take a while and we want to restore write availability # as quickly as possible. - teardown_mode = interface.TeardownMode.KILL if should_kill else interface.TeardownMode.TERMINATE + teardown_mode = ( + interface.TeardownMode.KILL + if should_kill + else interface.TeardownMode.TERMINATE + ) primary.mongod.stop(mode=teardown_mode) return True @@ -809,7 +948,8 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface if chosen_index is None or max_optime is None: raise self.fixturelib.ServerFailure( "Failed to find a secondary eligible for " - f"election; index: {chosen_index}, optime: {max_optime}") + f"election; index: {chosen_index}, optime: {max_optime}" + ) return self.nodes[chosen_index] @@ -818,13 +958,17 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface retry_start_time = time.time() while True: - member_infos = primary_client.admin.command({"replSetGetStatus": 1})["members"] + member_infos = primary_client.admin.command({"replSetGetStatus": 1})[ + "members" + ] chosen_node = get_chosen_node_from_replsetstatus(member_infos) if chosen_node.change_version_if_needed(primary): self.logger.info( "Waiting for the chosen secondary on port %d of replica set '%s' to exit.", - chosen_node.port, self.replset_name) + chosen_node.port, + self.replset_name, + ) teardown_mode = interface.TeardownMode.TERMINATE chosen_node.mongod.stop(mode=teardown_mode) @@ -832,7 +976,9 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface self.logger.info( "Attempting to restart the chosen secondary on port %d of replica set '%s'.", - chosen_node.port, self.replset_name) + chosen_node.port, + self.replset_name, + ) chosen_node.setup() self.logger.info(interface.create_fixture_table(self)) @@ -844,7 +990,10 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface if time.time() - retry_start_time > retry_time_secs: raise self.fixturelib.ServerFailure( "The old primary on port {} of replica set {} did not step up in" - " {} seconds.".format(chosen_node.port, self.replset_name, retry_time_secs)) + " {} seconds.".format( + chosen_node.port, self.replset_name, retry_time_secs + ) + ) return chosen_node @@ -853,7 +1002,9 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface try: self.logger.info( "Attempting to step up the chosen secondary on port %d of replica set '%s'.", - node.port, self.replset_name) + node.port, + self.replset_name, + ) client = interface.build_client(node, auth_options) client.admin.command("replSetStepUp") return True @@ -862,8 +1013,11 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface # not receiving enough votes. This can happen when the 'chosen' secondary's opTime # is behind that of other secondaries. We handle this by attempting to elect a # different secondary. - self.logger.info("Failed to step up the secondary on port %d of replica set '%s'.", - node.port, self.replset_name) + self.logger.info( + "Failed to step up the secondary on port %d of replica set '%s'.", + node.port, + self.replset_name, + ) return False except pymongo.errors.AutoReconnect: # It is possible for a replSetStepUp to fail with AutoReconnect if that node goes @@ -872,23 +1026,35 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface def restart_node(self, chosen): """Restart the new step up node.""" - self.logger.info("Waiting for the old primary on port %d of replica set '%s' to exit.", - chosen.port, self.replset_name) + self.logger.info( + "Waiting for the old primary on port %d of replica set '%s' to exit.", + chosen.port, + self.replset_name, + ) exit_code = chosen.mongod.wait() # This function is called after stop_primary() which could kill or cleanly shutdown the # process. We therefore also allow an exit code of -9. if exit_code in (0, -interface.TeardownMode.KILL.value): - self.logger.info("Successfully stopped the mongod on port {:d}.".format(chosen.port)) + self.logger.info( + "Successfully stopped the mongod on port {:d}.".format(chosen.port) + ) else: - self.logger.warning("Stopped the mongod on port {:d}. " - "Process exited with code {:d}.".format(chosen.port, exit_code)) + self.logger.warning( + "Stopped the mongod on port {:d}. " + "Process exited with code {:d}.".format(chosen.port, exit_code) + ) raise self.fixturelib.ServerFailure( "mongod on port {:d} with pid {:d} exited with code {:d}".format( - chosen.port, chosen.mongod.pid, exit_code)) + chosen.port, chosen.mongod.pid, exit_code + ) + ) - self.logger.info("Attempting to restart the old primary on port %d of replica set '%s'.", - chosen.port, self.replset_name) + self.logger.info( + "Attempting to restart the old primary on port %d of replica set '%s'.", + chosen.port, + self.replset_name, + ) # Restart the mongod on the old primary and wait until we can contact it again. Keep the # original preserve_dbpath to restore after restarting the mongod. @@ -909,15 +1075,17 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface def get_secondary_indices(self): """Return a list of secondary indices from the replica set.""" primary = self.get_primary() - return [index for index, node in enumerate(self.nodes) if node.port != primary.port] + return [ + index for index, node in enumerate(self.nodes) if node.port != primary.port + ] def get_voting_members(self): """Return the number of voting nodes in the replica set.""" primary = self.get_primary() client = primary.mongo_client() - members = client.admin.command({"replSetGetConfig": 1})['config']['members'] - voting_members = [member['host'] for member in members if member['votes'] == 1] + members = client.admin.command({"replSetGetConfig": 1})["config"]["members"] + voting_members = [member["host"] for member in members if member["votes"] == 1] return voting_members @@ -937,20 +1105,29 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface """Return options that may be passed to a mongod.""" mongod_options = self.mongod_options.copy() - mongod_options["dbpath"] = os.path.join(self._dbpath_prefix, "node{}".format(index)) - mongod_options["set_parameters"] = mongod_options.get("set_parameters", - self.fixturelib.make_historic( - {})).copy() + mongod_options["dbpath"] = os.path.join( + self._dbpath_prefix, "node{}".format(index) + ) + mongod_options["set_parameters"] = mongod_options.get( + "set_parameters", self.fixturelib.make_historic({}) + ).copy() if index == 0 and self.use_auto_bootstrap_procedure: del mongod_options["replSet"] if self.linear_chain and index > 0: self.mongod_options["set_parameters"][ - "failpoint.forceSyncSourceCandidate"] = self.fixturelib.make_historic({ + "failpoint.forceSyncSourceCandidate" + ] = self.fixturelib.make_historic( + { "mode": "alwaysOn", - "data": {"hostAndPort": self.nodes[index - 1].get_internal_connection_string()} - }) + "data": { + "hostAndPort": self.nodes[ + index - 1 + ].get_internal_connection_string() + }, + } + ) return mongod_options def get_logger_for_mongod(self, index): @@ -971,14 +1148,16 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface if self.shard_logging_prefix is not None: node_name = f"{self.shard_logging_prefix}:{node_name}" - return self.fixturelib.new_fixture_node_logger("ShardedClusterFixture", self.job_num, - node_name) + return self.fixturelib.new_fixture_node_logger( + "ShardedClusterFixture", self.job_num, node_name + ) if self.replicaset_logging_prefix is not None: node_name = f"{self.replicaset_logging_prefix}:{node_name}" - return self.fixturelib.new_fixture_node_logger(self.__class__.__name__, self.job_num, - node_name) + return self.fixturelib.new_fixture_node_logger( + self.__class__.__name__, self.job_num, node_name + ) def get_internal_connection_string(self): """Return the internal connection string.""" @@ -1010,8 +1189,12 @@ class ReplicaSetFixture(interface.ReplFixture, interface._DockerComposeInterface # anticipate the client will want to gracefully handle any failovers. conn_strs = [node.get_internal_connection_string() for node in self.nodes] if self.initial_sync_node: - conn_strs.append(self.initial_sync_node.get_internal_connection_string()) - return "mongodb://" + ",".join(conn_strs) + "/?replicaSet=" + self.replset_name + conn_strs.append( + self.initial_sync_node.get_internal_connection_string() + ) + return ( + "mongodb://" + ",".join(conn_strs) + "/?replicaSet=" + self.replset_name + ) else: # We return a direct connection to the expected pimary when only the first node is # electable because we want the client to error out if a stepdown occurs. @@ -1034,13 +1217,15 @@ def get_last_optime(client, fixturelib): optime_is_empty = False if isinstance(optime, bson.Timestamp): # PV0 - optime_is_empty = (optime == bson.Timestamp(0, 0)) + optime_is_empty = optime == bson.Timestamp(0, 0) else: # PV1 - optime_is_empty = (optime["ts"] == bson.Timestamp(0, 0) and optime["t"] == -1) + optime_is_empty = optime["ts"] == bson.Timestamp(0, 0) and optime["t"] == -1 if optime_is_empty: raise fixturelib.ServerFailure( "Uninitialized opTime being reported by {addr[0]}:{addr[1]}: {repl_set_status}".format( - addr=client.address, repl_set_status=repl_set_status)) + addr=client.address, repl_set_status=repl_set_status + ) + ) return optime diff --git a/buildscripts/resmokelib/testing/fixtures/shardedcluster.py b/buildscripts/resmokelib/testing/fixtures/shardedcluster.py index 8e5b9a4a3b0..ed4b2b33ded 100644 --- a/buildscripts/resmokelib/testing/fixtures/shardedcluster.py +++ b/buildscripts/resmokelib/testing/fixtures/shardedcluster.py @@ -60,28 +60,39 @@ class ShardedClusterFixture(interface.Fixture, interface._DockerComposeInterface and all routing requests are directed to the routing ports of those nodes. TODO SERVER-86554: Support a mix of shard servers with the routerPort opened and not. """ - interface.Fixture.__init__(self, logger, job_num, fixturelib, dbpath_prefix=dbpath_prefix) + interface.Fixture.__init__( + self, logger, job_num, fixturelib, dbpath_prefix=dbpath_prefix + ) if "dbpath" in mongod_options: raise ValueError("Cannot specify mongod_options.dbpath") self.mongos_options = self.fixturelib.make_historic( - self.fixturelib.default_if_none(mongos_options, {})) + self.fixturelib.default_if_none(mongos_options, {}) + ) # The mongotHost and searchIndexManagementHostAndPort options cannot be set on mongos_options yet because # the port value is only assigned in MongoDFixture initialization, which happens later. self.launch_mongot = launch_mongot # mongod options self.mongod_options = self.fixturelib.make_historic( - self.fixturelib.default_if_none(mongod_options, {})) + self.fixturelib.default_if_none(mongod_options, {}) + ) self.mongod_executable = mongod_executable self.mongod_options["set_parameters"] = self.fixturelib.make_historic( - mongod_options.get("set_parameters", {})).copy() - self.mongod_options["set_parameters"]["migrationLockAcquisitionMaxWaitMS"] = \ - self.mongod_options["set_parameters"].get("migrationLockAcquisitionMaxWaitMS", 30000) + mongod_options.get("set_parameters", {}) + ).copy() + self.mongod_options["set_parameters"]["migrationLockAcquisitionMaxWaitMS"] = ( + self.mongod_options[ + "set_parameters" + ].get("migrationLockAcquisitionMaxWaitMS", 30000) + ) # Extend time for transactions by default to account for slow machines during testing. - self.mongod_options["set_parameters"]["maxTransactionLockRequestTimeoutMillis"] = \ - self.mongod_options["set_parameters"].get("maxTransactionLockRequestTimeoutMillis", 10 * 1000) + self.mongod_options["set_parameters"][ + "maxTransactionLockRequestTimeoutMillis" + ] = self.mongod_options["set_parameters"].get( + "maxTransactionLockRequestTimeoutMillis", 10 * 1000 + ) # Misc other options for the fixture. self.config_shard = config_shard @@ -94,18 +105,24 @@ class ShardedClusterFixture(interface.Fixture, interface._DockerComposeInterface self.embedded_router_mode = embedded_router self.replica_endpoint_mode = replica_set_endpoint self.set_cluster_parameter = set_cluster_parameter - self.has_uninitialized_fcv_initial_sync_nodes_in_shards = has_uninitialized_fcv_initial_sync_nodes_in_shards + self.has_uninitialized_fcv_initial_sync_nodes_in_shards = ( + has_uninitialized_fcv_initial_sync_nodes_in_shards + ) self.inject_catalog_metadata = inject_catalog_metadata # Options for roles - shardsvr, configsvr. self.configsvr_options = self.fixturelib.make_historic( - self.fixturelib.default_if_none(configsvr_options, {})) + self.fixturelib.default_if_none(configsvr_options, {}) + ) self.shard_options = self.fixturelib.make_historic( - self.fixturelib.default_if_none(shard_options, {})) + self.fixturelib.default_if_none(shard_options, {}) + ) # Logging prefix options. # `cluster_logging_prefix` is the logging prefix used in cluster to cluster replication. - self.cluster_logging_prefix = "" if cluster_logging_prefix is None else f"{cluster_logging_prefix}:" + self.cluster_logging_prefix = ( + "" if cluster_logging_prefix is None else f"{cluster_logging_prefix}:" + ) self.configsvr_shard_logging_prefix = f"{self.cluster_logging_prefix}configsvr" self.rs_shard_logging_prefix = f"{self.cluster_logging_prefix}shard" self.mongos_logging_prefix = f"{self.cluster_logging_prefix}mongos" @@ -125,23 +142,30 @@ class ShardedClusterFixture(interface.Fixture, interface._DockerComposeInterface "random_migrations can only be enabled when balancer is enabled (enable_balancer=True)" ) - if "failpoint.balancerShouldReturnRandomMigrations" in self.mongod_options[ - "set_parameters"]: + if ( + "failpoint.balancerShouldReturnRandomMigrations" + in self.mongod_options["set_parameters"] + ): raise ValueError( "Cannot enable random_migrations because balancerShouldReturnRandomMigrations failpoint is already present in mongod_options" ) # Enable random migrations failpoint self.mongod_options["set_parameters"][ - "failpoint.balancerShouldReturnRandomMigrations"] = {"mode": "alwaysOn"} + "failpoint.balancerShouldReturnRandomMigrations" + ] = {"mode": "alwaysOn"} # Reduce migration throttling to increase frequency of random migrations - self.mongod_options["set_parameters"][ - "balancerMigrationsThrottlingMs"] = self.mongod_options["set_parameters"].get( - "balancerMigrationsThrottlingMs", 250) # millis + self.mongod_options["set_parameters"]["balancerMigrationsThrottlingMs"] = ( + self.mongod_options[ + "set_parameters" + ].get("balancerMigrationsThrottlingMs", 250) + ) # millis - self._dbpath_prefix = os.path.join(dbpath_prefix if dbpath_prefix else self._dbpath_prefix, - self.config.FIXTURE_SUBDIR) + self._dbpath_prefix = os.path.join( + dbpath_prefix if dbpath_prefix else self._dbpath_prefix, + self.config.FIXTURE_SUBDIR, + ) self.configsvr = None self.mongos = [] @@ -160,12 +184,13 @@ class ShardedClusterFixture(interface.Fixture, interface._DockerComposeInterface out.extend(self.configsvr.pids()) else: self.logger.debug( - 'Config server not running when gathering sharded cluster fixture pids.') + "Config server not running when gathering sharded cluster fixture pids." + ) if self.shards is not None: for shard in self.shards: out.extend(shard.pids()) else: - self.logger.debug('No shards when gathering sharded cluster fixture pids.') + self.logger.debug("No shards when gathering sharded cluster fixture pids.") return out def setup(self): @@ -207,8 +232,10 @@ class ShardedClusterFixture(interface.Fixture, interface._DockerComposeInterface except pymongo.errors.OperationFailure as err: if err.code != self._WRITE_CONCERN_FAILED: raise err - self.logger.info("Ignoring write concern timeout for refreshLogicalSessionCacheNow " - "command and continuing to wait") + self.logger.info( + "Ignoring write concern timeout for refreshLogicalSessionCacheNow " + "command and continuing to wait" + ) target.await_last_op_committed(target.AWAIT_REPL_TIMEOUT_FOREVER_MINS * 60) def get_shard_ids(self): @@ -230,7 +257,9 @@ class ShardedClusterFixture(interface.Fixture, interface._DockerComposeInterface # Need to get the new config shard connection string generated from the auto-bootstrap procedure if self.use_auto_bootstrap_procedure and not self.embedded_router_mode: for mongos in self.mongos: - mongos.mongos_options["configdb"] = self.configsvr.get_internal_connection_string() + mongos.mongos_options["configdb"] = ( + self.configsvr.get_internal_connection_string() + ) # We call mongos.setup() in self.await_ready() function instead of self.setup() # because mongos routers have to connect to a running cluster. @@ -239,8 +268,9 @@ class ShardedClusterFixture(interface.Fixture, interface._DockerComposeInterface # In search enabled sharded cluster, mongos has to be spun up with a connection string to a # mongot in order to issue PlanShardedSearch commands. mongos.mongos_options["mongotHost"] = self.mongotHost - mongos.mongos_options[ - "searchIndexManagementHostAndPort"] = self.searchIndexManagementHostAndPort + mongos.mongos_options["searchIndexManagementHostAndPort"] = ( + self.searchIndexManagementHostAndPort + ) # Start up the mongos. mongos.setup() @@ -276,7 +306,9 @@ class ShardedClusterFixture(interface.Fixture, interface._DockerComposeInterface if self.inject_catalog_metadata: csrs_client = interface.build_client(self.configsvr, self.auth_options) - sharded_cluster_util.inject_catalog_metadata_on_the_csrs(csrs_client, self.inject_catalog_metadata) + sharded_cluster_util.inject_catalog_metadata_on_the_csrs( + csrs_client, self.inject_catalog_metadata + ) self.is_ready = True @@ -285,7 +317,9 @@ class ShardedClusterFixture(interface.Fixture, interface._DockerComposeInterface client = interface.build_client(self, self.auth_options) command_request = { "setClusterParameter": { - self.set_cluster_parameter["parameter"]: self.set_cluster_parameter["value"] + self.set_cluster_parameter["parameter"]: self.set_cluster_parameter[ + "value" + ] }, } client.admin.command(command_request) @@ -295,7 +329,8 @@ class ShardedClusterFixture(interface.Fixture, interface._DockerComposeInterface # Running getClusterParameter on a router causes it to refresh its cache. for mongos in self.mongos: mongos.mongo_client().admin.command( - {"getClusterParameter": self.set_cluster_parameter["parameter"]}) + {"getClusterParameter": self.set_cluster_parameter["parameter"]} + ) # TODO SERVER-76343 remove the join_migrations parameter and the if clause depending on it. def stop_balancer(self, timeout_ms=300000, join_migrations=True): @@ -304,7 +339,9 @@ class ShardedClusterFixture(interface.Fixture, interface._DockerComposeInterface client.admin.command({"balancerStop": 1}, maxTimeMS=timeout_ms) if join_migrations: for shard in self.shards: - shard_client = interface.build_client(shard.get_primary(), self.auth_options) + shard_client = interface.build_client( + shard.get_primary(), self.auth_options + ) shard_client.admin.command({"_shardsvrJoinMigrations": 1}) self.logger.info("Stopped the balancer") @@ -319,7 +356,7 @@ class ShardedClusterFixture(interface.Fixture, interface._DockerComposeInterface csrs_client = interface.build_client(self.configsvr, self.auth_options) try: res = csrs_client.admin.command({"getParameter": 1, full_ff_name: 1}) - return bool(res[full_ff_name]['value']) + return bool(res[full_ff_name]["value"]) except pymongo.errors.OperationFailure as err: if err.code == 72: # InvalidOptions # The feature flag is not present @@ -345,13 +382,17 @@ class ShardedClusterFixture(interface.Fixture, interface._DockerComposeInterface running_at_start = self.is_running() if not running_at_start: - self.logger.warning("All members of the sharded cluster were expected to be running, " - "but weren't.") + self.logger.warning( + "All members of the sharded cluster were expected to be running, " + "but weren't." + ) # If we're killing or aborting to archive data files, stopping the balancer will execute # server commands that might lead to on-disk changes from the point of failure. - if self.enable_balancer and mode not in (interface.TeardownMode.KILL, - interface.TeardownMode.ABORT): + if self.enable_balancer and mode not in ( + interface.TeardownMode.KILL, + interface.TeardownMode.ABORT, + ): self.stop_balancer() teardown_handler = interface.FixtureTeardownHandler(self.logger) @@ -377,17 +418,27 @@ class ShardedClusterFixture(interface.Fixture, interface._DockerComposeInterface def is_running(self): """Return true if all nodes in the cluster are all still operating.""" - return (self.configsvr is not None and self.configsvr.is_running() - and all(shard.is_running() - for shard in self.shards if not shard.removeshard_teardown_marker) - and all(mongos.is_running() for mongos in self.mongos)) + return ( + self.configsvr is not None + and self.configsvr.is_running() + and all( + shard.is_running() + for shard in self.shards + if not shard.removeshard_teardown_marker + ) + and all(mongos.is_running() for mongos in self.mongos) + ) def get_internal_connection_string(self): """Return the internal connection string.""" if self.mongos is None: - raise ValueError("Must call setup() before calling get_internal_connection_string()") + raise ValueError( + "Must call setup() before calling get_internal_connection_string()" + ) - return ",".join([mongos.get_internal_connection_string() for mongos in self.mongos]) + return ",".join( + [mongos.get_internal_connection_string() for mongos in self.mongos] + ) def get_driver_connection_url(self): """Return the driver connection URL.""" @@ -398,16 +449,20 @@ class ShardedClusterFixture(interface.Fixture, interface._DockerComposeInterface # commands. if len(self.shards) == 0: raise ValueError( - "Must call install_rs_shard() before calling get_internal_connection_string()") + "Must call install_rs_shard() before calling get_internal_connection_string()" + ) if len(self.shards) > 1: - raise ValueError("Cannot use replica set endpoint on a multi-shard cluster") + raise ValueError( + "Cannot use replica set endpoint on a multi-shard cluster" + ) return self.shards[0].get_driver_connection_url() if self.embedded_router_mode: # If the embedded router is enabled, we must have a mongos placed in a node acting as a # configsvr. - config_mongos = next((mongos for mongos in self.mongos if mongos.is_from_configsvr()), - None) + config_mongos = next( + (mongos for mongos in self.mongos if mongos.is_from_configsvr()), None + ) if config_mongos: return config_mongos.get_driver_connection_url() else: @@ -428,8 +483,9 @@ class ShardedClusterFixture(interface.Fixture, interface._DockerComposeInterface def get_configsvr_logger(self): """Return a new logging.Logger instance used for a config server shard.""" - return self.fixturelib.new_fixture_node_logger(self.__class__.__name__, self.job_num, - self.configsvr_shard_logging_prefix) + return self.fixturelib.new_fixture_node_logger( + self.__class__.__name__, self.job_num, self.configsvr_shard_logging_prefix + ) def get_configsvr_kwargs(self): """Return args to create replicaset.ReplicaSetFixture configured as the config server.""" @@ -444,7 +500,8 @@ class ShardedClusterFixture(interface.Fixture, interface._DockerComposeInterface mongod_options = self.mongod_options.copy() mongod_options = self.fixturelib.merge_mongo_option_dicts( mongod_options, - self.fixturelib.make_historic(configsvr_options.pop("mongod_options", {}))) + self.fixturelib.make_historic(configsvr_options.pop("mongod_options", {})), + ) mongod_options["configsvr"] = "" mongod_options["dbpath"] = os.path.join(self._dbpath_prefix, "config") mongod_options["replSet"] = ShardedClusterFixture._CONFIGSVR_REPLSET_NAME @@ -454,10 +511,14 @@ class ShardedClusterFixture(interface.Fixture, interface._DockerComposeInterface mongod_options["routerPort"] = "" return { - "mongod_options": mongod_options, "mongod_executable": self.mongod_executable, - "preserve_dbpath": preserve_dbpath, "num_nodes": num_nodes, - "auth_options": auth_options, "replset_config_options": replset_config_options, - "shard_logging_prefix": self.configsvr_shard_logging_prefix, **configsvr_options + "mongod_options": mongod_options, + "mongod_executable": self.mongod_executable, + "preserve_dbpath": preserve_dbpath, + "num_nodes": num_nodes, + "auth_options": auth_options, + "replset_config_options": replset_config_options, + "shard_logging_prefix": self.configsvr_shard_logging_prefix, + **configsvr_options, } def install_configsvr(self, configsvr): @@ -471,8 +532,9 @@ class ShardedClusterFixture(interface.Fixture, interface._DockerComposeInterface def get_rs_shard_logger(self, index): """Return a new logging.Logger instance used for a replica set shard.""" shard_logging_prefix = self._get_rs_shard_logging_prefix(index) - return self.fixturelib.new_fixture_node_logger(self.__class__.__name__, self.job_num, - shard_logging_prefix) + return self.fixturelib.new_fixture_node_logger( + self.__class__.__name__, self.job_num, shard_logging_prefix + ) def get_rs_shard_kwargs(self, index): """Return args to create replicaset.ReplicaSetFixture configured as a shard in a sharded cluster.""" @@ -487,9 +549,13 @@ class ShardedClusterFixture(interface.Fixture, interface._DockerComposeInterface mongod_options = self.mongod_options.copy() mongod_options = self.fixturelib.merge_mongo_option_dicts( - mongod_options, self.fixturelib.make_historic(shard_options.pop("mongod_options", {}))) + mongod_options, + self.fixturelib.make_historic(shard_options.pop("mongod_options", {})), + ) mongod_options["shardsvr"] = "" - mongod_options["dbpath"] = os.path.join(self._dbpath_prefix, "shard{}".format(index)) + mongod_options["dbpath"] = os.path.join( + self._dbpath_prefix, "shard{}".format(index) + ) mongod_options["replSet"] = self._SHARD_REPLSET_NAME_PREFIX + str(index) if self.config_shard == index: @@ -502,10 +568,12 @@ class ShardedClusterFixture(interface.Fixture, interface._DockerComposeInterface if "mongod_options" in configsvr_options: mongod_options = self.fixturelib.merge_mongo_option_dicts( - mongod_options, configsvr_options["mongod_options"]) + mongod_options, configsvr_options["mongod_options"] + ) if "replset_config_options" in configsvr_options: replset_config_options = self.fixturelib.merge_mongo_option_dicts( - replset_config_options, configsvr_options["replset_config_options"]) + replset_config_options, configsvr_options["replset_config_options"] + ) for option, value in configsvr_options.items(): if option in ("num_nodes", "mongod_options", "replset_config_options"): @@ -513,22 +581,28 @@ class ShardedClusterFixture(interface.Fixture, interface._DockerComposeInterface if option in shard_options: if shard_options[option] != value: raise Exception( - "Conflicting values when combining shard and configsvr options") + "Conflicting values when combining shard and configsvr options" + ) else: shard_options[option] = value if self.embedded_router_mode: mongod_options["routerPort"] = "" - use_auto_bootstrap_procedure = self.use_auto_bootstrap_procedure and self.config_shard == index + use_auto_bootstrap_procedure = ( + self.use_auto_bootstrap_procedure and self.config_shard == index + ) shard_logging_prefix = self._get_rs_shard_logging_prefix(index) return { - "mongod_options": mongod_options, "mongod_executable": self.mongod_executable, - "auth_options": auth_options, "preserve_dbpath": preserve_dbpath, - "replset_config_options": replset_config_options, "shard_logging_prefix": - shard_logging_prefix, "use_auto_bootstrap_procedure": use_auto_bootstrap_procedure, - **shard_options + "mongod_options": mongod_options, + "mongod_executable": self.mongod_executable, + "auth_options": auth_options, + "preserve_dbpath": preserve_dbpath, + "replset_config_options": replset_config_options, + "shard_logging_prefix": shard_logging_prefix, + "use_auto_bootstrap_procedure": use_auto_bootstrap_procedure, + **shard_options, } def install_rs_shard(self, rs_shard): @@ -537,9 +611,14 @@ class ShardedClusterFixture(interface.Fixture, interface._DockerComposeInterface def get_mongos_logger(self, index, total): """Return a new logging.Logger instance used for a mongos.""" - logger_name = self.mongos_logging_prefix if total == 1 else f"{self.mongos_logging_prefix}{index}" - return self.fixturelib.new_fixture_node_logger(self.__class__.__name__, self.job_num, - logger_name) + logger_name = ( + self.mongos_logging_prefix + if total == 1 + else f"{self.mongos_logging_prefix}{index}" + ) + return self.fixturelib.new_fixture_node_logger( + self.__class__.__name__, self.job_num, logger_name + ) def get_mongos_kwargs(self): """Return options that may be passed to a mongos.""" @@ -548,9 +627,9 @@ class ShardedClusterFixture(interface.Fixture, interface._DockerComposeInterface if self.config_shard is not None: if "set_parameters" not in mongos_options: mongos_options["set_parameters"] = {} - mongos_options["set_parameters"] = mongos_options.get("set_parameters", - self.fixturelib.make_historic( - {})).copy() + mongos_options["set_parameters"] = mongos_options.get( + "set_parameters", self.fixturelib.make_historic({}) + ).copy() return {"dbpath_prefix": self._dbpath_prefix, "mongos_options": mongos_options} def install_mongos(self, mongos): @@ -594,17 +673,22 @@ class ExternalShardedClusterFixture(external.ExternalFixture, ShardedClusterFixt """Initialize ExternalShardedClusterFixture.""" self.dummy_fixture = _builder.make_dummy_fixture(original_suite_name) self.shell_conn_string = "mongodb://" + ",".join( - [f"mongos{i}:27017" for i in range(self.dummy_fixture.num_mongos)]) + [f"mongos{i}:27017" for i in range(self.dummy_fixture.num_mongos)] + ) - external.ExternalFixture.__init__(self, logger, job_num, fixturelib, self.shell_conn_string) - ShardedClusterFixture.__init__(self, logger, job_num, fixturelib, mongod_options={}) + external.ExternalFixture.__init__( + self, logger, job_num, fixturelib, self.shell_conn_string + ) + ShardedClusterFixture.__init__( + self, logger, job_num, fixturelib, mongod_options={} + ) def setup(self): """Execute some setup before offically starting testing against this external cluster.""" client = pymongo.MongoClient(self.get_driver_connection_url()) for i in range(50): if i == 49: - raise RuntimeError('Sharded Cluster setup has timed out.') + raise RuntimeError("Sharded Cluster setup has timed out.") payload = client.admin.command({"listShards": 1}) if len(payload["shards"]) == self.dummy_fixture.num_shards: print("Sharded Cluster available.") @@ -614,7 +698,7 @@ class ExternalShardedClusterFixture(external.ExternalFixture, ShardedClusterFixt time.sleep(5) continue if len(payload["shards"]) > self.dummy_fixture.num_shards: - raise RuntimeError('More shards in cluster than expected.') + raise RuntimeError("More shards in cluster than expected.") def pids(self): """Use ExternalFixture method.""" @@ -663,7 +747,8 @@ class _RouterView(interface.Fixture): self.is_configsvr = is_configsvr if not self.port: raise ValueError( - "Mongod must be started with the --routerPort flag to support a RouterView") + "Mongod must be started with the --routerPort flag to support a RouterView" + ) def pids(self): return self.mongod.pids @@ -692,12 +777,18 @@ class _RouterView(interface.Fixture): if remaining <= 0.0: raise self.fixturelib.ServerFailure( "Failed to connect to embedded router on port {} after {} seconds".format( - self.port, interface.Fixture.AWAIT_READY_TIMEOUT_SECS)) + self.port, interface.Fixture.AWAIT_READY_TIMEOUT_SECS + ) + ) - self.logger.info("Waiting to connect to embedded router on port %d.", self.port) + self.logger.info( + "Waiting to connect to embedded router on port %d.", self.port + ) time.sleep(0.1) # Wait a little bit before trying again. - self.logger.info("Successfully contacted the embedded router on port %d.", self.port) + self.logger.info( + "Successfully contacted the embedded router on port %d.", self.port + ) def is_running(self): """Return true if the cluster is still operating.""" @@ -715,8 +806,16 @@ class _RouterView(interface.Fixture): class _MongoSFixture(interface.Fixture, interface._DockerComposeInterface): """Fixture which provides JSTests with a mongos to connect to.""" - def __init__(self, logger, job_num, fixturelib, dbpath_prefix, mongos_executable=None, - mongos_options=None, add_feature_flags=False): + def __init__( + self, + logger, + job_num, + fixturelib, + dbpath_prefix, + mongos_executable=None, + mongos_options=None, + add_feature_flags=False, + ): """Initialize _MongoSFixture.""" interface.Fixture.__init__(self, logger, job_num, fixturelib) @@ -725,11 +824,13 @@ class _MongoSFixture(interface.Fixture, interface._DockerComposeInterface): self.config = self.fixturelib.get_config() # Default to command line options if the YAML configuration is not passed in. - self.mongos_executable = self.fixturelib.default_if_none(mongos_executable, - self.config.MONGOS_EXECUTABLE) + self.mongos_executable = self.fixturelib.default_if_none( + mongos_executable, self.config.MONGOS_EXECUTABLE + ) self.mongos_options = self.fixturelib.make_historic( - self.fixturelib.default_if_none(mongos_options, {})).copy() + self.fixturelib.default_if_none(mongos_options, {}) + ).copy() if add_feature_flags: for ff in self.config.ENABLED_FEATURE_FLAGS: @@ -745,18 +846,26 @@ class _MongoSFixture(interface.Fixture, interface._DockerComposeInterface): """Set up the sharded cluster.""" if self.config.ALWAYS_USE_LOG_FILES: self.mongos_options["logpath"] = self._dbpath_prefix + "/{name}.log".format( - name=self.logger.name) + name=self.logger.name + ) self.mongos_options["logappend"] = "" launcher = MongosLauncher(self.fixturelib) - mongos, _ = launcher.launch_mongos_program(self.logger, self.job_num, - executable=self.mongos_executable, - mongos_options=self.mongos_options) + mongos, _ = launcher.launch_mongos_program( + self.logger, + self.job_num, + executable=self.mongos_executable, + mongos_options=self.mongos_options, + ) self.mongos_options["port"] = self.port try: - self.logger.info("Starting mongos on port %d...\n%s", self.port, mongos.as_command()) + self.logger.info( + "Starting mongos on port %d...\n%s", self.port, mongos.as_command() + ) mongos.start() - self.logger.info("mongos started on port %d with pid %d.", self.port, mongos.pid) + self.logger.info( + "mongos started on port %d with pid %d.", self.port, mongos.pid + ) except Exception as err: msg = "Failed to start mongos on port {:d}: {}".format(self.port, err) self.logger.exception(msg) @@ -773,7 +882,7 @@ class _MongoSFixture(interface.Fixture, interface._DockerComposeInterface): if self.mongos is not None: return [self.mongos.pid] else: - self.logger.debug('Mongos not running when gathering mongos fixture pids.') + self.logger.debug("Mongos not running when gathering mongos fixture pids.") return [] def is_from_configsvr(self): @@ -793,7 +902,8 @@ class _MongoSFixture(interface.Fixture, interface._DockerComposeInterface): if exit_code is not None: raise self.fixturelib.ServerFailure( "Could not connect to mongos on port {}, process ended" - " unexpectedly with code {}.".format(self.port, exit_code)) + " unexpectedly with code {}.".format(self.port, exit_code) + ) try: # Use a shorter connection timeout to more closely satisfy the requested deadline. @@ -805,7 +915,9 @@ class _MongoSFixture(interface.Fixture, interface._DockerComposeInterface): if remaining <= 0.0: raise self.fixturelib.ServerFailure( "Failed to connect to mongos on port {} after {} seconds".format( - self.port, interface.Fixture.AWAIT_READY_TIMEOUT_SECS)) + self.port, interface.Fixture.AWAIT_READY_TIMEOUT_SECS + ) + ) self.logger.info("Waiting to connect to mongos on port %d.", self.port) time.sleep(0.1) # Wait a little bit before trying again. @@ -826,14 +938,19 @@ class _MongoSFixture(interface.Fixture, interface._DockerComposeInterface): if mode == interface.TeardownMode.ABORT: self.logger.info( "Attempting to send SIGABRT from resmoke to mongos on port %d with pid %d...", - self.port, self.mongos.pid) + self.port, + self.mongos.pid, + ) else: - self.logger.info("Stopping mongos on port %d with pid %d...", self.port, - self.mongos.pid) + self.logger.info( + "Stopping mongos on port %d with pid %d...", self.port, self.mongos.pid + ) if not self.is_running(): exit_code = self.mongos.poll() - msg = ("mongos on port {:d} was expected to be running, but wasn't. " - "Process exited with code {:d}").format(self.port, exit_code) + msg = ( + "mongos on port {:d} was expected to be running, but wasn't. " + "Process exited with code {:d}" + ).format(self.port, exit_code) self.logger.warning(msg) raise self.fixturelib.ServerFailure(msg) @@ -842,20 +959,30 @@ class _MongoSFixture(interface.Fixture, interface._DockerComposeInterface): # Python's subprocess module returns negative versions of system calls. if exit_code == 0 or (mode is not None and exit_code == -(mode.value)): - self.logger.info("Successfully stopped the mongos on port {:d}".format(self.port)) + self.logger.info( + "Successfully stopped the mongos on port {:d}".format(self.port) + ) else: - self.logger.warning("Stopped the mongos on port {:d}. " - "Process exited with code {:d}.".format(self.port, exit_code)) + self.logger.warning( + "Stopped the mongos on port {:d}. " + "Process exited with code {:d}.".format(self.port, exit_code) + ) raise self.fixturelib.ServerFailure( "mongos on port {:d} with pid {:d} exited with code {:d}".format( - self.port, self.mongos.pid, exit_code)) + self.port, self.mongos.pid, exit_code + ) + ) def is_running(self): """Return true if the cluster is still operating.""" return self.mongos is not None and self.mongos.poll() is None def _get_hostname(self): - return self.logger.external_sut_hostname if self.config.NOOP_MONGO_D_S_PROCESSES else 'localhost' + return ( + self.logger.external_sut_hostname + if self.config.NOOP_MONGO_D_S_PROCESSES + else "localhost" + ) def get_internal_connection_string(self): """Return the internal connection string.""" @@ -875,8 +1002,13 @@ class _MongoSFixture(interface.Fixture, interface._DockerComposeInterface): self.logger.warning("The mongos fixture has not been set up yet.") return [] - info = interface.NodeInfo(full_name=self.logger.full_name, name=self.logger.name, - port=self.port, pid=self.mongos.pid, router_port=None) + info = interface.NodeInfo( + full_name=self.logger.full_name, + name=self.logger.name, + port=self.port, + pid=self.mongos.pid, + router_port=None, + ) return [info] @@ -906,39 +1038,47 @@ class MongosLauncher(object): return DEFAULT_EVERGREEN_MONGOS_LOG_COMPONENT_VERBOSITY return DEFAULT_MONGOS_LOG_COMPONENT_VERBOSITY - def launch_mongos_program(self, logger, job_num, executable=None, process_kwargs=None, - mongos_options=None): + def launch_mongos_program( + self, logger, job_num, executable=None, process_kwargs=None, mongos_options=None + ): """Return a Process instance that starts a mongos with arguments constructed from 'kwargs'.""" - executable = self.fixturelib.default_if_none(executable, - self.config.DEFAULT_MONGOS_EXECUTABLE) + executable = self.fixturelib.default_if_none( + executable, self.config.DEFAULT_MONGOS_EXECUTABLE + ) # Apply the --setParameter command line argument. Command line options to resmoke.py override # the YAML configuration. suite_set_parameters = mongos_options.setdefault("set_parameters", {}) if self.config.MONGOS_SET_PARAMETERS is not None: - suite_set_parameters.update(yaml.safe_load(self.config.MONGOS_SET_PARAMETERS)) + suite_set_parameters.update( + yaml.safe_load(self.config.MONGOS_SET_PARAMETERS) + ) if "mongotHost" in mongos_options: suite_set_parameters["mongotHost"] = mongos_options.pop("mongotHost") - suite_set_parameters["searchIndexManagementHostAndPort"] = mongos_options.pop( - "searchIndexManagementHostAndPort") + suite_set_parameters["searchIndexManagementHostAndPort"] = ( + mongos_options.pop("searchIndexManagementHostAndPort") + ) # Set default log verbosity levels if none were specified. if "logComponentVerbosity" not in suite_set_parameters: - suite_set_parameters[ - "logComponentVerbosity"] = self.default_mongos_log_component_verbosity() + suite_set_parameters["logComponentVerbosity"] = ( + self.default_mongos_log_component_verbosity() + ) # Set default shutdown timeout millis if none was specified. if "mongosShutdownTimeoutMillisForSignaledShutdown" not in suite_set_parameters: - suite_set_parameters[ - "mongosShutdownTimeoutMillisForSignaledShutdown"] = DEFAULT_MONGOS_SHUTDOWN_TIMEOUT_MILLIS + suite_set_parameters["mongosShutdownTimeoutMillisForSignaledShutdown"] = ( + DEFAULT_MONGOS_SHUTDOWN_TIMEOUT_MILLIS + ) _add_testing_set_parameters(suite_set_parameters) - return self.fixturelib.mongos_program(logger, job_num, executable, process_kwargs, - mongos_options) + return self.fixturelib.mongos_program( + logger, job_num, executable, process_kwargs, mongos_options + ) def _add_testing_set_parameters(suite_set_parameters): @@ -949,4 +1089,6 @@ def _add_testing_set_parameters(suite_set_parameters): """ suite_set_parameters.setdefault("testingDiagnosticsEnabled", True) suite_set_parameters.setdefault("enableTestCommands", True) - suite_set_parameters.setdefault("disableTransitionFromLatestToLastContinuous", False) + suite_set_parameters.setdefault( + "disableTransitionFromLatestToLastContinuous", False + ) diff --git a/buildscripts/resmokelib/testing/fixtures/standalone.py b/buildscripts/resmokelib/testing/fixtures/standalone.py index d1d05da667c..7c92b192093 100644 --- a/buildscripts/resmokelib/testing/fixtures/standalone.py +++ b/buildscripts/resmokelib/testing/fixtures/standalone.py @@ -16,13 +16,26 @@ from buildscripts.resmokelib.testing.fixtures import interface class MongoDFixture(interface.Fixture, interface._DockerComposeInterface): """Fixture which provides JSTests with a standalone mongod to run against.""" - def __init__(self, logger, job_num, fixturelib, mongod_executable=None, mongod_options=None, - add_feature_flags=False, dbpath_prefix=None, preserve_dbpath=False, port=None, - launch_mongot=False): + def __init__( + self, + logger, + job_num, + fixturelib, + mongod_executable=None, + mongod_options=None, + add_feature_flags=False, + dbpath_prefix=None, + preserve_dbpath=False, + port=None, + launch_mongot=False, + ): """Initialize MongoDFixture with different options for the mongod process.""" - interface.Fixture.__init__(self, logger, job_num, fixturelib, dbpath_prefix=dbpath_prefix) + interface.Fixture.__init__( + self, logger, job_num, fixturelib, dbpath_prefix=dbpath_prefix + ) self.mongod_options = self.fixturelib.make_historic( - self.fixturelib.default_if_none(mongod_options, {})) + self.fixturelib.default_if_none(mongod_options, {}) + ) if "set_parameters" not in self.mongod_options: self.mongod_options["set_parameters"] = {} @@ -32,17 +45,21 @@ class MongoDFixture(interface.Fixture, interface._DockerComposeInterface): self.mongod_options["set_parameters"][ff] = "true" if "dbpath" in self.mongod_options and dbpath_prefix is not None: - raise ValueError("Cannot specify both mongod_options.dbpath and dbpath_prefix") + raise ValueError( + "Cannot specify both mongod_options.dbpath and dbpath_prefix" + ) # Default to command line options if the YAML configuration is not passed in. - self.mongod_executable = self.fixturelib.default_if_none(mongod_executable, - self.config.MONGOD_EXECUTABLE) + self.mongod_executable = self.fixturelib.default_if_none( + mongod_executable, self.config.MONGOD_EXECUTABLE + ) # The dbpath in mongod_options takes precedence over other settings to make it easier for # users to specify a dbpath containing data to test against. if "dbpath" not in self.mongod_options: - self.mongod_options["dbpath"] = os.path.join(self._dbpath_prefix, - self.config.FIXTURE_SUBDIR) + self.mongod_options["dbpath"] = os.path.join( + self._dbpath_prefix, self.config.FIXTURE_SUBDIR + ) self._dbpath = self.mongod_options["dbpath"] if self.config.ALWAYS_USE_LOG_FILES: @@ -61,8 +78,9 @@ class MongoDFixture(interface.Fixture, interface._DockerComposeInterface): self.mongot_port = fixturelib.get_next_port(job_num) self.mongod_options["mongotHost"] = "localhost:" + str(self.mongot_port) # In future architectures, this could change - self.mongod_options["searchIndexManagementHostAndPort"] = self.mongod_options[ - "mongotHost"] + self.mongod_options["searchIndexManagementHostAndPort"] = ( + self.mongod_options["mongotHost"] + ) else: self.launch_mongot = False # If a suite enables launching mongot, the MongoTFixture will be created in setup_mongot, @@ -75,9 +93,12 @@ class MongoDFixture(interface.Fixture, interface._DockerComposeInterface): mongod_options["routerPort"] = self.router_port # Always log backtraces to a file in the dbpath in our testing. - backtrace_log_file_name = os.path.join(self.get_dbpath_prefix(), - uuid.uuid4().hex + ".stacktrace") - self.mongod_options["set_parameters"]["backtraceLogFile"] = backtrace_log_file_name + backtrace_log_file_name = os.path.join( + self.get_dbpath_prefix(), uuid.uuid4().hex + ".stacktrace" + ) + self.mongod_options["set_parameters"]["backtraceLogFile"] = ( + backtrace_log_file_name + ) def setup(self): """Set up the mongod.""" @@ -90,9 +111,12 @@ class MongoDFixture(interface.Fixture, interface._DockerComposeInterface): # Second return val is the port, which we ignore because we explicitly created the port above. # The port is used to set other mongod_option's here: # https://github.com/mongodb/mongo/blob/532a6a8ae7b8e7ab5939e900759c00794862963d/buildscripts/resmokelib/testing/fixtures/replicaset.py#L136 - mongod, _ = launcher.launch_mongod_program(self.logger, self.job_num, - executable=self.mongod_executable, - mongod_options=self.mongod_options) + mongod, _ = launcher.launch_mongod_program( + self.logger, + self.job_num, + executable=self.mongod_executable, + mongod_options=self.mongod_options, + ) try: msg = f"Starting mongod on port { self.port }{(' with embedded router on port ' + str(self.router_port)) if self.router_port else ''}...\n{ mongod.as_command() }" @@ -115,7 +139,9 @@ class MongoDFixture(interface.Fixture, interface._DockerComposeInterface): """:return: pids owned by this fixture if any.""" out = [x.pid for x in [self.mongod] if x is not None] if not out: - self.logger.debug('Mongod not running when gathering standalone fixture pid.') + self.logger.debug( + "Mongod not running when gathering standalone fixture pid." + ) return out def _handle_await_ready_retry(self, deadline): @@ -123,7 +149,9 @@ class MongoDFixture(interface.Fixture, interface._DockerComposeInterface): if remaining <= 0.0: raise self.fixturelib.ServerFailure( "Failed to connect to mongod on port {} after {} seconds".format( - self.port, MongoDFixture.AWAIT_READY_TIMEOUT_SECS)) + self.port, MongoDFixture.AWAIT_READY_TIMEOUT_SECS + ) + ) self.logger.info("Waiting to connect to mongod on port %d.", self.port) time.sleep(0.1) # Wait a little bit before trying again. @@ -134,12 +162,15 @@ class MongoDFixture(interface.Fixture, interface._DockerComposeInterface): mongot_options["port"] = self.mongot_port if "keyFile" not in self.mongod_options: - raise self.fixturelib.ServerFailure("Cannot launch mongot without providing a keyfile") + raise self.fixturelib.ServerFailure( + "Cannot launch mongot without providing a keyfile" + ) mongot_options["keyFile"] = self.mongod_options["keyFile"] - mongot = self.fixturelib.make_fixture("MongoTFixture", self.logger, self.job_num, - mongot_options=mongot_options) + mongot = self.fixturelib.make_fixture( + "MongoTFixture", self.logger, self.job_num, mongot_options=mongot_options + ) mongot.setup() self.mongot = mongot @@ -158,7 +189,8 @@ class MongoDFixture(interface.Fixture, interface._DockerComposeInterface): if exit_code is not None: raise self.fixturelib.ServerFailure( "Could not connect to mongod on port {}, process ended" - " unexpectedly with code {}.".format(self.port, exit_code)) + " unexpectedly with code {}.".format(self.port, exit_code) + ) try: # Use a shorter connection timeout to more closely satisfy the requested deadline. @@ -188,14 +220,19 @@ class MongoDFixture(interface.Fixture, interface._DockerComposeInterface): if mode == interface.TeardownMode.ABORT: self.logger.info( "Attempting to send SIGABRT from resmoke to mongod on port %d with pid %d...", - self.port, self.mongod.pid) + self.port, + self.mongod.pid, + ) else: - self.logger.info("Stopping mongod on port %d with pid %d...", self.port, - self.mongod.pid) + self.logger.info( + "Stopping mongod on port %d with pid %d...", self.port, self.mongod.pid + ) if not self.is_running(): exit_code = self.mongod.poll() - msg = ("mongod on port {:d} was expected to be running, but wasn't. " - "Process exited with code {:d}.").format(self.port, exit_code) + msg = ( + "mongod on port {:d} was expected to be running, but wasn't. " + "Process exited with code {:d}." + ).format(self.port, exit_code) self.logger.warning(msg) raise self.fixturelib.ServerFailure(msg) @@ -207,13 +244,19 @@ class MongoDFixture(interface.Fixture, interface._DockerComposeInterface): # Python's subprocess module returns negative versions of system calls. if exit_code == 0 or (mode is not None and exit_code == -(mode.value)): - self.logger.info("Successfully stopped the mongod on port {:d}.".format(self.port)) + self.logger.info( + "Successfully stopped the mongod on port {:d}.".format(self.port) + ) else: - self.logger.warning("Stopped the mongod on port {:d}. " - "Process exited with code {:d}.".format(self.port, exit_code)) + self.logger.warning( + "Stopped the mongod on port {:d}. " + "Process exited with code {:d}.".format(self.port, exit_code) + ) raise self.fixturelib.ServerFailure( "mongod on port {:d} with pid {:d} exited with code {:d}".format( - self.port, self.mongod.pid, exit_code)) + self.port, self.mongod.pid, exit_code + ) + ) def is_running(self): """Return true if the mongod is still operating.""" @@ -229,12 +272,21 @@ class MongoDFixture(interface.Fixture, interface._DockerComposeInterface): self.logger.warning("The mongod fixture has not been set up yet.") return [] - info = interface.NodeInfo(full_name=self.logger.full_name, name=self.logger.name, - port=self.port, pid=self.mongod.pid, router_port=self.router_port) + info = interface.NodeInfo( + full_name=self.logger.full_name, + name=self.logger.name, + port=self.port, + pid=self.mongod.pid, + router_port=self.router_port, + ) return [info] def _get_hostname(self): - return self.logger.external_sut_hostname if self.config.NOOP_MONGO_D_S_PROCESSES else 'localhost' + return ( + self.logger.external_sut_hostname + if self.config.NOOP_MONGO_D_S_PROCESSES + else "localhost" + ) def get_internal_connection_string(self): """Return the internal connection string.""" @@ -246,7 +298,11 @@ class MongoDFixture(interface.Fixture, interface._DockerComposeInterface): def get_driver_connection_url(self): """Return the driver connection URL.""" - return "mongodb://" + self.get_internal_connection_string() + "/?directConnection=true" + return ( + "mongodb://" + + self.get_internal_connection_string() + + "/?directConnection=true" + ) # The below parameters define the default 'logComponentVerbosity' object passed to mongod processes @@ -258,16 +314,20 @@ class MongoDFixture(interface.Fixture, interface._DockerComposeInterface): # The default verbosity setting for any tests that are not started with an Evergreen task id. This # will apply to any tests run locally. DEFAULT_MONGOD_LOG_COMPONENT_VERBOSITY = { - "replication": {"rollback": 2}, "sharding": {"migration": 2, "rangeDeleter": 2}, - "transaction": 4, "tenantMigration": 4 + "replication": {"rollback": 2}, + "sharding": {"migration": 2, "rangeDeleter": 2}, + "transaction": 4, + "tenantMigration": 4, } # The default verbosity setting for any mongod processes running in Evergreen i.e. started with an # Evergreen task id. DEFAULT_EVERGREEN_MONGOD_LOG_COMPONENT_VERBOSITY = { - "replication": {"election": 4, "heartbeats": 2, "initialSync": 2, - "rollback": 2}, "sharding": {"migration": 2, "rangeDeleter": 2}, - "storage": {"recovery": 2}, "transaction": 4, "tenantMigration": 4 + "replication": {"election": 4, "heartbeats": 2, "initialSync": 2, "rollback": 2}, + "sharding": {"migration": 2, "rangeDeleter": 2}, + "storage": {"recovery": 2}, + "transaction": 4, + "tenantMigration": 4, } @@ -279,8 +339,9 @@ class MongodLauncher(object): self.fixturelib = fixturelib self.config = fixturelib.get_config() - def launch_mongod_program(self, logger, job_num, executable=None, process_kwargs=None, - mongod_options=None): + def launch_mongod_program( + self, logger, job_num, executable=None, process_kwargs=None, mongod_options=None + ): """ Return a Process instance that starts mongod arguments constructed from 'mongod_options'. @@ -289,8 +350,9 @@ class MongodLauncher(object): @param process_kwargs - A dict of key-value pairs to pass to the process. @param mongod_options - A HistoryDict describing the various options to pass to the mongod. """ - executable = self.fixturelib.default_if_none(executable, - self.config.DEFAULT_MONGOD_EXECUTABLE) + executable = self.fixturelib.default_if_none( + executable, self.config.DEFAULT_MONGOD_EXECUTABLE + ) mongod_options = self.fixturelib.default_if_none(mongod_options, {}).copy() # Apply the --setParameter command line argument. Command line options to resmoke.py override @@ -299,33 +361,41 @@ class MongodLauncher(object): suite_set_parameters = mongod_options.setdefault("set_parameters", {}) if self.config.MONGOD_SET_PARAMETERS is not None: - suite_set_parameters.update(yaml.safe_load(self.config.MONGOD_SET_PARAMETERS)) + suite_set_parameters.update( + yaml.safe_load(self.config.MONGOD_SET_PARAMETERS) + ) if "mongotHost" in mongod_options: suite_set_parameters["mongotHost"] = mongod_options.pop("mongotHost") - suite_set_parameters["searchIndexManagementHostAndPort"] = mongod_options.pop( - "searchIndexManagementHostAndPort") + suite_set_parameters["searchIndexManagementHostAndPort"] = ( + mongod_options.pop("searchIndexManagementHostAndPort") + ) # Some storage options are both a mongod option (as in config file option and its equivalent # "--xyz" command line parameter) and a "--setParameter". In case of conflict, for instance # due to the config fuzzer adding "xyz" as a "--setParameter" when the "--xyz" option is # already defined in the suite's YAML, the "--setParameter" value shall be preserved and the # "--xyz" option discarded to avoid hitting an error due to conflicting definitions. - mongod_option_and_set_parameter_conflicts = ["syncdelay", "journalCommitInterval"] + mongod_option_and_set_parameter_conflicts = [ + "syncdelay", + "journalCommitInterval", + ] for key in mongod_option_and_set_parameter_conflicts: - if (key in mongod_options and key in suite_set_parameters): + if key in mongod_options and key in suite_set_parameters: del mongod_options[key] # Set default log verbosity levels if none were specified. if "logComponentVerbosity" not in suite_set_parameters: - suite_set_parameters[ - "logComponentVerbosity"] = self.get_default_log_component_verbosity_for_mongod() + suite_set_parameters["logComponentVerbosity"] = ( + self.get_default_log_component_verbosity_for_mongod() + ) # orphanCleanupDelaySecs controls an artificial delay before cleaning up an orphaned chunk # that has migrated off of a shard, meant to allow most dependent queries on secondaries to # complete first. It defaults to 900, or 15 minutes, which is prohibitively long for tests. # Setting it in the .yml file overrides this. - if (("shardsvr" in mongod_options or "configsvr" in mongod_options) - and "orphanCleanupDelaySecs" not in suite_set_parameters): + if ( + "shardsvr" in mongod_options or "configsvr" in mongod_options + ) and "orphanCleanupDelaySecs" not in suite_set_parameters: suite_set_parameters["orphanCleanupDelaySecs"] = 1 # receiveChunkWaitForRangeDeleterTimeoutMS controls the amount of time an incoming migration @@ -333,8 +403,9 @@ class MongodLauncher(object): # default is 10 seconds, but in some slower variants this is not enough time for the range # deleter to finish so we increase it here to 90 seconds. Setting a value for this parameter # in the .yml file overrides this. - if (("shardsvr" in mongod_options or "configsvr" in mongod_options) - and "receiveChunkWaitForRangeDeleterTimeoutMS" not in suite_set_parameters): + if ( + "shardsvr" in mongod_options or "configsvr" in mongod_options + ) and "receiveChunkWaitForRangeDeleterTimeoutMS" not in suite_set_parameters: suite_set_parameters["receiveChunkWaitForRangeDeleterTimeoutMS"] = 90000 # The LogicalSessionCache does automatic background refreshes in the server. This is @@ -346,8 +417,13 @@ class MongodLauncher(object): # Set coordinateCommitReturnImmediatelyAfterPersistingDecision to false so that tests do # not need to rely on causal consistency or explicitly wait for the transaction to finish # committing. - if "coordinateCommitReturnImmediatelyAfterPersistingDecision" not in suite_set_parameters: - suite_set_parameters["coordinateCommitReturnImmediatelyAfterPersistingDecision"] = False + if ( + "coordinateCommitReturnImmediatelyAfterPersistingDecision" + not in suite_set_parameters + ): + suite_set_parameters[ + "coordinateCommitReturnImmediatelyAfterPersistingDecision" + ] = False # There's a periodic background thread that checks for and aborts expired transactions. # "transactionLifetimeLimitSeconds" specifies for how long a transaction can run before expiring @@ -368,23 +444,33 @@ class MongodLauncher(object): # the potential to mask issues such as SERVER-31609 because it allows the operationTime of # cluster to advance even if the client is blocked for other reasons. We should disable the # periodic no-op writer. Set in the .yml file to override this. - if "replSet" in mongod_options and "writePeriodicNoops" not in suite_set_parameters: + if ( + "replSet" in mongod_options + and "writePeriodicNoops" not in suite_set_parameters + ): suite_set_parameters["writePeriodicNoops"] = False # The default time for stepdown and quiesce mode in response to SIGTERM is 15 seconds. Reduce # this to 100ms for faster shutdown. On branches 4.4 and earlier, there is no quiesce mode, but # the default time for stepdown is 10 seconds. - if (("replSet" in mongod_options or "serverless" in mongod_options) - and "shutdownTimeoutMillisForSignaledShutdown" not in suite_set_parameters): + if ( + "replSet" in mongod_options or "serverless" in mongod_options + ) and "shutdownTimeoutMillisForSignaledShutdown" not in suite_set_parameters: suite_set_parameters["shutdownTimeoutMillisForSignaledShutdown"] = 100 - if "enableFlowControl" not in suite_set_parameters and self.config.FLOW_CONTROL is not None: - suite_set_parameters["enableFlowControl"] = (self.config.FLOW_CONTROL == "on") + if ( + "enableFlowControl" not in suite_set_parameters + and self.config.FLOW_CONTROL is not None + ): + suite_set_parameters["enableFlowControl"] = self.config.FLOW_CONTROL == "on" - if ("failpoint.flowControlTicketOverride" not in suite_set_parameters - and self.config.FLOW_CONTROL_TICKETS is not None): + if ( + "failpoint.flowControlTicketOverride" not in suite_set_parameters + and self.config.FLOW_CONTROL_TICKETS is not None + ): suite_set_parameters["failpoint.flowControlTicketOverride"] = { - "mode": "alwaysOn", "data": {"numTickets": self.config.FLOW_CONTROL_TICKETS} + "mode": "alwaysOn", + "data": {"numTickets": self.config.FLOW_CONTROL_TICKETS}, } _add_testing_set_parameters(suite_set_parameters) @@ -399,20 +485,31 @@ class MongodLauncher(object): if self.config.STORAGE_ENGINE == "inMemory": shortcut_opts["inMemorySizeGB"] = self.config.STORAGE_ENGINE_CACHE_SIZE - elif self.config.STORAGE_ENGINE == "wiredTiger" or self.config.STORAGE_ENGINE is None: - shortcut_opts["wiredTigerCacheSizeGB"] = self.config.STORAGE_ENGINE_CACHE_SIZE - shortcut_opts["wiredTigerCacheSizePct"] = self.config.STORAGE_ENGINE_CACHE_SIZE_PCT + elif ( + self.config.STORAGE_ENGINE == "wiredTiger" + or self.config.STORAGE_ENGINE is None + ): + shortcut_opts["wiredTigerCacheSizeGB"] = ( + self.config.STORAGE_ENGINE_CACHE_SIZE + ) + shortcut_opts["wiredTigerCacheSizePct"] = ( + self.config.STORAGE_ENGINE_CACHE_SIZE_PCT + ) # These options are just flags, so they should not take a value. - opts_without_vals = ("logappend") + opts_without_vals = "logappend" # Ensure that config servers run with journaling enabled. if "configsvr" in mongod_options: - suite_set_parameters.setdefault("reshardingMinimumOperationDurationMillis", 5000) - suite_set_parameters.setdefault("reshardingCriticalSectionTimeoutMillis", - 24 * 60 * 60 * 1000) # 24 hours suite_set_parameters.setdefault( - "reshardingDelayBeforeRemainingOperationTimeQueryMillis", 1) + "reshardingMinimumOperationDurationMillis", 5000 + ) + suite_set_parameters.setdefault( + "reshardingCriticalSectionTimeoutMillis", 24 * 60 * 60 * 1000 + ) # 24 hours + suite_set_parameters.setdefault( + "reshardingDelayBeforeRemainingOperationTimeQueryMillis", 1 + ) # Command line options override the YAML configuration. for opt_name in shortcut_opts: @@ -433,8 +530,9 @@ class MongodLauncher(object): if "configsvr" in mongod_options: mongod_options["storageEngine"] = "wiredTiger" - return self.fixturelib.mongod_program(logger, job_num, executable, process_kwargs, - mongod_options) + return self.fixturelib.mongod_program( + logger, job_num, executable, process_kwargs, mongod_options + ) def get_default_log_component_verbosity_for_mongod(self): """Return the default 'logComponentVerbosity' value to use for mongod processes.""" @@ -455,5 +553,9 @@ def _add_testing_set_parameters(suite_set_parameters): # Set it to true for now as a placeholder that will error if no further processing is done. # The placeholder is needed so older versions don't have this option won't have this value set. suite_set_parameters.setdefault("backtraceLogFile", True) - suite_set_parameters.setdefault("disableTransitionFromLatestToLastContinuous", False) - suite_set_parameters.setdefault("oplogApplicationEnforcesSteadyStateConstraints", True) + suite_set_parameters.setdefault( + "disableTransitionFromLatestToLastContinuous", False + ) + suite_set_parameters.setdefault( + "oplogApplicationEnforcesSteadyStateConstraints", True + ) diff --git a/buildscripts/resmokelib/testing/fixtures/yesfixture.py b/buildscripts/resmokelib/testing/fixtures/yesfixture.py index 2570aeb965f..12cff476d30 100644 --- a/buildscripts/resmokelib/testing/fixtures/yesfixture.py +++ b/buildscripts/resmokelib/testing/fixtures/yesfixture.py @@ -8,7 +8,9 @@ from buildscripts.resmokelib.testing.fixtures import interface class YesFixture(interface.Fixture): # pylint: disable=abstract-method """Fixture which spawns several 'yes' executables to generate lots of log messages.""" - def __init__(self, logger, job_num, fixturelib, num_instances=1, message_length=100): + def __init__( + self, logger, job_num, fixturelib, num_instances=1, message_length=100 + ): """Initialize YesFixture.""" interface.Fixture.__init__(self, logger, job_num, fixturelib) @@ -21,7 +23,7 @@ class YesFixture(interface.Fixture): # pylint: disable=abstract-method def setup(self): """Start the yes processes.""" - for (i, process) in enumerate(self.__processes): + for i, process in enumerate(self.__processes): process = self._make_process(i) self.logger.info("Starting yes process...\n%s", process.as_command()) @@ -31,8 +33,9 @@ class YesFixture(interface.Fixture): # pylint: disable=abstract-method self.__processes[i] = process def _make_process(self, index): - logger = self.fixturelib.new_fixture_node_logger(self.__class__.__name__, self.job_num, - "yes{:d}".format(index)) + logger = self.fixturelib.new_fixture_node_logger( + self.__class__.__name__, self.job_num, "yes{:d}".format(index) + ) return self.fixturelib.generic_program(logger, ["yes", self.__message]) def _do_teardown(self, mode=None): @@ -41,7 +44,8 @@ class YesFixture(interface.Fixture): # pylint: disable=abstract-method if not running_at_start: self.logger.info( - "yes processes were expected to be running in _do_teardown(), but weren't.") + "yes processes were expected to be running in _do_teardown(), but weren't." + ) else: self.logger.info("Stopping all yes processes...") @@ -57,7 +61,10 @@ class YesFixture(interface.Fixture): # pylint: disable=abstract-method if running_at_start: self.logger.info( "Successfully terminated the yes process with pid %d, exited with code" - " %d.", process.pid, exit_code) + " %d.", + process.pid, + exit_code, + ) if running_at_start: self.logger.info("Successfully stopped all yes processes.") @@ -66,4 +73,7 @@ class YesFixture(interface.Fixture): # pylint: disable=abstract-method def is_running(self): """Return true if the yes processes are running.""" - return all(process is not None and process.poll() is None for process in self.__processes) + return all( + process is not None and process.poll() is None + for process in self.__processes + ) diff --git a/buildscripts/resmokelib/testing/hook_test_archival.py b/buildscripts/resmokelib/testing/hook_test_archival.py index 3ef83b35e49..e8ef36b69bc 100644 --- a/buildscripts/resmokelib/testing/hook_test_archival.py +++ b/buildscripts/resmokelib/testing/hook_test_archival.py @@ -23,7 +23,7 @@ TRACER = trace.get_tracer("resmoke") class HookTestArchival(object): """Archive hooks and tests to S3.""" - def __init__(self, suite: Suite, hooks, archive_instance, archive_config): #pylint: disable=unused-argument + def __init__(self, suite: Suite, hooks, archive_instance, archive_config): # pylint: disable=unused-argument """Initialize HookTestArchival.""" self.archive_instance = archive_instance archive_config = utils.default_if_none(archive_config, {}) @@ -50,10 +50,10 @@ class HookTestArchival(object): self._lock = threading.Lock() def archive( - self, - logger: logging.Logger, - result: 'TestResult', - manager: 'FixtureTestCaseManager', + self, + logger: logging.Logger, + result: "TestResult", + manager: "FixtureTestCaseManager", ): """ Archive data files for hooks or tests. @@ -70,7 +70,9 @@ class HookTestArchival(object): return if result.hook and result.hook.REGISTERED_NAME in self.hooks: - test_name = "{}:{}".format(result.test.short_name(), result.hook.REGISTERED_NAME) + test_name = "{}:{}".format( + result.test.short_name(), result.hook.REGISTERED_NAME + ) should_archive = True else: test_name = result.test.test_name @@ -89,20 +91,24 @@ class HookTestArchival(object): @TRACER.start_as_current_span("hook_test_archival._archive_hook_or_test") def _archive_hook_or_test( - self, - logger: logging.Logger, - test_name: str, - test: TestCase, - manager: 'FixtureTestCaseManager', + self, + logger: logging.Logger, + test_name: str, + test: TestCase, + manager: "FixtureTestCaseManager", ): """Trigger archive of data files for a test or hook.""" archive_hook_or_test_span = trace.get_current_span() - archive_hook_or_test_span.set_attributes(attributes=test.get_test_otel_attributes()) + archive_hook_or_test_span.set_attributes( + attributes=test.get_test_otel_attributes() + ) # We can still attempt archiving even if the teardown fails. if not manager.teardown_fixture(logger, abort=True): - logger.warning("Error while aborting test fixtures; data files may be invalid.") + logger.warning( + "Error while aborting test fixtures; data files may be invalid." + ) with self._lock: # Test repeat number is how many times the particular test has been archived. if test_name not in self._tests_repeat: @@ -110,30 +116,47 @@ class HookTestArchival(object): else: self._tests_repeat[test_name] += 1 # Normalize test path from a test or hook name. - test_path = \ - test_name.replace("/", "_").replace("\\", "_").replace(".", "_").replace(":", "_") - file_name = "mongo-data-{}-{}-{}-{}.tgz".format(config.EVERGREEN_TASK_ID, test_path, - config.EVERGREEN_EXECUTION, - self._tests_repeat[test_name]) + test_path = ( + test_name.replace("/", "_") + .replace("\\", "_") + .replace(".", "_") + .replace(":", "_") + ) + file_name = "mongo-data-{}-{}-{}-{}.tgz".format( + config.EVERGREEN_TASK_ID, + test_path, + config.EVERGREEN_EXECUTION, + self._tests_repeat[test_name], + ) # Retrieve root directory for all dbPaths from fixture. input_files = test.fixture.get_path_for_archival() s3_bucket = config.ARCHIVE_BUCKET - s3_path = "{}/{}/{}/datafiles/{}".format(config.EVERGREEN_PROJECT_NAME, - config.EVERGREEN_VARIANT_NAME, - config.EVERGREEN_REVISION, file_name) + s3_path = "{}/{}/{}/datafiles/{}".format( + config.EVERGREEN_PROJECT_NAME, + config.EVERGREEN_VARIANT_NAME, + config.EVERGREEN_REVISION, + file_name, + ) display_name = "Data files {} - Execution {} Repetition {}".format( - test_name, config.EVERGREEN_EXECUTION, self._tests_repeat[test_name]) + test_name, config.EVERGREEN_EXECUTION, self._tests_repeat[test_name] + ) logger.info("Archiving data files for test %s from %s", test_name, input_files) - status, message = self.archive_instance.archive_files_to_s3(display_name, input_files, - s3_bucket, s3_path) + status, message = self.archive_instance.archive_files_to_s3( + display_name, input_files, s3_bucket, s3_path + ) if status: logger.warning("Archive failed for %s: %s", test_name, message) else: logger.info("Archive succeeded for %s: %s", test_name, message) if HANG_ANALYZER_CALLED.is_set(): - logger.info("Hang Analyzer has been called. Fixtures will not be restarted.") + logger.info( + "Hang Analyzer has been called. Fixtures will not be restarted." + ) raise errors.StopExecution( - "Hang analyzer has been called. Stopping further execution of tests.") + "Hang analyzer has been called. Stopping further execution of tests." + ) elif not manager.setup_fixture(logger): - raise errors.StopExecution("Error while restarting test fixtures after archiving.") + raise errors.StopExecution( + "Error while restarting test fixtures after archiving." + ) diff --git a/buildscripts/resmokelib/testing/hooks/add_remove_shards.py b/buildscripts/resmokelib/testing/hooks/add_remove_shards.py index b1b30786d09..7cde7797f92 100644 --- a/buildscripts/resmokelib/testing/hooks/add_remove_shards.py +++ b/buildscripts/resmokelib/testing/hooks/add_remove_shards.py @@ -23,24 +23,27 @@ from buildscripts.resmokelib.testing.retry import ( class ContinuousAddRemoveShard(interface.Hook): DESCRIPTION = ( - "Continuously adds and removes shards at regular intervals. If running with configsvr " + - "transitions, will transition in/out of config shard mode.") + "Continuously adds and removes shards at regular intervals. If running with configsvr " + + "transitions, will transition in/out of config shard mode." + ) IS_BACKGROUND = True STOPS_FIXTURE = False def __init__( - self, - hook_logger, - fixture, - auth_options=None, - random_balancer_on=True, - transition_configsvr=False, - add_remove_random_shards=False, - move_primary_comment=None, + self, + hook_logger, + fixture, + auth_options=None, + random_balancer_on=True, + transition_configsvr=False, + add_remove_random_shards=False, + move_primary_comment=None, ): - interface.Hook.__init__(self, hook_logger, fixture, ContinuousAddRemoveShard.DESCRIPTION) + interface.Hook.__init__( + self, hook_logger, fixture, ContinuousAddRemoveShard.DESCRIPTION + ) self._fixture = fixture self._add_remove_thread = None self._auth_options = auth_options @@ -122,15 +125,15 @@ class _AddRemoveShardThread(threading.Thread): ] def __init__( - self, - logger, - stepdown_lifecycle, - fixture, - auth_options, - random_balancer_on, - transition_configsvr, - add_remove_random_shards, - move_primary_comment, + self, + logger, + stepdown_lifecycle, + fixture, + auth_options, + random_balancer_on, + transition_configsvr, + add_remove_random_shards, + move_primary_comment, ): threading.Thread.__init__(self, name="AddRemoveShardThread") self.logger = logger @@ -167,9 +170,12 @@ class _AddRemoveShardThread(threading.Thread): # If running with both config transitions and random shard add/removals, pick any shard # including the config shard. Otherwise, pick any shard that is not the config shard. - shard_to_remove_and_add = (self._get_other_shard_info(None) if self._transition_configsvr - and self._current_config_mode is self.CONFIG_SHARD else - self._get_other_shard_info("config")) + shard_to_remove_and_add = ( + self._get_other_shard_info(None) + if self._transition_configsvr + and self._current_config_mode is self.CONFIG_SHARD + else self._get_other_shard_info("config") + ) return shard_to_remove_and_add["_id"], shard_to_remove_and_add["host"] def run(self): @@ -187,8 +193,11 @@ class _AddRemoveShardThread(threading.Thread): shard_id, shard_host = self._pick_shard_to_add_remove() wait_secs = random.choice(self.TRANSITION_INTERVALS) - msg = ("transition to dedicated." - if shard_id == "config" else "removing shard " + shard_id + ".") + msg = ( + "transition to dedicated." + if shard_id == "config" + else "removing shard " + shard_id + "." + ) self.logger.info(f"Waiting {wait_secs} seconds before " + msg) self.__lifecycle.wait_for_action_interval(wait_secs) @@ -213,8 +222,11 @@ class _AddRemoveShardThread(threading.Thread): # Wait a random interval before transitioning back, unless the test already ended. if not self.__lifecycle.poll_for_idle_request(): wait_secs = random.choice(self.TRANSITION_INTERVALS) - msg = ("transition to config shard." - if shard_id == "config" else "adding shard " + shard_id + ".") + msg = ( + "transition to config shard." + if shard_id == "config" + else "adding shard " + shard_id + "." + ) self.logger.info(f"Waiting {wait_secs} seconds before " + msg) self.__lifecycle.wait_for_action_interval(wait_secs) @@ -354,19 +366,35 @@ class _AddRemoveShardThread(threading.Thread): self.logger.error(msg) raise errors.ServerFailure(msg) - direct_shard_conn = pymongo.MongoClient(shard_obj.get_driver_connection_url()) + direct_shard_conn = pymongo.MongoClient( + shard_obj.get_driver_connection_url() + ) # Wait until any DDL, resharding, transactions, and migration ops are cleaned up. # TODO SERVER-90782 Change these to be assertions, rather than waiting for the collections # to be empty - if len(list(direct_shard_conn.config.system.sharding_ddl_coordinators.find())) != 0: + if ( + len( + list( + direct_shard_conn.config.system.sharding_ddl_coordinators.find() + ) + ) + != 0 + ): self.logger.info( "Waiting for config.system.sharding_ddl_coordinators to be empty before decomissioning." ) time.sleep(1) continue - if len(list(direct_shard_conn.config.localReshardingOperations.recipient.find())) != 0: + if ( + len( + list( + direct_shard_conn.config.localReshardingOperations.recipient.find() + ) + ) + != 0 + ): self.logger.info( "Waiting for config.localReshardingOperations.recipient to be empty before decomissioning." ) @@ -383,7 +411,8 @@ class _AddRemoveShardThread(threading.Thread): # TODO SERVER-91474 Wait for ongoing transactions to finish on participants if self._get_number_of_ongoing_transactions(direct_shard_conn) != 0: self.logger.info( - "Waiting for ongoing transactions to commit or abort before decomissioning.") + "Waiting for ongoing transactions to commit or abort before decomissioning." + ) time.sleep(1) continue @@ -391,13 +420,27 @@ class _AddRemoveShardThread(threading.Thread): all_dbs = direct_shard_conn.admin.command({"listDatabases": 1}) for db in all_dbs["databases"]: - if db["name"] not in ["admin", "config", "local"] and db["empty"] is False: + if ( + db["name"] not in ["admin", "config", "local"] + and db["empty"] is False + ): db_name = db["name"] all_collections = direct_shard_conn.get_database(db_name).command( - {"listCollections": 1}) + {"listCollections": 1} + ) for coll in all_collections: - if len(list(direct_shard_conn.get_database(db_name).coll.find())) != 0: - msg = "Found non-empty collection after removing shard: " + coll + if ( + len( + list( + direct_shard_conn.get_database(db_name).coll.find() + ) + ) + != 0 + ): + msg = ( + "Found non-empty collection after removing shard: " + + coll + ) self.logger.error(msg) raise errors.ServerFailure(msg) @@ -413,36 +456,42 @@ class _AddRemoveShardThread(threading.Thread): def _get_tracked_collections_on_shard(self, shard_id): return list( - self._client.config.collections.aggregate([ - { - "$lookup": { - "from": - "chunks", - "localField": - "uuid", - "foreignField": - "uuid", - "as": - "chunksOnRemovedShard", - "pipeline": [ - {"$match": {"shard": shard_id}}, - # History can be very large because we randomize migrations, so - # exclude it to reduce log spam. - {"$project": {"history": 0}}, - ], - } - }, - {"$match": {"chunksOnRemovedShard": {"$ne": []}}}, - ])) + self._client.config.collections.aggregate( + [ + { + "$lookup": { + "from": "chunks", + "localField": "uuid", + "foreignField": "uuid", + "as": "chunksOnRemovedShard", + "pipeline": [ + {"$match": {"shard": shard_id}}, + # History can be very large because we randomize migrations, so + # exclude it to reduce log spam. + {"$project": {"history": 0}}, + ], + } + }, + {"$match": {"chunksOnRemovedShard": {"$ne": []}}}, + ] + ) + ) def _get_untracked_collections_on_shard(self, source): untracked_collections = [] databases = list( - self._client.config.databases.aggregate([{ - "$match": {"primary": source}, - }])) + self._client.config.databases.aggregate( + [ + { + "$match": {"primary": source}, + } + ] + ) + ) for database in databases: - for collection in self._client.get_database(database["_id"]).list_collections(): + for collection in self._client.get_database( + database["_id"] + ).list_collections(): namespace = database["_id"] + "." + collection["name"] coll_doc = self._client.config.collections.find_one({"_id": namespace}) if not coll_doc: @@ -454,25 +503,36 @@ class _AddRemoveShardThread(threading.Thread): for collection in collections: namespace = collection["_id"] destination = self._get_other_shard_id(source) - self.logger.info("Running moveCollection for " + namespace + " to " + destination) + self.logger.info( + "Running moveCollection for " + namespace + " to " + destination + ) try: - self._client.admin.command({"moveCollection": namespace, "toShard": destination}) + self._client.admin.command( + {"moveCollection": namespace, "toShard": destination} + ) except pymongo.errors.OperationFailure as err: if not self._is_expected_move_collection_error(err, namespace): raise err - self.logger.info("Ignoring error when moving the collection '" + namespace + "': " + - str(err)) + self.logger.info( + "Ignoring error when moving the collection '" + + namespace + + "': " + + str(err) + ) if err.code == self._RESHARD_COLLECTION_IN_PROGRESS: self.logger.info( - "Skip moving the other collections since there is already a resharding " + - "operation in progress") + "Skip moving the other collections since there is already a resharding " + + "operation in progress" + ) return def _move_all_primaries_from_shard(self, databases, source): for database in databases: destination = self._get_other_shard_id(source) try: - self.logger.info("Running movePrimary for " + database + " to " + destination) + self.logger.info( + "Running movePrimary for " + database + " to " + destination + ) cmd_obj = {"movePrimary": database, "to": destination} if self._move_primary_comment: cmd_obj["comment"] = self._move_primary_comment @@ -480,10 +540,16 @@ class _AddRemoveShardThread(threading.Thread): except pymongo.errors.OperationFailure as err: if not self._is_expected_move_primary_error_code(err.code): raise err - self.logger.info("Ignoring error when moving the database '" + database + "': " + - str(err)) + self.logger.info( + "Ignoring error when moving the database '" + + database + + "': " + + str(err) + ) - def _drain_shard_for_ongoing_transition(self, num_rounds, transition_result, source): + def _drain_shard_for_ongoing_transition( + self, num_rounds, transition_result, source + ): tracked_colls = self._get_tracked_collections_on_shard(source) sharded_colls = [] tracked_unsharded_colls = [] @@ -495,19 +561,48 @@ class _AddRemoveShardThread(threading.Thread): untracked_unsharded_colls = self._get_untracked_collections_on_shard(source) if num_rounds % 10 == 0: - self.logger.info("Draining shard " + source + ": " + str({"num_rounds": num_rounds})) - self.logger.info("Sharded collections on " + source + ": " + - str({"count": len(sharded_colls), "collections": sharded_colls})) - self.logger.info("Tracked unsharded collections on " + source + ": " + str( - {"count": len(tracked_unsharded_colls), "collections": tracked_unsharded_colls})) - self.logger.info("Untracked unsharded collections on " + source + ": " + str({ - "count": len(untracked_unsharded_colls), - "collections": untracked_unsharded_colls, - })) - self.logger.info("Databases on " + source + ": " + str({ - "count": len(transition_result["dbsToMove"]), - "collections": transition_result["dbsToMove"], - })) + self.logger.info( + "Draining shard " + source + ": " + str({"num_rounds": num_rounds}) + ) + self.logger.info( + "Sharded collections on " + + source + + ": " + + str({"count": len(sharded_colls), "collections": sharded_colls}) + ) + self.logger.info( + "Tracked unsharded collections on " + + source + + ": " + + str( + { + "count": len(tracked_unsharded_colls), + "collections": tracked_unsharded_colls, + } + ) + ) + self.logger.info( + "Untracked unsharded collections on " + + source + + ": " + + str( + { + "count": len(untracked_unsharded_colls), + "collections": untracked_unsharded_colls, + } + ) + ) + self.logger.info( + "Databases on " + + source + + ": " + + str( + { + "count": len(transition_result["dbsToMove"]), + "collections": transition_result["dbsToMove"], + } + ) + ) # If random balancing is on, the balancer will also move unsharded collections (both tracked # and untracked). However, random balancing is a test-only setting. In production, users are @@ -516,7 +611,8 @@ class _AddRemoveShardThread(threading.Thread): should_move = not self._random_balancer_on or random.random() < 0.5 if should_move: self._move_all_collections_from_shard( - tracked_unsharded_colls + untracked_unsharded_colls, source) + tracked_unsharded_colls + untracked_unsharded_colls, source + ) self._move_all_primaries_from_shard(transition_result["dbsToMove"], source) def _get_balancer_status_on_shard_not_found(self, prev_round_interrupted, msg): @@ -524,9 +620,11 @@ class _AddRemoveShardThread(threading.Thread): latest_status = self._client.admin.command({"balancerStatus": 1}) except pymongo.errors.OperationFailure as balancerStatusErr: if balancerStatusErr.code in set(retryable_network_errs): - self.logger.info("Network error when running balancerStatus after " - "receiving ShardNotFound error on " + msg + ", will " - "retry. err: " + str(balancerStatusErr)) + self.logger.info( + "Network error when running balancerStatus after " + "receiving ShardNotFound error on " + msg + ", will " + "retry. err: " + str(balancerStatusErr) + ) prev_round_interrupted = False return None, prev_round_interrupted @@ -534,8 +632,10 @@ class _AddRemoveShardThread(threading.Thread): raise balancerStatusErr prev_round_interrupted = True - self.logger.info("Ignoring 'Interrupted' error when running balancerStatus " - "after receiving ShardNotFound error on " + msg) + self.logger.info( + "Ignoring 'Interrupted' error when running balancerStatus " + "after receiving ShardNotFound error on " + msg + ) return None, prev_round_interrupted return latest_status, prev_round_interrupted @@ -557,7 +657,9 @@ class _AddRemoveShardThread(threading.Thread): while True: try: if last_balancer_status is None: - last_balancer_status = self._client.admin.command({"balancerStatus": 1}) + last_balancer_status = self._client.admin.command( + {"balancerStatus": 1} + ) if self._should_wait_for_balancer_round: # TODO SERVER-90291: Remove. @@ -578,13 +680,17 @@ class _AddRemoveShardThread(threading.Thread): time.sleep(1) continue - if last_balancer_status["numBalancerRounds"] >= latest_status[ - "numBalancerRounds"]: + if ( + last_balancer_status["numBalancerRounds"] + >= latest_status["numBalancerRounds"] + ): self.logger.info( - "Waiting for a balancer round before " + msg + - ". Last round: %d, latest round: %d", + "Waiting for a balancer round before " + + msg + + ". Last round: %d, latest round: %d", last_balancer_status["numBalancerRounds"], - latest_status["numBalancerRounds"]) + latest_status["numBalancerRounds"], + ) time.sleep(1) continue @@ -592,20 +698,29 @@ class _AddRemoveShardThread(threading.Thread): self._should_wait_for_balancer_round = False if shard_id == "config": - res = self._client.admin.command({"transitionToDedicatedConfigServer": 1}) + res = self._client.admin.command( + {"transitionToDedicatedConfigServer": 1} + ) else: res = self._client.admin.command({"removeShard": shard_id}) if res["state"] == "completed": - self.logger.info("Completed " + msg + " in %0d ms", - (time.time() - start_time) * 1000) + self.logger.info( + "Completed " + msg + " in %0d ms", + (time.time() - start_time) * 1000, + ) return True elif res["state"] == "started": - if self._client.config.chunks.count_documents({"shard": shard_id}) == 0: + if ( + self._client.config.chunks.count_documents({"shard": shard_id}) + == 0 + ): self._should_wait_for_balancer_round = True elif res["state"] == "ongoing": num_draining_rounds += 1 - self._drain_shard_for_ongoing_transition(num_draining_rounds, res, shard_id) + self._drain_shard_for_ongoing_transition( + num_draining_rounds, res, shard_id + ) prev_round_interrupted = False time.sleep(1) @@ -616,16 +731,18 @@ class _AddRemoveShardThread(threading.Thread): raise errors.ServerFailure(msg) except pymongo.errors.OperationFailure as err: # Some workloads add and remove shards so removing the config shard may fail transiently. - if err.code in [self._ILLEGAL_OPERATION - ] and "would remove the last shard" in str(err): + if err.code in [ + self._ILLEGAL_OPERATION + ] and "would remove the last shard" in str(err): # Abort the transition attempt and make the hook try again later. return False # Some suites run with forced failovers, if transitioning fails with a retryable # network error, we should retry. if err.code in set(retryable_network_errs): - self.logger.info("Network error during " + msg + ", will retry. err: " + - str(err)) + self.logger.info( + "Network error during " + msg + ", will retry. err: " + str(err) + ) time.sleep(1) prev_round_interrupted = False continue @@ -633,8 +750,12 @@ class _AddRemoveShardThread(threading.Thread): # Some suites kill the primary causing the request to fail with # FailedToSatisfyReadPreference if err.code in [self._FAILED_TO_SATISFY_READ_PREFERENCE]: - self.logger.info("Primary not found when " + msg + ", will retry. err: " + - str(err)) + self.logger.info( + "Primary not found when " + + msg + + ", will retry. err: " + + str(err) + ) time.sleep(1) continue @@ -644,7 +765,10 @@ class _AddRemoveShardThread(threading.Thread): # the transition to dedicated is retried, it will fail because the shard will no longer exist. if err.code in [self._SHARD_NOT_FOUND]: latest_status, prev_round_interrupted = ( - self._get_balancer_status_on_shard_not_found(prev_round_interrupted, msg)) + self._get_balancer_status_on_shard_not_found( + prev_round_interrupted, msg + ) + ) if latest_status is None: # The balancerStatus request was interrupted, so we retry the transition # request. We will fail with ShardNotFound again, and will retry this check @@ -655,13 +779,18 @@ class _AddRemoveShardThread(threading.Thread): if last_balancer_status is None: last_balancer_status = latest_status - if (last_balancer_status["term"] != latest_status["term"] - or prev_round_interrupted): + if ( + last_balancer_status["term"] != latest_status["term"] + or prev_round_interrupted + ): self.logger.info( - "Did not find entry for " + shard_id + - " in config.shards after detecting a " + "Did not find entry for " + + shard_id + + " in config.shards after detecting a " "change in repl set term or after transition was interrutped. Assuming " - + msg + " finished on previous transition request.") + + msg + + " finished on previous transition request." + ) return True if not self._is_expected_transition_error_code(err.code): @@ -681,21 +810,32 @@ class _AddRemoveShardThread(threading.Thread): while True: try: if shard_id == "config": - self._client.admin.command({"transitionFromDedicatedConfigServer": 1}) + self._client.admin.command( + {"transitionFromDedicatedConfigServer": 1} + ) else: - original_shard_id = (shard_id if self._shard_name_suffix == 0 else - shard_id.split("_")[0]) + original_shard_id = ( + shard_id + if self._shard_name_suffix == 0 + else shard_id.split("_")[0] + ) shard_name = original_shard_id + "_" + str(self._shard_name_suffix) self.logger.info("Adding shard with new shardId: " + shard_name) - self._client.admin.command({"addShard": shard_host, "name": shard_name}) + self._client.admin.command( + {"addShard": shard_host, "name": shard_name} + ) self._shard_name_suffix = self._shard_name_suffix + 1 return except pymongo.errors.OperationFailure as err: # Some suites run with forced failovers, if transitioning fails with a retryable # network error, we should retry. if err.code in set(retryable_network_errs): - self.logger.info("Network error when " + msg + " server, will retry. err: " + - str(err)) + self.logger.info( + "Network error when " + + msg + + " server, will retry. err: " + + str(err) + ) time.sleep(1) continue @@ -703,18 +843,31 @@ class _AddRemoveShardThread(threading.Thread): # transition/addShard, addShard will fail because it will not be able to connect. The # error code returned is not retryable (it is OperationFailed), so we check the specific # error message as well. - if err.code in [self._OPERATION_FAILED] and "Connection refused" in str(err) or any( - err_name in str(err) for err_name in retryable_network_err_names): - self.logger.info("Network error adding shard when " + msg + - ", will retry. err: " + str(err)) + if ( + err.code in [self._OPERATION_FAILED] + and "Connection refused" in str(err) + or any( + err_name in str(err) for err_name in retryable_network_err_names + ) + ): + self.logger.info( + "Network error adding shard when " + + msg + + ", will retry. err: " + + str(err) + ) time.sleep(1) continue # Some suites kill the primary causing the request to fail with # FailedToSatisfyReadPreference if err.code in [self._FAILED_TO_SATISFY_READ_PREFERENCE]: - self.logger.info("Primary not found when " + msg + ", will retry. err: " + - str(err)) + self.logger.info( + "Primary not found when " + + msg + + ", will retry. err: " + + str(err) + ) time.sleep(1) continue @@ -734,7 +887,9 @@ class _AddRemoveShardThread(threading.Thread): possible_choices = [] if shard_id is not None: possible_choices = [ - shard_info for shard_info in res["shards"] if shard_info["_id"] != shard_id + shard_info + for shard_info in res["shards"] + if shard_info["_id"] != shard_id ] else: possible_choices = [shard_info for shard_info in res["shards"]] @@ -746,15 +901,20 @@ class _AddRemoveShardThread(threading.Thread): def _get_number_of_ongoing_transactions(self, shard_conn): res = list( - shard_conn.admin.aggregate([ - {"$currentOp": { - "allUsers": True, - "idleConnections": True, - "idleSessions": True, - }}, - {"$match": {"transaction": {"$exists": True}}}, - {"$count": "num_ongoing_txns"}, - ])) + shard_conn.admin.aggregate( + [ + { + "$currentOp": { + "allUsers": True, + "idleConnections": True, + "idleSessions": True, + } + }, + {"$match": {"transaction": {"$exists": True}}}, + {"$count": "num_ongoing_txns"}, + ] + ) + ) return res[0]["num_ongoing_txns"] if res else 0 def _run_post_remove_shard_checks(self, removed_shard_fixture, removed_shard_name): @@ -762,25 +922,30 @@ class _AddRemoveShardThread(threading.Thread): try: # Configsvr metadata checks: ## Check that the removed shard no longer exists on config.shards. - assert (self._client["config"]["shards"].count_documents({ - "_id": removed_shard_name - }) == 0), f"Removed shard still exists on config.shards: {removed_shard_name}" + assert ( + self._client["config"]["shards"].count_documents( + {"_id": removed_shard_name} + ) + == 0 + ), f"Removed shard still exists on config.shards: {removed_shard_name}" ## Check that no database has the removed shard as its primary shard. databasesPointingToRemovedShard = [ - doc for doc in self._client["config"]["databases"].find( - {"primary": removed_shard_name}) + doc + for doc in self._client["config"]["databases"].find( + {"primary": removed_shard_name} + ) ] assert not databasesPointingToRemovedShard, f"Found databases whose primary shard is a removed shard: {databasesPointingToRemovedShard}" ## Check that no chunk has the removed shard as its owner. chunksPointingToRemovedShard = [ doc - for doc in self._client["config"]["chunks"].find({"shard": removed_shard_name}) + for doc in self._client["config"]["chunks"].find( + {"shard": removed_shard_name} + ) ] - assert ( - not chunksPointingToRemovedShard - ), f"Found chunks whose owner is a removed shard: {chunksPointingToRemovedShard}" + assert not chunksPointingToRemovedShard, f"Found chunks whose owner is a removed shard: {chunksPointingToRemovedShard}" ## Check that all tag in config.tags refer to at least one existing shard. tagsWithoutShardPipeline = [ @@ -794,8 +959,9 @@ class _AddRemoveShardThread(threading.Thread): }, {"$match": {"shards": []}}, ] - tagsWithoutShardPipelineResultCursor = self._client["config"]["tags"].aggregate( - tagsWithoutShardPipeline) + tagsWithoutShardPipelineResultCursor = self._client["config"][ + "tags" + ].aggregate(tagsWithoutShardPipeline) tagsWithoutShardPipelineResult = [ doc for doc in tagsWithoutShardPipelineResultCursor ] @@ -806,26 +972,37 @@ class _AddRemoveShardThread(threading.Thread): # Check that there is no user data left on the removed shard. (Note: This can only be # checked on transitionToDedicatedConfigServer) - removed_shard_primary_client = removed_shard_fixture.get_primary().mongo_client() + removed_shard_primary_client = ( + removed_shard_fixture.get_primary().mongo_client() + ) dbs = removed_shard_primary_client.list_database_names() - assert all(databaseName in {"local", "admin", "config"} for databaseName in - dbs), f"Expected to not have any user database on removed shard: {dbs}" + assert all( + databaseName in {"local", "admin", "config"} for databaseName in dbs + ), f"Expected to not have any user database on removed shard: {dbs}" # Check the filtering metadata on removed shard. Expect that the shard knows that it does # not own any chunk anymore. Check on all replica set nodes. # First, await secondaries to replicate the last optime removed_shard_fixture.await_last_op_committed( - removed_shard_fixture.AWAIT_REPL_TIMEOUT_FOREVER_MINS * 60) - for removed_shard_node in [removed_shard_fixture.get_primary() - ] + removed_shard_fixture.get_secondaries(): - sharding_state_response = removed_shard_node.mongo_client().admin.command( - {"shardingState": 1}) + removed_shard_fixture.AWAIT_REPL_TIMEOUT_FOREVER_MINS * 60 + ) + for removed_shard_node in [ + removed_shard_fixture.get_primary() + ] + removed_shard_fixture.get_secondaries(): + sharding_state_response = ( + removed_shard_node.mongo_client().admin.command( + {"shardingState": 1} + ) + ) for nss, metadata in sharding_state_response["versions"].items(): # placementVersion == Timestamp(0, 0) means that this shard owns no chunk for the # collection. # TODO (SERVER-90810): Re-enable this check for resharding temporary collections. - if "system.resharding" in nss or "system.buckets.resharding" in nss: + if ( + "system.resharding" in nss + or "system.buckets.resharding" in nss + ): continue assert ( @@ -833,25 +1010,31 @@ class _AddRemoveShardThread(threading.Thread): ), f"Expected remove shard's filtering information to reflect that the shard does not own any chunk for collection {nss}, but found {metadata} on node {removed_shard_node.get_driver_connection_url()}" return - except (pymongo.errors.AutoReconnect, pymongo.errors.NotPrimaryError) as err: + except ( + pymongo.errors.AutoReconnect, + pymongo.errors.NotPrimaryError, + ) as err: # The above operations run directly on a shard, so they may fail getting a # connection if the shard node is killed. self.logger.info( - "Connection error when running post removal checks, will retry. err: " + - str(err)) + "Connection error when running post removal checks, will retry. err: " + + str(err) + ) continue except pymongo.errors.OperationFailure as err: # Retry on retryable errors that might be thrown in suites with forced failovers. if err.code in set(retryable_network_errs): self.logger.info( - "Retryable error when running post removal checks, will retry. err: " + - str(err)) + "Retryable error when running post removal checks, will retry. err: " + + str(err) + ) continue if err.code in set([self._INTERRUPTED]): # Some workloads kill sessions which may interrupt the transition. self.logger.info( "Received 'Interrupted' error when running post removal checks, will retry. err: " - + str(err)) + + str(err) + ) continue raise err diff --git a/buildscripts/resmokelib/testing/hooks/aggregate_metrics_background.py b/buildscripts/resmokelib/testing/hooks/aggregate_metrics_background.py index 4096b110073..4e47a1af1b6 100644 --- a/buildscripts/resmokelib/testing/hooks/aggregate_metrics_background.py +++ b/buildscripts/resmokelib/testing/hooks/aggregate_metrics_background.py @@ -17,9 +17,12 @@ class AggregateResourceConsumptionMetricsInBackground(BGHook): def __init__(self, hook_logger, fixture, shell_options=None): """Initialize AggregateResourceConsumptionMetricsInBackground.""" - description = "Run background $operationMetrics on all mongods while a test is running" - super().__init__(hook_logger, fixture, description, tests_per_cycle=None, - loop_delay_ms=1000) + description = ( + "Run background $operationMetrics on all mongods while a test is running" + ) + super().__init__( + hook_logger, fixture, description, tests_per_cycle=None, loop_delay_ms=1000 + ) def run_action(self): """Collects $operationMetrics on all non-arbiter nodes in the fixture.""" @@ -28,46 +31,67 @@ class AggregateResourceConsumptionMetricsInBackground(BGHook): # Filter out arbiters. if "arbiterOnly" in conn.admin.command({"isMaster": 1}): self.logger.info( - "Skipping background aggregation against test node: %s because it is an " + - "arbiter and has no data.", node_info.full_name) + "Skipping background aggregation against test node: %s because it is an " + + "arbiter and has no data.", + node_info.full_name, + ) return # Clear the metrics about 10% of the time. clear_metrics = random.random() < 0.1 - self.logger.info("Running $operationMetrics with {clearMetrics: %s} on host: %s", - clear_metrics, node_info.full_name) + self.logger.info( + "Running $operationMetrics with {clearMetrics: %s} on host: %s", + clear_metrics, + node_info.full_name, + ) with conn.admin.aggregate( - [{"$operationMetrics": {"clearMetrics": clear_metrics}}]) as cursor: + [{"$operationMetrics": {"clearMetrics": clear_metrics}}] + ) as cursor: for doc in cursor: try: self.verify_metrics(doc) except: self.logger.info( - "caught exception while verifying that all expected fields are in the" + - " metrics output: ", doc) + "caught exception while verifying that all expected fields are in the" + + " metrics output: ", + doc, + ) raise def verify_metrics(self, doc): """Checks whether the output from $operatiomMetrics has the schema we expect.""" top_level_fields = [ - "docBytesWritten", "docUnitsWritten", "idxEntryBytesWritten", "idxEntryUnitsWritten", - "totalUnitsWritten", "cpuNanos", "db", "primaryMetrics", "secondaryMetrics" + "docBytesWritten", + "docUnitsWritten", + "idxEntryBytesWritten", + "idxEntryUnitsWritten", + "totalUnitsWritten", + "cpuNanos", + "db", + "primaryMetrics", + "secondaryMetrics", ] read_fields = [ - "docBytesRead", "docUnitsRead", "idxEntryBytesRead", "idxEntryUnitsRead", "keysSorted", - "docUnitsReturned" + "docBytesRead", + "docUnitsRead", + "idxEntryBytesRead", + "idxEntryUnitsRead", + "keysSorted", + "docUnitsReturned", ] for key in top_level_fields: - assert key in doc, ("The metrics output is missing the property: " + key) + assert key in doc, "The metrics output is missing the property: " + key primary_metrics = doc["primaryMetrics"] for key in read_fields: assert key in primary_metrics, ( - "The metrics output is missing the property: primaryMetrics." + key) + "The metrics output is missing the property: primaryMetrics." + key + ) secondary_metrics = doc["secondaryMetrics"] for key in read_fields: assert key in secondary_metrics, ( - "The metrics output is missing the property: secondaryMetrics." + key) + "The metrics output is missing the property: secondaryMetrics." + key + ) diff --git a/buildscripts/resmokelib/testing/hooks/analyze_shard_key.py b/buildscripts/resmokelib/testing/hooks/analyze_shard_key.py index 2c08dc38167..fa18c1f0534 100644 --- a/buildscripts/resmokelib/testing/hooks/analyze_shard_key.py +++ b/buildscripts/resmokelib/testing/hooks/analyze_shard_key.py @@ -21,9 +21,17 @@ class AnalyzeShardKeysInBackground(jsfile.JSHook): def __init__(self, hook_logger, fixture, shell_options=None): """Initialize AnalyzeShardKeysInBackground.""" description = "Runs running analyzeShardKey commands while a test is running" - js_filename = os.path.join("jstests", "hooks", "run_analyze_shard_key_background.js") - jsfile.JSHook.__init__(self, hook_logger, fixture, js_filename, description, - shell_options=shell_options) + js_filename = os.path.join( + "jstests", "hooks", "run_analyze_shard_key_background.js" + ) + jsfile.JSHook.__init__( + self, + hook_logger, + fixture, + js_filename, + description, + shell_options=shell_options, + ) self._background_job = None @@ -47,7 +55,8 @@ class AnalyzeShardKeysInBackground(jsfile.JSHook): return hook_test_case = _ContinuousDynamicJSTestCase.create_before_test( - test.logger, test, self, self._js_filename, self._shell_options) + test.logger, test, self, self._js_filename, self._shell_options + ) hook_test_case.configure(self.fixture) self.logger.info("Resuming the background thread for analyzing shard keys.") @@ -72,5 +81,6 @@ class AnalyzeShardKeysInBackground(jsfile.JSHook): else: self.logger.error( "Encountered an error inside the background thread for analyzing shard keys.", - exc_info=self._background_job.exc_info) + exc_info=self._background_job.exc_info, + ) raise self._background_job.exc_info[1] diff --git a/buildscripts/resmokelib/testing/hooks/antithesis_logging.py b/buildscripts/resmokelib/testing/hooks/antithesis_logging.py index f570e8115ca..af51d31e007 100644 --- a/buildscripts/resmokelib/testing/hooks/antithesis_logging.py +++ b/buildscripts/resmokelib/testing/hooks/antithesis_logging.py @@ -14,7 +14,9 @@ class AntithesisLogging(interface.Hook): def __init__(self, hook_logger, fixture): """Initialize the AntithesisLogging hook.""" - interface.Hook.__init__(self, hook_logger, fixture, AntithesisLogging.DESCRIPTION) + interface.Hook.__init__( + self, hook_logger, fixture, AntithesisLogging.DESCRIPTION + ) def before_test(self, test, test_report): """Ensure the fault injector is running before a test.""" diff --git a/buildscripts/resmokelib/testing/hooks/bghook.py b/buildscripts/resmokelib/testing/hooks/bghook.py index 1b3fed3b104..67a2c2116eb 100644 --- a/buildscripts/resmokelib/testing/hooks/bghook.py +++ b/buildscripts/resmokelib/testing/hooks/bghook.py @@ -35,7 +35,9 @@ class BGJob(threading.Thread): # The configured loop delay asked us to wait before running the action again. Do # that wait, but listen to see if we finish running the test or are killed in # the meantime. - interrupted = self._interrupt_event.wait(self._loop_delay_ms / 1000.0) + interrupted = self._interrupt_event.wait( + self._loop_delay_ms / 1000.0 + ) if interrupted: self._hook.logger.info("interrupted") break @@ -57,7 +59,9 @@ class BGHook(interface.Hook): # By default, we continuously run the background hook for the duration of the suite. DEFAULT_TESTS_PER_CYCLE = math.inf - def __init__(self, hook_logger, fixture, desc, tests_per_cycle=None, loop_delay_ms=None): + def __init__( + self, hook_logger, fixture, desc, tests_per_cycle=None, loop_delay_ms=None + ): """ Initialize the background hook. @@ -72,7 +76,9 @@ class BGHook(interface.Hook): self._test_num = 0 # The number of tests we execute before restarting the background hook. - self._tests_per_cycle = self.DEFAULT_TESTS_PER_CYCLE if tests_per_cycle is None else tests_per_cycle + self._tests_per_cycle = ( + self.DEFAULT_TESTS_PER_CYCLE if tests_per_cycle is None else tests_per_cycle + ) self._loop_delay_ms = loop_delay_ms def run_action(self): @@ -98,7 +104,9 @@ class BGHook(interface.Hook): self._background_job.join() if self._background_job.err is not None: - self.logger.error("Encountered an error inside the hook: %s.", self._background_job.err) + self.logger.error( + "Encountered an error inside the hook: %s.", self._background_job.err + ) raise self._background_job.err def before_test(self, test, test_report): @@ -114,14 +122,21 @@ class BGHook(interface.Hook): def after_test(self, test, test_report): """Each test will call this after it executes. Check if the hook found an error.""" self._test_num += 1 - if self._test_num % self._tests_per_cycle != 0 and self._background_job.err is None: + if ( + self._test_num % self._tests_per_cycle != 0 + and self._background_job.err is None + ): return self._background_job.kill() self._background_job.join() if self._background_job.err is not None: - self.logger.error("Encountered an error inside the hook: %s.", self._background_job.err) + self.logger.error( + "Encountered an error inside the hook: %s.", self._background_job.err + ) raise self._background_job.err else: - self.logger.info("Reached end of cycle in the hook, killing background thread.") + self.logger.info( + "Reached end of cycle in the hook, killing background thread." + ) diff --git a/buildscripts/resmokelib/testing/hooks/change_collection_consistency.py b/buildscripts/resmokelib/testing/hooks/change_collection_consistency.py index ce4c7831c4b..c8146179531 100644 --- a/buildscripts/resmokelib/testing/hooks/change_collection_consistency.py +++ b/buildscripts/resmokelib/testing/hooks/change_collection_consistency.py @@ -11,9 +11,18 @@ class CheckReplChangeCollectionConsistency(jsfile.PerClusterDataConsistencyHook) IS_BACKGROUND = False def __init__( # pylint: disable=super-init-not-called - self, hook_logger, fixture, shell_options=None): + self, hook_logger, fixture, shell_options=None + ): """Initialize CheckReplChangeCollectionConsistency.""" description = "Check change_collection(s) of all replica set members" - js_filename = os.path.join("jstests", "hooks", "run_check_repl_change_collection.js") + js_filename = os.path.join( + "jstests", "hooks", "run_check_repl_change_collection.js" + ) jsfile.JSHook.__init__( # pylint: disable=non-parent-init-called - self, hook_logger, fixture, js_filename, description, shell_options=shell_options) + self, + hook_logger, + fixture, + js_filename, + description, + shell_options=shell_options, + ) diff --git a/buildscripts/resmokelib/testing/hooks/change_streams.py b/buildscripts/resmokelib/testing/hooks/change_streams.py index ef662dbe8f3..4c2224e6e8c 100644 --- a/buildscripts/resmokelib/testing/hooks/change_streams.py +++ b/buildscripts/resmokelib/testing/hooks/change_streams.py @@ -1,4 +1,5 @@ """Test hook to run change streams in the background.""" + import random import threading import time @@ -29,8 +30,11 @@ class RunChangeStreamsInBackground(interface.Hook): def before_suite(self, test_report): """Print the log message.""" - self.logger.info("Opening and closing change streams every %d tests. The seed is %d.", - self._every_n_tests, config.RANDOM_SEED) + self.logger.info( + "Opening and closing change streams every %d tests. The seed is %d.", + self._every_n_tests, + config.RANDOM_SEED, + ) def after_suite(self, test_report, teardown_flag=None): """Stop the background thread.""" @@ -41,7 +45,9 @@ class RunChangeStreamsInBackground(interface.Hook): """Start the background thread if it is not already started.""" if self._change_streams_thread is None: mongo_client = self._fixture.mongo_client() - self._change_streams_thread = _ChangeStreamsThread(self.logger, mongo_client) + self._change_streams_thread = _ChangeStreamsThread( + self.logger, mongo_client + ) self.logger.info("Starting the background change streams thread.") self._change_streams_thread.start() self._test_run = 0 @@ -79,12 +85,13 @@ class _ChangeStreamsThread(threading.Thread): try: change = stream.try_next() except Exception as err: # pylint: disable=broad-except - self.logger.error("Failed to get the next change from the change stream: %s", - err) + self.logger.error( + "Failed to get the next change from the change stream: %s", err + ) else: if change is None: # Since there are tests that are running under 1s, we are sleeping here for just 10ms - time.sleep(.01) + time.sleep(0.01) else: self.logger.info("Change document: %r", change) self._changes_num += 1 diff --git a/buildscripts/resmokelib/testing/hooks/cleanup.py b/buildscripts/resmokelib/testing/hooks/cleanup.py index 2b0f3cc4dfc..ff1beef5d69 100644 --- a/buildscripts/resmokelib/testing/hooks/cleanup.py +++ b/buildscripts/resmokelib/testing/hooks/cleanup.py @@ -24,7 +24,9 @@ class CleanEveryN(interface.Hook): if "detect_leaks=1" in os.getenv("ASAN_OPTIONS", ""): self.logger.info( "ASAN_OPTIONS environment variable set to detect leaks, so restarting" - " the fixture after each test instead of after every %d.", n) + " the fixture after each test instead of after every %d.", + n, + ) n = 1 self.n = n # pylint: disable=invalid-name @@ -47,8 +49,10 @@ class CleanEveryNTestCase(interface.DynamicTestCase): def run_test(self): """Execute test hook.""" try: - self.logger.info("%d tests have been run against the fixture, stopping it...", - self._hook.tests_run) + self.logger.info( + "%d tests have been run against the fixture, stopping it...", + self._hook.tests_run, + ) self._hook.tests_run = 0 self.fixture.teardown() diff --git a/buildscripts/resmokelib/testing/hooks/cleanup_concurrency_workloads.py b/buildscripts/resmokelib/testing/hooks/cleanup_concurrency_workloads.py index f7cf2cd368b..2c96ffa2a85 100644 --- a/buildscripts/resmokelib/testing/hooks/cleanup_concurrency_workloads.py +++ b/buildscripts/resmokelib/testing/hooks/cleanup_concurrency_workloads.py @@ -20,14 +20,22 @@ class CleanupConcurrencyWorkloads(interface.Hook): IS_BACKGROUND = False - def __init__(self, hook_logger, fixture, exclude_dbs=None, same_collection=False, - same_db=False): + def __init__( + self, + hook_logger, + fixture, + exclude_dbs=None, + same_collection=False, + same_db=False, + ): """Initialize CleanupConcurrencyWorkloads.""" description = "CleanupConcurrencyWorkloads drops all databases in the fixture" interface.Hook.__init__(self, hook_logger, fixture, description) protected_dbs = ["admin", "config", "local", "$external"] - self.exclude_dbs = list(set().union(protected_dbs, utils.default_if_none(exclude_dbs, []))) + self.exclude_dbs = list( + set().union(protected_dbs, utils.default_if_none(exclude_dbs, [])) + ) self.same_collection_name = None self.same_db_name = None if same_db or same_collection: @@ -40,7 +48,8 @@ class CleanupConcurrencyWorkloads(interface.Hook): def after_test(self, test, test_report): """After test cleanup.""" hook_test_case = CleanupConcurrencyWorkloadsTestCase.create_after_test( - test.logger, test, self) + test.logger, test, self + ) hook_test_case.configure(self.fixture) hook_test_case.run_dynamic_test(test_report) @@ -73,7 +82,9 @@ class CleanupConcurrencyWorkloadsTestCase(interface.DynamicTestCase): try: with_naive_retry(lambda: client.drop_database(db_name)) except: - self.logger.exception("Encountered an error while dropping database %s.", db_name) + self.logger.exception( + "Encountered an error while dropping database %s.", db_name + ) raise if self._hook.same_collection_name and same_db_name: @@ -83,11 +94,16 @@ class CleanupConcurrencyWorkloadsTestCase(interface.DynamicTestCase): self._hook.same_collection_name, ) colls = with_naive_retry(client[same_db_name].list_collection_names) - for coll in [coll for coll in colls if coll != self._hook.same_collection_name]: + for coll in [ + coll for coll in colls if coll != self._hook.same_collection_name + ]: self.logger.info("Dropping db %s collection %s", same_db_name, coll) try: with_naive_retry(lambda: client[same_db_name].drop_collection(coll)) except: - self.logger.exception("Encountered an error while dropping db % collection %s.", - same_db_name, coll) + self.logger.exception( + "Encountered an error while dropping db % collection %s.", + same_db_name, + coll, + ) raise diff --git a/buildscripts/resmokelib/testing/hooks/cluster_index_consistency.py b/buildscripts/resmokelib/testing/hooks/cluster_index_consistency.py index 075ca5b78a1..7287881725c 100644 --- a/buildscripts/resmokelib/testing/hooks/cluster_index_consistency.py +++ b/buildscripts/resmokelib/testing/hooks/cluster_index_consistency.py @@ -15,10 +15,15 @@ class CheckClusterIndexConsistency(jsfile.DataConsistencyHook): """Initialize CheckClusterIndexConsistency.""" if not isinstance(fixture, shardedcluster.ShardedClusterFixture): - raise ValueError(f"'fixture' must be an instance of ShardedClusterFixture, but got" - f" {fixture.__class__.__name__}") + raise ValueError( + f"'fixture' must be an instance of ShardedClusterFixture, but got" + f" {fixture.__class__.__name__}" + ) description = "Check index consistency across cluster" - js_filename = os.path.join("jstests", "hooks", "run_cluster_index_consistency.js") - super().__init__(hook_logger, fixture, js_filename, description, - shell_options=shell_options) + js_filename = os.path.join( + "jstests", "hooks", "run_cluster_index_consistency.js" + ) + super().__init__( + hook_logger, fixture, js_filename, description, shell_options=shell_options + ) diff --git a/buildscripts/resmokelib/testing/hooks/cluster_parameter.py b/buildscripts/resmokelib/testing/hooks/cluster_parameter.py index 112c041a918..2d05b5acb84 100644 --- a/buildscripts/resmokelib/testing/hooks/cluster_parameter.py +++ b/buildscripts/resmokelib/testing/hooks/cluster_parameter.py @@ -7,7 +7,6 @@ from buildscripts.resmokelib.testing.hooks import interface class ClusterParameter(interface.Hook): - IS_BACKGROUND = False def __init__(self, hook_logger, rs_fixture, key=None, value=None): @@ -22,7 +21,8 @@ class ClusterParameter(interface.Hook): """Calls setClusterParameter to set the specified parameter on the fixture before running the suite.""" client = self._fixture.get_primary().mongo_client() self._original_value = client.get_database("admin").command( - {"getClusterParameter": self._key})["clusterParameters"][0] + {"getClusterParameter": self._key} + )["clusterParameters"][0] # There are extra parameters in the response that aren't part of the original value so # they must be removed. del self._original_value["_id"] @@ -42,4 +42,5 @@ class ClusterParameter(interface.Hook): } client.get_database("admin").command(command_request) self.logger.info( - f"Successfully called setClusterParameter to restor original value of {self._key}") + f"Successfully called setClusterParameter to restor original value of {self._key}" + ) diff --git a/buildscripts/resmokelib/testing/hooks/continuous_initial_sync.py b/buildscripts/resmokelib/testing/hooks/continuous_initial_sync.py index 20c0431d5be..9aaea093586 100644 --- a/buildscripts/resmokelib/testing/hooks/continuous_initial_sync.py +++ b/buildscripts/resmokelib/testing/hooks/continuous_initial_sync.py @@ -20,14 +20,20 @@ from buildscripts.resmokelib.testing.hooks import lifecycle as lifecycle_interfa class ContinuousInitialSync(interface.Hook): """Periodically initial sync nodes then step them up.""" - DESCRIPTION = ("Continuous initial sync with failover") + DESCRIPTION = "Continuous initial sync with failover" IS_BACKGROUND = True # The hook stops the fixture partially during its execution. STOPS_FIXTURE = True - def __init__(self, hook_logger, fixture, use_action_permitted_file=False, sync_interval_secs=8): + def __init__( + self, + hook_logger, + fixture, + use_action_permitted_file=False, + sync_interval_secs=8, + ): """Initialize the ContinuousInitialSync. Args: @@ -36,7 +42,9 @@ class ContinuousInitialSync(interface.Hook): use_action_permitted_file: use a file to control if the syncer thread should do a failover or initial sync sync_interval_secs: how often to trigger a new cycle """ - interface.Hook.__init__(self, hook_logger, fixture, ContinuousInitialSync.DESCRIPTION) + interface.Hook.__init__( + self, hook_logger, fixture, ContinuousInitialSync.DESCRIPTION + ) self.hook_logger = hook_logger @@ -53,10 +61,12 @@ class ContinuousInitialSync(interface.Hook): dbpath_prefix = fixture.get_dbpath_prefix() if use_action_permitted_file: - self.__action_files = lifecycle_interface.ActionFiles._make([ - os.path.join(dbpath_prefix, field) - for field in lifecycle_interface.ActionFiles._fields - ]) + self.__action_files = lifecycle_interface.ActionFiles._make( + [ + os.path.join(dbpath_prefix, field) + for field in lifecycle_interface.ActionFiles._fields + ] + ) else: self.__action_files = None @@ -66,13 +76,20 @@ class ContinuousInitialSync(interface.Hook): self._add_fixture(self._fixture) if self.__action_files is not None: - lifecycle = lifecycle_interface.FileBasedThreadLifecycle(self.__action_files) + lifecycle = lifecycle_interface.FileBasedThreadLifecycle( + self.__action_files + ) else: lifecycle_interface.FlagBasedThreadLifecycle() - self._initial_sync_thread = _InitialSyncThread(self.logger, self._rs_fixtures, - self._mongos_fixtures, self._fixture, - lifecycle, self._sync_interval_secs) + self._initial_sync_thread = _InitialSyncThread( + self.logger, + self._rs_fixtures, + self._mongos_fixtures, + self._fixture, + lifecycle, + self._sync_interval_secs, + ) self.logger.info("Starting the continuous initial syncer thread.") self._initial_sync_thread.start() @@ -119,7 +136,6 @@ class SyncerStage(Enum): class _InitialSyncThread(threading.Thread): - # Error codes, taken from mongo/base/error_codes.yml. _NODE_NOT_FOUND = 74 _NEW_REPLICA_SET_CONFIGURATION_INCOMPATIBLE = 103 @@ -128,8 +144,15 @@ class _InitialSyncThread(threading.Thread): _INTERRUPTED_DUE_TO_STORAGE_CHANGE = 355 _INTERRUPTED_DUE_TO_REPL_STATE_CHANGE = 11602 - def __init__(self, logger, rs_fixtures, mongos_fixtures, fixture, lifecycle, - sync_interval_secs): + def __init__( + self, + logger, + rs_fixtures, + mongos_fixtures, + fixture, + lifecycle, + sync_interval_secs, + ): """Initialize _InitialSyncThread.""" threading.Thread.__init__(self, name="InitialSyncThread") self.daemon = True @@ -194,7 +217,9 @@ class _InitialSyncThread(threading.Thread): elif stage == SyncerStage.INITSYNC_PRIMARY: self.logger.info("Stepping up new secondaries...") for fixture in self._rs_fixtures: - self._fail_over_to_node(fixture.get_initial_sync_node(), fixture) + self._fail_over_to_node( + fixture.get_initial_sync_node(), fixture + ) stage = SyncerStage.ORIGINAL_PRIMARY wait_secs = self._sync_interval_secs @@ -215,7 +240,9 @@ class _InitialSyncThread(threading.Thread): self._is_idle_evt.set() self.logger.info( "Syncer sleeping for {} seconds before moving to the next stage.".format( - wait_secs)) + wait_secs + ) + ) self.__lifecycle.wait_for_action_interval(wait_secs) except Exception as err: # pylint: disable=broad-except @@ -270,19 +297,24 @@ class _InitialSyncThread(threading.Thread): if not rs_fixture.is_running(): raise errors.ServerFailure( "ReplicaSetFixture with pids {} expected to be running in" - " ContinuousInitialSync, but wasn't.".format(rs_fixture.pids())) + " ContinuousInitialSync, but wasn't.".format(rs_fixture.pids()) + ) for mongos_fixture in self._mongos_fixtures: if not mongos_fixture.is_running(): - raise errors.ServerFailure("MongoSFixture with pids {} expected to be running in" - " ContinuousInitialSync, but wasn't.".format( - mongos_fixture.pids())) + raise errors.ServerFailure( + "MongoSFixture with pids {} expected to be running in" + " ContinuousInitialSync, but wasn't.".format(mongos_fixture.pids()) + ) def _add_initsync_tag(self, fixture): """Adds the 'INIT_SYNC' unique tag to the initial-sync node of the given fixture.""" sync_node = fixture.get_initial_sync_node() - self.logger.info("Adding unique tag to initial sync node on port {} in set {}".format( - sync_node.port, fixture.replset_name)) + self.logger.info( + "Adding unique tag to initial sync node on port {} in set {}".format( + sync_node.port, fixture.replset_name + ) + ) retry_time_secs = fixture_interface.ReplFixture.AWAIT_REPL_TIMEOUT_MINS * 60 retry_start_time = time.time() @@ -290,8 +322,10 @@ class _InitialSyncThread(threading.Thread): while True: if time.time() - retry_start_time > retry_time_secs: raise errors.ServerFailure( - "Could not add unique tag to node on port {} for replica set {} in {} seconds.". - format(sync_node.port, fixture.replset_name, retry_time_secs)) + "Could not add unique tag to node on port {} for replica set {} in {} seconds.".format( + sync_node.port, fixture.replset_name, retry_time_secs + ) + ) try: primary = fixture.get_primary() client = primary.mongo_client() @@ -300,10 +334,14 @@ class _InitialSyncThread(threading.Thread): repl_config["version"] += 1 repl_config["members"][-1]["tags"]["uniqueTag"] = "INIT_SYNC_NODE" - client.admin.command({ - "replSetReconfig": repl_config, - "maxTimeMS": fixture_interface.ReplFixture.AWAIT_REPL_TIMEOUT_MINS * 60 * 1000 - }) + client.admin.command( + { + "replSetReconfig": repl_config, + "maxTimeMS": fixture_interface.ReplFixture.AWAIT_REPL_TIMEOUT_MINS + * 60 + * 1000, + } + ) break except pymongo.errors.AutoReconnect: @@ -319,11 +357,14 @@ class _InitialSyncThread(threading.Thread): # (potentially) higher config version. We should not receive these codes # indefinitely. # pylint: disable=too-many-boolean-expressions - if err.code not in (self._NEW_REPLICA_SET_CONFIGURATION_INCOMPATIBLE, - self._CURRENT_CONFIG_NOT_COMMITTED_YET, - self._CONFIGURATION_IN_PROGRESS, self._NODE_NOT_FOUND, - self._INTERRUPTED_DUE_TO_REPL_STATE_CHANGE, - self._INTERRUPTED_DUE_TO_STORAGE_CHANGE): + if err.code not in ( + self._NEW_REPLICA_SET_CONFIGURATION_INCOMPATIBLE, + self._CURRENT_CONFIG_NOT_COMMITTED_YET, + self._CONFIGURATION_IN_PROGRESS, + self._NODE_NOT_FOUND, + self._INTERRUPTED_DUE_TO_REPL_STATE_CHANGE, + self._INTERRUPTED_DUE_TO_STORAGE_CHANGE, + ): msg = ( "Operation failure while adding tag for node on port {} in fixture {}: {}" ).format(sync_node.port, fixture.replset_name, err) @@ -331,8 +372,8 @@ class _InitialSyncThread(threading.Thread): raise self.fixturelib.ServerFailure(msg) msg = ( - "Retrying failed attempt to add new node on port {} to fixture {}: {}").format( - sync_node.port, fixture.replset_name, err) + "Retrying failed attempt to add new node on port {} to fixture {}: {}" + ).format(sync_node.port, fixture.replset_name, err) self.logger.error(msg) time.sleep(0.1) continue @@ -344,8 +385,10 @@ class _InitialSyncThread(threading.Thread): method = random.choice(["logical", "fileCopyBased"]) self.logger.info( - "Restarting initial sync on node on port {} in set {} with initial sync method {}". - format(sync_node.port, fixture.replset_name, method)) + "Restarting initial sync on node on port {} in set {} with initial sync method {}".format( + sync_node.port, fixture.replset_name, method + ) + ) sync_node.teardown() sync_node.mongod_options["set_parameters"]["initialSyncMethod"] = method @@ -360,8 +403,11 @@ class _InitialSyncThread(threading.Thread): """Waits for the initial sync node to complete its transition to secondary.""" sync_node = fixture.get_initial_sync_node() - self.logger.info("Waiting for node on port {} in set {} to complete initial sync".format( - sync_node.port, fixture.replset_name)) + self.logger.info( + "Waiting for node on port {} in set {} to complete initial sync".format( + sync_node.port, fixture.replset_name + ) + ) retry_time_secs = fixture_interface.ReplFixture.AWAIT_REPL_TIMEOUT_MINS * 60 retry_start_time = time.time() @@ -372,16 +418,21 @@ class _InitialSyncThread(threading.Thread): time.sleep(2) if time.time() - retry_start_time > retry_time_secs: raise errors.ServerFailure( - "Node on port {} of replica set {} did not finish initial sync in {} seconds.". - format(sync_node.port, fixture.replset_name, retry_time_secs)) + "Node on port {} of replica set {} did not finish initial sync in {} seconds.".format( + sync_node.port, fixture.replset_name, retry_time_secs + ) + ) def _check_initial_sync_done(self, fixture): """A one-time check for whether a node has completed initial sync and transitioned to a secondary state.""" sync_node = fixture.get_initial_sync_node() sync_node_conn = sync_node.mongo_client() - self.logger.info("Checking initial sync progress for node on port {} in set {}".format( - sync_node.port, fixture.replset_name)) + self.logger.info( + "Checking initial sync progress for node on port {} in set {}".format( + sync_node.port, fixture.replset_name + ) + ) try: state = sync_node_conn.admin.command("replSetGetStatus").get("myState") @@ -395,8 +446,11 @@ class _InitialSyncThread(threading.Thread): conn = node.mongo_client() old_primary = fixture.get_primary() - self.logger.info("Failing over to node on port {} from node on port {} in set {}".format( - node.port, old_primary.port, fixture.replset_name)) + self.logger.info( + "Failing over to node on port {} from node on port {} in set {}".format( + node.port, old_primary.port, fixture.replset_name + ) + ) retry_time_secs = fixture_interface.ReplFixture.AWAIT_REPL_TIMEOUT_MINS * 60 retry_start_time = time.time() @@ -414,11 +468,20 @@ class _InitialSyncThread(threading.Thread): if time.time() - retry_start_time > retry_time_secs: raise errors.ServerFailure( "Node on port {} of replica set {} did not step up in {} seconds.".format( - node.port, fixture.replset_name, retry_time_secs)) + node.port, fixture.replset_name, retry_time_secs + ) + ) - cmd = bson.SON([("replSetTest", 1), ("waitForMemberState", 1), - ("timeoutMillis", - fixture_interface.ReplFixture.AWAIT_REPL_TIMEOUT_FOREVER_MINS * 60)]) + cmd = bson.SON( + [ + ("replSetTest", 1), + ("waitForMemberState", 1), + ( + "timeoutMillis", + fixture_interface.ReplFixture.AWAIT_REPL_TIMEOUT_FOREVER_MINS * 60, + ), + ] + ) retry_start_time = time.time() @@ -427,8 +490,10 @@ class _InitialSyncThread(threading.Thread): conn.admin.command(cmd) break except pymongo.errors.OperationFailure as err: - if err.code not in (self.INTERRUPTED_DUE_TO_REPL_STATE_CHANGE, - self.INTERRUPTED_DUE_TO_STORAGE_CHANGE): + if err.code not in ( + self.INTERRUPTED_DUE_TO_REPL_STATE_CHANGE, + self.INTERRUPTED_DUE_TO_STORAGE_CHANGE, + ): raise msg = ( "Interrupted while waiting for node on port {} in set {} to reach primary state, retrying: {}" @@ -436,13 +501,17 @@ class _InitialSyncThread(threading.Thread): self.logger.error(msg) if time.time() - retry_start_time > retry_time_secs: raise errors.ServerFailure( - "Node on port {} of replica set {} did not finish stepping up in {} seconds.". - format(node.port, fixture.replset_name, retry_time_secs)) + "Node on port {} of replica set {} did not finish stepping up in {} seconds.".format( + node.port, fixture.replset_name, retry_time_secs + ) + ) time.sleep(0.2) self.logger.info( - "Successfully stepped up node on port {} in set {}. Waiting for old primary on port {} to step down" - .format(node.port, fixture.replset_name, old_primary.port)) + "Successfully stepped up node on port {} in set {}. Waiting for old primary on port {} to step down".format( + node.port, fixture.replset_name, old_primary.port + ) + ) retry_start_time = time.time() @@ -455,16 +524,24 @@ class _InitialSyncThread(threading.Thread): break except pymongo.errors.AutoReconnect: pass - self.logger.info("Waiting for old primary on port {} in set {} to step down.".format( - old_primary.port, fixture.replset_name)) + self.logger.info( + "Waiting for old primary on port {} in set {} to step down.".format( + old_primary.port, fixture.replset_name + ) + ) if time.time() - retry_start_time > retry_time_secs: raise errors.ServerFailure( - "Old primary on port {} of replica set {} did not step down in {} seconds.". - format(node.port, fixture.replset_name, retry_time_secs)) + "Old primary on port {} of replica set {} did not step down in {} seconds.".format( + node.port, fixture.replset_name, retry_time_secs + ) + ) time.sleep(0.2) - self.logger.info("Old primary on port {} in set {} successfully stepped down".format( - old_primary.port, fixture.replset_name)) + self.logger.info( + "Old primary on port {} in set {} successfully stepped down".format( + old_primary.port, fixture.replset_name + ) + ) # It is possible for the initial sync node to have been behind every other node when it # stepped up causing them all to enter rollback. To make sure this doesn't happen we @@ -501,6 +578,10 @@ class _InitialSyncThread(threading.Thread): shell_proc.start() return_code = shell_proc.wait() if return_code: - raise errors.ServerFailure("Awaiting replication failed for {}".format(client_conn)) + raise errors.ServerFailure( + "Awaiting replication failed for {}".format(client_conn) + ) - self.logger.info("Finished WaitForReplication, no nodes should be in ROLLBACK state.") + self.logger.info( + "Finished WaitForReplication, no nodes should be in ROLLBACK state." + ) diff --git a/buildscripts/resmokelib/testing/hooks/cpp_libfuzzer.py b/buildscripts/resmokelib/testing/hooks/cpp_libfuzzer.py index 1053191a64c..9146c4a9521 100644 --- a/buildscripts/resmokelib/testing/hooks/cpp_libfuzzer.py +++ b/buildscripts/resmokelib/testing/hooks/cpp_libfuzzer.py @@ -41,8 +41,10 @@ class LibfuzzerHook(interface.Hook): self._merge_corpus(test) def _merge_corpus(self, test): - self.logger.info(f"Merge for {test.short_name()} libfuzzer test started, " - f"merging to {test.merged_corpus_directory}.") + self.logger.info( + f"Merge for {test.short_name()} libfuzzer test started, " + f"merging to {test.merged_corpus_directory}." + ) os.makedirs(test.merged_corpus_directory, exist_ok=True) default_args = [ test.program_executable, @@ -50,7 +52,9 @@ class LibfuzzerHook(interface.Hook): test.merged_corpus_directory, test.corpus_directory, ] - process = core.programs.make_process(self.logger, default_args, **test.program_options) + process = core.programs.make_process( + self.logger, default_args, **test.program_options + ) process.start() process.wait() self.logger.info(f"Merge for {test.short_name()} libfuzzer test finished.") diff --git a/buildscripts/resmokelib/testing/hooks/dbcheck_background.py b/buildscripts/resmokelib/testing/hooks/dbcheck_background.py index b08125deaed..278fd43e9aa 100644 --- a/buildscripts/resmokelib/testing/hooks/dbcheck_background.py +++ b/buildscripts/resmokelib/testing/hooks/dbcheck_background.py @@ -22,8 +22,14 @@ class RunDBCheckInBackground(jsfile.JSHook): """Initialize RunDBCheckInBackground.""" description = "Runs dbCheck on a replica set while a test is running" js_filename = os.path.join("jstests", "hooks", "run_dbcheck_background.js") - jsfile.JSHook.__init__(self, hook_logger, fixture, js_filename, description, - shell_options=shell_options) + jsfile.JSHook.__init__( + self, + hook_logger, + fixture, + js_filename, + description, + shell_options=shell_options, + ) self._background_job = None @@ -47,7 +53,8 @@ class RunDBCheckInBackground(jsfile.JSHook): return hook_test_case = _ContinuousDynamicJSTestCase.create_before_test( - test.logger, test, self, self._js_filename, self._shell_options) + test.logger, test, self, self._js_filename, self._shell_options + ) hook_test_case.configure(self.fixture) self.logger.info("Resuming the background check repl dbCheck thread.") @@ -72,5 +79,6 @@ class RunDBCheckInBackground(jsfile.JSHook): else: self.logger.error( "Encountered an error inside the background check repl dbCheck thread.", - exc_info=self._background_job.exc_info) + exc_info=self._background_job.exc_info, + ) raise self._background_job.exc_info[1] diff --git a/buildscripts/resmokelib/testing/hooks/dbhash.py b/buildscripts/resmokelib/testing/hooks/dbhash.py index 9f71fc1aeb6..f98830a070d 100644 --- a/buildscripts/resmokelib/testing/hooks/dbhash.py +++ b/buildscripts/resmokelib/testing/hooks/dbhash.py @@ -15,9 +15,16 @@ class CheckReplDBHash(jsfile.PerClusterDataConsistencyHook): IS_BACKGROUND = False def __init__( # pylint: disable=super-init-not-called - self, hook_logger, fixture, shell_options=None): + self, hook_logger, fixture, shell_options=None + ): """Initialize CheckReplDBHash.""" description = "Check dbhashes of all replica set or master/slave members" js_filename = os.path.join("jstests", "hooks", "run_check_repl_dbhash.js") jsfile.JSHook.__init__( # pylint: disable=non-parent-init-called - self, hook_logger, fixture, js_filename, description, shell_options=shell_options) + self, + hook_logger, + fixture, + js_filename, + description, + shell_options=shell_options, + ) diff --git a/buildscripts/resmokelib/testing/hooks/dbhash_background.py b/buildscripts/resmokelib/testing/hooks/dbhash_background.py index c4b5e3b1ac1..304b470bad1 100644 --- a/buildscripts/resmokelib/testing/hooks/dbhash_background.py +++ b/buildscripts/resmokelib/testing/hooks/dbhash_background.py @@ -21,10 +21,20 @@ class CheckReplDBHashInBackground(jsfile.JSHook): def __init__(self, hook_logger, fixture, shell_options=None): """Initialize CheckReplDBHashInBackground.""" - description = "Check dbhashes of all replica set members while a test is running" - js_filename = os.path.join("jstests", "hooks", "run_check_repl_dbhash_background.js") - jsfile.JSHook.__init__(self, hook_logger, fixture, js_filename, description, - shell_options=shell_options) + description = ( + "Check dbhashes of all replica set members while a test is running" + ) + js_filename = os.path.join( + "jstests", "hooks", "run_check_repl_dbhash_background.js" + ) + jsfile.JSHook.__init__( + self, + hook_logger, + fixture, + js_filename, + description, + shell_options=shell_options, + ) self._background_job = None @@ -36,11 +46,14 @@ class CheckReplDBHashInBackground(jsfile.JSHook): # replica set shards supports snapshot reads. if not client.is_mongos: server_status = client.admin.command("serverStatus") - if not server_status["storageEngine"].get("supportsSnapshotReadConcern", False): + if not server_status["storageEngine"].get( + "supportsSnapshotReadConcern", False + ): self.logger.info( "Not enabling the background check repl dbhash thread because '%s' storage" " engine doesn't support snapshot reads.", - server_status["storageEngine"]["name"]) + server_status["storageEngine"]["name"], + ) return self._background_job = _BackgroundJob("CheckReplDBHashInBackground") @@ -61,7 +74,8 @@ class CheckReplDBHashInBackground(jsfile.JSHook): return hook_test_case = _ContinuousDynamicJSTestCase.create_before_test( - test.logger, test, self, self._js_filename, self._shell_options) + test.logger, test, self, self._js_filename, self._shell_options + ) hook_test_case.configure(self.fixture) self.logger.info("Resuming the background check repl dbhash thread.") @@ -86,5 +100,6 @@ class CheckReplDBHashInBackground(jsfile.JSHook): else: self.logger.error( "Encountered an error inside the background check repl dbhash thread.", - exc_info=self._background_job.exc_info) + exc_info=self._background_job.exc_info, + ) raise self._background_job.exc_info[1] diff --git a/buildscripts/resmokelib/testing/hooks/drop_user_collections.py b/buildscripts/resmokelib/testing/hooks/drop_user_collections.py index 0db80612ccf..ed4a53e7150 100644 --- a/buildscripts/resmokelib/testing/hooks/drop_user_collections.py +++ b/buildscripts/resmokelib/testing/hooks/drop_user_collections.py @@ -14,5 +14,11 @@ class DropUserCollections(jsfile.JSHook): """.""" description = "Drop all user collections" js_filename = os.path.join("jstests", "hooks", "drop_user_collections.js") - jsfile.JSHook.__init__(self, hook_logger, fixture, js_filename, description, - shell_options=shell_options) + jsfile.JSHook.__init__( + self, + hook_logger, + fixture, + js_filename, + description, + shell_options=shell_options, + ) diff --git a/buildscripts/resmokelib/testing/hooks/enable_change_stream.py b/buildscripts/resmokelib/testing/hooks/enable_change_stream.py index c5c4477fd3d..3ca707c6e2f 100644 --- a/buildscripts/resmokelib/testing/hooks/enable_change_stream.py +++ b/buildscripts/resmokelib/testing/hooks/enable_change_stream.py @@ -3,6 +3,7 @@ A hook to enable change stream in the replica set and the sharded cluster in the multi-tenant environment. """ + import os.path from time import sleep @@ -23,7 +24,9 @@ class EnableChangeStream(interface.Hook): def __init__(self, hook_logger, fixture, tenant_id=None): """Initialize the EnableChangeCollection.""" description = "Enables the change stream in the multi-tenant environment." - self._js_filename = os.path.join("jstests", "hooks", "run_enable_change_stream.js") + self._js_filename = os.path.join( + "jstests", "hooks", "run_enable_change_stream.js" + ) interface.Hook.__init__(self, hook_logger, fixture, description) self._fixture = fixture self._tenant_id = ObjectId(tenant_id) if tenant_id else None @@ -47,8 +50,11 @@ class EnableChangeStream(interface.Hook): sleep(5) def _call_js_hook(self, fixture, test, test_report): - shell_options = {"global_vars": {"TestData": {"tenantId": str(self._tenant_id)}}} + shell_options = { + "global_vars": {"TestData": {"tenantId": str(self._tenant_id)}} + } hook_test_case = jsfile.DynamicJSTestCase.create_before_test( - test.logger, test, self, self._js_filename, shell_options) + test.logger, test, self, self._js_filename, shell_options + ) hook_test_case.configure(fixture) hook_test_case.run_dynamic_test(test_report) diff --git a/buildscripts/resmokelib/testing/hooks/enable_spurious_write_conflicts.py b/buildscripts/resmokelib/testing/hooks/enable_spurious_write_conflicts.py index 9cee6163e5a..86715ffae96 100644 --- a/buildscripts/resmokelib/testing/hooks/enable_spurious_write_conflicts.py +++ b/buildscripts/resmokelib/testing/hooks/enable_spurious_write_conflicts.py @@ -17,20 +17,26 @@ class EnableSpuriousWriteConflicts(interface.Hook): def __init__(self, hook_logger, fixture, shell_options=None): """Initialize ToggleWriteConflicts.""" super().__init__(hook_logger, fixture, "TogglesWTWriteConflictExceptions") - self._enable_js_filename = os.path.join("jstests", "hooks", "enable_write_conflicts.js") - self._disable_js_filename = os.path.join("jstests", "hooks", "disable_write_conflicts.js") + self._enable_js_filename = os.path.join( + "jstests", "hooks", "enable_write_conflicts.js" + ) + self._disable_js_filename = os.path.join( + "jstests", "hooks", "disable_write_conflicts.js" + ) self._shell_options = shell_options def before_test(self, test, test_report): """Enable WTWriteConflictExceptions.""" hook_test_case = jsfile.DynamicJSTestCase.create_after_test( - test.logger, test, self, self._enable_js_filename, self._shell_options) + test.logger, test, self, self._enable_js_filename, self._shell_options + ) hook_test_case.configure(self.fixture) hook_test_case.run_dynamic_test(test_report) def after_test(self, test, test_report): """Disable WTWriteConflictExceptions.""" hook_test_case = jsfile.DynamicJSTestCase.create_after_test( - test.logger, test, self, self._disable_js_filename, self._shell_options) + test.logger, test, self, self._disable_js_filename, self._shell_options + ) hook_test_case.configure(self.fixture) hook_test_case.run_dynamic_test(test_report) diff --git a/buildscripts/resmokelib/testing/hooks/fcv_upgrade_downgrade.py b/buildscripts/resmokelib/testing/hooks/fcv_upgrade_downgrade.py index cb38c139695..9c12f1762db 100644 --- a/buildscripts/resmokelib/testing/hooks/fcv_upgrade_downgrade.py +++ b/buildscripts/resmokelib/testing/hooks/fcv_upgrade_downgrade.py @@ -23,9 +23,17 @@ class FCVUpgradeDowngradeInBackground(jsfile.JSHook): def __init__(self, hook_logger, fixture, shell_options=None): """Initialize FCVUpgradeDowngradeInBackground.""" description = "Run background FCV upgrade/downgrade while a test is running" - js_filename = os.path.join("jstests", "hooks", "run_fcv_upgrade_downgrade_background.js") - jsfile.JSHook.__init__(self, hook_logger, fixture, js_filename, description, - shell_options=shell_options) + js_filename = os.path.join( + "jstests", "hooks", "run_fcv_upgrade_downgrade_background.js" + ) + jsfile.JSHook.__init__( + self, + hook_logger, + fixture, + js_filename, + description, + shell_options=shell_options, + ) self._background_job = None @@ -49,7 +57,8 @@ class FCVUpgradeDowngradeInBackground(jsfile.JSHook): return hook_test_case = _ContinuousDynamicJSTestCase.create_before_test( - test.logger, test, self, self._js_filename, self._shell_options) + test.logger, test, self, self._js_filename, self._shell_options + ) hook_test_case.configure(self.fixture) self.logger.info("Resuming the background FCV upgrade/downgrade thread.") @@ -74,5 +83,6 @@ class FCVUpgradeDowngradeInBackground(jsfile.JSHook): else: self.logger.error( "Encountered an error inside the background FCV upgrade/downgrade thread.", - exc_info=self._background_job.exc_info) + exc_info=self._background_job.exc_info, + ) raise self._background_job.exc_info[1] diff --git a/buildscripts/resmokelib/testing/hooks/fuzzer_restore_settings.py b/buildscripts/resmokelib/testing/hooks/fuzzer_restore_settings.py index 00277b12f58..1d56a200dcf 100644 --- a/buildscripts/resmokelib/testing/hooks/fuzzer_restore_settings.py +++ b/buildscripts/resmokelib/testing/hooks/fuzzer_restore_settings.py @@ -18,5 +18,11 @@ class FuzzerRestoreSettings(jsfile.JSHook): """Run fuzzer cleanup.""" description = "Clean up unwanted changes from fuzzer" js_filename = os.path.join("jstests", "hooks", "run_fuzzer_restore_settings.js") - jsfile.JSHook.__init__(self, hook_logger, fixture, js_filename, description, - shell_options=shell_options) + jsfile.JSHook.__init__( + self, + hook_logger, + fixture, + js_filename, + description, + shell_options=shell_options, + ) diff --git a/buildscripts/resmokelib/testing/hooks/generate_and_check_perf_results.py b/buildscripts/resmokelib/testing/hooks/generate_and_check_perf_results.py index 959d741e86e..f555af306ec 100644 --- a/buildscripts/resmokelib/testing/hooks/generate_and_check_perf_results.py +++ b/buildscripts/resmokelib/testing/hooks/generate_and_check_perf_results.py @@ -62,7 +62,9 @@ class GenerateAndCheckPerfResults(interface.Hook): def __init__(self, hook_logger, fixture): """Initialize GenerateAndCheckPerfResults.""" - interface.Hook.__init__(self, hook_logger, fixture, GenerateAndCheckPerfResults.DESCRIPTION) + interface.Hook.__init__( + self, hook_logger, fixture, GenerateAndCheckPerfResults.DESCRIPTION + ) self.cedar_report_file = _config.CEDAR_REPORT_FILE self.variant = _config.EVERGREEN_VARIANT_NAME self.cedar_reports: List[CedarTestReport] = [] @@ -89,10 +91,17 @@ class GenerateAndCheckPerfResults(interface.Hook): self.cedar_reports.extend(cedar_formatted_results) - self._check_pass_fail(benchmark_reports, cedar_formatted_results, test, test_report) + self._check_pass_fail( + benchmark_reports, cedar_formatted_results, test, test_report + ) - def _check_pass_fail(self, benchmark_reports: Dict[str, "_BenchmarkThreadsReport"], - cedar_formatted_results: List[CedarTestReport], test, test_report): + def _check_pass_fail( + self, + benchmark_reports: Dict[str, "_BenchmarkThreadsReport"], + cedar_formatted_results: List[CedarTestReport], + test, + test_report, + ): """Check to see if any of the reported results violate any of the thresholds set.""" if self.variant is None: self.logger.info( @@ -103,7 +112,8 @@ class GenerateAndCheckPerfResults(interface.Hook): variant_thresholds = self.performance_thresholds.get(test_name, None) if variant_thresholds is None: self.logger.info( - f"No thresholds were set for {test_name}, skipping threshold check") + f"No thresholds were set for {test_name}, skipping threshold check" + ) continue test_thresholds = variant_thresholds.get(self.variant, None) if test_thresholds is None: @@ -117,24 +127,33 @@ class GenerateAndCheckPerfResults(interface.Hook): thread_level = item["thread_level"] for metric in item["metrics"]: metrics_to_check.append( - IndividualMetricThreshold(test_name=test_name, thread_level=thread_level, - metric_name=metric["name"], value=metric["value"], - bound_direction=metric["bound_direction"])) + IndividualMetricThreshold( + test_name=test_name, + thread_level=thread_level, + metric_name=metric["name"], + value=metric["value"], + bound_direction=metric["bound_direction"], + ) + ) # Transform the reported performance results into something we can more easily use. transformed_metrics: Dict[ReportedMetric, CedarMetric] = {} for cedar_result in cedar_formatted_results: for individual_metric in cedar_result.metrics: - reported_metric = ReportedMetric(test_name=cedar_result.test_name, - thread_level=cedar_result.thread_level, - metric_name=individual_metric.name) + reported_metric = ReportedMetric( + test_name=cedar_result.test_name, + thread_level=cedar_result.thread_level, + metric_name=individual_metric.name, + ) if transformed_metrics.get(reported_metric, None) is not None: raise CedarReportError( - f"Multiple values reported for the same metric: {reported_metric}") + f"Multiple values reported for the same metric: {reported_metric}" + ) else: transformed_metrics[reported_metric] = individual_metric # Add a dynamic resmoke test to make sure that the pass/fail results are reported correctly. hook_test_case = CheckPerfResultTestCase.create_after_test( - self.logger, test, self, metrics_to_check, transformed_metrics) + self.logger, test, self, metrics_to_check, transformed_metrics + ) hook_test_case.configure(self.fixture) hook_test_case.run_dynamic_test(test_report) @@ -148,7 +167,8 @@ class GenerateAndCheckPerfResults(interface.Hook): except Exception: self.logger.exception( f"Could not load in the threshold file needed to check performance results. " - f"Trying to retrieve them from {THRESHOLD_LOCATION}.") + f"Trying to retrieve them from {THRESHOLD_LOCATION}." + ) raise ServerFailure( "Could not load the needed threshold information. Please make sure you are in the root of the mongo repo." ) @@ -164,7 +184,8 @@ class GenerateAndCheckPerfResults(interface.Hook): json.dump(dict_formatted_results, fh) def _generate_cedar_report( - self, benchmark_reports: Dict[str, "_BenchmarkThreadsReport"]) -> List[CedarTestReport]: + self, benchmark_reports: Dict[str, "_BenchmarkThreadsReport"] + ) -> List[CedarTestReport]: """Format the data to look like a cedar report.""" cedar_report = [] @@ -176,8 +197,9 @@ class GenerateAndCheckPerfResults(interface.Hook): raise CedarReportError(msg) for threads_count, thread_metrics in cedar_metrics.items(): - test_report = CedarTestReport(test_name=name, thread_level=threads_count, - metrics=thread_metrics) + test_report = CedarTestReport( + test_name=name, thread_level=threads_count, metrics=thread_metrics + ) cedar_report.append(test_report) return cedar_report @@ -191,19 +213,32 @@ class GenerateAndCheckPerfResults(interface.Hook): bm_name_obj = _BenchmarkThreadsReport.parse_bm_name(benchmark_res) if bm_name_obj.base_name not in benchmark_reports: - benchmark_reports[bm_name_obj.base_name] = _BenchmarkThreadsReport(context) - benchmark_reports[bm_name_obj.base_name].add_report(bm_name_obj, benchmark_res) + benchmark_reports[bm_name_obj.base_name] = _BenchmarkThreadsReport( + context + ) + benchmark_reports[bm_name_obj.base_name].add_report( + bm_name_obj, benchmark_res + ) return benchmark_reports class CheckPerfResultTestCase(interface.DynamicTestCase): """CheckPerfResultTestCase class.""" - def __init__(self, logger, test_name, description, base_test_name, hook, - thresholds_to_check: List["IndividualMetricThreshold"], - reported_metrics: Dict[ReportedMetric, CedarMetric]): + def __init__( + self, + logger, + test_name, + description, + base_test_name, + hook, + thresholds_to_check: List["IndividualMetricThreshold"], + reported_metrics: Dict[ReportedMetric, CedarMetric], + ): super().__init__(logger, test_name, description, base_test_name, hook) - self.thresholds_to_check: List["IndividualMetricThreshold"] = thresholds_to_check + self.thresholds_to_check: List["IndividualMetricThreshold"] = ( + thresholds_to_check + ) self.reported_metrics: Dict[ReportedMetric, CedarMetric] = reported_metrics def run_test(self): @@ -217,19 +252,26 @@ class CheckPerfResultTestCase(interface.DynamicTestCase): for metric_to_check in self.thresholds_to_check: reported_metric = self.reported_metrics.get( - ReportedMetric(test_name=metric_to_check.test_name, - thread_level=metric_to_check.thread_level, - metric_name=metric_to_check.metric_name), None) + ReportedMetric( + test_name=metric_to_check.test_name, + thread_level=metric_to_check.thread_level, + metric_name=metric_to_check.metric_name, + ), + None, + ) if reported_metric is None: self.logger.error( f"One of the expected metrics was not able to be found in the performance results generated by this task. {metric_to_check.test_name} with thread_level of {metric_to_check.thread_level} did not report a metric called {metric_to_check.metric_name}." ) any_metric_has_failed = True continue - if (metric_to_check.bound_direction == BoundDirection.UPPER - and metric_to_check.value < reported_metric.value) or ( - metric_to_check.bound_direction == BoundDirection.LOWER - and metric_to_check.value > reported_metric.value): + if ( + metric_to_check.bound_direction == BoundDirection.UPPER + and metric_to_check.value < reported_metric.value + ) or ( + metric_to_check.bound_direction == BoundDirection.LOWER + and metric_to_check.value > reported_metric.value + ): if metric_to_check.bound_direction == BoundDirection.LOWER: self.logger.error( f"Metric {metric_to_check.metric_name} in {metric_to_check.test_name} with thread_level of {metric_to_check.thread_level} has failed the threshold check. The reported value of {reported_metric.value} is lower than the set threshold of {metric_to_check.value}" @@ -247,8 +289,9 @@ class CheckPerfResultTestCase(interface.DynamicTestCase): # Capture information from a Benchmark name in a logical format. -_BenchmarkName = collections.namedtuple("_BenchmarkName", - ["base_name", "thread_count", "statistic_type"]) +_BenchmarkName = collections.namedtuple( + "_BenchmarkName", ["base_name", "thread_count", "statistic_type"] +) class _BenchmarkThreadsReport(object): @@ -301,16 +344,21 @@ class _BenchmarkThreadsReport(object): cedar_metric_type: str BENCHMARK_METRICS_TO_GATHER = { - "latency": - BenchmarkMetricInfo(local_name="cpu_time", cedar_name="latency_per_op", - cedar_metric_type="LATENCY"), - "instructions_per_iteration": - BenchmarkMetricInfo(local_name="instructions_per_iteration", - cedar_name="instructions_per_iteration", - cedar_metric_type="LATENCY"), - "cycles_per_iteration": - BenchmarkMetricInfo(local_name="cycles_per_iteration", - cedar_name="cycles_per_iteration", cedar_metric_type="LATENCY"), + "latency": BenchmarkMetricInfo( + local_name="cpu_time", + cedar_name="latency_per_op", + cedar_metric_type="LATENCY", + ), + "instructions_per_iteration": BenchmarkMetricInfo( + local_name="instructions_per_iteration", + cedar_name="instructions_per_iteration", + cedar_metric_type="LATENCY", + ), + "cycles_per_iteration": BenchmarkMetricInfo( + local_name="cycles_per_iteration", + cedar_name="cycles_per_iteration", + cedar_metric_type="LATENCY", + ), } # Map benchmark metric type to the type in Cedar @@ -342,7 +390,9 @@ class _BenchmarkThreadsReport(object): def __init__(self, context_dict): # `context_dict` was parsed from a json file and might have additional fields. - relevant = dict(filter(lambda e: e[0] in self.Context._fields, context_dict.items())) + relevant = dict( + filter(lambda e: e[0] in self.Context._fields, context_dict.items()) + ) self.context = self.Context(**relevant) # list of benchmark runs for each thread. @@ -468,16 +518,18 @@ class _BenchmarkThreadsReport(object): # cedar_type becomes `MEAN`. metric_name = f"{metric_cedar_name}_{aggregate_name}" - metric_cedar_type = self.AGGREGATE_TYPE_TO_CEDAR_METRIC_TYPE_MAP[ - aggregate_name] + metric_cedar_type = ( + self.AGGREGATE_TYPE_TO_CEDAR_METRIC_TYPE_MAP[aggregate_name] + ) else: # Call out what iteration this metric came from. For example, if we are looking at iteration 2 # and the `latency` metric, metric_name becomes `latency_2`. idx = report.get("repetition_index", 0) metric_name = f"{metric_cedar_name}_{idx}" - metric = CedarMetric(name=metric_name, type=metric_cedar_type, - value=metric_value) + metric = CedarMetric( + name=metric_name, type=metric_cedar_type, value=metric_value + ) threads = report["threads"] if threads in res: res[threads].append(metric) @@ -519,7 +571,7 @@ class _BenchmarkThreadsReport(object): statistic_type_candidate = name_str.rsplit("_", 1)[-1] # Remove the statistic type suffix from the name. if statistic_type_candidate == statistic_type: - name_str = name_str[:-len(statistic_type) - 1] + name_str = name_str[: -len(statistic_type) - 1] # Step 2: Get the thread count and name. thread_section = name_str.rsplit("/", 1)[-1] diff --git a/buildscripts/resmokelib/testing/hooks/hello_failures.py b/buildscripts/resmokelib/testing/hooks/hello_failures.py index 2506c61104b..1681d18e0f4 100644 --- a/buildscripts/resmokelib/testing/hooks/hello_failures.py +++ b/buildscripts/resmokelib/testing/hooks/hello_failures.py @@ -18,22 +18,28 @@ class HelloDelays(interface.Hook): """Initialize HelloDelays.""" description = "Sets Hello fault injections" interface.Hook.__init__(self, hook_logger, fixture, description) - self.js_filename = os.path.join("jstests", "hooks", "run_inject_hello_failures.js") - self.cleanup_js_filename = os.path.join("jstests", "hooks", "run_cleanup_hello_failures.js") + self.js_filename = os.path.join( + "jstests", "hooks", "run_inject_hello_failures.js" + ) + self.cleanup_js_filename = os.path.join( + "jstests", "hooks", "run_cleanup_hello_failures.js" + ) self.shell_options = None def before_test(self, test, test_report): """Each test will call this before it executes.""" - print('before_test hook starts injecting Hello failures') + print("before_test hook starts injecting Hello failures") hook_test_case = jsfile.DynamicJSTestCase.create_before_test( - test.logger, test, self, self.js_filename, self.shell_options) + test.logger, test, self, self.js_filename, self.shell_options + ) hook_test_case.configure(self.fixture) hook_test_case.run_dynamic_test(test_report) def after_test(self, test, test_report): """Each test will call this after it executes.""" - print('Cleanup hook is starting to remove Hello fail injections') + print("Cleanup hook is starting to remove Hello fail injections") hook_test_case = jsfile.DynamicJSTestCase.create_after_test( - test.logger, test, self, self.cleanup_js_filename, self.shell_options) + test.logger, test, self, self.cleanup_js_filename, self.shell_options + ) hook_test_case.configure(self.fixture) hook_test_case.run_dynamic_test(test_report) diff --git a/buildscripts/resmokelib/testing/hooks/initialsync.py b/buildscripts/resmokelib/testing/hooks/initialsync.py index 9b9c6bcb4ab..bab6f63ac94 100644 --- a/buildscripts/resmokelib/testing/hooks/initialsync.py +++ b/buildscripts/resmokelib/testing/hooks/initialsync.py @@ -34,8 +34,11 @@ class BackgroundInitialSync(interface.Hook): def __init__(self, hook_logger, fixture, n=DEFAULT_N, shell_options=None): """Initialize BackgroundInitialSync.""" if not isinstance(fixture, replicaset.ReplicaSetFixture): - raise ValueError("`fixture` must be an instance of ReplicaSetFixture, not {}".format( - fixture.__class__.__name__)) + raise ValueError( + "`fixture` must be an instance of ReplicaSetFixture, not {}".format( + fixture.__class__.__name__ + ) + ) description = "Background Initial Sync" interface.Hook.__init__(self, hook_logger, fixture, description) @@ -48,7 +51,8 @@ class BackgroundInitialSync(interface.Hook): def before_test(self, test, test_report): """Before test execution.""" hook_test_case = BackgroundInitialSyncTestCase.create_after_test( - test.logger, test, self, self._shell_options) + test.logger, test, self, self._shell_options + ) hook_test_case.configure(self.fixture) hook_test_case.run_dynamic_test(test_report) self.tests_run += 1 @@ -57,14 +61,26 @@ class BackgroundInitialSync(interface.Hook): class BackgroundInitialSyncTestCase(jsfile.DynamicJSTestCase): """BackgroundInitialSyncTestCase class.""" - JS_FILENAME = os.path.join("jstests", "hooks", "run_initial_sync_node_validation.js") + JS_FILENAME = os.path.join( + "jstests", "hooks", "run_initial_sync_node_validation.js" + ) INTERRUPTED_DUE_TO_REPL_STATE_CHANGE = 11602 INTERRUPTED_DUE_TO_STORAGE_CHANGE = 355 - def __init__(self, logger, test_name, description, base_test_name, hook, shell_options=None): + def __init__( + self, logger, test_name, description, base_test_name, hook, shell_options=None + ): """Initialize BackgroundInitialSyncTestCase.""" - jsfile.DynamicJSTestCase.__init__(self, logger, test_name, description, base_test_name, - hook, self.JS_FILENAME, shell_options) + jsfile.DynamicJSTestCase.__init__( + self, + logger, + test_name, + description, + base_test_name, + hook, + self.JS_FILENAME, + shell_options, + ) def run_test(self): """Execute test hook.""" @@ -75,20 +91,32 @@ class BackgroundInitialSyncTestCase(jsfile.DynamicJSTestCase): if self._hook.tests_run >= self._hook.n: self.logger.info( "%d tests have been run against the fixture, waiting for initial sync" - " node to go into SECONDARY state", self._hook.tests_run) + " node to go into SECONDARY state", + self._hook.tests_run, + ) self._hook.tests_run = 0 cmd = bson.SON( - [("replSetTest", 1), ("waitForMemberState", 2), - ("timeoutMillis", - fixture_interface.ReplFixture.AWAIT_REPL_TIMEOUT_FOREVER_MINS * 60 * 1000)]) + [ + ("replSetTest", 1), + ("waitForMemberState", 2), + ( + "timeoutMillis", + fixture_interface.ReplFixture.AWAIT_REPL_TIMEOUT_FOREVER_MINS + * 60 + * 1000, + ), + ] + ) while True: try: sync_node_conn.admin.command(cmd) break except pymongo.errors.OperationFailure as err: - if err.code not in (self.INTERRUPTED_DUE_TO_REPL_STATE_CHANGE, - self.INTERRUPTED_DUE_TO_STORAGE_CHANGE): + if err.code not in ( + self.INTERRUPTED_DUE_TO_REPL_STATE_CHANGE, + self.INTERRUPTED_DUE_TO_STORAGE_CHANGE, + ): raise msg = ( "Interrupted while waiting for node to reach secondary state, retrying: {}" @@ -104,18 +132,24 @@ class BackgroundInitialSyncTestCase(jsfile.DynamicJSTestCase): if state != 2: if self._hook.tests_run == 0: msg = "Initial sync node did not catch up after waiting 24 hours" - self.logger.exception("{0} failed: {1}".format(self._hook.description, msg)) + self.logger.exception( + "{0} failed: {1}".format(self._hook.description, msg) + ) raise errors.TestFailure(msg) self.logger.info( "Initial sync node is in state %d, not state SECONDARY (2)." - " Skipping BackgroundInitialSync hook for %s", state, self._base_test_name) + " Skipping BackgroundInitialSync hook for %s", + state, + self._base_test_name, + ) # If we have not restarted initial sync since the last time we ran the data # validation, restart initial sync with a 20% probability. if self._hook.random_restarts < 1 and random.random() < 0.2: self.logger.info( - "randomly restarting initial sync in the middle of initial sync") + "randomly restarting initial sync in the middle of initial sync" + ) self.__restart_init_sync(sync_node) self._hook.random_restarts += 1 return @@ -124,7 +158,8 @@ class BackgroundInitialSyncTestCase(jsfile.DynamicJSTestCase): # STARTUP2 state and replSetGetStatus will succeed after the next test. self.logger.info( "replSetGetStatus call failed in BackgroundInitialSync hook, skipping hook for %s", - self._base_test_name) + self._base_test_name, + ) return self._hook.random_restarts = 0 @@ -161,8 +196,11 @@ class IntermediateInitialSync(interface.Hook): def __init__(self, hook_logger, fixture, n=DEFAULT_N): """Initialize IntermediateInitialSync.""" if not isinstance(fixture, replicaset.ReplicaSetFixture): - raise ValueError("`fixture` must be an instance of ReplicaSetFixture, not {}".format( - fixture.__class__.__name__)) + raise ValueError( + "`fixture` must be an instance of ReplicaSetFixture, not {}".format( + fixture.__class__.__name__ + ) + ) description = "Intermediate Initial Sync" interface.Hook.__init__(self, hook_logger, fixture, description) @@ -185,7 +223,9 @@ class IntermediateInitialSync(interface.Hook): if not self._should_run_after_test(): return - hook_test_case = IntermediateInitialSyncTestCase.create_after_test(test.logger, test, self) + hook_test_case = IntermediateInitialSyncTestCase.create_after_test( + test.logger, test, self + ) hook_test_case.configure(self.fixture) hook_test_case.run_dynamic_test(test_report) @@ -193,12 +233,15 @@ class IntermediateInitialSync(interface.Hook): class IntermediateInitialSyncTestCase(jsfile.DynamicJSTestCase): """IntermediateInitialSyncTestCase class.""" - JS_FILENAME = os.path.join("jstests", "hooks", "run_initial_sync_node_validation.js") + JS_FILENAME = os.path.join( + "jstests", "hooks", "run_initial_sync_node_validation.js" + ) def __init__(self, logger, test_name, description, base_test_name, hook): """Initialize IntermediateInitialSyncTestCase.""" - jsfile.DynamicJSTestCase.__init__(self, logger, test_name, description, base_test_name, - hook, self.JS_FILENAME) + jsfile.DynamicJSTestCase.__init__( + self, logger, test_name, description, base_test_name, hook, self.JS_FILENAME + ) def run_test(self): """Execute test hook.""" @@ -214,19 +257,30 @@ class IntermediateInitialSyncTestCase(jsfile.DynamicJSTestCase): # Do initial sync round. self.logger.info("Waiting for initial sync node to go into SECONDARY state") cmd = bson.SON( - [("replSetTest", 1), ("waitForMemberState", 2), - ("timeoutMillis", - fixture_interface.ReplFixture.AWAIT_REPL_TIMEOUT_FOREVER_MINS * 60 * 1000)]) + [ + ("replSetTest", 1), + ("waitForMemberState", 2), + ( + "timeoutMillis", + fixture_interface.ReplFixture.AWAIT_REPL_TIMEOUT_FOREVER_MINS + * 60 + * 1000, + ), + ] + ) while True: try: sync_node_conn.admin.command(cmd) break except pymongo.errors.OperationFailure as err: - if err.code not in (self.INTERRUPTED_DUE_TO_REPL_STATE_CHANGE, - self.INTERRUPTED_DUE_TO_STORAGE_CHANGE): + if err.code not in ( + self.INTERRUPTED_DUE_TO_REPL_STATE_CHANGE, + self.INTERRUPTED_DUE_TO_STORAGE_CHANGE, + ): raise - msg = ("Interrupted while waiting for node to reach secondary state, retrying: {}" - ).format(err) + msg = ( + "Interrupted while waiting for node to reach secondary state, retrying: {}" + ).format(err) self.logger.error(msg) # Run data validation and dbhash checking. diff --git a/buildscripts/resmokelib/testing/hooks/interface.py b/buildscripts/resmokelib/testing/hooks/interface.py index 007ffc20747..0acaabb0c1c 100644 --- a/buildscripts/resmokelib/testing/hooks/interface.py +++ b/buildscripts/resmokelib/testing/hooks/interface.py @@ -38,7 +38,8 @@ class Hook(object, metaclass=registry.make_registry_metaclass(_HOOKS)): # pylin if self.IS_BACKGROUND is None: raise ValueError( - "Concrete Hook subclasses must override the IS_BACKGROUND class property") + "Concrete Hook subclasses must override the IS_BACKGROUND class property" + ) def before_suite(self, test_report): """Test runner calls this exactly once before they start running the suite.""" @@ -97,7 +98,9 @@ class DynamicTestCase(testcase.TestCase): # pylint: disable=abstract-method base_test_name = base_test.short_name() test_name = cls._make_test_name(base_test_name, hook) description = "{} before running '{}'".format(hook.description, base_test_name) - return cls(logger, test_name, description, base_test_name, hook, *args, **kwargs) + return cls( + logger, test_name, description, base_test_name, hook, *args, **kwargs + ) @classmethod def create_after_test(cls, logger, base_test, hook, *args, **kwargs): @@ -105,7 +108,9 @@ class DynamicTestCase(testcase.TestCase): # pylint: disable=abstract-method base_test_name = base_test.short_name() test_name = cls._make_test_name(base_test_name, hook) description = "{} after running '{}'".format(hook.description, base_test_name) - return cls(logger, test_name, description, base_test_name, hook, *args, **kwargs) + return cls( + logger, test_name, description, base_test_name, hook, *args, **kwargs + ) @staticmethod def _make_test_name(base_test_name, hook): diff --git a/buildscripts/resmokelib/testing/hooks/jsfile.py b/buildscripts/resmokelib/testing/hooks/jsfile.py index 9a27f5298c1..58826a297c8 100644 --- a/buildscripts/resmokelib/testing/hooks/jsfile.py +++ b/buildscripts/resmokelib/testing/hooks/jsfile.py @@ -12,7 +12,9 @@ class JSHook(interface.Hook): REGISTERED_NAME = registry.LEAVE_UNREGISTERED - def __init__(self, hook_logger, fixture, js_filename, description, shell_options=None): + def __init__( + self, hook_logger, fixture, js_filename, description, shell_options=None + ): """Initialize JSHook.""" interface.Hook.__init__(self, hook_logger, fixture, description) self._js_filename = js_filename @@ -31,8 +33,9 @@ class JSHook(interface.Hook): if not self._should_run_after_test(): return - hook_test_case = DynamicJSTestCase.create_after_test(test.logger, test, self, - self._js_filename, self._shell_options) + hook_test_case = DynamicJSTestCase.create_after_test( + test.logger, test, self, self._js_filename, self._shell_options + ) hook_test_case.configure(self.fixture) hook_test_case.run_dynamic_test(test_report) @@ -68,14 +71,22 @@ class PerClusterDataConsistencyHook(DataConsistencyHook): """After test execution.""" # Break the fixture down into its participant clusters if it is a MultiClusterFixture. - clusters = [self.fixture] if not isinstance(self.fixture, MultiClusterFixture)\ - else self.fixture.get_independent_clusters() + clusters = ( + [self.fixture] + if not isinstance(self.fixture, MultiClusterFixture) + else self.fixture.get_independent_clusters() + ) for cluster in clusters: - self.logger.info("Running jsfile '%s' on '%s' with driver URL '%s'", self._js_filename, - cluster, cluster.get_driver_connection_url()) + self.logger.info( + "Running jsfile '%s' on '%s' with driver URL '%s'", + self._js_filename, + cluster, + cluster.get_driver_connection_url(), + ) hook_test_case = DynamicJSTestCase.create_after_test( - test.logger, test, self, self._js_filename, self._shell_options) + test.logger, test, self, self._js_filename, self._shell_options + ) hook_test_case.configure(cluster) hook_test_case.run_dynamic_test(test_report) @@ -83,13 +94,23 @@ class PerClusterDataConsistencyHook(DataConsistencyHook): class DynamicJSTestCase(interface.DynamicTestCase): """A dynamic TestCase that runs a JavaScript file.""" - def __init__(self, logger, test_name, description, base_test_name, hook, js_filename, - shell_options=None): + def __init__( + self, + logger, + test_name, + description, + base_test_name, + hook, + js_filename, + shell_options=None, + ): """Initialize DynamicJSTestCase.""" - interface.DynamicTestCase.__init__(self, logger, test_name, description, base_test_name, - hook) - self._js_test_builder = jstest.JSTestCaseBuilder(logger, js_filename, self.id(), - shell_options=shell_options) + interface.DynamicTestCase.__init__( + self, logger, test_name, description, base_test_name, hook + ) + self._js_test_builder = jstest.JSTestCaseBuilder( + logger, js_filename, self.id(), shell_options=shell_options + ) self._js_test_case = None def override_logger(self, new_logger): @@ -106,7 +127,9 @@ class DynamicJSTestCase(interface.DynamicTestCase): """Configure the fixture.""" super().configure(fixture, *args, **kwargs) self._js_test_builder.configure(fixture, *args, **kwargs) - self._js_test_case = self._js_test_builder.create_test_case_for_thread(self.logger) + self._js_test_case = self._js_test_builder.create_test_case_for_thread( + self.logger + ) def run_test(self): """Execute the test.""" diff --git a/buildscripts/resmokelib/testing/hooks/lifecycle.py b/buildscripts/resmokelib/testing/hooks/lifecycle.py index 15b311919a8..cf3cadc0a02 100644 --- a/buildscripts/resmokelib/testing/hooks/lifecycle.py +++ b/buildscripts/resmokelib/testing/hooks/lifecycle.py @@ -6,7 +6,9 @@ import threading import buildscripts.resmokelib.utils.filesystem as fs -ActionFiles = collections.namedtuple("ActionFiles", ["permitted", "idle_request", "idle_ack"]) +ActionFiles = collections.namedtuple( + "ActionFiles", ["permitted", "idle_request", "idle_ack"] +) class FlagBasedThreadLifecycle(object): diff --git a/buildscripts/resmokelib/testing/hooks/magic_restore.py b/buildscripts/resmokelib/testing/hooks/magic_restore.py index 8c7ae416e6c..f4475cbc9b7 100644 --- a/buildscripts/resmokelib/testing/hooks/magic_restore.py +++ b/buildscripts/resmokelib/testing/hooks/magic_restore.py @@ -51,14 +51,17 @@ class MagicRestoreEveryN(interface.Hook): if run_backup: # Collect data files from backup cursor - hook_test_case = BackupCursorTestCase.create_after_test(test.logger, test, self) + hook_test_case = BackupCursorTestCase.create_after_test( + test.logger, test, self + ) hook_test_case.configure(self.fixture) hook_test_case.run_dynamic_test(test_report) if run_restore: # Run the magic restore procedure and run a data consistency check magic_restore_test_case = MagicRestoreTestCase.create_after_test( - test.logger, test, self) + test.logger, test, self + ) magic_restore_test_case.configure(self.fixture) magic_restore_test_case.run_dynamic_test(test_report) @@ -80,8 +83,9 @@ class BackupCursorTestCase(jsfile.DynamicJSTestCase): def __init__(self, logger, test_name, description, base_test_name, hook): """Initialize BackupCursorTestCase.""" - jsfile.DynamicJSTestCase.__init__(self, logger, test_name, description, base_test_name, - hook, self.JS_FILENAME) + jsfile.DynamicJSTestCase.__init__( + self, logger, test_name, description, base_test_name, hook, self.JS_FILENAME + ) def run_test(self): """Execute test hook.""" @@ -97,8 +101,9 @@ class MagicRestoreTestCase(jsfile.DynamicJSTestCase): def __init__(self, logger, test_name, description, base_test_name, hook): """Initialize MagicRestoreTestCase.""" - jsfile.DynamicJSTestCase.__init__(self, logger, test_name, description, base_test_name, - hook, self.JS_FILENAME) + jsfile.DynamicJSTestCase.__init__( + self, logger, test_name, description, base_test_name, hook, self.JS_FILENAME + ) def run_test(self): """Execute test hook.""" diff --git a/buildscripts/resmokelib/testing/hooks/metadata_consistency.py b/buildscripts/resmokelib/testing/hooks/metadata_consistency.py index 48214191be8..e69cfc8bbc2 100644 --- a/buildscripts/resmokelib/testing/hooks/metadata_consistency.py +++ b/buildscripts/resmokelib/testing/hooks/metadata_consistency.py @@ -15,7 +15,7 @@ from buildscripts.resmokelib.testing.hooks.background_job import ( _ContinuousDynamicJSTestCase, ) -_IS_WINDOWS = (sys.platform == "win32") +_IS_WINDOWS = sys.platform == "win32" class CheckMetadataConsistencyInBackground(jsfile.PerClusterDataConsistencyHook): @@ -31,26 +31,35 @@ class CheckMetadataConsistencyInBackground(jsfile.PerClusterDataConsistencyHook) "bazel-bin/install/transport_integration_test", # Skip tests that update the internalDocumentSourceGroupMaxMemoryBytes parameter and make # checkMetadataConsistency fail with QueryExceededMemoryLimitNoDiskUseAllowed error. - "jstests/aggregation/sources/unionWith/unionWith.js" + "jstests/aggregation/sources/unionWith/unionWith.js", ] if _IS_WINDOWS: - SKIP_TESTS = [testname_utils.denormalize_test_file(path)[1] for path in SKIP_TESTS] + SKIP_TESTS = [ + testname_utils.denormalize_test_file(path)[1] for path in SKIP_TESTS + ] def __init__(self, hook_logger, fixture, shell_options=None): """Initialize CheckMetadataConsistencyInBackground.""" - if not isinstance(fixture, shardedcluster.ShardedClusterFixture) and not isinstance( - fixture, multi_sharded_cluster.MultiShardedClusterFixture): + if not isinstance( + fixture, shardedcluster.ShardedClusterFixture + ) and not isinstance(fixture, multi_sharded_cluster.MultiShardedClusterFixture): raise ValueError( f"'fixture' must be an instance of ShardedClusterFixture or MultiShardedClusterFixture, but got" - f" {fixture.__class__.__name__}") + f" {fixture.__class__.__name__}" + ) - description = "Perform consistency checks between the config database and metadata " \ - "stored/cached in the shards" - js_filename = os.path.join("jstests", "hooks", "run_check_metadata_consistency.js") - super().__init__(hook_logger, fixture, js_filename, description, - shell_options=shell_options) + description = ( + "Perform consistency checks between the config database and metadata " + "stored/cached in the shards" + ) + js_filename = os.path.join( + "jstests", "hooks", "run_check_metadata_consistency.js" + ) + super().__init__( + hook_logger, fixture, js_filename, description, shell_options=shell_options + ) self._background_job = None @@ -79,21 +88,26 @@ class CheckMetadataConsistencyInBackground(jsfile.PerClusterDataConsistencyHook) return # TODO SERVER-75675 do not skip index consistency check - shell_options = self._shell_options.copy() if self._shell_options is not None else {} + shell_options = ( + self._shell_options.copy() if self._shell_options is not None else {} + ) if "global_vars" not in shell_options: shell_options["global_vars"] = {} if "TestData" not in shell_options["global_vars"]: shell_options["global_vars"]["TestData"] = {} shell_options["global_vars"]["TestData"][ - "skipCheckingIndexesConsistentAcrossCluster"] = True + "skipCheckingIndexesConsistentAcrossCluster" + ] = True hook_test_case = _ContinuousDynamicJSTestCase.create_before_test( - test.logger, test, self, self._js_filename, shell_options) + test.logger, test, self, self._js_filename, shell_options + ) hook_test_case.configure(self.fixture) if test.test_name in self.SKIP_TESTS: - self.logger.info("Metadata consistency check explicitely disabled for %s", - test.test_name) + self.logger.info( + "Metadata consistency check explicitely disabled for %s", test.test_name + ) return self.logger.info("Resuming background metadata consistency checker thread") @@ -122,5 +136,6 @@ class CheckMetadataConsistencyInBackground(jsfile.PerClusterDataConsistencyHook) else: self.logger.error( "Encountered an error inside background metadata consistency checker thread", - exc_info=self._background_job.exc_info) + exc_info=self._background_job.exc_info, + ) raise self._background_job.exc_info[1] diff --git a/buildscripts/resmokelib/testing/hooks/oplog.py b/buildscripts/resmokelib/testing/hooks/oplog.py index 94baea30482..c49e9f90706 100644 --- a/buildscripts/resmokelib/testing/hooks/oplog.py +++ b/buildscripts/resmokelib/testing/hooks/oplog.py @@ -11,9 +11,16 @@ class CheckReplOplogs(jsfile.PerClusterDataConsistencyHook): IS_BACKGROUND = False def __init__( # pylint: disable=super-init-not-called - self, hook_logger, fixture, shell_options=None): + self, hook_logger, fixture, shell_options=None + ): """Initialize CheckReplOplogs.""" description = "Check oplogs of all replica set members" js_filename = os.path.join("jstests", "hooks", "run_check_repl_oplogs.js") jsfile.JSHook.__init__( # pylint: disable=non-parent-init-called - self, hook_logger, fixture, js_filename, description, shell_options=shell_options) + self, + hook_logger, + fixture, + js_filename, + description, + shell_options=shell_options, + ) diff --git a/buildscripts/resmokelib/testing/hooks/orphans.py b/buildscripts/resmokelib/testing/hooks/orphans.py index 2d2d3947aec..5f156950aec 100644 --- a/buildscripts/resmokelib/testing/hooks/orphans.py +++ b/buildscripts/resmokelib/testing/hooks/orphans.py @@ -15,10 +15,15 @@ class CheckOrphansDeleted(jsfile.DataConsistencyHook): """Initialize CheckOrphansDeleted.""" if not isinstance(fixture, shardedcluster.ShardedClusterFixture): - raise ValueError(f"'fixture' must be an instance of ShardedClusterFixture, but got" - f" {fixture.__class__.__name__}") + raise ValueError( + f"'fixture' must be an instance of ShardedClusterFixture, but got" + f" {fixture.__class__.__name__}" + ) description = "Check orphan documents are eventually deleted" - js_filename = os.path.join("jstests", "hooks", "run_check_orphans_are_deleted.js") - super().__init__(hook_logger, fixture, js_filename, description, - shell_options=shell_options) + js_filename = os.path.join( + "jstests", "hooks", "run_check_orphans_are_deleted.js" + ) + super().__init__( + hook_logger, fixture, js_filename, description, shell_options=shell_options + ) diff --git a/buildscripts/resmokelib/testing/hooks/periodic_kill_secondaries.py b/buildscripts/resmokelib/testing/hooks/periodic_kill_secondaries.py index 4af7ca6322a..84e76f17763 100644 --- a/buildscripts/resmokelib/testing/hooks/periodic_kill_secondaries.py +++ b/buildscripts/resmokelib/testing/hooks/periodic_kill_secondaries.py @@ -33,15 +33,21 @@ class PeriodicKillSecondaries(interface.Hook): def __init__(self, hook_logger, rs_fixture, period_secs=DEFAULT_PERIOD_SECS): """Initialize PeriodicKillSecondaries.""" if not isinstance(rs_fixture, replicaset.ReplicaSetFixture): - raise TypeError("{} either does not support replication or does not support writing to" - " its oplog early".format(rs_fixture.__class__.__name__)) + raise TypeError( + "{} either does not support replication or does not support writing to" + " its oplog early".format(rs_fixture.__class__.__name__) + ) if rs_fixture.num_nodes <= 1: - raise ValueError("PeriodicKillSecondaries requires the replica set to contain at least" - " one secondary") + raise ValueError( + "PeriodicKillSecondaries requires the replica set to contain at least" + " one secondary" + ) - description = ("PeriodicKillSecondaries (kills the secondary after running tests for a" - " configurable period of time)") + description = ( + "PeriodicKillSecondaries (kills the secondary after running tests for a" + " configurable period of time)" + ) interface.Hook.__init__(self, hook_logger, rs_fixture, description) self._period_secs = period_secs @@ -85,7 +91,8 @@ class PeriodicKillSecondaries(interface.Hook): def _run(self, test_report): try: hook_test_case = PeriodicKillSecondariesTestCase.create_after_test( - self.logger, self._last_test, self, test_report) + self.logger, self._last_test, self, test_report + ) hook_test_case.configure(self.fixture) hook_test_case.run_dynamic_test(test_report) finally: @@ -101,13 +108,20 @@ class PeriodicKillSecondaries(interface.Hook): client = secondary.mongo_client() try: client.admin.command( - bson.SON([("configureFailPoint", "rsSyncApplyStop"), ("mode", "alwaysOn")])) + bson.SON( + [("configureFailPoint", "rsSyncApplyStop"), ("mode", "alwaysOn")] + ) + ) except pymongo.errors.OperationFailure as err: - self.logger.exception("Unable to disable oplog application on the mongod on port %d", - secondary.port) + self.logger.exception( + "Unable to disable oplog application on the mongod on port %d", + secondary.port, + ) raise errors.ServerFailure( "Unable to disable oplog application on the mongod on port {}: {}".format( - secondary.port, err.args[0])) + secondary.port, err.args[0] + ) + ) def _disable_rssyncapplystop(self, secondary): # Disable the "rsSyncApplyStop" failpoint on the secondary to have it resume applying @@ -115,13 +129,18 @@ class PeriodicKillSecondaries(interface.Hook): client = secondary.mongo_client() try: client.admin.command( - bson.SON([("configureFailPoint", "rsSyncApplyStop"), ("mode", "off")])) + bson.SON([("configureFailPoint", "rsSyncApplyStop"), ("mode", "off")]) + ) except pymongo.errors.OperationFailure as err: - self.logger.exception("Unable to re-enable oplog application on the mongod on port %d", - secondary.port) + self.logger.exception( + "Unable to re-enable oplog application on the mongod on port %d", + secondary.port, + ) raise errors.ServerFailure( "Unable to re-enable oplog application on the mongod on port {}: {}".format( - secondary.port, err.args[0])) + secondary.port, err.args[0] + ) + ) class PeriodicKillSecondariesTestCase(interface.DynamicTestCase): @@ -130,10 +149,13 @@ class PeriodicKillSecondariesTestCase(interface.DynamicTestCase): INTERRUPTED_DUE_TO_REPL_STATE_CHANGE = 11602 INTERRUPTED_DUE_TO_STORAGE_CHANGE = 355 - def __init__(self, logger, test_name, description, base_test_name, hook, test_report): + def __init__( + self, logger, test_name, description, base_test_name, hook, test_report + ): """Initialize PeriodicKillSecondariesTestCase.""" - interface.DynamicTestCase.__init__(self, logger, test_name, description, base_test_name, - hook) + interface.DynamicTestCase.__init__( + self, logger, test_name, description, base_test_name, hook + ) self._test_report = test_report def run_test(self): @@ -182,7 +204,10 @@ class PeriodicKillSecondariesTestCase(interface.DynamicTestCase): if not secondary.is_running(): raise errors.ServerFailure( "mongod on port {} was expected to be running in" - " PeriodicKillSecondaries.after_test(), but wasn't.".format(secondary.port)) + " PeriodicKillSecondaries.after_test(), but wasn't.".format( + secondary.port + ) + ) self.logger.info("Killing the secondary on port %d...", secondary.port) secondary.mongod.stop(mode=fixture.TeardownMode.KILL) @@ -206,9 +231,12 @@ class PeriodicKillSecondariesTestCase(interface.DynamicTestCase): # potential consistency issues before we validate the config.system.preimages # collection. if "set_parameters" in secondary.mongod_options: - secondary.mongod_options["set_parameters"]["disableExpiredPreImagesRemover"] = True secondary.mongod_options["set_parameters"][ - "disableExpiredChangeCollectionRemover"] = True + "disableExpiredPreImagesRemover" + ] = True + secondary.mongod_options["set_parameters"][ + "disableExpiredChangeCollectionRemover" + ] = True else: secondary.mongod_options["set_parameters"] = { "disableExpiredPreImagesRemover": True, @@ -217,7 +245,9 @@ class PeriodicKillSecondariesTestCase(interface.DynamicTestCase): self.logger.info( "Restarting the secondary on port %d as a replica set node with" - " its data files intact...", secondary.port) + " its data files intact...", + secondary.port, + ) # Start the 'secondary' mongod back up as part of the replica set and wait for it to # reach state SECONDARY. secondary.setup() @@ -230,23 +260,29 @@ class PeriodicKillSecondariesTestCase(interface.DynamicTestCase): except errors.ServerFailure: raise errors.ServerFailure( "{} did not exit cleanly after reconciling the end of its oplog".format( - secondary)) + secondary + ) + ) - self.logger.info("Starting the fixture back up again with its data files intact for final" - " validation...") + self.logger.info( + "Starting the fixture back up again with its data files intact for final" + " validation..." + ) try: self.fixture.setup() self.logger.info(fixture.create_fixture_table(self.fixture)) self.fixture.await_ready() finally: - for (i, node) in enumerate(self.fixture.nodes): + for i, node in enumerate(self.fixture.nodes): node.preserve_dbpath = preserve_dbpaths[i] def _validate_collections(self, test_report): validate_test_case = validate.ValidateCollections( - self._hook.logger, self.fixture, - {'global_vars': {'TestData': {'skipEnforceFastCountOnValidate': True}}}) + self._hook.logger, + self.fixture, + {"global_vars": {"TestData": {"skipEnforceFastCountOnValidate": True}}}, + ) validate_test_case.before_suite(test_report) validate_test_case.before_test(self, test_report) validate_test_case.after_test(self, test_report) @@ -268,15 +304,19 @@ class PeriodicKillSecondariesTestCase(interface.DynamicTestCase): def _check_pre_images_consistency(self, test_report): preimages_test_case = preimages_consistency.CheckReplPreImagesConsistency( - self._hook.logger, self.fixture) + self._hook.logger, self.fixture + ) preimages_test_case.before_suite(test_report) preimages_test_case.before_test(self, test_report) preimages_test_case.after_test(self, test_report) preimages_test_case.after_suite(test_report) def _check_change_collection_consistency(self, test_report): - change_collection_test_case = change_collection_consistency.CheckReplChangeCollectionConsistency( - self._hook.logger, self.fixture) + change_collection_test_case = ( + change_collection_consistency.CheckReplChangeCollectionConsistency( + self._hook.logger, self.fixture + ) + ) change_collection_test_case.before_suite(test_report) change_collection_test_case.before_test(self, test_report) change_collection_test_case.after_test(self, test_report) @@ -294,13 +334,20 @@ class PeriodicKillSecondariesTestCase(interface.DynamicTestCase): self.fixture.teardown() except errors.ServerFailure: raise errors.ServerFailure( - "{} did not exit cleanly after verifying data consistency".format(self.fixture)) + "{} did not exit cleanly after verifying data consistency".format( + self.fixture + ) + ) for secondary in self.fixture.get_secondaries(): # We re-enable the removers for pre-images and change collections. These were disabled # before re-joining the replSet as a secondary during the consistency checks. - secondary.mongod_options["set_parameters"].pop("disableExpiredPreImagesRemover") - secondary.mongod_options["set_parameters"].pop("disableExpiredChangeCollectionRemover") + secondary.mongod_options["set_parameters"].pop( + "disableExpiredPreImagesRemover" + ) + secondary.mongod_options["set_parameters"].pop( + "disableExpiredChangeCollectionRemover" + ) self.logger.info("Starting the fixture back up again with no data...") self.fixture.setup() @@ -321,7 +368,9 @@ class PeriodicKillSecondariesTestCase(interface.DynamicTestCase): self.logger.info( "Restarting the secondary on port %d as a standalone node with" - " its data files intact...", secondary.port) + " its data files intact...", + secondary.port, + ) try: secondary.setup() @@ -329,16 +378,22 @@ class PeriodicKillSecondariesTestCase(interface.DynamicTestCase): secondary.await_ready() client = secondary.mongo_client() - oplog_truncate_after_doc = client.local["replset.oplogTruncateAfterPoint"].find_one() - recovery_timestamp_res = client.admin.command("replSetTest", - getLastStableRecoveryTimestamp=True) - latest_oplog_doc = client.local["oplog.rs"].find_one(sort=[("$natural", - pymongo.DESCENDING)]) + oplog_truncate_after_doc = client.local[ + "replset.oplogTruncateAfterPoint" + ].find_one() + recovery_timestamp_res = client.admin.command( + "replSetTest", getLastStableRecoveryTimestamp=True + ) + latest_oplog_doc = client.local["oplog.rs"].find_one( + sort=[("$natural", pymongo.DESCENDING)] + ) - self.logger.info("Checking replication invariants. oplogTruncateAfterPoint: {}," - " stable recovery timestamp: {}, latest oplog doc: {}".format( - oplog_truncate_after_doc, recovery_timestamp_res, - latest_oplog_doc)) + self.logger.info( + "Checking replication invariants. oplogTruncateAfterPoint: {}," + " stable recovery timestamp: {}, latest oplog doc: {}".format( + oplog_truncate_after_doc, recovery_timestamp_res, latest_oplog_doc + ) + ) null_ts = bson.Timestamp(0, 0) @@ -349,45 +404,61 @@ class PeriodicKillSecondariesTestCase(interface.DynamicTestCase): latest_oplog_entry_ts = latest_oplog_doc.get("ts") if latest_oplog_entry_ts is None: raise errors.ServerFailure( - "Latest oplog entry had no 'ts' field: {}".format(latest_oplog_doc)) + "Latest oplog entry had no 'ts' field: {}".format(latest_oplog_doc) + ) # The "lastStableRecoveryTimestamp" field is present if the storage engine supports # "recover to a timestamp". If it's a null timestamp on a durable storage engine, that # means we do not yet have a stable checkpoint timestamp and must be restarting at the # top of the oplog. Since we wait for a stable recovery timestamp at test fixture setup, # we should never encounter a null timestamp here. - recovery_timestamp = recovery_timestamp_res.get("lastStableRecoveryTimestamp") + recovery_timestamp = recovery_timestamp_res.get( + "lastStableRecoveryTimestamp" + ) if recovery_timestamp == null_ts: raise errors.ServerFailure( - "Received null stable recovery timestamp {}".format(recovery_timestamp_res)) + "Received null stable recovery timestamp {}".format( + recovery_timestamp_res + ) + ) # On a storage engine that doesn't support "recover to a timestamp", we default to null. if recovery_timestamp is None: recovery_timestamp = null_ts # last stable recovery timestamp <= top of oplog if not recovery_timestamp <= latest_oplog_entry_ts: - raise errors.ServerFailure("The condition last stable recovery timestamp <= top" - " of oplog ({} <= {}) doesn't hold:" - " getLastStableRecoveryTimestamp result={}," - " latest oplog entry={}".format( - recovery_timestamp, latest_oplog_entry_ts, - recovery_timestamp_res, latest_oplog_doc)) + raise errors.ServerFailure( + "The condition last stable recovery timestamp <= top" + " of oplog ({} <= {}) doesn't hold:" + " getLastStableRecoveryTimestamp result={}," + " latest oplog entry={}".format( + recovery_timestamp, + latest_oplog_entry_ts, + recovery_timestamp_res, + latest_oplog_doc, + ) + ) try: secondary.teardown() except errors.ServerFailure: raise errors.ServerFailure( "{} did not exit cleanly after being started up as a standalone".format( - secondary)) + secondary + ) + ) except pymongo.errors.OperationFailure as err: self.logger.exception( "Failed to read the minValid document, the oplogTruncateAfterPoint document," " the last stable recovery timestamp, or the latest oplog entry from the" - " mongod on port %d", secondary.port) + " mongod on port %d", + secondary.port, + ) raise errors.ServerFailure( "Failed to read the minValid document, the oplogTruncateAfterPoint document," " the last stable recovery timestamp, or the latest oplog entry from the" - " mongod on port {}: {}".format(secondary.port, err.args[0])) + " mongod on port {}: {}".format(secondary.port, err.args[0]) + ) finally: # Set the secondary's options back to their original values. if replset_name: @@ -400,25 +471,39 @@ class PeriodicKillSecondariesTestCase(interface.DynamicTestCase): while True: try: client.admin.command( - bson.SON([ - ("replSetTest", 1), - ("waitForMemberState", 2), # 2 = SECONDARY - ("timeoutMillis", - fixture.ReplFixture.AWAIT_REPL_TIMEOUT_FOREVER_MINS * 60 * 1000) - ])) + bson.SON( + [ + ("replSetTest", 1), + ("waitForMemberState", 2), # 2 = SECONDARY + ( + "timeoutMillis", + fixture.ReplFixture.AWAIT_REPL_TIMEOUT_FOREVER_MINS + * 60 + * 1000, + ), + ] + ) + ) break except pymongo.errors.OperationFailure as err: - if err.code not in (self.INTERRUPTED_DUE_TO_REPL_STATE_CHANGE, - self.INTERRUPTED_DUE_TO_STORAGE_CHANGE): + if err.code not in ( + self.INTERRUPTED_DUE_TO_REPL_STATE_CHANGE, + self.INTERRUPTED_DUE_TO_STORAGE_CHANGE, + ): self.logger.exception( "mongod on port %d failed to reach state SECONDARY after %d seconds", - secondary.port, fixture.ReplFixture.AWAIT_REPL_TIMEOUT_FOREVER_MINS * 60) + secondary.port, + fixture.ReplFixture.AWAIT_REPL_TIMEOUT_FOREVER_MINS * 60, + ) raise errors.ServerFailure( - "mongod on port {} failed to reach state SECONDARY after {} seconds: {}". - format(secondary.port, - fixture.ReplFixture.AWAIT_REPL_TIMEOUT_FOREVER_MINS * 60, - err.args[0])) + "mongod on port {} failed to reach state SECONDARY after {} seconds: {}".format( + secondary.port, + fixture.ReplFixture.AWAIT_REPL_TIMEOUT_FOREVER_MINS * 60, + err.args[0], + ) + ) - msg = ("Interrupted while waiting for node to reach secondary state, retrying: {}" - ).format(err) + msg = ( + "Interrupted while waiting for node to reach secondary state, retrying: {}" + ).format(err) self.logger.error(msg) diff --git a/buildscripts/resmokelib/testing/hooks/preimages_consistency.py b/buildscripts/resmokelib/testing/hooks/preimages_consistency.py index 291bbc5b798..e32a1b435d8 100644 --- a/buildscripts/resmokelib/testing/hooks/preimages_consistency.py +++ b/buildscripts/resmokelib/testing/hooks/preimages_consistency.py @@ -11,9 +11,16 @@ class CheckReplPreImagesConsistency(jsfile.PerClusterDataConsistencyHook): IS_BACKGROUND = False def __init__( # pylint: disable=super-init-not-called - self, hook_logger, fixture, shell_options=None): + self, hook_logger, fixture, shell_options=None + ): """Initialize CheckReplPreImagesConsistency.""" description = "Check pre-images of all replica set members" js_filename = os.path.join("jstests", "hooks", "run_check_repl_pre_images.js") jsfile.JSHook.__init__( # pylint: disable=non-parent-init-called - self, hook_logger, fixture, js_filename, description, shell_options=shell_options) + self, + hook_logger, + fixture, + js_filename, + description, + shell_options=shell_options, + ) diff --git a/buildscripts/resmokelib/testing/hooks/reconfig_background.py b/buildscripts/resmokelib/testing/hooks/reconfig_background.py index 2f3c4c4eb05..c902bda359d 100644 --- a/buildscripts/resmokelib/testing/hooks/reconfig_background.py +++ b/buildscripts/resmokelib/testing/hooks/reconfig_background.py @@ -22,8 +22,14 @@ class DoReconfigInBackground(jsfile.JSHook): """Initialize DoReconfigInBackground.""" description = "Run reconfigs against the primary while the test is running." js_filename = os.path.join("jstests", "hooks", "run_reconfig_background.js") - jsfile.JSHook.__init__(self, hook_logger, fixture, js_filename, description, - shell_options=shell_options) + jsfile.JSHook.__init__( + self, + hook_logger, + fixture, + js_filename, + description, + shell_options=shell_options, + ) self._background_job = None @@ -47,7 +53,8 @@ class DoReconfigInBackground(jsfile.JSHook): return hook_test_case = _ContinuousDynamicJSTestCase.create_before_test( - test.logger, test, self, self._js_filename, self._shell_options) + test.logger, test, self, self._js_filename, self._shell_options + ) hook_test_case.configure(self.fixture) self.logger.info("Resuming the background reconfig thread.") @@ -70,6 +77,8 @@ class DoReconfigInBackground(jsfile.JSHook): # test execution to stop. raise errors.ServerFailure(self._background_job.exc_info[1].args[0]) else: - self.logger.error("Encountered an error inside the background reconfig thread.", - exc_info=self._background_job.exc_info) + self.logger.error( + "Encountered an error inside the background reconfig thread.", + exc_info=self._background_job.exc_info, + ) raise self._background_job.exc_info[1] diff --git a/buildscripts/resmokelib/testing/hooks/routing_table_consistency.py b/buildscripts/resmokelib/testing/hooks/routing_table_consistency.py index 5f2b7e100e8..274b814c3d2 100644 --- a/buildscripts/resmokelib/testing/hooks/routing_table_consistency.py +++ b/buildscripts/resmokelib/testing/hooks/routing_table_consistency.py @@ -17,13 +17,18 @@ class CheckRoutingTableConsistency(jsfile.PerClusterDataConsistencyHook): def __init__(self, hook_logger, fixture, shell_options=None): """Initialize CheckRoutingTableConsistency.""" - if not isinstance(fixture, shardedcluster.ShardedClusterFixture) and not isinstance( - fixture, multi_sharded_cluster.MultiShardedClusterFixture): + if not isinstance( + fixture, shardedcluster.ShardedClusterFixture + ) and not isinstance(fixture, multi_sharded_cluster.MultiShardedClusterFixture): raise ValueError( f"'fixture' must be an instance of ShardedClusterFixture or MultiShardedClusterFixture, but got" - f" {fixture.__class__.__name__}") + f" {fixture.__class__.__name__}" + ) description = "Inspect collection and chunk metadata in config server" - js_filename = os.path.join("jstests", "hooks", "run_check_routing_table_consistency.js") - super().__init__(hook_logger, fixture, js_filename, description, - shell_options=shell_options) + js_filename = os.path.join( + "jstests", "hooks", "run_check_routing_table_consistency.js" + ) + super().__init__( + hook_logger, fixture, js_filename, description, shell_options=shell_options + ) diff --git a/buildscripts/resmokelib/testing/hooks/run_query_stats.py b/buildscripts/resmokelib/testing/hooks/run_query_stats.py index 2323cf7eb90..3f2d7227b44 100644 --- a/buildscripts/resmokelib/testing/hooks/run_query_stats.py +++ b/buildscripts/resmokelib/testing/hooks/run_query_stats.py @@ -23,43 +23,65 @@ class RunQueryStats(Hook): Args: hook_logger: the logger instance for this hook. fixture: the target fixture (replica sets or a sharded cluster). - allow_feature_not_supported: absorb 'QueryFeatureNotAllowed' errors when calling + allow_feature_not_supported: absorb 'QueryFeatureNotAllowed' errors when calling $queryStats. This is to support fuzzer suites that may manipulate the FCV. """ description = "Read query stats data after each test." super().__init__(hook_logger, fixture, description) self.client = self.fixture.mongo_client() - self.hmac_key = binary.Binary(("0" * 32).encode('utf-8'), 8) + self.hmac_key = binary.Binary(("0" * 32).encode("utf-8"), 8) self.allow_feature_not_supported = allow_feature_not_supported def verify_query_stats(self, querystats_spec): """Verify a $queryStats call has all the right properties.""" try: - with self.client.admin.aggregate([{"$queryStats": querystats_spec}]) as cursor: + with self.client.admin.aggregate( + [{"$queryStats": querystats_spec}] + ) as cursor: for operation in cursor: assert "key" in operation assert "metrics" in operation assert "asOf" in operation except pymongo.errors.OperationFailure as err: - if self.allow_feature_not_supported and err.code in QUERY_STATS_NOT_ENABLED_CODES: - self.logger.info("Encountered an error while running $queryStats. " - "$queryStats will not be run for this test.") + if ( + self.allow_feature_not_supported + and err.code in QUERY_STATS_NOT_ENABLED_CODES + ): + self.logger.info( + "Encountered an error while running $queryStats. " + "$queryStats will not be run for this test." + ) else: raise err def after_test(self, test, test_report): self.verify_query_stats({}) self.verify_query_stats( - {"transformIdentifiers": {"algorithm": "hmac-sha-256", "hmacKey": self.hmac_key}}) + { + "transformIdentifiers": { + "algorithm": "hmac-sha-256", + "hmacKey": self.hmac_key, + } + } + ) def before_test(self, test, test_report): try: # Clear out all existing entries, then reset the size cap. - self.client.admin.command("setParameter", 1, internalQueryStatsCacheSize="0%") - self.client.admin.command("setParameter", 1, internalQueryStatsCacheSize="1%") + self.client.admin.command( + "setParameter", 1, internalQueryStatsCacheSize="0%" + ) + self.client.admin.command( + "setParameter", 1, internalQueryStatsCacheSize="1%" + ) except pymongo.errors.OperationFailure as err: - if self.allow_feature_not_supported and err.code in QUERY_STATS_NOT_ENABLED_CODES: - self.logger.info("Encountered an error while configuring the query stats store. " - "Query stats will not be collected for this test.") + if ( + self.allow_feature_not_supported + and err.code in QUERY_STATS_NOT_ENABLED_CODES + ): + self.logger.info( + "Encountered an error while configuring the query stats store. " + "Query stats will not be collected for this test." + ) else: raise err diff --git a/buildscripts/resmokelib/testing/hooks/secondary_lag.py b/buildscripts/resmokelib/testing/hooks/secondary_lag.py index e08f70715c5..df20bddc3df 100644 --- a/buildscripts/resmokelib/testing/hooks/secondary_lag.py +++ b/buildscripts/resmokelib/testing/hooks/secondary_lag.py @@ -22,8 +22,14 @@ class LagOplogApplicationInBackground(jsfile.JSHook): """Initialize LagOplogApplicationInBackground.""" description = "Lag a secondary's oplog application while the test is running." js_filename = os.path.join("jstests", "hooks", "lag_secondary_application.js") - jsfile.JSHook.__init__(self, hook_logger, fixture, js_filename, description, - shell_options=shell_options) + jsfile.JSHook.__init__( + self, + hook_logger, + fixture, + js_filename, + description, + shell_options=shell_options, + ) self._background_job = None @@ -47,10 +53,13 @@ class LagOplogApplicationInBackground(jsfile.JSHook): return hook_test_case = _ContinuousDynamicJSTestCase.create_before_test( - test.logger, test, self, self._js_filename, self._shell_options) + test.logger, test, self, self._js_filename, self._shell_options + ) hook_test_case.configure(self.fixture) - self.logger.info("Resuming the background secondary oplog application lag thread.") + self.logger.info( + "Resuming the background secondary oplog application lag thread." + ) self._background_job.resume(hook_test_case, test_report) def after_test(self, test, test_report): # noqa: D205,D400 @@ -60,7 +69,9 @@ class LagOplogApplicationInBackground(jsfile.JSHook): if self._background_job is None: return - self.logger.info("Pausing the background secondary oplog application lag thread.") + self.logger.info( + "Pausing the background secondary oplog application lag thread." + ) self._background_job.pause() if self._background_job.exc_info is not None: @@ -72,5 +83,6 @@ class LagOplogApplicationInBackground(jsfile.JSHook): else: self.logger.error( "Encountered an error inside the background secondary oplog application lag thread.", - exc_info=self._background_job.exc_info) + exc_info=self._background_job.exc_info, + ) raise self._background_job.exc_info[1] diff --git a/buildscripts/resmokelib/testing/hooks/shard_filtering_metadata.py b/buildscripts/resmokelib/testing/hooks/shard_filtering_metadata.py index 5c197b69ae5..7ea583d15d0 100644 --- a/buildscripts/resmokelib/testing/hooks/shard_filtering_metadata.py +++ b/buildscripts/resmokelib/testing/hooks/shard_filtering_metadata.py @@ -13,10 +13,15 @@ class CheckShardFilteringMetadata(jsfile.DataConsistencyHook): """Initialize CheckShardFilteringMetadata.""" if not isinstance(fixture, shardedcluster.ShardedClusterFixture): - raise ValueError(f"'fixture' must be an instance of ShardedClusterFixture, but got" - f" {fixture.__class__.__name__}") + raise ValueError( + f"'fixture' must be an instance of ShardedClusterFixture, but got" + f" {fixture.__class__.__name__}" + ) description = "Inspect filtering metadata on shards" - js_filename = os.path.join("jstests", "hooks", "run_check_shard_filtering_metadata.js") - super().__init__(hook_logger, fixture, js_filename, description, - shell_options=shell_options) + js_filename = os.path.join( + "jstests", "hooks", "run_check_shard_filtering_metadata.js" + ) + super().__init__( + hook_logger, fixture, js_filename, description, shell_options=shell_options + ) diff --git a/buildscripts/resmokelib/testing/hooks/simulate_crash.py b/buildscripts/resmokelib/testing/hooks/simulate_crash.py index 289eaa46da1..20dc2ba5f50 100644 --- a/buildscripts/resmokelib/testing/hooks/simulate_crash.py +++ b/buildscripts/resmokelib/testing/hooks/simulate_crash.py @@ -66,8 +66,11 @@ class SimulateCrash(bghook.BGHook): for node in nodes_to_copy: node.mongod.pause() - self.logger.info("Starting to copy data files. DBPath: {}".format( - node.get_dbpath_prefix())) + self.logger.info( + "Starting to copy data files. DBPath: {}".format( + node.get_dbpath_prefix() + ) + ) try: for tup in os.walk(node.get_dbpath_prefix(), followlinks=True): @@ -80,9 +83,11 @@ class SimulateCrash(bghook.BGHook): continue fqfn = "/".join([tup[0], filename]) self.copy_file( - node.get_dbpath_prefix(), fqfn, - node.get_dbpath_prefix() + "/simulateCrashes/{}".format( - self.backup_num)) + node.get_dbpath_prefix(), + fqfn, + node.get_dbpath_prefix() + + "/simulateCrashes/{}".format(self.backup_num), + ) finally: node.mongod.resume() @@ -92,13 +97,15 @@ class SimulateCrash(bghook.BGHook): in_fd = os.open(fqfn, os.O_RDONLY) in_bytes = os.stat(in_fd).st_size - rel = fqfn[len(root):] + rel = fqfn[len(root) :] os.makedirs(new_root + "/journal", exist_ok=True) out_fd = os.open(new_root + rel, os.O_WRONLY | os.O_CREAT) total_bytes_sent = 0 - while (total_bytes_sent < in_bytes): - bytes_sent = os.sendfile(out_fd, in_fd, total_bytes_sent, in_bytes - total_bytes_sent) + while total_bytes_sent < in_bytes: + bytes_sent = os.sendfile( + out_fd, in_fd, total_bytes_sent, in_bytes - total_bytes_sent + ) if bytes_sent == 0: raise ValueError("Unexpectedly reached EOF copying file") total_bytes_sent += bytes_sent @@ -109,20 +116,39 @@ class SimulateCrash(bghook.BGHook): def validate_all(self): """Start a standalone node to validate all collections on the copied data files.""" for node in self.fixture.nodes: - path = node.get_dbpath_prefix() + "/simulateCrashes/{}".format(self.backup_num) - self.logger.info("Starting to validate. DBPath: {} Port: {}".format( - path, self.validate_port)) + path = node.get_dbpath_prefix() + "/simulateCrashes/{}".format( + self.backup_num + ) + self.logger.info( + "Starting to validate. DBPath: {} Port: {}".format( + path, self.validate_port + ) + ) - mdb = process.Process(self.logger, [ - node.mongod_executable, "--dbpath", path, "--port", - str(self.validate_port), "--setParameter", "enableTestCommands=1", "--setParameter", - "testingDiagnosticsEnabled=1" - ]) + mdb = process.Process( + self.logger, + [ + node.mongod_executable, + "--dbpath", + path, + "--port", + str(self.validate_port), + "--setParameter", + "enableTestCommands=1", + "--setParameter", + "testingDiagnosticsEnabled=1", + ], + ) mdb.start() - client = pymongo.MongoClient(host="localhost", port=self.validate_port, connect=True, - connectTimeoutMS=300000, serverSelectionTimeoutMS=300000, - directConnection=True) + client = pymongo.MongoClient( + host="localhost", + port=self.validate_port, + connect=True, + connectTimeoutMS=300000, + serverSelectionTimeoutMS=300000, + directConnection=True, + ) is_valid = validate(client, self.logger, self.acceptable_err_codes) mdb.stop() @@ -140,8 +166,12 @@ class SimulateCrash(bghook.BGHook): self._background_job.join() if self._background_job.err is not None and test_report.wasSuccessful(): - self.logger.error("Encountered an error inside the hook after all tests passed: %s.", - self._background_job.err) + self.logger.error( + "Encountered an error inside the hook after all tests passed: %s.", + self._background_job.err, + ) raise self._background_job.err else: - self.logger.info("Reached end of cycle in the hook, killing background thread.") + self.logger.info( + "Reached end of cycle in the hook, killing background thread." + ) diff --git a/buildscripts/resmokelib/testing/hooks/stepdown.py b/buildscripts/resmokelib/testing/hooks/stepdown.py index 8af806ba494..868fa8e782c 100644 --- a/buildscripts/resmokelib/testing/hooks/stepdown.py +++ b/buildscripts/resmokelib/testing/hooks/stepdown.py @@ -18,18 +18,31 @@ from buildscripts.resmokelib.testing.hooks import lifecycle as lifecycle_interfa class ContinuousStepdown(interface.Hook): """Regularly connect to replica sets and send a replSetStepDown command.""" - DESCRIPTION = ("Continuous stepdown (steps down the primary of replica sets at regular" - " intervals)") + DESCRIPTION = ( + "Continuous stepdown (steps down the primary of replica sets at regular" + " intervals)" + ) IS_BACKGROUND = True # The hook stops the fixture partially during its execution. STOPS_FIXTURE = True - def __init__(self, hook_logger, fixture, config_stepdown=True, shard_stepdown=True, - stepdown_interval_ms=8000, terminate=False, kill=False, randomize_kill=False, - use_action_permitted_file=False, background_reconfig=False, auth_options=None, - should_downgrade=False): + def __init__( + self, + hook_logger, + fixture, + config_stepdown=True, + shard_stepdown=True, + stepdown_interval_ms=8000, + terminate=False, + kill=False, + randomize_kill=False, + use_action_permitted_file=False, + background_reconfig=False, + auth_options=None, + should_downgrade=False, + ): """Initialize the ContinuousStepdown. Args: @@ -50,10 +63,16 @@ class ContinuousStepdown(interface.Hook): "SIGKILL" signals that are used to stop the process. On Windows, there are no signals, so we use a different means to achieve the same result as sending SIGTERM or SIGKILL. """ - interface.Hook.__init__(self, hook_logger, fixture, ContinuousStepdown.DESCRIPTION) + interface.Hook.__init__( + self, hook_logger, fixture, ContinuousStepdown.DESCRIPTION + ) self._fixture = fixture - if hasattr(fixture, "config_shard") and fixture.config_shard is not None and shard_stepdown: + if ( + hasattr(fixture, "config_shard") + and fixture.config_shard is not None + and shard_stepdown + ): # If the config server is a shard, shard_stepdown implies config_stepdown. config_stepdown = shard_stepdown @@ -81,10 +100,12 @@ class ContinuousStepdown(interface.Hook): dbpath_prefix = fixture.get_dbpath_prefix() if use_action_permitted_file: - self.__action_files = lifecycle_interface.ActionFiles._make([ - os.path.join(dbpath_prefix, field) - for field in lifecycle_interface.ActionFiles._fields - ]) + self.__action_files = lifecycle_interface.ActionFiles._make( + [ + os.path.join(dbpath_prefix, field) + for field in lifecycle_interface.ActionFiles._fields + ] + ) else: self.__action_files = None @@ -94,14 +115,25 @@ class ContinuousStepdown(interface.Hook): self._add_fixture(self._fixture) if self.__action_files is not None: - lifecycle = lifecycle_interface.FileBasedThreadLifecycle(self.__action_files) + lifecycle = lifecycle_interface.FileBasedThreadLifecycle( + self.__action_files + ) else: lifecycle = lifecycle_interface.FlagBasedThreadLifecycle() self._stepdown_thread = _StepdownThread( - self.logger, self._mongos_fixtures, self._rs_fixtures, self._stepdown_interval_secs, - self._terminate, self._kill, lifecycle, self._background_reconfig, self._fixture, - self._auth_options, self._should_downgrade) + self.logger, + self._mongos_fixtures, + self._rs_fixtures, + self._stepdown_interval_secs, + self._terminate, + self._kill, + lifecycle, + self._background_reconfig, + self._fixture, + self._auth_options, + self._should_downgrade, + ) self.logger.info("Starting the stepdown thread.") self._stepdown_thread.start() @@ -128,7 +160,8 @@ class ContinuousStepdown(interface.Hook): if not fixture.all_nodes_electable: raise ValueError( "The replica sets that are the target of the ContinuousStepdown hook must have" - " the 'all_nodes_electable' option set.") + " the 'all_nodes_electable' option set." + ) self._rs_fixtures.append(fixture) elif isinstance(fixture, shardedcluster.ShardedClusterFixture): if self._shard_stepdown: @@ -146,9 +179,20 @@ class ContinuousStepdown(interface.Hook): class _StepdownThread(threading.Thread): - def __init__(self, logger, mongos_fixtures, rs_fixtures, stepdown_interval_secs, terminate, - kill, stepdown_lifecycle, background_reconfig, fixture, auth_options=None, - should_downgrade=False): + def __init__( + self, + logger, + mongos_fixtures, + rs_fixtures, + stepdown_interval_secs, + terminate, + kill, + stepdown_lifecycle, + background_reconfig, + fixture, + auth_options=None, + should_downgrade=False, + ): """Initialize _StepdownThread.""" threading.Thread.__init__(self, name="StepdownThread") self.daemon = True @@ -200,8 +244,10 @@ class _StepdownThread(threading.Thread): # Wait until each replica set has a primary, so the test can make progress. self._await_primaries() self._last_exec = time.time() - self.logger.info("Completed stepdown of all primaries in %0d ms", - (self._last_exec - now) * 1000) + self.logger.info( + "Completed stepdown of all primaries in %0d ms", + (self._last_exec - now) * 1000, + ) found_idle_request = self.__lifecycle.poll_for_idle_request() if found_idle_request: @@ -211,7 +257,9 @@ class _StepdownThread(threading.Thread): # The 'wait_secs' is used to wait 'self._stepdown_interval_secs' from the moment # the last stepdown command was sent. now = time.time() - wait_secs = max(0, self._stepdown_interval_secs - (now - self._last_exec)) + wait_secs = max( + 0, self._stepdown_interval_secs - (now - self._last_exec) + ) self.__lifecycle.wait_for_action_interval(wait_secs) except Exception: # pylint: disable=W0703 # Proactively log the exception when it happens so it will be @@ -244,12 +292,14 @@ class _StepdownThread(threading.Thread): if not rs_fixture.is_running(): raise errors.ServerFailure( "ReplicaSetFixture with pids {} expected to be running in" - " ContinuousStepdown, but wasn't.".format(rs_fixture.pids())) + " ContinuousStepdown, but wasn't.".format(rs_fixture.pids()) + ) for mongos_fixture in self._mongos_fixtures: if not mongos_fixture.is_running(): - raise errors.ServerFailure("MongoSFixture with pids {} expected to be running in" - " ContinuousStepdown, but wasn't.".format( - mongos_fixture.pids())) + raise errors.ServerFailure( + "MongoSFixture with pids {} expected to be running in" + " ContinuousStepdown, but wasn't.".format(mongos_fixture.pids()) + ) def resume(self): """Resume the thread.""" @@ -257,7 +307,8 @@ class _StepdownThread(threading.Thread): self.logger.info( "Current statistics about which nodes have been successfully stepped up: %s", - self._step_up_stats) + self._step_up_stats, + ) def _wait(self, timeout): # Wait until stop or timeout. @@ -282,7 +333,9 @@ class _StepdownThread(threading.Thread): def _step_down(self, rs_fixture): try: - old_primary = rs_fixture.get_primary(timeout_secs=self._stepdown_interval_secs) + old_primary = rs_fixture.get_primary( + timeout_secs=self._stepdown_interval_secs + ) except errors.ServerFailure: # We ignore the ServerFailure exception because it means a primary wasn't available. # We'll try again after self._stepdown_interval_secs seconds. @@ -290,15 +343,21 @@ class _StepdownThread(threading.Thread): secondaries = rs_fixture.get_secondaries() - self.logger.info("Stepping down primary on port %d of replica set '%s'", old_primary.port, - rs_fixture.replset_name) + self.logger.info( + "Stepping down primary on port %d of replica set '%s'", + old_primary.port, + rs_fixture.replset_name, + ) if self._terminate: - if not rs_fixture.stop_primary(old_primary, self._background_reconfig, self._kill): + if not rs_fixture.stop_primary( + old_primary, self._background_reconfig, self._kill + ): return if self._should_downgrade: - new_primary = rs_fixture.change_version_and_restart_node(old_primary, - self._auth_options) + new_primary = rs_fixture.change_version_and_restart_node( + old_primary, self._auth_options + ) else: def step_up_secondary(): @@ -306,11 +365,15 @@ class _StepdownThread(threading.Thread): chosen = random.choice(secondaries) self.logger.info( "Chose secondary on port %d of replica set '%s' for step up attempt.", - chosen.port, rs_fixture.replset_name) + chosen.port, + rs_fixture.replset_name, + ) if not rs_fixture.stepup_node(chosen, self._auth_options): self.logger.info( "Attempt to step up secondary on port %d of replica set '%s' failed.", - chosen.port, rs_fixture.replset_name) + chosen.port, + rs_fixture.replset_name, + ) secondaries.remove(chosen) else: return chosen @@ -327,7 +390,9 @@ class _StepdownThread(threading.Thread): # that may depend on the health of the replica set. self.logger.info( "Successfully stepped up the secondary on port %d of replica set '%s'.", - new_primary.port, rs_fixture.replset_name) + new_primary.port, + rs_fixture.replset_name, + ) retry_time_secs = rs_fixture.AWAIT_REPL_TIMEOUT_MINS * 60 retry_start_time = time.time() while True: @@ -341,13 +406,21 @@ class _StepdownThread(threading.Thread): if time.time() - retry_start_time > retry_time_secs: raise errors.ServerFailure( "The old primary on port {} of replica set {} did not step down in" - " {} seconds.".format(client.port, rs_fixture.replset_name, - retry_time_secs)) - self.logger.info("Waiting for primary on port %d of replica set '%s' to step down.", - old_primary.port, rs_fixture.replset_name) + " {} seconds.".format( + client.port, rs_fixture.replset_name, retry_time_secs + ) + ) + self.logger.info( + "Waiting for primary on port %d of replica set '%s' to step down.", + old_primary.port, + rs_fixture.replset_name, + ) time.sleep(0.2) # Wait a little bit before trying again. - self.logger.info("Primary on port %d of replica set '%s' stepped down.", - old_primary.port, rs_fixture.replset_name) + self.logger.info( + "Primary on port %d of replica set '%s' stepped down.", + old_primary.port, + rs_fixture.replset_name, + ) if not secondaries: # If we failed to step up one of the secondaries, then we run the replSetStepUp to try @@ -375,12 +448,15 @@ class _StepdownThread(threading.Thread): if time.time() - retry_start_time > retry_time_secs: raise errors.ServerFailure( "The old primary on port {} of replica set {} did not step up in" - " {} seconds.".format(client.port, rs_fixture.replset_name, - retry_time_secs)) + " {} seconds.".format( + client.port, rs_fixture.replset_name, retry_time_secs + ) + ) # Bump the counter for the chosen secondary to indicate that the replSetStepUp command # executed successfully. key = "{}/{}".format( rs_fixture.replset_name, - new_primary.get_internal_connection_string() if secondaries else "none") + new_primary.get_internal_connection_string() if secondaries else "none", + ) self._step_up_stats[key] += 1 diff --git a/buildscripts/resmokelib/testing/hooks/validate.py b/buildscripts/resmokelib/testing/hooks/validate.py index a4417e2d848..4d5363faa29 100644 --- a/buildscripts/resmokelib/testing/hooks/validate.py +++ b/buildscripts/resmokelib/testing/hooks/validate.py @@ -15,9 +15,16 @@ class ValidateCollections(jsfile.PerClusterDataConsistencyHook): IS_BACKGROUND = False def __init__( # pylint: disable=super-init-not-called - self, hook_logger, fixture, shell_options=None): + self, hook_logger, fixture, shell_options=None + ): """Initialize ValidateCollections.""" description = "Full collection validation" js_filename = os.path.join("jstests", "hooks", "run_validate_collections.js") jsfile.JSHook.__init__( # pylint: disable=non-parent-init-called - self, hook_logger, fixture, js_filename, description, shell_options=shell_options) + self, + hook_logger, + fixture, + js_filename, + description, + shell_options=shell_options, + ) diff --git a/buildscripts/resmokelib/testing/hooks/validate_background.py b/buildscripts/resmokelib/testing/hooks/validate_background.py index 4e994d40db6..4878ba20c11 100644 --- a/buildscripts/resmokelib/testing/hooks/validate_background.py +++ b/buildscripts/resmokelib/testing/hooks/validate_background.py @@ -23,9 +23,17 @@ class ValidateCollectionsInBackground(jsfile.JSHook): def __init__(self, hook_logger, fixture, shell_options=None): """Initialize ValidateCollectionsInBackground.""" description = "Run background collection validation against all mongods while a test is running" - js_filename = os.path.join("jstests", "hooks", "run_validate_collections_background.js") - jsfile.JSHook.__init__(self, hook_logger, fixture, js_filename, description, - shell_options=shell_options) + js_filename = os.path.join( + "jstests", "hooks", "run_validate_collections_background.js" + ) + jsfile.JSHook.__init__( + self, + hook_logger, + fixture, + js_filename, + description, + shell_options=shell_options, + ) self._background_job = None @@ -49,7 +57,8 @@ class ValidateCollectionsInBackground(jsfile.JSHook): return hook_test_case = _ContinuousDynamicJSTestCase.create_before_test( - test.logger, test, self, self._js_filename, self._shell_options) + test.logger, test, self, self._js_filename, self._shell_options + ) hook_test_case.configure(self.fixture) self.logger.info("Resuming the background collection validation thread.") @@ -74,5 +83,6 @@ class ValidateCollectionsInBackground(jsfile.JSHook): else: self.logger.error( "Encountered an error inside the background collection validation thread.", - exc_info=self._background_job.exc_info) + exc_info=self._background_job.exc_info, + ) raise self._background_job.exc_info[1] diff --git a/buildscripts/resmokelib/testing/hooks/validate_direct_secondary_reads.py b/buildscripts/resmokelib/testing/hooks/validate_direct_secondary_reads.py index 12ec2c81eb9..f602d397fea 100644 --- a/buildscripts/resmokelib/testing/hooks/validate_direct_secondary_reads.py +++ b/buildscripts/resmokelib/testing/hooks/validate_direct_secondary_reads.py @@ -15,9 +15,18 @@ class ValidateDirectSecondaryReads(jsfile.PerClusterDataConsistencyHook): IS_BACKGROUND = False def __init__( # pylint: disable=super-init-not-called - self, hook_logger, fixture, shell_options=None): + self, hook_logger, fixture, shell_options=None + ): """Initialize ValidateDirectSecondaryReads.""" description = "Validate direct secondary reads" - js_filename = os.path.join("jstests", "hooks", "run_validate_direct_secondary_reads.js") + js_filename = os.path.join( + "jstests", "hooks", "run_validate_direct_secondary_reads.js" + ) jsfile.JSHook.__init__( # pylint: disable=non-parent-init-called - self, hook_logger, fixture, js_filename, description, shell_options=shell_options) + self, + hook_logger, + fixture, + js_filename, + description, + shell_options=shell_options, + ) diff --git a/buildscripts/resmokelib/testing/hooks/wait_for_replication.py b/buildscripts/resmokelib/testing/hooks/wait_for_replication.py index 938be25493b..0201208bc9c 100644 --- a/buildscripts/resmokelib/testing/hooks/wait_for_replication.py +++ b/buildscripts/resmokelib/testing/hooks/wait_for_replication.py @@ -36,9 +36,15 @@ class WaitForReplication(interface.Hook): jsTestLog("Ignoring shutdown error in quiesce mode"); }}""" shell_options = {"nodb": "", "eval": js_cmds.format(client_conn)} - shell_proc = core.programs.mongo_shell_program(self.hook_logger, **shell_options) + shell_proc = core.programs.mongo_shell_program( + self.hook_logger, **shell_options + ) shell_proc.start() return_code = shell_proc.wait() if return_code: - raise errors.ServerFailure("Awaiting replication failed for {}".format(client_conn)) - self.hook_logger.info("WaitForReplication took %0.4f seconds", time.time() - start_time) + raise errors.ServerFailure( + "Awaiting replication failed for {}".format(client_conn) + ) + self.hook_logger.info( + "WaitForReplication took %0.4f seconds", time.time() - start_time + ) diff --git a/buildscripts/resmokelib/testing/job.py b/buildscripts/resmokelib/testing/job.py index 50bd7c77799..a79b7459549 100644 --- a/buildscripts/resmokelib/testing/job.py +++ b/buildscripts/resmokelib/testing/job.py @@ -36,9 +36,17 @@ TRACER = trace.get_tracer("resmoke") class Job(object): """Run tests from a queue.""" - def __init__(self, job_num: int, logger: logging.Logger, fixture: Fixture, hooks: List[Hook], - report: TestReport, archival: HookTestArchival, suite_options: config.SuiteOptions, - test_queue_logger: logging.Logger): + def __init__( + self, + job_num: int, + logger: logging.Logger, + fixture: Fixture, + hooks: List[Hook], + report: TestReport, + archival: HookTestArchival, + suite_options: config.SuiteOptions, + test_queue_logger: logging.Logger, + ): """Initialize the job with the specified fixture and hooks.""" self.logger = logger @@ -47,14 +55,17 @@ class Job(object): self.report = report self.archival = archival self.suite_options = suite_options - self.manager = FixtureTestCaseManager(test_queue_logger, self.fixture, job_num, self.report) + self.manager = FixtureTestCaseManager( + test_queue_logger, self.fixture, job_num, self.report + ) # Don't check fixture.is_running() when using hooks that kill and restart fixtures, such # as ContinuousStepdown or KillReplicator. Even if the fixture is still running as # expected, there is a race where fixture.is_running() could fail if called after the # primary was killed but before it was restarted. self._check_if_fixture_running = not any( - hasattr(hook, "STOPS_FIXTURE") and hook.STOPS_FIXTURE for hook in self.hooks) + hasattr(hook, "STOPS_FIXTURE") and hook.STOPS_FIXTURE for hook in self.hooks + ) @property def job_num(self) -> int: @@ -62,21 +73,23 @@ class Job(object): return self.manager.job_num @staticmethod - def _interrupt_all_jobs(queue: 'TestQueue[Union[QueueElemRepeatTime, QueueElem]]', - interrupt_flag: threading.Event): + def _interrupt_all_jobs( + queue: "TestQueue[Union[QueueElemRepeatTime, QueueElem]]", + interrupt_flag: threading.Event, + ): # Set the interrupt flag so that other jobs do not start running more tests. interrupt_flag.set() # Drain the queue to unblock the main thread. Job._drain_queue(queue) def __call__( - self, - queue: 'TestQueue[Union[QueueElemRepeatTime, QueueElem]]', - interrupt_flag: threading.Event, - parent_context: Context, - setup_flag: Optional[threading.Event] = None, - teardown_flag: Optional[threading.Event] = None, - hook_failure_flag: Optional[threading.Event] = None, + self, + queue: "TestQueue[Union[QueueElemRepeatTime, QueueElem]]", + interrupt_flag: threading.Event, + parent_context: Context, + setup_flag: Optional[threading.Event] = None, + teardown_flag: Optional[threading.Event] = None, + hook_failure_flag: Optional[threading.Event] = None, ): """Continuously execute tests from 'queue' and records their details in 'report'. @@ -100,12 +113,16 @@ class Job(object): # test_id from logkeeper for where to put the log output. We don't attempt to run # any tests. self.logger.error( - "Received a StopExecution exception when setting up the fixture: %s.", err) + "Received a StopExecution exception when setting up the fixture: %s.", + err, + ) setup_succeeded = False except: # pylint: disable=bare-except # Something unexpected happened when setting up the fixture. We don't attempt to run # any tests. - self.logger.exception("Encountered an error when setting up the fixture.") + self.logger.exception( + "Encountered an error when setting up the fixture." + ) setup_succeeded = False if not setup_succeeded: @@ -133,13 +150,17 @@ class Job(object): # executor thread that teardown has failed. This likely means resmoke.py is exiting # without having terminated all of the child processes it spawned. self.logger.error( - "Received a StopExecution exception when tearing down the fixture: %s.", err) + "Received a StopExecution exception when tearing down the fixture: %s.", + err, + ) teardown_succeeded = False except: # pylint: disable=bare-except # Something unexpected happened when tearing down the fixture. We indicate back to # the executor thread that teardown has failed. This may mean resmoke.py is exiting # without having terminated all of the child processes it spawned. - self.logger.exception("Encountered an error when tearing down the fixture.") + self.logger.exception( + "Encountered an error when tearing down the fixture." + ) teardown_succeeded = False if not teardown_succeeded: @@ -150,9 +171,13 @@ class Job(object): """Get current time to aid in the unit testing of the _run method.""" return time.time() - def _run(self, queue: 'TestQueue[Union[QueueElemRepeatTime, QueueElem]]', - interrupt_flag: threading.Event, teardown_flag: Optional[threading.Event] = None, - hook_failure_flag: Optional[threading.Event] = None): + def _run( + self, + queue: "TestQueue[Union[QueueElemRepeatTime, QueueElem]]", + interrupt_flag: threading.Event, + teardown_flag: Optional[threading.Event] = None, + hook_failure_flag: Optional[threading.Event] = None, + ): """Call the before/after suite hooks and continuously execute tests from 'queue'.""" self._run_hooks_before_suite(hook_failure_flag) @@ -176,31 +201,48 @@ class Job(object): if self.suite_options.time_repeat_tests_secs: progress = "{} of ({}/{}/{:2.2f} min/max/time)".format( - queue_elem.repeat_num + 1, self.suite_options.num_repeat_tests_min, - self.suite_options.num_repeat_tests_max, self.suite_options.time_repeat_tests_secs) + queue_elem.repeat_num + 1, + self.suite_options.num_repeat_tests_min, + self.suite_options.num_repeat_tests_max, + self.suite_options.time_repeat_tests_secs, + ) else: - progress = "{} of {}".format(queue_elem.repeat_num + 1, - self.suite_options.num_repeat_tests) - self.logger.info(("Requeueing test %s %s, cumulative time elapsed %0.2f"), - queue_elem.testcase.test_name, progress, queue_elem.repeat_time_elapsed) + progress = "{} of {}".format( + queue_elem.repeat_num + 1, self.suite_options.num_repeat_tests + ) + self.logger.info( + ("Requeueing test %s %s, cumulative time elapsed %0.2f"), + queue_elem.testcase.test_name, + progress, + queue_elem.repeat_time_elapsed, + ) - def _requeue_test(self, queue: 'TestQueue[Union[QueueElemRepeatTime, QueueElem]]', - queue_elem: QueueElemRepeatTime, interrupt_flag: threading.Event): + def _requeue_test( + self, + queue: "TestQueue[Union[QueueElemRepeatTime, QueueElem]]", + queue_elem: QueueElemRepeatTime, + interrupt_flag: threading.Event, + ): """Requeue a test if it needs to be repeated.""" if not queue_elem.should_requeue(): return queue_elem.testcase = testcases.make_test_case( - queue_elem.testcase.REGISTERED_NAME, queue_elem.testcase.logger, - queue_elem.testcase.test_name, **queue_elem.test_config) + queue_elem.testcase.REGISTERED_NAME, + queue_elem.testcase.logger, + queue_elem.testcase.test_name, + **queue_elem.test_config, + ) if not interrupt_flag.is_set(): self._log_requeue_test(queue_elem) queue.put(queue_elem) @TRACER.start_as_current_span("job._execute_test") - def _execute_test(self, test: TestCase, hook_failure_flag: Optional[threading.Event]): + def _execute_test( + self, test: TestCase, hook_failure_flag: Optional[threading.Event] + ): """Call the before/after test hooks and execute 'test'.""" common_test_attributes = test.get_test_otel_attributes() @@ -208,13 +250,17 @@ class Job(object): execute_test_span.set_attributes(attributes=common_test_attributes) execute_test_span.set_status(StatusCode.ERROR, "fail_early") - test.configure(self.fixture, config.NUM_CLIENTS_PER_FIXTURE, config.USE_TENANT_CLIENT) + test.configure( + self.fixture, config.NUM_CLIENTS_PER_FIXTURE, config.USE_TENANT_CLIENT + ) self._run_hooks_before_tests(test, hook_failure_flag) self.report.logging_prefix = create_fixture_table(self.fixture) - with TRACER.start_as_current_span("run_test", attributes=common_test_attributes): + with TRACER.start_as_current_span( + "run_test", attributes=common_test_attributes + ): test(self.report) try: if test.propagate_error is not None: @@ -225,19 +271,27 @@ class Job(object): # part of a hook has added a failed test case to 'self.report'. Checking the individual # 'test' status ensures self._run_hooks_after_tests() is called if it is a hook's test # case that has failed and not 'test' that has failed. - if self.suite_options.fail_fast and self.report.find_test_info(test).status != "pass": - self.logger.info("%s failed, so stopping..." % (test.short_description())) + if ( + self.suite_options.fail_fast + and self.report.find_test_info(test).status != "pass" + ): + self.logger.info( + "%s failed, so stopping..." % (test.short_description()) + ) raise errors.StopExecution("%s failed" % (test.short_description())) if self._check_if_fixture_running and not self.fixture.is_running(): self.logger.error( "%s marked as a failure because the fixture crashed during the test.", - test.short_description()) - self.report.setFailure(test, return_code=2, - reason="the fixture crashed during the test") + test.short_description(), + ) + self.report.setFailure( + test, return_code=2, reason="the fixture crashed during the test" + ) # Always fail fast if the fixture fails. raise errors.StopExecution( - "%s not running after %s" % (self.fixture, test.short_description())) + "%s not running after %s" % (self.fixture, test.short_description()) + ) finally: success: bool = self.report.find_test_info(test).status == "pass" if success: @@ -255,8 +309,13 @@ class Job(object): self._run_hooks_after_tests(test, hook_failure_flag, background=False) - def _run_hook(self, hook: Hook, hook_function, test: TestCase, - hook_failure_flag: Optional[threading.Event]): + def _run_hook( + self, + hook: Hook, + hook_function, + test: TestCase, + hook_failure_flag: Optional[threading.Event], + ): """Provide helper to run hook and archival.""" try: success = False @@ -283,11 +342,15 @@ class Job(object): if hooks_failed and hook_failure_flag is not None: hook_failure_flag.set() run_hooks_before_suite_span.set_status( - StatusCode.ERROR if hooks_failed else StatusCode.OK) + StatusCode.ERROR if hooks_failed else StatusCode.OK + ) @TRACER.start_as_current_span("job._run_hooks_after_suite") - def _run_hooks_after_suite(self, teardown_flag: Optional[threading.Event], - hook_failure_flag: Optional[threading.Event]): + def _run_hooks_after_suite( + self, + teardown_flag: Optional[threading.Event], + hook_failure_flag: Optional[threading.Event], + ): """Run the after_suite method on each of the hooks.""" run_hooks_after_suite_span = trace.get_current_span() hooks_failed = True @@ -299,9 +362,12 @@ class Job(object): if hooks_failed and hook_failure_flag is not None: hook_failure_flag.set() run_hooks_after_suite_span.set_status( - StatusCode.ERROR if hooks_failed else StatusCode.OK) + StatusCode.ERROR if hooks_failed else StatusCode.OK + ) - def _run_hooks_before_tests(self, test: TestCase, hook_failure_flag: Optional[threading.Event]): + def _run_hooks_before_tests( + self, test: TestCase, hook_failure_flag: Optional[threading.Event] + ): """Run the before_test method on each of the hooks. Swallows any TestFailure exceptions if set to continue on @@ -316,14 +382,18 @@ class Job(object): raise except errors.ServerFailure: - self.logger.exception("%s marked as a failure by a hook's before_test.", - test.short_description()) + self.logger.exception( + "%s marked as a failure by a hook's before_test.", + test.short_description(), + ) self._fail_test(test, sys.exc_info(), return_code=2) raise errors.StopExecution("A hook's before_test failed") except errors.TestFailure: - self.logger.exception("%s marked as a failure by a hook's before_test.", - test.short_description()) + self.logger.exception( + "%s marked as a failure by a hook's before_test.", + test.short_description(), + ) self._fail_test(test, sys.exc_info(), return_code=1) if self.suite_options.fail_fast: raise errors.StopExecution("A hook's before_test failed") @@ -335,8 +405,12 @@ class Job(object): self.report.stopTest(test) raise - def _run_hooks_after_tests(self, test: TestCase, hook_failure_flag: Optional[threading.Event], - background: bool = False): + def _run_hooks_after_tests( + self, + test: TestCase, + hook_failure_flag: Optional[threading.Event], + background: bool = False, + ): """Run the after_test method on each of the hooks. Swallows any TestFailure exceptions if set to continue on @@ -346,23 +420,31 @@ class Job(object): @param background: whether to run background hooks. """ - suite_with_balancer = isinstance( - self.fixture, shardedcluster.ShardedClusterFixture) and self.fixture.enable_balancer + suite_with_balancer = ( + isinstance(self.fixture, shardedcluster.ShardedClusterFixture) + and self.fixture.enable_balancer + ) if not background and suite_with_balancer: try: self.logger.info("Stopping the balancer before running end-test hooks") self.fixture.stop_balancer() except: - self.logger.exception("%s failed while stopping the balancer for after-test hooks", - test.short_description()) + self.logger.exception( + "%s failed while stopping the balancer for after-test hooks", + test.short_description(), + ) self.report.setFailure( - test, return_code=2, - reason="the balancer failed to stop before running after-test hooks") + test, + return_code=2, + reason="the balancer failed to stop before running after-test hooks", + ) if self.archival: result = TestResult(test=test, hook=None, success=False) self.archival.archive(self.logger, result, self.manager) - raise errors.StopExecution("stop_balancer failed before running after test hooks") + raise errors.StopExecution( + "stop_balancer failed before running after test hooks" + ) try: for hook in self.hooks: @@ -373,17 +455,23 @@ class Job(object): raise except errors.ServerFailure: - self.logger.exception("%s marked as a failure by a hook's after_test.", - test.short_description()) - self.report.setFailure(test, return_code=2, - reason=f"The hook {hook.REGISTERED_NAME} failed.") + self.logger.exception( + "%s marked as a failure by a hook's after_test.", + test.short_description(), + ) + self.report.setFailure( + test, return_code=2, reason=f"The hook {hook.REGISTERED_NAME} failed." + ) raise errors.StopExecution("A hook's after_test failed") except errors.TestFailure: - self.logger.exception("%s marked as a failure by a hook's after_test.", - test.short_description()) - self.report.setFailure(test, return_code=1, - reason=f"The hook {hook.REGISTERED_NAME} failed.") + self.logger.exception( + "%s marked as a failure by a hook's after_test.", + test.short_description(), + ) + self.report.setFailure( + test, return_code=1, reason=f"The hook {hook.REGISTERED_NAME} failed." + ) if self.suite_options.fail_fast: raise errors.StopExecution("A hook's after_test failed") @@ -398,14 +486,19 @@ class Job(object): except: self.logger.exception( "%s failed while re-starting the balancer after end-test hooks", - test.short_description()) + test.short_description(), + ) self.report.setFailure( - test, return_code=2, - reason="the balancer failed to restart after running after test hooks") + test, + return_code=2, + reason="the balancer failed to restart after running after test hooks", + ) if self.archival: result = TestResult(test=test, hook=None, success=False) self.archival.archive(self.logger, result, self.manager) - raise errors.StopExecution("start_balancer failed after running after test hooks") + raise errors.StopExecution( + "start_balancer failed after running after test hooks" + ) def _fail_test(self, test: TestCase, exc_info, return_code=1): """Provide helper to record a test as a failure with the provided return code. @@ -436,14 +529,19 @@ class Job(object): pass -TestResult = namedtuple('TestResult', ['test', 'hook', 'success']) +TestResult = namedtuple("TestResult", ["test", "hook", "success"]) class FixtureTestCaseManager: """Class that holds information needed to create new fixture setup/teardown test cases for a single job.""" - def __init__(self, test_queue_logger: logging.Logger, fixture: Fixture, job_num: int, - report: TestReport): + def __init__( + self, + test_queue_logger: logging.Logger, + fixture: Fixture, + job_num: int, + report: TestReport, + ): """ Initialize the test case manager. @@ -464,8 +562,12 @@ class FixtureTestCaseManager: Return True if the setup was successful, False otherwise. """ - test_case = _fixture.FixtureSetupTestCase(self.test_queue_logger, self.fixture, - "job{}".format(self.job_num), self.times_set_up) + test_case = _fixture.FixtureSetupTestCase( + self.test_queue_logger, + self.fixture, + "job{}".format(self.job_num), + self.times_set_up, + ) test_case(self.report) if self.report.find_test_info(test_case).status != "pass": logger.error("The setup of %s failed.", self.fixture) @@ -480,16 +582,22 @@ class FixtureTestCaseManager: Return True if the teardown was successful, False otherwise. """ try: - test_case: Union[_fixture.FixtureAbortTestCase, _fixture.FixtureTeardownTestCase] = None + test_case: Union[ + _fixture.FixtureAbortTestCase, _fixture.FixtureTeardownTestCase + ] = None if abort: - test_case = _fixture.FixtureAbortTestCase(self.test_queue_logger, self.fixture, - "job{}".format(self.job_num), - self.times_set_up) + test_case = _fixture.FixtureAbortTestCase( + self.test_queue_logger, + self.fixture, + "job{}".format(self.job_num), + self.times_set_up, + ) self.times_set_up += 1 else: - test_case = _fixture.FixtureTeardownTestCase(self.test_queue_logger, self.fixture, - "job{}".format(self.job_num)) + test_case = _fixture.FixtureTeardownTestCase( + self.test_queue_logger, self.fixture, "job{}".format(self.job_num) + ) # Refresh the fixture table before teardown to capture changes due to # CleanEveryN and stepdown hooks. diff --git a/buildscripts/resmokelib/testing/queue_element.py b/buildscripts/resmokelib/testing/queue_element.py index 7cef4144907..5129f4135a7 100644 --- a/buildscripts/resmokelib/testing/queue_element.py +++ b/buildscripts/resmokelib/testing/queue_element.py @@ -3,8 +3,9 @@ from typing import Union -def queue_elem_factory(testcase, test_config, - suite_options) -> Union['QueueElemRepeatTime', 'QueueElem']: +def queue_elem_factory( + testcase, test_config, suite_options +) -> Union["QueueElemRepeatTime", "QueueElem"]: """ Create the appropriate queue element based on suite_options given. diff --git a/buildscripts/resmokelib/testing/report.py b/buildscripts/resmokelib/testing/report.py index bffedc345bf..59ffa406eaf 100644 --- a/buildscripts/resmokelib/testing/report.py +++ b/buildscripts/resmokelib/testing/report.py @@ -52,8 +52,10 @@ class TestReport(unittest.TestResult): # TestReports that are used when running tests need a JobLogger but combined reports don't # use the logger. - combined_report = cls(logging.loggers.ROOT_EXECUTOR_LOGGER, - _config.SuiteOptions.ALL_INHERITED.resolve()) + combined_report = cls( + logging.loggers.ROOT_EXECUTOR_LOGGER, + _config.SuiteOptions.ALL_INHERITED.resolve(), + ) combining_time = time.time() for report in reports: @@ -78,7 +80,8 @@ class TestReport(unittest.TestResult): if "AFTER_TIMEOUT" not in logger.name: logger.name = f"{logger.name}:AFTER_TIMEOUT" logger.error( - "HIT EVERGREEN TIMEOUT: Hang analyzer will kill or abort processes") + "HIT EVERGREEN TIMEOUT: Hang analyzer will kill or abort processes" + ) # Until EVG-1536 is completed, we shouldn't distinguish between failures and # interrupted tests in the report.json file. In Evergreen, the behavior to # sort tests with the "timeout" test status after tests with the "pass" test @@ -131,20 +134,28 @@ class TestReport(unittest.TestResult): self.num_dynamic += 1 # Set up the test-specific logger. - (test_logger, url_endpoint) = logging.loggers.new_test_logger(test.short_name(), - test.basename(), command, - test.logger, self.job_num, - test.id(), self.job_logger) + (test_logger, url_endpoint) = logging.loggers.new_test_logger( + test.short_name(), + test.basename(), + command, + test.logger, + self.job_num, + test.id(), + self.job_logger, + ) test_info.add_logger(test_logger) test_info.add_logger(self.job_logger) # Set up logging handlers to capture exceptions. - test_info.exception_extractors = logging.loggers.configure_exception_capture(test_logger) + test_info.exception_extractors = logging.loggers.configure_exception_capture( + test_logger + ) test_info.log_info = { "log_name": logging.loggers.get_evergreen_log_name(self.job_num, test.id()), "logs_to_merge": [logging.loggers.get_evergreen_log_name(self.job_num)], - "rendering_type": "resmoke", "version": 0 + "rendering_type": "resmoke", + "version": 0, } test_info.url_endpoint = url_endpoint if self.logging_prefix is not None: @@ -169,11 +180,14 @@ class TestReport(unittest.TestResult): with self._lock: test_info = self.find_test_info(test) test_info.end_time = time.time() - test_status = "no failures detected" if test_info.status == "pass" else "failed" + test_status = ( + "no failures detected" if test_info.status == "pass" else "failed" + ) time_taken = test_info.end_time - test_info.start_time - self.job_logger.info("%s ran in %0.2f seconds: %s.", test.basename(), time_taken, - test_status) + self.job_logger.info( + "%s ran in %0.2f seconds: %s.", test.basename(), time_taken, test_status + ) finally: # This is a failsafe. In the event that 'stopTest' fails, @@ -201,7 +215,7 @@ class TestReport(unittest.TestResult): test_info.status = "error" test_info.evergreen_status = "fail" test_info.return_code = test.return_code - test_info.error = self._exc_info_to_string(err, test).split('\n') + test_info.error = self._exc_info_to_string(err, test).split("\n") def setError(self, test, err): """Change the outcome of an existing test to an error.""" @@ -217,7 +231,7 @@ class TestReport(unittest.TestResult): test_info.status = "error" test_info.evergreen_status = "fail" test_info.return_code = 2 - test_info.error = self._exc_info_to_string(err, test).split('\n') + test_info.error = self._exc_info_to_string(err, test).split("\n") # Recompute number of success, failures, and errors. self.num_succeeded = len(self.get_successful()) @@ -286,25 +300,37 @@ class TestReport(unittest.TestResult): """Return the status and timing information of the tests that executed successfully.""" with self._lock: - return [test_info for test_info in self.test_infos if test_info.status == "pass"] + return [ + test_info for test_info in self.test_infos if test_info.status == "pass" + ] def get_failed(self): """Return the status and timing information of tests that raised a failureException.""" with self._lock: - return [test_info for test_info in self.test_infos if test_info.status == "fail"] + return [ + test_info for test_info in self.test_infos if test_info.status == "fail" + ] def get_errored(self): """Return the status and timing information of tests that raised a non-failureException.""" with self._lock: - return [test_info for test_info in self.test_infos if test_info.status == "error"] + return [ + test_info + for test_info in self.test_infos + if test_info.status == "error" + ] def get_interrupted(self): """Return the status and timing information of tests that were execution interrupted.""" with self._lock: - return [test_info for test_info in self.test_infos if test_info.status == "timeout"] + return [ + test_info + for test_info in self.test_infos + if test_info.status == "timeout" + ] def as_dict(self): """Return the test result information as a dictionary. @@ -347,11 +373,15 @@ class TestReport(unittest.TestResult): Used when combining reports instances. """ - report = cls(logging.loggers.ROOT_EXECUTOR_LOGGER, - _config.SuiteOptions.ALL_INHERITED.resolve()) + report = cls( + logging.loggers.ROOT_EXECUTOR_LOGGER, + _config.SuiteOptions.ALL_INHERITED.resolve(), + ) for result in report_dict["results"]: # By convention, dynamic tests are named ":". - is_dynamic = ":" in result["test_file"] or ":" in result.get("display_test_name", "") + is_dynamic = ":" in result["test_file"] or ":" in result.get( + "display_test_name", "" + ) test_file = result["test_file"] # Using test_file as the test id is ok here since the test id only needs to be unique # during suite execution. @@ -405,12 +435,21 @@ class TestReport(unittest.TestResult): def _log_outcome_change(self, test, outcome, reason=""): # Recreate the test logger for this test in order to append to the existing log. - (logger, - _) = logging.loggers.new_test_logger(test.short_name(), test.basename(), None, test.logger, - self.job_num, test.id(), self.job_logger) + (logger, _) = logging.loggers.new_test_logger( + test.short_name(), + test.basename(), + None, + test.logger, + self.job_num, + test.id(), + self.job_logger, + ) logger.info( 'Sometime after completion of %s, the test outcome was changed to "%s" because: %s', - test.short_description(), outcome, reason if reason else ".") + test.short_description(), + outcome, + reason if reason else ".", + ) for handler in logger.handlers: logging.flush.close_later(handler) @@ -450,13 +489,13 @@ def test_order(test_name): Investigate setup/teardown errors, then hooks, then test files. """ - if 'fixture_setup' in test_name: + if "fixture_setup" in test_name: return 1 - elif 'fixture_teardown' in test_name: + elif "fixture_teardown" in test_name: return 2 - elif 'fixture_abort' in test_name: + elif "fixture_abort" in test_name: return 3 - elif ':' in test_name: + elif ":" in test_name: return 4 else: return 5 diff --git a/buildscripts/resmokelib/testing/retry.py b/buildscripts/resmokelib/testing/retry.py index 58c293a540f..2f39a541c4b 100644 --- a/buildscripts/resmokelib/testing/retry.py +++ b/buildscripts/resmokelib/testing/retry.py @@ -83,4 +83,5 @@ def with_naive_retry(func, timeout=100, extra_retryable_error_codes=None): time.sleep(0.1) raise ExecutionTimeout( - f"Operation exceeded time limit after {timeout} seconds, last error: {last_exc}") + f"Operation exceeded time limit after {timeout} seconds, last error: {last_exc}" + ) diff --git a/buildscripts/resmokelib/testing/suite.py b/buildscripts/resmokelib/testing/suite.py index 3cdf4e8d531..8beee35f65b 100644 --- a/buildscripts/resmokelib/testing/suite.py +++ b/buildscripts/resmokelib/testing/suite.py @@ -63,7 +63,9 @@ def synchronized(method): class Suite(object): """A suite of tests of a particular kind (e.g. C++ unit tests, dbtests, jstests).""" - def __init__(self, suite_name, suite_config, suite_options=_config.SuiteOptions.ALL_INHERITED): + def __init__( + self, suite_name, suite_config, suite_options=_config.SuiteOptions.ALL_INHERITED + ): """Initialize the suite with the specified name and configuration.""" self._lock = threading.RLock() @@ -295,7 +297,9 @@ class Suite(object): active_report = _report.TestReport.combine(*self._partial_reports) # Use the current time as the time that this suite finished running. end_time = time.time() - return self._summarize_report(active_report, self._test_start_times[-1], end_time, sb) + return self._summarize_report( + active_report, self._test_start_times[-1], end_time, sb + ) def _summarize_repeated(self, sb): """Return the summary information of all executions. @@ -309,20 +313,28 @@ class Suite(object): start_times = self._test_start_times[:] end_times = self._test_end_times[:] if self._partial_reports: - end_times.append(time.time()) # Add an end time in this copy for the partial reports. + end_times.append( + time.time() + ) # Add an end time in this copy for the partial reports. total_time_taken = end_times[-1] - start_times[0] - sb.append("Executed %d times in %0.2f seconds:" % (num_iterations, total_time_taken)) + sb.append( + "Executed %d times in %0.2f seconds:" % (num_iterations, total_time_taken) + ) combined_summary = _summary.Summary(0, 0.0, 0, 0, 0, 0) for iteration in range(num_iterations): # Summarize each execution as a bulleted list of results. bulleter_sb = [] - summary = self._summarize_report(reports[iteration], start_times[iteration], - end_times[iteration], bulleter_sb) + summary = self._summarize_report( + reports[iteration], + start_times[iteration], + end_times[iteration], + bulleter_sb, + ) combined_summary = _summary.combine(combined_summary, summary) - for (i, line) in enumerate(bulleter_sb): + for i, line in enumerate(bulleter_sb): # Only bullet first line, indent others. prefix = "* " if i == 0 else " " sb.append(prefix + line) @@ -335,8 +347,12 @@ class Suite(object): Also append a summary of that execution onto the string builder 'sb'. """ - return self._summarize_report(self._reports[iteration], self._test_start_times[iteration], - self._test_end_times[iteration], sb) + return self._summarize_report( + self._reports[iteration], + self._test_start_times[iteration], + self._test_end_times[iteration], + sb, + ) def _summarize_report(self, report, start_time, end_time, sb): """Return the summary information of the execution. @@ -364,20 +380,36 @@ class Suite(object): sb.append("All %d test(s) passed in %0.2f seconds." % (num_run, time_taken)) return _summary.Summary(num_run, time_taken, num_run, 0, 0, 0) - summary = _summary.Summary(num_run, time_taken, report.num_succeeded, num_skipped, - num_failed, report.num_errored) + summary = _summary.Summary( + num_run, + time_taken, + report.num_succeeded, + num_skipped, + num_failed, + report.num_errored, + ) - sb.append("%d test(s) ran in %0.2f seconds" - " (%d succeeded, %d were skipped, %d failed, %d errored)" % summary) + sb.append( + "%d test(s) ran in %0.2f seconds" + " (%d succeeded, %d were skipped, %d failed, %d errored)" % summary + ) test_names = [] if num_failed > 0: sb.append("The following tests failed (with exit code):") - for test_info in itertools.chain(report.get_failed(), report.get_interrupted()): + for test_info in itertools.chain( + report.get_failed(), report.get_interrupted() + ): test_names.append(test_info.test_file) - sb.append(" %s (%d %s)" % (test_info.test_file, test_info.return_code, - translate_exit_code(test_info.return_code))) + sb.append( + " %s (%d %s)" + % ( + test_info.test_file, + test_info.return_code, + translate_exit_code(test_info.return_code), + ) + ) for exception_extractor in test_info.exception_extractors: for log_line in exception_extractor.get_exception(): @@ -395,8 +427,10 @@ class Suite(object): if num_failed > 0 or report.num_errored > 0: test_names.sort(key=_report.test_order) - sb.append("If you're unsure where to begin investigating these errors, " - "consider looking at tests in the following order:") + sb.append( + "If you're unsure where to begin investigating these errors, " + "consider looking at tests in the following order:" + ) for test_name in test_names: sb.append(" %s" % (test_name)) @@ -407,11 +441,15 @@ class Suite(object): """Log summary of all suites.""" sb = [] sb.append( - "Summary of all suites: %d suites ran in %0.2f seconds" % (len(suites), time_taken)) + "Summary of all suites: %d suites ran in %0.2f seconds" + % (len(suites), time_taken) + ) for suite in suites: suite_sb = [] suite.summarize(suite_sb) - sb.append(" %s: %s" % (suite.get_display_name(), "\n ".join(suite_sb))) + sb.append( + " %s: %s" % (suite.get_display_name(), "\n ".join(suite_sb)) + ) logger.info("=" * 80) logger.info("\n".join(sb)) diff --git a/buildscripts/resmokelib/testing/summary.py b/buildscripts/resmokelib/testing/summary.py index a5d439b64b0..62aca83de9a 100644 --- a/buildscripts/resmokelib/testing/summary.py +++ b/buildscripts/resmokelib/testing/summary.py @@ -4,7 +4,15 @@ import collections Summary = collections.namedtuple( "Summary", - ["num_run", "time_taken", "num_succeeded", "num_skipped", "num_failed", "num_errored"]) + [ + "num_run", + "time_taken", + "num_succeeded", + "num_skipped", + "num_failed", + "num_errored", + ], +) def combine(summary1, summary2): diff --git a/buildscripts/resmokelib/testing/symbolizer_service.py b/buildscripts/resmokelib/testing/symbolizer_service.py index 350a270737d..d1ad3d8df24 100644 --- a/buildscripts/resmokelib/testing/symbolizer_service.py +++ b/buildscripts/resmokelib/testing/symbolizer_service.py @@ -1,4 +1,5 @@ """Symbolize stacktraces inside test logs.""" + from __future__ import annotations import ast @@ -85,17 +86,29 @@ class ResmokeSymbolizerConfig(NamedTuple): class ResmokeSymbolizer: """Symbolize stacktraces inside test logs.""" - def __init__(self, config: Optional[ResmokeSymbolizerConfig] = None, - symbolizer_service: Optional[SymbolizerService] = None, - file_service: Optional[FileService] = None): + def __init__( + self, + config: Optional[ResmokeSymbolizerConfig] = None, + symbolizer_service: Optional[SymbolizerService] = None, + file_service: Optional[FileService] = None, + ): """Initialize instance.""" - self.config = config if config is not None else ResmokeSymbolizerConfig.from_resmoke_config( + self.config = ( + config + if config is not None + else ResmokeSymbolizerConfig.from_resmoke_config() ) - self.symbolizer_service = symbolizer_service if symbolizer_service is not None else SymbolizerService( + self.symbolizer_service = ( + symbolizer_service + if symbolizer_service is not None + else SymbolizerService() + ) + self.file_service = ( + file_service + if file_service is not None + else FileService(PROCESSED_FILES_LIST_FILE_PATH) ) - self.file_service = file_service if file_service is not None else FileService( - PROCESSED_FILES_LIST_FILE_PATH) def get_unsymbolized_stacktrace( self, @@ -134,7 +147,9 @@ class ResmokeSymbolizer: data = self.get_unsymbolized_stacktrace_data(test, files) self.make_symbolization_instructions_or_symbolize(test, data, files) - def get_unsymbolized_stacktrace_data(self, test: TestCase, files: list[str]) -> dict: + def get_unsymbolized_stacktrace_data( + self, test: TestCase, files: list[str] + ) -> dict: """ Reads each file containing unsymbolized stacktraces and stores its content. In each entry, the original name of the file and the test associated with the stacktrace is also stored. @@ -152,18 +167,23 @@ class ResmokeSymbolizer: with open(UNSYMBOLIZED_STACKTRACE_JSON, "r") as file: data = json.load(file) except Exception as ex: - test.logger.info(f"unable to read existing unsymbolized_stacktraces file: {ex}") + test.logger.info( + f"unable to read existing unsymbolized_stacktraces file: {ex}" + ) for f in files: unsymbolized_content_dict = {} try: with open(f, "r") as file: - unsymbolized_content = ','.join([line.rstrip('\n') for line in file]) - unsymbolized_content_dict = ast.literal_eval(unsymbolized_content) + unsymbolized_content = ",".join( + [line.rstrip("\n") for line in file] + ) + unsymbolized_content_dict = ast.literal_eval( + unsymbolized_content + ) except Exception as e: test.logger.error(e) - unsymbolized_stacktrace_details = { "name": f, "unsymbolized_stacktrace": unsymbolized_content_dict, @@ -272,7 +292,9 @@ class ResmokeSymbolizer: missing_keys = [] for entry in unsymbolized_stacktraces_info: unsymbolized_stacktrace = entry["unsymbolized_stacktrace"] - found_backtrace = self.get_value_recursively(unsymbolized_stacktrace, BACKTRACE_KEY) + found_backtrace = self.get_value_recursively( + unsymbolized_stacktrace, BACKTRACE_KEY + ) found_process_info = self.get_value_recursively( unsymbolized_stacktrace, PROCESS_INFO_KEY ) @@ -348,8 +370,10 @@ If no symbolized stacktrace is created, then most likely either: return False if self.config.client_id is None or self.config.client_secret is None: - test.logger.info("Symbolizer client secret and/or client ID are absent," - " skipping symbolization") + test.logger.info( + "Symbolizer client secret and/or client ID are absent," + " skipping symbolization" + ) return False if self.config.is_windows(): @@ -547,9 +571,13 @@ class SymbolizerService: ] with open(full_file_path) as file_obj: - symbolizer_process = subprocess.Popen(args=symbolizer_args, close_fds=True, - stdin=file_obj, stdout=subprocess.PIPE, - stderr=subprocess.STDOUT) + symbolizer_process = subprocess.Popen( + args=symbolizer_args, + close_fds=True, + stdin=file_obj, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) try: output, _ = symbolizer_process.communicate(timeout=retry_timeout_secs) diff --git a/buildscripts/resmokelib/testing/tags.py b/buildscripts/resmokelib/testing/tags.py index 2a76a9e9f2b..86e16650808 100644 --- a/buildscripts/resmokelib/testing/tags.py +++ b/buildscripts/resmokelib/testing/tags.py @@ -74,8 +74,11 @@ class TagsConfig(object): with open(filename, "w") as fstream: if preamble: print( - textwrap.fill(preamble, width=100, initial_indent="# ", subsequent_indent="# "), - file=fstream) + textwrap.fill( + preamble, width=100, initial_indent="# ", subsequent_indent="# " + ), + file=fstream, + ) # We use yaml.safe_dump() in order avoid having strings being written to the file as # "!!python/unicode ..." and instead have them written as plain 'str' instances. diff --git a/buildscripts/resmokelib/testing/testcases/benchmark_test.py b/buildscripts/resmokelib/testing/testcases/benchmark_test.py index c344532965f..35d42dbf1cd 100644 --- a/buildscripts/resmokelib/testing/testcases/benchmark_test.py +++ b/buildscripts/resmokelib/testing/testcases/benchmark_test.py @@ -13,7 +13,9 @@ class BenchmarkTestCase(interface.ProcessTestCase): def __init__(self, logger, program_executable, program_options=None): """Initialize the BenchmarkTestCase with the executable to run.""" - interface.ProcessTestCase.__init__(self, logger, "Benchmark test", program_executable) + interface.ProcessTestCase.__init__( + self, logger, "Benchmark test", program_executable + ) self.validate_benchmark_options() self.bm_executable = program_executable @@ -26,11 +28,16 @@ class BenchmarkTestCase(interface.ProcessTestCase): :return: None """ - if _config.REPEAT_SUITES > 1 or _config.REPEAT_TESTS > 1 or _config.REPEAT_TESTS_SECS: + if ( + _config.REPEAT_SUITES > 1 + or _config.REPEAT_TESTS > 1 + or _config.REPEAT_TESTS_SECS + ): raise ValueError( "--repeatSuites/--repeatTests cannot be used with benchmark tests. " "Please use --benchmarkMinTimeSecs to increase the runtime of a single benchmark " - "configuration.") + "configuration." + ) def configure(self, fixture, *args, **kwargs): """Configure BenchmarkTestCase.""" @@ -44,7 +51,7 @@ class BenchmarkTestCase(interface.ProcessTestCase): "benchmark_repetitions": _config.DEFAULT_BENCHMARK_REPETITIONS, # TODO: remove the following line once we bump our Google Benchmark version to one that # contains the fix for https://github.com/google/benchmark/issues/559 . - "benchmark_color": False + "benchmark_color": False, } # 2. Override Benchmark options with options set through `program_options` in the suite @@ -54,10 +61,11 @@ class BenchmarkTestCase(interface.ProcessTestCase): # 3. Override Benchmark options with options set through resmoke's command line. resmoke_bm_options = { - "benchmark_filter": _config.BENCHMARK_FILTER, "benchmark_list_tests": - _config.BENCHMARK_LIST_TESTS, "benchmark_min_time": _config.BENCHMARK_MIN_TIME, + "benchmark_filter": _config.BENCHMARK_FILTER, + "benchmark_list_tests": _config.BENCHMARK_LIST_TESTS, + "benchmark_min_time": _config.BENCHMARK_MIN_TIME, "benchmark_out_format": _config.BENCHMARK_OUT_FORMAT, - "benchmark_repetitions": _config.BENCHMARK_REPETITIONS + "benchmark_repetitions": _config.BENCHMARK_REPETITIONS, } for key, value in list(resmoke_bm_options.items()): @@ -74,4 +82,6 @@ class BenchmarkTestCase(interface.ProcessTestCase): return self.bm_executable + ".json" def _make_process(self): - return core.programs.generic_program(self.logger, [self.bm_executable], **self.bm_options) + return core.programs.generic_program( + self.logger, [self.bm_executable], **self.bm_options + ) diff --git a/buildscripts/resmokelib/testing/testcases/bulk_write_cluster_js_test.py b/buildscripts/resmokelib/testing/testcases/bulk_write_cluster_js_test.py index e05a680bbf4..d3792eb5b1c 100644 --- a/buildscripts/resmokelib/testing/testcases/bulk_write_cluster_js_test.py +++ b/buildscripts/resmokelib/testing/testcases/bulk_write_cluster_js_test.py @@ -11,9 +11,14 @@ class BulkWriteClusterTestCase(jsrunnerfile.JSRunnerFileTestCase): def __init__(self, logger, js_filename, shell_executable=None, shell_options=None): """Initialize the BulkWriteClusterTestCase.""" jsrunnerfile.JSRunnerFileTestCase.__init__( - self, logger, "BulkWriteCluster Test", js_filename, + self, + logger, + "BulkWriteCluster Test", + js_filename, test_runner_file="jstests/libs/bulk_write_passthrough_runner.js", - shell_executable=shell_executable, shell_options=shell_options) + shell_executable=shell_executable, + shell_options=shell_options, + ) @property def js_filename(self): @@ -22,5 +27,9 @@ class BulkWriteClusterTestCase(jsrunnerfile.JSRunnerFileTestCase): def _populate_test_data(self, test_data): test_data["jsTestFile"] = self.js_filename - test_data["bulkWriteCluster"] = self.fixture.clusters[0].get_driver_connection_url() - test_data["normalCluster"] = self.fixture.clusters[1].get_driver_connection_url() + test_data["bulkWriteCluster"] = self.fixture.clusters[ + 0 + ].get_driver_connection_url() + test_data["normalCluster"] = self.fixture.clusters[ + 1 + ].get_driver_connection_url() diff --git a/buildscripts/resmokelib/testing/testcases/cpp_integration_test.py b/buildscripts/resmokelib/testing/testcases/cpp_integration_test.py index 13d6c802dbb..3d8464510d5 100644 --- a/buildscripts/resmokelib/testing/testcases/cpp_integration_test.py +++ b/buildscripts/resmokelib/testing/testcases/cpp_integration_test.py @@ -12,7 +12,9 @@ class CPPIntegrationTestCase(interface.ProcessTestCase): def __init__(self, logger, program_executable, program_options=None): """Initialize the CPPIntegrationTestCase with the executable to run.""" - interface.ProcessTestCase.__init__(self, logger, "C++ integration test", program_executable) + interface.ProcessTestCase.__init__( + self, logger, "C++ integration test", program_executable + ) self.program_executable = program_executable self.program_options = utils.default_if_none(program_options, {}).copy() @@ -21,8 +23,11 @@ class CPPIntegrationTestCase(interface.ProcessTestCase): """Configure the test case.""" interface.ProcessTestCase.configure(self, fixture, *args, **kwargs) - self.program_options["connectionString"] = self.fixture.get_internal_connection_string() + self.program_options["connectionString"] = ( + self.fixture.get_internal_connection_string() + ) def _make_process(self): - return core.programs.generic_program(self.logger, [self.program_executable], - **self.program_options) + return core.programs.generic_program( + self.logger, [self.program_executable], **self.program_options + ) diff --git a/buildscripts/resmokelib/testing/testcases/cpp_libfuzzer_test.py b/buildscripts/resmokelib/testing/testcases/cpp_libfuzzer_test.py index cbe7e634963..c88df461776 100644 --- a/buildscripts/resmokelib/testing/testcases/cpp_libfuzzer_test.py +++ b/buildscripts/resmokelib/testing/testcases/cpp_libfuzzer_test.py @@ -13,11 +13,19 @@ class CPPLibfuzzerTestCase(interface.ProcessTestCase): REGISTERED_NAME = "cpp_libfuzzer_test" DEFAULT_TIMEOUT = datetime.timedelta(hours=1) - def __init__(self, logger, program_executable, program_options=None, runs=1000000, - corpus_directory_stem="corpora"): + def __init__( + self, + logger, + program_executable, + program_options=None, + runs=1000000, + corpus_directory_stem="corpora", + ): """Initialize the CPPLibfuzzerTestCase with the executable to run.""" - interface.ProcessTestCase.__init__(self, logger, "C++ libfuzzer test", program_executable) + interface.ProcessTestCase.__init__( + self, logger, "C++ libfuzzer test", program_executable + ) self.program_executable = program_executable self.program_options = utils.default_if_none(program_options, {}).copy() @@ -25,7 +33,9 @@ class CPPLibfuzzerTestCase(interface.ProcessTestCase): self.runs = runs self.corpus_directory = f"{corpus_directory_stem}/corpus-{self.short_name()}" - self.merged_corpus_directory = f"{corpus_directory_stem}-merged/corpus-{self.short_name()}" + self.merged_corpus_directory = ( + f"{corpus_directory_stem}-merged/corpus-{self.short_name()}" + ) os.makedirs(self.corpus_directory, exist_ok=True) @@ -38,4 +48,6 @@ class CPPLibfuzzerTestCase(interface.ProcessTestCase): f"-runs={self.runs}", self.corpus_directory, ] - return core.programs.make_process(self.logger, default_args, **self.program_options) + return core.programs.make_process( + self.logger, default_args, **self.program_options + ) diff --git a/buildscripts/resmokelib/testing/testcases/cpp_unittest.py b/buildscripts/resmokelib/testing/testcases/cpp_unittest.py index 48df94b8c66..de3b8234d0e 100644 --- a/buildscripts/resmokelib/testing/testcases/cpp_unittest.py +++ b/buildscripts/resmokelib/testing/testcases/cpp_unittest.py @@ -1,4 +1,5 @@ """The unittest.TestCase for C++ unit tests.""" + import os from buildscripts.resmokelib import config, core, utils @@ -13,7 +14,9 @@ class CPPUnitTestCase(interface.ProcessTestCase): def __init__(self, logger, program_executable, program_options=None): """Initialize the CPPUnitTestCase with the executable to run.""" - interface.ProcessTestCase.__init__(self, logger, "C++ unit test", program_executable) + interface.ProcessTestCase.__init__( + self, logger, "C++ unit test", program_executable + ) self.program_executable = program_executable self.program_options = utils.default_if_none(program_options, {}).copy() @@ -26,14 +29,15 @@ class CPPUnitTestCase(interface.ProcessTestCase): if config.UNDO_RECORDER_PATH: # Record the list of failed tests so we can upload them to the Evergreen task. # Non-recorded tests rely on the core dump content to identify the test binaries. - with open("failed_recorded_tests.txt", 'a') as failure_list: + with open("failed_recorded_tests.txt", "a") as failure_list: failure_list.write(self.program_executable) failure_list.write("\n") self.logger.exception( "*** Failed test run was recorded. ***\n" "For instructions on using the recording instead of core dumps, see\n" "https://wiki.corp.mongodb.com/display/COREENG/Time+Travel+Debugging+in+MongoDB\n" - "For questions or bug reports, please reach out in #server-testing") + "For questions or bug reports, please reach out in #server-testing" + ) # Archive any available recordings if there's any failure. It's possible a problem # with the recorder will cause no recordings to be generated. @@ -41,5 +45,6 @@ class CPPUnitTestCase(interface.ProcessTestCase): raise def _make_process(self): - return core.programs.make_process(self.logger, [self.program_executable], - **self.program_options) + return core.programs.make_process( + self.logger, [self.program_executable], **self.program_options + ) diff --git a/buildscripts/resmokelib/testing/testcases/dbtest.py b/buildscripts/resmokelib/testing/testcases/dbtest.py index f54b22d56f5..ea1db316a6e 100644 --- a/buildscripts/resmokelib/testing/testcases/dbtest.py +++ b/buildscripts/resmokelib/testing/testcases/dbtest.py @@ -13,13 +13,17 @@ class DBTestCase(interface.ProcessTestCase): REGISTERED_NAME = "db_test" - def __init__(self, logger, dbtest_suite, dbtest_executable=None, dbtest_options=None): + def __init__( + self, logger, dbtest_suite, dbtest_executable=None, dbtest_options=None + ): """Initialize the DBTestCase with the dbtest suite to run.""" interface.ProcessTestCase.__init__(self, logger, "dbtest suite", dbtest_suite) # Command line options override the YAML configuration. - self.dbtest_executable = utils.default_if_none(config.DBTEST_EXECUTABLE, dbtest_executable) + self.dbtest_executable = utils.default_if_none( + config.DBTEST_EXECUTABLE, dbtest_executable + ) self.dbtest_suite = dbtest_suite self.dbtest_options = utils.default_if_none(dbtest_options, {}).copy() @@ -29,7 +33,9 @@ class DBTestCase(interface.ProcessTestCase): interface.ProcessTestCase.configure(self, fixture, *args, **kwargs) # If a dbpath was specified, then use it as a container for all other dbpaths. - dbpath_prefix = self.dbtest_options.pop("dbpath", DBTestCase._get_dbpath_prefix()) + dbpath_prefix = self.dbtest_options.pop( + "dbpath", DBTestCase._get_dbpath_prefix() + ) dbpath = os.path.join(dbpath_prefix, "job%d" % self.fixture.job_num, "unittest") self.dbtest_options["dbpath"] = dbpath @@ -49,8 +55,12 @@ class DBTestCase(interface.ProcessTestCase): shutil.rmtree(self.dbtest_options["dbpath"], ignore_errors=True) def _make_process(self): - return core.programs.dbtest_program(self.logger, executable=self.dbtest_executable, - suites=[self.dbtest_suite], **self.dbtest_options) + return core.programs.dbtest_program( + self.logger, + executable=self.dbtest_executable, + suites=[self.dbtest_suite], + **self.dbtest_options, + ) @staticmethod def _get_dbpath_prefix(): diff --git a/buildscripts/resmokelib/testing/testcases/fixture.py b/buildscripts/resmokelib/testing/testcases/fixture.py index ea155022294..51e5e3c800f 100644 --- a/buildscripts/resmokelib/testing/testcases/fixture.py +++ b/buildscripts/resmokelib/testing/testcases/fixture.py @@ -15,8 +15,13 @@ class FixtureTestCase(interface.TestCase): # pylint: disable=abstract-method def __init__(self, logger, job_name, phase): """Initialize the FixtureTestCase.""" - interface.TestCase.__init__(self, logger, "Fixture test", "{}_fixture_{}".format( - job_name, phase), dynamic=True) + interface.TestCase.__init__( + self, + logger, + "Fixture test", + "{}_fixture_{}".format(job_name, phase), + dynamic=True, + ) self.job_name = job_name @@ -28,8 +33,9 @@ class FixtureSetupTestCase(FixtureTestCase): def __init__(self, logger, fixture, job_name, times_set_up): """Initialize the FixtureSetupTestCase.""" - specific_phase = "{phase}_{times_set_up}".format(phase=self.PHASE, - times_set_up=times_set_up) + specific_phase = "{phase}_{times_set_up}".format( + phase=self.PHASE, times_set_up=times_set_up + ) FixtureTestCase.__init__(self, logger, job_name, specific_phase) self.fixture = fixture @@ -41,18 +47,30 @@ class FixtureSetupTestCase(FixtureTestCase): self.fixture.setup() self.logger.info("Waiting for %s to be ready.", self.fixture) self.fixture.await_ready() - if (not isinstance(self.fixture, (fixture_interface.NoOpFixture, ExternalFixture)) - # Replica set with --configsvr cannot run refresh unless it is part of a sharded cluster. - and not (isinstance(self.fixture, ReplicaSetFixture) - and "configsvr" in self.fixture.mongod_options)): - self.fixture.mongo_client().admin.command({"refreshLogicalSessionCacheNow": 1}) + if ( + not isinstance( + self.fixture, (fixture_interface.NoOpFixture, ExternalFixture) + ) + # Replica set with --configsvr cannot run refresh unless it is part of a sharded cluster. + and not ( + isinstance(self.fixture, ReplicaSetFixture) + and "configsvr" in self.fixture.mongod_options + ) + ): + self.fixture.mongo_client().admin.command( + {"refreshLogicalSessionCacheNow": 1} + ) self.logger.info("Finished the setup of %s.", self.fixture) self.return_code = 0 except errors.ServerFailure as err: - self.logger.error("An error occurred during the setup of %s: %s", self.fixture, err) + self.logger.error( + "An error occurred during the setup of %s: %s", self.fixture, err + ) raise except: - self.logger.exception("An error occurred during the setup of %s.", self.fixture) + self.logger.exception( + "An error occurred during the setup of %s.", self.fixture + ) raise @@ -76,10 +94,14 @@ class FixtureTeardownTestCase(FixtureTestCase): self.logger.info("Finished the teardown of %s.", self.fixture) self.return_code = 0 except errors.ServerFailure as err: - self.logger.error("An error occurred during the teardown of %s: %s", self.fixture, err) + self.logger.error( + "An error occurred during the teardown of %s: %s", self.fixture, err + ) raise except: - self.logger.exception("An error occurred during the teardown of %s.", self.fixture) + self.logger.exception( + "An error occurred during the teardown of %s.", self.fixture + ) raise @@ -91,8 +113,9 @@ class FixtureAbortTestCase(FixtureTestCase): def __init__(self, logger, fixture, job_name, times_set_up): """Initialize the FixtureAbortTestCase.""" - specific_phase = "{phase}_{times_set_up}".format(phase=self.PHASE, - times_set_up=times_set_up) + specific_phase = "{phase}_{times_set_up}".format( + phase=self.PHASE, times_set_up=times_set_up + ) FixtureTestCase.__init__(self, logger, job_name, specific_phase) self.fixture = fixture @@ -100,8 +123,12 @@ class FixtureAbortTestCase(FixtureTestCase): """Tear down the fixture.""" try: self.return_code = 2 # Test return code of 2 is used for fixture failures. - self.logger.info("Aborting the fixture %s due to test failure.", self.fixture) - self.fixture.teardown(finished=False, mode=fixture_interface.TeardownMode.ABORT) + self.logger.info( + "Aborting the fixture %s due to test failure.", self.fixture + ) + self.fixture.teardown( + finished=False, mode=fixture_interface.TeardownMode.ABORT + ) self.logger.info("Finished aborting %s.", self.fixture) self.return_code = 0 except errors.ServerFailure: diff --git a/buildscripts/resmokelib/testing/testcases/fsm_workload_test.py b/buildscripts/resmokelib/testing/testcases/fsm_workload_test.py index dc3bf724145..635d9c96005 100644 --- a/buildscripts/resmokelib/testing/testcases/fsm_workload_test.py +++ b/buildscripts/resmokelib/testing/testcases/fsm_workload_test.py @@ -14,13 +14,20 @@ class _SingleFSMWorkloadTestCase(jsrunnerfile.JSRunnerFileTestCase): REGISTERED_NAME = registry.LEAVE_UNREGISTERED - def __init__(self, logger, test_name, test_id, shell_executable=None, shell_options=None): + def __init__( + self, logger, test_name, test_id, shell_executable=None, shell_options=None + ): """Initialize the _SingleFSMWorkloadTestCase with the FSM workload file.""" jsrunnerfile.JSRunnerFileTestCase.__init__( - self, logger, "FSM workload", test_name, + self, + logger, + "FSM workload", + test_name, test_runner_file="jstests/concurrency/fsm_libs/resmoke_runner.js", - shell_executable=shell_executable, shell_options=shell_options) + shell_executable=shell_executable, + shell_options=shell_options, + ) self._id = test_id def configure(self, fixture, *args, **kwargs): @@ -34,10 +41,22 @@ class _FSMWorkloadTestCaseBuilder(interface.TestCaseFactory): _COUNTER_LOCK = threading.Lock() _COUNTER = 0 - def __init__(self, logger, fsm_workload_group, test_name, test_id, shell_executable=None, - shell_options=None, same_db=False, same_collection=False, db_name_prefix=None): + def __init__( + self, + logger, + fsm_workload_group, + test_name, + test_id, + shell_executable=None, + shell_options=None, + same_db=False, + same_collection=False, + db_name_prefix=None, + ): """Initialize the _FSMWorkloadTestCaseBuilder.""" - interface.TestCaseFactory.__init__(self, _SingleFSMWorkloadTestCase, shell_options) + interface.TestCaseFactory.__init__( + self, _SingleFSMWorkloadTestCase, shell_options + ) self.logger = logger self.fsm_workload_group = fsm_workload_group self.test_name = test_name @@ -88,8 +107,9 @@ class _FSMWorkloadTestCaseBuilder(interface.TestCaseFactory): return test_case._make_process() # pylint: disable=protected-access def create_test_case(self, logger, shell_options) -> _SingleFSMWorkloadTestCase: - test_case = _SingleFSMWorkloadTestCase(logger, self.test_name, self.test_id, - self.shell_executable, shell_options) + test_case = _SingleFSMWorkloadTestCase( + logger, self.test_name, self.test_id, self.shell_executable, shell_options + ) test_case.configure(self.fixture) return test_case @@ -100,18 +120,35 @@ class FSMWorkloadTestCase(jstest.MultiClientsTestCase): REGISTERED_NAME = "fsm_workload_test" TEST_KIND = "FSM workload" - def __init__(self, logger, selected_tests, shell_executable=None, shell_options=None, - same_db=False, same_collection=False, db_name_prefix=None): + def __init__( + self, + logger, + selected_tests, + shell_executable=None, + shell_options=None, + same_db=False, + same_collection=False, + db_name_prefix=None, + ): """Initialize the FSMWorkloadTestCase with the FSM workload file.""" fsm_workload_group = self.get_workload_group(selected_tests) test_name = self.get_workload_uid(selected_tests) test_id = uuid.uuid4() - factory = _FSMWorkloadTestCaseBuilder(logger, fsm_workload_group, test_name, test_id, - shell_executable, shell_options, same_db, - same_collection, db_name_prefix) - jstest.MultiClientsTestCase.__init__(self, logger, self.TEST_KIND, test_name, test_id, - factory) + factory = _FSMWorkloadTestCaseBuilder( + logger, + fsm_workload_group, + test_name, + test_id, + shell_executable, + shell_options, + same_db, + same_collection, + db_name_prefix, + ) + jstest.MultiClientsTestCase.__init__( + self, logger, self.TEST_KIND, test_name, test_id, factory + ) @staticmethod def get_workload_group(selected_tests): diff --git a/buildscripts/resmokelib/testing/testcases/gennylib_test.py b/buildscripts/resmokelib/testing/testcases/gennylib_test.py index 18598043b2e..fe8ecb60ae4 100644 --- a/buildscripts/resmokelib/testing/testcases/gennylib_test.py +++ b/buildscripts/resmokelib/testing/testcases/gennylib_test.py @@ -12,7 +12,9 @@ class GennyLibTestCase(interface.ProcessTestCase): REGISTERED_NAME = "gennylib_test" - def __init__(self, logger, test_name, program_executable=None, verbatim_arguments=None): + def __init__( + self, logger, test_name, program_executable=None, verbatim_arguments=None + ): """ Initialize the GennyLibTestCase with the executable to run. @@ -21,7 +23,9 @@ class GennyLibTestCase(interface.ProcessTestCase): and CTest can't use resmoke's version of "=" separated keyword arguments. """ - interface.ProcessTestCase.__init__(self, logger, "gennylib test", program_executable) + interface.ProcessTestCase.__init__( + self, logger, "gennylib test", program_executable + ) self.program_executable = program_executable self.program_options = {} @@ -36,6 +40,8 @@ class GennyLibTestCase(interface.ProcessTestCase): self.program_options["process_kwargs"] = process_kwargs def _make_process(self): - return core.programs.generic_program(self.logger, - [self.program_executable] + self.verbatim_arguments, - **self.program_options) + return core.programs.generic_program( + self.logger, + [self.program_executable] + self.verbatim_arguments, + **self.program_options, + ) diff --git a/buildscripts/resmokelib/testing/testcases/gennytest.py b/buildscripts/resmokelib/testing/testcases/gennytest.py index 4e7634482a4..70970c66d96 100644 --- a/buildscripts/resmokelib/testing/testcases/gennytest.py +++ b/buildscripts/resmokelib/testing/testcases/gennytest.py @@ -12,11 +12,17 @@ class GennyTestCase(interface.ProcessTestCase): REGISTERED_NAME = "genny_test" - def __init__(self, logger, genny_workload, genny_executable=None, genny_options=None): + def __init__( + self, logger, genny_workload, genny_executable=None, genny_options=None + ): """Init the GennyTestCase with the genny workload to run.""" - interface.ProcessTestCase.__init__(self, logger, "Genny workload", genny_workload) + interface.ProcessTestCase.__init__( + self, logger, "Genny workload", genny_workload + ) - self.genny_executable = utils.default_if_none(config.GENNY_EXECUTABLE, genny_executable) + self.genny_executable = utils.default_if_none( + config.GENNY_EXECUTABLE, genny_executable + ) self.genny_options = utils.default_if_none(genny_options, {}).copy() self.genny_options["workload-file"] = genny_workload @@ -37,4 +43,6 @@ class GennyTestCase(interface.ProcessTestCase): self.genny_options.setdefault("metrics-output-file", output_file) def _make_process(self): - return core.programs.genny_program(self.logger, self.genny_executable, **self.genny_options) + return core.programs.genny_program( + self.logger, self.genny_executable, **self.genny_options + ) diff --git a/buildscripts/resmokelib/testing/testcases/interface.py b/buildscripts/resmokelib/testing/testcases/interface.py index 61aaabe4ec6..c279281e174 100644 --- a/buildscripts/resmokelib/testing/testcases/interface.py +++ b/buildscripts/resmokelib/testing/testcases/interface.py @@ -2,6 +2,7 @@ This is used to perform the actual test case. """ + import glob import os import os.path @@ -15,14 +16,16 @@ from buildscripts.resmokelib.utils import registry _TEST_CASES: Dict[str, Callable] = {} # type: ignore -def make_test_case(test_kind, *args, **kwargs) -> 'TestCase': +def make_test_case(test_kind, *args, **kwargs) -> "TestCase": """Provide factory function for creating TestCase instances.""" if test_kind not in _TEST_CASES: raise ValueError("Unknown test kind '%s'" % test_kind) return _TEST_CASES[test_kind](*args, **kwargs) -class TestCase(unittest.TestCase, metaclass=registry.make_registry_metaclass(_TEST_CASES)): # pylint: disable=invalid-metaclass +class TestCase( + unittest.TestCase, metaclass=registry.make_registry_metaclass(_TEST_CASES) +): # pylint: disable=invalid-metaclass """A test case to execute.""" REGISTERED_NAME = registry.LEAVE_UNREGISTERED @@ -146,7 +149,7 @@ class UndoDBUtilsMixin: # But that should be rare and there's no harm in having more recordings stored. for recording in glob.glob(program_executable + "*.undo"): self.logger.info("Keeping recording %s", recording) - os.rename(recording, recording + '.tokeep') + os.rename(recording, recording + ".tokeep") class ProcessTestCase(TestCase, UndoDBUtilsMixin): @@ -160,8 +163,9 @@ class ProcessTestCase(TestCase, UndoDBUtilsMixin): except self.failureException: raise except: - self.logger.exception("Encountered an error running %s %s", self.test_kind, - self.basename()) + self.logger.exception( + "Encountered an error running %s %s", self.test_kind, self.basename() + ) raise def as_command(self): @@ -170,10 +174,14 @@ class ProcessTestCase(TestCase, UndoDBUtilsMixin): def _execute(self, process): """Run the specified process.""" - self.logger.info("Starting %s...\n%s", self.short_description(), process.as_command()) + self.logger.info( + "Starting %s...\n%s", self.short_description(), process.as_command() + ) process.start() - self.logger.info("%s started with pid %s.", self.short_description(), process.pid) + self.logger.info( + "%s started with pid %s.", self.short_description(), process.pid + ) self.return_code = process.wait() if self.return_code != 0: @@ -183,36 +191,48 @@ class ProcessTestCase(TestCase, UndoDBUtilsMixin): def _make_process(self): """Return a new Process instance that could be used to run the test or log the command.""" - raise NotImplementedError("_make_process must be implemented by TestCase subclasses") + raise NotImplementedError( + "_make_process must be implemented by TestCase subclasses" + ) class TestCaseFactory: def __init__(self, factory_class, shell_options): if not issubclass(factory_class, TestCase): - raise TypeError("factory_class should be a subclass of Interface.TestCase", - factory_class) + raise TypeError( + "factory_class should be a subclass of Interface.TestCase", + factory_class, + ) self._factory_class = factory_class self.shell_options = shell_options def create_test_case(self, logger, shell_options) -> TestCase: raise NotImplementedError( - "create_test_case must be implemented by TestCaseFactory subclasses") + "create_test_case must be implemented by TestCaseFactory subclasses" + ) - def create_test_case_for_thread(self, logger, num_clients=1, thread_id=0, - tenant_id=None) -> TestCase: + def create_test_case_for_thread( + self, logger, num_clients=1, thread_id=0, tenant_id=None + ) -> TestCase: """Create and configure a TestCase to be run in a separate thread.""" - shell_options = self._get_shell_options_for_thread(num_clients, thread_id, tenant_id) + shell_options = self._get_shell_options_for_thread( + num_clients, thread_id, tenant_id + ) test_case = self.create_test_case(logger, shell_options) return test_case def configure(self, fixture, *args, **kwargs): """Configure the test case factory.""" - raise NotImplementedError("configure must be implemented by TestCaseFactory subclasses") + raise NotImplementedError( + "configure must be implemented by TestCaseFactory subclasses" + ) def make_process(self): """Make a process for a TestCase.""" - raise NotImplementedError("make_process must be implemented by TestCaseFactory subclasses") + raise NotImplementedError( + "make_process must be implemented by TestCaseFactory subclasses" + ) def _get_shell_options_for_thread(self, num_clients, thread_id, tenant_id): """Get shell_options with an initialized TestData object for given thread.""" diff --git a/buildscripts/resmokelib/testing/testcases/json_schema_test.py b/buildscripts/resmokelib/testing/testcases/json_schema_test.py index 426250e339f..c1dd37c6305 100644 --- a/buildscripts/resmokelib/testing/testcases/json_schema_test.py +++ b/buildscripts/resmokelib/testing/testcases/json_schema_test.py @@ -8,13 +8,20 @@ class JSONSchemaTestCase(jsrunnerfile.JSRunnerFileTestCase): REGISTERED_NAME = "json_schema_test" - def __init__(self, logger, json_filename, shell_executable=None, shell_options=None): + def __init__( + self, logger, json_filename, shell_executable=None, shell_options=None + ): """Initialize the JSONSchemaTestCase with the JSON test file.""" jsrunnerfile.JSRunnerFileTestCase.__init__( - self, logger, "JSON Schema test", json_filename, + self, + logger, + "JSON Schema test", + json_filename, test_runner_file="jstests/libs/json_schema_test_runner.js", - shell_executable=shell_executable, shell_options=shell_options) + shell_executable=shell_executable, + shell_options=shell_options, + ) @property def json_filename(self): diff --git a/buildscripts/resmokelib/testing/testcases/jsrunnerfile.py b/buildscripts/resmokelib/testing/testcases/jsrunnerfile.py index bd51b700836..346a718039b 100644 --- a/buildscripts/resmokelib/testing/testcases/jsrunnerfile.py +++ b/buildscripts/resmokelib/testing/testcases/jsrunnerfile.py @@ -10,14 +10,23 @@ class JSRunnerFileTestCase(interface.ProcessTestCase): REGISTERED_NAME = registry.LEAVE_UNREGISTERED - def __init__(self, logger, test_kind, test_name, test_runner_file, shell_executable=None, - shell_options=None): + def __init__( + self, + logger, + test_kind, + test_name, + test_runner_file, + shell_executable=None, + shell_options=None, + ): """Initialize the JSRunnerFileTestCase with the 'test_name' file.""" interface.ProcessTestCase.__init__(self, logger, test_kind, test_name) # Command line options override the YAML configuration. - self.shell_executable = utils.default_if_none(config.MONGO_EXECUTABLE, shell_executable) + self.shell_executable = utils.default_if_none( + config.MONGO_EXECUTABLE, shell_executable + ) self.shell_options = utils.default_if_none(shell_options, {}).copy() self.test_runner_file = test_runner_file @@ -44,6 +53,10 @@ class JSRunnerFileTestCase(interface.ProcessTestCase): def _make_process(self): return core.programs.mongo_shell_program( - self.logger, executable=self.shell_executable, + self.logger, + executable=self.shell_executable, connection_string=self.fixture.get_shell_connection_url(), - filename=self.test_runner_file, test_filename=self.test_name, **self.shell_options) + filename=self.test_runner_file, + test_filename=self.test_name, + **self.shell_options, + ) diff --git a/buildscripts/resmokelib/testing/testcases/jstest.py b/buildscripts/resmokelib/testing/testcases/jstest.py index 2b08c82c322..777c989c6de 100644 --- a/buildscripts/resmokelib/testing/testcases/jstest.py +++ b/buildscripts/resmokelib/testing/testcases/jstest.py @@ -20,12 +20,16 @@ class _SingleJSTestCase(interface.ProcessTestCase): REGISTERED_NAME = registry.LEAVE_UNREGISTERED - def __init__(self, logger, js_filename, _id, shell_executable=None, shell_options=None): + def __init__( + self, logger, js_filename, _id, shell_executable=None, shell_options=None + ): """Initialize the _SingleJSTestCase with the JS file to run.""" interface.ProcessTestCase.__init__(self, logger, "JSTest", js_filename) # Command line options override the YAML configuration. - self.shell_executable = utils.default_if_none(config.MONGO_EXECUTABLE, shell_executable) + self.shell_executable = utils.default_if_none( + config.MONGO_EXECUTABLE, shell_executable + ) self.js_filename = js_filename self._id = _id @@ -46,18 +50,27 @@ class _SingleJSTestCase(interface.ProcessTestCase): data_dir = self._get_data_dir(global_vars) # Set MongoRunner.dataPath if overridden at command line or not specified in YAML. - if config.DBPATH_PREFIX is not None or "MongoRunner.dataPath" not in global_vars: + if ( + config.DBPATH_PREFIX is not None + or "MongoRunner.dataPath" not in global_vars + ): # dataPath property is the dataDir property with a trailing slash. data_path = os.path.join(data_dir, "") else: - data_path = os.path.join(os.path.abspath(global_vars["MongoRunner.dataPath"]), "") + data_path = os.path.join( + os.path.abspath(global_vars["MongoRunner.dataPath"]), "" + ) global_vars["MongoRunner.dataDir"] = data_dir global_vars["MongoRunner.dataPath"] = data_path test_data = global_vars.get("TestData", {}).copy() - test_data["minPort"] = core.network.PortAllocator.min_test_port(self.fixture.job_num) - test_data["maxPort"] = core.network.PortAllocator.max_test_port(self.fixture.job_num) + test_data["minPort"] = core.network.PortAllocator.min_test_port( + self.fixture.job_num + ) + test_data["maxPort"] = core.network.PortAllocator.max_test_port( + self.fixture.job_num + ) test_data["peerPids"] = self.fixture.pids() test_data["alwaysUseLogFiles"] = config.ALWAYS_USE_LOG_FILES test_data["ignoreUnterminatedProcesses"] = False @@ -65,8 +78,9 @@ class _SingleJSTestCase(interface.ProcessTestCase): # The tests in 'timeseries' directory need to use a different logic for implicity sharding # the collection. Make sure that we consider both unix and windows directory structures. - test_data["implicitlyShardOnCreateCollectionOnly"] = "/timeseries/" in self.js_filename or \ - "\\timeseries\\" in self.js_filename + test_data["implicitlyShardOnCreateCollectionOnly"] = ( + "/timeseries/" in self.js_filename or "\\timeseries\\" in self.js_filename + ) global_vars["TestData"] = test_data self.shell_options["global_vars"] = global_vars @@ -81,10 +95,12 @@ class _SingleJSTestCase(interface.ProcessTestCase): process_kwargs = copy.deepcopy(self.shell_options.get("process_kwargs", {})) - if process_kwargs \ - and "env_vars" in process_kwargs \ - and "KRB5_CONFIG" in process_kwargs["env_vars"] \ - and "KRB5CCNAME" not in process_kwargs["env_vars"]: + if ( + process_kwargs + and "env_vars" in process_kwargs + and "KRB5_CONFIG" in process_kwargs["env_vars"] + and "KRB5CCNAME" not in process_kwargs["env_vars"] + ): # Use a job-specific credential cache for JavaScript tests involving Kerberos. krb5_dir = os.path.join(data_dir, "krb5") @@ -100,26 +116,40 @@ class _SingleJSTestCase(interface.ProcessTestCase): def _get_data_dir(self, global_vars): """Return the value that mongo shell should set for the MongoRunner.dataDir property.""" # Command line options override the YAML configuration. - data_dir_prefix = utils.default_if_none(config.DBPATH_PREFIX, - global_vars.get("MongoRunner.dataDir")) - data_dir_prefix = utils.default_if_none(data_dir_prefix, config.DEFAULT_DBPATH_PREFIX) + data_dir_prefix = utils.default_if_none( + config.DBPATH_PREFIX, global_vars.get("MongoRunner.dataDir") + ) + data_dir_prefix = utils.default_if_none( + data_dir_prefix, config.DEFAULT_DBPATH_PREFIX + ) return os.path.abspath( - os.path.join(data_dir_prefix, "job%d" % self.fixture.job_num, - config.MONGO_RUNNER_SUBDIR)) + os.path.join( + data_dir_prefix, + "job%d" % self.fixture.job_num, + config.MONGO_RUNNER_SUBDIR, + ) + ) def _make_process(self): return core.programs.mongo_shell_program( - self.logger, executable=self.shell_executable, filename=self.js_filename, - connection_string=self.fixture.get_shell_connection_url(), **self.shell_options) + self.logger, + executable=self.shell_executable, + filename=self.js_filename, + connection_string=self.fixture.get_shell_connection_url(), + **self.shell_options, + ) class JSTestCaseBuilder(interface.TestCaseFactory): """Build the real TestCase in the JSTestCase wrapper.""" - def __init__(self, logger, js_filename, test_id, shell_executable=None, shell_options=None): + def __init__( + self, logger, js_filename, test_id, shell_executable=None, shell_options=None + ): """Initialize the JSTestCase with the JS file to run.""" - self.test_case_template = _SingleJSTestCase(logger, js_filename, test_id, shell_executable, - shell_options) + self.test_case_template = _SingleJSTestCase( + logger, js_filename, test_id, shell_executable, shell_options + ) interface.TestCaseFactory.__init__(self, _SingleJSTestCase, shell_options) def configure(self, fixture, *args, **kwargs): @@ -133,9 +163,13 @@ class JSTestCaseBuilder(interface.TestCaseFactory): return self.test_case_template._make_process() # pylint: disable=protected-access def create_test_case(self, logger, shell_options): - test_case = _SingleJSTestCase(logger, self.test_case_template.js_filename, - self.test_case_template._id, - self.test_case_template.shell_executable, shell_options) + test_case = _SingleJSTestCase( + logger, + self.test_case_template.js_filename, + self.test_case_template._id, + self.test_case_template.shell_executable, + shell_options, + ) test_case.configure(self.test_case_template.fixture) return test_case @@ -170,8 +204,13 @@ class MultiClientsTestCase(interface.TestCase, interface.UndoDBUtilsMixin): self._factory = factory def configure( # pylint: disable=arguments-differ,keyword-arg-before-vararg - self, fixture, num_clients=DEFAULT_CLIENT_NUM, use_tenant_client=False, *args, - **kwargs): + self, + fixture, + num_clients=DEFAULT_CLIENT_NUM, + use_tenant_client=False, + *args, + **kwargs, + ): """Configure the test case and its factory.""" super().configure(fixture, *args, **kwargs) self.num_clients = num_clients @@ -184,8 +223,9 @@ class MultiClientsTestCase(interface.TestCase, interface.UndoDBUtilsMixin): def _run_single_copy(self): tenant_id = str(ObjectId()) if self.use_tenant_client else None - test_case = self._factory.create_test_case_for_thread(self.logger, num_clients=1, - thread_id=0, tenant_id=tenant_id) + test_case = self._factory.create_test_case_for_thread( + self.logger, num_clients=1, thread_id=0, tenant_id=tenant_id + ) try: test_case.run_test() @@ -201,18 +241,24 @@ class MultiClientsTestCase(interface.TestCase, interface.UndoDBUtilsMixin): # If there are multiple clients, make a new thread for each client. for thread_id in range(self.num_clients): tenant_id = str(ObjectId()) if self.use_tenant_client else None - logger = logging.loggers.new_test_thread_logger(self.logger, self.test_kind, - str(thread_id), tenant_id) + logger = logging.loggers.new_test_thread_logger( + self.logger, self.test_kind, str(thread_id), tenant_id + ) test_case = self._factory.create_test_case_for_thread( - logger, num_clients=self.num_clients, thread_id=thread_id, tenant_id=tenant_id) + logger, + num_clients=self.num_clients, + thread_id=thread_id, + tenant_id=tenant_id, + ) test_cases.append(test_case) thread = self.ThreadWithException(target=test_case.run_test) threads.append(thread) thread.start() except: - self.logger.exception("Encountered an error starting threads for jstest %s.", - self.basename()) + self.logger.exception( + "Encountered an error starting threads for jstest %s.", self.basename() + ) raise finally: for thread in threads: @@ -226,12 +272,15 @@ class MultiClientsTestCase(interface.TestCase, interface.UndoDBUtilsMixin): return_code = test_case.return_code self.return_code = return_code - for (thread_id, thread) in enumerate(threads): + for thread_id, thread in enumerate(threads): if thread.exc_info is not None: if not isinstance(thread.exc_info[1], self.failureException): self.logger.error( - "Encountered an error inside thread %d running jstest %s.", thread_id, - self.basename(), exc_info=thread.exc_info) + "Encountered an error inside thread %d running jstest %s.", + thread_id, + self.basename(), + exc_info=thread.exc_info, + ) raise thread.exc_info[1] def run_test(self): @@ -255,7 +304,8 @@ class MultiClientsTestCase(interface.TestCase, interface.UndoDBUtilsMixin): if return_code not in (252, 253, 0): self.propagate_error = errors.UnsafeExitError( f"Mongo shell exited with code {return_code} while running jstest {self.basename()}." - " Further test execution may be unsafe.") + " Further test execution may be unsafe." + ) raise self.propagate_error # pylint: disable=raising-bad-type @@ -269,8 +319,12 @@ class JSTestCase(MultiClientsTestCase): """Initialize the TestCase for running JS files.""" test_id = uuid.uuid4() - factory = JSTestCaseBuilder(logger, js_filename, test_id, shell_executable, shell_options) - MultiClientsTestCase.__init__(self, logger, self.TEST_KIND, js_filename, test_id, factory) + factory = JSTestCaseBuilder( + logger, js_filename, test_id, shell_executable, shell_options + ) + MultiClientsTestCase.__init__( + self, logger, self.TEST_KIND, js_filename, test_id, factory + ) class AllVersionsJSTestCase(JSTestCase): diff --git a/buildscripts/resmokelib/testing/testcases/mongos_test.py b/buildscripts/resmokelib/testing/testcases/mongos_test.py index d0b3fdd74b1..a353049b7e1 100644 --- a/buildscripts/resmokelib/testing/testcases/mongos_test.py +++ b/buildscripts/resmokelib/testing/testcases/mongos_test.py @@ -12,10 +12,13 @@ class MongosTestCase(interface.ProcessTestCase): def __init__(self, logger, mongos_options): """Initialize the mongos test and saves the options.""" - self.mongos_executable = utils.default_if_none(config.MONGOS_EXECUTABLE, - config.DEFAULT_MONGOS_EXECUTABLE) + self.mongos_executable = utils.default_if_none( + config.MONGOS_EXECUTABLE, config.DEFAULT_MONGOS_EXECUTABLE + ) # Use the executable as the test name. - interface.ProcessTestCase.__init__(self, logger, "mongos test", self.mongos_executable) + interface.ProcessTestCase.__init__( + self, logger, "mongos test", self.mongos_executable + ) self.options = mongos_options.copy() def configure(self, fixture, *args, **kwargs): @@ -27,5 +30,9 @@ class MongosTestCase(interface.ProcessTestCase): self.options["test"] = "" def _make_process(self): - return core.programs.mongos_program(self.logger, self.fixture.job_num, - executable=self.mongos_executable, **self.options) + return core.programs.mongos_program( + self.logger, + self.fixture.job_num, + executable=self.mongos_executable, + **self.options, + ) diff --git a/buildscripts/resmokelib/testing/testcases/mql_model_haskell_test.py b/buildscripts/resmokelib/testing/testcases/mql_model_haskell_test.py index c3edfeba065..f3b25e6470e 100644 --- a/buildscripts/resmokelib/testing/testcases/mql_model_haskell_test.py +++ b/buildscripts/resmokelib/testing/testcases/mql_model_haskell_test.py @@ -16,25 +16,39 @@ class MqlModelHaskellTestCase(interface.ProcessTestCase): def __init__(self, logger, json_filename, mql_executable=None): """Initialize the MqlModelHaskellTestCase with the executable to run.""" - interface.ProcessTestCase.__init__(self, logger, "MQL Haskell Model test", json_filename) + interface.ProcessTestCase.__init__( + self, logger, "MQL Haskell Model test", json_filename + ) self.json_test_file = json_filename # Determine the top level directory where we start a search for a mql binary - self.top_level_dirname = os.path.join(os.path.normpath(json_filename).split(os.sep)[0], "") + self.top_level_dirname = os.path.join( + os.path.normpath(json_filename).split(os.sep)[0], "" + ) # Our haskell cabal build produces binaries in an unique directory # .../dist-sandbox-/... # so we use a glob pattern to fish out the binary - mql_executable = utils.default_if_none(mql_executable, "mql-model/dist/dist*/build/mql/mql") + mql_executable = utils.default_if_none( + mql_executable, "mql-model/dist/dist*/build/mql/mql" + ) execs = globstar.glob(mql_executable) if len(execs) != 1: - raise errors.StopExecution("There must be a single mql binary in {}".format(execs)) + raise errors.StopExecution( + "There must be a single mql binary in {}".format(execs) + ) self.program_executable = execs[0] def _make_process(self): - return core.programs.make_process(self.logger, [ - self.program_executable, "--test", self.json_test_file, "--prefix", - self.top_level_dirname - ]) + return core.programs.make_process( + self.logger, + [ + self.program_executable, + "--test", + self.json_test_file, + "--prefix", + self.top_level_dirname, + ], + ) diff --git a/buildscripts/resmokelib/testing/testcases/mql_model_mongod_test.py b/buildscripts/resmokelib/testing/testcases/mql_model_mongod_test.py index 2c6332dd344..b9ace7a7600 100644 --- a/buildscripts/resmokelib/testing/testcases/mql_model_mongod_test.py +++ b/buildscripts/resmokelib/testing/testcases/mql_model_mongod_test.py @@ -11,13 +11,20 @@ class MqlModelMongodTestCase(jsrunnerfile.JSRunnerFileTestCase): REGISTERED_NAME = "mql_model_mongod_test" - def __init__(self, logger, json_filename, shell_executable=None, shell_options=None): + def __init__( + self, logger, json_filename, shell_executable=None, shell_options=None + ): """Initialize the MqlModelMongodTestCase with the JSON test file.""" jsrunnerfile.JSRunnerFileTestCase.__init__( - self, logger, "MQL MongoD Model test", json_filename, + self, + logger, + "MQL MongoD Model test", + json_filename, test_runner_file="jstests/libs/mql_model_mongod_test_runner.js", - shell_executable=shell_executable, shell_options=shell_options) + shell_executable=shell_executable, + shell_options=shell_options, + ) @property def json_filename(self): diff --git a/buildscripts/resmokelib/testing/testcases/multi_stmt_txn_test.py b/buildscripts/resmokelib/testing/testcases/multi_stmt_txn_test.py index c922fbe34a9..141a7d0d6a7 100644 --- a/buildscripts/resmokelib/testing/testcases/multi_stmt_txn_test.py +++ b/buildscripts/resmokelib/testing/testcases/multi_stmt_txn_test.py @@ -8,12 +8,23 @@ class MultiStmtTxnTestCase(jsrunnerfile.JSRunnerFileTestCase): REGISTERED_NAME = "multi_stmt_txn_passthrough" - def __init__(self, logger, multi_stmt_txn_test_file, shell_executable=None, shell_options=None): + def __init__( + self, + logger, + multi_stmt_txn_test_file, + shell_executable=None, + shell_options=None, + ): """Initilize MultiStmtTxnTestCase.""" jsrunnerfile.JSRunnerFileTestCase.__init__( - self, logger, "Multi-statement Transaction Passthrough", multi_stmt_txn_test_file, + self, + logger, + "Multi-statement Transaction Passthrough", + multi_stmt_txn_test_file, test_runner_file="jstests/libs/txns/txn_passthrough_runner.js", - shell_executable=shell_executable, shell_options=shell_options) + shell_executable=shell_executable, + shell_options=shell_options, + ) @property def multi_stmt_txn_test_file(self): @@ -23,5 +34,6 @@ class MultiStmtTxnTestCase(jsrunnerfile.JSRunnerFileTestCase): def _populate_test_data(self, test_data): test_data["multiStmtTxnTestFile"] = self.multi_stmt_txn_test_file test_data["peerPids"] = self.fixture.pids() - test_data["implicitlyShardOnCreateCollectionOnly"] = "/timeseries/" in self.test_name or \ - "\\timeseries\\" in self.test_name + test_data["implicitlyShardOnCreateCollectionOnly"] = ( + "/timeseries/" in self.test_name or "\\timeseries\\" in self.test_name + ) diff --git a/buildscripts/resmokelib/testing/testcases/pretty_printer_testcase.py b/buildscripts/resmokelib/testing/testcases/pretty_printer_testcase.py index c133dd17f31..f8ba9f17fa1 100644 --- a/buildscripts/resmokelib/testing/testcases/pretty_printer_testcase.py +++ b/buildscripts/resmokelib/testing/testcases/pretty_printer_testcase.py @@ -12,11 +12,14 @@ class PrettyPrinterTestCase(interface.ProcessTestCase): def __init__(self, logger, program_executable, program_options=None): """Initialize the PrettyPrinterTestCase with the executable to run.""" - interface.ProcessTestCase.__init__(self, logger, "pretty printer test", program_executable) + interface.ProcessTestCase.__init__( + self, logger, "pretty printer test", program_executable + ) self.program_executable = program_executable self.program_options = utils.default_if_none(program_options, {}).copy() def _make_process(self): - return core.programs.make_process(self.logger, [self.program_executable], - **self.program_options) + return core.programs.make_process( + self.logger, [self.program_executable], **self.program_options + ) diff --git a/buildscripts/resmokelib/testing/testcases/pytest.py b/buildscripts/resmokelib/testing/testcases/pytest.py index 46cd7d5afa8..dc8c3e91781 100644 --- a/buildscripts/resmokelib/testing/testcases/pytest.py +++ b/buildscripts/resmokelib/testing/testcases/pytest.py @@ -1,4 +1,5 @@ """The unittest.TestCase for Python unittests.""" + import os import sys @@ -17,7 +18,8 @@ class PyTestCase(interface.ProcessTestCase): def _make_process(self): return core.programs.generic_program( - self.logger, [sys.executable, "-m", "unittest", self.test_module_name]) + self.logger, [sys.executable, "-m", "unittest", self.test_module_name] + ) @property def test_module_name(self): diff --git a/buildscripts/resmokelib/testing/testcases/sdam_json_test.py b/buildscripts/resmokelib/testing/testcases/sdam_json_test.py index 516fd8dac1a..74251196a7c 100644 --- a/buildscripts/resmokelib/testing/testcases/sdam_json_test.py +++ b/buildscripts/resmokelib/testing/testcases/sdam_json_test.py @@ -1,4 +1,5 @@ """The unittest.TestCase for Server Discovery and Monitoring JSON tests.""" + import os import os.path @@ -14,7 +15,9 @@ class SDAMJsonTestCase(interface.ProcessTestCase): def __init__(self, logger, json_test_file, program_options=None): """Initialize the TestCase with the executable to run.""" - interface.ProcessTestCase.__init__(self, logger, "SDAM Json Test", json_test_file) + interface.ProcessTestCase.__init__( + self, logger, "SDAM Json Test", json_test_file + ) self.program_executable = self._find_executable() self.json_test_file = os.path.normpath(json_test_file) @@ -26,11 +29,15 @@ class SDAMJsonTestCase(interface.ProcessTestCase): binary += ".exe" if not os.path.isfile(binary): - raise errors.StopExecution(f"Failed to locate sdam_json_test binary at {binary}") + raise errors.StopExecution( + f"Failed to locate sdam_json_test binary at {binary}" + ) return binary def _make_process(self): command_line = [self.program_executable] command_line += ["--source-dir", self.TEST_DIR] command_line += ["-f", self.json_test_file] - return core.programs.make_process(self.logger, command_line, **self.program_options) + return core.programs.make_process( + self.logger, command_line, **self.program_options + ) diff --git a/buildscripts/resmokelib/testing/testcases/server_selection_json_test.py b/buildscripts/resmokelib/testing/testcases/server_selection_json_test.py index c8fe6f1dece..8ad804670f3 100644 --- a/buildscripts/resmokelib/testing/testcases/server_selection_json_test.py +++ b/buildscripts/resmokelib/testing/testcases/server_selection_json_test.py @@ -1,4 +1,5 @@ """The unittest.TestCase for Server Selection JSON tests.""" + import os import os.path @@ -10,12 +11,15 @@ class ServerSelectionJsonTestCase(interface.ProcessTestCase): """Server Selection JSON test case.""" REGISTERED_NAME = "server_selection_json_test" - TEST_DIR = os.path.normpath("src/mongo/client/sdam/json_tests/server_selection_tests") + TEST_DIR = os.path.normpath( + "src/mongo/client/sdam/json_tests/server_selection_tests" + ) def __init__(self, logger, json_test_file, program_options=None): """Initialize the TestCase with the executable to run.""" - interface.ProcessTestCase.__init__(self, logger, "Server Selection Json Test", - json_test_file) + interface.ProcessTestCase.__init__( + self, logger, "Server Selection Json Test", json_test_file + ) self.program_executable = self._find_executable() self.json_test_file = os.path.normpath(json_test_file) @@ -28,11 +32,14 @@ class ServerSelectionJsonTestCase(interface.ProcessTestCase): if not os.path.isfile(binary): raise errors.StopExecution( - f"Failed to locate server_selection_json_test binary at {binary}") + f"Failed to locate server_selection_json_test binary at {binary}" + ) return binary def _make_process(self): command_line = [self.program_executable] command_line += ["--source-dir", self.TEST_DIR] command_line += ["-f", self.json_test_file] - return core.programs.make_process(self.logger, command_line, **self.program_options) + return core.programs.make_process( + self.logger, command_line, **self.program_options + ) diff --git a/buildscripts/resmokelib/testing/testcases/sleeptest.py b/buildscripts/resmokelib/testing/testcases/sleeptest.py index e62fa9698ed..f303965e255 100644 --- a/buildscripts/resmokelib/testing/testcases/sleeptest.py +++ b/buildscripts/resmokelib/testing/testcases/sleeptest.py @@ -15,8 +15,9 @@ class SleepTestCase(interface.TestCase): sleep_duration_secs = int(sleep_duration_secs) - interface.TestCase.__init__(self, logger, "Sleep", - "{:d} seconds".format(sleep_duration_secs)) + interface.TestCase.__init__( + self, logger, "Sleep", "{:d} seconds".format(sleep_duration_secs) + ) self.__sleep_duration_secs = sleep_duration_secs diff --git a/buildscripts/resmokelib/testing/testcases/tla_plus_test.py b/buildscripts/resmokelib/testing/testcases/tla_plus_test.py index b3013f5e4df..45bc5a56fda 100644 --- a/buildscripts/resmokelib/testing/testcases/tla_plus_test.py +++ b/buildscripts/resmokelib/testing/testcases/tla_plus_test.py @@ -19,8 +19,10 @@ class TLAPlusTestCase(interface.ProcessTestCase): java_binary is the full path to the "java" program, or None. """ - message = f"Path '{model_config_file}' doesn't" \ - f" match **//MC.cfg" + message = ( + f"Path '{model_config_file}' doesn't" + f" match **//MC.cfg" + ) # spec_dir should be like src/mongo/tla_plus/MongoReplReconfig. spec_dir, filename = os.path.split(model_config_file) @@ -29,7 +31,7 @@ class TLAPlusTestCase(interface.ProcessTestCase): # working_dir is like src/mongo/tla_plus. self.working_dir, specname = os.path.split(spec_dir) - if not specname or filename != f'MC{specname}.cfg': + if not specname or filename != f"MC{specname}.cfg": raise ValueError(message) self.java_binary = java_binary @@ -41,5 +43,8 @@ class TLAPlusTestCase(interface.ProcessTestCase): if self.java_binary is not None: process_kwargs["env_vars"] = {"JAVA_BINARY": self.java_binary} - return core.programs.generic_program(self.logger, ["sh", "model-check.sh", self.test_name], - process_kwargs=process_kwargs) + return core.programs.generic_program( + self.logger, + ["sh", "model-check.sh", self.test_name], + process_kwargs=process_kwargs, + ) diff --git a/buildscripts/resmokelib/undodb/__init__.py b/buildscripts/resmokelib/undodb/__init__.py index c65ad58acc3..c59c5aede9a 100644 --- a/buildscripts/resmokelib/undodb/__init__.py +++ b/buildscripts/resmokelib/undodb/__init__.py @@ -45,8 +45,13 @@ class UndoDbPlugin(PluginInterface): """ parser = subparsers.add_parser(_COMMAND, help=_HELP) # Accept arbitrary args like 'resmoke.py undodb foobar', but ignore them. - parser.add_argument("--fetch", '-f', action="store", type=str, - help="Fetch UndoDB recordings archive with the given Evergreen task ID") + parser.add_argument( + "--fetch", + "-f", + action="store", + type=str, + help="Fetch UndoDB recordings archive with the given Evergreen task ID", + ) parser.add_argument("args", nargs="*") def parse(self, subcommand, parser, parsed_args, **kwargs): diff --git a/buildscripts/resmokelib/undodb/fetch.py b/buildscripts/resmokelib/undodb/fetch.py index dc8f9af2f9f..979369b6b17 100644 --- a/buildscripts/resmokelib/undodb/fetch.py +++ b/buildscripts/resmokelib/undodb/fetch.py @@ -1,4 +1,5 @@ """Subcommand for fetching UndoDB recordings from Evergreen.""" + import os import tarfile import tempfile @@ -35,7 +36,9 @@ class Fetch(Subcommand): :return: None """ if self._ticket: - raise NotImplementedError("Fetching recordings from JIRA tickets not yet implemented") + raise NotImplementedError( + "Fetching recordings from JIRA tickets not yet implemented" + ) assert self._task_id @@ -57,7 +60,9 @@ class Fetch(Subcommand): _cleanup(local_file) -def _find_undodb_artifact_url(artifacts: List[evergreen.task.Artifact]) -> Optional[str]: +def _find_undodb_artifact_url( + artifacts: List[evergreen.task.Artifact], +) -> Optional[str]: for artifact in artifacts: if artifact.name.startswith("UndoDB Recordings - Execution "): return artifact.url diff --git a/buildscripts/resmokelib/utils/__init__.py b/buildscripts/resmokelib/utils/__init__.py index 15ccfcb2528..20e449db63a 100644 --- a/buildscripts/resmokelib/utils/__init__.py +++ b/buildscripts/resmokelib/utils/__init__.py @@ -80,7 +80,9 @@ def load_yaml(value): try: return yaml.safe_load(value) except yaml.YAMLError as err: - raise ValueError("Attempted to parse invalid YAML value '%s': %s" % (value, err)) + raise ValueError( + "Attempted to parse invalid YAML value '%s': %s" % (value, err) + ) def get_task_name_without_suffix(task_name, variant_name): @@ -90,7 +92,7 @@ def get_task_name_without_suffix(task_name, variant_name): Example: "noPassthrough_0_enterprise-rhel-8-64-bit-dynamic-required" -> "noPassthrough" """ task_name = task_name if task_name else "" - return re.sub(fr"(_[0-9]+)?(_{variant_name})?$", "", task_name) + return re.sub(rf"(_[0-9]+)?(_{variant_name})?$", "", task_name) def pick_catalog_shard_node(config_shard, num_shards): @@ -107,6 +109,6 @@ def pick_catalog_shard_node(config_shard, num_shards): config_shard_index = int(config_shard) if config_shard_index < 0 or config_shard_index >= num_shards: - raise ValueError("Config shard value must be in range 0..num_shards-1 or \"any\"") + raise ValueError('Config shard value must be in range 0..num_shards-1 or "any"') return config_shard_index diff --git a/buildscripts/resmokelib/utils/archival.py b/buildscripts/resmokelib/utils/archival.py index c851bad8afe..c2034d10c4a 100644 --- a/buildscripts/resmokelib/utils/archival.py +++ b/buildscripts/resmokelib/utils/archival.py @@ -18,13 +18,22 @@ _IS_WINDOWS = sys.platform in ("win32", "cygwin") if _IS_WINDOWS: import ctypes -UploadArgs = collections.namedtuple("UploadArgs", [ - "archival_file", "display_name", "local_file", "content_type", "s3_bucket", "s3_path", - "delete_file" -]) +UploadArgs = collections.namedtuple( + "UploadArgs", + [ + "archival_file", + "display_name", + "local_file", + "content_type", + "s3_bucket", + "s3_path", + "delete_file", + ], +) -ArchiveArgs = collections.namedtuple("ArchiveArgs", - ["archival_file", "display_name", "remote_file"]) +ArchiveArgs = collections.namedtuple( + "ArchiveArgs", ["archival_file", "display_name", "remote_file"] +) def file_list_size(files): @@ -62,7 +71,8 @@ def free_space(path): dirname = os.path.dirname(path) free_bytes = ctypes.c_ulonglong(0) ctypes.windll.kernel32.GetDiskFreeSpaceExW( - ctypes.c_wchar_p(dirname), None, None, ctypes.pointer(free_bytes)) + ctypes.c_wchar_p(dirname), None, None, ctypes.pointer(free_bytes) + ) return free_bytes.value stat = os.statvfs(path) return stat.f_bavail * stat.f_bsize @@ -86,8 +96,14 @@ def remove_file(file_name): class Archival(object): """Class to support file archival to S3.""" - def __init__(self, logger, archival_json_file="archive.json", limit_size_mb=0, limit_files=0, - s3_client=None): + def __init__( + self, + logger, + archival_json_file="archive.json", + limit_size_mb=0, + limit_files=0, + s3_client=None, + ): """Initialize Archival.""" self.archival_json_file = archival_json_file @@ -103,9 +119,11 @@ class Archival(object): # Start the worker thread to update the 'archival_json_file'. self._archive_file_queue = queue.Queue() - self._archive_file_worker = threading.Thread(target=self._update_archive_file_wkr, - args=(self._archive_file_queue, - logger), name="archive_file_worker") + self._archive_file_worker = threading.Thread( + target=self._update_archive_file_wkr, + args=(self._archive_file_queue, logger), + name="archive_file_worker", + ) self._archive_file_worker.setDaemon(True) self._archive_file_worker.start() if not s3_client: @@ -116,8 +134,10 @@ class Archival(object): # Start the worker thread which uploads the archive. self._upload_queue = queue.Queue() self._upload_worker = threading.Thread( - target=self._upload_to_s3_wkr, args=(self._upload_queue, self._archive_file_queue, - logger, self.s3_client), name="upload_worker") + target=self._upload_to_s3_wkr, + args=(self._upload_queue, self._archive_file_queue, logger, self.s3_client), + name="upload_worker", + ) self._upload_worker.setDaemon(True) self._upload_worker.start() @@ -126,6 +146,7 @@ class Archival(object): # Since boto3 is a 3rd party module, we import locally. import boto3 import botocore.session + botocore.session.Session() if sys.platform in ("win32", "cygwin"): @@ -134,19 +155,20 @@ class Archival(object): # This is due to the backwards breaking changed python introduced https://bugs.python.org/issue36264 botocore_session = botocore.session.Session( session_vars={ - 'config_file': ( + "config_file": ( None, - 'AWS_CONFIG_FILE', - os.path.join(os.environ['HOME'], '.aws', 'config'), + "AWS_CONFIG_FILE", + os.path.join(os.environ["HOME"], ".aws", "config"), None, ), - 'credentials_file': ( + "credentials_file": ( None, - 'AWS_SHARED_CREDENTIALS_FILE', - os.path.join(os.environ['HOME'], '.aws', 'credentials'), + "AWS_SHARED_CREDENTIALS_FILE", + os.path.join(os.environ["HOME"], ".aws", "credentials"), None, ), - }) + } + ) boto3.setup_default_session(botocore_session=botocore_session) return boto3.client("s3") @@ -167,13 +189,18 @@ class Archival(object): message = "No input_files specified" elif self.limit_size_mb and self.size_mb >= self.limit_size_mb: status = 1 - message = "Files not archived, {}MB size limit reached".format(self.limit_size_mb) + message = "Files not archived, {}MB size limit reached".format( + self.limit_size_mb + ) elif self.limit_files and self.num_files >= self.limit_files: status = 1 - message = "Files not archived, {} file limit reached".format(self.limit_files) + message = "Files not archived, {} file limit reached".format( + self.limit_files + ) else: - status, message, file_size_mb = self._archive_files(display_name, input_files, - s3_bucket, s3_path) + status, message, file_size_mb = self._archive_files( + display_name, input_files, s3_bucket, s3_path + ) if status == 0: self.num_files += 1 @@ -193,11 +220,15 @@ class Archival(object): work_queue.task_done() break archival_record = { - "name": archive_args.display_name, "link": archive_args.remote_file, - "visibility": "private" + "name": archive_args.display_name, + "link": archive_args.remote_file, + "visibility": "private", } - logger.debug("Updating archive file %s with %s", archive_args.archival_file, - archival_record) + logger.debug( + "Updating archive file %s with %s", + archive_args.archival_file, + archival_record, + ) archival_json.append(archival_record) with open(archive_args.archival_file, "w") as archival_fh: json.dump(archival_json, archival_fh) @@ -214,15 +245,27 @@ class Archival(object): archive_file_work_queue.put(None) break extra_args = {"ContentType": upload_args.content_type, "ACL": "public-read"} - logger.debug("Uploading to S3 %s to bucket %s path %s", upload_args.local_file, - upload_args.s3_bucket, upload_args.s3_path) + logger.debug( + "Uploading to S3 %s to bucket %s path %s", + upload_args.local_file, + upload_args.s3_bucket, + upload_args.s3_path, + ) upload_completed = False try: - s3_client.upload_file(upload_args.local_file, upload_args.s3_bucket, - upload_args.s3_path, ExtraArgs=extra_args) + s3_client.upload_file( + upload_args.local_file, + upload_args.s3_bucket, + upload_args.s3_path, + ExtraArgs=extra_args, + ) upload_completed = True - logger.debug("Upload to S3 completed for %s to bucket %s path %s", - upload_args.local_file, upload_args.s3_bucket, upload_args.s3_path) + logger.debug( + "Upload to S3 completed for %s to bucket %s path %s", + upload_args.local_file, + upload_args.s3_bucket, + upload_args.s3_path, + ) except Exception as err: # pylint: disable=broad-except logger.exception("Upload to S3 error %s", err) @@ -231,11 +274,15 @@ class Archival(object): if status: logger.error("Upload to S3 delete file error %s", message) - remote_file = "https://s3.amazonaws.com/{}/{}".format(upload_args.s3_bucket, - upload_args.s3_path) + remote_file = "https://s3.amazonaws.com/{}/{}".format( + upload_args.s3_bucket, upload_args.s3_path + ) if upload_completed: archive_file_work_queue.put( - ArchiveArgs(upload_args.archival_file, upload_args.display_name, remote_file)) + ArchiveArgs( + upload_args.archival_file, upload_args.display_name, remote_file + ) + ) work_queue.task_done() @@ -256,15 +303,21 @@ class Archival(object): status = 0 size_mb = 0 - if 'test_archival' in config.INTERNAL_PARAMS: + if "test_archival" in config.INTERNAL_PARAMS: message = "'test_archival' specified. Skipping tar/gzip." - with open(os.path.join(config.DBPATH_PREFIX, "test_archival.txt"), "a") as test_file: + with open( + os.path.join(config.DBPATH_PREFIX, "test_archival.txt"), "a" + ) as test_file: for input_file in input_files: # If a resmoke fixture is used, the input_file will be the source of the data # files. If mongorunner is used, input_file/mongorunner will be the source # of the data files. - if os.path.isdir(os.path.join(input_file, config.MONGO_RUNNER_SUBDIR)): - input_file = os.path.join(input_file, config.MONGO_RUNNER_SUBDIR) + if os.path.isdir( + os.path.join(input_file, config.MONGO_RUNNER_SUBDIR) + ): + input_file = os.path.join( + input_file, config.MONGO_RUNNER_SUBDIR + ) # Each node contains one directory for its data files. Here we write out # the names of those directories. In the unit test for archival, we will @@ -281,7 +334,9 @@ class Archival(object): if file_list_size(input_files) > free_space(temp_file): status, message = remove_file(temp_file) if status: - self.logger.warning("Removing tarfile due to insufficient space - %s", message) + self.logger.warning( + "Removing tarfile due to insufficient space - %s", message + ) return 1, "Insufficient space for {}".format(message), 0 try: @@ -291,18 +346,29 @@ class Archival(object): tar_handle.add(input_file) except (IOError, OSError, tarfile.TarError) as err: message = "{}; Unable to add {} to archive file: {}".format( - message, input_file, err) + message, input_file, err + ) except (IOError, OSError, tarfile.TarError) as err: status, message = remove_file(temp_file) if status: - self.logger.warning("Removing tarfile due to creation failure - %s", message) + self.logger.warning( + "Removing tarfile due to creation failure - %s", message + ) return 1, str(err), 0 # Round up the size of the archive. size_mb = int(math.ceil(float(file_list_size(temp_file)) / (1024 * 1024))) self._upload_queue.put( - UploadArgs(self.archival_json_file, display_name, temp_file, "application/x-gzip", - s3_bucket, s3_path, True)) + UploadArgs( + self.archival_json_file, + display_name, + temp_file, + "application/x-gzip", + s3_bucket, + s3_path, + True, + ) + ) return status, message, size_mb @@ -311,11 +377,17 @@ class Archival(object): if thread.is_alive() and not expected_alive: self.logger.warning( "The %s thread did not complete, some files might not have been uploaded" - " to S3 or archived to %s.", thread.name, self.archival_json_file) + " to S3 or archived to %s.", + thread.name, + self.archival_json_file, + ) elif not thread.is_alive() and expected_alive: self.logger.warning( "The %s thread is no longer running, some files might not have been uploaded" - " to S3 or archived to %s.", thread.name, self.archival_json_file) + " to S3 or archived to %s.", + thread.name, + self.archival_json_file, + ) def exit(self, timeout=30): """Wait for worker threads to finish.""" @@ -330,8 +402,12 @@ class Archival(object): self._archive_file_worker.join(timeout=timeout) self.check_thread(self._archive_file_worker, False) - self.logger.info("Total tar/gzip archive time is %0.2f seconds, for %d file(s) %d MB", - self.archive_time, self.num_files, self.size_mb) + self.logger.info( + "Total tar/gzip archive time is %0.2f seconds, for %d file(s) %d MB", + self.archive_time, + self.num_files, + self.size_mb, + ) def files_archived_num(self): """Return the number of the archived files.""" diff --git a/buildscripts/resmokelib/utils/autoloader.py b/buildscripts/resmokelib/utils/autoloader.py index ab1ace84798..943839c5096 100644 --- a/buildscripts/resmokelib/utils/autoloader.py +++ b/buildscripts/resmokelib/utils/autoloader.py @@ -18,5 +18,5 @@ def load_all_modules(name, path): _autoloader.load_all_modules(name=__name__, path=__path__) """ - for (_, module, _) in pkgutil.walk_packages(path=path): + for _, module, _ in pkgutil.walk_packages(path=path): importlib.import_module("." + module, package=name) diff --git a/buildscripts/resmokelib/utils/check_has_tag.py b/buildscripts/resmokelib/utils/check_has_tag.py index ad90e258f4a..a6a27d6bd5f 100755 --- a/buildscripts/resmokelib/utils/check_has_tag.py +++ b/buildscripts/resmokelib/utils/check_has_tag.py @@ -9,16 +9,16 @@ import jscomment try: if len(sys.argv) != 3: print( - 'This program checks if a javascript test has specific tag (e.g.: @tag=[name] in the comment)' + "This program checks if a javascript test has specific tag (e.g.: @tag=[name] in the comment)" ) - print('It returns result via exit code:') - print(' 0 if script has specified tag') - print(' 1 if script does not have specified tag') - print(' 2 if script was invoked incorrectly') - print(' 3 if any error happened during check') - print('Usage:') - print(' check_has_tag.py ') - print('Notice: is a regex, not search string') + print("It returns result via exit code:") + print(" 0 if script has specified tag") + print(" 1 if script does not have specified tag") + print(" 2 if script was invoked incorrectly") + print(" 3 if any error happened during check") + print("Usage:") + print(" check_has_tag.py ") + print("Notice: is a regex, not search string") sys.exit(2) else: tags = jscomment.get_tags(sys.argv[1]) diff --git a/buildscripts/resmokelib/utils/dictionary.py b/buildscripts/resmokelib/utils/dictionary.py index f1124705ccf..b8c5a9b22dc 100644 --- a/buildscripts/resmokelib/utils/dictionary.py +++ b/buildscripts/resmokelib/utils/dictionary.py @@ -1,4 +1,5 @@ """Utility functions for working with Dict-type structures.""" + from typing import MutableMapping diff --git a/buildscripts/resmokelib/utils/evergreen_conn.py b/buildscripts/resmokelib/utils/evergreen_conn.py index bd40267372a..71844ec10c5 100644 --- a/buildscripts/resmokelib/utils/evergreen_conn.py +++ b/buildscripts/resmokelib/utils/evergreen_conn.py @@ -1,4 +1,5 @@ """Helper functions to interact with evergreen.""" + import os import pathlib from collections import deque @@ -76,20 +77,29 @@ def get_evergreen_api(evergreen_config=None): LOGGER.error( "Could not connect to Evergreen with any .evergreen.yml files available on this system", - config_file_candidates=possible_configs) + config_file_candidates=possible_configs, + ) raise last_ex -def get_buildvariant_name(config: SetupMultiversionConfig, edition, platform, architecture, - major_minor_version): +def get_buildvariant_name( + config: SetupMultiversionConfig, + edition, + platform, + architecture, + major_minor_version, +): """Return Evergreen buildvariant name.""" buildvariant_name = "" evergreen_buildvariants = config.evergreen_buildvariants for buildvariant in evergreen_buildvariants: - if (buildvariant.edition == edition and buildvariant.platform == platform - and buildvariant.architecture == architecture): + if ( + buildvariant.edition == edition + and buildvariant.platform == platform + and buildvariant.architecture == architecture + ): versions = buildvariant.versions if major_minor_version in versions: buildvariant_name = buildvariant.name @@ -109,8 +119,10 @@ def get_patch_module_diffs(evg_api: RetryingEvergreenApi, version_id): except requests.exceptions.HTTPError as err: err_res = err.response if err_res.status_code == 400: - LOGGER.debug("Not a patch build task, skipping applying patch", - version_id_of_task=version_id) + LOGGER.debug( + "Not a patch build task, skipping applying patch", + version_id_of_task=version_id, + ) return None else: raise @@ -130,12 +142,20 @@ def get_patch_module_diffs(evg_api: RetryingEvergreenApi, version_id): def get_generic_buildvariant_name(config: SetupMultiversionConfig, major_minor_version): """Return Evergreen buildvariant name for generic platform.""" - LOGGER.info("Falling back to generic architecture.", edition=GENERIC_EDITION, - platform=GENERIC_PLATFORM, architecture=GENERIC_ARCHITECTURE) + LOGGER.info( + "Falling back to generic architecture.", + edition=GENERIC_EDITION, + platform=GENERIC_PLATFORM, + architecture=GENERIC_ARCHITECTURE, + ) generic_buildvariant_name = get_buildvariant_name( - config=config, edition=GENERIC_EDITION, platform=GENERIC_PLATFORM, - architecture=GENERIC_ARCHITECTURE, major_minor_version=major_minor_version) + config=config, + edition=GENERIC_EDITION, + platform=GENERIC_PLATFORM, + architecture=GENERIC_ARCHITECTURE, + major_minor_version=major_minor_version, + ) if not generic_buildvariant_name: raise EvergreenConnError("Generic architecture buildvariant not found.") @@ -143,7 +163,9 @@ def get_generic_buildvariant_name(config: SetupMultiversionConfig, major_minor_v return generic_buildvariant_name -def get_evergreen_version(evg_api: RetryingEvergreenApi, evg_ref: str) -> Optional[Version]: +def get_evergreen_version( + evg_api: RetryingEvergreenApi, evg_ref: str +) -> Optional[Version]: """Return evergreen version by reference (commit_hash or evergreen_version_id).""" from buildscripts.resmokelib import multiversionconstants @@ -151,7 +173,9 @@ def get_evergreen_version(evg_api: RetryingEvergreenApi, evg_ref: str) -> Option evg_refs = [evg_ref] # Evergreen reference as {project_name}_{commit_hash} evg_refs.extend( - f"{proj.replace('-', '_')}_{evg_ref}" for proj in multiversionconstants.EVERGREEN_PROJECTS) + f"{proj.replace('-', '_')}_{evg_ref}" + for proj in multiversionconstants.EVERGREEN_PROJECTS + ) for ref in evg_refs: try: @@ -159,20 +183,28 @@ def get_evergreen_version(evg_api: RetryingEvergreenApi, evg_ref: str) -> Option except HTTPError: continue else: - LOGGER.debug("Found evergreen version.", - evergreen_version=f"{EVERGREEN_HOST}/version/{evg_version.version_id}") + LOGGER.debug( + "Found evergreen version.", + evergreen_version=f"{EVERGREEN_HOST}/version/{evg_version.version_id}", + ) return evg_version return None -def get_evergreen_versions(evg_api: RetryingEvergreenApi, evg_project: str) -> Iterator[Version]: +def get_evergreen_versions( + evg_api: RetryingEvergreenApi, evg_project: str +) -> Iterator[Version]: """Return the list of evergreen versions by evergreen project name.""" return evg_api.versions_by_project(evg_project) -def get_compile_artifact_urls(evg_api: RetryingEvergreenApi, evg_version: Version, - buildvariant_name, ignore_failed_push=False): +def get_compile_artifact_urls( + evg_api: RetryingEvergreenApi, + evg_version: Version, + buildvariant_name, + ignore_failed_push=False, +): """Return compile urls from buildvariant in Evergreen version.""" try: build_id = evg_version.build_variants_map[buildvariant_name] @@ -180,7 +212,9 @@ def get_compile_artifact_urls(evg_api: RetryingEvergreenApi, evg_version: Versio raise EvergreenConnError(f"Buildvariant {buildvariant_name} not found.") evg_build = evg_api.build_by_id(build_id) - LOGGER.debug("Found evergreen build.", evergreen_build=f"{EVERGREEN_HOST}/build/{build_id}") + LOGGER.debug( + "Found evergreen build.", evergreen_build=f"{EVERGREEN_HOST}/build/{build_id}" + ) evg_tasks: Deque[Union[Task, str]] = deque(evg_build.get_tasks()) tasks_wrapper = _filter_successful_tasks(evg_api, evg_tasks) LOGGER.info( @@ -202,8 +236,12 @@ def get_compile_artifact_urls(evg_api: RetryingEvergreenApi, evg_version: Versio class _MultiversionTasks(object): """Tasks relevant for multiversion setup.""" - def __init__(self, symbols: Union[Task, None], binary: Union[Task, None], - push: Union[Task, None]): + def __init__( + self, + symbols: Union[Task, None], + binary: Union[Task, None], + push: Union[Task, None], + ): """Init function.""" self.symbols_task = symbols self.binary_task = binary @@ -220,9 +258,11 @@ def _get_multiversion_urls(tasks_wrapper: _MultiversionTasks): required_tasks = [binary, push] if push is not None else [binary] if all(task and task.status == "success" for task in required_tasks): - LOGGER.info("Required evergreen task(s) were successful.", - required_tasks=f"{required_tasks}", - task_id=f"{EVERGREEN_HOST}/task/{required_tasks[0].task_id}") + LOGGER.info( + "Required evergreen task(s) were successful.", + required_tasks=f"{required_tasks}", + task_id=f"{EVERGREEN_HOST}/task/{required_tasks[0].task_id}", + ) evg_artifacts = binary.artifacts for artifact in evg_artifacts: compile_artifact_urls[artifact.name] = artifact.url @@ -231,24 +271,31 @@ def _get_multiversion_urls(tasks_wrapper: _MultiversionTasks): for artifact in symbols.artifacts: compile_artifact_urls[artifact.name] = artifact.url elif symbols and symbols.task_id: - LOGGER.warning("debug symbol archive was unsuccessful", - archive_symbols_task=f"{EVERGREEN_HOST}/task/{symbols.task_id}") + LOGGER.warning( + "debug symbol archive was unsuccessful", + archive_symbols_task=f"{EVERGREEN_HOST}/task/{symbols.task_id}", + ) # Tack on the project id for generating a friendly decompressed name for the artifacts. compile_artifact_urls["project_identifier"] = binary.project_identifier elif all(task for task in required_tasks): - LOGGER.warning("Required Evergreen task(s) were not successful.", - required_tasks=f"{required_tasks}", - task_id=f"{EVERGREEN_HOST}/task/{required_tasks[0].task_id}") + LOGGER.warning( + "Required Evergreen task(s) were not successful.", + required_tasks=f"{required_tasks}", + task_id=f"{EVERGREEN_HOST}/task/{required_tasks[0].task_id}", + ) else: - LOGGER.error("There are no `compile` and/or 'push' tasks in the evergreen build") + LOGGER.error( + "There are no `compile` and/or 'push' tasks in the evergreen build" + ) return compile_artifact_urls -def _filter_successful_tasks(evg_api: RetryingEvergreenApi, - evg_tasks: Deque[Union[Task, str]]) -> _MultiversionTasks: +def _filter_successful_tasks( + evg_api: RetryingEvergreenApi, evg_tasks: Deque[Union[Task, str]] +) -> _MultiversionTasks: """ We want to filter successful tasks in order by variant then by dependent tasks to find the compile tasks. @@ -273,9 +320,15 @@ def _filter_successful_tasks(evg_api: RetryingEvergreenApi, # Only set the compile task if there isn't one already, otherwise # newer tasks like "archive_dist_test_debug" take precedence. # Use `get_execution_or_self` to prevent grabbing an unfinished restarted executed task. - if evg_task.display_name in ( - "compile", "archive_dist_test", - "archive_dist_test_future_git_tag_multiversion") and compile_task is None: + if ( + evg_task.display_name + in ( + "compile", + "archive_dist_test", + "archive_dist_test_future_git_tag_multiversion", + ) + and compile_task is None + ): compile_task = evg_task.get_execution_or_self(0) # archive_dist_test_debug might not be in the dep chain # it should always be in the same build variant as the compile task @@ -283,9 +336,14 @@ def _filter_successful_tasks(evg_api: RetryingEvergreenApi, elif evg_task.display_name == "push": push_task = evg_task.get_execution_or_self(0) - elif evg_task.display_name in ("archive_dist_test_debug", - "archive_dist_test_debug_future_git_tag_multiversion" - ) and archive_symbols_task is None: + elif ( + evg_task.display_name + in ( + "archive_dist_test_debug", + "archive_dist_test_debug_future_git_tag_multiversion", + ) + and archive_symbols_task is None + ): archive_symbols_task = evg_task.get_execution_or_self(0) if compile_task and push_task and archive_symbols_task: break @@ -293,4 +351,6 @@ def _filter_successful_tasks(evg_api: RetryingEvergreenApi, dependent_tasks = evg_task.depends_on if evg_task.depends_on else [] for dep_task in dependent_tasks: evg_tasks.append(dep_task["id"]) - return _MultiversionTasks(symbols=archive_symbols_task, binary=compile_task, push=push_task) + return _MultiversionTasks( + symbols=archive_symbols_task, binary=compile_task, push=push_task + ) diff --git a/buildscripts/resmokelib/utils/external_suite.py b/buildscripts/resmokelib/utils/external_suite.py index 4b890e5c8a1..532115037e4 100755 --- a/buildscripts/resmokelib/utils/external_suite.py +++ b/buildscripts/resmokelib/utils/external_suite.py @@ -19,7 +19,8 @@ INCOMPATIBLE_HOOKS = [ def delete_archival(suite): """Remove archival for External Suites.""" logging.loggers.ROOT_EXECUTOR_LOGGER.warning( - "`archive` is not supported for external suites and will be removed if it exists.") + "`archive` is not supported for external suites and will be removed if it exists." + ) suite.pop("archive", None) suite.get("executor", {}).pop("archive", None) @@ -27,19 +28,25 @@ def delete_archival(suite): def make_hooks_compatible(suite): """Make hooks compatible for external suites.""" logging.loggers.ROOT_EXECUTOR_LOGGER.warning( - "Some hooks are automatically disabled for external suites: %s", INCOMPATIBLE_HOOKS) + "Some hooks are automatically disabled for external suites: %s", + INCOMPATIBLE_HOOKS, + ) logging.loggers.ROOT_EXECUTOR_LOGGER.warning( - "The `AntithesisLogging` hook is automatically added for external suites.") + "The `AntithesisLogging` hook is automatically added for external suites." + ) if suite.get("executor", {}).get("hooks", None): # it's either a list of strings, or a list of dicts, each with key 'class' if isinstance(suite["executor"]["hooks"][0], str): suite["executor"]["hooks"] = ["AntithesisLogging"] + [ - hook for hook in suite["executor"]["hooks"] if hook not in INCOMPATIBLE_HOOKS + hook + for hook in suite["executor"]["hooks"] + if hook not in INCOMPATIBLE_HOOKS ] elif isinstance(suite["executor"]["hooks"][0], dict): suite["executor"]["hooks"] = [{"class": "AntithesisLogging"}] + [ hook - for hook in suite["executor"]["hooks"] if hook["class"] not in INCOMPATIBLE_HOOKS + for hook in suite["executor"]["hooks"] + if hook["class"] not in INCOMPATIBLE_HOOKS ] else: raise RuntimeError( @@ -52,35 +59,43 @@ def update_test_data(suite): logging.loggers.ROOT_EXECUTOR_LOGGER.warning( "`useActionPermittedFile` is incompatible with external suites and will always be set to `False`." ) - suite.setdefault("executor", {}).setdefault( - "config", {}).setdefault("shell_options", {}).setdefault("global_vars", {}).setdefault( - "TestData", {}).update({"useActionPermittedFile": False}) + suite.setdefault("executor", {}).setdefault("config", {}).setdefault( + "shell_options", {} + ).setdefault("global_vars", {}).setdefault("TestData", {}).update( + {"useActionPermittedFile": False} + ) def update_shell(suite): """Update shell for when running external suites.""" logging.loggers.ROOT_EXECUTOR_LOGGER.warning( - "`jsTestLog` is a no-op on external suites to reduce logging.") - suite.setdefault("executor", {}).setdefault("config", {}).setdefault("shell_options", - {}).setdefault("eval", "") - suite["executor"]["config"]["shell_options"]["eval"] += "jsTestLog = Function.prototype;" + "`jsTestLog` is a no-op on external suites to reduce logging." + ) + suite.setdefault("executor", {}).setdefault("config", {}).setdefault( + "shell_options", {} + ).setdefault("eval", "") + suite["executor"]["config"]["shell_options"]["eval"] += ( + "jsTestLog = Function.prototype;" + ) def update_exclude_tags(suite): """Update the exclude tags to exclude external suite incompatible tests.""" logging.loggers.ROOT_EXECUTOR_LOGGER.warning( - "The `antithesis_incompatible` tagged tests will be excluded for external suites.") - suite.setdefault('selector', {}) - if not suite.get('selector').get('exclude_with_any_tags'): - suite['selector']['exclude_with_any_tags'] = ["antithesis_incompatible"] + "The `antithesis_incompatible` tagged tests will be excluded for external suites." + ) + suite.setdefault("selector", {}) + if not suite.get("selector").get("exclude_with_any_tags"): + suite["selector"]["exclude_with_any_tags"] = ["antithesis_incompatible"] else: - suite['selector']['exclude_with_any_tags'].append('antithesis_incompatible') + suite["selector"]["exclude_with_any_tags"].append("antithesis_incompatible") def make_external(suite): """Modify suite in-place to be external compatible.""" logging.loggers.ROOT_EXECUTOR_LOGGER.warning( - "This suite is being converted to an 'External Suite': %s", suite) + "This suite is being converted to an 'External Suite': %s", suite + ) delete_archival(suite) make_hooks_compatible(suite) update_test_data(suite) diff --git a/buildscripts/resmokelib/utils/file_span_exporter.py b/buildscripts/resmokelib/utils/file_span_exporter.py index 6e18e8ac8dc..7872dd3aab1 100644 --- a/buildscripts/resmokelib/utils/file_span_exporter.py +++ b/buildscripts/resmokelib/utils/file_span_exporter.py @@ -16,16 +16,17 @@ logger = getLogger(__name__) class FileSpanExporter(SpanExporter): """ FileSpanExporter is an implementation of :class:`SpanExporter` that sends spans to files in directory. - + These files are in JSON format by default. """ def __init__( - self, - directory: str, - pretty_print=False, - service_name: Optional[str] = None, - formatter: Callable[[ReadableSpan], str] = lambda span: span.to_json() + linesep, + self, + directory: str, + pretty_print=False, + service_name: Optional[str] = None, + formatter: Callable[[ReadableSpan], str] = lambda span: span.to_json() + + linesep, ): self.formatter = formatter self.service_name = service_name @@ -56,7 +57,7 @@ class FileSpanExporter(SpanExporter): if self.pretty_print: json.dump(message, file, indent=2) else: - json.dump(message, file, indent=None, separators=(',', ':')) + json.dump(message, file, indent=None, separators=(",", ":")) except: logger.exception("Failed to write OTEL metrics to file %s", file_name) return SpanExportResult.FAILURE diff --git a/buildscripts/resmokelib/utils/history.py b/buildscripts/resmokelib/utils/history.py index fdbe78c9492..4805c84cff1 100644 --- a/buildscripts/resmokelib/utils/history.py +++ b/buildscripts/resmokelib/utils/history.py @@ -69,7 +69,9 @@ class Historic(ABC, metaclass=registry.make_registry_metaclass(_HISTORICS, type( def unsubscribe(self, subscriber): """Allow a subscriber to unsubscribe from notifications.""" - self._subscribers = [sub for sub in self._subscribers if sub.obj is not subscriber] + self._subscribers = [ + sub for sub in self._subscribers if sub.obj is not subscriber + ] def notify_subscriber_write(self): """Notify the subscribers that a write has happened.""" @@ -117,8 +119,8 @@ class Historic(ABC, metaclass=registry.make_registry_metaclass(_HISTORICS, type( class Subscriber: """Class representing the subscriber to a Historic.""" - obj: 'typing.Any' - key: 'typing.Any' + obj: "typing.Any" + key: "typing.Any" # 1. We only allow immutable types or types that have special logic @@ -145,7 +147,9 @@ class HistoryDict(MutableMapping, Historic): # pylint: disable=too-many-ancesto super(HistoryDict, self).__init__() if filename is not None and yaml_string is not None: - raise ValueError("Cannot construct HistoryDict from both yaml object and file.") + raise ValueError( + "Cannot construct HistoryDict from both yaml object and file." + ) self._history_store = defaultdict(list) self._value_store = dict() @@ -162,7 +166,8 @@ class HistoryDict(MutableMapping, Historic): # pylint: disable=too-many-ancesto schema_version = raw_dict["SchemaVersion"] if schema_version != SCHEMA_VERSION: raise ValueError( - f"Invalid schema version. Expected {SCHEMA_VERSION} but found {schema_version}.") + f"Invalid schema version. Expected {SCHEMA_VERSION} but found {schema_version}." + ) history_dict = raw_dict["History"] for key in history_dict: for raw_access in history_dict[key]: @@ -205,7 +210,7 @@ class HistoryDict(MutableMapping, Historic): # pylint: disable=too-many-ancesto output = "\n".join(processed) # Make sure SchemaVersion is at the top. - output = f"SchemaVersion: \"{SCHEMA_VERSION}\"\n" + output + output = f'SchemaVersion: "{SCHEMA_VERSION}"\n' + output if filename is not None: with open(filename, "w") as fp: fp.write(output) @@ -219,11 +224,13 @@ class HistoryDict(MutableMapping, Historic): # pylint: disable=too-many-ancesto for key in self._value_store: our_writes = [ - access.value_written for access in self._history_store[key] + access.value_written + for access in self._history_store[key] if access.type == AccessType.WRITE ] their_writes = [ - access.value_written for access in other_dict._history_store[key] # pylint: disable=protected-access + access.value_written + for access in other_dict._history_store[key] # pylint: disable=protected-access if access.type == AccessType.WRITE ] if not our_writes == their_writes: @@ -235,7 +242,10 @@ class HistoryDict(MutableMapping, Historic): # pylint: disable=too-many-ancesto storable_dict = {} for key, value in self._value_store.items(): storable_dict[key] = storable_dict_from_historic(value) - return {"object_class": HistoryDict.REGISTERED_NAME, "object_value": storable_dict} + return { + "object_class": HistoryDict.REGISTERED_NAME, + "object_value": storable_dict, + } @staticmethod def from_storable_dict(raw_dict): @@ -246,7 +256,9 @@ class HistoryDict(MutableMapping, Historic): # pylint: disable=too-many-ancesto def from_python_obj(obj): """Convert from a python object, overrides Historic.""" if not isinstance(obj, dict): - raise ValueError("HistoryDict can only be converted from dict python objects.") + raise ValueError( + "HistoryDict can only be converted from dict python objects." + ) history_dict = HistoryDict() for key, value in obj.items(): history_dict[key] = make_historic(value) @@ -283,8 +295,10 @@ class HistoryDict(MutableMapping, Historic): # pylint: disable=too-many-ancesto value = make_historic(value) if not isinstance(value, ALLOWED_TYPES): - raise ValueError(f"HistoryDict cannot store type {type(value)}." - " Please use a different type or create a Historic wrapper.") + raise ValueError( + f"HistoryDict cannot store type {type(value)}." + " Please use a different type or create a Historic wrapper." + ) self._record_write(key, value) self._value_store[key] = value if isinstance(value, HistoryDict): @@ -318,21 +332,29 @@ class HistoryDict(MutableMapping, Historic): # pylint: disable=too-many-ancesto def __repr__(self): # eval(repr(self)) isn't valid, but this is at least useful for debugging. - return f'{self.__class__.__name__}({repr(self._value_store)})' + return f"{self.__class__.__name__}({repr(self._value_store)})" def _record_write(self, key, value): written = None if type(value) in ALLOWED_TYPES and value is not Historic: written = value - cur_access = Access(type=AccessType.WRITE, location=_get_location(), value_written=written, - time=self._global_time) + cur_access = Access( + type=AccessType.WRITE, + location=_get_location(), + value_written=written, + time=self._global_time, + ) self._history_store[key].append(cur_access) self._global_time += 1 def _record_delete(self, key): - cur_access = Access(type=AccessType.DELETE, location=_get_location(), value_written=None, - time=self._global_time) + cur_access = Access( + type=AccessType.DELETE, + location=_get_location(), + value_written=None, + time=self._global_time, + ) self._history_store[key].append(cur_access) self._global_time += 1 @@ -364,10 +386,10 @@ class AccessType(Enum): class Access: """Class representing an access to store in the dict's history.""" - type: 'AccessType' + type: "AccessType" time: int - location: ['traceback.FrameSummary'] = field(default_factory=list) - value_written: 'typing.Any' = None + location: ["traceback.FrameSummary"] = field(default_factory=list) + value_written: "typing.Any" = None def as_dict(self): """Convert this class into a dict (accounting for AccessType).""" @@ -380,10 +402,13 @@ class Access: def from_dict(raw_dict): """Retrieve this class from a dict (accounting for AccessType).""" return Access( - type=AccessType[raw_dict["type"]], time=raw_dict["time"], + type=AccessType[raw_dict["type"]], + time=raw_dict["time"], location=raw_dict["location"] if "location" in raw_dict else list(), value_written=copy.deepcopy(raw_dict["value_written"]) - if "value_written" in raw_dict else None) + if "value_written" in raw_dict + else None, + ) def _get_location(): @@ -399,7 +424,7 @@ class PipeLiteral(str): def pipe_literal_representer(dumper, data): """Create a representer for pipe literals, used internally for pyyaml.""" - return dumper.represent_scalar('tag:yaml.org,2002:str', data, style='|') + return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|") yaml.add_representer(PipeLiteral, pipe_literal_representer) diff --git a/buildscripts/resmokelib/utils/jscomment.py b/buildscripts/resmokelib/utils/jscomment.py index bad08f1293d..c2680f591f3 100644 --- a/buildscripts/resmokelib/utils/jscomment.py +++ b/buildscripts/resmokelib/utils/jscomment.py @@ -28,31 +28,45 @@ def get_tags(pathname): */ """ - with io.open(pathname, 'r', encoding='utf-8') as fp: + with io.open(pathname, "r", encoding="utf-8") as fp: match = _JSTEST_TAGS_RE.match(fp.read()) if match: try: # TODO: it might be worth supporting the block (indented) style of YAML lists in # addition to the flow (bracketed) style tags = yaml.safe_load(_strip_jscomments(match.group(1))) - if not isinstance(tags, list) and all(isinstance(tag, str) for tag in tags): - raise TypeError("Expected a list of string tags, but got '%s'" % (tags)) + if not isinstance(tags, list) and all( + isinstance(tag, str) for tag in tags + ): + raise TypeError( + "Expected a list of string tags, but got '%s'" % (tags) + ) for tag in tags: - if '//' in tag: - raise ValueError(("Found a JS line comment '%s'. "\ - "Use '#' YAML style comments instead in a tags array %s") - % (tag, pathname)) + if "//" in tag: + raise ValueError( + ( + "Found a JS line comment '%s'. " + "Use '#' YAML style comments instead in a tags array %s" + ) + % (tag, pathname) + ) - if ' ' in tag: - raise ValueError(("Found an empty space in tag '%s'. "\ - "This is not permitted and may indicate a missing comma in %s") - % (tag, pathname)) + if " " in tag: + raise ValueError( + ( + "Found an empty space in tag '%s'. " + "This is not permitted and may indicate a missing comma in %s" + ) + % (tag, pathname) + ) return tags except yaml.YAMLError as err: raise ValueError( - "File '%s' contained invalid tags (expected YAML): %s" % (pathname, err)) + "File '%s' contained invalid tags (expected YAML): %s" + % (pathname, err) + ) return [] diff --git a/buildscripts/resmokelib/utils/queue.py b/buildscripts/resmokelib/utils/queue.py index aa396fd43dc..ad150951f0c 100644 --- a/buildscripts/resmokelib/utils/queue.py +++ b/buildscripts/resmokelib/utils/queue.py @@ -13,7 +13,7 @@ from typing import Generic, TypeVar # Exception that is raised when get_nowait() is called on an empty Queue. Empty = _queue.Empty -T = TypeVar('T') +T = TypeVar("T") class Queue(_queue.Queue, Generic[T]): diff --git a/buildscripts/resmokelib/utils/registry.py b/buildscripts/resmokelib/utils/registry.py index d6bfdfb5aae..95fec16da4c 100644 --- a/buildscripts/resmokelib/utils/registry.py +++ b/buildscripts/resmokelib/utils/registry.py @@ -75,7 +75,9 @@ def make_registry_metaclass(registry_store, base_metaclass=None): if name_to_register in registry_store: raise ValueError( "The name %s is already registered; a different value for the" - " 'REGISTERED_NAME' attribute must be chosen" % (registered_name)) + " 'REGISTERED_NAME' attribute must be chosen" + % (registered_name) + ) registry_store[name_to_register] = cls return cls diff --git a/buildscripts/s3_binary/download.py b/buildscripts/s3_binary/download.py index 1018783f74c..b0f9603e87a 100644 --- a/buildscripts/s3_binary/download.py +++ b/buildscripts/s3_binary/download.py @@ -22,7 +22,8 @@ def _verify_s3_hash(s3_path: str, local_path: str, expected_hash: str) -> None: hash_string = _sha256_file(local_path) if hash_string != expected_hash: raise ValueError( - f"Hash mismatch for {s3_path}, expected {expected_hash} but got {hash_string}") + f"Hash mismatch for {s3_path}, expected {expected_hash} but got {hash_string}" + ) def _download_path_with_retry(*args, **kwargs): @@ -39,8 +40,8 @@ def _download_path_with_retry(*args, **kwargs): def download_s3_binary( - s3_path: str, - local_path: str = None, + s3_path: str, + local_path: str = None, ) -> None: if local_path is None: local_path = s3_path.split("/")[-1] diff --git a/buildscripts/s3_binary/hashes.py b/buildscripts/s3_binary/hashes.py index 2b379f321fb..defc06c33d2 100644 --- a/buildscripts/s3_binary/hashes.py +++ b/buildscripts/s3_binary/hashes.py @@ -1,40 +1,21 @@ S3_SHA256_HASHES = { - "https://mdb-build-public.s3.amazonaws.com/bazelisk-binaries/v1.26.0/bazelisk-darwin": - "a8c966e9ae6983b1e1c0116313ff523a862076d81b20add23da825b58610c1b3", - "https://mdb-build-public.s3.amazonaws.com/bazelisk-binaries/v1.26.0/bazelisk-darwin-amd64": - "5c77f33f91dd3df119d192175100cb5b50302eb7ee37859cbab79e10a76ccce8", - "https://mdb-build-public.s3.amazonaws.com/bazelisk-binaries/v1.26.0/bazelisk-darwin-arm64": - "d1ca9911cc19e1f17483f93956908334f2b7f3dd13f20853417b68fc3c3eb370", - "https://mdb-build-public.s3.amazonaws.com/bazelisk-binaries/v1.26.0/bazelisk-linux-amd64": - "6539c12842ad76966f3d493e8f80d67caa84ec4a000e220d5459833c967c12bc", - "https://mdb-build-public.s3.amazonaws.com/bazelisk-binaries/v1.26.0/bazelisk-linux-arm64": - "54f85ef4c23393f835252cc882e5fea596e8ef3c4c2056b059f8067cd19f0351", - "https://mdb-build-public.s3.amazonaws.com/bazelisk-binaries/v1.26.0/bazelisk-windows-amd64.exe": - "023734f33ed6b9c6d65468fe20bb2c5fb32473ccb8aca2fc5bf1521e61ce1622", - "https://mdb-build-public.s3.amazonaws.com/bazelisk-binaries/v1.26.0/bazelisk-windows-arm64.exe": - "99ea5997df128b33c34ba93bad26882af4aabf8c26d50e704b9b651d291fae76", - "https://mdb-build-public.s3.amazonaws.com/buildozer/v7.3.1/buildozer-darwin-amd64": - "854c9583efc166602276802658cef3f224d60898cfaa60630b33d328db3b0de2", - "https://mdb-build-public.s3.amazonaws.com/buildozer/v7.3.1/buildozer-darwin-arm64": - "31b1bfe20d7d5444be217af78f94c5c43799cdf847c6ce69794b7bf3319c5364", - "https://mdb-build-public.s3.amazonaws.com/buildozer/v7.3.1/buildozer-linux-amd64": - "3305e287b3fcc68b9a35fd8515ee617452cd4e018f9e6886b6c7cdbcba8710d4", - "https://mdb-build-public.s3.amazonaws.com/buildozer/v7.3.1/buildozer-linux-arm64": - "0b5a2a717ac4fc911e1fec8d92af71dbb4fe95b10e5213da0cc3d56cea64a328", - "https://mdb-build-public.s3.amazonaws.com/buildozer/v7.3.1/buildozer-windows-amd64.exe": - "58d41ce53257c5594c9bc86d769f580909269f68de114297f46284fbb9023dcf", - "https://mdb-build-public.s3.amazonaws.com/ruff/0.6.9/ruff-aarch64-apple-darwin.tar.gz": - "b94562393a4bf23f1a48521f5495a8e48de885b7c173bd7ea8206d6d09921633", - "https://mdb-build-public.s3.amazonaws.com/ruff/0.6.9/ruff-aarch64-unknown-linux-musl.tar.gz": - "73df3729a3381d0918e4640aac4b2653c542f74c7b7843dee8310e2c877e6f2e", - "https://mdb-build-public.s3.amazonaws.com/ruff/0.6.9/ruff-powerpc64le-unknown-linux-gnu.tar.gz": - "6eedb853553ee52309e9519af775b3359a12227ec342404b6a033308cdd48b1b", - "https://mdb-build-public.s3.amazonaws.com/ruff/0.6.9/ruff-s390x-unknown-linux-gnu.tar.gz": - "b4f93af861c1b3e1956df08e0d9f20b7e55cd7beb37c9df09b659908e920ebe6", - "https://mdb-build-public.s3.amazonaws.com/ruff/0.6.9/ruff-x86_64-apple-darwin.tar.gz": - "34aa37643e30dcb81a3c0e011c3a8df552465ea7580ba92ca727a3b7c6de25d1", - "https://mdb-build-public.s3.amazonaws.com/ruff/0.6.9/ruff-x86_64-pc-windows-msvc.zip": - "9d10e1282c5f695b2130cf593d55e37266513fc6d497edc4a30a6ed6d8ba4067", - "https://mdb-build-public.s3.amazonaws.com/ruff/0.6.9/ruff-x86_64-unknown-linux-musl.tar.gz": - "39a1cd878962ebc88322b4f6d33cae2292454563028f93a3f1f8ce58e3025b07", + "https://mdb-build-public.s3.amazonaws.com/bazelisk-binaries/v1.26.0/bazelisk-darwin": "a8c966e9ae6983b1e1c0116313ff523a862076d81b20add23da825b58610c1b3", + "https://mdb-build-public.s3.amazonaws.com/bazelisk-binaries/v1.26.0/bazelisk-darwin-amd64": "5c77f33f91dd3df119d192175100cb5b50302eb7ee37859cbab79e10a76ccce8", + "https://mdb-build-public.s3.amazonaws.com/bazelisk-binaries/v1.26.0/bazelisk-darwin-arm64": "d1ca9911cc19e1f17483f93956908334f2b7f3dd13f20853417b68fc3c3eb370", + "https://mdb-build-public.s3.amazonaws.com/bazelisk-binaries/v1.26.0/bazelisk-linux-amd64": "6539c12842ad76966f3d493e8f80d67caa84ec4a000e220d5459833c967c12bc", + "https://mdb-build-public.s3.amazonaws.com/bazelisk-binaries/v1.26.0/bazelisk-linux-arm64": "54f85ef4c23393f835252cc882e5fea596e8ef3c4c2056b059f8067cd19f0351", + "https://mdb-build-public.s3.amazonaws.com/bazelisk-binaries/v1.26.0/bazelisk-windows-amd64.exe": "023734f33ed6b9c6d65468fe20bb2c5fb32473ccb8aca2fc5bf1521e61ce1622", + "https://mdb-build-public.s3.amazonaws.com/bazelisk-binaries/v1.26.0/bazelisk-windows-arm64.exe": "99ea5997df128b33c34ba93bad26882af4aabf8c26d50e704b9b651d291fae76", + "https://mdb-build-public.s3.amazonaws.com/buildozer/v7.3.1/buildozer-darwin-amd64": "854c9583efc166602276802658cef3f224d60898cfaa60630b33d328db3b0de2", + "https://mdb-build-public.s3.amazonaws.com/buildozer/v7.3.1/buildozer-darwin-arm64": "31b1bfe20d7d5444be217af78f94c5c43799cdf847c6ce69794b7bf3319c5364", + "https://mdb-build-public.s3.amazonaws.com/buildozer/v7.3.1/buildozer-linux-amd64": "3305e287b3fcc68b9a35fd8515ee617452cd4e018f9e6886b6c7cdbcba8710d4", + "https://mdb-build-public.s3.amazonaws.com/buildozer/v7.3.1/buildozer-linux-arm64": "0b5a2a717ac4fc911e1fec8d92af71dbb4fe95b10e5213da0cc3d56cea64a328", + "https://mdb-build-public.s3.amazonaws.com/buildozer/v7.3.1/buildozer-windows-amd64.exe": "58d41ce53257c5594c9bc86d769f580909269f68de114297f46284fbb9023dcf", + "https://mdb-build-public.s3.amazonaws.com/ruff/0.6.9/ruff-aarch64-apple-darwin.tar.gz": "b94562393a4bf23f1a48521f5495a8e48de885b7c173bd7ea8206d6d09921633", + "https://mdb-build-public.s3.amazonaws.com/ruff/0.6.9/ruff-aarch64-unknown-linux-musl.tar.gz": "73df3729a3381d0918e4640aac4b2653c542f74c7b7843dee8310e2c877e6f2e", + "https://mdb-build-public.s3.amazonaws.com/ruff/0.6.9/ruff-powerpc64le-unknown-linux-gnu.tar.gz": "6eedb853553ee52309e9519af775b3359a12227ec342404b6a033308cdd48b1b", + "https://mdb-build-public.s3.amazonaws.com/ruff/0.6.9/ruff-s390x-unknown-linux-gnu.tar.gz": "b4f93af861c1b3e1956df08e0d9f20b7e55cd7beb37c9df09b659908e920ebe6", + "https://mdb-build-public.s3.amazonaws.com/ruff/0.6.9/ruff-x86_64-apple-darwin.tar.gz": "34aa37643e30dcb81a3c0e011c3a8df552465ea7580ba92ca727a3b7c6de25d1", + "https://mdb-build-public.s3.amazonaws.com/ruff/0.6.9/ruff-x86_64-pc-windows-msvc.zip": "9d10e1282c5f695b2130cf593d55e37266513fc6d497edc4a30a6ed6d8ba4067", + "https://mdb-build-public.s3.amazonaws.com/ruff/0.6.9/ruff-x86_64-unknown-linux-musl.tar.gz": "39a1cd878962ebc88322b4f6d33cae2292454563028f93a3f1f8ce58e3025b07", } diff --git a/buildscripts/s3_binary/upload.py b/buildscripts/s3_binary/upload.py index 95554963122..43955568d85 100644 --- a/buildscripts/s3_binary/upload.py +++ b/buildscripts/s3_binary/upload.py @@ -24,7 +24,9 @@ def _upload(local_source_directory: str, s3_destination_directory: str) -> None: files_to_upload = [] for file in pathlib.Path(local_source_directory).iterdir(): files_to_upload.append(file) - print("Please authenticate with an account that can upload to the s3 bucket mdb-build-public") + print( + "Please authenticate with an account that can upload to the s3 bucket mdb-build-public" + ) subprocess.check_call(["aws", "configure", "sso", "--profile", "devprod-build"]) s3_destination_directory = s3_destination_directory.rstrip("/") + "/" @@ -49,25 +51,37 @@ def _upload(local_source_directory: str, s3_destination_directory: str) -> None: print("Storing hashes in buildscripts/s3_binary/hashes.py...") for file in files_to_upload: - https_path = (re.sub(r"s3://(.*?)/(.*)", r"https://\1.s3.amazonaws.com/\2", - s3_destination_directory) + file.name) + https_path = ( + re.sub( + r"s3://(.*?)/(.*)", + r"https://\1.s3.amazonaws.com/\2", + s3_destination_directory, + ) + + file.name + ) S3_SHA256_HASHES[https_path] = _sha256_file(file) with open("buildscripts/s3_binary/hashes.py", "w", encoding="utf-8") as hash_file: - hash_dict = (pformat(S3_SHA256_HASHES, indent=4).replace("'", '"').replace("}", "").replace( - "{", "")) + hash_dict = ( + pformat(S3_SHA256_HASHES, indent=4) + .replace("'", '"') + .replace("}", "") + .replace("{", "") + ) hash_file.write(f"S3_SHA256_HASHES = {{\n {hash_dict}\n}}\n") print(f"Uploading to {s3_destination_directory}...") - result = subprocess.check_call([ - "aws", - "s3", - "cp", - "--recursive", - "--profile=devprod-build", - local_source_directory, - s3_destination_directory, - ]) + result = subprocess.check_call( + [ + "aws", + "s3", + "cp", + "--recursive", + "--profile=devprod-build", + local_source_directory, + s3_destination_directory, + ] + ) return False diff --git a/buildscripts/sbom/config.py b/buildscripts/sbom/config.py index c72c2057c0d..1746f12b9bc 100644 --- a/buildscripts/sbom/config.py +++ b/buildscripts/sbom/config.py @@ -72,7 +72,9 @@ VERSION_PATTERN_REPL = [ # 'asio-1-34-2' pkg:github/chriskohlhoff/asio # 'cares-1_27_0' pkg:github/c-ares/c-ares [ - re.compile(rf"^[a-z]+-{RE_VER_NUM}[_-]{RE_VER_NUM}[_-]{RE_VER_NUM}{RE_VER_LBL}$"), + re.compile( + rf"^[a-z]+-{RE_VER_NUM}[_-]{RE_VER_NUM}[_-]{RE_VER_NUM}{RE_VER_LBL}$" + ), r"\1.\2.\3", ], # 'pcre2-10.40' pkg:github/pcre2project/pcre2 @@ -149,7 +151,9 @@ def process_component_special_cases( # "https://s3.amazonaws.com/boxes.10gen.com/build/windows_cyrus_sasl-2.1.28.zip", # Rather than add the complexity of Bazel queries to this script, we just search the text. - versions["import_script"] = get_version_sasl_from_workspace(repo_root + "/WORKSPACE.bazel") + versions["import_script"] = get_version_sasl_from_workspace( + repo_root + "/WORKSPACE.bazel" + ) logger.info( f"VERSION SPECIAL CASE: {component_key}: Found version '{versions['import_script']}' in 'WORKSPACE.bazel' file" ) @@ -169,4 +173,5 @@ def process_component_special_cases( f"VERSION SPECIAL CASE: {component_key}: Found version '{versions['import_script']}' in 'RELEASE_INFO' file" ) + # endregion special component use-case functions diff --git a/buildscripts/sbom/endorctl_utils.py b/buildscripts/sbom/endorctl_utils.py index 8de36797e64..783dcca028d 100644 --- a/buildscripts/sbom/endorctl_utils.py +++ b/buildscripts/sbom/endorctl_utils.py @@ -156,7 +156,12 @@ class EndorCtl: """https://docs.endorlabs.com/endorctl/""" try: - command = [self.endorctl_path, command, subcommand, f"--namespace={self.namespace}"] + command = [ + self.endorctl_path, + command, + subcommand, + f"--namespace={self.namespace}", + ] if self.config_path: command.append(f"--config-path={self.config_path}") @@ -198,12 +203,18 @@ class EndorCtl: tries = 0 while True: tries += 1 - result = self._call_endorctl("api", "list", resource=resource, filter=filter, **kwargs) + result = self._call_endorctl( + "api", "list", resource=resource, filter=filter, **kwargs + ) # The expected output of 'endorctl api list' is: { "list": { "objects": [...] } } # We want to just return the objects. In case we get an empty list, return a list # with a single None to avoid having to handle index errors downstream. - if result and result["list"].get("objects") and len(result["list"]["objects"]) > 0: + if ( + result + and result["list"].get("objects") + and len(result["list"]["objects"]) > 0 + ): return result["list"]["objects"] elif retry: logger.info( @@ -286,7 +297,9 @@ class EndorCtl: resource_description = ( f"{resource_kind} with filter '{filter}' in namespace '{self.namespace}'" ) - scan_result = self.get_resources(resource_kind, filter=filter, retry=retry, page_size=1)[0] + scan_result = self.get_resources( + resource_kind, filter=filter, retry=retry, page_size=1 + )[0] self._check_resource(scan_result, resource_description) uuid = scan_result.get("uuid") start_time = scan_result["spec"].get("start_time") @@ -402,7 +415,9 @@ class EndorCtl: ] # Export SBOM - sbom = self.export_sbom(package_version_uuids=package_version_uuids, app_name=app_name) + sbom = self.export_sbom( + package_version_uuids=package_version_uuids, app_name=app_name + ) print( f"Retrieved: CycloneDX SBOM for PackageVersion(s), name: {package_version_names}, uuid: {package_version_uuids}" ) @@ -429,7 +444,9 @@ class EndorCtl: repository_version_uuid = repository_version["uuid"] repository_version_ref = repository_version["spec"]["version"]["ref"] repository_version_sha = repository_version["spec"]["version"]["sha"] - repository_version_scan_object_status = repository_version["scan_object"]["status"] + repository_version_scan_object_status = repository_version["scan_object"][ + "status" + ] if repository_version_scan_object_status != "STATUS_SCANNED": logger.warning( f"RepositoryVersion (uuid: {repository_version_uuid}, ref: {repository_version_ref}, sha: {repository_version_sha}) scan status is '{repository_version_scan_object_status}' (expected 'STATUS_SCANNED')" @@ -437,7 +454,10 @@ class EndorCtl: # ScanResult: search for a completed scan filter_str = endor_filter.scan_result( - EndorContextType.MAIN, project_uuid, repository_version_ref, repository_version_sha + EndorContextType.MAIN, + project_uuid, + repository_version_ref, + repository_version_sha, ) scan_result = self.get_scan_result(filter_str, retry=False) project_uuid = scan_result["meta"]["parent_uuid"] @@ -449,13 +469,17 @@ class EndorCtl: else: context_type = EndorContextType.REF context_id = branch - filter_str = endor_filter.package_version(context_type, context_id, project_uuid) + filter_str = endor_filter.package_version( + context_type, context_id, project_uuid + ) package_version = self.get_package_versions(filter_str)[0] package_version_name = package_version["meta"]["name"] package_version_uuid = package_version["uuid"] # Export SBOM - sbom = self.export_sbom(package_version_uuid=package_version_uuid, app_name=app_name) + sbom = self.export_sbom( + package_version_uuid=package_version_uuid, app_name=app_name + ) logger.info( f"SBOM: Retrieved CycloneDX SBOM for PackageVersion, name: {package_version_name}, uuid {package_version_uuid}" ) diff --git a/buildscripts/sbom/generate_sbom.py b/buildscripts/sbom/generate_sbom.py index a0ea8fd5694..e601bd2a09a 100755 --- a/buildscripts/sbom/generate_sbom.py +++ b/buildscripts/sbom/generate_sbom.py @@ -60,7 +60,9 @@ script_directory = script_path.parent # Regex for validation REGEX_COMMIT_SHA = r"^[0-9a-fA-F]{40}$" REGEX_GIT_BRANCH = r"^[a-zA-Z0-9_.\-/]+$" -REGEX_GITHUB_URL = r"^(https://github.com/)([a-zA-Z0-9-]{1,39}/[a-zA-Z0-9-_.]{1,100})(\.git)$" +REGEX_GITHUB_URL = ( + r"^(https://github.com/)([a-zA-Z0-9-]{1,39}/[a-zA-Z0-9-_.]{1,100})(\.git)$" +) REGEX_RELEASE_BRANCH = r"^v\d\.\d$" REGEX_RELEASE_TAG = r"^r\d\.\d.\d(-\w*)?$" @@ -165,9 +167,13 @@ class GitInfo: # filter tags for latest release e.g., r8.2.1 release_tags = [] filtered_tags = [ - tag for tag in self._repo.tags if re.fullmatch(REGEX_RELEASE_TAG, tag.name) + tag + for tag in self._repo.tags + if re.fullmatch(REGEX_RELEASE_TAG, tag.name) ] - logging.info(f"GIT: Parsing {len(filtered_tags)} release tags for match to commit") + logging.info( + f"GIT: Parsing {len(filtered_tags)} release tags for match to commit" + ) for tag in filtered_tags: if tag.commit == self.commit: release_tags.append(tag.name) @@ -235,7 +241,9 @@ def is_valid_purl(purl: str) -> bool: """Validate a GitHub or Generic PURL""" for purl_type, regex in REGEX_PURL.items(): if regex.match(purl): - logger.debug(f"PURL: {purl} matched PURL type '{purl_type}' regex '{regex.pattern}'") + logger.debug( + f"PURL: {purl} matched PURL type '{purl_type}' regex '{regex.pattern}'" + ) return True return False @@ -245,7 +253,8 @@ def sbom_components_to_dict(sbom: dict, with_version: bool = False) -> dict: components = sbom["components"] if with_version: components_dict = { - urllib.parse.unquote(component["bom-ref"]): component for component in components + urllib.parse.unquote(component["bom-ref"]): component + for component in components } else: components_dict = { @@ -283,7 +292,9 @@ def read_sbom_json_file(file_path: str) -> dict: logger.error(f"Error loading SBOM file from {file_path}") logger.error(e) else: - logger.info(f"SBOM loaded from {file_path} with {len(result['components'])} components") + logger.info( + f"SBOM loaded from {file_path} with {len(result['components'])} components" + ) return result @@ -335,7 +346,9 @@ def set_component_version( component["cpe"] = component["cpe"].replace("{{VERSION}}", cpe_version) -def set_dependency_version(dependencies: list, meta_bom_ref: str, purl_version: str) -> None: +def set_dependency_version( + dependencies: list, meta_bom_ref: str, purl_version: str +) -> None: """Update the appropriate dependency version fields in the metadata SBOM""" r = 0 d = 0 @@ -350,7 +363,9 @@ def set_dependency_version(dependencies: list, meta_bom_ref: str, purl_version: ) d += 1 - logger.debug(f"set_dependency_version: '{meta_bom_ref}' updated {r} refs and {d} dependsOn") + logger.debug( + f"set_dependency_version: '{meta_bom_ref}' updated {r} refs and {d} dependsOn" + ) def get_subfolders_dict(folder_path: str = ".") -> dict: @@ -408,7 +423,10 @@ def del_component_priority_version_source(component: dict) -> None: # Reverse iterate properties list to safely modify in situ if "properties" in component: for i in range(len(component["properties"]) - 1, -1, -1): - if component["properties"][i].get("name") == "generate_sbom:priority_version_source": + if ( + component["properties"][i].get("name") + == "generate_sbom:priority_version_source" + ): logger.debug( f"PRIORITY VERSION SOURCE: {component['bom-ref']}: Removing priority version source from SBOM metadata." ) @@ -461,7 +479,9 @@ def main() -> None: type=str, ) endor.add_argument( - "--namespace", help="Endor Labs namespace (Default: mongodb.{git org})", type=str + "--namespace", + help="Endor Labs namespace (Default: mongodb.{git org})", + type=str, ) endor.add_argument( "--target", @@ -476,7 +496,9 @@ def main() -> None: type=str, ) - target = parser.add_argument_group("Target values. Apply only if --target is not 'project'") + target = parser.add_argument_group( + "Target values. Apply only if --target is not 'project'" + ) exclusive_target = target.add_mutually_exclusive_group() exclusive_target.add_argument( "--commit", @@ -526,7 +548,9 @@ def main() -> None: default=None, type=str, ) - parser.add_argument("--debug", help="Set logging level to DEBUG", action="store_true") + parser.add_argument( + "--debug", help="Set logging level to DEBUG", action="store_true" + ) # endregion define args @@ -589,7 +613,9 @@ def main() -> None: # region export Endor Labs SBOM print_banner(f"Exporting Endor Labs SBOM for {target} {getattr(git_info, target)}") - endorctl = EndorCtl(namespace, retry_limit, sleep_duration, endorctl_path, config_path) + endorctl = EndorCtl( + namespace, retry_limit, sleep_duration, endorctl_path, config_path + ) if target == "commit": endor_bom = endorctl.get_sbom_for_commit(git_info.project, git_info.commit) elif target == "branch": @@ -602,7 +628,9 @@ def main() -> None: if not endor_bom: logger.error("Empty result for Endor SBOM!") if target == "commit": - logger.error("Check Endor Labs for any unanticipated issues with the target PR scan.") + logger.error( + "Check Endor Labs for any unanticipated issues with the target PR scan." + ) else: logger.error("Check Endor Labs for status of the target monitoring scan.") sys.exit(1) @@ -633,7 +661,9 @@ def main() -> None: component["bom-ref"] = component["bom-ref"].replace(old, new) component["purl"] = component["purl"].replace(old, new) - logger.info(f"Endor Labs SBOM pre-processed with {len(endor_bom['components'])} components") + logger.info( + f"Endor Labs SBOM pre-processed with {len(endor_bom['components'])} components" + ) # endregion Pre-process Endor Labs SBOM @@ -735,7 +765,9 @@ def main() -> None: ) # Set main component version - set_component_version(meta_bom["metadata"]["component"], version, purl_version, cpe_version) + set_component_version( + meta_bom["metadata"]["component"], version, purl_version, cpe_version + ) # Run through 'dependency' objects to set main component version set_dependency_version(meta_bom["dependencies"], meta_bom_ref, purl_version) @@ -745,7 +777,9 @@ def main() -> None: # region Parse metadata SBOM components - third_party_folders = get_subfolders_dict(git_info.repo_root.as_posix() + "/src/third_party") + third_party_folders = get_subfolders_dict( + git_info.repo_root.as_posix() + "/src/third_party" + ) # pre-exclude 'scripts' folder del third_party_folders["scripts"] @@ -786,9 +820,13 @@ def main() -> None: if import_script_path: import_script = Path(import_script_path) if import_script.exists(): - versions["import_script"] = get_version_from_import_script(import_script_path) + versions["import_script"] = get_version_from_import_script( + import_script_path + ) if versions["import_script"]: - versions["import_script"] = versions["import_script"].replace("release-", "") + versions["import_script"] = versions["import_script"].replace( + "release-", "" + ) if versions["import_script"]: logger.debug( f"VERSION IMPORT SCRIPT: {component_key}: Found version '{versions['import_script']}' in import script '{import_script_path}'" @@ -839,13 +877,18 @@ def main() -> None: # For the standard workflow, we favor the pre-set priority version source, # followed by Endor Labs version, followed by import script, followed by hard coded - if versions["priority_version_source"] and versions["priority_version_source"] in versions: + if ( + versions["priority_version_source"] + and versions["priority_version_source"] in versions + ): version = versions[versions["priority_version_source"]] logger.info( f"VERSION: {component_key}: Using priority_version_source '{priority_version_source}' from metadata file." ) else: - version = versions["endor"] or versions["import_script"] or versions["metadata"] + version = ( + versions["endor"] or versions["import_script"] or versions["metadata"] + ) ############## Assign Version ############### if version: @@ -854,7 +897,9 @@ def main() -> None: ## Special case for FireFox ## # The CPE for FireFox ESR needs the 'esr' removed from the version, as it is specified in another section if component["bom-ref"].startswith("pkg:deb/debian/firefox-esr@"): - set_component_version(component, version, cpe_version=version.replace("esr", "")) + set_component_version( + component, version, cpe_version=version.replace("esr", "") + ) else: semver = get_semver_from_release_version(version) set_component_version(component, semver, version, semver) @@ -934,7 +979,8 @@ def main() -> None: # Have the SBOM app version changed? sbom_app_version_changed = ( - prev_bom["metadata"]["component"]["version"] != meta_bom["metadata"]["component"]["version"] + prev_bom["metadata"]["component"]["version"] + != meta_bom["metadata"]["component"]["version"] ) logger.info(f"SUMMARY: MongoDB version changed: {sbom_app_version_changed}") @@ -949,7 +995,9 @@ def main() -> None: # Components in prev SBOM but not in generated SBOM prev_components = sbom_components_to_dict(prev_bom, with_version=False) meta_components = sbom_components_to_dict(meta_bom, with_version=False) - prev_components_diff = list(set(prev_components.keys()) - set(meta_components.keys())) + prev_components_diff = list( + set(prev_components.keys()) - set(meta_components.keys()) + ) if prev_components_diff: logger.info( "SBOM_DIFF: Components in previous SBOM and not in generated SBOM: " @@ -957,7 +1005,9 @@ def main() -> None: ) # Components in generated SBOM but not in prev SBOM - meta_components_diff = list(set(meta_components.keys()) - set(prev_components.keys())) + meta_components_diff = list( + set(meta_components.keys()) - set(prev_components.keys()) + ) if meta_components_diff: logger.info( "SBOM_DIFF: Components in generated SBOM and not in previous SBOM: " @@ -982,7 +1032,9 @@ def main() -> None: # Only update the timestamp if something has changed if sbom_app_version_changed or sbom_components_changed: meta_bom["metadata"]["timestamp"] = ( - datetime.now(timezone.utc).isoformat(timespec="seconds").replace("+00:00", "Z") + datetime.now(timezone.utc) + .isoformat(timespec="seconds") + .replace("+00:00", "Z") ) else: meta_bom["metadata"]["timestamp"] = prev_bom["metadata"]["timestamp"] @@ -1005,7 +1057,9 @@ def main() -> None: print_banner("COMPLETED") if not os.getenv("CI"): - print("Be sure to add the SBOM to your next commit if the file content has changed.") + print( + "Be sure to add the SBOM to your next commit if the file content has changed." + ) # endregion Finalize SBOM diff --git a/buildscripts/sbom/sbom_files_pr.py b/buildscripts/sbom/sbom_files_pr.py index 34deda5e2ca..558ba676d3c 100644 --- a/buildscripts/sbom/sbom_files_pr.py +++ b/buildscripts/sbom/sbom_files_pr.py @@ -20,7 +20,9 @@ from github import ( SBOM_FILES = ["sbom.json", "README.third_party.md"] -def get_repository(github_owner, github_repo, app_id, _private_key) -> Repository.Repository: +def get_repository( + github_owner, github_repo, app_id, _private_key +) -> Repository.Repository: """ Gets the mongo github repository """ @@ -52,7 +54,8 @@ def create_branch(base_branch, new_branch) -> None: """ try: print( - f"Attempting to create branch '{new_branch}' with base branch '{base_branch}'.") + f"Attempting to create branch '{new_branch}' with base branch '{base_branch}'." + ) ref = f"refs/heads/{new_branch}" base_repo_branch = repo.get_branch(base_branch) sha = base_repo_branch.commit.sha @@ -83,16 +86,18 @@ if __name__ == "__main__": description="This script checks for changes to SBOM and related files and creats a PR if files have been updated.", ) parser.add_argument( - "--github-owner", help="GitHub org/owner (e.g., 10gen).", type=str) + "--github-owner", help="GitHub org/owner (e.g., 10gen).", type=str + ) parser.add_argument( - "--github-repo", help="GitHub repository name (e.g., mongo).", type=str) - parser.add_argument( - "--base-branch", help="base branch to merge into.", type=str) - parser.add_argument( - "--new-branch", help="New branch for the PR.", type=str) + "--github-repo", help="GitHub repository name (e.g., mongo).", type=str + ) + parser.add_argument("--base-branch", help="base branch to merge into.", type=str) + parser.add_argument("--new-branch", help="New branch for the PR.", type=str) parser.add_argument("--pr-title", help="Title for the PR.", type=str) parser.add_argument( - "--saved-warnings", help="Path to file to include as text in PR message.", type=str + "--saved-warnings", + help="Path to file to include as text in PR message.", + type=str, ) parser.add_argument( "--app-id", @@ -115,19 +120,20 @@ if __name__ == "__main__": # Replace spaces with newline, if applicable private_key = ( - args.private_key[:31] + args.private_key[31:- - 29].replace(" ", "\n") + args.private_key[-29:] + args.private_key[:31] + + args.private_key[31:-29].replace(" ", "\n") + + args.private_key[-29:] ) - repo = get_repository( - args.github_owner, args.github_repo, args.app_id, private_key) + repo = get_repository(args.github_owner, args.github_repo, args.app_id, private_key) print("repo: ", repo) HAS_UPDATE = False for file_path in SBOM_FILES: original_file = repo.get_contents( - file_path, ref=f"refs/heads/{args.base_branch}") + file_path, ref=f"refs/heads/{args.base_branch}" + ) print("original_file: ", original_file) original_content = original_file.decoded_content.decode() try: @@ -140,9 +146,9 @@ if __name__ == "__main__": PATTERN = r'{"name":"EndorLabsInc","version":".*"}' REPL = r'{"name":"EndorLabsInc","version":""}' original_content_compare = re.sub( - PATTERN, REPL, "".join(original_content.split())) - new_content_compare = re.sub( - PATTERN, REPL, "".join(new_content.split())) + PATTERN, REPL, "".join(original_content.split()) + ) + new_content_compare = re.sub(PATTERN, REPL, "".join(new_content.split())) if original_content_compare != new_content_compare: create_branch(args.base_branch, args.new_branch) @@ -178,7 +184,9 @@ if __name__ == "__main__": if HAS_UPDATE: # Get open PR or create new PR pull_requests = repo.get_pulls( - state="open", head=f"{args.github_owner}:{args.new_branch}", base=args.base_branch + state="open", + head=f"{args.github_owner}:{args.new_branch}", + base=args.base_branch, ) if pull_requests.totalCount: pull_request = pull_requests[0] @@ -200,7 +208,9 @@ if __name__ == "__main__": print("pull_request: ", pull_request) if args.saved_warnings: - pr_comment = "The following warnings were output by the SBOM generation script:\n" + pr_comment = ( + "The following warnings were output by the SBOM generation script:\n" + ) if os.path.isfile(args.saved_warnings): pr_comment += read_text_file(args.saved_warnings) comment = pull_request.create_issue_comment(pr_comment) diff --git a/buildscripts/sbom_linter.py b/buildscripts/sbom_linter.py index 4b3bc30141e..e4da91857a7 100644 --- a/buildscripts/sbom_linter.py +++ b/buildscripts/sbom_linter.py @@ -13,8 +13,12 @@ except ImportError: print("'jsonschema' not found. Continuing without it.") jsonschema = None -BOM_SCHEMA_LOCATION = os.path.join("buildscripts", "tests", "sbom_linter", "bom-1.5.schema.json") -SPDX_SCHEMA_LOCATION = os.path.join("buildscripts", "tests", "sbom_linter", "spdx.schema.json") +BOM_SCHEMA_LOCATION = os.path.join( + "buildscripts", "tests", "sbom_linter", "bom-1.5.schema.json" +) +SPDX_SCHEMA_LOCATION = os.path.join( + "buildscripts", "tests", "sbom_linter", "spdx.schema.json" +) SPDX_SCHEMA_REF = "spdx.schema.json" # directory to scan for third party libraries @@ -39,8 +43,12 @@ SCHEMA_MATCH_FAILURE = "File did not match the CycloneDX schema" MISSING_VERSION_IN_SBOM_COMPONENT_ERROR = "Component must include a version." MISSING_VERSION_IN_IMPORT_FILE_ERROR = "Missing version in the import file: " MISSING_LICENSE_IN_SBOM_COMPONENT_ERROR = "Component must include a license." -COULD_NOT_FIND_OR_READ_SCRIPT_FILE_ERROR = "Could not find or read the import script file" -VERSION_MISMATCH_ERROR = "Version mismatch (may simply be an artifact of SBOM automation): " +COULD_NOT_FIND_OR_READ_SCRIPT_FILE_ERROR = ( + "Could not find or read the import script file" +) +VERSION_MISMATCH_ERROR = ( + "Version mismatch (may simply be an artifact of SBOM automation): " +) # A class for managing error messages for components @@ -103,7 +111,9 @@ def get_script_version( try: file = open(script_path, "r") except OSError: - error_manager.append_full_error_message(COULD_NOT_FIND_OR_READ_SCRIPT_FILE_ERROR) + error_manager.append_full_error_message( + COULD_NOT_FIND_OR_READ_SCRIPT_FILE_ERROR + ) return result with file: @@ -141,20 +151,26 @@ def validate_license(component: dict, error_manager: ErrorManager) -> None: valid_license = True if not valid_license: - licensing_validate = get_spdx_licensing().validate(expression, validate=True) + licensing_validate = get_spdx_licensing().validate( + expression, validate=True + ) # ExpressionInfo( # original_expression='', # normalized_expression='', # errors=[], # invalid_symbols=[] # ) - valid_license = not licensing_validate.errors or not licensing_validate.invalid_symbols + valid_license = ( + not licensing_validate.errors or not licensing_validate.invalid_symbols + ) if not valid_license: error_manager.append_full_error_message(licensing_validate) return -def validate_evidence(component: dict, third_party_libs: set, error_manager: ErrorManager) -> None: +def validate_evidence( + component: dict, third_party_libs: set, error_manager: ErrorManager +) -> None: if component.get("scope") == "required": if "evidence" not in component or "occurrences" not in component["evidence"]: error_manager.append_full_error_message(MISSING_EVIDENCE_ERROR) @@ -179,7 +195,9 @@ def validate_properties(component: dict, error_manager: ErrorManager) -> None: if script_path: script_path_is_file = os.path.isfile(script_path) if not script_path_is_file: - error_manager.append_full_error_message(COULD_NOT_FIND_OR_READ_SCRIPT_FILE_ERROR) + error_manager.append_full_error_message( + COULD_NOT_FIND_OR_READ_SCRIPT_FILE_ERROR + ) # Only look for VERSION if the import script is a shell script file elif script_path.endswith(".sh"): script_version = get_script_version(script_path, "VERSION", error_manager) @@ -193,7 +211,9 @@ def validate_properties(component: dict, error_manager: ErrorManager) -> None: return -def validate_component(component: dict, third_party_libs: set, error_manager: ErrorManager) -> None: +def validate_component( + component: dict, third_party_libs: set, error_manager: ErrorManager +) -> None: error_manager.update_component_attribute(component["name"]) if "scope" not in component: error_manager.append_full_error_message("component must include a scope.") @@ -207,7 +227,9 @@ def validate_component(component: dict, third_party_libs: set, error_manager: Er error_manager.update_component_attribute("") -def validate_location(component: dict, third_party_libs: set, error_manager: ErrorManager) -> None: +def validate_location( + component: dict, third_party_libs: set, error_manager: ErrorManager +) -> None: if "evidence" in component: if "occurrences" not in component["evidence"]: error_manager.append_full_error_message( @@ -220,7 +242,9 @@ def validate_location(component: dict, third_party_libs: set, error_manager: Err location = occurrence["location"] if not os.path.exists(location) and not SKIP_FILE_CHECKING: - error_manager.append_full_error_message("location does not exist in repo.") + error_manager.append_full_error_message( + "location does not exist in repo." + ) if location.startswith(THIRD_PARTY_LOCATION_PREFIX): lib = location.removeprefix(THIRD_PARTY_LOCATION_PREFIX) @@ -283,7 +307,9 @@ def main() -> int: help="Whether to apply formatting to the output file.", ) parser.add_argument( - "--input-file", default="sbom.json", help="The input CycloneDX file to format and lint." + "--input-file", + default="sbom.json", + help="The input CycloneDX file to format and lint.", ) parser.add_argument( "--output-file", diff --git a/buildscripts/setup_multiversion_mongodb.py b/buildscripts/setup_multiversion_mongodb.py index 3debea3c537..e17801c01d3 100755 --- a/buildscripts/setup_multiversion_mongodb.py +++ b/buildscripts/setup_multiversion_mongodb.py @@ -9,4 +9,5 @@ if __name__ == "__main__": "\n" "The latter ensures the tool is in the global PATH. See installation instructions for `pipx`\n" "here if you don't have it:\n" - "https://github.com/pypa/pipx#on-linux-install-via-pip-requires-pip-190-or-later") + "https://github.com/pypa/pipx#on-linux-install-via-pip-requires-pip-190-or-later" + ) diff --git a/buildscripts/sign_macos_binaries_for_testing.py b/buildscripts/sign_macos_binaries_for_testing.py index 7775d6c606f..0217f4fc693 100644 --- a/buildscripts/sign_macos_binaries_for_testing.py +++ b/buildscripts/sign_macos_binaries_for_testing.py @@ -2,7 +2,7 @@ Signs all of the known testing binaries with insecure development entitlements. Specifically the `Get Task Allow` is what we are looking for. -Adding the `Get Task Allow` entitlement allows us to attach to +Adding the `Get Task Allow` entitlement allows us to attach to the mongo processes and get core dumps/debug in any way we need. You can view some more documentation on this topic here: https://developer.apple.com/documentation/bundleresources/entitlements/com_apple_security_cs_debugger#discussion @@ -25,7 +25,9 @@ def main(): build_bin_dir = os.path.join("build", "install", "bin") binary_directories = [MULTIVERSION_BIN_DIR, LOCAL_BIN_DIR, build_bin_dir] - entitlements_file = os.path.abspath(os.path.join("etc", "macos_dev_entitlements.xml")) + entitlements_file = os.path.abspath( + os.path.join("etc", "macos_dev_entitlements.xml") + ) assert os.path.exists(entitlements_file), f"{entitlements_file} does not exist" for binary_dir in binary_directories: @@ -38,11 +40,19 @@ def main(): continue print(f"Signing {binary}") - subprocess.run([ - "/usr/bin/codesign", "-s", "-", "-f", "--entitlements", entitlements_file, - binary_path - ], check=True) + subprocess.run( + [ + "/usr/bin/codesign", + "-s", + "-", + "-f", + "--entitlements", + entitlements_file, + binary_path, + ], + check=True, + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/buildscripts/sync_repo_with_copybara.py b/buildscripts/sync_repo_with_copybara.py index 17451d8a98c..879fe861255 100644 --- a/buildscripts/sync_repo_with_copybara.py +++ b/buildscripts/sync_repo_with_copybara.py @@ -1,4 +1,5 @@ """Module for syncing a repo with Copybara and setting up configurations.""" + from __future__ import annotations import argparse @@ -108,8 +109,9 @@ def run_command(command): """ try: - return subprocess.run(command, shell=True, check=True, text=True, - capture_output=True).stdout + return subprocess.run( + command, shell=True, check=True, text=True, capture_output=True + ).stdout except subprocess.CalledProcessError as e: print(f"Error while executing: '{command}'.\n{e}\nStandard Error: {e.stderr}") raise @@ -126,14 +128,15 @@ def create_mongodb_bot_gitconfig(): gitconfig_path = os.path.expanduser("~/mongodb-bot.gitconfig") - with open(gitconfig_path, 'w') as file: + with open(gitconfig_path, "w") as file: file.write(content) print("mongodb-bot.gitconfig file created.") -def get_installation_access_token(app_id: int, private_key: str, - installation_id: int) -> Optional[str]: # noqa: D407,D413 +def get_installation_access_token( + app_id: int, private_key: str, installation_id: int +) -> Optional[str]: # noqa: D407,D413 """ Obtain an installation access token using JWT. @@ -166,7 +169,8 @@ def send_failure_message_to_slack(expansions): error_msg = ( "Evergreen task '* Copybara Sync Between Repos' failed\n" "See troubleshooting doc .\n" - f"See task log here: .") + f"See task log here: ." + ) evg_api = RetryingEvergreenApi.get_api(config_file=".evergreen.yml") evg_api.send_slack_message( @@ -186,13 +190,14 @@ def check_destination_branch_exists(copybara_config: CopybaraConfig) -> bool: - bool: `True` if the branch exists in the destination repository, `False` otherwise. """ - command = ( - f"git ls-remote {copybara_config.destination.git_url} {copybara_config.destination.branch}") + command = f"git ls-remote {copybara_config.destination.git_url} {copybara_config.destination.branch}" output = run_command(command) return copybara_config.destination.branch in output -def find_matching_commit(dir_source_repo: str, dir_destination_repo: str) -> Optional[str]: +def find_matching_commit( + dir_source_repo: str, dir_destination_repo: str +) -> Optional[str]: """ Finds a matching commit in the destination repository based on the commit hash from the source repository. @@ -245,8 +250,8 @@ def has_only_destination_repo_remote(repo_name: str): Returns bool: True if the repository only contains the destination repository remote URL, False otherwise. """ - git_config_path = os.path.join('.git', 'config') - with open(git_config_path, 'r') as f: + git_config_path = os.path.join(".git", "config") + with open(git_config_path, "r") as f: config_content = f.read() # Define a regular expression pattern to match the '{owner}/{repo}.git' @@ -261,8 +266,11 @@ def has_only_destination_repo_remote(repo_name: str): return False -def push_branch_to_destination_repo(destination_repo_dir: str, copybara_config: CopybaraConfig, - branching_off_commit: str): +def push_branch_to_destination_repo( + destination_repo_dir: str, + copybara_config: CopybaraConfig, + branching_off_commit: str, +): """ Pushes a new branch to the remote repository after ensuring it branches off the public repository. @@ -279,13 +287,16 @@ def push_branch_to_destination_repo(destination_repo_dir: str, copybara_config: # Check the current repo has only destination repository remote. if not has_only_destination_repo_remote(copybara_config.destination.repo_name): - raise Exception(f"{destination_repo_dir} git repo has not only the destination repo remote") + raise Exception( + f"{destination_repo_dir} git repo has not only the destination repo remote" + ) # Confirm the top commit is matching the found commit before pushing new_branch_top_commit = run_command('git log --pretty=format:"%H" -1') if not new_branch_top_commit == branching_off_commit: raise Exception( - "The new branch top commit does not match the branching_off_commit. Aborting push.") + "The new branch top commit does not match the branching_off_commit. Aborting push." + ) # Confirming whether the commit exists in the destination repository to ensure # we are not pushing anything that isn't already in the destination repository. @@ -294,7 +305,8 @@ def push_branch_to_destination_repo(destination_repo_dir: str, copybara_config: # Push the new branch to the destination repository run_command( - f"git push {copybara_config.destination.git_url} {copybara_config.destination.branch}") + f"git push {copybara_config.destination.git_url} {copybara_config.destination.branch}" + ) def create_branch_from_matching_commit(copybara_config: CopybaraConfig) -> None: @@ -311,7 +323,9 @@ def create_branch_from_matching_commit(copybara_config: CopybaraConfig) -> None: try: # Create a unique directory based on the current timestamp. working_dir = os.path.join( - original_dir, "make_branch_attempt_" + datetime.now().strftime("%Y_%m_%d_%H_%M_%S")) + original_dir, + "make_branch_attempt_" + datetime.now().strftime("%Y_%m_%d_%H_%M_%S"), + ) os.makedirs(working_dir, exist_ok=True) os.chdir(working_dir) @@ -319,13 +333,18 @@ def create_branch_from_matching_commit(copybara_config: CopybaraConfig) -> None: cloned_source_repo_dir = os.path.join(working_dir, "source-repo") cloned_destination_repo_dir = os.path.join(working_dir, "destination-repo") - run_command(f"git clone -b {copybara_config.source.branch}" - f" {copybara_config.source.git_url} {cloned_source_repo_dir}") run_command( - f"git clone {copybara_config.destination.git_url} {cloned_destination_repo_dir}") + f"git clone -b {copybara_config.source.branch}" + f" {copybara_config.source.git_url} {cloned_source_repo_dir}" + ) + run_command( + f"git clone {copybara_config.destination.git_url} {cloned_destination_repo_dir}" + ) # Find matching commits to branching off - commit = find_matching_commit(cloned_source_repo_dir, cloned_destination_repo_dir) + commit = find_matching_commit( + cloned_source_repo_dir, cloned_destination_repo_dir + ) if commit is not None: # Delete the cloned_source_repo_dir folder shutil.rmtree(cloned_source_repo_dir) @@ -334,10 +353,14 @@ def create_branch_from_matching_commit(copybara_config: CopybaraConfig) -> None: # Once a matching commit is found, create a new branch based on it. os.chdir(cloned_destination_repo_dir) - run_command(f"git checkout -b {copybara_config.destination.branch} {commit}") + run_command( + f"git checkout -b {copybara_config.destination.branch} {commit}" + ) # Push the new branch to the remote repository - push_branch_to_destination_repo(cloned_destination_repo_dir, copybara_config, commit) + push_branch_to_destination_repo( + cloned_destination_repo_dir, copybara_config, commit + ) else: print( f"Could not find matching commits between {copybara_config.destination.repo_name}/master" @@ -356,13 +379,17 @@ def main(): """Clone the Copybara repo, build its Docker image, and set up and run migrations.""" parser = argparse.ArgumentParser() - parser.add_argument("--expansions-file", "-e", default="../expansions.yml", - help="Location of expansions file generated by evergreen.") + parser.add_argument( + "--expansions-file", + "-e", + default="../expansions.yml", + help="Location of expansions file generated by evergreen.", + ) args = parser.parse_args() # Check if the copybara directory already exists - if os.path.exists('copybara'): + if os.path.exists("copybara"): print("Copybara directory already exists.") else: run_command("git clone https://github.com/10gen/copybara.git") @@ -430,11 +457,15 @@ def main(): else: if not check_destination_branch_exists(copybara_config): create_branch_from_matching_commit(copybara_config) - print(f"New branch named '{copybara_config.destination.branch}' has been created" - f" for the '{copybara_config.destination.repo_name}' repo") + print( + f"New branch named '{copybara_config.destination.branch}' has been created" + f" for the '{copybara_config.destination.repo_name}' repo" + ) else: - print(f"The branch named '{copybara_config.destination.branch}' already exists" - f" in the '{copybara_config.destination.repo_name}' repo.") + print( + f"The branch named '{copybara_config.destination.branch}' already exists" + f" in the '{copybara_config.destination.repo_name}' repo." + ) # Set up the Docker command and execute it docker_cmd = [ @@ -461,8 +492,10 @@ def main(): "Updates were rejected because the remote contains work that you do", ] - if any(acceptable_message in error_message - for acceptable_message in acceptable_error_messages): + if any( + acceptable_message in error_message + for acceptable_message in acceptable_error_messages + ): return # Send a failure message to #devprod-build-automation if the Copybara sync task fails. diff --git a/buildscripts/tests/burn_in_tests_end2end/test_burn_in_tests_end2end.py b/buildscripts/tests/burn_in_tests_end2end/test_burn_in_tests_end2end.py index d82d9db3528..7402d171c81 100644 --- a/buildscripts/tests/burn_in_tests_end2end/test_burn_in_tests_end2end.py +++ b/buildscripts/tests/burn_in_tests_end2end/test_burn_in_tests_end2end.py @@ -11,11 +11,13 @@ import buildscripts.burn_in_tests as under_test class TestBurnInTestsEnd2End(unittest.TestCase): @classmethod def setUpClass(cls): - subprocess.run([ - sys.executable, - "buildscripts/burn_in_tests.py", - "generate-test-membership-map-file-for-ci", - ]) + subprocess.run( + [ + sys.executable, + "buildscripts/burn_in_tests.py", + "generate-test-membership-map-file-for-ci", + ] + ) @classmethod def tearDownClass(cls): @@ -23,12 +25,16 @@ class TestBurnInTestsEnd2End(unittest.TestCase): os.remove(under_test.BURN_IN_TEST_MEMBERSHIP_FILE) def test_valid_yaml_output(self): - process = subprocess.run([ - sys.executable, - "buildscripts/burn_in_tests.py", - "run", - "--yaml", - ], text=True, capture_output=True) + process = subprocess.run( + [ + sys.executable, + "buildscripts/burn_in_tests.py", + "run", + "--yaml", + ], + text=True, + capture_output=True, + ) output = process.stdout self.assertEqual(0, process.returncode) diff --git a/buildscripts/tests/ciconfig/test_evergreen.py b/buildscripts/tests/ciconfig/test_evergreen.py index 3a942aa01c4..ecc592ef779 100644 --- a/buildscripts/tests/ciconfig/test_evergreen.py +++ b/buildscripts/tests/ciconfig/test_evergreen.py @@ -15,7 +15,9 @@ class TestEvergreenProjectConfig(unittest.TestCase): @classmethod def setUpClass(cls): env = os.environ.copy() - cls.conf = _evergreen.parse_evergreen_file(TEST_FILE_PATH, evergreen_binary=None) + cls.conf = _evergreen.parse_evergreen_file( + TEST_FILE_PATH, evergreen_binary=None + ) # Assert there is no leakage of env variables from this function assert env == os.environ @@ -88,39 +90,51 @@ class TestTask(unittest.TestCase): def test_suite_to_resmoke_args_map_for_non_gen_task(self): suite_and_task = "jstestfuzz" - task_commands = [{ - "func": "run tests", - "vars": {"resmoke_args": "--arg=val"}, - }] + task_commands = [ + { + "func": "run tests", + "vars": {"resmoke_args": "--arg=val"}, + } + ] task_dict = {"name": suite_and_task, "commands": task_commands} task = _evergreen.Task(task_dict) - self.assertEqual({suite_and_task: f"--suites={suite_and_task} --arg=val"}, - task.suite_to_resmoke_args_map) + self.assertEqual( + {suite_and_task: f"--suites={suite_and_task} --arg=val"}, + task.suite_to_resmoke_args_map, + ) def test_suite_to_resmoke_args_map_for_gen_task(self): suite = "jsCore" - task_commands = [{ - "func": "generate resmoke tasks", - "vars": {"resmoke_args": "--installDir=/bin"}, - }] + task_commands = [ + { + "func": "generate resmoke tasks", + "vars": {"resmoke_args": "--installDir=/bin"}, + } + ] task_dict = {"name": f"{suite}_gen", "commands": task_commands} task = _evergreen.Task(task_dict) - self.assertEqual({suite: f"--suites={suite} --installDir=/bin"}, - task.suite_to_resmoke_args_map) + self.assertEqual( + {suite: f"--suites={suite} --installDir=/bin"}, + task.suite_to_resmoke_args_map, + ) def test_suite_to_resmoke_args_map_for_gen_task_with_suite(self): suite = "core" - task_commands = [{ - "func": "generate resmoke tasks", - "vars": {"suite": suite, "resmoke_args": "--installDir=/bin"}, - }] + task_commands = [ + { + "func": "generate resmoke tasks", + "vars": {"suite": suite, "resmoke_args": "--installDir=/bin"}, + } + ] task_dict = {"name": "jsCore", "commands": task_commands} task = _evergreen.Task(task_dict) - self.assertEqual({suite: f"--suites={suite} --installDir=/bin"}, - task.suite_to_resmoke_args_map) + self.assertEqual( + {suite: f"--suites={suite} --installDir=/bin"}, + task.suite_to_resmoke_args_map, + ) def test_suite_to_resmoke_args_map_for_initialize_multiversion_tasks_task(self): task_commands = [ @@ -143,23 +157,22 @@ class TestTask(unittest.TestCase): task_dict = {"name": "multiversion_sanity_check_gen", "commands": task_commands} task = _evergreen.Task(task_dict) - self.assertEqual({ - "multiversion_sanity_check_last_continuous_new_new_old": - "--suites=multiversion_sanity_check_last_continuous_new_new_old --installDir=/bin", - "multiversion_sanity_check_last_continuous_new_old_new": - "--suites=multiversion_sanity_check_last_continuous_new_old_new --installDir=/bin", - "multiversion_sanity_check_last_continuous_old_new_new": - "--suites=multiversion_sanity_check_last_continuous_old_new_new --installDir=/bin", - "multiversion_sanity_check_last_lts_new_new_old": - "--suites=multiversion_sanity_check_last_lts_new_new_old --installDir=/bin", - "multiversion_sanity_check_last_lts_new_old_new": - "--suites=multiversion_sanity_check_last_lts_new_old_new --installDir=/bin", - "multiversion_sanity_check_last_lts_old_new_new": - "--suites=multiversion_sanity_check_last_lts_old_new_new --installDir=/bin", - }, task.suite_to_resmoke_args_map) + self.assertEqual( + { + "multiversion_sanity_check_last_continuous_new_new_old": "--suites=multiversion_sanity_check_last_continuous_new_new_old --installDir=/bin", + "multiversion_sanity_check_last_continuous_new_old_new": "--suites=multiversion_sanity_check_last_continuous_new_old_new --installDir=/bin", + "multiversion_sanity_check_last_continuous_old_new_new": "--suites=multiversion_sanity_check_last_continuous_old_new_new --installDir=/bin", + "multiversion_sanity_check_last_lts_new_new_old": "--suites=multiversion_sanity_check_last_lts_new_new_old --installDir=/bin", + "multiversion_sanity_check_last_lts_new_old_new": "--suites=multiversion_sanity_check_last_lts_new_old_new --installDir=/bin", + "multiversion_sanity_check_last_lts_old_new_new": "--suites=multiversion_sanity_check_last_lts_old_new_new --installDir=/bin", + }, + task.suite_to_resmoke_args_map, + ) def test_is_run_tests_task(self): - task_commands = [{"func": "run tests", "vars": {"resmoke_args": "--suites=core"}}] + task_commands = [ + {"func": "run tests", "vars": {"resmoke_args": "--suites=core"}} + ] task_dict = {"name": "jsCore", "commands": task_commands} task = _evergreen.Task(task_dict) @@ -168,7 +181,9 @@ class TestTask(unittest.TestCase): self.assertFalse(task.is_initialize_multiversion_tasks_task) def test_run_tests_command(self): - task_commands = [{"func": "run tests", "vars": {"resmoke_args": "--suites=core"}}] + task_commands = [ + {"func": "run tests", "vars": {"resmoke_args": "--suites=core"}} + ] task_dict = {"name": "jsCore", "commands": task_commands} task = _evergreen.Task(task_dict) @@ -176,10 +191,12 @@ class TestTask(unittest.TestCase): def test_is_generate_resmoke_task(self): task_name = "core" - task_commands = [{ - "func": "generate resmoke tasks", - "vars": {"task": task_name, "resmoke_args": "--installDir=/bin"} - }] + task_commands = [ + { + "func": "generate resmoke tasks", + "vars": {"task": task_name, "resmoke_args": "--installDir=/bin"}, + } + ] task_dict = {"name": "jsCore", "commands": task_commands} task = _evergreen.Task(task_dict) @@ -188,9 +205,12 @@ class TestTask(unittest.TestCase): self.assertFalse(task.is_initialize_multiversion_tasks_task) def test_generate_resmoke_tasks_command(self): - task_commands = [{ - "func": "generate resmoke tasks", "vars": {"resmoke_args": "--installDir=/bin"} - }] + task_commands = [ + { + "func": "generate resmoke tasks", + "vars": {"resmoke_args": "--installDir=/bin"}, + } + ] task_dict = {"name": "jsCore_gen", "commands": task_commands} task = _evergreen.Task(task_dict) @@ -212,10 +232,12 @@ class TestTask(unittest.TestCase): }, {"func": "generate resmoke tasks"}, ] - task = _evergreen.Task({ - "name": "multiversion_sanity_check_gen", - "commands": task_commands, - }) + task = _evergreen.Task( + { + "name": "multiversion_sanity_check_gen", + "commands": task_commands, + } + ) self.assertTrue(task.is_initialize_multiversion_tasks_task) self.assertTrue(task.is_generate_resmoke_task) @@ -236,12 +258,16 @@ class TestTask(unittest.TestCase): }, {"func": "generate resmoke tasks"}, ] - task = _evergreen.Task({ - "name": "multiversion_sanity_check_gen", - "commands": task_commands, - }) + task = _evergreen.Task( + { + "name": "multiversion_sanity_check_gen", + "commands": task_commands, + } + ) - self.assertDictEqual(task_commands[0], task.initialize_multiversion_tasks_command) + self.assertDictEqual( + task_commands[0], task.initialize_multiversion_tasks_command + ) self.assertEqual("multiversion_sanity_check", task.generated_task_name) def test_get_resmoke_command_vars_from_run_tests_command(self): @@ -254,7 +280,9 @@ class TestTask(unittest.TestCase): def test_get_resmoke_command_vars_from_generate_resmoke_tasks_command(self): resmoke_command_vars = {"suite": "core"} - task_commands = [{"func": "generate resmoke tasks", "vars": resmoke_command_vars}] + task_commands = [ + {"func": "generate resmoke tasks", "vars": resmoke_command_vars} + ] task_dict = {"name": "jsCore", "commands": task_commands} task = _evergreen.Task(task_dict) @@ -297,10 +325,12 @@ class TestTask(unittest.TestCase): def test_generate_resmoke_tasks_command_with_suite(self): task_name = "jsCore_gen" suite_name = "core" - task_commands = [{ - "func": "generate resmoke tasks", - "vars": {"suite": suite_name, "resmoke_args": "--installDir=/bin"} - }] + task_commands = [ + { + "func": "generate resmoke tasks", + "vars": {"suite": suite_name, "resmoke_args": "--installDir=/bin"}, + } + ] task_dict = {"name": task_name, "commands": task_commands} task = _evergreen.Task(task_dict) @@ -308,60 +338,75 @@ class TestTask(unittest.TestCase): self.assertEqual("jsCore", task.generated_task_name) def test_get_suite_names_from_non_gen_task_name(self): - task = _evergreen.Task({ - "name": "task_name", - "commands": [{"func": "run tests"}], - }) + task = _evergreen.Task( + { + "name": "task_name", + "commands": [{"func": "run tests"}], + } + ) self.assertEqual(["task_name"], task.get_suite_names()) def test_get_suite_names_from_non_gen_task_suite_var(self): - task = _evergreen.Task({ - "name": "task_name", - "commands": [{ - "func": "run tests", - "vars": {"suite": "suite_var"}, - }], - }) + task = _evergreen.Task( + { + "name": "task_name", + "commands": [ + { + "func": "run tests", + "vars": {"suite": "suite_var"}, + } + ], + } + ) self.assertEqual(["suite_var"], task.get_suite_names()) def test_get_suite_names_from_gen_task_name(self): - task = _evergreen.Task({ - "name": "task_name_gen", - "commands": [{"func": "generate resmoke tasks"}], - }) + task = _evergreen.Task( + { + "name": "task_name_gen", + "commands": [{"func": "generate resmoke tasks"}], + } + ) self.assertEqual(["task_name"], task.get_suite_names()) def test_get_suite_names_from_gen_task_suite_var(self): - task = _evergreen.Task({ - "name": "task_name_gen", - "commands": [{ - "func": "generate resmoke tasks", - "vars": {"suite": "suite_var"}, - }], - }) + task = _evergreen.Task( + { + "name": "task_name_gen", + "commands": [ + { + "func": "generate resmoke tasks", + "vars": {"suite": "suite_var"}, + } + ], + } + ) self.assertEqual(["suite_var"], task.get_suite_names()) def test_get_suite_names_from_init_multiversion_task(self): - task = _evergreen.Task({ - "name": - "task_name_multiversion_gen", - "commands": [ - { - "func": "initialize multiversion tasks", - "vars": { - "suite_last_continuous": "last_continuous", - "suite_last_lts": "last_lts", + task = _evergreen.Task( + { + "name": "task_name_multiversion_gen", + "commands": [ + { + "func": "initialize multiversion tasks", + "vars": { + "suite_last_continuous": "last_continuous", + "suite_last_lts": "last_lts", + }, }, - }, - {"func": "generate resmoke tasks"}, - ], - }) + {"func": "generate resmoke tasks"}, + ], + } + ) - self.assertEqual(["suite_last_continuous", "suite_last_lts"], task.get_suite_names()) + self.assertEqual( + ["suite_last_continuous", "suite_last_lts"], task.get_suite_names() + ) def test_generate_task_name_non_gen_tasks(self): task_name = "jsCore" @@ -386,8 +431,14 @@ class TestTaskGroup(unittest.TestCase): def test_from_list(self): task_group_dict = { - "name": "my_group", "max_hosts": 3, "tasks": ["task1", "task2"], "setup_task": [], - "teardown_task": [], "setup_group": [], "teardown_group": [], "timeout": [] + "name": "my_group", + "max_hosts": 3, + "tasks": ["task1", "task2"], + "setup_task": [], + "teardown_task": [], + "setup_group": [], + "teardown_group": [], + "timeout": [], } task_group = _evergreen.TaskGroup(task_group_dict) @@ -401,7 +452,9 @@ class TestVariant(unittest.TestCase): @classmethod def setUpClass(cls): - cls.conf = _evergreen.parse_evergreen_file(TEST_FILE_PATH, evergreen_binary=None) + cls.conf = _evergreen.parse_evergreen_file( + TEST_FILE_PATH, evergreen_binary=None + ) def test_from_dict(self): task = _evergreen.Task({"name": "compile"}) @@ -447,12 +500,16 @@ class TestVariant(unittest.TestCase): def test_expansion(self): variant_ubuntu = self.conf.get_variant("ubuntu") - self.assertEqual("--param=value --ubuntu", variant_ubuntu.expansion("test_flags")) + self.assertEqual( + "--param=value --ubuntu", variant_ubuntu.expansion("test_flags") + ) self.assertEqual(None, variant_ubuntu.expansion("not_a_valid_expansion_name")) def test_expansions(self): variant_ubuntu = self.conf.get_variant("ubuntu") - self.assertEqual({"test_flags": "--param=value --ubuntu"}, variant_ubuntu.expansions) + self.assertEqual( + {"test_flags": "--param=value --ubuntu"}, variant_ubuntu.expansions + ) def test_modules(self): variant_ubuntu = self.conf.get_variant("ubuntu") @@ -493,7 +550,11 @@ class TestVariant(unittest.TestCase): variant_ubuntu = self.conf.get_variant("ubuntu") self.assertEqual(5, len(variant_ubuntu.tasks)) for task_name in [ - "compile", "passing_test", "failing_test", "timeout_test", "resmoke_task" + "compile", + "passing_test", + "failing_test", + "timeout_test", + "resmoke_task", ]: task = variant_ubuntu.get_task(task_name) self.assertIsNotNone(task) @@ -502,39 +563,42 @@ class TestVariant(unittest.TestCase): # Check combined_suite_to_resmoke_args_map when test_flags is set on the variant. resmoke_task = variant_ubuntu.get_task("resmoke_task") - self.assertEqual({ - "resmoke_task": - "--suites=resmoke_task --storageEngine=wiredTiger --param=value --ubuntu" - }, resmoke_task.combined_suite_to_resmoke_args_map) + self.assertEqual( + { + "resmoke_task": "--suites=resmoke_task --storageEngine=wiredTiger --param=value --ubuntu" + }, + resmoke_task.combined_suite_to_resmoke_args_map, + ) # Check combined_suite_to_resmoke_args_map when the task doesn't have resmoke_args. passing_task = variant_ubuntu.get_task("passing_test") - self.assertEqual({"passing_test": "--suites=passing_test --param=value --ubuntu"}, - passing_task.combined_suite_to_resmoke_args_map) + self.assertEqual( + {"passing_test": "--suites=passing_test --param=value --ubuntu"}, + passing_task.combined_suite_to_resmoke_args_map, + ) # Check combined_suite_to_resmoke_args_map when test_flags is not set on the variant. variant_debian = self.conf.get_variant("debian") resmoke_task = variant_debian.get_task("resmoke_task") - self.assertEqual({"resmoke_task": "--suites=resmoke_task --storageEngine=wiredTiger"}, - resmoke_task.combined_suite_to_resmoke_args_map) + self.assertEqual( + {"resmoke_task": "--suites=resmoke_task --storageEngine=wiredTiger"}, + resmoke_task.combined_suite_to_resmoke_args_map, + ) # Check combined_suite_to_resmoke_args_map for "initialize multiversion tasks" task. variant_debian = self.conf.get_variant("debian") resmoke_task = variant_debian.get_task("resmoke_multiversion_task_gen") - self.assertEqual({ - "multiversion_sanity_check_last_continuous_new_new_old": - "--suites=multiversion_sanity_check_last_continuous_new_new_old --storageEngine=wiredTiger", - "multiversion_sanity_check_last_continuous_new_old_new": - "--suites=multiversion_sanity_check_last_continuous_new_old_new --storageEngine=wiredTiger", - "multiversion_sanity_check_last_continuous_old_new_new": - "--suites=multiversion_sanity_check_last_continuous_old_new_new --storageEngine=wiredTiger", - "multiversion_sanity_check_last_lts_new_new_old": - "--suites=multiversion_sanity_check_last_lts_new_new_old --storageEngine=wiredTiger", - "multiversion_sanity_check_last_lts_new_old_new": - "--suites=multiversion_sanity_check_last_lts_new_old_new --storageEngine=wiredTiger", - "multiversion_sanity_check_last_lts_old_new_new": - "--suites=multiversion_sanity_check_last_lts_old_new_new --storageEngine=wiredTiger", - }, resmoke_task.combined_suite_to_resmoke_args_map) + self.assertEqual( + { + "multiversion_sanity_check_last_continuous_new_new_old": "--suites=multiversion_sanity_check_last_continuous_new_new_old --storageEngine=wiredTiger", + "multiversion_sanity_check_last_continuous_new_old_new": "--suites=multiversion_sanity_check_last_continuous_new_old_new --storageEngine=wiredTiger", + "multiversion_sanity_check_last_continuous_old_new_new": "--suites=multiversion_sanity_check_last_continuous_old_new_new --storageEngine=wiredTiger", + "multiversion_sanity_check_last_lts_new_new_old": "--suites=multiversion_sanity_check_last_lts_new_new_old --storageEngine=wiredTiger", + "multiversion_sanity_check_last_lts_new_old_new": "--suites=multiversion_sanity_check_last_lts_new_old_new --storageEngine=wiredTiger", + "multiversion_sanity_check_last_lts_old_new_new": "--suites=multiversion_sanity_check_last_lts_old_new_new --storageEngine=wiredTiger", + }, + resmoke_task.combined_suite_to_resmoke_args_map, + ) # Check for tasks included in task_groups variant_amazon = self.conf.get_variant("amazon") diff --git a/buildscripts/tests/patch_builds/test_change_data.py b/buildscripts/tests/patch_builds/test_change_data.py index 35551fafbeb..098d48674c5 100644 --- a/buildscripts/tests/patch_builds/test_change_data.py +++ b/buildscripts/tests/patch_builds/test_change_data.py @@ -1,4 +1,5 @@ """Unit tests for buildscripts.patch_builds.change_data.py.""" + from __future__ import absolute_import import os @@ -38,28 +39,43 @@ class TestFindChangedFilesInRepos(unittest.TestCase): second_repo_file_changes = [ os.path.join("jstests", "test2.js"), ] - changed_files_mock.side_effect = [first_repo_file_changes, second_repo_file_changes] + changed_files_mock.side_effect = [ + first_repo_file_changes, + second_repo_file_changes, + ] self.assertEqual(3, len(under_test.find_changed_files_in_repos(repos_mock))) class TestGenerateRevisionMap(unittest.TestCase): def test_mongo_revisions_is_mapped_correctly(self): - mock_repo_list = [create_mock_repo(os.getcwd()), create_mock_repo("/path/to/enterprise")] + mock_repo_list = [ + create_mock_repo(os.getcwd()), + create_mock_repo("/path/to/enterprise"), + ] revision_data = {"mongo": "revision1234", "enterprise": "revision5678"} revision_map = under_test.generate_revision_map(mock_repo_list, revision_data) - self.assertEqual(revision_map[mock_repo_list[0].git_dir], revision_data["mongo"]) - self.assertEqual(revision_map[mock_repo_list[1].git_dir], revision_data["enterprise"]) + self.assertEqual( + revision_map[mock_repo_list[0].git_dir], revision_data["mongo"] + ) + self.assertEqual( + revision_map[mock_repo_list[1].git_dir], revision_data["enterprise"] + ) def test_missing_revisions_are_not_returned(self): - mock_repo_list = [create_mock_repo(os.getcwd()), create_mock_repo("/path/to/enterprise")] + mock_repo_list = [ + create_mock_repo(os.getcwd()), + create_mock_repo("/path/to/enterprise"), + ] revision_data = {"mongo": "revision1234"} revision_map = under_test.generate_revision_map(mock_repo_list, revision_data) - self.assertEqual(revision_map[mock_repo_list[0].git_dir], revision_data["mongo"]) + self.assertEqual( + revision_map[mock_repo_list[0].git_dir], revision_data["mongo"] + ) self.assertEqual(len(revision_map), 1) def test_missing_repos_are_not_returned(self): @@ -68,5 +84,7 @@ class TestGenerateRevisionMap(unittest.TestCase): revision_map = under_test.generate_revision_map(mock_repo_list, revision_data) - self.assertEqual(revision_map[mock_repo_list[0].git_dir], revision_data["mongo"]) + self.assertEqual( + revision_map[mock_repo_list[0].git_dir], revision_data["mongo"] + ) self.assertEqual(len(revision_map), 1) diff --git a/buildscripts/tests/resmoke_end2end/test_resmoke.py b/buildscripts/tests/resmoke_end2end/test_resmoke.py index 9147c3b051f..2fe29180f9d 100644 --- a/buildscripts/tests/resmoke_end2end/test_resmoke.py +++ b/buildscripts/tests/resmoke_end2end/test_resmoke.py @@ -51,7 +51,10 @@ class _ResmokeSelftest(unittest.TestCase): def execute_resmoke(self, resmoke_args, **kwargs): # pylint: disable=unused-argument resmoke_process = core.programs.make_process( self.logger, - [sys.executable, "buildscripts/resmoke.py"] + self.resmoke_const_args + resmoke_args) + [sys.executable, "buildscripts/resmoke.py"] + + self.resmoke_const_args + + resmoke_args, + ) resmoke_process.start() self.resmoke_process = resmoke_process @@ -88,7 +91,9 @@ class TestArchivalOnFailure(_ResmokeSelftest): # test archival archival_dirs_to_expect = 4 # 2 tests * 2 nodes - self.assert_dir_file_count(self.test_dir, self.archival_file, archival_dirs_to_expect) + self.assert_dir_file_count( + self.test_dir, self.archival_file, archival_dirs_to_expect + ) def test_archival_on_task_failure_no_passthrough(self): # The --originSuite argument is to trick the resmoke local invocation into passing @@ -108,7 +113,9 @@ class TestArchivalOnFailure(_ResmokeSelftest): # test archival archival_dirs_to_expect = 8 # (2 tests + 2 stacktrace files) * 2 nodes - self.assert_dir_file_count(self.test_dir, self.archival_file, archival_dirs_to_expect) + self.assert_dir_file_count( + self.test_dir, self.archival_file, archival_dirs_to_expect + ) def test_no_archival_locally(self): # archival should not happen if --taskId is not set. @@ -140,10 +147,14 @@ class TestTimeout(_ResmokeSelftest): rmtree(self.test_dir_inner, ignore_errors=True) def signal_resmoke(self): - hang_analyzer_options = f"-o=file -o=stdout -m=contains -p=python -d={self.resmoke_process.pid}" + hang_analyzer_options = ( + f"-o=file -o=stdout -m=contains -p=python -d={self.resmoke_process.pid}" + ) signal_resmoke_process = core.programs.make_process( - self.logger, [sys.executable, "buildscripts/resmoke.py", "hang-analyzer" - ] + hang_analyzer_options.split()) + self.logger, + [sys.executable, "buildscripts/resmoke.py", "hang-analyzer"] + + hang_analyzer_options.split(), + ) signal_resmoke_process.start() # Wait for resmoke_process to be killed by 'run-timeout' so this doesn't hang. @@ -170,7 +181,9 @@ class TestTimeout(_ResmokeSelftest): started_polling_datetime = datetime.datetime.now() while not os.path.isfile(sentinel_path): time.sleep(0.1) - if datetime.datetime.now() - started_polling_datetime > datetime.timedelta(minutes=5): + if datetime.datetime.now() - started_polling_datetime > datetime.timedelta( + minutes=5 + ): self.fail("SUT is not available within 99 seconds; aborting test") # Kill resmoke: @@ -193,10 +206,14 @@ class TestTimeout(_ResmokeSelftest): self.execute_resmoke(resmoke_args, sentinel_file="timeout0") archival_dirs_to_expect = 4 # 2 tests * 2 mongod - self.assert_dir_file_count(self.test_dir, self.archival_file, archival_dirs_to_expect) + self.assert_dir_file_count( + self.test_dir, self.archival_file, archival_dirs_to_expect + ) analysis_pids_to_expect = 6 # 2 tests * (2 mongod + 1 mongo) - self.assert_dir_file_count(self.test_dir, self.analysis_file, analysis_pids_to_expect) + self.assert_dir_file_count( + self.test_dir, self.analysis_file, analysis_pids_to_expect + ) def test_task_timeout_no_passthrough(self): # The --originSuite argument is to trick the resmoke local invocation into passing @@ -215,10 +232,14 @@ class TestTimeout(_ResmokeSelftest): self.execute_resmoke(resmoke_args, sentinel_file="timeout1") archival_dirs_to_expect = 8 # (2 tests + 2 stacktrace files) * 2 nodes - self.assert_dir_file_count(self.test_dir, self.archival_file, archival_dirs_to_expect) + self.assert_dir_file_count( + self.test_dir, self.archival_file, archival_dirs_to_expect + ) analysis_pids_to_expect = 6 # 2 tests * (2 mongod + 1 mongo) - self.assert_dir_file_count(self.test_dir, self.analysis_file, analysis_pids_to_expect) + self.assert_dir_file_count( + self.test_dir, self.analysis_file, analysis_pids_to_expect + ) # Test scenarios where an resmoke-launched process launches resmoke. def test_nested_timeout(self): @@ -237,12 +258,20 @@ class TestTimeout(_ResmokeSelftest): self.execute_resmoke(resmoke_args, sentinel_file="inner_level_timeout") - archival_dirs_to_expect = 4 # ((2 tests + 2 stacktrace files) * 2 nodes) / 2 data_file directories - self.assert_dir_file_count(self.test_dir, self.archival_file, archival_dirs_to_expect) - self.assert_dir_file_count(self.test_dir_inner, self.archival_file, archival_dirs_to_expect) + archival_dirs_to_expect = ( + 4 # ((2 tests + 2 stacktrace files) * 2 nodes) / 2 data_file directories + ) + self.assert_dir_file_count( + self.test_dir, self.archival_file, archival_dirs_to_expect + ) + self.assert_dir_file_count( + self.test_dir_inner, self.archival_file, archival_dirs_to_expect + ) analysis_pids_to_expect = 6 # 2 tests * (2 mongod + 1 mongo) - self.assert_dir_file_count(self.test_dir, self.analysis_file, analysis_pids_to_expect) + self.assert_dir_file_count( + self.test_dir, self.analysis_file, analysis_pids_to_expect + ) class TestTestSelection(_ResmokeSelftest): @@ -252,7 +281,9 @@ class TestTestSelection(_ResmokeSelftest): def execute_resmoke(self, resmoke_args): # pylint: disable=arguments-differ resmoke_process = core.programs.make_process( - self.logger, [sys.executable, "buildscripts/resmoke.py", "run"] + resmoke_args) + self.logger, + [sys.executable, "buildscripts/resmoke.py", "run"] + resmoke_args, + ) resmoke_process.start() return resmoke_process @@ -273,24 +304,34 @@ class TestTestSelection(_ResmokeSelftest): # Tests a suite that excludes a missing file self.assertEqual( 0, - self.execute_resmoke([ - f"--reportFile={self.report_file}", "--repeatTests=2", - f"--suites={self.suites_root}/resmoke_missing_test.yml", - f"{self.testfiles_root}/one.js", f"{self.testfiles_root}/one.js", - f"{self.testfiles_root}/one.js" - ]).wait()) + self.execute_resmoke( + [ + f"--reportFile={self.report_file}", + "--repeatTests=2", + f"--suites={self.suites_root}/resmoke_missing_test.yml", + f"{self.testfiles_root}/one.js", + f"{self.testfiles_root}/one.js", + f"{self.testfiles_root}/one.js", + ] + ).wait(), + ) self.assertEqual(6 * [f"{self.testfiles_root}/one.js"], self.get_tests_run()) def test_positional_arguments(self): self.assertEqual( 0, - self.execute_resmoke([ - f"--reportFile={self.report_file}", "--repeatTests=2", - f"--suites={self.suites_root}/resmoke_no_mongod.yml", - f"{self.testfiles_root}/one.js", f"{self.testfiles_root}/one.js", - f"{self.testfiles_root}/one.js" - ]).wait()) + self.execute_resmoke( + [ + f"--reportFile={self.report_file}", + "--repeatTests=2", + f"--suites={self.suites_root}/resmoke_no_mongod.yml", + f"{self.testfiles_root}/one.js", + f"{self.testfiles_root}/one.js", + f"{self.testfiles_root}/one.js", + ] + ).wait(), + ) self.assertEqual(6 * [f"{self.testfiles_root}/one.js"], self.get_tests_run()) @@ -299,34 +340,49 @@ class TestTestSelection(_ResmokeSelftest): self.assertEqual( 0, - self.execute_resmoke([ - f"--reportFile={self.report_file}", "--repeatTests=2", - f"--suites={self.suites_root}/resmoke_no_mongod.yml", - f"--replay={self.test_dir}/replay" - ]).wait()) + self.execute_resmoke( + [ + f"--reportFile={self.report_file}", + "--repeatTests=2", + f"--suites={self.suites_root}/resmoke_no_mongod.yml", + f"--replay={self.test_dir}/replay", + ] + ).wait(), + ) self.assertEqual(6 * [f"{self.testfiles_root}/two.js"], self.get_tests_run()) def test_suite_file(self): self.assertEqual( 0, - self.execute_resmoke([ - f"--reportFile={self.report_file}", "--repeatTests=2", - f"--suites={self.suites_root}/resmoke_no_mongod.yml" - ]).wait()) + self.execute_resmoke( + [ + f"--reportFile={self.report_file}", + "--repeatTests=2", + f"--suites={self.suites_root}/resmoke_no_mongod.yml", + ] + ).wait(), + ) - self.assertEqual(2 * [f"{self.testfiles_root}/one.js", f"{self.testfiles_root}/two.js"], - self.get_tests_run()) + self.assertEqual( + 2 * [f"{self.testfiles_root}/one.js", f"{self.testfiles_root}/two.js"], + self.get_tests_run(), + ) def test_at_sign_as_replay_file(self): self.create_file_in_test_dir("replay", f"{self.testfiles_root}/two.js\n" * 3) self.assertEqual( 0, - self.execute_resmoke([ - f"--reportFile={self.report_file}", "--repeatTests=2", - f"--suites={self.suites_root}/resmoke_no_mongod.yml", f"@{self.test_dir}/replay" - ]).wait()) + self.execute_resmoke( + [ + f"--reportFile={self.report_file}", + "--repeatTests=2", + f"--suites={self.suites_root}/resmoke_no_mongod.yml", + f"@{self.test_dir}/replay", + ] + ).wait(), + ) self.assertEqual(6 * [f"{self.testfiles_root}/two.js"], self.get_tests_run()) @@ -337,17 +393,23 @@ class TestTestSelection(_ResmokeSelftest): self.assertEqual( 2, self.execute_resmoke( - [f"--replay={self.test_dir}/replay", f"{self.testfiles_root}/one.js"]).wait()) + [f"--replay={self.test_dir}/replay", f"{self.testfiles_root}/one.js"] + ).wait(), + ) # When multiple positional arguments are presented, they're all treated as test files. self.assertEqual( 2, - self.execute_resmoke([f"@{self.test_dir}/replay", - f"{self.testfiles_root}/one.js"]).wait()) + self.execute_resmoke( + [f"@{self.test_dir}/replay", f"{self.testfiles_root}/one.js"] + ).wait(), + ) self.assertEqual( 2, - self.execute_resmoke([f"{self.testfiles_root}/one.js", - f"@{self.test_dir}/replay"]).wait()) + self.execute_resmoke( + [f"{self.testfiles_root}/one.js", f"@{self.test_dir}/replay"] + ).wait(), + ) class TestSetParameters(_ResmokeSelftest): @@ -372,7 +434,8 @@ class TestSetParameters(_ResmokeSelftest): pass suite["executor"]["config"]["shell_options"]["global_vars"]["TestData"][ - "outputLocation"] = self.shell_output_file + "outputLocation" + ] = self.shell_output_file with open(os.path.normpath(suite_output_path), "w") as fd: yaml.dump(suite, fd, default_flow_style=False) @@ -390,19 +453,25 @@ class TestSetParameters(_ResmokeSelftest): self.logger.info( "Running test. Template suite: %s Rewritten suite: %s Resmoke Args: %s Test output file: %s.", - suite_template, suite_file, resmoke_args, self.shell_output_file) + suite_template, + suite_file, + resmoke_args, + self.shell_output_file, + ) resmoke_process = core.programs.make_process( self.logger, - [sys.executable, "buildscripts/resmoke.py", "run", f"--suites={suite_file}" - ] + resmoke_args) + [sys.executable, "buildscripts/resmoke.py", "run", f"--suites={suite_file}"] + + resmoke_args, + ) resmoke_process.start() return resmoke_process def test_suite_set_parameters(self): self.generate_suite_and_execute_resmoke( - f"{self.suites_root}/resmoke_selftest_set_parameters.yml", []).wait() + f"{self.suites_root}/resmoke_selftest_set_parameters.yml", [] + ).wait() set_params = self.parse_output_json() self.assertEqual("1", set_params["enableTestCommands"]) @@ -412,8 +481,10 @@ class TestSetParameters(_ResmokeSelftest): def test_cli_set_parameters(self): self.generate_suite_and_execute_resmoke( f"{self.suites_root}/resmoke_selftest_set_parameters.yml", - ["""--mongodSetParameter={"enableFlowControl": false, "flowControlMaxSamples": 500}""" - ]).wait() + [ + """--mongodSetParameter={"enableFlowControl": false, "flowControlMaxSamples": 500}""" + ], + ).wait() set_params = self.parse_output_json() self.assertEqual("1", set_params["enableTestCommands"]) @@ -423,7 +494,8 @@ class TestSetParameters(_ResmokeSelftest): def test_override_set_parameters(self): self.generate_suite_and_execute_resmoke( f"{self.suites_root}/resmoke_selftest_set_parameters.yml", - ["""--mongodSetParameter={"testingDiagnosticsEnabled": true}"""]).wait() + ["""--mongodSetParameter={"testingDiagnosticsEnabled": true}"""], + ).wait() set_params = self.parse_output_json() self.assertEqual("true", set_params["testingDiagnosticsEnabled"]) @@ -431,10 +503,12 @@ class TestSetParameters(_ResmokeSelftest): def test_merge_cli_set_parameters(self): self.generate_suite_and_execute_resmoke( - f"{self.suites_root}/resmoke_selftest_set_parameters.yml", [ + f"{self.suites_root}/resmoke_selftest_set_parameters.yml", + [ """--mongodSetParameter={"enableFlowControl": false}""", - """--mongodSetParameter={"flowControlMaxSamples": 500}""" - ]).wait() + """--mongodSetParameter={"flowControlMaxSamples": 500}""", + ], + ).wait() set_params = self.parse_output_json() self.assertEqual("false", set_params["testingDiagnosticsEnabled"]) @@ -446,52 +520,72 @@ class TestSetParameters(_ResmokeSelftest): self.assertEqual( 2, self.generate_suite_and_execute_resmoke( - f"{self.suites_root}/resmoke_selftest_set_parameters.yml", [ + f"{self.suites_root}/resmoke_selftest_set_parameters.yml", + [ """--mongodSetParameter={"enableFlowControl": false}""", - """--mongodSetParameter={"enableFlowControl": true}""" - ]).wait()) + """--mongodSetParameter={"enableFlowControl": true}""", + ], + ).wait(), + ) def test_mongos_set_parameter(self): self.generate_suite_and_execute_resmoke( - f"{self.suites_root}/resmoke_selftest_set_parameters_sharding.yml", [ + f"{self.suites_root}/resmoke_selftest_set_parameters_sharding.yml", + [ """--mongosSetParameter={"maxTimeMSForHedgedReads": 100}""", - """--mongosSetParameter={"mongosShutdownTimeoutMillisForSignaledShutdown": 1000}""" - ]).wait() + """--mongosSetParameter={"mongosShutdownTimeoutMillisForSignaledShutdown": 1000}""", + ], + ).wait() set_params = self.parse_output_json() self.assertEqual("100", set_params["maxTimeMSForHedgedReads"]) - self.assertEqual("1000", set_params["mongosShutdownTimeoutMillisForSignaledShutdown"]) + self.assertEqual( + "1000", set_params["mongosShutdownTimeoutMillisForSignaledShutdown"] + ) def test_merge_error_cli_mongos_set_parameter(self): self.assertEqual( 2, self.generate_suite_and_execute_resmoke( - f"{self.suites_root}/resmoke_selftest_set_parameters_sharding.yml", [ + f"{self.suites_root}/resmoke_selftest_set_parameters_sharding.yml", + [ """--mongosSetParameter={"maxTimeMSForHedgedReads": 100}""", - """--mongosSetParameter={"maxTimeMSForHedgedReads": 1000}""" - ]).wait()) + """--mongosSetParameter={"maxTimeMSForHedgedReads": 1000}""", + ], + ).wait(), + ) def test_allow_duplicate_set_parameter_values(self): self.assertEqual( 0, self.generate_suite_and_execute_resmoke( - f"{self.suites_root}/resmoke_selftest_set_parameters.yml", [ + f"{self.suites_root}/resmoke_selftest_set_parameters.yml", + [ """--mongodSetParameter={"enableFlowControl": false}""", - """--mongodSetParameter={"enableFlowControl": false}""" - ]).wait()) + """--mongodSetParameter={"enableFlowControl": false}""", + ], + ).wait(), + ) self.assertEqual( 0, self.generate_suite_and_execute_resmoke( - f"{self.suites_root}/resmoke_selftest_set_parameters.yml", [ + f"{self.suites_root}/resmoke_selftest_set_parameters.yml", + [ """--mongodSetParameter={"mirrorReads": {samplingRate: 1.0}}""", - """--mongodSetParameter={"mirrorReads": {samplingRate: 1.0}}""" - ]).wait()) + """--mongodSetParameter={"mirrorReads": {samplingRate: 1.0}}""", + ], + ).wait(), + ) def execute_resmoke(resmoke_args): - return subprocess.run([sys.executable, "buildscripts/resmoke.py", "run"] + resmoke_args, - text=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + return subprocess.run( + [sys.executable, "buildscripts/resmoke.py", "run"] + resmoke_args, + text=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) class TestExceptionExtraction(unittest.TestCase): @@ -577,14 +671,15 @@ class TestSetShellSeed(unittest.TestCase): resmoke_args = [ "--suites=buildscripts/tests/resmoke_end2end/suites/resmoke_set_shellseed.yml", "buildscripts/tests/resmoke_end2end/testfiles/random_with_seed.js", - f"--shellSeed={test_seed}" + f"--shellSeed={test_seed}", ] seed = self.execute_resmoke_and_get_seed(resmoke_args) self.assertEqual( - seed, test_seed, msg= - "The found random seed does not match the seed passed with the --shellSeed resmoke argument." + seed, + test_seed, + msg="The found random seed does not match the seed passed with the --shellSeed resmoke argument.", ) def test_random_shell_seed(self): @@ -600,7 +695,9 @@ class TestSetShellSeed(unittest.TestCase): random_seeds.add(seed) self.assertTrue( - len(random_seeds) > 1, msg="Resmoke generated the same random seed 10 times in a row.") + len(random_seeds) > 1, + msg="Resmoke generated the same random seed 10 times in a row.", + ) # In resmoke we expect certain parts of the evergreen config to be a certain way @@ -616,8 +713,10 @@ class TestEvergreenYML(unittest.TestCase): suite_config = suitesconfig.get_suite(suite_name).get_config() expected_selector = ["jstestfuzz/out/*.js"] self.assertEqual( - suite_config["selector"]["roots"], expected_selector, - msg=f"The jstestfuzz selector for {suite_name} did not match 'jstestfuzz/out/*.js'") + suite_config["selector"]["roots"], + expected_selector, + msg=f"The jstestfuzz selector for {suite_name} did not match 'jstestfuzz/out/*.js'", + ) # This test asserts that the jstestfuzz tasks uploads the the URL we expect it to # If the remote url changes, also change it in the _log_local_resmoke_invocation method @@ -631,13 +730,16 @@ class TestEvergreenYML(unittest.TestCase): continue remote_url = item["params"]["remote_file"] - if remote_url == "${project}/${build_variant}/${revision}/jstestfuzz/${task_id}-${execution}.tgz": + if ( + remote_url + == "${project}/${build_variant}/${revision}/jstestfuzz/${task_id}-${execution}.tgz" + ): contains_correct_url = True break self.assertTrue( - contains_correct_url, msg= - "The 'run jstestfuzz' function in evergreen did not contain the remote_url that was expected" + contains_correct_url, + msg="The 'run jstestfuzz' function in evergreen did not contain the remote_url that was expected", ) # This tasks asserts that the way implicit multiversion tasks are defined has not changed @@ -652,8 +754,11 @@ class TestEvergreenYML(unittest.TestCase): if func is not None: implicit_multiversion_count += 1 - self.assertNotEqual(0, implicit_multiversion_count, - msg="Could not find any implicit multiversion tasks in evergreen") + self.assertNotEqual( + 0, + implicit_multiversion_count, + msg="Could not find any implicit multiversion tasks in evergreen", + ) # This tasks asserts that the way jstestfuzz tasks are defined has not changed # It also asserts that the selector for jstestfuzz tasks always points to jstestfuzz/out/*.js @@ -663,31 +768,43 @@ class TestEvergreenYML(unittest.TestCase): jstestfuzz_count = 0 for task in self.evg_conf.tasks: generate_func = task.find_func_command("generate resmoke tasks") - if (generate_func is None - or get_dict_value(generate_func, ["vars", "is_jstestfuzz"]) != "true"): + if ( + generate_func is None + or get_dict_value(generate_func, ["vars", "is_jstestfuzz"]) != "true" + ): continue jstestfuzz_count += 1 self.validate_jstestfuzz_selector(task.get_suite_names()) - self.assertNotEqual(0, jstestfuzz_count, msg="Could not find any jstestfuzz tasks") + self.assertNotEqual( + 0, jstestfuzz_count, msg="Could not find any jstestfuzz tasks" + ) class TestMultiversionConfig(unittest.TestCase): def test_valid_yaml(self): file_name = "multiversion-config.yml" - subprocess.run([ - sys.executable, "buildscripts/resmoke.py", "multiversion-config", - "--config-file-output", file_name - ], check=True) + subprocess.run( + [ + sys.executable, + "buildscripts/resmoke.py", + "multiversion-config", + "--config-file-output", + file_name, + ], + check=True, + ) with open(file_name, "r") as file: file_contents = file.read() try: yaml.safe_load(file_contents) except Exception: - self.fail(msg="`resmoke.py multiversion-config` does not output valid yaml.") + self.fail( + msg="`resmoke.py multiversion-config` does not output valid yaml." + ) os.remove(file_name) @@ -697,5 +814,9 @@ class TestCoreAnalyzerFunctions(unittest.TestCase): task_name = "test_tast_name" execution = "0" generated_task_name = get_generated_task_name(task_name, execution) - self.assertEquals(matches_generated_task_pattern(task_name, generated_task_name), execution) - self.assertIsNone(matches_generated_task_pattern("not_same_task", generated_task_name)) + self.assertEquals( + matches_generated_task_pattern(task_name, generated_task_name), execution + ) + self.assertIsNone( + matches_generated_task_pattern("not_same_task", generated_task_name) + ) diff --git a/buildscripts/tests/resmoke_proxy/test_resmoke_proxy.py b/buildscripts/tests/resmoke_proxy/test_resmoke_proxy.py index dbd1bf476c7..4f1a29a7eb7 100644 --- a/buildscripts/tests/resmoke_proxy/test_resmoke_proxy.py +++ b/buildscripts/tests/resmoke_proxy/test_resmoke_proxy.py @@ -1,4 +1,5 @@ """Unit tests for resmoke_proxy.py.""" + import unittest from unittest.mock import MagicMock @@ -10,7 +11,14 @@ from buildscripts.resmoke_proxy import resmoke_proxy as under_test class TestResmokeProxy(unittest.TestCase): def test_list_tests_can_handle_strings_and_lists(self): mock_suite = MagicMock( - tests=["test0", "test1", ["test2a", "tests2b", "test2c"], "test3", ["test4a"]]) + tests=[ + "test0", + "test1", + ["test2a", "tests2b", "test2c"], + "test3", + ["test4a"], + ] + ) resmoke_proxy = under_test.ResmokeProxyService() resmoke_proxy._suite_config = MagicMock() diff --git a/buildscripts/tests/resmoke_validation/test_find_suites.py b/buildscripts/tests/resmoke_validation/test_find_suites.py index 4180745f1f0..3cf3da948d7 100644 --- a/buildscripts/tests/resmoke_validation/test_find_suites.py +++ b/buildscripts/tests/resmoke_validation/test_find_suites.py @@ -7,7 +7,11 @@ class TestFindSuites(unittest.TestCase): def test_find_suites(self): jstests = glob.glob("jstests/core/*.js") resmoke_process = subprocess.run( - ["python3", "buildscripts/resmoke.py", "find-suites", jstests[0]]) + ["python3", "buildscripts/resmoke.py", "find-suites", jstests[0]] + ) - self.assertEqual(0, resmoke_process.returncode, - msg="find-suites subcommand did not execute successfully.") + self.assertEqual( + 0, + resmoke_process.returncode, + msg="find-suites subcommand did not execute successfully.", + ) diff --git a/buildscripts/tests/resmoke_validation/test_generated_matrix_suites.py b/buildscripts/tests/resmoke_validation/test_generated_matrix_suites.py index d43f1fad6be..b4c8de59848 100644 --- a/buildscripts/tests/resmoke_validation/test_generated_matrix_suites.py +++ b/buildscripts/tests/resmoke_validation/test_generated_matrix_suites.py @@ -20,27 +20,34 @@ class ValidateGeneratedSuites(unittest.TestCase): try: suite = self.matrix_suite_config.get_config_obj_and_verify(suite_name) self.assertIsNotNone( - suite, msg= - f"{suite_name} was not found. This means either MatrixSuiteConfig.get_named_suites() " - + - "or MatrixSuiteConfig.get_config_obj_and_verify() are not working as intended.") + suite, + msg=f"{suite_name} was not found. This means either MatrixSuiteConfig.get_named_suites() " + + "or MatrixSuiteConfig.get_config_obj_and_verify() are not working as intended.", + ) except Exception as ex: self.fail(repr(ex)) def test_stray_generated_files(self): suite_names = set(self.matrix_suite_config.get_named_suites()) - suites_dir = os.path.join(self.matrix_suite_config.get_suites_dir(), "generated_suites") + suites_dir = os.path.join( + self.matrix_suite_config.get_suites_dir(), "generated_suites" + ) generated_files = os.listdir(suites_dir) for filename in generated_files: (suite_name, ext) = os.path.splitext(filename) - self.assertEqual(ext, ".yml", - msg=f"{filename} has the wrong file extension, expected `.yml`") - expected_mapping_file = os.path.join(self.matrix_suite_config.get_suites_dir(), - f"mappings/{suite_name}.yml") + self.assertEqual( + ext, + ".yml", + msg=f"{filename} has the wrong file extension, expected `.yml`", + ) + expected_mapping_file = os.path.join( + self.matrix_suite_config.get_suites_dir(), f"mappings/{suite_name}.yml" + ) self.assertIn( - suite_name, suite_names, msg= - f"{filename} does not have a correlated mapping file . Make a mapping file or delete it." + suite_name, + suite_names, + msg=f"{filename} does not have a correlated mapping file . Make a mapping file or delete it." f"You have a generated file {filename} that does not have a corresponding mapping file {expected_mapping_file}. " - + - "If you have added a non matrix suite to resmokeconfig/matrix_suites/generated_suites, move it to the resmokeconfig/suites." - + " If you have removed the mapping file be sure to remove the generated file.") + + "If you have added a non matrix suite to resmokeconfig/matrix_suites/generated_suites, move it to the resmokeconfig/suites." + + " If you have removed the mapping file be sure to remove the generated file.", + ) diff --git a/buildscripts/tests/resmoke_validation/test_jstest_tags.py b/buildscripts/tests/resmoke_validation/test_jstest_tags.py index c42d6241dea..fdf4dfc70bb 100644 --- a/buildscripts/tests/resmoke_validation/test_jstest_tags.py +++ b/buildscripts/tests/resmoke_validation/test_jstest_tags.py @@ -36,8 +36,8 @@ class JstestTagRule: class FeatureFlagIncompatibleTagRule(JstestTagRule): def __init__(self): super().__init__( - failure_message= - "The following tags are not allowed for feature flags that default to true") + failure_message="The following tags are not allowed for feature flags that default to true" + ) self.disallowed_tags = { f"{flag}_incompatible" for flag in get_all_feature_flags_turned_on_by_default() @@ -50,8 +50,12 @@ class FeatureFlagIncompatibleTagRule(JstestTagRule): class RequiresFcvTagRule(JstestTagRule): def __init__(self): super().__init__( - failure_message="The following tags reference FCV version that is not available") - self.allowed_tags = [*REQUIRES_FCV_TAGS_LESS_THAN_LATEST, REQUIRES_FCV_TAG_LATEST] + failure_message="The following tags reference FCV version that is not available" + ) + self.allowed_tags = [ + *REQUIRES_FCV_TAGS_LESS_THAN_LATEST, + REQUIRES_FCV_TAG_LATEST, + ] def _tag_failed(self, file: str, tag: str) -> bool: return tag.startswith("requires_fcv_") and tag not in self.allowed_tags diff --git a/buildscripts/tests/resmoke_validation/test_matrix_suite_generation.py b/buildscripts/tests/resmoke_validation/test_matrix_suite_generation.py index dc0bdb0186c..8375a8ef794 100644 --- a/buildscripts/tests/resmoke_validation/test_matrix_suite_generation.py +++ b/buildscripts/tests/resmoke_validation/test_matrix_suite_generation.py @@ -18,13 +18,15 @@ class TestSuiteGeneration(unittest.TestCase): def verify_suite_generation(self): tested_suite = "test_matrix_suite" - generated_suite_path = self.matrix_suite_config.get_generated_suite_path(tested_suite) + generated_suite_path = self.matrix_suite_config.get_generated_suite_path( + tested_suite + ) if os.path.exists(generated_suite_path): os.remove(generated_suite_path) with self.assertRaises( - InvalidMatrixSuiteError, msg= - f"{tested_suite} suite should have failed because the generated suite does not exist." + InvalidMatrixSuiteError, + msg=f"{tested_suite} suite should have failed because the generated suite does not exist.", ): self.matrix_suite_config.get_config_obj_and_verify(tested_suite) @@ -38,14 +40,17 @@ class TestSuiteGeneration(unittest.TestCase): def verify_altered_generated_suite(self): tested_suite = "test_matrix_suite" - generated_suite_path = self.matrix_suite_config.get_generated_suite_path(tested_suite) + generated_suite_path = self.matrix_suite_config.get_generated_suite_path( + tested_suite + ) self.matrix_suite_config.generate_matrix_suite_file(tested_suite) with open(generated_suite_path, "a") as file: file.write("test change") with self.assertRaises( - InvalidMatrixSuiteError, msg= - f"{tested_suite} suite should have failed because the generated suite was edited."): + InvalidMatrixSuiteError, + msg=f"{tested_suite} suite should have failed because the generated suite was edited.", + ): self.matrix_suite_config.get_config_obj_and_verify(tested_suite) # restore original file back @@ -59,15 +64,27 @@ class TestSuiteGeneration(unittest.TestCase): def run_generated_suite(self): tested_suite = "test_matrix_suite" - generated_suite_path = self.matrix_suite_config.get_generated_suite_path(tested_suite) + generated_suite_path = self.matrix_suite_config.get_generated_suite_path( + tested_suite + ) self.matrix_suite_config.generate_matrix_suite_file(tested_suite) resmoke_process = subprocess.run( - ["python3", "buildscripts/resmoke.py", "run", "--suites", generated_suite_path]) + [ + "python3", + "buildscripts/resmoke.py", + "run", + "--suites", + generated_suite_path, + ] + ) - self.assertEqual(0, resmoke_process.returncode, - msg="Generated resmoke suite did not execute successfully.") + self.assertEqual( + 0, + resmoke_process.returncode, + msg="Generated resmoke suite did not execute successfully.", + ) def test_everything_sequentially(self): self.verify_suite_generation() diff --git a/buildscripts/tests/resmoke_validation/test_suites_configurations.py b/buildscripts/tests/resmoke_validation/test_suites_configurations.py index a612af4a5bf..7c65f3666df 100644 --- a/buildscripts/tests/resmoke_validation/test_suites_configurations.py +++ b/buildscripts/tests/resmoke_validation/test_suites_configurations.py @@ -20,4 +20,6 @@ class TestSuitesConfigurations(unittest.TestCase): if err.filename in config.EXTERNAL_SUITE_SELECTORS: continue except Exception as ex: - self.fail(f"While validating `{suite.get_name()}` suite got an error: {str(ex)}") + self.fail( + f"While validating `{suite.get_name()}` suite got an error: {str(ex)}" + ) diff --git a/buildscripts/tests/resmokelib/core/test_pipe.py b/buildscripts/tests/resmokelib/core/test_pipe.py index 954f3a9979e..f069e799c21 100644 --- a/buildscripts/tests/resmokelib/core/test_pipe.py +++ b/buildscripts/tests/resmokelib/core/test_pipe.py @@ -21,8 +21,9 @@ class TestLoggerPipe(unittest.TestCase): logger = logging.Logger("for_testing") logger.log = mock.MagicMock() - logger_pipe = _pipe.LoggerPipe(logger=logger, level=cls.LOG_LEVEL, - pipe_out=io.BytesIO(output)) + logger_pipe = _pipe.LoggerPipe( + logger=logger, level=cls.LOG_LEVEL, pipe_out=io.BytesIO(output) + ) logger_pipe.wait_until_started() logger_pipe.wait_until_finished() diff --git a/buildscripts/tests/resmokelib/core/test_programs.py b/buildscripts/tests/resmokelib/core/test_programs.py index 465f0522e30..ff62dbef8c7 100644 --- a/buildscripts/tests/resmokelib/core/test_programs.py +++ b/buildscripts/tests/resmokelib/core/test_programs.py @@ -8,15 +8,17 @@ from buildscripts.resmokelib.core.programs import _format_shell_vars class ResmokeProgramsTestCase(unittest.TestCase): def test_format_shell_vars_with_dot(self): string_builder = [] - with_dot = {'a.b': 'c'} - _format_shell_vars(string_builder, ['dummy_key'], with_dot) - expected = ['dummy_key = new Object()', 'dummy_key["a.b"] = "c"'] + with_dot = {"a.b": "c"} + _format_shell_vars(string_builder, ["dummy_key"], with_dot) + expected = ["dummy_key = new Object()", 'dummy_key["a.b"] = "c"'] self.assertEqual(string_builder, expected) string_builder = [] - without_dot = {'a': {'b': 'c'}} - _format_shell_vars(string_builder, ['dummy_key'], without_dot) + without_dot = {"a": {"b": "c"}} + _format_shell_vars(string_builder, ["dummy_key"], without_dot) expected = [ - 'dummy_key = new Object()', 'dummy_key["a"] = new Object()', 'dummy_key["a"]["b"] = "c"' + "dummy_key = new Object()", + 'dummy_key["a"] = new Object()', + 'dummy_key["a"]["b"] = "c"', ] self.assertEqual(string_builder, expected) diff --git a/buildscripts/tests/resmokelib/hang_analyzer/test_process_list.py b/buildscripts/tests/resmokelib/hang_analyzer/test_process_list.py index d36e4f28df9..d1fd0ec7c8c 100644 --- a/buildscripts/tests/resmokelib/hang_analyzer/test_process_list.py +++ b/buildscripts/tests/resmokelib/hang_analyzer/test_process_list.py @@ -26,42 +26,56 @@ class TestGetProcesses(unittest.TestCase): (2, "mongo"), (3, "python"), (4, "mongod"), - (5, "java") # this should be ignored. + (5, "java"), # this should be ignored. ] process_ids = None - interesting_processes = ['python', 'mongo', 'mongod'] + interesting_processes = ["python", "mongo", "mongod"] process_match = "exact" logger = Mock() - processes = get_processes(process_ids, interesting_processes, process_match, logger) + processes = get_processes( + process_ids, interesting_processes, process_match, logger + ) - self.assertCountEqual(processes, [ - Pinfo(name="python", pidv=[1, 3]), - Pinfo(name="mongo", pidv=[2]), - Pinfo(name="mongod", pidv=[4]) - ]) + self.assertCountEqual( + processes, + [ + Pinfo(name="python", pidv=[1, 3]), + Pinfo(name="mongo", pidv=[2]), + Pinfo(name="mongod", pidv=[4]), + ], + ) @patch(ns("os.getpid")) @patch(ns("_get_lister")) def test_interesting_processes_and_process_ids(self, lister_mock, os_mock): os_mock.return_value = -1 - lister_mock.return_value.dump_processes.return_value = [(1, "python"), (2, "mongo"), - (3, "python"), (4, "mongod"), - (5, "java")] + lister_mock.return_value.dump_processes.return_value = [ + (1, "python"), + (2, "mongo"), + (3, "python"), + (4, "mongod"), + (5, "java"), + ] process_ids = [1, 2, 5] - interesting_processes = ['python', 'mongo', 'mongod'] + interesting_processes = ["python", "mongo", "mongod"] process_match = "exact" logger = Mock() - processes = get_processes(process_ids, interesting_processes, process_match, logger) + processes = get_processes( + process_ids, interesting_processes, process_match, logger + ) - self.assertCountEqual(processes, [ - Pinfo(name="python", pidv=[1]), - Pinfo(name="mongo", pidv=[2]), - Pinfo(name="java", pidv=[5]), - ]) + self.assertCountEqual( + processes, + [ + Pinfo(name="python", pidv=[1]), + Pinfo(name="mongo", pidv=[2]), + Pinfo(name="java", pidv=[5]), + ], + ) @patch(ns("os.getpid")) @patch(ns("_get_lister")) @@ -73,23 +87,28 @@ class TestGetProcesses(unittest.TestCase): (3, "python3"), (4, "mongod"), (5, "python"), - (5, "java") # this should be ignored. + (5, "java"), # this should be ignored. ] process_ids = None - interesting_processes = ['python', 'mongo', 'mongod'] + interesting_processes = ["python", "mongo", "mongod"] process_match = "contains" logger = Mock() - processes = get_processes(process_ids, interesting_processes, process_match, logger) + processes = get_processes( + process_ids, interesting_processes, process_match, logger + ) - self.assertCountEqual(processes, [ - Pinfo(name="python", pidv=[5]), - Pinfo(name="python2", pidv=[1]), - Pinfo(name="python3", pidv=[3]), - Pinfo(name="mongo", pidv=[2]), - Pinfo(name="mongod", pidv=[4]) - ]) + self.assertCountEqual( + processes, + [ + Pinfo(name="python", pidv=[5]), + Pinfo(name="python2", pidv=[1]), + Pinfo(name="python3", pidv=[3]), + Pinfo(name="mongo", pidv=[2]), + Pinfo(name="mongod", pidv=[4]), + ], + ) @patch(ns("os.getpid")) @patch(ns("_get_lister")) @@ -112,10 +131,15 @@ class TestGetProcesses(unittest.TestCase): process_match = "exact" logger = Mock() - processes = get_processes(process_ids, interesting_processes, process_match, logger) + processes = get_processes( + process_ids, interesting_processes, process_match, logger + ) - self.assertCountEqual(processes, [ - Pinfo(name="python", pidv=[1, 3]), - Pinfo(name="mongo", pidv=[2]), - Pinfo(name="mongod", pidv=[4, 5]) - ]) + self.assertCountEqual( + processes, + [ + Pinfo(name="python", pidv=[1, 3]), + Pinfo(name="mongo", pidv=[2]), + Pinfo(name="mongod", pidv=[4, 5]), + ], + ) diff --git a/buildscripts/tests/resmokelib/logging/test_buildlogger.py b/buildscripts/tests/resmokelib/logging/test_buildlogger.py index 8167bbe6327..8fefeb224cd 100644 --- a/buildscripts/tests/resmokelib/logging/test_buildlogger.py +++ b/buildscripts/tests/resmokelib/logging/test_buildlogger.py @@ -34,15 +34,28 @@ class TestLogsSplitter(unittest.TestCase): def test_split_max_size_larger(self): logs = self.__generate_logs(size=31) max_size = 30 - self.assertEqual([logs[0:-1], logs[-1:]], - buildlogger._LogsSplitter.split_logs(logs, max_size)) + self.assertEqual( + [logs[0:-1], logs[-1:]], + buildlogger._LogsSplitter.split_logs(logs, max_size), + ) logs = self.__generate_logs(size=149) max_size = 19 - self.assertEqual([ - logs[0:3], logs[3:6], logs[6:9], logs[9:12], logs[12:15], logs[15:18], logs[18:21], - logs[21:24], logs[24:27], logs[27:] - ], buildlogger._LogsSplitter.split_logs(logs, max_size)) + self.assertEqual( + [ + logs[0:3], + logs[3:6], + logs[6:9], + logs[9:12], + logs[12:15], + logs[15:18], + logs[18:21], + logs[21:24], + logs[24:27], + logs[27:], + ], + buildlogger._LogsSplitter.split_logs(logs, max_size), + ) def check_split_sizes(self, splits, max_size): for split in splits: diff --git a/buildscripts/tests/resmokelib/logging/test_loggers.py b/buildscripts/tests/resmokelib/logging/test_loggers.py index dadebd2291f..d1f5fe66347 100644 --- a/buildscripts/tests/resmokelib/logging/test_loggers.py +++ b/buildscripts/tests/resmokelib/logging/test_loggers.py @@ -57,14 +57,23 @@ class TestLoggers(unittest.TestCase): loggers.BUILDLOGGER_SERVER.get_test_log_url.return_value = "dummy_url" mock_parent = MagicMock() - (logger, url) = loggers.new_test_logger("dummy_shortname", "dummy_basename", - "dummy_command", mock_parent, 88, 99, MagicMock()) + (logger, url) = loggers.new_test_logger( + "dummy_shortname", + "dummy_basename", + "dummy_command", + mock_parent, + 88, + 99, + MagicMock(), + ) self.assertEqual(logger.handlers[0], mock_handler) self.assertEqual(logger.parent, mock_parent) self.assertEqual(url, "dummy_url") def test_test_thread_logger(self): - logger = loggers.new_test_thread_logger("dummy_parent", "dummy_kind", "dummy_id") + logger = loggers.new_test_thread_logger( + "dummy_parent", "dummy_kind", "dummy_id" + ) self.assertEqual(logger.parent, "dummy_parent") def test_hook_logger(self): diff --git a/buildscripts/tests/resmokelib/multiversion/test_multiversion_service.py b/buildscripts/tests/resmokelib/multiversion/test_multiversion_service.py index f6f1e665a40..c9aa733038d 100644 --- a/buildscripts/tests/resmokelib/multiversion/test_multiversion_service.py +++ b/buildscripts/tests/resmokelib/multiversion/test_multiversion_service.py @@ -1,4 +1,5 @@ """Unit tests for multiversion_service.py.""" + from unittest import TestCase from packaging.version import Version @@ -17,7 +18,9 @@ class TestTagStr(TestCase): class TestGetVersion(TestCase): def test_version_should_be_extracted(self): - mongo_version = under_test.MongoVersion(mongo_version="6.0.0-rc5-18-gbcdfaa9035b") + mongo_version = under_test.MongoVersion( + mongo_version="6.0.0-rc5-18-gbcdfaa9035b" + ) self.assertEqual(mongo_version.get_version(), Version("6.0")) @@ -34,13 +37,35 @@ class TestCalculateFcvConstants(TestCase): mongo_releases = under_test.MongoReleases( **{ "featureCompatibilityVersions": [ - "4.0", "4.2", "4.4", "4.7", "4.8", "4.9", "5.0", "5.1", "5.2", "5.3", "6.0", - "100.0" + "4.0", + "4.2", + "4.4", + "4.7", + "4.8", + "4.9", + "5.0", + "5.1", + "5.2", + "5.3", + "6.0", + "100.0", ], "longTermSupportReleases": ["4.0", "4.2", "4.4", "5.0"], - "eolVersions": - ["2.0", "2.2", "2.4", "2.6", "3.0", "3.2", "3.4", "3.6", "4.0", "5.1", "5.2"], - }) + "eolVersions": [ + "2.0", + "2.2", + "2.4", + "2.6", + "3.0", + "3.2", + "3.4", + "3.6", + "4.0", + "5.1", + "5.2", + ], + } + ) multiversion_service = under_test.MultiversionService( mongo_version=mongo_version, @@ -52,26 +77,67 @@ class TestCalculateFcvConstants(TestCase): self.assertEqual(version_constants.latest, Version("6.0")) self.assertEqual(version_constants.last_continuous, Version("5.3")) self.assertEqual(version_constants.last_lts, Version("5.0")) - self.assertEqual(version_constants.requires_fcv_tag_list, - [Version(v) for v in ["5.1", "5.2", "5.3", "6.0"]]) - self.assertEqual(version_constants.requires_fcv_tag_list_continuous, [Version("6.0")]) - self.assertEqual(version_constants.fcvs_less_than_latest, [ - Version(v) - for v in ["4.0", "4.2", "4.4", "4.7", "4.8", "4.9", "5.0", "5.1", "5.2", "5.3"] - ]) + self.assertEqual( + version_constants.requires_fcv_tag_list, + [Version(v) for v in ["5.1", "5.2", "5.3", "6.0"]], + ) + self.assertEqual( + version_constants.requires_fcv_tag_list_continuous, [Version("6.0")] + ) + self.assertEqual( + version_constants.fcvs_less_than_latest, + [ + Version(v) + for v in [ + "4.0", + "4.2", + "4.4", + "4.7", + "4.8", + "4.9", + "5.0", + "5.1", + "5.2", + "5.3", + ] + ], + ) def test_fcv_constants_should_be_accurate_for_future_git_tag(self): mongo_version = under_test.MongoVersion(mongo_version="100.0") mongo_releases = under_test.MongoReleases( **{ "featureCompatibilityVersions": [ - "4.0", "4.2", "4.4", "4.7", "4.8", "4.9", "5.0", "5.1", "5.2", "5.3", "6.0", - "6.1", "100.0" + "4.0", + "4.2", + "4.4", + "4.7", + "4.8", + "4.9", + "5.0", + "5.1", + "5.2", + "5.3", + "6.0", + "6.1", + "100.0", ], "longTermSupportReleases": ["4.0", "4.2", "4.4", "5.0", "6.0"], - "eolVersions": - ["2.0", "2.2", "2.4", "2.6", "3.0", "3.2", "3.4", "3.6", "4.0", "5.1", "5.2"], - }) + "eolVersions": [ + "2.0", + "2.2", + "2.4", + "2.6", + "3.0", + "3.2", + "3.4", + "3.6", + "4.0", + "5.1", + "5.2", + ], + } + ) multiversion_service = under_test.MultiversionService( mongo_version=mongo_version, @@ -83,10 +149,30 @@ class TestCalculateFcvConstants(TestCase): self.assertEqual(version_constants.latest, Version("100.0")) self.assertEqual(version_constants.last_continuous, Version("6.1")) self.assertEqual(version_constants.last_lts, Version("6.0")) - self.assertEqual(version_constants.requires_fcv_tag_list, - [Version(v) for v in ["6.1", "100.0"]]) - self.assertEqual(version_constants.requires_fcv_tag_list_continuous, [Version("100.0")]) - self.assertEqual(version_constants.fcvs_less_than_latest, [ - Version(v) for v in - ["4.0", "4.2", "4.4", "4.7", "4.8", "4.9", "5.0", "5.1", "5.2", "5.3", "6.0", "6.1"] - ]) + self.assertEqual( + version_constants.requires_fcv_tag_list, + [Version(v) for v in ["6.1", "100.0"]], + ) + self.assertEqual( + version_constants.requires_fcv_tag_list_continuous, [Version("100.0")] + ) + self.assertEqual( + version_constants.fcvs_less_than_latest, + [ + Version(v) + for v in [ + "4.0", + "4.2", + "4.4", + "4.7", + "4.8", + "4.9", + "5.0", + "5.1", + "5.2", + "5.3", + "6.0", + "6.1", + ] + ], + ) diff --git a/buildscripts/tests/resmokelib/powercycle/test_remote_operations.py b/buildscripts/tests/resmokelib/powercycle/test_remote_operations.py index 89f37a0e9d2..23ac015ab0e 100755 --- a/buildscripts/tests/resmokelib/powercycle/test_remote_operations.py +++ b/buildscripts/tests/resmokelib/powercycle/test_remote_operations.py @@ -23,11 +23,13 @@ class RemoteOperationsTestCase(unittest.TestCase): self.temp_remote_dir = tempfile.mkdtemp() self.rop = rop.RemoteOperations(user_host="localhost") self.rop_use_shell = rop.RemoteOperations(user_host="localhost", use_shell=True) - self.rop_sh_shell_binary = rop.RemoteOperations(user_host="localhost", - shell_binary="/bin/sh") + self.rop_sh_shell_binary = rop.RemoteOperations( + user_host="localhost", shell_binary="/bin/sh" + ) self.rop_ssh_opts = rop.RemoteOperations( user_host="localhost", - ssh_connection_options="-v -o ConnectTimeout=10 -o ConnectionAttempts=10") + ssh_connection_options="-v -o ConnectTimeout=10 -o ConnectionAttempts=10", + ) def tearDown(self): shutil.rmtree(self.temp_local_dir, ignore_errors=True) @@ -37,7 +39,6 @@ class RemoteOperationsTestCase(unittest.TestCase): class RemoteOperationConnection(RemoteOperationsTestCase): @unittest.skip("Known broken. SERVER-48969 tracks re-enabling.") def runTest(self): - self.assertTrue(self.rop.access_established()) ret, buff = self.rop.access_info() self.assertEqual(0, ret) @@ -51,8 +52,9 @@ class RemoteOperationConnection(RemoteOperationsTestCase): # Valid host with invalid ssh options ssh_connection_options = "-o invalid" - remote_op = rop.RemoteOperations(user_host="localhost", - ssh_connection_options=ssh_connection_options) + remote_op = rop.RemoteOperations( + user_host="localhost", ssh_connection_options=ssh_connection_options + ) ret, buff = remote_op.access_info() self.assertFalse(remote_op.access_established()) self.assertNotEqual(0, ret) @@ -67,8 +69,9 @@ class RemoteOperationConnection(RemoteOperationsTestCase): # Valid host with valid ssh options ssh_connection_options = "-v -o ConnectTimeout=10 -o ConnectionAttempts=10" - remote_op = rop.RemoteOperations(user_host="localhost", - ssh_connection_options=ssh_connection_options) + remote_op = rop.RemoteOperations( + user_host="localhost", ssh_connection_options=ssh_connection_options + ) ret, buff = remote_op.access_info() self.assertTrue(remote_op.access_established()) self.assertEqual(0, ret) @@ -83,9 +86,11 @@ class RemoteOperationConnection(RemoteOperationsTestCase): ssh_connection_options = "-v -o ConnectTimeout=10 -o ConnectionAttempts=10" ssh_options = "-t" - remote_op = rop.RemoteOperations(user_host="localhost", - ssh_connection_options=ssh_connection_options, - ssh_options=ssh_options) + remote_op = rop.RemoteOperations( + user_host="localhost", + ssh_connection_options=ssh_connection_options, + ssh_options=ssh_options, + ) ret, buff = remote_op.access_info() self.assertTrue(remote_op.access_established()) self.assertEqual(0, ret) @@ -95,7 +100,6 @@ class RemoteOperationConnection(RemoteOperationsTestCase): class RemoteOperationShell(RemoteOperationsTestCase): @unittest.skip("Known broken. SERVER-48969 tracks re-enabling.") def runTest(self): - # Shell connect ret, buff = self.rop.shell("uname") self.assertEqual(0, ret) @@ -145,32 +149,41 @@ class RemoteOperationShell(RemoteOperationsTestCase): self.assertIsNotNone(buff) # Multiple commands with escaped single quotes - ret, buff = self.rop.shell("echo \"hello \'dolly\'\"; pwd; echo \"goodbye \'charlie\'\"") + ret, buff = self.rop.shell( + "echo \"hello 'dolly'\"; pwd; echo \"goodbye 'charlie'\"" + ) self.assertEqual(0, ret) self.assertIsNotNone(buff) ret, buff = self.rop_use_shell.shell( - "echo \"hello \'dolly\'\"; pwd; echo \"goodbye \'charlie\'\"") + "echo \"hello 'dolly'\"; pwd; echo \"goodbye 'charlie'\"" + ) self.assertEqual(0, ret) self.assertIsNotNone(buff) # Command with escaped double quotes - ret, buff = self.rop.shell("echo \"hello there\" | grep \"hello\"") + ret, buff = self.rop.shell('echo "hello there" | grep "hello"') self.assertEqual(0, ret) self.assertIsNotNone(buff) - ret, buff = self.rop_use_shell.shell("echo \"hello there\" | grep \"hello\"") + ret, buff = self.rop_use_shell.shell('echo "hello there" | grep "hello"') self.assertEqual(0, ret) self.assertIsNotNone(buff) # Command with directory and pipe - ret, buff = self.rop.shell("touch {dir}/{file}; ls {dir} | grep {file}".format( - file=time.time(), dir="/tmp")) + ret, buff = self.rop.shell( + "touch {dir}/{file}; ls {dir} | grep {file}".format( + file=time.time(), dir="/tmp" + ) + ) self.assertEqual(0, ret) self.assertIsNotNone(buff) - ret, buff = self.rop_use_shell.shell("touch {dir}/{file}; ls {dir} | grep {file}".format( - file=time.time(), dir="/tmp")) + ret, buff = self.rop_use_shell.shell( + "touch {dir}/{file}; ls {dir} | grep {file}".format( + file=time.time(), dir="/tmp" + ) + ) self.assertEqual(0, ret) self.assertIsNotNone(buff) @@ -178,7 +191,6 @@ class RemoteOperationShell(RemoteOperationsTestCase): class RemoteOperationCopyTo(RemoteOperationsTestCase): @unittest.skip("Known broken. SERVER-48969 tracks re-enabling.") def runTest(self): - # Copy to remote l_temp_path = tempfile.mkstemp(dir=self.temp_local_dir)[1] l_temp_file = os.path.basename(l_temp_path) @@ -205,7 +217,9 @@ class RemoteOperationCopyTo(RemoteOperationsTestCase): l_temp_path = tempfile.mkstemp(dir=self.temp_local_dir)[1] l_temp_file = os.path.basename(l_temp_path) - ret, buff = self.rop_ssh_opts.operation("copy_to", l_temp_path, self.temp_remote_dir) + ret, buff = self.rop_ssh_opts.operation( + "copy_to", l_temp_path, self.temp_remote_dir + ) self.assertEqual(0, ret) self.assertIsNotNone(buff) self.assertTrue(os.path.isfile(r_temp_path)) @@ -221,7 +235,9 @@ class RemoteOperationCopyTo(RemoteOperationsTestCase): self.assertEqual(0, ret) self.assertIsNotNone(buff) for i in range(num_files): - r_temp_path = os.path.join(self.temp_remote_dir, os.path.basename(l_temp_files[i])) + r_temp_path = os.path.join( + self.temp_remote_dir, os.path.basename(l_temp_files[i]) + ) self.assertTrue(os.path.isfile(r_temp_path)) num_files = 3 @@ -230,11 +246,15 @@ class RemoteOperationCopyTo(RemoteOperationsTestCase): l_temp_path = tempfile.mkstemp(dir=self.temp_local_dir)[1] l_temp_file = os.path.basename(l_temp_path) l_temp_files.append(l_temp_path) - ret, buff = self.rop_use_shell.copy_to(" ".join(l_temp_files), self.temp_remote_dir) + ret, buff = self.rop_use_shell.copy_to( + " ".join(l_temp_files), self.temp_remote_dir + ) self.assertEqual(0, ret) self.assertIsNotNone(buff) for i in range(num_files): - r_temp_path = os.path.join(self.temp_remote_dir, os.path.basename(l_temp_files[i])) + r_temp_path = os.path.join( + self.temp_remote_dir, os.path.basename(l_temp_files[i]) + ) self.assertTrue(os.path.isfile(r_temp_path)) # Copy to remote without directory @@ -257,7 +277,9 @@ class RemoteOperationCopyTo(RemoteOperationsTestCase): os.remove(r_temp_path) # Copy to remote with space in file name, note it must be quoted. - l_temp_path = tempfile.mkstemp(dir=self.temp_local_dir, prefix="filename with space")[1] + l_temp_path = tempfile.mkstemp( + dir=self.temp_local_dir, prefix="filename with space" + )[1] l_temp_file = os.path.basename(l_temp_path) ret, buff = self.rop.copy_to("'{}'".format(l_temp_path)) self.assertEqual(0, ret) @@ -266,7 +288,9 @@ class RemoteOperationCopyTo(RemoteOperationsTestCase): self.assertTrue(os.path.isfile(r_temp_path)) os.remove(r_temp_path) - l_temp_path = tempfile.mkstemp(dir=self.temp_local_dir, prefix="filename with space")[1] + l_temp_path = tempfile.mkstemp( + dir=self.temp_local_dir, prefix="filename with space" + )[1] l_temp_file = os.path.basename(l_temp_path) ret, buff = self.rop_use_shell.copy_to("'{}'".format(l_temp_path)) self.assertEqual(0, ret) @@ -298,7 +322,6 @@ class RemoteOperationCopyTo(RemoteOperationsTestCase): class RemoteOperationCopyFrom(RemoteOperationsTestCase): @unittest.skip("Known broken. SERVER-48969 tracks re-enabling.") def runTest(self): - # Copy from remote r_temp_path = tempfile.mkstemp(dir=self.temp_remote_dir)[1] r_temp_file = os.path.basename(r_temp_path) @@ -342,7 +365,9 @@ class RemoteOperationCopyFrom(RemoteOperationsTestCase): os.remove(r_temp_file) # Copy from remote with space in file name, note it must be quoted. - r_temp_path = tempfile.mkstemp(dir=self.temp_remote_dir, prefix="filename with space")[1] + r_temp_path = tempfile.mkstemp( + dir=self.temp_remote_dir, prefix="filename with space" + )[1] r_temp_file = os.path.basename(r_temp_path) ret, buff = self.rop.copy_from("'{}'".format(r_temp_path)) self.assertEqual(0, ret) @@ -371,7 +396,9 @@ class RemoteOperationCopyFrom(RemoteOperationsTestCase): r_temp_path = tempfile.mkstemp(dir=self.temp_remote_dir)[1] r_temp_file = os.path.basename(r_temp_path) r_temp_files.append(r_temp_path) - ret, buff = self.rop_use_shell.copy_from(" ".join(r_temp_files), self.temp_local_dir) + ret, buff = self.rop_use_shell.copy_from( + " ".join(r_temp_files), self.temp_local_dir + ) self.assertEqual(0, ret) self.assertIsNotNone(buff) for i in range(num_files): @@ -391,7 +418,9 @@ class RemoteOperationCopyFrom(RemoteOperationsTestCase): self.assertEqual(0, ret) self.assertIsNotNone(buff) for i in range(num_files): - l_temp_path = os.path.join(self.temp_local_dir, os.path.basename(r_temp_files[i])) + l_temp_path = os.path.join( + self.temp_local_dir, os.path.basename(r_temp_files[i]) + ) self.assertTrue(os.path.isfile(l_temp_path)) num_files = 3 @@ -405,11 +434,15 @@ class RemoteOperationCopyFrom(RemoteOperationsTestCase): self.assertEqual(0, ret) self.assertIsNotNone(buff) for i in range(num_files): - l_temp_path = os.path.join(self.temp_local_dir, os.path.basename(r_temp_files[i])) + l_temp_path = os.path.join( + self.temp_local_dir, os.path.basename(r_temp_files[i]) + ) self.assertTrue(os.path.isfile(l_temp_path)) # Local directory does not exist. - self.assertRaises(ValueError, lambda: self.rop_use_shell.copy_from(r_temp_path, "bad_dir")) + self.assertRaises( + ValueError, lambda: self.rop_use_shell.copy_from(r_temp_path, "bad_dir") + ) # Valid scp options r_temp_path = tempfile.mkstemp(dir=self.temp_remote_dir)[1] diff --git a/buildscripts/tests/resmokelib/run/test_auto_kill_rogue_process.py b/buildscripts/tests/resmokelib/run/test_auto_kill_rogue_process.py index 85b901185a8..b07ed99ce04 100644 --- a/buildscripts/tests/resmokelib/run/test_auto_kill_rogue_process.py +++ b/buildscripts/tests/resmokelib/run/test_auto_kill_rogue_process.py @@ -1,4 +1,5 @@ """Unit tests for buildscripts/resmokelib/run/list_tags.py.""" + # pylint: disable=protected-access import logging import os @@ -23,15 +24,15 @@ class MockTestRunner(TestRunner): class TestDetectRogueProcess(unittest.TestCase): def setUp(self) -> None: - self.command = [sys.executable, '-c', "import time; time.sleep(5)"] - if sys.platform.lower() == 'win32': + self.command = [sys.executable, "-c", "import time; time.sleep(5)"] + if sys.platform.lower() == "win32": self.sigkill_return = fixture_interface.TeardownMode.TERMINATE.value else: self.sigkill_return = -fixture_interface.TeardownMode.KILL.value - if not os.environ.get('RESMOKE_PARENT_PROCESS'): - os.environ['RESMOKE_PARENT_PROCESS'] = str(os.getpid()) - os.environ['RESMOKE_PARENT_CTIME'] = str(psutil.Process().create_time()) + if not os.environ.get("RESMOKE_PARENT_PROCESS"): + os.environ["RESMOKE_PARENT_PROCESS"] = str(os.getpid()) + os.environ["RESMOKE_PARENT_CTIME"] = str(psutil.Process().create_time()) # TODO: SERVER-90631 reenable this test # This works locally which is what we care about but it unclear why it is failing remotly @@ -41,7 +42,7 @@ class TestDetectRogueProcess(unittest.TestCase): reason="TODO: SERVER-90631 reenable this test on macos", ) def test_warn(self): - buildscripts.resmokelib.config.AUTO_KILL = 'warn' + buildscripts.resmokelib.config.AUTO_KILL = "warn" buildscripts.resmokelib.config.SHELL_CONN_STRING = None test_runner = MockTestRunner("test") @@ -52,11 +53,11 @@ class TestDetectRogueProcess(unittest.TestCase): except errors.ResmokeError: self.fail("Detected processes when there should be none.") - tmp_ctime = os.environ['RESMOKE_PARENT_CTIME'] - os.environ['RESMOKE_PARENT_CTIME'] = str("rogue_process") + tmp_ctime = os.environ["RESMOKE_PARENT_CTIME"] + os.environ["RESMOKE_PARENT_CTIME"] = str("rogue_process") proc = process.Process(logging.getLogger(), self.command) proc.start() - os.environ['RESMOKE_PARENT_CTIME'] = tmp_ctime + os.environ["RESMOKE_PARENT_CTIME"] = tmp_ctime with self.assertRaises(errors.ResmokeError): test_runner._check_for_mongo_processes() @@ -65,8 +66,7 @@ class TestDetectRogueProcess(unittest.TestCase): proc.wait() def test_on(self): - - buildscripts.resmokelib.config.AUTO_KILL = 'on' + buildscripts.resmokelib.config.AUTO_KILL = "on" buildscripts.resmokelib.config.SHELL_CONN_STRING = None test_runner = MockTestRunner("test") @@ -74,11 +74,11 @@ class TestDetectRogueProcess(unittest.TestCase): test_runner._check_for_mongo_processes() - tmp_ctime = os.environ['RESMOKE_PARENT_CTIME'] - os.environ['RESMOKE_PARENT_CTIME'] = str("rogue_process") + tmp_ctime = os.environ["RESMOKE_PARENT_CTIME"] + os.environ["RESMOKE_PARENT_CTIME"] = str("rogue_process") proc = process.Process(logging.getLogger(), self.command) proc.start() - os.environ['RESMOKE_PARENT_CTIME'] = tmp_ctime + os.environ["RESMOKE_PARENT_CTIME"] = tmp_ctime test_runner._check_for_mongo_processes() @@ -90,7 +90,7 @@ class TestDetectRogueProcess(unittest.TestCase): ) def test_off(self): - buildscripts.resmokelib.config.AUTO_KILL = 'off' + buildscripts.resmokelib.config.AUTO_KILL = "off" buildscripts.resmokelib.config.SHELL_CONN_STRING = None test_runner = MockTestRunner("test") @@ -110,8 +110,8 @@ class TestDetectRogueProcess(unittest.TestCase): proc.wait() def test_shell_constring(self): - buildscripts.resmokelib.config.AUTO_KILL = 'warn' - buildscripts.resmokelib.config.SHELL_CONN_STRING = '127.0.0.1:27000' + buildscripts.resmokelib.config.AUTO_KILL = "warn" + buildscripts.resmokelib.config.SHELL_CONN_STRING = "127.0.0.1:27000" test_runner = MockTestRunner("test") test_runner._setup_logging() diff --git a/buildscripts/tests/resmokelib/run/test_generate_multiversion_exclude_tags.py b/buildscripts/tests/resmokelib/run/test_generate_multiversion_exclude_tags.py index 0e7fc7256ca..2ce6a0aaffd 100644 --- a/buildscripts/tests/resmokelib/run/test_generate_multiversion_exclude_tags.py +++ b/buildscripts/tests/resmokelib/run/test_generate_multiversion_exclude_tags.py @@ -1,4 +1,5 @@ """Unit tests for buildscripts/resmokelib/run/generate_multiversion_exclude_tags.py.""" + import os import unittest from tempfile import TemporaryDirectory @@ -28,16 +29,18 @@ class TestGenerateExcludeYaml(unittest.TestCase): def patch_and_run(self, latest, old, old_bin_version): """Helper to patch and run the test.""" mock_multiversion_methods = { - 'get_backports_required_hash_for_shell_version': MagicMock(), - 'get_old_yaml': MagicMock(return_value=old) + "get_backports_required_hash_for_shell_version": MagicMock(), + "get_old_yaml": MagicMock(return_value=old), } - with patch.multiple('buildscripts.resmokelib.run.generate_multiversion_exclude_tags', - **mock_multiversion_methods): + with patch.multiple( + "buildscripts.resmokelib.run.generate_multiversion_exclude_tags", + **mock_multiversion_methods, + ): with patch( - 'buildscripts.resmokelib.run.generate_multiversion_exclude_tags.read_yaml_file', - return_value=latest) as mock_read_yaml: - + "buildscripts.resmokelib.run.generate_multiversion_exclude_tags.read_yaml_file", + return_value=latest, + ) as mock_read_yaml: output = os.path.join(self._tmpdir.name, EXCLUDE_TAGS_FILE) under_test.generate_exclude_yaml( old_bin_version=old_bin_version, @@ -47,30 +50,51 @@ class TestGenerateExcludeYaml(unittest.TestCase): mock_read_yaml.assert_called_once() mock_multiversion_methods[ - 'get_backports_required_hash_for_shell_version'].assert_called_once() - mock_multiversion_methods['get_old_yaml'].assert_called_once() + "get_backports_required_hash_for_shell_version" + ].assert_called_once() + mock_multiversion_methods["get_old_yaml"].assert_called_once() def test_create_yaml_suite1(self): latest_yaml = { - 'last-continuous': None, 'last-lts': { - 'all': [{'ticket': 'fake_ticket0', 'test_file': 'jstests/fake_file0.js'}], - 'suites': { - 'suite1': [{'ticket': 'fake_ticket1', 'test_file': 'jstests/fake_file1.js'}, - {'ticket': 'fake_ticket2', 'test_file': 'jstests/fake_file2.js'}] - } - } + "last-continuous": None, + "last-lts": { + "all": [ + {"ticket": "fake_ticket0", "test_file": "jstests/fake_file0.js"} + ], + "suites": { + "suite1": [ + { + "ticket": "fake_ticket1", + "test_file": "jstests/fake_file1.js", + }, + { + "ticket": "fake_ticket2", + "test_file": "jstests/fake_file2.js", + }, + ] + }, + }, } old_yaml = { - 'last-continuous': None, 'last-lts': { - 'all': [{'ticket': 'fake_ticket0', 'test_file': 'jstests/fake_file0.js'}], 'suites': - {'suite1': [{'ticket': 'fake_ticket2', 'test_file': 'jstests/fake_file2.js'}]} - } + "last-continuous": None, + "last-lts": { + "all": [ + {"ticket": "fake_ticket0", "test_file": "jstests/fake_file0.js"} + ], + "suites": { + "suite1": [ + {"ticket": "fake_ticket2", "test_file": "jstests/fake_file2.js"} + ] + }, + }, } expected = { - 'selector': { - 'js_test': {'jstests/fake_file1.js': ['suite1_backport_required_multiversion']} + "selector": { + "js_test": { + "jstests/fake_file1.js": ["suite1_backport_required_multiversion"] + } } } @@ -79,29 +103,49 @@ class TestGenerateExcludeYaml(unittest.TestCase): def test_create_yaml_suite1_and_suite2(self): latest_yaml = { - 'last-continuous': None, 'last-lts': { - 'all': [{'ticket': 'fake_ticket0', 'test_file': 'jstests/fake_file0.js'}], - 'suites': { - 'suite1': [{'ticket': 'fake_ticket1', 'test_file': 'jstests/fake_file1.js'}, - {'ticket': 'fake_ticket2', 'test_file': 'jstests/fake_file2.js'}], - 'suite2': [{'ticket': 'fake_ticket1', 'test_file': 'jstests/fake_file1.js'}] - } - } + "last-continuous": None, + "last-lts": { + "all": [ + {"ticket": "fake_ticket0", "test_file": "jstests/fake_file0.js"} + ], + "suites": { + "suite1": [ + { + "ticket": "fake_ticket1", + "test_file": "jstests/fake_file1.js", + }, + { + "ticket": "fake_ticket2", + "test_file": "jstests/fake_file2.js", + }, + ], + "suite2": [ + {"ticket": "fake_ticket1", "test_file": "jstests/fake_file1.js"} + ], + }, + }, } old_yaml = { - 'last-continuous': None, 'last-lts': { - 'all': [{'ticket': 'fake_ticket0', 'test_file': 'jstests/fake_file0.js'}], 'suites': - {'suite1': [{'ticket': 'fake_ticket2', 'test_file': 'jstests/fake_file2.js'}]} - } + "last-continuous": None, + "last-lts": { + "all": [ + {"ticket": "fake_ticket0", "test_file": "jstests/fake_file0.js"} + ], + "suites": { + "suite1": [ + {"ticket": "fake_ticket2", "test_file": "jstests/fake_file2.js"} + ] + }, + }, } expected = { - 'selector': { - 'js_test': { - 'jstests/fake_file1.js': [ - 'suite1_backport_required_multiversion', - 'suite2_backport_required_multiversion' + "selector": { + "js_test": { + "jstests/fake_file1.js": [ + "suite1_backport_required_multiversion", + "suite2_backport_required_multiversion", ] } } @@ -112,25 +156,41 @@ class TestGenerateExcludeYaml(unittest.TestCase): def test_both_all_are_none(self): latest_yaml = { - 'last-continuous': None, 'last-lts': { - 'all': None, 'suites': { - 'suite1': [{'ticket': 'fake_ticket1', 'test_file': 'jstests/fake_file1.js'}, - {'ticket': 'fake_ticket2', 'test_file': 'jstests/fake_file2.js'}] - } - } + "last-continuous": None, + "last-lts": { + "all": None, + "suites": { + "suite1": [ + { + "ticket": "fake_ticket1", + "test_file": "jstests/fake_file1.js", + }, + { + "ticket": "fake_ticket2", + "test_file": "jstests/fake_file2.js", + }, + ] + }, + }, } old_yaml = { - 'last-continuous': None, 'last-lts': { - 'all': None, 'suites': { - 'suite1': [{'ticket': 'fake_ticket2', 'test_file': 'jstests/fake_file2.js'}] - } - } + "last-continuous": None, + "last-lts": { + "all": None, + "suites": { + "suite1": [ + {"ticket": "fake_ticket2", "test_file": "jstests/fake_file2.js"} + ] + }, + }, } expected = { - 'selector': { - 'js_test': {'jstests/fake_file1.js': ['suite1_backport_required_multiversion']} + "selector": { + "js_test": { + "jstests/fake_file1.js": ["suite1_backport_required_multiversion"] + } } } @@ -139,28 +199,43 @@ class TestGenerateExcludeYaml(unittest.TestCase): def test_old_all_is_none(self): latest_yaml = { - 'last-continuous': None, 'last-lts': { - 'all': [{'ticket': 'fake_ticket0', 'test_file': 'jstests/fake_file0.js'}], - 'suites': { - 'suite1': [{'ticket': 'fake_ticket1', 'test_file': 'jstests/fake_file1.js'}, - {'ticket': 'fake_ticket2', 'test_file': 'jstests/fake_file2.js'}] - } - } + "last-continuous": None, + "last-lts": { + "all": [ + {"ticket": "fake_ticket0", "test_file": "jstests/fake_file0.js"} + ], + "suites": { + "suite1": [ + { + "ticket": "fake_ticket1", + "test_file": "jstests/fake_file1.js", + }, + { + "ticket": "fake_ticket2", + "test_file": "jstests/fake_file2.js", + }, + ] + }, + }, } old_yaml = { - 'last-continuous': None, 'last-lts': { - 'all': None, 'suites': { - 'suite1': [{'ticket': 'fake_ticket2', 'test_file': 'jstests/fake_file2.js'}] - } - } + "last-continuous": None, + "last-lts": { + "all": None, + "suites": { + "suite1": [ + {"ticket": "fake_ticket2", "test_file": "jstests/fake_file2.js"} + ] + }, + }, } expected = { - 'selector': { - 'js_test': { - 'jstests/fake_file1.js': ['suite1_backport_required_multiversion'], - 'jstests/fake_file0.js': ['backport_required_multiversion'] + "selector": { + "js_test": { + "jstests/fake_file1.js": ["suite1_backport_required_multiversion"], + "jstests/fake_file0.js": ["backport_required_multiversion"], } } } @@ -170,28 +245,46 @@ class TestGenerateExcludeYaml(unittest.TestCase): def test_create_yaml_suite1_and_all(self): latest_yaml = { - 'last-continuous': None, 'last-lts': { - 'all': [{'ticket': 'fake_ticket0', 'test_file': 'jstests/fake_file0.js'}, - {'ticket': 'fake_ticket4', 'test_file': 'jstests/fake_file4.js'}], - 'suites': { - 'suite1': [{'ticket': 'fake_ticket1', 'test_file': 'jstests/fake_file1.js'}, - {'ticket': 'fake_ticket2', 'test_file': 'jstests/fake_file2.js'}] - } - } + "last-continuous": None, + "last-lts": { + "all": [ + {"ticket": "fake_ticket0", "test_file": "jstests/fake_file0.js"}, + {"ticket": "fake_ticket4", "test_file": "jstests/fake_file4.js"}, + ], + "suites": { + "suite1": [ + { + "ticket": "fake_ticket1", + "test_file": "jstests/fake_file1.js", + }, + { + "ticket": "fake_ticket2", + "test_file": "jstests/fake_file2.js", + }, + ] + }, + }, } old_yaml = { - 'last-continuous': None, 'last-lts': { - 'all': [{'ticket': 'fake_ticket0', 'test_file': 'jstests/fake_file0.js'}], 'suites': - {'suite1': [{'ticket': 'fake_ticket2', 'test_file': 'jstests/fake_file2.js'}]} - } + "last-continuous": None, + "last-lts": { + "all": [ + {"ticket": "fake_ticket0", "test_file": "jstests/fake_file0.js"} + ], + "suites": { + "suite1": [ + {"ticket": "fake_ticket2", "test_file": "jstests/fake_file2.js"} + ] + }, + }, } expected = { - 'selector': { - 'js_test': { - 'jstests/fake_file1.js': ['suite1_backport_required_multiversion'], - 'jstests/fake_file4.js': ['backport_required_multiversion'] + "selector": { + "js_test": { + "jstests/fake_file1.js": ["suite1_backport_required_multiversion"], + "jstests/fake_file4.js": ["backport_required_multiversion"], } } } @@ -201,25 +294,45 @@ class TestGenerateExcludeYaml(unittest.TestCase): def test_last_continuous(self): latest_yaml = { - 'last-continuous': { - 'all': [{'ticket': 'fake_ticket0', 'test_file': 'jstests/fake_file0.js'}], - 'suites': { - 'suite1': [{'ticket': 'fake_ticket1', 'test_file': 'jstests/fake_file1.js'}, - {'ticket': 'fake_ticket2', 'test_file': 'jstests/fake_file2.js'}] - } - }, 'last-lts': None + "last-continuous": { + "all": [ + {"ticket": "fake_ticket0", "test_file": "jstests/fake_file0.js"} + ], + "suites": { + "suite1": [ + { + "ticket": "fake_ticket1", + "test_file": "jstests/fake_file1.js", + }, + { + "ticket": "fake_ticket2", + "test_file": "jstests/fake_file2.js", + }, + ] + }, + }, + "last-lts": None, } old_yaml = { - 'last-continuous': { - 'all': [{'ticket': 'fake_ticket0', 'test_file': 'jstests/fake_file0.js'}], 'suites': - {'suite1': [{'ticket': 'fake_ticket2', 'test_file': 'jstests/fake_file2.js'}]} - }, 'last-lts': None + "last-continuous": { + "all": [ + {"ticket": "fake_ticket0", "test_file": "jstests/fake_file0.js"} + ], + "suites": { + "suite1": [ + {"ticket": "fake_ticket2", "test_file": "jstests/fake_file2.js"} + ] + }, + }, + "last-lts": None, } expected = { - 'selector': { - 'js_test': {'jstests/fake_file1.js': ['suite1_backport_required_multiversion']} + "selector": { + "js_test": { + "jstests/fake_file1.js": ["suite1_backport_required_multiversion"] + } } } @@ -228,28 +341,46 @@ class TestGenerateExcludeYaml(unittest.TestCase): def test_old_last_continuous_is_empty(self): latest_yaml = { - 'last-continuous': { - 'all': [{'ticket': 'fake_ticket0', 'test_file': 'jstests/fake_file0.js'}], - 'suites': { - 'suite1': [{'ticket': 'fake_ticket1', 'test_file': 'jstests/fake_file1.js'}, - {'ticket': 'fake_ticket2', 'test_file': 'jstests/fake_file2.js'}] - } - }, 'last-lts': None + "last-continuous": { + "all": [ + {"ticket": "fake_ticket0", "test_file": "jstests/fake_file0.js"} + ], + "suites": { + "suite1": [ + { + "ticket": "fake_ticket1", + "test_file": "jstests/fake_file1.js", + }, + { + "ticket": "fake_ticket2", + "test_file": "jstests/fake_file2.js", + }, + ] + }, + }, + "last-lts": None, } old_yaml = { - 'last-continuous': {'all': None, 'suites': {}}, 'last-lts': { - 'all': [{'ticket': 'fake_ticket0', 'test_file': 'jstests/fake_file0.js'}], 'suites': - {'suite1': [{'ticket': 'fake_ticket2', 'test_file': 'jstests/fake_file2.js'}]} - } + "last-continuous": {"all": None, "suites": {}}, + "last-lts": { + "all": [ + {"ticket": "fake_ticket0", "test_file": "jstests/fake_file0.js"} + ], + "suites": { + "suite1": [ + {"ticket": "fake_ticket2", "test_file": "jstests/fake_file2.js"} + ] + }, + }, } expected = { - 'selector': { - 'js_test': { - 'jstests/fake_file0.js': ['backport_required_multiversion'], - 'jstests/fake_file1.js': ['suite1_backport_required_multiversion'], - 'jstests/fake_file2.js': ['suite1_backport_required_multiversion'] + "selector": { + "js_test": { + "jstests/fake_file0.js": ["backport_required_multiversion"], + "jstests/fake_file1.js": ["suite1_backport_required_multiversion"], + "jstests/fake_file2.js": ["suite1_backport_required_multiversion"], } } } @@ -260,27 +391,41 @@ class TestGenerateExcludeYaml(unittest.TestCase): # Can delete after backporting the changed yml syntax. def test_not_backported(self): latest_yaml = { - 'last-continuous': None, 'last-lts': { - 'all': [{'ticket': 'fake_ticket0', 'test_file': 'jstests/fake_file0.js'}, - {'ticket': 'fake_ticket1', 'test_file': 'jstests/fake_file1.js'}], - 'suites': { - 'suite1': [{'ticket': 'fake_ticket2', 'test_file': 'jstests/fake_file2.js'}, - {'ticket': 'fake_ticket3', 'test_file': 'jstests/fake_file3.js'}] - } - } + "last-continuous": None, + "last-lts": { + "all": [ + {"ticket": "fake_ticket0", "test_file": "jstests/fake_file0.js"}, + {"ticket": "fake_ticket1", "test_file": "jstests/fake_file1.js"}, + ], + "suites": { + "suite1": [ + { + "ticket": "fake_ticket2", + "test_file": "jstests/fake_file2.js", + }, + { + "ticket": "fake_ticket3", + "test_file": "jstests/fake_file3.js", + }, + ] + }, + }, } old_yaml = { - 'all': [{'ticket': 'fake_ticket1', 'test_file': 'jstests/fake_file1.js'}], 'suites': { - 'suite1': [{'ticket': 'fake_ticket2', 'test_file': 'jstests/fake_file2.js'}] - } + "all": [{"ticket": "fake_ticket1", "test_file": "jstests/fake_file1.js"}], + "suites": { + "suite1": [ + {"ticket": "fake_ticket2", "test_file": "jstests/fake_file2.js"} + ] + }, } expected = { - 'selector': { - 'js_test': { - 'jstests/fake_file0.js': ['backport_required_multiversion'], - 'jstests/fake_file3.js': ['suite1_backport_required_multiversion'] + "selector": { + "js_test": { + "jstests/fake_file0.js": ["backport_required_multiversion"], + "jstests/fake_file3.js": ["suite1_backport_required_multiversion"], } } } diff --git a/buildscripts/tests/resmokelib/run/test_list_tags.py b/buildscripts/tests/resmokelib/run/test_list_tags.py index ef7d8b33980..296998d0f85 100644 --- a/buildscripts/tests/resmokelib/run/test_list_tags.py +++ b/buildscripts/tests/resmokelib/run/test_list_tags.py @@ -1,4 +1,5 @@ """Unit tests for buildscripts/resmokelib/run/list_tags.py.""" + import unittest from buildscripts.resmokelib.run import list_tags @@ -8,17 +9,20 @@ def _get_suite(tags_blocks): block = "" for tags_block in tags_blocks: block = f"{block} {tags_block}" + "\n" - return (("test_kind: js_test\n" - "\n" - "selector:\n" - " roots: []\n") + block + ("\n" - "executor:\n" - " config: {}\n" - " fixture:\n" - " class: MongoDFixture\n" - " mongod_options:\n" - " set_parameters:\n" - " enableTestCommands: 1\n")) + return ( + ("test_kind: js_test\n" "\n" "selector:\n" " roots: []\n") + + block + + ( + "\n" + "executor:\n" + " config: {}\n" + " fixture:\n" + " class: MongoDFixture\n" + " mongod_options:\n" + " set_parameters:\n" + " enableTestCommands: 1\n" + ) + ) class TestParseTagsBlocks(unittest.TestCase): @@ -34,53 +38,71 @@ class TestParseTagsBlocks(unittest.TestCase): self.assertCountEqual(tags_blocks, result) def test_two_tags_blocks(self): - tags_blocks = [("exclude_with_any_tags:\n" - " - dummy_tag_1\n" - " - dummy_tag_2\n" - " - dummy_tag_3"), - ("include_with_any_tags:\n" - " - dummy_tag_4\n" - " - dummy_tag_5\n" - " - dummy_tag_6")] + tags_blocks = [ + ( + "exclude_with_any_tags:\n" + " - dummy_tag_1\n" + " - dummy_tag_2\n" + " - dummy_tag_3" + ), + ( + "include_with_any_tags:\n" + " - dummy_tag_4\n" + " - dummy_tag_5\n" + " - dummy_tag_6" + ), + ] suite = _get_suite(tags_blocks) result = list_tags.parse_tags_blocks(suite) self.assertCountEqual(tags_blocks, result) def test_tags_block_with_tags_and_above_comments(self): - tags_blocks = [("exclude_with_any_tags:\n" - " # comment\n" - " - dummy_tag_1\n" - " # comment line 1\n" - " # comment line 2\n" - " - dummy_tag_2\n" - " #################\n" - " # fancy comment #\n" - " #################\n" - " - dummy_tag_3")] + tags_blocks = [ + ( + "exclude_with_any_tags:\n" + " # comment\n" + " - dummy_tag_1\n" + " # comment line 1\n" + " # comment line 2\n" + " - dummy_tag_2\n" + " #################\n" + " # fancy comment #\n" + " #################\n" + " - dummy_tag_3" + ) + ] suite = _get_suite(tags_blocks) result = list_tags.parse_tags_blocks(suite) self.assertCountEqual(tags_blocks, result) def test_tags_block_with_tags_and_inline_comments(self): - tags_blocks = [("exclude_with_any_tags:\n" - " - dummy_tag_1 # inline comment\n" - " - dummy_tag_2 # another one\n" - " - dummy_tag_3 # and another one")] + tags_blocks = [ + ( + "exclude_with_any_tags:\n" + " - dummy_tag_1 # inline comment\n" + " - dummy_tag_2 # another one\n" + " - dummy_tag_3 # and another one" + ) + ] suite = _get_suite(tags_blocks) result = list_tags.parse_tags_blocks(suite) self.assertCountEqual(tags_blocks, result) def test_tags_block_with_tags_and_both_comments(self): - tags_blocks = [("exclude_with_any_tags:\n" - " # above comment\n" - " - dummy_tag_1 # inline comment\n" - " # above comment line 1\n" - " # above comment line 2\n" - " - dummy_tag_2 # another one inline\n" - " #######################\n" - " # above fancy comment #\n" - " #######################\n" - " - dummy_tag_3 # and another one inline")] + tags_blocks = [ + ( + "exclude_with_any_tags:\n" + " # above comment\n" + " - dummy_tag_1 # inline comment\n" + " # above comment line 1\n" + " # above comment line 2\n" + " - dummy_tag_2 # another one inline\n" + " #######################\n" + " # above fancy comment #\n" + " #######################\n" + " - dummy_tag_3 # and another one inline" + ) + ] suite = _get_suite(tags_blocks) result = list_tags.parse_tags_blocks(suite) self.assertCountEqual(tags_blocks, result) @@ -97,10 +119,12 @@ class TestSplitIntoTags(unittest.TestCase): self.assertCountEqual([[""]], result) def test_block_with_tags_no_comments(self): - block = ("exclude_with_any_tags:\n" - " - dummy_tag_1\n" - " - dummy_tag_2\n" - " - dummy_tag_3") + block = ( + "exclude_with_any_tags:\n" + " - dummy_tag_1\n" + " - dummy_tag_2\n" + " - dummy_tag_3" + ) expected = [ ["- dummy_tag_1"], ["- dummy_tag_2"], @@ -110,16 +134,18 @@ class TestSplitIntoTags(unittest.TestCase): self.assertCountEqual(expected, result) def test_block_with_tags_and_above_comments(self): - block = ("exclude_with_any_tags:\n" - " # comment\n" - " - dummy_tag_1\n" - " # comment line 1\n" - " # comment line 2\n" - " - dummy_tag_2\n" - " #################\n" - " # fancy comment #\n" - " #################\n" - " - dummy_tag_3") + block = ( + "exclude_with_any_tags:\n" + " # comment\n" + " - dummy_tag_1\n" + " # comment line 1\n" + " # comment line 2\n" + " - dummy_tag_2\n" + " #################\n" + " # fancy comment #\n" + " #################\n" + " - dummy_tag_3" + ) expected = [ [ "# comment", @@ -141,10 +167,12 @@ class TestSplitIntoTags(unittest.TestCase): self.assertCountEqual(expected, result) def test_block_with_tags_and_inline_comments(self): - block = (("exclude_with_any_tags:\n" - " - dummy_tag_1 # inline comment\n" - " - dummy_tag_2 # another one\n" - " - dummy_tag_3 # and another one")) + block = ( + "exclude_with_any_tags:\n" + " - dummy_tag_1 # inline comment\n" + " - dummy_tag_2 # another one\n" + " - dummy_tag_3 # and another one" + ) expected = [ ["- dummy_tag_1 # inline comment"], ["- dummy_tag_2 # another one"], @@ -154,16 +182,18 @@ class TestSplitIntoTags(unittest.TestCase): self.assertCountEqual(expected, result) def test_block_with_tags_and_both_comments(self): - block = ("exclude_with_any_tags:\n" - " # above comment\n" - " - dummy_tag_1 # inline comment\n" - " # above comment line 1\n" - " # above comment line 2\n" - " - dummy_tag_2 # another one inline\n" - " #######################\n" - " # above fancy comment #\n" - " #######################\n" - " - dummy_tag_3 # and another one inline") + block = ( + "exclude_with_any_tags:\n" + " # above comment\n" + " - dummy_tag_1 # inline comment\n" + " # above comment line 1\n" + " # above comment line 2\n" + " - dummy_tag_2 # another one inline\n" + " #######################\n" + " # above fancy comment #\n" + " #######################\n" + " - dummy_tag_3 # and another one inline" + ) expected = [ [ "# above comment", @@ -228,8 +258,9 @@ class TestGetTagDoc(unittest.TestCase): "# above comment line 2", "- dummy_tag # inline comment", ] - expected = ("dummy_tag", ("above comment line 1\n" - "above comment line 2\n" - "inline comment")) + expected = ( + "dummy_tag", + ("above comment line 1\n" "above comment line 2\n" "inline comment"), + ) result = list_tags.get_tag_doc(tag_block) self.assertEqual(expected, result) diff --git a/buildscripts/tests/resmokelib/setup_multiversion/test_setup_multiversion.py b/buildscripts/tests/resmokelib/setup_multiversion/test_setup_multiversion.py index da6a5858716..9153fc586c0 100644 --- a/buildscripts/tests/resmokelib/setup_multiversion/test_setup_multiversion.py +++ b/buildscripts/tests/resmokelib/setup_multiversion/test_setup_multiversion.py @@ -1,4 +1,5 @@ """Unit tests for buildscripts/resmokelib/setup_multiversion/setup_multiversion.py.""" + import unittest from argparse import Namespace @@ -17,86 +18,86 @@ from buildscripts.resmokelib.utils import evergreen_conn class TestInferPlatform(unittest.TestCase): @patch("platform.system") def test_infer_platform_darwin(self, mock_system): - mock_system.return_value = 'Darwin' - pltf = infer_platform('base', "4.2") - self.assertEqual(pltf, 'osx') - pltf = infer_platform('enterprise', "4.2") - self.assertEqual(pltf, 'osx') - pltf = infer_platform('base', "4.0") - self.assertEqual(pltf, 'osx') + mock_system.return_value = "Darwin" + pltf = infer_platform("base", "4.2") + self.assertEqual(pltf, "osx") + pltf = infer_platform("enterprise", "4.2") + self.assertEqual(pltf, "osx") + pltf = infer_platform("base", "4.0") + self.assertEqual(pltf, "osx") pltf = infer_platform(None, "4.2") - self.assertEqual(pltf, 'osx') - pltf = infer_platform('base', None) - self.assertEqual(pltf, 'osx') + self.assertEqual(pltf, "osx") + pltf = infer_platform("base", None) + self.assertEqual(pltf, "osx") pltf = infer_platform(None, None) - self.assertEqual(pltf, 'osx') + self.assertEqual(pltf, "osx") @patch("platform.system") def test_infer_platform_windows(self, mock_system): - mock_system.return_value = 'Windows' - pltf = infer_platform('base', "4.2") - self.assertEqual(pltf, 'windows_x86_64-2012plus') - pltf = infer_platform('enterprise', "4.2") - self.assertEqual(pltf, 'windows') - pltf = infer_platform('base', "4.0") - self.assertEqual(pltf, 'windows') + mock_system.return_value = "Windows" + pltf = infer_platform("base", "4.2") + self.assertEqual(pltf, "windows_x86_64-2012plus") + pltf = infer_platform("enterprise", "4.2") + self.assertEqual(pltf, "windows") + pltf = infer_platform("base", "4.0") + self.assertEqual(pltf, "windows") pltf = infer_platform(None, "4.2") - self.assertEqual(pltf, 'windows') - pltf = infer_platform('base', None) - self.assertEqual(pltf, 'windows') + self.assertEqual(pltf, "windows") + pltf = infer_platform("base", None) + self.assertEqual(pltf, "windows") pltf = infer_platform(None, None) - self.assertEqual(pltf, 'windows') + self.assertEqual(pltf, "windows") @patch("distro.minor_version") @patch("distro.major_version") @patch("distro.id") @patch("platform.system") def test_infer_platform_linux(self, mock_system, mock_id, mock_major, mock_minor): - mock_system.return_value = 'Linux' - mock_id.return_value = 'ubuntu' - mock_major.return_value = '18' - mock_minor.return_value = '04' - pltf = infer_platform('base', "4.2") - self.assertEqual(pltf, 'ubuntu1804') - pltf = infer_platform('enterprise', "4.2") - self.assertEqual(pltf, 'ubuntu1804') - pltf = infer_platform('base', "4.0") - self.assertEqual(pltf, 'ubuntu1804') + mock_system.return_value = "Linux" + mock_id.return_value = "ubuntu" + mock_major.return_value = "18" + mock_minor.return_value = "04" + pltf = infer_platform("base", "4.2") + self.assertEqual(pltf, "ubuntu1804") + pltf = infer_platform("enterprise", "4.2") + self.assertEqual(pltf, "ubuntu1804") + pltf = infer_platform("base", "4.0") + self.assertEqual(pltf, "ubuntu1804") pltf = infer_platform(None, "4.2") - self.assertEqual(pltf, 'ubuntu1804') - pltf = infer_platform('base', None) - self.assertEqual(pltf, 'ubuntu1804') + self.assertEqual(pltf, "ubuntu1804") + pltf = infer_platform("base", None) + self.assertEqual(pltf, "ubuntu1804") pltf = infer_platform(None, None) - self.assertEqual(pltf, 'ubuntu1804') + self.assertEqual(pltf, "ubuntu1804") - mock_id.return_value = 'rhel' - mock_major.return_value = '8' - mock_minor.return_value = '0' - pltf = infer_platform('base', "4.2") - self.assertEqual(pltf, 'rhel80') - pltf = infer_platform('enterprise', "4.2") - self.assertEqual(pltf, 'rhel80') - pltf = infer_platform('base', "4.0") - self.assertEqual(pltf, 'rhel80') + mock_id.return_value = "rhel" + mock_major.return_value = "8" + mock_minor.return_value = "0" + pltf = infer_platform("base", "4.2") + self.assertEqual(pltf, "rhel80") + pltf = infer_platform("enterprise", "4.2") + self.assertEqual(pltf, "rhel80") + pltf = infer_platform("base", "4.0") + self.assertEqual(pltf, "rhel80") pltf = infer_platform(None, "4.2") - self.assertEqual(pltf, 'rhel80') - pltf = infer_platform('base', None) - self.assertEqual(pltf, 'rhel80') + self.assertEqual(pltf, "rhel80") + pltf = infer_platform("base", None) + self.assertEqual(pltf, "rhel80") pltf = infer_platform(None, None) - self.assertEqual(pltf, 'rhel80') + self.assertEqual(pltf, "rhel80") @patch("distro.id") @patch("platform.system") def test_infer_platform_others(self, mock_system, mock_id): - mock_system.return_value = 'Java' - self.assertRaises(ValueError, infer_platform, 'enterprise', "4.2") - self.assertRaises(ValueError, infer_platform, 'base', None) + mock_system.return_value = "Java" + self.assertRaises(ValueError, infer_platform, "enterprise", "4.2") + self.assertRaises(ValueError, infer_platform, "base", None) self.assertRaises(ValueError, infer_platform, None, "4.2") self.assertRaises(ValueError, infer_platform, None, None) - mock_system.return_value = 'Linux' - mock_id.return_value = 'debian' - self.assertRaises(ValueError, infer_platform, 'enterprise', "4.2") - self.assertRaises(ValueError, infer_platform, 'base', None) + mock_system.return_value = "Linux" + mock_id.return_value = "debian" + self.assertRaises(ValueError, infer_platform, "enterprise", "4.2") + self.assertRaises(ValueError, infer_platform, "base", None) self.assertRaises(ValueError, infer_platform, None, "4.2") self.assertRaises(ValueError, infer_platform, None, None) @@ -112,7 +113,8 @@ class TestSetupMultiversionBase(unittest.TestCase): "evergreen_projects": [ "mongodb-mongo-master", "mongodb-mongo-v4.4", - ], "evergreen_buildvariants": [ + ], + "evergreen_buildvariants": [ { "name": self.buildvariant_name, "edition": edition, @@ -125,7 +127,7 @@ class TestSetupMultiversionBase(unittest.TestCase): "platform": evergreen_conn.GENERIC_PLATFORM, "architecture": evergreen_conn.GENERIC_ARCHITECTURE, }, - ] + ], } download_options = _DownloadOptions(db=True, ds=False, da=False, dv=False) @@ -143,8 +145,9 @@ class TestSetupMultiversionBase(unittest.TestCase): download_options=download_options, debug=False, ) - with patch("buildscripts.resmokelib.setup_multiversion.config.SetupMultiversionConfig" - ) as mock_config: + with patch( + "buildscripts.resmokelib.setup_multiversion.config.SetupMultiversionConfig" + ) as mock_config: mock_config.return_value = SetupMultiversionConfig(raw_yaml_config) self.setup_multiversion = SetupMultiversion(**vars(options)) @@ -178,8 +181,9 @@ class TestSetupMultiversionGetLatestUrls(TestSetupMultiversionBase): @patch("evergreen.version.Version") @patch("evergreen.api.EvergreenApi.versions_by_project") @patch("buildscripts.resmokelib.utils.evergreen_conn.get_compile_artifact_urls") - def test_no_compile_artifacts(self, mock_get_compile_artifact_urls, mock_versions_by_project, - mock_version): + def test_no_compile_artifacts( + self, mock_get_compile_artifact_urls, mock_versions_by_project, mock_version + ): mock_version.build_variants_map = {self.buildvariant_name: "build_id"} mock_versions_by_project.return_value = iter([mock_version]) mock_get_compile_artifact_urls.return_value = {} @@ -190,11 +194,11 @@ class TestSetupMultiversionGetLatestUrls(TestSetupMultiversionBase): @patch("evergreen.version.Version") @patch("evergreen.api.EvergreenApi.versions_by_project") @patch("buildscripts.resmokelib.utils.evergreen_conn.get_compile_artifact_urls") - def test_urls_found_on_last_version(self, mock_get_compile_artifact_urls, - mock_versions_by_project, mock_version): + def test_urls_found_on_last_version( + self, mock_get_compile_artifact_urls, mock_versions_by_project, mock_version + ): expected_urls = { - "Binaries": - "https://mciuploads.s3.amazonaws.com/mongodb-mongo-master/ubuntu1804/90f767adbb1901d007ee4dd8714f53402d893669/binaries/mongo-mongodb_mongo_master_ubuntu1804_90f767adbb1901d007ee4dd8714f53402d893669_20_11_30_03_14_30.tgz" + "Binaries": "https://mciuploads.s3.amazonaws.com/mongodb-mongo-master/ubuntu1804/90f767adbb1901d007ee4dd8714f53402d893669/binaries/mongo-mongodb_mongo_master_ubuntu1804_90f767adbb1901d007ee4dd8714f53402d893669_20_11_30_03_14_30.tgz" } mock_version.build_variants_map = {self.buildvariant_name: "build_id"} @@ -208,12 +212,15 @@ class TestSetupMultiversionGetLatestUrls(TestSetupMultiversionBase): @patch("evergreen.version.Version") @patch("evergreen.api.EvergreenApi.versions_by_project") @patch("buildscripts.resmokelib.utils.evergreen_conn.get_compile_artifact_urls") - def test_urls_found_on_not_last_version(self, mock_get_compile_artifact_urls, - mock_versions_by_project, mock_version, - mock_expected_version): + def test_urls_found_on_not_last_version( + self, + mock_get_compile_artifact_urls, + mock_versions_by_project, + mock_version, + mock_expected_version, + ): expected_urls = { - "Binaries": - "https://mciuploads.s3.amazonaws.com/mongodb-mongo-master/ubuntu1804/90f767adbb1901d007ee4dd8714f53402d893669/binaries/mongo-mongodb_mongo_master_ubuntu1804_90f767adbb1901d007ee4dd8714f53402d893669_20_11_30_03_14_30.tgz" + "Binaries": "https://mciuploads.s3.amazonaws.com/mongodb-mongo-master/ubuntu1804/90f767adbb1901d007ee4dd8714f53402d893669/binaries/mongo-mongodb_mongo_master_ubuntu1804_90f767adbb1901d007ee4dd8714f53402d893669_20_11_30_03_14_30.tgz" } mock_version.build_variants_map = {self.buildvariant_name: "build_id"} @@ -221,11 +228,22 @@ class TestSetupMultiversionGetLatestUrls(TestSetupMultiversionBase): evg_versions = [mock_version for _ in range(3)] evg_versions.append(mock_expected_version) mock_versions_by_project.return_value = iter(evg_versions) - mock_get_compile_artifact_urls.side_effect = lambda evg_api, evg_version, buildvariant_name, ignore_failed_push: { - (self.setup_multiversion.evg_api, mock_version, self.buildvariant_name, False): {}, - (self.setup_multiversion.evg_api, mock_expected_version, self.buildvariant_name, False): - expected_urls, - }[evg_api, evg_version, buildvariant_name, ignore_failed_push] + mock_get_compile_artifact_urls.side_effect = ( + lambda evg_api, evg_version, buildvariant_name, ignore_failed_push: { + ( + self.setup_multiversion.evg_api, + mock_version, + self.buildvariant_name, + False, + ): {}, + ( + self.setup_multiversion.evg_api, + mock_expected_version, + self.buildvariant_name, + False, + ): expected_urls, + }[evg_api, evg_version, buildvariant_name, ignore_failed_push] + ) urlinfo = self.setup_multiversion.get_latest_urls("4.4") self.assertEqual(urlinfo.urls, expected_urls) @@ -234,16 +252,19 @@ class TestSetupMultiversionGetLatestUrls(TestSetupMultiversionBase): @patch("evergreen.version.Version") @patch("evergreen.api.EvergreenApi.versions_by_project") @patch("buildscripts.resmokelib.utils.evergreen_conn.get_compile_artifact_urls") - def test_start_from_revision(self, mock_get_compile_artifact_urls, mock_versions_by_project, - mock_version, mock_expected_version): + def test_start_from_revision( + self, + mock_get_compile_artifact_urls, + mock_versions_by_project, + mock_version, + mock_expected_version, + ): start_from_revision = "90f767adbb1901d007ee4dd8714f53402d893669" unexpected_urls = { - "Binaries": - "https://mciuploads.s3.amazonaws.com/project/build_variant/revision/binaries/unexpected.tgz" + "Binaries": "https://mciuploads.s3.amazonaws.com/project/build_variant/revision/binaries/unexpected.tgz" } expected_urls = { - "Binaries": - "https://mciuploads.s3.amazonaws.com/project/build_variant/90f767adbb1901d007ee4dd8714f53402d893669/binaries/expected.tgz" + "Binaries": "https://mciuploads.s3.amazonaws.com/project/build_variant/90f767adbb1901d007ee4dd8714f53402d893669/binaries/expected.tgz" } mock_version.build_variants_map = {self.buildvariant_name: "build_id"} @@ -254,12 +275,22 @@ class TestSetupMultiversionGetLatestUrls(TestSetupMultiversionBase): evg_versions.append(mock_expected_version) mock_versions_by_project.return_value = iter(evg_versions) - mock_get_compile_artifact_urls.side_effect = lambda evg_api, evg_version, buildvariant_name, ignore_failed_push: { - (self.setup_multiversion.evg_api, mock_version, self.buildvariant_name, False): - unexpected_urls, - (self.setup_multiversion.evg_api, mock_expected_version, self.buildvariant_name, False): - expected_urls, - }[evg_api, evg_version, buildvariant_name, ignore_failed_push] + mock_get_compile_artifact_urls.side_effect = ( + lambda evg_api, evg_version, buildvariant_name, ignore_failed_push: { + ( + self.setup_multiversion.evg_api, + mock_version, + self.buildvariant_name, + False, + ): unexpected_urls, + ( + self.setup_multiversion.evg_api, + mock_expected_version, + self.buildvariant_name, + False, + ): expected_urls, + }[evg_api, evg_version, buildvariant_name, ignore_failed_push] + ) urlinfo = self.setup_multiversion.get_latest_urls("master", start_from_revision) self.assertEqual(urlinfo.urls, expected_urls) @@ -268,18 +299,25 @@ class TestSetupMultiversionGetLatestUrls(TestSetupMultiversionBase): class TestSetupMultiversionGetUrls(TestSetupMultiversionBase): @patch("evergreen.version.Version") @patch("buildscripts.resmokelib.utils.evergreen_conn.get_evergreen_version") - @patch("buildscripts.resmokelib.setup_multiversion.github_conn.get_git_tag_and_commit") + @patch( + "buildscripts.resmokelib.setup_multiversion.github_conn.get_git_tag_and_commit" + ) @patch("buildscripts.resmokelib.utils.evergreen_conn.get_compile_artifact_urls") - def test_urls_by_binary_version_found(self, mock_get_compile_artifact_urls, - mock_get_git_tag_and_commit, mock_get_evergreen_version, - mock_version): + def test_urls_by_binary_version_found( + self, + mock_get_compile_artifact_urls, + mock_get_git_tag_and_commit, + mock_get_evergreen_version, + mock_version, + ): expected_urls = { - "Binaries": - "https://mciuploads.s3.amazonaws.com/mongodb-mongo-master/ubuntu1804/90f767adbb1901d007ee4dd8714f53402d893669/binaries/mongo-mongodb_mongo_master_ubuntu1804_90f767adbb1901d007ee4dd8714f53402d893669_20_11_30_03_14_30.tgz" + "Binaries": "https://mciuploads.s3.amazonaws.com/mongodb-mongo-master/ubuntu1804/90f767adbb1901d007ee4dd8714f53402d893669/binaries/mongo-mongodb_mongo_master_ubuntu1804_90f767adbb1901d007ee4dd8714f53402d893669_20_11_30_03_14_30.tgz" } - mock_get_git_tag_and_commit.return_value = ("r4.4.1", - "90f767adbb1901d007ee4dd8714f53402d893669") + mock_get_git_tag_and_commit.return_value = ( + "r4.4.1", + "90f767adbb1901d007ee4dd8714f53402d893669", + ) mock_version.build_variants_map = {self.buildvariant_name: "build_id"} mock_version.project_identifier = "mongodb-mongo-v4.4" mock_get_evergreen_version.return_value = mock_version @@ -291,11 +329,11 @@ class TestSetupMultiversionGetUrls(TestSetupMultiversionBase): @patch("evergreen.version.Version") @patch("buildscripts.resmokelib.utils.evergreen_conn.get_evergreen_version") @patch("buildscripts.resmokelib.utils.evergreen_conn.get_compile_artifact_urls") - def test_urls_by_commit_hash_found(self, mock_get_compile_artifact_urls, - mock_get_evergreen_version, mock_version): + def test_urls_by_commit_hash_found( + self, mock_get_compile_artifact_urls, mock_get_evergreen_version, mock_version + ): expected_urls = { - "Binaries": - "https://mciuploads.s3.amazonaws.com/mongodb-mongo-master/ubuntu1804/90f767adbb1901d007ee4dd8714f53402d893669/binaries/mongo-mongodb_mongo_master_ubuntu1804_90f767adbb1901d007ee4dd8714f53402d893669_20_11_30_03_14_30.tgz" + "Binaries": "https://mciuploads.s3.amazonaws.com/mongodb-mongo-master/ubuntu1804/90f767adbb1901d007ee4dd8714f53402d893669/binaries/mongo-mongodb_mongo_master_ubuntu1804_90f767adbb1901d007ee4dd8714f53402d893669_20_11_30_03_14_30.tgz" } mock_version.build_variants_map = {self.buildvariant_name: "build_id"} @@ -303,17 +341,28 @@ class TestSetupMultiversionGetUrls(TestSetupMultiversionBase): mock_get_evergreen_version.return_value = mock_version mock_get_compile_artifact_urls.return_value = expected_urls - urlinfo = self.setup_multiversion.get_urls("90f767adbb1901d007ee4dd8714f53402d893669") + urlinfo = self.setup_multiversion.get_urls( + "90f767adbb1901d007ee4dd8714f53402d893669" + ) self.assertEqual(urlinfo.urls, expected_urls) @patch("evergreen.version.Version") @patch("buildscripts.resmokelib.utils.evergreen_conn.get_evergreen_version") - @patch("buildscripts.resmokelib.setup_multiversion.github_conn.get_git_tag_and_commit") + @patch( + "buildscripts.resmokelib.setup_multiversion.github_conn.get_git_tag_and_commit" + ) @patch("buildscripts.resmokelib.utils.evergreen_conn.get_compile_artifact_urls") - def test_urls_not_found(self, mock_get_compile_artifact_urls, mock_get_git_tag_and_commit, - mock_get_evergreen_version, mock_version): - mock_get_git_tag_and_commit.return_value = ("r4.4.1", - "90f767adbb1901d007ee4dd8714f53402d893669") + def test_urls_not_found( + self, + mock_get_compile_artifact_urls, + mock_get_git_tag_and_commit, + mock_get_evergreen_version, + mock_version, + ): + mock_get_git_tag_and_commit.return_value = ( + "r4.4.1", + "90f767adbb1901d007ee4dd8714f53402d893669", + ) mock_version.version_id = "dummy-version-id" mock_version.build_variants_map = {self.buildvariant_name: "build_id"} mock_version.project_identifier = "mongodb-mongo-v4.4" @@ -325,10 +374,16 @@ class TestSetupMultiversionGetUrls(TestSetupMultiversionBase): self.assertEqual(urlinfo.evg_version_id, mock_version.version_id) @patch("buildscripts.resmokelib.utils.evergreen_conn.get_evergreen_version") - @patch("buildscripts.resmokelib.setup_multiversion.github_conn.get_git_tag_and_commit") - def test_evg_version_not_found(self, mock_get_git_tag_and_commit, mock_get_evergreen_version): - mock_get_git_tag_and_commit.return_value = ("r4.4.1", - "90f767adbb1901d007ee4dd8714f53402d893669") + @patch( + "buildscripts.resmokelib.setup_multiversion.github_conn.get_git_tag_and_commit" + ) + def test_evg_version_not_found( + self, mock_get_git_tag_and_commit, mock_get_evergreen_version + ): + mock_get_git_tag_and_commit.return_value = ( + "r4.4.1", + "90f767adbb1901d007ee4dd8714f53402d893669", + ) mock_get_evergreen_version.return_value = None urlinfo = self.setup_multiversion.get_urls("4.4.1") diff --git a/buildscripts/tests/resmokelib/test_multiversionconstants_location.py b/buildscripts/tests/resmokelib/test_multiversionconstants_location.py index 0e3396083fd..cc657da969d 100644 --- a/buildscripts/tests/resmokelib/test_multiversionconstants_location.py +++ b/buildscripts/tests/resmokelib/test_multiversionconstants_location.py @@ -1,18 +1,25 @@ """Unit tests to ensure buildscripts/resmokelib/multiversionconstants.py location for the one-click repro tool.""" + import importlib import unittest class TestMultiversionconstantsLocation(unittest.TestCase): def test_multiversionconstants_location(self): - multiversionconstants_module_name = "buildscripts.resmokelib.multiversionconstants" + multiversionconstants_module_name = ( + "buildscripts.resmokelib.multiversionconstants" + ) try: - under_test_module = importlib.import_module(multiversionconstants_module_name) + under_test_module = importlib.import_module( + multiversionconstants_module_name + ) except ImportError: - self.fail(f"Failed to import `{multiversionconstants_module_name}` module. One-click" - f" repro tool (https://github.com/10gen/db-contrib-tools) requires this" - f" module. If the module was changed, one-click repro tool should also" - f" be updated. Please reach out in #server-testing slack channel.") + self.fail( + f"Failed to import `{multiversionconstants_module_name}` module. One-click" + f" repro tool (https://github.com/10gen/db-contrib-tools) requires this" + f" module. If the module was changed, one-click repro tool should also" + f" be updated. Please reach out in #server-testing slack channel." + ) else: expected_consts = ["LAST_LTS_FCV", "LAST_CONTINUOUS_FCV"] for const in expected_consts: @@ -21,4 +28,5 @@ class TestMultiversionconstantsLocation(unittest.TestCase): f"`{const}` constant is not found in `{multiversionconstants_module_name}`" f" module. One-click repro tool (https://github.com/10gen/db-contrib-tools)" f" uses this constant. If the module was changed, one-click repro tool" - f" should also be updated. Please reach out in #server-testing slack channel.") + f" should also be updated. Please reach out in #server-testing slack channel.", + ) diff --git a/buildscripts/tests/resmokelib/test_parser.py b/buildscripts/tests/resmokelib/test_parser.py index 8250cccc10b..eea2851c565 100644 --- a/buildscripts/tests/resmokelib/test_parser.py +++ b/buildscripts/tests/resmokelib/test_parser.py @@ -10,311 +10,404 @@ class TestLocalCommandLine(unittest.TestCase): """Unit tests for the to_local_args() function.""" def test_keeps_any_positional_arguments(self): - cmdline = to_local_args([ - "run", - "test_file1.js", - "test_file2.js", - "test_file3.js", - "test_file4.js", - "--suites=my_suite", - "--storageEngine=my_storage_engine", - ]) + cmdline = to_local_args( + [ + "run", + "test_file1.js", + "test_file2.js", + "test_file3.js", + "test_file4.js", + "--suites=my_suite", + "--storageEngine=my_storage_engine", + ] + ) - self.assertEqual(cmdline, [ - "run", - "--suites=my_suite", - "--storageEngine=my_storage_engine", - "test_file1.js", - "test_file2.js", - "test_file3.js", - "test_file4.js", - ]) + self.assertEqual( + cmdline, + [ + "run", + "--suites=my_suite", + "--storageEngine=my_storage_engine", + "test_file1.js", + "test_file2.js", + "test_file3.js", + "test_file4.js", + ], + ) def test_keeps_continue_on_failure_option(self): - cmdline = to_local_args([ - "run", - "--suites=my_suite", - "--continueOnFailure", - "--storageEngine=my_storage_engine", - ]) + cmdline = to_local_args( + [ + "run", + "--suites=my_suite", + "--continueOnFailure", + "--storageEngine=my_storage_engine", + ] + ) - self.assertEqual(cmdline, [ - "run", - "--suites=my_suite", - "--storageEngine=my_storage_engine", - "--continueOnFailure", - ]) + self.assertEqual( + cmdline, + [ + "run", + "--suites=my_suite", + "--storageEngine=my_storage_engine", + "--continueOnFailure", + ], + ) def test_keeps_exclude_with_any_tags_option(self): - cmdline = to_local_args([ - "run", - "--suites=my_suite", - "--excludeWithAnyTags=tag1,tag2,tag4", - "--excludeWithAnyTags=tag3,tag5", - "--storageEngine=my_storage_engine", - ]) + cmdline = to_local_args( + [ + "run", + "--suites=my_suite", + "--excludeWithAnyTags=tag1,tag2,tag4", + "--excludeWithAnyTags=tag3,tag5", + "--storageEngine=my_storage_engine", + ] + ) - self.assertEqual(cmdline, [ - "run", - "--suites=my_suite", - "--storageEngine=my_storage_engine", - "--excludeWithAnyTags=tag1,tag2,tag4", - "--excludeWithAnyTags=tag3,tag5", - ]) + self.assertEqual( + cmdline, + [ + "run", + "--suites=my_suite", + "--storageEngine=my_storage_engine", + "--excludeWithAnyTags=tag1,tag2,tag4", + "--excludeWithAnyTags=tag3,tag5", + ], + ) def test_keeps_include_with_any_tags_option(self): - cmdline = to_local_args([ - "run", - "--suites=my_suite", - "--includeWithAnyTags=tag1,tag2,tag4", - "--includeWithAnyTags=tag3,tag5", - "--storageEngine=my_storage_engine", - ]) + cmdline = to_local_args( + [ + "run", + "--suites=my_suite", + "--includeWithAnyTags=tag1,tag2,tag4", + "--includeWithAnyTags=tag3,tag5", + "--storageEngine=my_storage_engine", + ] + ) - self.assertEqual(cmdline, [ - "run", - "--suites=my_suite", - "--storageEngine=my_storage_engine", - "--includeWithAnyTags=tag1,tag2,tag4", - "--includeWithAnyTags=tag3,tag5", - ]) + self.assertEqual( + cmdline, + [ + "run", + "--suites=my_suite", + "--storageEngine=my_storage_engine", + "--includeWithAnyTags=tag1,tag2,tag4", + "--includeWithAnyTags=tag3,tag5", + ], + ) def test_keeps_num_clients_per_fixture_option(self): - cmdline = to_local_args([ - "run", - "--suites=my_suite", - "--numClientsPerFixture=10", - "--storageEngine=my_storage_engine", - ]) + cmdline = to_local_args( + [ + "run", + "--suites=my_suite", + "--numClientsPerFixture=10", + "--storageEngine=my_storage_engine", + ] + ) - self.assertEqual(cmdline, [ - "run", - "--suites=my_suite", - "--storageEngine=my_storage_engine", - "--numClientsPerFixture=10", - ]) + self.assertEqual( + cmdline, + [ + "run", + "--suites=my_suite", + "--storageEngine=my_storage_engine", + "--numClientsPerFixture=10", + ], + ) def test_keeps_repeat_options(self): - cmdline = to_local_args([ - "run", - "--suites=my_suite", - "--repeatSuites=1000", - "--storageEngine=my_storage_engine", - ]) + cmdline = to_local_args( + [ + "run", + "--suites=my_suite", + "--repeatSuites=1000", + "--storageEngine=my_storage_engine", + ] + ) - self.assertEqual(cmdline, [ - "run", - "--suites=my_suite", - "--storageEngine=my_storage_engine", - "--repeatSuites=1000", - ]) + self.assertEqual( + cmdline, + [ + "run", + "--suites=my_suite", + "--storageEngine=my_storage_engine", + "--repeatSuites=1000", + ], + ) - cmdline = to_local_args([ - "run", - "--suites=my_suite", - "--repeatTests=1000", - "--storageEngine=my_storage_engine", - ]) + cmdline = to_local_args( + [ + "run", + "--suites=my_suite", + "--repeatTests=1000", + "--storageEngine=my_storage_engine", + ] + ) - self.assertEqual(cmdline, [ - "run", - "--suites=my_suite", - "--storageEngine=my_storage_engine", - "--repeatTests=1000", - ]) + self.assertEqual( + cmdline, + [ + "run", + "--suites=my_suite", + "--storageEngine=my_storage_engine", + "--repeatTests=1000", + ], + ) - cmdline = to_local_args([ - "run", - "--suites=my_suite", - "--repeatTestsMax=1000", - "--repeatTestsMin=20", - "--repeatTestsSecs=300", - "--storageEngine=my_storage_engine", - ]) + cmdline = to_local_args( + [ + "run", + "--suites=my_suite", + "--repeatTestsMax=1000", + "--repeatTestsMin=20", + "--repeatTestsSecs=300", + "--storageEngine=my_storage_engine", + ] + ) - self.assertEqual(cmdline, [ - "run", - "--suites=my_suite", - "--storageEngine=my_storage_engine", - "--repeatTestsMax=1000", - "--repeatTestsMin=20", - "--repeatTestsSecs=300.0", - ]) + self.assertEqual( + cmdline, + [ + "run", + "--suites=my_suite", + "--storageEngine=my_storage_engine", + "--repeatTestsMax=1000", + "--repeatTestsMin=20", + "--repeatTestsSecs=300.0", + ], + ) def test_keeps_shuffle_option(self): - cmdline = to_local_args([ - "run", - "--suites=my_suite", - "--shuffle", - "--storageEngine=my_storage_engine", - ]) + cmdline = to_local_args( + [ + "run", + "--suites=my_suite", + "--shuffle", + "--storageEngine=my_storage_engine", + ] + ) - self.assertEqual(cmdline, [ - "run", - "--suites=my_suite", - "--storageEngine=my_storage_engine", - "--shuffle", - ]) + self.assertEqual( + cmdline, + [ + "run", + "--suites=my_suite", + "--storageEngine=my_storage_engine", + "--shuffle", + ], + ) def test_keeps_storage_engine_cache_size_option(self): - cmdline = to_local_args([ - "run", - "--suites=my_suite", - "--storageEngineCacheSizeGB=1", - "--storageEngine=my_storage_engine", - ]) + cmdline = to_local_args( + [ + "run", + "--suites=my_suite", + "--storageEngineCacheSizeGB=1", + "--storageEngine=my_storage_engine", + ] + ) - self.assertEqual(cmdline, [ - "run", - "--suites=my_suite", - "--storageEngine=my_storage_engine", - "--storageEngineCacheSizeGB=1", - ]) + self.assertEqual( + cmdline, + [ + "run", + "--suites=my_suite", + "--storageEngine=my_storage_engine", + "--storageEngineCacheSizeGB=1", + ], + ) def test_origin_suite_option_replaces_suite_option(self): - cmdline = to_local_args([ - # We intentionally say --suite rather than --suites here to protect against this command - # line option from becoming ambiguous if more similarly named command line options are - # added in the future. - "run", - "--suite=part_of_my_suite", - "--originSuite=my_entire_suite", - "--storageEngine=my_storage_engine", - ]) + cmdline = to_local_args( + [ + # We intentionally say --suite rather than --suites here to protect against this command + # line option from becoming ambiguous if more similarly named command line options are + # added in the future. + "run", + "--suite=part_of_my_suite", + "--originSuite=my_entire_suite", + "--storageEngine=my_storage_engine", + ] + ) - self.assertEqual(cmdline, - ["run", "--suites=my_entire_suite", "--storageEngine=my_storage_engine"]) + self.assertEqual( + cmdline, + ["run", "--suites=my_entire_suite", "--storageEngine=my_storage_engine"], + ) def test_removes_archival_options(self): - cmdline = to_local_args([ - "run", - "--suites=my_suite", - "--archiveLimitMb=100", - "--archiveLimitTests=10", - "--storageEngine=my_storage_engine", - ]) + cmdline = to_local_args( + [ + "run", + "--suites=my_suite", + "--archiveLimitMb=100", + "--archiveLimitTests=10", + "--storageEngine=my_storage_engine", + ] + ) - self.assertEqual(cmdline, ["run", "--suites=my_suite", "--storageEngine=my_storage_engine"]) + self.assertEqual( + cmdline, ["run", "--suites=my_suite", "--storageEngine=my_storage_engine"] + ) def test_removes_evergreen_options(self): - cmdline = to_local_args([ - "run", - "--suites=my_suite", - "--buildId=some_build_id", - "--distroId=some_distro_id", - "--executionNumber=1", - "--gitRevision=c0de", - "--patchBuild", - "--projectName=some_project", - "--revisionOrderId=20", - "--taskName=some_task", - "--taskId=some_task_id", - "--variantName=some_variant", - "--versionId=some_version_id", - "--storageEngine=my_storage_engine", - ]) + cmdline = to_local_args( + [ + "run", + "--suites=my_suite", + "--buildId=some_build_id", + "--distroId=some_distro_id", + "--executionNumber=1", + "--gitRevision=c0de", + "--patchBuild", + "--projectName=some_project", + "--revisionOrderId=20", + "--taskName=some_task", + "--taskId=some_task_id", + "--variantName=some_variant", + "--versionId=some_version_id", + "--storageEngine=my_storage_engine", + ] + ) - self.assertEqual(cmdline, ["run", "--suites=my_suite", "--storageEngine=my_storage_engine"]) + self.assertEqual( + cmdline, ["run", "--suites=my_suite", "--storageEngine=my_storage_engine"] + ) def test_removes_log_option(self): - cmdline = to_local_args([ - "run", - "--suites=my_suite", - "--log=buildlogger", - "--buildloggerUrl=some_url", - "--storageEngine=my_storage_engine", - ]) + cmdline = to_local_args( + [ + "run", + "--suites=my_suite", + "--log=buildlogger", + "--buildloggerUrl=some_url", + "--storageEngine=my_storage_engine", + ] + ) - self.assertEqual(cmdline, ["run", "--suites=my_suite", "--storageEngine=my_storage_engine"]) + self.assertEqual( + cmdline, ["run", "--suites=my_suite", "--storageEngine=my_storage_engine"] + ) def test_removes_report_file_options(self): - cmdline = to_local_args([ - "run", - "--suites=my_suite", - "--reportFile=report.json", - "--storageEngine=my_storage_engine", - ]) + cmdline = to_local_args( + [ + "run", + "--suites=my_suite", + "--reportFile=report.json", + "--storageEngine=my_storage_engine", + ] + ) - self.assertEqual(cmdline, ["run", "--suites=my_suite", "--storageEngine=my_storage_engine"]) + self.assertEqual( + cmdline, ["run", "--suites=my_suite", "--storageEngine=my_storage_engine"] + ) def test_removes_stagger_jobs_option(self): - cmdline = to_local_args([ - "run", - "--suites=my_suite", - "--staggerJobs=on", - "--storageEngine=my_storage_engine", - ]) + cmdline = to_local_args( + [ + "run", + "--suites=my_suite", + "--staggerJobs=on", + "--storageEngine=my_storage_engine", + ] + ) - self.assertEqual(cmdline, ["run", "--suites=my_suite", "--storageEngine=my_storage_engine"]) + self.assertEqual( + cmdline, ["run", "--suites=my_suite", "--storageEngine=my_storage_engine"] + ) def test_removes_tag_file_option(self): - cmdline = to_local_args([ - "run", - "--suites=my_suite", - "--tagFile=etc/test_retrial.yml", - "--storageEngine=my_storage_engine", - ]) + cmdline = to_local_args( + [ + "run", + "--suites=my_suite", + "--tagFile=etc/test_retrial.yml", + "--storageEngine=my_storage_engine", + ] + ) - self.assertEqual(cmdline, ["run", "--suites=my_suite", "--storageEngine=my_storage_engine"]) + self.assertEqual( + cmdline, ["run", "--suites=my_suite", "--storageEngine=my_storage_engine"] + ) def test_accepts_space_delimited_args(self): - cmdline = to_local_args([ - "run", - "--suites", - "my_suite", - "--tagFile=etc/test_retrial.yml", - "--storageEngine", - "my_storage_engine", - "--includeWithAnyTags", - "tag1,tag2,tag4", - "--includeWithAnyTags", - "tag3,tag5", - ]) + cmdline = to_local_args( + [ + "run", + "--suites", + "my_suite", + "--tagFile=etc/test_retrial.yml", + "--storageEngine", + "my_storage_engine", + "--includeWithAnyTags", + "tag1,tag2,tag4", + "--includeWithAnyTags", + "tag3,tag5", + ] + ) - self.assertEqual(cmdline, [ - "run", - "--suites=my_suite", - "--storageEngine=my_storage_engine", - "--includeWithAnyTags=tag1,tag2,tag4", - "--includeWithAnyTags=tag3,tag5", - ]) + self.assertEqual( + cmdline, + [ + "run", + "--suites=my_suite", + "--storageEngine=my_storage_engine", + "--includeWithAnyTags=tag1,tag2,tag4", + "--includeWithAnyTags=tag3,tag5", + ], + ) class TestParseArgs(unittest.TestCase): """Unit tests for the parse() function.""" def test_files_at_end(self): - _, args = parse([ - "run", - "--suites=my_suite1,my_suite2", - "test_file1.js", - "test_file2.js", - "test_file3.js", - ]) + _, args = parse( + [ + "run", + "--suites=my_suite1,my_suite2", + "test_file1.js", + "test_file2.js", + "test_file3.js", + ] + ) - self.assertEqual(args.test_files, [ - "test_file1.js", - "test_file2.js", - "test_file3.js", - ]) + self.assertEqual( + args.test_files, + [ + "test_file1.js", + "test_file2.js", + "test_file3.js", + ], + ) # suites get split up when config.py gets populated self.assertEqual(args.suite_files, "my_suite1,my_suite2") def test_files_in_the_middle(self): - _, args = parse([ - "run", - "--storageEngine=my_storage_engine", - "test_file1.js", - "test_file2.js", - "test_file3.js", - "--suites=my_suite1", - ]) + _, args = parse( + [ + "run", + "--storageEngine=my_storage_engine", + "test_file1.js", + "test_file2.js", + "test_file3.js", + "--suites=my_suite1", + ] + ) - self.assertEqual(args.test_files, [ - "test_file1.js", - "test_file2.js", - "test_file3.js", - ]) + self.assertEqual( + args.test_files, + [ + "test_file1.js", + "test_file2.js", + "test_file3.js", + ], + ) self.assertEqual(args.suite_files, "my_suite1") @@ -322,13 +415,13 @@ class TestParseCommandLine(unittest.TestCase): """Unit tests for the parse_command_line() function.""" def test_find_suites(self): - subcommand_obj = parse_command_line(['find-suites']) - self.assertTrue(hasattr(subcommand_obj, 'execute')) + subcommand_obj = parse_command_line(["find-suites"]) + self.assertTrue(hasattr(subcommand_obj, "execute")) def test_list_suites(self): - subcommand_obj = parse_command_line(['list-suites']) - self.assertTrue(hasattr(subcommand_obj, 'execute')) + subcommand_obj = parse_command_line(["list-suites"]) + self.assertTrue(hasattr(subcommand_obj, "execute")) def test_run(self): - subcommand_obj = parse_command_line(['run', '--suite=my_suite', 'my_test.js']) - self.assertTrue(hasattr(subcommand_obj, 'execute')) + subcommand_obj = parse_command_line(["run", "--suite=my_suite", "my_test.js"]) + self.assertTrue(hasattr(subcommand_obj, "execute")) diff --git a/buildscripts/tests/resmokelib/test_selector.py b/buildscripts/tests/resmokelib/test_selector.py index 4d577bcdc6f..ea669df5a99 100644 --- a/buildscripts/tests/resmokelib/test_selector.py +++ b/buildscripts/tests/resmokelib/test_selector.py @@ -70,10 +70,14 @@ class TestExpressions(unittest.TestCase): tags_nomatch_3 = [tag2, "other_tag_2"] tags_nomatch_4 = [tag2] tags_nomatch_5 = ["other_tag_2"] - expression = selector.make_expression({"$allOf": [ - {"$anyOf": [tag1, tag2]}, - tag3, - ]}) + expression = selector.make_expression( + { + "$allOf": [ + {"$anyOf": [tag1, tag2]}, + tag3, + ] + } + ) self.assertIsInstance(expression, selector._AllOfExpression) self.assertTrue(expression(tags_match_1)) self.assertTrue(expression(tags_match_2)) @@ -102,39 +106,52 @@ class TestTestFileExplorer(unittest.TestCase): def test_fnmatchcase(self): pattern = "dir*/file.js" - self.assertTrue(self.test_file_explorer.fnmatchcase("directory/file.js", pattern)) + self.assertTrue( + self.test_file_explorer.fnmatchcase("directory/file.js", pattern) + ) self.assertFalse(self.test_file_explorer.fnmatchcase("other/file.js", pattern)) def test_parse_tag_files_single_file(self): - tests = (os.path.join(FIXTURE_PREFIX, "one.js"), os.path.join(FIXTURE_PREFIX, "two.js"), - os.path.join(FIXTURE_PREFIX, "three.js")) + tests = ( + os.path.join(FIXTURE_PREFIX, "one.js"), + os.path.join(FIXTURE_PREFIX, "two.js"), + os.path.join(FIXTURE_PREFIX, "three.js"), + ) expected = collections.defaultdict(list) expected[tests[0]] = ["tag1", "tag2", "tag3"] expected[tests[1]] = ["tag1", "tag2"] tags = self.test_file_explorer.parse_tag_files( - "js_test", [os.path.join(FIXTURE_PREFIX, "tag_file1.yml")]) + "js_test", [os.path.join(FIXTURE_PREFIX, "tag_file1.yml")] + ) # defaultdict isn't == comparable for test in tests: self.assertEqual(tags[test], expected[test]) expected[tests[1]] = ["tag1", "tag2", "tag4"] tags = self.test_file_explorer.parse_tag_files( - "js_test", [os.path.join(FIXTURE_PREFIX, "tag_file2.yml")], tags) + "js_test", [os.path.join(FIXTURE_PREFIX, "tag_file2.yml")], tags + ) for test in tests: self.assertEqual(tags[test], expected[test]) def test_parse_tag_files_multiple_file(self): - tests = (os.path.join(FIXTURE_PREFIX, "one.js"), os.path.join(FIXTURE_PREFIX, "two.js"), - os.path.join(FIXTURE_PREFIX, "three.js")) + tests = ( + os.path.join(FIXTURE_PREFIX, "one.js"), + os.path.join(FIXTURE_PREFIX, "two.js"), + os.path.join(FIXTURE_PREFIX, "three.js"), + ) expected = collections.defaultdict(list) expected[tests[0]] = ["tag1", "tag2", "tag3"] expected[tests[1]] = ["tag1", "tag2", "tag4"] - tags = self.test_file_explorer.parse_tag_files("js_test", [ - os.path.join(FIXTURE_PREFIX, "tag_file1.yml"), - os.path.join(FIXTURE_PREFIX, "tag_file2.yml") - ]) + tags = self.test_file_explorer.parse_tag_files( + "js_test", + [ + os.path.join(FIXTURE_PREFIX, "tag_file1.yml"), + os.path.join(FIXTURE_PREFIX, "tag_file2.yml"), + ], + ) # defaultdict isn't == comparable for test in tests: self.assertEqual(tags[test], expected[test]) @@ -149,16 +166,27 @@ class MockTestFileExplorer(object): def __init__(self): self.files = [ - "dir/subdir1/test11.js", "dir/subdir1/test12.js", "dir/subdir2/test21.js", - "dir/subdir3/a/test3a1.js", "build/testA", "build/testB", "build/testC", "dbtest", - "dbtest.exe" + "dir/subdir1/test11.js", + "dir/subdir1/test12.js", + "dir/subdir2/test21.js", + "dir/subdir3/a/test3a1.js", + "build/testA", + "build/testB", + "build/testC", + "dbtest", + "dbtest.exe", ] self.tags = { - "dir/subdir1/test11.js": ["tag1", "tag2"], "dir/subdir1/test12.js": ["tag3"], - "dir/subdir2/test21.js": ["tag2", "tag4"], "dir/subdir3/a/test3a1.js": ["tag4", "tag5"] + "dir/subdir1/test11.js": ["tag1", "tag2"], + "dir/subdir1/test12.js": ["tag3"], + "dir/subdir2/test21.js": ["tag2", "tag4"], + "dir/subdir3/a/test3a1.js": ["tag4", "tag5"], } self.binary = MockTestFileExplorer.BINARY - self.jstest_tag_file = {"dir/subdir1/test11.js": "tagA", "dir/subdir3/a/test3a1.js": "tagB"} + self.jstest_tag_file = { + "dir/subdir1/test11.js": "tagA", + "dir/subdir3/a/test3a1.js": "tagB", + } def is_glob_pattern(self, pattern): return globstar.is_glob_pattern(pattern) @@ -203,14 +231,18 @@ class TestTestList(unittest.TestCase): def test_roots(self): roots = ["a", "b"] - test_list = selector._TestList(self.test_file_explorer, roots, tests_are_files=False) + test_list = selector._TestList( + self.test_file_explorer, roots, tests_are_files=False + ) selected, excluded = test_list.get_tests() self.assertEqual(roots, selected) self.assertEqual([], excluded) def test_roots_normpath(self): roots = ["dir/a/abc.js", "dir/b/xyz.js"] - test_list = selector._TestList(self.test_file_explorer, roots, tests_are_files=False) + test_list = selector._TestList( + self.test_file_explorer, roots, tests_are_files=False + ) selected, excluded = test_list.get_tests() for root_file, selected_file in zip(roots, selected): self.assertEqual(os.path.normpath(root_file), selected_file) @@ -227,17 +259,23 @@ class TestTestList(unittest.TestCase): def test_roots_with_unmatching_glob(self): glob_roots = ["dir/unknown_subdir1/*.js"] with self.assertRaisesRegex( - errors.SuiteSelectorConfigurationError, - re.escape("Pattern(s) and/or filename(s) in `roots`" - " do not match any existing test files: ['dir/unknown_subdir1/*.js']")): + errors.SuiteSelectorConfigurationError, + re.escape( + "Pattern(s) and/or filename(s) in `roots`" + " do not match any existing test files: ['dir/unknown_subdir1/*.js']" + ), + ): selector._TestList(self.test_file_explorer, glob_roots) def test_roots_unknown_file(self): roots = ["dir/subdir1/unknown"] with self.assertRaisesRegex( - errors.SuiteSelectorConfigurationError, - re.escape("Pattern(s) and/or filename(s) in `roots`" - " do not match any existing test files: ['dir/subdir1/unknown']")): + errors.SuiteSelectorConfigurationError, + re.escape( + "Pattern(s) and/or filename(s) in `roots`" + " do not match any existing test files: ['dir/subdir1/unknown']" + ), + ): selector._TestList(self.test_file_explorer, roots, tests_are_files=True) def test_include_files(self): @@ -252,9 +290,12 @@ class TestTestList(unittest.TestCase): roots = ["dir/subdir1/*.js", "dir/subdir2/test21.*"] test_list = selector._TestList(self.test_file_explorer, roots) with self.assertRaisesRegex( - errors.SuiteSelectorConfigurationError, - re.escape("Pattern(s) and/or filename(s) in `include_files`" - " do not match any existing test files: ['dir/subdir2/test26.js']")): + errors.SuiteSelectorConfigurationError, + re.escape( + "Pattern(s) and/or filename(s) in `include_files`" + " do not match any existing test files: ['dir/subdir2/test26.js']" + ), + ): test_list.include_files(["dir/subdir2/test26.js"]) def test_exclude_files(self): @@ -284,7 +325,14 @@ class TestTestList(unittest.TestCase): roots = ["dir/subdir1/*.js", "dir/subdir2/test21.*"] test_list = selector._TestList(self.test_file_explorer, roots) expression = selector.make_expression( - {"$anyOf": [{"$allOf": ["tag1", "tag2"]}, "tag3", {"$allOf": ["tag5", "tag6"]}]}) + { + "$anyOf": [ + {"$allOf": ["tag1", "tag2"]}, + "tag3", + {"$allOf": ["tag5", "tag6"]}, + ] + } + ) def get_tags(test_file): return self.test_file_explorer.jstest_tags(test_file) @@ -302,39 +350,64 @@ class TestTestList(unittest.TestCase): selected, excluded = test_list.get_tests() self.assertEqual(["dir/subdir3/a/test3a1.js"], selected) self.assertEqual( - ["dir/subdir1/test11.js", "dir/subdir1/test12.js", "dir/subdir2/test21.js"], excluded) + ["dir/subdir1/test11.js", "dir/subdir1/test12.js", "dir/subdir2/test21.js"], + excluded, + ) # 1 pattern and 0 matching test_list = selector._TestList(self.test_file_explorer, roots) test_list.include_any_pattern(["dir/*4/a/*"]) selected, excluded = test_list.get_tests() self.assertEqual([], selected) - self.assertEqual([ - "dir/subdir1/test11.js", "dir/subdir1/test12.js", "dir/subdir2/test21.js", - "dir/subdir3/a/test3a1.js" - ], excluded) + self.assertEqual( + [ + "dir/subdir1/test11.js", + "dir/subdir1/test12.js", + "dir/subdir2/test21.js", + "dir/subdir3/a/test3a1.js", + ], + excluded, + ) # 3 patterns and 1 matching test_list = selector._TestList(self.test_file_explorer, roots) - test_list.include_any_pattern(["dir/*3/a/*", "notmaching/*", "notmatching2/*.js"]) + test_list.include_any_pattern( + ["dir/*3/a/*", "notmaching/*", "notmatching2/*.js"] + ) selected, excluded = test_list.get_tests() self.assertEqual(["dir/subdir3/a/test3a1.js"], selected) self.assertEqual( - ["dir/subdir1/test11.js", "dir/subdir1/test12.js", "dir/subdir2/test21.js"], excluded) + ["dir/subdir1/test11.js", "dir/subdir1/test12.js", "dir/subdir2/test21.js"], + excluded, + ) # 3 patterns and 0 matching test_list = selector._TestList(self.test_file_explorer, roots) - test_list.include_any_pattern(["dir2/*3/a/*", "notmaching/*", "notmatching2/*.js"]) + test_list.include_any_pattern( + ["dir2/*3/a/*", "notmaching/*", "notmatching2/*.js"] + ) selected, excluded = test_list.get_tests() self.assertEqual([], selected) - self.assertEqual([ - "dir/subdir1/test11.js", "dir/subdir1/test12.js", "dir/subdir2/test21.js", - "dir/subdir3/a/test3a1.js" - ], excluded) + self.assertEqual( + [ + "dir/subdir1/test11.js", + "dir/subdir1/test12.js", + "dir/subdir2/test21.js", + "dir/subdir3/a/test3a1.js", + ], + excluded, + ) # 3 patterns and 3 matching test_list = selector._TestList(self.test_file_explorer, roots) - test_list.include_any_pattern(["dir/*1/*11*", "dir/subdir3/**", "dir/subdir2/*.js"]) + test_list.include_any_pattern( + ["dir/*1/*11*", "dir/subdir3/**", "dir/subdir2/*.js"] + ) selected, excluded = test_list.get_tests() self.assertEqual( - ["dir/subdir1/test11.js", "dir/subdir2/test21.js", "dir/subdir3/a/test3a1.js"], - selected) + [ + "dir/subdir1/test11.js", + "dir/subdir2/test21.js", + "dir/subdir3/a/test3a1.js", + ], + selected, + ) self.assertEqual(["dir/subdir1/test12.js"], excluded) def test_include_tests(self): @@ -348,7 +421,9 @@ class TestTestList(unittest.TestCase): def test_tests_are_not_files(self): roots = ["a", "b"] - test_list = selector._TestList(self.test_file_explorer, roots, tests_are_files=False) + test_list = selector._TestList( + self.test_file_explorer, roots, tests_are_files=False + ) with self.assertRaises(TypeError): test_list.include_files([]) with self.assertRaises(TypeError): @@ -365,8 +440,9 @@ class TestSelectorConfig(unittest.TestCase): selector._SelectorConfig(include_tags="tag1", exclude_tags="tag2") def test_multi_jstest_selector_config(self): - sc = selector._MultiJSTestSelectorConfig(roots=["test1", "test2"], group_size=1234, - group_count_multiplier=5678) + sc = selector._MultiJSTestSelectorConfig( + roots=["test1", "test2"], group_size=1234, group_count_multiplier=5678 + ) self.assertEqual(sc.group_size, 1234) self.assertEqual(sc.group_count_multiplier, 5678) @@ -378,55 +454,85 @@ class TestSelector(unittest.TestCase): def test_select_all(self): config = selector._SelectorConfig( - roots=["dir/subdir1/*.js", "dir/subdir2/*.js", "dir/subdir3/a/*.js"]) + roots=["dir/subdir1/*.js", "dir/subdir2/*.js", "dir/subdir3/a/*.js"] + ) selected, excluded = self.selector.select(config) - self.assertEqual([ - "dir/subdir1/test11.js", "dir/subdir1/test12.js", "dir/subdir2/test21.js", - "dir/subdir3/a/test3a1.js" - ], selected) + self.assertEqual( + [ + "dir/subdir1/test11.js", + "dir/subdir1/test12.js", + "dir/subdir2/test21.js", + "dir/subdir3/a/test3a1.js", + ], + selected, + ) self.assertEqual([], excluded) def test_select_exclude_files(self): config = selector._SelectorConfig( roots=["dir/subdir1/*.js", "dir/subdir2/*.js", "dir/subdir3/a/*.js"], - exclude_files=["dir/subdir2/test21.js"]) + exclude_files=["dir/subdir2/test21.js"], + ) selected, excluded = self.selector.select(config) self.assertEqual( - ["dir/subdir1/test11.js", "dir/subdir1/test12.js", "dir/subdir3/a/test3a1.js"], - selected) + [ + "dir/subdir1/test11.js", + "dir/subdir1/test12.js", + "dir/subdir3/a/test3a1.js", + ], + selected, + ) self.assertEqual(["dir/subdir2/test21.js"], excluded) def test_select_include_files(self): config = selector._SelectorConfig( roots=["dir/subdir1/*.js", "dir/subdir2/*.js", "dir/subdir3/a/*.js"], - include_files=["dir/subdir2/test21.js"]) + include_files=["dir/subdir2/test21.js"], + ) selected, excluded = self.selector.select(config) self.assertEqual(["dir/subdir2/test21.js"], selected) self.assertEqual( - ["dir/subdir1/test11.js", "dir/subdir1/test12.js", "dir/subdir3/a/test3a1.js"], - excluded) + [ + "dir/subdir1/test11.js", + "dir/subdir1/test12.js", + "dir/subdir3/a/test3a1.js", + ], + excluded, + ) def test_select_include_tags(self): config = selector._SelectorConfig( roots=["dir/subdir1/*.js", "dir/subdir2/*.js", "dir/subdir3/a/*.js"], - include_tags="tag1") + include_tags="tag1", + ) selected, excluded = self.selector.select(config) self.assertEqual([], selected) - self.assertEqual([ - "dir/subdir1/test11.js", "dir/subdir1/test12.js", "dir/subdir2/test21.js", - "dir/subdir3/a/test3a1.js" - ], excluded) + self.assertEqual( + [ + "dir/subdir1/test11.js", + "dir/subdir1/test12.js", + "dir/subdir2/test21.js", + "dir/subdir3/a/test3a1.js", + ], + excluded, + ) def test_select_include_any_tags(self): config = selector._SelectorConfig( roots=["dir/subdir1/*.js", "dir/subdir2/*.js", "dir/subdir3/a/*.js"], - include_with_any_tags=["tag1"]) + include_with_any_tags=["tag1"], + ) selected, excluded = self.selector.select(config) self.assertEqual([], selected) - self.assertEqual([ - "dir/subdir1/test11.js", "dir/subdir1/test12.js", "dir/subdir2/test21.js", - "dir/subdir3/a/test3a1.js" - ], excluded) + self.assertEqual( + [ + "dir/subdir1/test11.js", + "dir/subdir1/test12.js", + "dir/subdir2/test21.js", + "dir/subdir3/a/test3a1.js", + ], + excluded, + ) class TestMultiJSSelector(unittest.TestCase): @@ -436,30 +542,41 @@ class TestMultiJSSelector(unittest.TestCase): @unittest.skip("Known broken. SERVER-48969 tracks re-enabling.") def test_multi_js_test_selector_normal(self): - config = selector._MultiJSTestSelectorConfig(roots=["dir/**/*.js"], group_size=3, - group_count_multiplier=2) + config = selector._MultiJSTestSelectorConfig( + roots=["dir/**/*.js"], group_size=3, group_count_multiplier=2 + ) selected, _ = self.selector.select(config) total = 0 for group in selected[:-1]: - self.assertEqual(len(group), 3, "{} did not have 3 unique tests".format(group)) + self.assertEqual( + len(group), 3, "{} did not have 3 unique tests".format(group) + ) total += 3 self.assertLessEqual( - len(selected[-1]), 3, - "Last selected group did not have 3 or fewer tests: {}".format(selected[-1])) + len(selected[-1]), + 3, + "Last selected group did not have 3 or fewer tests: {}".format( + selected[-1] + ), + ) total += len(selected[-1]) - self.assertEqual(total, MockTestFileExplorer.NUM_JS_FILES * config.group_count_multiplier, - "The total number of workloads is incorrect") + self.assertEqual( + total, + MockTestFileExplorer.NUM_JS_FILES * config.group_count_multiplier, + "The total number of workloads is incorrect", + ) @unittest.skip("Known broken. SERVER-48969 tracks re-enabling.") def test_multi_js_test_selector_one_group(self): """Test we return only one group if the group size equals number of files.""" num_files = MockTestFileExplorer.NUM_JS_FILES - config = selector._MultiJSTestSelectorConfig(roots=["dir/**/*.js"], group_size=num_files, - group_count_multiplier=9999999) + config = selector._MultiJSTestSelectorConfig( + roots=["dir/**/*.js"], group_size=num_files, group_count_multiplier=9999999 + ) selected, _ = self.selector.select(config) self.assertEqual(len(selected), 1) self.assertEqual(len(selected[0]), num_files) @@ -476,8 +593,9 @@ class TestFilterTests(unittest.TestCase): def test_cpp_all(self): config = {"root": "integrationtest.txt"} - selected, excluded = selector.filter_tests("cpp_integration_test", config, - self.test_file_explorer) + selected, excluded = selector.filter_tests( + "cpp_integration_test", config, self.test_file_explorer + ) self.assertEqual(["build/testA", "build/testB"], selected) self.assertEqual([], excluded) @@ -485,22 +603,28 @@ class TestFilterTests(unittest.TestCase): # When roots are specified for cpp tests they override all filtering since # 'roots' are populated with the command line arguments. config = {"include_files": "unknown_file", "roots": ["build/testC"]} - selected, excluded = selector.filter_tests("cpp_unit_test", config, self.test_file_explorer) + selected, excluded = selector.filter_tests( + "cpp_unit_test", config, self.test_file_explorer + ) self.assertEqual(["build/testC"], selected) self.assertEqual([], excluded) - selected, excluded = selector.filter_tests("cpp_integration_test", config, - self.test_file_explorer) + selected, excluded = selector.filter_tests( + "cpp_integration_test", config, self.test_file_explorer + ) self.assertEqual(["build/testC"], selected) self.assertEqual([], excluded) def test_cpp_expand_roots(self): config = {"root": "integrationtest.txt", "roots": ["build/test*"]} - selected, excluded = selector.filter_tests("cpp_integration_test", config, - self.test_file_explorer) + selected, excluded = selector.filter_tests( + "cpp_integration_test", config, self.test_file_explorer + ) self.assertEqual(["build/testA", "build/testB", "build/testC"], selected) self.assertEqual([], excluded) - selected, excluded = selector.filter_tests("cpp_unit_test", config, self.test_file_explorer) + selected, excluded = selector.filter_tests( + "cpp_unit_test", config, self.test_file_explorer + ) self.assertEqual(["build/testA", "build/testB", "build/testC"], selected) self.assertEqual([], excluded) @@ -508,8 +632,11 @@ class TestFilterTests(unittest.TestCase): buildscripts.resmokelib.config.INCLUDE_WITH_ANY_TAGS = ["tag1"] try: selector_config = {"root": "unittest.txt"} - selected, excluded = selector.filter_tests("cpp_unit_test", selector_config, - test_file_explorer=self.test_file_explorer) + selected, excluded = selector.filter_tests( + "cpp_unit_test", + selector_config, + test_file_explorer=self.test_file_explorer, + ) self.assertEqual([], selected) self.assertEqual(["build/testA", "build/testB"], excluded) finally: @@ -518,33 +645,51 @@ class TestFilterTests(unittest.TestCase): def test_jstest_include_tags(self): config = { "roots": ["dir/subdir1/*.js", "dir/subdir2/*.js", "dir/subdir3/a/*.js"], - "include_tags": "tag1" + "include_tags": "tag1", } - selected, excluded = selector.filter_tests("js_test", config, self.test_file_explorer) + selected, excluded = selector.filter_tests( + "js_test", config, self.test_file_explorer + ) self.assertEqual(["dir/subdir1/test11.js"], selected) self.assertEqual( - ["dir/subdir1/test12.js", "dir/subdir2/test21.js", "dir/subdir3/a/test3a1.js"], - excluded) + [ + "dir/subdir1/test12.js", + "dir/subdir2/test21.js", + "dir/subdir3/a/test3a1.js", + ], + excluded, + ) def test_jstest_exclude_tags(self): config = { "roots": ["dir/subdir1/*.js", "dir/subdir2/*.js", "dir/subdir3/a/*.js"], - "exclude_tags": "tag1" + "exclude_tags": "tag1", } - selected, excluded = selector.filter_tests("js_test", config, self.test_file_explorer) + selected, excluded = selector.filter_tests( + "js_test", config, self.test_file_explorer + ) self.assertEqual( - ["dir/subdir1/test12.js", "dir/subdir2/test21.js", "dir/subdir3/a/test3a1.js"], - selected) + [ + "dir/subdir1/test12.js", + "dir/subdir2/test21.js", + "dir/subdir3/a/test3a1.js", + ], + selected, + ) self.assertEqual(["dir/subdir1/test11.js"], excluded) def test_jstest_exclude_with_any_tags(self): config = { "roots": ["dir/subdir1/*.js", "dir/subdir2/*.js", "dir/subdir3/a/*.js"], - "exclude_with_any_tags": ["tag2"] + "exclude_with_any_tags": ["tag2"], } - selected, excluded = selector.filter_tests("js_test", config, self.test_file_explorer) + selected, excluded = selector.filter_tests( + "js_test", config, self.test_file_explorer + ) self.assertEqual(["dir/subdir1/test11.js", "dir/subdir2/test21.js"], excluded) - self.assertEqual(["dir/subdir1/test12.js", "dir/subdir3/a/test3a1.js"], selected) + self.assertEqual( + ["dir/subdir1/test12.js", "dir/subdir3/a/test3a1.js"], selected + ) @unittest.skip("Known broken. SERVER-48969 tracks re-enabling.") def test_filter_temporarily_disabled_tests(self): @@ -552,78 +697,117 @@ class TestFilterTests(unittest.TestCase): test_file_explorer = MockTestFileExplorer() test_file_explorer.tags = { "dir/subdir1/test11.js": ["tag1", "tag2", "__TEMPORARILY_DISABLED__"], - "dir/subdir1/test12.js": ["tag3"], "dir/subdir2/test21.js": ["tag2", "tag4"] + "dir/subdir1/test12.js": ["tag3"], + "dir/subdir2/test21.js": ["tag2", "tag4"], } config = {"roots": ["dir/subdir1/*.js", "dir/subdir2/*.js"]} - selected, excluded = selector.filter_tests("js_test", config, test_file_explorer) + selected, excluded = selector.filter_tests( + "js_test", config, test_file_explorer + ) self.assertEqual(["dir/subdir1/test11.js"], excluded) self.assertEqual(["dir/subdir1/test12.js", "dir/subdir2/test21.js"], selected) def test_jstest_include(self): config = { "roots": ["dir/subdir1/*.js", "dir/subdir2/*.js", "dir/subdir3/a/*.js"], - "include_files": ["dir/subdir1/*.js"], "exclude_tags": "tag1" + "include_files": ["dir/subdir1/*.js"], + "exclude_tags": "tag1", } - selected, excluded = selector.filter_tests("js_test", config, self.test_file_explorer) + selected, excluded = selector.filter_tests( + "js_test", config, self.test_file_explorer + ) self.assertEqual(["dir/subdir1/test11.js", "dir/subdir1/test12.js"], selected) - self.assertEqual(["dir/subdir2/test21.js", "dir/subdir3/a/test3a1.js"], excluded) + self.assertEqual( + ["dir/subdir2/test21.js", "dir/subdir3/a/test3a1.js"], excluded + ) def test_jstest_all(self): - config = {"roots": ["dir/subdir1/*.js", "dir/subdir2/*.js", "dir/subdir3/a/*.js"]} - selected, excluded = selector.filter_tests("js_test", config, self.test_file_explorer) - self.assertEqual([ - "dir/subdir1/test11.js", "dir/subdir1/test12.js", "dir/subdir2/test21.js", - "dir/subdir3/a/test3a1.js" - ], selected) + config = { + "roots": ["dir/subdir1/*.js", "dir/subdir2/*.js", "dir/subdir3/a/*.js"] + } + selected, excluded = selector.filter_tests( + "js_test", config, self.test_file_explorer + ) + self.assertEqual( + [ + "dir/subdir1/test11.js", + "dir/subdir1/test12.js", + "dir/subdir2/test21.js", + "dir/subdir3/a/test3a1.js", + ], + selected, + ) self.assertEqual([], excluded) def test_jstest_include_with_any_tags(self): config = { "roots": ["dir/subdir1/*.js", "dir/subdir2/*.js", "dir/subdir3/a/*.js"], - "include_with_any_tags": ["tag2"] + "include_with_any_tags": ["tag2"], } - selected, excluded = selector.filter_tests("js_test", config, self.test_file_explorer) + selected, excluded = selector.filter_tests( + "js_test", config, self.test_file_explorer + ) self.assertEqual(["dir/subdir1/test11.js", "dir/subdir2/test21.js"], selected) - self.assertEqual(["dir/subdir1/test12.js", "dir/subdir3/a/test3a1.js"], excluded) + self.assertEqual( + ["dir/subdir1/test12.js", "dir/subdir3/a/test3a1.js"], excluded + ) def test_jstest_unknown_file(self): config = {"roots": ["dir/subdir1/*.js", "dir/subdir1/unknown"]} with self.assertRaisesRegex( - errors.SuiteSelectorConfigurationError, - re.escape("Pattern(s) and/or filename(s) in `roots`" - " do not match any existing test files: ['dir/subdir1/unknown']")): + errors.SuiteSelectorConfigurationError, + re.escape( + "Pattern(s) and/or filename(s) in `roots`" + " do not match any existing test files: ['dir/subdir1/unknown']" + ), + ): selector.filter_tests("js_test", config, self.test_file_explorer) def test_json_schema_exclude_files(self): config = { "roots": ["dir/subdir1/*.js", "dir/subdir2/*.js", "dir/subdir3/a/*.js"], - "exclude_files": ["dir/subdir2/test21.js"] + "exclude_files": ["dir/subdir2/test21.js"], } - selected, excluded = selector.filter_tests("json_schema_test", config, - self.test_file_explorer) + selected, excluded = selector.filter_tests( + "json_schema_test", config, self.test_file_explorer + ) self.assertEqual( - ["dir/subdir1/test11.js", "dir/subdir1/test12.js", "dir/subdir3/a/test3a1.js"], - selected) + [ + "dir/subdir1/test11.js", + "dir/subdir1/test12.js", + "dir/subdir3/a/test3a1.js", + ], + selected, + ) self.assertEqual(["dir/subdir2/test21.js"], excluded) def test_json_schema_include_files(self): config = { "roots": ["dir/subdir1/*.js", "dir/subdir2/*.js", "dir/subdir3/a/*.js"], - "include_files": ["dir/subdir2/test21.js"] + "include_files": ["dir/subdir2/test21.js"], } - selected, excluded = selector.filter_tests("json_schema_test", config, - self.test_file_explorer) + selected, excluded = selector.filter_tests( + "json_schema_test", config, self.test_file_explorer + ) self.assertEqual(["dir/subdir2/test21.js"], selected) self.assertEqual( - ["dir/subdir1/test11.js", "dir/subdir1/test12.js", "dir/subdir3/a/test3a1.js"], - excluded) + [ + "dir/subdir1/test11.js", + "dir/subdir1/test12.js", + "dir/subdir3/a/test3a1.js", + ], + excluded, + ) @unittest.skipUnless( os.path.exists(MockTestFileExplorer.BINARY), - "{} not built".format(MockTestFileExplorer.BINARY)) + "{} not built".format(MockTestFileExplorer.BINARY), + ) def test_db_tests_all(self): config = {"binary": self.test_file_explorer.binary} - selected, excluded = selector.filter_tests("db_test", config, self.test_file_explorer) + selected, excluded = selector.filter_tests( + "db_test", config, self.test_file_explorer + ) self.assertEqual(["dbtestA", "dbtestB", "dbtestC"], selected) self.assertEqual([], excluded) @@ -631,18 +815,27 @@ class TestFilterTests(unittest.TestCase): # When roots are specified for db_tests they override all filtering since # 'roots' are populated with the command line arguments. config = { - "binary": self.test_file_explorer.binary, "include_suites": ["dbtestB"], - "roots": ["dbtestOverride"] + "binary": self.test_file_explorer.binary, + "include_suites": ["dbtestB"], + "roots": ["dbtestOverride"], } - selected, excluded = selector.filter_tests("db_test", config, self.test_file_explorer) + selected, excluded = selector.filter_tests( + "db_test", config, self.test_file_explorer + ) self.assertEqual(["dbtestOverride"], selected) self.assertEqual([], excluded) @unittest.skipUnless( os.path.exists(MockTestFileExplorer.BINARY), - "{} not built".format(MockTestFileExplorer.BINARY)) + "{} not built".format(MockTestFileExplorer.BINARY), + ) def test_db_tests_include_suites(self): - config = {"binary": self.test_file_explorer.binary, "include_suites": ["dbtestB"]} - selected, excluded = selector.filter_tests("db_test", config, self.test_file_explorer) + config = { + "binary": self.test_file_explorer.binary, + "include_suites": ["dbtestB"], + } + selected, excluded = selector.filter_tests( + "db_test", config, self.test_file_explorer + ) self.assertEqual(["dbtestB"], selected) self.assertEqual(["dbtestA", "dbtestC"], excluded) diff --git a/buildscripts/tests/resmokelib/test_suitesconfig.py b/buildscripts/tests/resmokelib/test_suitesconfig.py index e6a566869e5..39e57a9c30c 100644 --- a/buildscripts/tests/resmokelib/test_suitesconfig.py +++ b/buildscripts/tests/resmokelib/test_suitesconfig.py @@ -14,18 +14,23 @@ RESMOKELIB = "buildscripts.resmokelib" class TestSuitesConfig(unittest.TestCase): @mock.patch(RESMOKELIB + ".testing.suite.Suite") @mock.patch(RESMOKELIB + ".suitesconfig.get_named_suites") - def test_no_suites_matching_test_kind(self, mock_get_named_suites, mock_suite_class): + def test_no_suites_matching_test_kind( + self, mock_get_named_suites, mock_suite_class + ): all_suites = ["core", "replica_sets_jscore_passthrough"] mock_get_named_suites.return_value = all_suites - membership_map = suitesconfig.create_test_membership_map(test_kind="nonexistent_test") + membership_map = suitesconfig.create_test_membership_map( + test_kind="nonexistent_test" + ) self.assertEqual(membership_map, {}) self.assertEqual(mock_suite_class.call_count, 2) @mock.patch(RESMOKELIB + ".testing.suite.Suite.tests") @mock.patch(RESMOKELIB + ".suitesconfig.get_named_suites") - def test_multiple_suites_matching_single_test_kind(self, mock_get_named_suites, - mock_suite_get_tests): + def test_multiple_suites_matching_single_test_kind( + self, mock_get_named_suites, mock_suite_get_tests + ): all_suites = ["core", "replica_sets_jscore_passthrough"] mock_get_named_suites.return_value = all_suites @@ -36,13 +41,15 @@ class TestSuitesConfig(unittest.TestCase): @mock.patch(RESMOKELIB + ".testing.suite.Suite.tests") @mock.patch(RESMOKELIB + ".suitesconfig.get_named_suites") - def test_multiple_suites_matching_multiple_test_kinds(self, mock_get_named_suites, - mock_suite_get_tests): + def test_multiple_suites_matching_multiple_test_kinds( + self, mock_get_named_suites, mock_suite_get_tests + ): all_suites = ["core", "concurrency"] mock_get_named_suites.return_value = all_suites mock_suite_get_tests.__get__ = mock.Mock(return_value=["test1", "test2"]) membership_map = suitesconfig.create_test_membership_map( - test_kind=("fsm_workload_test", "js_test")) + test_kind=("fsm_workload_test", "js_test") + ) self.assertEqual(membership_map, dict(test1=all_suites, test2=all_suites)) diff --git a/buildscripts/tests/resmokelib/test_undodb.py b/buildscripts/tests/resmokelib/test_undodb.py index afe69623883..4390241a3d3 100644 --- a/buildscripts/tests/resmokelib/test_undodb.py +++ b/buildscripts/tests/resmokelib/test_undodb.py @@ -1,4 +1,5 @@ """Fetch subcommand unittest.""" + import unittest from mock import MagicMock, patch @@ -14,15 +15,22 @@ class TestFetch(unittest.TestCase): @patch("buildscripts.resmokelib.undodb.fetch.urlopen") @patch("buildscripts.resmokelib.undodb.fetch.copyfileobj") @patch("tarfile.open") - def test_fetch(self, tarfile_open_mock, copyfileobj_mock, urlopen_mock, get_api_mock): + def test_fetch( + self, tarfile_open_mock, copyfileobj_mock, urlopen_mock, get_api_mock + ): api_mock = MagicMock() get_api_mock.return_value = api_mock - api_mock.task_by_id.return_value = evergreen.task.Task({ - "artifacts": [{ - "name": "UndoDB Recordings - Execution 1", - "url": "fake://somewhere.over/the/rainbow.tgz", - }] - }, api_mock) + api_mock.task_by_id.return_value = evergreen.task.Task( + { + "artifacts": [ + { + "name": "UndoDB Recordings - Execution 1", + "url": "fake://somewhere.over/the/rainbow.tgz", + } + ] + }, + api_mock, + ) subcommand = fetch.Fetch("task_id") subcommand.execute() diff --git a/buildscripts/tests/resmokelib/testing/fixtures/test_api_adherence.py b/buildscripts/tests/resmokelib/testing/fixtures/test_api_adherence.py index 60d2e7736cb..b7fe7d36c70 100644 --- a/buildscripts/tests/resmokelib/testing/fixtures/test_api_adherence.py +++ b/buildscripts/tests/resmokelib/testing/fixtures/test_api_adherence.py @@ -1,4 +1,5 @@ """Unit tests for the resmokelib.testing.fixtures.interface module.""" + import ast import os import unittest @@ -30,7 +31,10 @@ class AdherenceChecker(ast.NodeVisitor): self.allowed_imports = allowed_imports def check_breakage(self, module): - if module.split('.')[0] == self.disallowed_root and module not in self.allowed_imports: + if ( + module.split(".")[0] == self.disallowed_root + and module not in self.allowed_imports + ): self.breakages.append(module) def visit_Import(self, node): # pylint: disable=invalid-name @@ -47,7 +51,9 @@ class TestFixtureAPIAdherence(unittest.TestCase): def test_api_adherence(self): (_, _, filenames) = next(os.walk(FIXTURE_PATH)) pathnames = [ - os.path.join(FIXTURE_PATH, file) for file in filenames if file not in IGNORED_FILES + os.path.join(FIXTURE_PATH, file) + for file in filenames + if file not in IGNORED_FILES ] for path in pathnames: self._check_file(path) @@ -59,6 +65,7 @@ class TestFixtureAPIAdherence(unittest.TestCase): msg = ( f"File {pathname} imports the following modules that possibly break the fixture API: {checker.breakages}. " f"Only files from {ALLOWED_IMPORTS} may be imported. If making an API-breaking change, please add to " - "fixturelib.py and increment the API version. For statements of form \"from x import y\", please ensure " - "that x is the pathname of the module being imported and that y is not a file.") + 'fixturelib.py and increment the API version. For statements of form "from x import y", please ensure ' + "that x is the pathname of the module being imported and that y is not a file." + ) self.assertFalse(len(checker.breakages), msg) diff --git a/buildscripts/tests/resmokelib/testing/fixtures/test_builder.py b/buildscripts/tests/resmokelib/testing/fixtures/test_builder.py index e3c7ed63fe7..5f1f5569e80 100644 --- a/buildscripts/tests/resmokelib/testing/fixtures/test_builder.py +++ b/buildscripts/tests/resmokelib/testing/fixtures/test_builder.py @@ -1,4 +1,5 @@ """Unit tests for the resmokelib.testing.fixtures._builder module.""" + # pylint: disable=protected-access,invalid-name import unittest from unittest.mock import MagicMock @@ -39,34 +40,50 @@ class TestBuildShardedCluster(unittest.TestCase): def test_build_sharded_cluster_simple(self): parser.set_run_options() fixture_config = {"mongod_options": {SET_PARAMS: {"enableTestCommands": 1}}} - sharded_cluster = under_test.make_fixture(self.fixture_class_name, self.mock_logger, - self.job_num, **fixture_config) + sharded_cluster = under_test.make_fixture( + self.fixture_class_name, self.mock_logger, self.job_num, **fixture_config + ) self.assertEqual(len(sharded_cluster.configsvr.nodes), 1) self.assertEqual(len(sharded_cluster.shards), 1) self.assertEqual(len(sharded_cluster.shards[0].nodes), 1) self.assertEqual(len(sharded_cluster.mongos), 1) from buildscripts.resmokelib import multiversionconstants - self.assertEqual(sharded_cluster.shards[0].fcv, multiversionconstants.LATEST_FCV) + + self.assertEqual( + sharded_cluster.shards[0].fcv, multiversionconstants.LATEST_FCV + ) def test_build_sharded_cluster_with_feature_flags(self): ff_name = "featureFlagDummy" parser.set_run_options(f"--additionalFeatureFlags={ff_name}") fixture_config = {"mongod_options": {SET_PARAMS: {"enableTestCommands": 1}}} - sharded_cluster = under_test.make_fixture(self.fixture_class_name, self.mock_logger, - self.job_num, **fixture_config) + sharded_cluster = under_test.make_fixture( + self.fixture_class_name, self.mock_logger, self.job_num, **fixture_config + ) self.assertEqual(len(sharded_cluster.configsvr.nodes), 1) self.assertEqual(len(sharded_cluster.shards), 1) self.assertEqual(len(sharded_cluster.shards[0].nodes), 1) self.assertEqual(len(sharded_cluster.mongos), 1) from buildscripts.resmokelib import multiversionconstants - self.assertEqual(sharded_cluster.shards[0].fcv, multiversionconstants.LATEST_FCV) + + self.assertEqual( + sharded_cluster.shards[0].fcv, multiversionconstants.LATEST_FCV + ) # feature flags are set - self.assertIn(ff_name, sharded_cluster.configsvr.nodes[0].mongod_options[SET_PARAMS]) - self.assertTrue(sharded_cluster.configsvr.nodes[0].mongod_options[SET_PARAMS][ff_name]) - self.assertIn(ff_name, sharded_cluster.shards[0].nodes[0].mongod_options[SET_PARAMS]) - self.assertTrue(sharded_cluster.shards[0].nodes[0].mongod_options[SET_PARAMS][ff_name]) + self.assertIn( + ff_name, sharded_cluster.configsvr.nodes[0].mongod_options[SET_PARAMS] + ) + self.assertTrue( + sharded_cluster.configsvr.nodes[0].mongod_options[SET_PARAMS][ff_name] + ) + self.assertIn( + ff_name, sharded_cluster.shards[0].nodes[0].mongod_options[SET_PARAMS] + ) + self.assertTrue( + sharded_cluster.shards[0].nodes[0].mongod_options[SET_PARAMS][ff_name] + ) self.assertIn(ff_name, sharded_cluster.mongos[0].mongos_options[SET_PARAMS]) self.assertTrue(sharded_cluster.mongos[0].mongos_options[SET_PARAMS][ff_name]) @@ -80,8 +97,9 @@ class TestBuildShardedCluster(unittest.TestCase): "mixed_bin_versions": "new_old_old_new", "old_bin_version": "last_lts", } - sharded_cluster = under_test.make_fixture(self.fixture_class_name, self.mock_logger, - self.job_num, **fixture_config) + sharded_cluster = under_test.make_fixture( + self.fixture_class_name, self.mock_logger, self.job_num, **fixture_config + ) self.assertEqual(len(sharded_cluster.configsvr.nodes), 2) self.assertEqual(len(sharded_cluster.shards), 2) @@ -90,26 +108,45 @@ class TestBuildShardedCluster(unittest.TestCase): self.assertEqual(len(sharded_cluster.mongos), 1) from buildscripts.resmokelib import multiversionconstants + # configsvr nodes are always latest - self.assertEqual(sharded_cluster.configsvr.nodes[0].mongod_executable, - config.DEFAULT_MONGOD_EXECUTABLE) - self.assertEqual(sharded_cluster.configsvr.nodes[1].mongod_executable, - config.DEFAULT_MONGOD_EXECUTABLE) + self.assertEqual( + sharded_cluster.configsvr.nodes[0].mongod_executable, + config.DEFAULT_MONGOD_EXECUTABLE, + ) + self.assertEqual( + sharded_cluster.configsvr.nodes[1].mongod_executable, + config.DEFAULT_MONGOD_EXECUTABLE, + ) # 1st repl set nodes are latest and last-lts (new_old) - self.assertEqual(sharded_cluster.shards[0].nodes[0].mongod_executable, - config.DEFAULT_MONGOD_EXECUTABLE) - self.assertEqual(sharded_cluster.shards[0].nodes[1].mongod_executable, - multiversionconstants.LAST_LTS_MONGOD_BINARY) - self.assertEqual(sharded_cluster.shards[0].fcv, multiversionconstants.LAST_LTS_FCV) + self.assertEqual( + sharded_cluster.shards[0].nodes[0].mongod_executable, + config.DEFAULT_MONGOD_EXECUTABLE, + ) + self.assertEqual( + sharded_cluster.shards[0].nodes[1].mongod_executable, + multiversionconstants.LAST_LTS_MONGOD_BINARY, + ) + self.assertEqual( + sharded_cluster.shards[0].fcv, multiversionconstants.LAST_LTS_FCV + ) # 2st repl set nodes are last-lts and latest (old_new) - self.assertEqual(sharded_cluster.shards[1].nodes[0].mongod_executable, - multiversionconstants.LAST_LTS_MONGOD_BINARY) - self.assertEqual(sharded_cluster.shards[1].nodes[1].mongod_executable, - config.DEFAULT_MONGOD_EXECUTABLE) - self.assertEqual(sharded_cluster.shards[0].fcv, multiversionconstants.LAST_LTS_FCV) + self.assertEqual( + sharded_cluster.shards[1].nodes[0].mongod_executable, + multiversionconstants.LAST_LTS_MONGOD_BINARY, + ) + self.assertEqual( + sharded_cluster.shards[1].nodes[1].mongod_executable, + config.DEFAULT_MONGOD_EXECUTABLE, + ) + self.assertEqual( + sharded_cluster.shards[0].fcv, multiversionconstants.LAST_LTS_FCV + ) # mongos is last-lts - self.assertEqual(sharded_cluster.mongos[0].mongos_executable, - multiversionconstants.LAST_LTS_MONGOS_BINARY) + self.assertEqual( + sharded_cluster.mongos[0].mongos_executable, + multiversionconstants.LAST_LTS_MONGOS_BINARY, + ) def test_build_sharded_cluster_multiversion_with_feature_flags(self): ff_name = "featureFlagDummy" @@ -122,8 +159,9 @@ class TestBuildShardedCluster(unittest.TestCase): "mixed_bin_versions": "new_old_old_new", "old_bin_version": "last_lts", } - sharded_cluster = under_test.make_fixture(self.fixture_class_name, self.mock_logger, - self.job_num, **fixture_config) + sharded_cluster = under_test.make_fixture( + self.fixture_class_name, self.mock_logger, self.job_num, **fixture_config + ) self.assertEqual(len(sharded_cluster.configsvr.nodes), 2) self.assertEqual(len(sharded_cluster.shards), 2) @@ -131,15 +169,35 @@ class TestBuildShardedCluster(unittest.TestCase): self.assertEqual(len(sharded_cluster.shards[1].nodes), 2) self.assertEqual(len(sharded_cluster.mongos), 1) # feature flags are set on new versions - self.assertIn(ff_name, sharded_cluster.configsvr.nodes[0].mongod_options[SET_PARAMS]) - self.assertTrue(sharded_cluster.configsvr.nodes[0].mongod_options[SET_PARAMS][ff_name]) - self.assertIn(ff_name, sharded_cluster.configsvr.nodes[1].mongod_options[SET_PARAMS]) - self.assertTrue(sharded_cluster.configsvr.nodes[1].mongod_options[SET_PARAMS][ff_name]) - self.assertIn(ff_name, sharded_cluster.shards[0].nodes[0].mongod_options[SET_PARAMS]) - self.assertTrue(sharded_cluster.shards[0].nodes[0].mongod_options[SET_PARAMS][ff_name]) - self.assertIn(ff_name, sharded_cluster.shards[1].nodes[1].mongod_options[SET_PARAMS]) - self.assertTrue(sharded_cluster.shards[1].nodes[1].mongod_options[SET_PARAMS][ff_name]) + self.assertIn( + ff_name, sharded_cluster.configsvr.nodes[0].mongod_options[SET_PARAMS] + ) + self.assertTrue( + sharded_cluster.configsvr.nodes[0].mongod_options[SET_PARAMS][ff_name] + ) + self.assertIn( + ff_name, sharded_cluster.configsvr.nodes[1].mongod_options[SET_PARAMS] + ) + self.assertTrue( + sharded_cluster.configsvr.nodes[1].mongod_options[SET_PARAMS][ff_name] + ) + self.assertIn( + ff_name, sharded_cluster.shards[0].nodes[0].mongod_options[SET_PARAMS] + ) + self.assertTrue( + sharded_cluster.shards[0].nodes[0].mongod_options[SET_PARAMS][ff_name] + ) + self.assertIn( + ff_name, sharded_cluster.shards[1].nodes[1].mongod_options[SET_PARAMS] + ) + self.assertTrue( + sharded_cluster.shards[1].nodes[1].mongod_options[SET_PARAMS][ff_name] + ) # feature flags are NOT set on old versions - self.assertNotIn(ff_name, sharded_cluster.shards[0].nodes[1].mongod_options[SET_PARAMS]) - self.assertNotIn(ff_name, sharded_cluster.shards[1].nodes[0].mongod_options[SET_PARAMS]) + self.assertNotIn( + ff_name, sharded_cluster.shards[0].nodes[1].mongod_options[SET_PARAMS] + ) + self.assertNotIn( + ff_name, sharded_cluster.shards[1].nodes[0].mongod_options[SET_PARAMS] + ) self.assertNotIn(ff_name, sharded_cluster.mongos[0].mongos_options[SET_PARAMS]) diff --git a/buildscripts/tests/resmokelib/testing/fixtures/test_fixturelib.py b/buildscripts/tests/resmokelib/testing/fixtures/test_fixturelib.py index 6b9495c5207..acae0ae135d 100644 --- a/buildscripts/tests/resmokelib/testing/fixtures/test_fixturelib.py +++ b/buildscripts/tests/resmokelib/testing/fixtures/test_fixturelib.py @@ -12,14 +12,17 @@ class TestMergeMongoOptionDicts(unittest.TestCase): def test_merge_empty(self): original = { - "dbpath": "value0", self.under_test.SET_PARAMETERS_KEY: { + "dbpath": "value0", + self.under_test.SET_PARAMETERS_KEY: { "param1": "value1", "param2": "value2", - } + }, } override = {} - merged = self.under_test.merge_mongo_option_dicts(copy.deepcopy(original), override) + merged = self.under_test.merge_mongo_option_dicts( + copy.deepcopy(original), override + ) self.assertDictEqual(merged, original) @@ -27,8 +30,13 @@ class TestMergeMongoOptionDicts(unittest.TestCase): non_param1_key = "non_param1" non_param2_key = "non_param2" original = { - non_param1_key: "value0", non_param2_key: {"nested_param1": "value0", }, - self.under_test.SET_PARAMETERS_KEY: {"param1": "value1", } + non_param1_key: "value0", + non_param2_key: { + "nested_param1": "value0", + }, + self.under_test.SET_PARAMETERS_KEY: { + "param1": "value1", + }, } override = { @@ -39,26 +47,38 @@ class TestMergeMongoOptionDicts(unittest.TestCase): self.under_test.merge_mongo_option_dicts(original, override) expected = { - non_param1_key: "value1", non_param2_key: "value1", - self.under_test.SET_PARAMETERS_KEY: {"param1": "value1", } + non_param1_key: "value1", + non_param2_key: "value1", + self.under_test.SET_PARAMETERS_KEY: { + "param1": "value1", + }, } self.assertEqual(original, expected) def test_merge_params(self): original = { - "dbpath": "value", self.under_test.SET_PARAMETERS_KEY: { + "dbpath": "value", + self.under_test.SET_PARAMETERS_KEY: { "param1": "value", - "param2": {"param3": "value", }, - } + "param2": { + "param3": "value", + }, + }, } - override = {self.under_test.SET_PARAMETERS_KEY: {"param2": {"param3": {"param4": "value"}}}} + override = { + self.under_test.SET_PARAMETERS_KEY: { + "param2": {"param3": {"param4": "value"}} + } + } self.under_test.merge_mongo_option_dicts(original, override) expected = { - "dbpath": "value", self.under_test.SET_PARAMETERS_KEY: { - "param1": "value", "param2": {"param3": {"param4": "value"}} - } + "dbpath": "value", + self.under_test.SET_PARAMETERS_KEY: { + "param1": "value", + "param2": {"param3": {"param4": "value"}}, + }, } self.assertDictEqual(original, expected) diff --git a/buildscripts/tests/resmokelib/testing/fixtures/test_interface.py b/buildscripts/tests/resmokelib/testing/fixtures/test_interface.py index 4a0862270c3..165147be30b 100644 --- a/buildscripts/tests/resmokelib/testing/fixtures/test_interface.py +++ b/buildscripts/tests/resmokelib/testing/fixtures/test_interface.py @@ -1,4 +1,5 @@ """Unit tests for the resmokelib.testing.fixtures.interface module.""" + import logging import unittest @@ -20,7 +21,9 @@ class TestFixture(unittest.TestCase): class TestFixtureTeardownHandler(unittest.TestCase): def test_teardown_ok(self): - handler = interface.FixtureTeardownHandler(logging.getLogger("handler_unittests")) + handler = interface.FixtureTeardownHandler( + logging.getLogger("handler_unittests") + ) # Before any teardown. self.assertTrue(handler.was_successful()) self.assertIsNone(handler.get_error_message()) @@ -32,7 +35,9 @@ class TestFixtureTeardownHandler(unittest.TestCase): self.assertIsNone(handler.get_error_message()) def test_teardown_error(self): - handler = interface.FixtureTeardownHandler(logging.getLogger("handler_unittests")) + handler = interface.FixtureTeardownHandler( + logging.getLogger("handler_unittests") + ) # Failing teardown. ko_fixture = UnitTestFixture(should_raise=True) handler.teardown(ko_fixture, "ko") diff --git a/buildscripts/tests/resmokelib/testing/hooks/test_generate_and_check_perf_results.py b/buildscripts/tests/resmokelib/testing/hooks/test_generate_and_check_perf_results.py index 56195ed31e8..2d1eebc367c 100755 --- a/buildscripts/tests/resmokelib/testing/hooks/test_generate_and_check_perf_results.py +++ b/buildscripts/tests/resmokelib/testing/hooks/test_generate_and_check_perf_results.py @@ -51,41 +51,89 @@ _BM_REPORT_2 = { } _BM_REPORT_WITH_INSTRUCTIONS_1 = { - "name": "BM_Name1/arg1/arg with space", "run_type": "iteration", "repetition_index": 0, - "threads": 1, "iterations": 1000, "real_time": 1204, "cpu_time": 1305, "bytes_per_second": 1406, - "items_per_second": 1507, "custom_counter_1": 1608, "instructions_per_iteration": 101 + "name": "BM_Name1/arg1/arg with space", + "run_type": "iteration", + "repetition_index": 0, + "threads": 1, + "iterations": 1000, + "real_time": 1204, + "cpu_time": 1305, + "bytes_per_second": 1406, + "items_per_second": 1507, + "custom_counter_1": 1608, + "instructions_per_iteration": 101, } _BM_REPORT_WITH_INSTRUCTIONS_2 = { - "name": "BM_Name1/arg1/arg with space", "run_type": "iteration", "repetition_index": 0, - "threads": 2, "iterations": 1000, "real_time": 1202, "cpu_time": 1303, "bytes_per_second": 1404, - "items_per_second": 1505, "custom_counter_1": 1606, "instructions_per_iteration": 100 + "name": "BM_Name1/arg1/arg with space", + "run_type": "iteration", + "repetition_index": 0, + "threads": 2, + "iterations": 1000, + "real_time": 1202, + "cpu_time": 1303, + "bytes_per_second": 1404, + "items_per_second": 1505, + "custom_counter_1": 1606, + "instructions_per_iteration": 100, } _BM_REPORT_WITH_INSTRUCTIONS_MEAN = { - "name": "BM_Name1/arg1/arg with space_mean", "run_type": "aggregate", "repetition_index": 0, - "threads": 2, "iterations": 1000, "real_time": 1202, "cpu_time": 1303, "bytes_per_second": 1404, - "items_per_second": 1505, "custom_counter_1": 1606, "instructions_per_iteration": 100, - "aggregate_name": "mean" + "name": "BM_Name1/arg1/arg with space_mean", + "run_type": "aggregate", + "repetition_index": 0, + "threads": 2, + "iterations": 1000, + "real_time": 1202, + "cpu_time": 1303, + "bytes_per_second": 1404, + "items_per_second": 1505, + "custom_counter_1": 1606, + "instructions_per_iteration": 100, + "aggregate_name": "mean", } _BM_REPORT_WITH_CYCLES_1 = { - "name": "BM_Name1/arg1/arg with space", "run_type": "iteration", "repetition_index": 0, - "threads": 1, "iterations": 1000, "real_time": 1204, "cpu_time": 1305, "bytes_per_second": 1406, - "items_per_second": 1507, "custom_counter_1": 1608, "cycles_per_iteration": 101 + "name": "BM_Name1/arg1/arg with space", + "run_type": "iteration", + "repetition_index": 0, + "threads": 1, + "iterations": 1000, + "real_time": 1204, + "cpu_time": 1305, + "bytes_per_second": 1406, + "items_per_second": 1507, + "custom_counter_1": 1608, + "cycles_per_iteration": 101, } _BM_REPORT_WITH_CYCLES_2 = { - "name": "BM_Name1/arg1/arg with space", "run_type": "iteration", "repetition_index": 0, - "threads": 2, "iterations": 1000, "real_time": 1202, "cpu_time": 1303, "bytes_per_second": 1404, - "items_per_second": 1505, "custom_counter_1": 1606, "cycles_per_iteration": 100 + "name": "BM_Name1/arg1/arg with space", + "run_type": "iteration", + "repetition_index": 0, + "threads": 2, + "iterations": 1000, + "real_time": 1202, + "cpu_time": 1303, + "bytes_per_second": 1404, + "items_per_second": 1505, + "custom_counter_1": 1606, + "cycles_per_iteration": 100, } _BM_REPORT_WITH_CYCLES_MEAN = { - "name": "BM_Name1/arg1/arg with space_mean", "run_type": "aggregate", "repetition_index": 0, - "threads": 2, "iterations": 1000, "real_time": 1202, "cpu_time": 1303, "bytes_per_second": 1404, - "items_per_second": 1505, "custom_counter_1": 1606, "cycles_per_iteration": 100, - "aggregate_name": "mean" + "name": "BM_Name1/arg1/arg with space_mean", + "run_type": "aggregate", + "repetition_index": 0, + "threads": 2, + "iterations": 1000, + "real_time": 1202, + "cpu_time": 1303, + "bytes_per_second": 1404, + "items_per_second": 1505, + "custom_counter_1": 1606, + "cycles_per_iteration": 100, + "aggregate_name": "mean", } _BM_MEAN_REPORT = { @@ -128,26 +176,26 @@ _BM_MULTITHREAD_MEDIAN_REPORT = { } _BM_FULL_REPORT = { - "context": - _BM_CONTEXT, "benchmarks": [ - _BM_REPORT_1, - _BM_REPORT_2, - _BM_MEAN_REPORT, - _BM_MULTITHREAD_REPORT, - _BM_MULTITHREAD_MEDIAN_REPORT, - ] + "context": _BM_CONTEXT, + "benchmarks": [ + _BM_REPORT_1, + _BM_REPORT_2, + _BM_MEAN_REPORT, + _BM_MULTITHREAD_REPORT, + _BM_MULTITHREAD_MEDIAN_REPORT, + ], } _BM_FULL_REPORT_WITH_DUPS = { - "context": - _BM_CONTEXT, "benchmarks": [ - _BM_REPORT_1, - _BM_REPORT_1, - _BM_REPORT_2, - _BM_MEAN_REPORT, - _BM_MULTITHREAD_REPORT, - _BM_MULTITHREAD_MEDIAN_REPORT, - ] + "context": _BM_CONTEXT, + "benchmarks": [ + _BM_REPORT_1, + _BM_REPORT_1, + _BM_REPORT_2, + _BM_MEAN_REPORT, + _BM_MULTITHREAD_REPORT, + _BM_MULTITHREAD_MEDIAN_REPORT, + ], } # 12/31/2999 @ 11:59pm (UTC) @@ -158,7 +206,6 @@ _END_TIME = 32503680000 class GenerateAndCheckPerfResultsFixture(unittest.TestCase): - # Mock the hook's parent class because we're testing only functionality of this hook and # not anything related to or inherit from the parent class. @mock.patch("buildscripts.resmokelib.testing.hooks.interface.Hook", autospec=True) @@ -174,7 +221,9 @@ class GenerateAndCheckPerfResultsFixture(unittest.TestCase): class TestGenerateAndCheckPerfResults(GenerateAndCheckPerfResultsFixture): def test_generate_cedar_report(self): - report = self.cbr_hook._generate_cedar_report(self.cbr_hook._parse_report(_BM_FULL_REPORT)) + report = self.cbr_hook._generate_cedar_report( + self.cbr_hook._parse_report(_BM_FULL_REPORT) + ) self.assertEqual(len(report), 2) self.assertEqual(report[0].thread_level, 1) @@ -193,18 +242,23 @@ class TestGenerateAndCheckPerfResults(GenerateAndCheckPerfResultsFixture): class TestBenchmarkThreadsReport(GenerateAndCheckPerfResultsFixture): def test_thread_from_name(self): - name_obj = self.bm_threads_report.parse_bm_name({"name": "BM_Name/arg name:100/threads:10"}) + name_obj = self.bm_threads_report.parse_bm_name( + {"name": "BM_Name/arg name:100/threads:10"} + ) self.assertEqual(name_obj.thread_count, "10") self.assertEqual(name_obj.statistic_type, None) self.assertEqual(name_obj.base_name, "BM_Name/arg name:100") name_obj = self.bm_threads_report.parse_bm_name( - {"name": "BM_Name/arg name:100/threads:10_mean", "aggregate_name": "mean"}) + {"name": "BM_Name/arg name:100/threads:10_mean", "aggregate_name": "mean"} + ) self.assertEqual(name_obj.thread_count, "10") self.assertEqual(name_obj.statistic_type, "mean") self.assertEqual(name_obj.base_name, "BM_Name/arg name:100") - name_obj = self.bm_threads_report.parse_bm_name({"name": "BM_Name/threads:abcd"}) + name_obj = self.bm_threads_report.parse_bm_name( + {"name": "BM_Name/threads:abcd"} + ) self.assertEqual(name_obj.thread_count, "abcd") self.assertEqual(name_obj.statistic_type, None) self.assertEqual(name_obj.base_name, "BM_Name") @@ -215,33 +269,42 @@ class TestBenchmarkThreadsReport(GenerateAndCheckPerfResultsFixture): self.assertEqual(name_obj.base_name, "BM_Name") name_obj = self.bm_threads_report.parse_bm_name( - {"name": "BM_Name/1/eeee_mean", "aggregate_name": "mean"}) + {"name": "BM_Name/1/eeee_mean", "aggregate_name": "mean"} + ) self.assertEqual(name_obj.thread_count, "1") self.assertEqual(name_obj.statistic_type, "mean") self.assertEqual(name_obj.base_name, "BM_Name/1/eeee") - name_obj = self.bm_threads_report.parse_bm_name({"name": "BM_Name/arg name:100"}) + name_obj = self.bm_threads_report.parse_bm_name( + {"name": "BM_Name/arg name:100"} + ) self.assertEqual(name_obj.thread_count, "1") self.assertEqual(name_obj.statistic_type, None) self.assertEqual(name_obj.base_name, "BM_Name/arg name:100") - name_obj = self.bm_threads_report.parse_bm_name({"name": "BM_baseline_match_simple/0"}) + name_obj = self.bm_threads_report.parse_bm_name( + {"name": "BM_baseline_match_simple/0"} + ) self.assertEqual(name_obj.thread_count, "1") self.assertEqual(name_obj.statistic_type, None) self.assertEqual(name_obj.base_name, "BM_baseline_match_simple/0") name_obj = self.bm_threads_report.parse_bm_name( - {"name": "BM_baseline_match_simple/0_mean", "aggregate_name": "mean"}) + {"name": "BM_baseline_match_simple/0_mean", "aggregate_name": "mean"} + ) self.assertEqual(name_obj.thread_count, "1") self.assertEqual(name_obj.statistic_type, "mean") self.assertEqual(name_obj.base_name, "BM_baseline_match_simple/0") def test_generate_multithread_cedar_metrics(self): self.bm_threads_report.add_report( - self.bm_threads_report.parse_bm_name(_BM_MULTITHREAD_REPORT), _BM_MULTITHREAD_REPORT) + self.bm_threads_report.parse_bm_name(_BM_MULTITHREAD_REPORT), + _BM_MULTITHREAD_REPORT, + ) self.bm_threads_report.add_report( self.bm_threads_report.parse_bm_name(_BM_MULTITHREAD_MEDIAN_REPORT), - _BM_MULTITHREAD_MEDIAN_REPORT) + _BM_MULTITHREAD_MEDIAN_REPORT, + ) self.assertEqual(len(self.bm_threads_report.thread_benchmark_map.keys()), 1) cedar_metrics = self.bm_threads_report.generate_cedar_metrics() @@ -254,11 +317,14 @@ class TestBenchmarkThreadsReport(GenerateAndCheckPerfResultsFixture): def test_generate_single_thread_cedar_metrics(self): self.bm_threads_report.add_report( - self.bm_threads_report.parse_bm_name(_BM_REPORT_1), _BM_REPORT_1) + self.bm_threads_report.parse_bm_name(_BM_REPORT_1), _BM_REPORT_1 + ) self.bm_threads_report.add_report( - self.bm_threads_report.parse_bm_name(_BM_REPORT_2), _BM_REPORT_2) + self.bm_threads_report.parse_bm_name(_BM_REPORT_2), _BM_REPORT_2 + ) self.bm_threads_report.add_report( - self.bm_threads_report.parse_bm_name(_BM_MEAN_REPORT), _BM_MEAN_REPORT) + self.bm_threads_report.parse_bm_name(_BM_MEAN_REPORT), _BM_MEAN_REPORT + ) self.assertEqual(len(self.bm_threads_report.thread_benchmark_map.keys()), 1) cedar_metrics = self.bm_threads_report.generate_cedar_metrics() @@ -272,13 +338,16 @@ class TestBenchmarkThreadsReport(GenerateAndCheckPerfResultsFixture): def test_generate_cedar_report_with_instructions(self): self.bm_threads_report.add_report( self.bm_threads_report.parse_bm_name(_BM_REPORT_WITH_INSTRUCTIONS_1), - _BM_REPORT_WITH_INSTRUCTIONS_1) + _BM_REPORT_WITH_INSTRUCTIONS_1, + ) self.bm_threads_report.add_report( self.bm_threads_report.parse_bm_name(_BM_REPORT_WITH_INSTRUCTIONS_2), - _BM_REPORT_WITH_INSTRUCTIONS_2) + _BM_REPORT_WITH_INSTRUCTIONS_2, + ) self.bm_threads_report.add_report( self.bm_threads_report.parse_bm_name(_BM_REPORT_WITH_INSTRUCTIONS_MEAN), - _BM_REPORT_WITH_INSTRUCTIONS_MEAN) + _BM_REPORT_WITH_INSTRUCTIONS_MEAN, + ) self.assertEqual(len(self.bm_threads_report.thread_benchmark_map.keys()), 1) cedar_metrics = self.bm_threads_report.generate_cedar_metrics() @@ -292,13 +361,16 @@ class TestBenchmarkThreadsReport(GenerateAndCheckPerfResultsFixture): def test_generate_cedar_report_with_cycles(self): self.bm_threads_report.add_report( self.bm_threads_report.parse_bm_name(_BM_REPORT_WITH_CYCLES_1), - _BM_REPORT_WITH_CYCLES_1) + _BM_REPORT_WITH_CYCLES_1, + ) self.bm_threads_report.add_report( self.bm_threads_report.parse_bm_name(_BM_REPORT_WITH_CYCLES_2), - _BM_REPORT_WITH_CYCLES_2) + _BM_REPORT_WITH_CYCLES_2, + ) self.bm_threads_report.add_report( self.bm_threads_report.parse_bm_name(_BM_REPORT_WITH_CYCLES_MEAN), - _BM_REPORT_WITH_CYCLES_MEAN) + _BM_REPORT_WITH_CYCLES_MEAN, + ) self.assertEqual(len(self.bm_threads_report.thread_benchmark_map.keys()), 1) cedar_metrics = self.bm_threads_report.generate_cedar_metrics() @@ -313,64 +385,112 @@ class TestBenchmarkThreadsReport(GenerateAndCheckPerfResultsFixture): class TestCheckPerfResultTestCase(unittest.TestCase): def test_all_metrics_pass(self): thresholds_to_check: List[cbr.IndividualMetricThreshold] = [ - cbr.IndividualMetricThreshold(metric_name="latency", thread_level=1, - test_name="fake-test", value=10, bound_direction="upper") + cbr.IndividualMetricThreshold( + metric_name="latency", + thread_level=1, + test_name="fake-test", + value=10, + bound_direction="upper", + ) ] reported_metrics: Dict[cbr.ReportedMetric, CedarMetric] = { - cbr.ReportedMetric(test_name="fake-test", thread_level=1, metric_name="latency"): - CedarMetric(name="latency", type="LATENCY", value=1) + cbr.ReportedMetric( + test_name="fake-test", thread_level=1, metric_name="latency" + ): CedarMetric(name="latency", type="LATENCY", value=1) } test_case = cbr.CheckPerfResultTestCase( - logging.getLogger("hook_logger"), "my-test", None, None, None, thresholds_to_check, - reported_metrics) + logging.getLogger("hook_logger"), + "my-test", + None, + None, + None, + thresholds_to_check, + reported_metrics, + ) # We want to make sure this can run without any exceptions. test_case.run_test() def test_a_metric_fails(self): thresholds_to_check: List[cbr.IndividualMetricThreshold] = [ - cbr.IndividualMetricThreshold(metric_name="latency", thread_level=1, - test_name="fake-test", value=10, bound_direction="upper") + cbr.IndividualMetricThreshold( + metric_name="latency", + thread_level=1, + test_name="fake-test", + value=10, + bound_direction="upper", + ) ] reported_metrics: Dict[cbr.ReportedMetric, CedarMetric] = { - cbr.ReportedMetric(test_name="fake-test", thread_level=1, metric_name="latency"): - CedarMetric(name="latency", type="LATENCY", value=11) + cbr.ReportedMetric( + test_name="fake-test", thread_level=1, metric_name="latency" + ): CedarMetric(name="latency", type="LATENCY", value=11) } test_case = cbr.CheckPerfResultTestCase( - logging.getLogger("hook_logger"), "my-test", None, None, None, thresholds_to_check, - reported_metrics) + logging.getLogger("hook_logger"), + "my-test", + None, + None, + None, + thresholds_to_check, + reported_metrics, + ) with self.assertRaisesRegex(ServerFailure, "threshold check"): test_case.run_test() def test_metric_doesnt_exist(self): thresholds_to_check: List[cbr.IndividualMetricThreshold] = [ - cbr.IndividualMetricThreshold(metric_name="latency", thread_level=1, - test_name="fake-test", value=10, bound_direction="upper") + cbr.IndividualMetricThreshold( + metric_name="latency", + thread_level=1, + test_name="fake-test", + value=10, + bound_direction="upper", + ) ] reported_metrics: Dict[cbr.ReportedMetric, CedarMetric] = { - cbr.ReportedMetric(test_name="fake-test", thread_level=1, metric_name="instructions"): - CedarMetric(name="instructions", type="LATENCY", value=1) + cbr.ReportedMetric( + test_name="fake-test", thread_level=1, metric_name="instructions" + ): CedarMetric(name="instructions", type="LATENCY", value=1) } test_case = cbr.CheckPerfResultTestCase( - logging.getLogger("hook_logger"), "my-test", None, None, None, thresholds_to_check, - reported_metrics) + logging.getLogger("hook_logger"), + "my-test", + None, + None, + None, + thresholds_to_check, + reported_metrics, + ) with self.assertRaisesRegex(ServerFailure, "threshold check"): test_case.run_test() def test_thread_level_doesnt_exist(self): thresholds_to_check: List[cbr.IndividualMetricThreshold] = [ - cbr.IndividualMetricThreshold(metric_name="latency", thread_level=1, - test_name="fake-test", value=10, bound_direction="upper") + cbr.IndividualMetricThreshold( + metric_name="latency", + thread_level=1, + test_name="fake-test", + value=10, + bound_direction="upper", + ) ] reported_metrics: Dict[cbr.ReportedMetric, CedarMetric] = { - cbr.ReportedMetric(test_name="fake-test", thread_level=12, metric_name="latency"): - CedarMetric(name="latency", type="LATENCY", value=1) + cbr.ReportedMetric( + test_name="fake-test", thread_level=12, metric_name="latency" + ): CedarMetric(name="latency", type="LATENCY", value=1) } test_case = cbr.CheckPerfResultTestCase( - logging.getLogger("hook_logger"), "my-test", None, None, None, thresholds_to_check, - reported_metrics) + logging.getLogger("hook_logger"), + "my-test", + None, + None, + None, + thresholds_to_check, + reported_metrics, + ) with self.assertRaisesRegex(ServerFailure, "threshold check"): test_case.run_test() diff --git a/buildscripts/tests/resmokelib/testing/hooks/test_lifecycle.py b/buildscripts/tests/resmokelib/testing/hooks/test_lifecycle.py index 9ecae41f574..8d574b09a5f 100644 --- a/buildscripts/tests/resmokelib/testing/hooks/test_lifecycle.py +++ b/buildscripts/tests/resmokelib/testing/hooks/test_lifecycle.py @@ -52,8 +52,9 @@ class TestFlagBasedThreadLifecycle(unittest.TestCase): class TestFileBasedThreadLifecycle(unittest.TestCase): - - ACTION_FILES = lifecycle_interface.ActionFiles._make(lifecycle_interface.ActionFiles._fields) + ACTION_FILES = lifecycle_interface.ActionFiles._make( + lifecycle_interface.ActionFiles._fields + ) def test_still_idle_after_test_starts(self): lifecycle = lifecycle_interface.FileBasedThreadLifecycle(self.ACTION_FILES) @@ -96,7 +97,9 @@ class TestFileBasedThreadLifecycle(unittest.TestCase): @mock.patch("threading.Condition") @mock.patch("os.path") - def test_thread_waits_until_permitted_file_exists(self, mock_os_path, MockCondition): # pylint: disable=invalid-name + def test_thread_waits_until_permitted_file_exists( + self, mock_os_path, MockCondition + ): # pylint: disable=invalid-name lifecycle = lifecycle_interface.FileBasedThreadLifecycle(self.ACTION_FILES) lifecycle.mark_test_started() @@ -104,7 +107,7 @@ class TestFileBasedThreadLifecycle(unittest.TestCase): if filename == "permitted": return permitted_file_exists - self.fail("Mock called with unexpected filename: %s" % (filename, )) + self.fail("Mock called with unexpected filename: %s" % (filename,)) mock_os_path.isfile = mock_does_permitted_file_exists @@ -121,7 +124,9 @@ class TestFileBasedThreadLifecycle(unittest.TestCase): @mock.patch("threading.Condition") @mock.patch("os.path") - def test_waiting_for_action_permitted_is_interruptible(self, mock_os_path, MockCondition): # pylint: disable=invalid-name + def test_waiting_for_action_permitted_is_interruptible( + self, mock_os_path, MockCondition + ): # pylint: disable=invalid-name lifecycle = lifecycle_interface.FileBasedThreadLifecycle(self.ACTION_FILES) lifecycle.mark_test_started() diff --git a/buildscripts/tests/resmokelib/testing/hooks/test_stepdown.py b/buildscripts/tests/resmokelib/testing/hooks/test_stepdown.py index 4cfbe3d83ce..d731e19eaa1 100644 --- a/buildscripts/tests/resmokelib/testing/hooks/test_stepdown.py +++ b/buildscripts/tests/resmokelib/testing/hooks/test_stepdown.py @@ -14,9 +14,13 @@ from buildscripts.resmokelib.testing.hooks import stepdown as _stepdown class TestStepdownThread(unittest.TestCase): @mock.patch("buildscripts.resmokelib.testing.fixtures.replicaset.ReplicaSetFixture") - @mock.patch("buildscripts.resmokelib.testing.fixtures.shardedcluster.ShardedClusterFixture") - @mock.patch("buildscripts.resmokelib.testing.hooks.stepdown._StepdownThread.is_alive", - mock.Mock(return_value=True)) + @mock.patch( + "buildscripts.resmokelib.testing.fixtures.shardedcluster.ShardedClusterFixture" + ) + @mock.patch( + "buildscripts.resmokelib.testing.hooks.stepdown._StepdownThread.is_alive", + mock.Mock(return_value=True), + ) def test_pause_throws_error(self, shardcluster_fixture, rs_fixture): stepdown_thread = _stepdown._StepdownThread( logger=logging.getLogger("hook_logger"), diff --git a/buildscripts/tests/resmokelib/testing/test_executor.py b/buildscripts/tests/resmokelib/testing/test_executor.py index 89b2037e89d..d367548d271 100644 --- a/buildscripts/tests/resmokelib/testing/test_executor.py +++ b/buildscripts/tests/resmokelib/testing/test_executor.py @@ -1,4 +1,5 @@ """Unit tests for the resmokelib.testing.executor module.""" + import logging import unittest @@ -53,17 +54,21 @@ class TestCreateJobs(unittest.TestCase): class TestCreateQueueElemForTestName(unittest.TestCase): @mock.patch(ns("testcases.make_test_case")) @mock.patch(ns("queue_elem_factory")) - def test_queue_elem_created_for_test_name(self, queue_elem_mock, make_test_case_mock): + def test_queue_elem_created_for_test_name( + self, queue_elem_mock, make_test_case_mock + ): num_tests = 1 test_config = {} suite = mock_suite(num_tests) ut_executor = UnitTestExecutor(suite, test_config) - queue_elem = ut_executor._create_queue_elem_for_test_name('test_name') + queue_elem = ut_executor._create_queue_elem_for_test_name("test_name") self.assertEqual(queue_elem_mock.return_value, queue_elem) - make_test_case_mock.assert_called_with(suite.test_kind, ut_executor.test_queue_logger, - 'test_name', **test_config) - queue_elem_mock.assert_called_with(make_test_case_mock.return_value, test_config, - suite.options) + make_test_case_mock.assert_called_with( + suite.test_kind, ut_executor.test_queue_logger, "test_name", **test_config + ) + queue_elem_mock.assert_called_with( + make_test_case_mock.return_value, test_config, suite.options + ) class TestMakeTestQueue(unittest.TestCase): diff --git a/buildscripts/tests/resmokelib/testing/test_job.py b/buildscripts/tests/resmokelib/testing/test_job.py index b026f2849bf..11eceb06a0a 100644 --- a/buildscripts/tests/resmokelib/testing/test_job.py +++ b/buildscripts/tests/resmokelib/testing/test_job.py @@ -17,7 +17,6 @@ from buildscripts.resmokelib.utils import queue as _queue class TestJob(unittest.TestCase): - TESTS = ["jstests/core/and.js", "jstests/core/or.js"] @staticmethod @@ -35,8 +34,12 @@ class TestJob(unittest.TestCase): return interrupt_flag @staticmethod - def get_suite_options(num_repeat_tests=None, time_repeat_tests_secs=None, - num_repeat_tests_min=None, num_repeat_tests_max=None): + def get_suite_options( + num_repeat_tests=None, + time_repeat_tests_secs=None, + num_repeat_tests_min=None, + num_repeat_tests_max=None, + ): suite_options = mock.Mock() suite_options.num_repeat_tests = num_repeat_tests suite_options.time_repeat_tests_secs = time_repeat_tests_secs @@ -73,13 +76,19 @@ class TestJob(unittest.TestCase): time_repeat_tests_secs = 10 expected_tests_run = self.expected_run_num(time_repeat_tests_secs, increment) queue = _queue.Queue() - suite_options = self.get_suite_options(time_repeat_tests_secs=time_repeat_tests_secs) + suite_options = self.get_suite_options( + time_repeat_tests_secs=time_repeat_tests_secs + ) mock_time = MockTime(increment) job_object = UnitJob(suite_options) - self.queue_tests(self.TESTS, queue, queue_element.QueueElemRepeatTime, suite_options) + self.queue_tests( + self.TESTS, queue, queue_element.QueueElemRepeatTime, suite_options + ) job_object._get_time = mock_time.time job_object._run(queue, self.mock_interrupt_flag()) - self.assertEqual(job_object.total_test_num, expected_tests_run * len(self.TESTS)) + self.assertEqual( + job_object.total_test_num, expected_tests_run * len(self.TESTS) + ) for test in self.TESTS: self.assertEqual(job_object.tests[test], expected_tests_run) @@ -89,14 +98,20 @@ class TestJob(unittest.TestCase): num_repeat_tests_max = 100 expected_tests_run = self.expected_run_num(time_repeat_tests_secs, increment) queue = _queue.Queue() - suite_options = self.get_suite_options(time_repeat_tests_secs=time_repeat_tests_secs, - num_repeat_tests_max=num_repeat_tests_max) + suite_options = self.get_suite_options( + time_repeat_tests_secs=time_repeat_tests_secs, + num_repeat_tests_max=num_repeat_tests_max, + ) mock_time = MockTime(increment) job_object = UnitJob(suite_options) - self.queue_tests(self.TESTS, queue, queue_element.QueueElemRepeatTime, suite_options) + self.queue_tests( + self.TESTS, queue, queue_element.QueueElemRepeatTime, suite_options + ) job_object._get_time = mock_time.time job_object._run(queue, self.mock_interrupt_flag()) - self.assertLess(job_object.total_test_num, num_repeat_tests_max * len(self.TESTS)) + self.assertLess( + job_object.total_test_num, num_repeat_tests_max * len(self.TESTS) + ) for test in self.TESTS: self.assertEqual(job_object.tests[test], expected_tests_run) @@ -106,14 +121,20 @@ class TestJob(unittest.TestCase): num_repeat_tests_min = 1 expected_tests_run = self.expected_run_num(time_repeat_tests_secs, increment) queue = _queue.Queue() - suite_options = self.get_suite_options(time_repeat_tests_secs=time_repeat_tests_secs, - num_repeat_tests_min=num_repeat_tests_min) + suite_options = self.get_suite_options( + time_repeat_tests_secs=time_repeat_tests_secs, + num_repeat_tests_min=num_repeat_tests_min, + ) mock_time = MockTime(increment) job_object = UnitJob(suite_options) - self.queue_tests(self.TESTS, queue, queue_element.QueueElemRepeatTime, suite_options) + self.queue_tests( + self.TESTS, queue, queue_element.QueueElemRepeatTime, suite_options + ) job_object._get_time = mock_time.time job_object._run(queue, self.mock_interrupt_flag()) - self.assertGreater(job_object.total_test_num, num_repeat_tests_min * len(self.TESTS)) + self.assertGreater( + job_object.total_test_num, num_repeat_tests_min * len(self.TESTS) + ) for test in self.TESTS: self.assertEqual(job_object.tests[test], expected_tests_run) @@ -124,16 +145,24 @@ class TestJob(unittest.TestCase): num_repeat_tests_max = 100 expected_tests_run = self.expected_run_num(time_repeat_tests_secs, increment) queue = _queue.Queue() - suite_options = self.get_suite_options(time_repeat_tests_secs=time_repeat_tests_secs, - num_repeat_tests_min=num_repeat_tests_min, - num_repeat_tests_max=num_repeat_tests_max) + suite_options = self.get_suite_options( + time_repeat_tests_secs=time_repeat_tests_secs, + num_repeat_tests_min=num_repeat_tests_min, + num_repeat_tests_max=num_repeat_tests_max, + ) mock_time = MockTime(increment) job_object = UnitJob(suite_options) - self.queue_tests(self.TESTS, queue, queue_element.QueueElemRepeatTime, suite_options) + self.queue_tests( + self.TESTS, queue, queue_element.QueueElemRepeatTime, suite_options + ) job_object._get_time = mock_time.time job_object._run(queue, self.mock_interrupt_flag()) - self.assertGreater(job_object.total_test_num, num_repeat_tests_min * len(self.TESTS)) - self.assertLess(job_object.total_test_num, num_repeat_tests_max * len(self.TESTS)) + self.assertGreater( + job_object.total_test_num, num_repeat_tests_min * len(self.TESTS) + ) + self.assertLess( + job_object.total_test_num, num_repeat_tests_max * len(self.TESTS) + ) for test in self.TESTS: self.assertEqual(job_object.tests[test], expected_tests_run) @@ -143,15 +172,21 @@ class TestJob(unittest.TestCase): num_repeat_tests_min = 3 num_repeat_tests_max = 100 queue = _queue.Queue() - suite_options = self.get_suite_options(time_repeat_tests_secs=time_repeat_tests_secs, - num_repeat_tests_min=num_repeat_tests_min, - num_repeat_tests_max=num_repeat_tests_max) + suite_options = self.get_suite_options( + time_repeat_tests_secs=time_repeat_tests_secs, + num_repeat_tests_min=num_repeat_tests_min, + num_repeat_tests_max=num_repeat_tests_max, + ) mock_time = MockTime(increment) job_object = UnitJob(suite_options) - self.queue_tests(self.TESTS, queue, queue_element.QueueElemRepeatTime, suite_options) + self.queue_tests( + self.TESTS, queue, queue_element.QueueElemRepeatTime, suite_options + ) job_object._get_time = mock_time.time job_object._run(queue, self.mock_interrupt_flag()) - self.assertEqual(job_object.total_test_num, num_repeat_tests_min * len(self.TESTS)) + self.assertEqual( + job_object.total_test_num, num_repeat_tests_min * len(self.TESTS) + ) for test in self.TESTS: self.assertEqual(job_object.tests[test], num_repeat_tests_min) @@ -160,17 +195,25 @@ class TestJob(unittest.TestCase): time_repeat_tests_secs = 30 num_repeat_tests_min = 1 num_repeat_tests_max = 10 - expected_time_repeat_tests = self.expected_run_num(time_repeat_tests_secs, increment) + expected_time_repeat_tests = self.expected_run_num( + time_repeat_tests_secs, increment + ) queue = _queue.Queue() - suite_options = self.get_suite_options(time_repeat_tests_secs=time_repeat_tests_secs, - num_repeat_tests_min=num_repeat_tests_min, - num_repeat_tests_max=num_repeat_tests_max) + suite_options = self.get_suite_options( + time_repeat_tests_secs=time_repeat_tests_secs, + num_repeat_tests_min=num_repeat_tests_min, + num_repeat_tests_max=num_repeat_tests_max, + ) mock_time = MockTime(increment) job_object = UnitJob(suite_options) - self.queue_tests(self.TESTS, queue, queue_element.QueueElemRepeatTime, suite_options) + self.queue_tests( + self.TESTS, queue, queue_element.QueueElemRepeatTime, suite_options + ) job_object._get_time = mock_time.time job_object._run(queue, self.mock_interrupt_flag()) - self.assertEqual(job_object.total_test_num, num_repeat_tests_max * len(self.TESTS)) + self.assertEqual( + job_object.total_test_num, num_repeat_tests_max * len(self.TESTS) + ) for test in self.TESTS: self.assertEqual(job_object.tests[test], num_repeat_tests_max) self.assertLess(job_object.tests[test], expected_time_repeat_tests) @@ -193,8 +236,16 @@ class MockTime(object): class UnitJob(job.Job): def __init__(self, suite_options): - super(UnitJob, self).__init__(0, logging.getLogger("job_unittest"), None, [], None, None, - suite_options, logging.getLogger("job_unittest")) + super(UnitJob, self).__init__( + 0, + logging.getLogger("job_unittest"), + None, + [], + None, + None, + suite_options, + logging.getLogger("job_unittest"), + ) self.total_test_num = 0 self.tests = {} @@ -210,8 +261,16 @@ class TestFixtureSetupAndTeardown(unittest.TestCase): def setUp(self): logger = logging.getLogger("job_unittest") - self.__job_object = job.Job(job_num=0, logger=logger, fixture=None, hooks=[], report=None, - archival=None, suite_options=None, test_queue_logger=logger) + self.__job_object = job.Job( + job_num=0, + logger=logger, + fixture=None, + hooks=[], + report=None, + archival=None, + suite_options=None, + test_queue_logger=logger, + ) self.__context = Context(trace_id=0, span_id=0, is_remote=False) # Initialize the Job instance such that its setup_fixture() and teardown_fixture() methods @@ -226,7 +285,9 @@ class TestFixtureSetupAndTeardown(unittest.TestCase): setup_flag = threading.Event() teardown_flag = threading.Event() - self.__job_object(queue, interrupt_flag, self.__context, setup_flag, teardown_flag) + self.__job_object( + queue, interrupt_flag, self.__context, setup_flag, teardown_flag + ) self.assertEqual(setup_succeeded, not interrupt_flag.is_set()) self.assertEqual(setup_succeeded, not setup_flag.is_set()) @@ -244,13 +305,17 @@ class TestFixtureSetupAndTeardown(unittest.TestCase): self.__assert_when_run_tests(setup_succeeded=False) def test_setup_raises_logging_config_exception(self): - self.__job_object.manager.setup_fixture.side_effect = errors.LoggerRuntimeConfigError( - "Logging configuration error intentionally raised in unit test") + self.__job_object.manager.setup_fixture.side_effect = ( + errors.LoggerRuntimeConfigError( + "Logging configuration error intentionally raised in unit test" + ) + ) self.__assert_when_run_tests(setup_succeeded=False) def test_setup_raises_unexpected_exception(self): self.__job_object.manager.setup_fixture.side_effect = Exception( - "Generic error intentionally raised in unit test") + "Generic error intentionally raised in unit test" + ) self.__assert_when_run_tests(setup_succeeded=False) def test_teardown_returns_failure(self): @@ -258,13 +323,17 @@ class TestFixtureSetupAndTeardown(unittest.TestCase): self.__assert_when_run_tests(teardown_succeeded=False) def test_teardown_raises_logging_config_exception(self): - self.__job_object.manager.teardown_fixture.side_effect = errors.LoggerRuntimeConfigError( - "Logging configuration error intentionally raised in unit test") + self.__job_object.manager.teardown_fixture.side_effect = ( + errors.LoggerRuntimeConfigError( + "Logging configuration error intentionally raised in unit test" + ) + ) self.__assert_when_run_tests(teardown_succeeded=False) def test_teardown_raises_unexpected_exception(self): self.__job_object.manager.teardown_fixture.side_effect = Exception( - "Generic error intentionally raised in unit test") + "Generic error intentionally raised in unit test" + ) self.__assert_when_run_tests(teardown_succeeded=False) @@ -274,17 +343,25 @@ class TestNoOpFixtureSetupAndTeardown(unittest.TestCase): def setUp(self): self.logger = logging.getLogger("job_unittest") fixturelib = FixtureLib() - self.__noop_fixture = _fixtures.NoOpFixture(logger=self.logger, job_num=0, - fixturelib=fixturelib) + self.__noop_fixture = _fixtures.NoOpFixture( + logger=self.logger, job_num=0, fixturelib=fixturelib + ) self.__noop_fixture.setup = mock.Mock() self.__noop_fixture.teardown = mock.Mock() test_report = mock.Mock() test_report.find_test_info().status = "pass" - self.__job_object = job.Job(job_num=0, logger=self.logger, fixture=self.__noop_fixture, - hooks=[], report=test_report, archival=None, suite_options=None, - test_queue_logger=self.logger) + self.__job_object = job.Job( + job_num=0, + logger=self.logger, + fixture=self.__noop_fixture, + hooks=[], + report=test_report, + archival=None, + suite_options=None, + test_queue_logger=self.logger, + ) def test_setup_called_for_noop_fixture(self): self.assertTrue(self.__job_object.manager.setup_fixture(self.logger)) diff --git a/buildscripts/tests/resmokelib/testing/test_queue_element.py b/buildscripts/tests/resmokelib/testing/test_queue_element.py index 33dfd0bf05b..858e8d2299d 100644 --- a/buildscripts/tests/resmokelib/testing/test_queue_element.py +++ b/buildscripts/tests/resmokelib/testing/test_queue_element.py @@ -1,4 +1,5 @@ """Unit tests for the resmokelib.testing.executor module.""" + import unittest import mock diff --git a/buildscripts/tests/resmokelib/testing/test_suite.py b/buildscripts/tests/resmokelib/testing/test_suite.py index 685e0ecc5f4..23206541e85 100644 --- a/buildscripts/tests/resmokelib/testing/test_suite.py +++ b/buildscripts/tests/resmokelib/testing/test_suite.py @@ -1,4 +1,5 @@ """Unit tests for the resmokelib.testing.suite module.""" + import unittest from mock import MagicMock @@ -48,4 +49,6 @@ class TestNumJobsToStart(unittest.TestCase): num_repeat = 2 under_test._config.JOBS = 100 under_test._config.REPEAT_TESTS = num_repeat - self.assertEqual(self.num_tests * num_repeat, self.suite.get_num_jobs_to_start()) + self.assertEqual( + self.num_tests * num_repeat, self.suite.get_num_jobs_to_start() + ) diff --git a/buildscripts/tests/resmokelib/testing/test_symbolizer_service.py b/buildscripts/tests/resmokelib/testing/test_symbolizer_service.py index 5fd7cd6baf3..dffe7a4bb6b 100644 --- a/buildscripts/tests/resmokelib/testing/test_symbolizer_service.py +++ b/buildscripts/tests/resmokelib/testing/test_symbolizer_service.py @@ -1,4 +1,5 @@ """Unit tests for buildscripts/resmokelib/testing/symbolizer_service.py.""" + import os import unittest from pathlib import Path @@ -10,7 +11,8 @@ from buildscripts.resmokelib.testing import symbolizer_service as under_test def mock_resmoke_symbolizer_config(): config_mock: under_test.ResmokeSymbolizerConfig = MagicMock( - spec_set=under_test.ResmokeSymbolizerConfig) + spec_set=under_test.ResmokeSymbolizerConfig + ) config_mock.evg_task_id = "evg_task_id" config_mock.client_id = "client_id" config_mock.client_secret = "client_secret" @@ -23,10 +25,14 @@ class TestResmokeSymbolizer(unittest.TestCase): def setUp(self) -> None: self.config_mock = mock_resmoke_symbolizer_config() self.symbolizer_service_mock: under_test.SymbolizerService = MagicMock( - spec_set=under_test.SymbolizerService) - self.file_service_mock: under_test.FileService = MagicMock(spec_set=under_test.FileService) + spec_set=under_test.SymbolizerService + ) + self.file_service_mock: under_test.FileService = MagicMock( + spec_set=under_test.FileService + ) self.resmoke_symbolizer = under_test.ResmokeSymbolizer( - self.config_mock, self.symbolizer_service_mock, self.file_service_mock) + self.config_mock, self.symbolizer_service_mock, self.file_service_mock + ) def test_symbolize_test_logs_should_not_symbolize(self): self.config_mock.is_windows.return_value = True @@ -73,7 +79,9 @@ class TestResmokeSymbolizer(unittest.TestCase): def test_get_stacktrace_dir_returns_dir(self): dbpath = "dbpath" - test = MagicMock(fixture=MagicMock(get_dbpath_prefix=MagicMock(return_value=dbpath))) + test = MagicMock( + fixture=MagicMock(get_dbpath_prefix=MagicMock(return_value=dbpath)) + ) self.file_service_mock.check_path_exists.return_value = True ret = self.resmoke_symbolizer.get_stacktrace_dir(test) @@ -176,12 +184,16 @@ class TestFileService(unittest.TestCase): fstream.write("stacktrace") self.assertEqual( - set(self.file_service.filter_out_empty_files(abs_file_paths)), set(abs_file_paths)) + set(self.file_service.filter_out_empty_files(abs_file_paths)), + set(abs_file_paths), + ) def test_do_not_panic_when_file_does_not_exist(self): non_existing_files = ["this-does-not-exist.file", "my.cat"] # non-existing files should be filtered out, instead of causing errors - self.assertListEqual(self.file_service.filter_out_empty_files(non_existing_files), []) + self.assertListEqual( + self.file_service.filter_out_empty_files(non_existing_files), [] + ) def test_filter_out_empty_files_if_partly_empty(self): with TemporaryDirectory() as tmpdir: @@ -219,7 +231,10 @@ class TestFileService(unittest.TestCase): def test_filter_out_already_processed_files(self): processed_files = ["processed-file.stacktrace"] files = [ - "file.stacktrace", "other-file.stacktrace", "another-file.stacktrace", *processed_files + "file.stacktrace", + "other-file.stacktrace", + "another-file.stacktrace", + *processed_files, ] self.file_service.add_to_processed_files(processed_files) filtered = self.file_service.filter_out_already_processed_files(files) diff --git a/buildscripts/tests/resmokelib/testing/testcases/test_pytest.py b/buildscripts/tests/resmokelib/testing/testcases/test_pytest.py index 0e5bb8302ce..e60fb753506 100644 --- a/buildscripts/tests/resmokelib/testing/testcases/test_pytest.py +++ b/buildscripts/tests/resmokelib/testing/testcases/test_pytest.py @@ -1,4 +1,5 @@ """Unit tests for the buildscripts.resmokelib.testing.testcases.pytest module.""" + import logging import sys import unittest diff --git a/buildscripts/tests/resmokelib/utils/test_archival.py b/buildscripts/tests/resmokelib/utils/test_archival.py index e6bee0b473a..72031577c03 100644 --- a/buildscripts/tests/resmokelib/utils/test_archival.py +++ b/buildscripts/tests/resmokelib/utils/test_archival.py @@ -84,20 +84,28 @@ class ArchivalFileTests(ArchivalTestCase): input_files = "no_file" s3_path = self.s3_path("unittest/no_file.tgz", False) self.assertRaises( - OSError, lambda: self.archive.archive_files_to_s3(display_name, input_files, self. - bucket, s3_path)) + OSError, + lambda: self.archive.archive_files_to_s3( + display_name, input_files, self.bucket, s3_path + ), + ) # Invalid input_files in a list input_files = ["no_file", "no_file2"] s3_path = self.s3_path("unittest/no_files.tgz", False) self.assertRaises( - OSError, lambda: self.archive.archive_files_to_s3(display_name, input_files, self. - bucket, s3_path)) + OSError, + lambda: self.archive.archive_files_to_s3( + display_name, input_files, self.bucket, s3_path + ), + ) # No files display_name = "Unittest no files" s3_path = self.s3_path("unittest/no_files.tgz") - status, message = self.archive.archive_files_to_s3(display_name, [], self.bucket, s3_path) + status, message = self.archive.archive_files_to_s3( + display_name, [], self.bucket, s3_path + ) self.assertEqual(1, status, message) def test_files(self): @@ -105,16 +113,18 @@ class ArchivalFileTests(ArchivalTestCase): display_name = "Unittest valid file" temp_file = tempfile.mkstemp(dir=self.temp_dir)[1] s3_path = self.s3_path("unittest/valid_file.tgz") - status, message = self.archive.archive_files_to_s3(display_name, temp_file, self.bucket, - s3_path) + status, message = self.archive.archive_files_to_s3( + display_name, temp_file, self.bucket, s3_path + ) self.assertEqual(0, status, message) # 2 valid files display_name = "Unittest 2 valid files" temp_file2 = tempfile.mkstemp(dir=self.temp_dir)[1] s3_path = self.s3_path("unittest/2valid_files.tgz") - status, message = self.archive.archive_files_to_s3(display_name, [temp_file, temp_file2], - self.bucket, s3_path) + status, message = self.archive.archive_files_to_s3( + display_name, [temp_file, temp_file2], self.bucket, s3_path + ) self.assertEqual(0, status, message) def test_empty_directory(self): @@ -122,15 +132,17 @@ class ArchivalFileTests(ArchivalTestCase): display_name = "Unittest valid directory no files" temp_dir = tempfile.mkdtemp(dir=self.temp_dir) s3_path = self.s3_path("unittest/valid_directory.tgz") - status, message = self.archive.archive_files_to_s3(display_name, temp_dir, self.bucket, - s3_path) + status, message = self.archive.archive_files_to_s3( + display_name, temp_dir, self.bucket, s3_path + ) self.assertEqual(0, status, message) display_name = "Unittest valid directories no files" temp_dir2 = tempfile.mkdtemp(dir=self.temp_dir) s3_path = self.s3_path("unittest/valid_directories.tgz") - status, message = self.archive.archive_files_to_s3(display_name, [temp_dir, temp_dir2], - self.bucket, s3_path) + status, message = self.archive.archive_files_to_s3( + display_name, [temp_dir, temp_dir2], self.bucket, s3_path + ) self.assertEqual(0, status, message) def test_directory(self): @@ -140,8 +152,9 @@ class ArchivalFileTests(ArchivalTestCase): # Create 10 empty files for _ in range(10): tempfile.mkstemp(dir=temp_dir) - status, message = self.archive.archive_files_to_s3(display_name, temp_dir, self.bucket, - s3_path) + status, message = self.archive.archive_files_to_s3( + display_name, temp_dir, self.bucket, s3_path + ) self.assertEqual(0, status, message) display_name = "Unittest 2 valid directory files" @@ -150,8 +163,9 @@ class ArchivalFileTests(ArchivalTestCase): # Create 10 empty files for _ in range(10): tempfile.mkstemp(dir=temp_dir2) - status, message = self.archive.archive_files_to_s3(display_name, [temp_dir, temp_dir2], - self.bucket, s3_path) + status, message = self.archive.archive_files_to_s3( + display_name, [temp_dir, temp_dir2], self.bucket, s3_path + ) self.assertEqual(0, status, message) @@ -162,25 +176,27 @@ class ArchivalLimitSizeTests(ArchivalTestCase): @unittest.skip("Known broken. SERVER-48969 tracks re-enabling.") def test_limit_size(self): - # Files within limit size display_name = "Unittest under limit size" temp_file = tempfile.mkstemp(dir=self.temp_dir)[1] create_random_file(temp_file, 3) s3_path = self.s3_path("unittest/valid_limit_size.tgz") - status, message = self.archive.archive_files_to_s3(display_name, temp_file, self.bucket, - s3_path) + status, message = self.archive.archive_files_to_s3( + display_name, temp_file, self.bucket, s3_path + ) self.assertEqual(0, status, message) # Note the size limit is enforced after the file uploaded. Subsequent # uploads will not be permitted, once the limit has been reached. - status, message = self.archive.archive_files_to_s3(display_name, temp_file, self.bucket, - s3_path) + status, message = self.archive.archive_files_to_s3( + display_name, temp_file, self.bucket, s3_path + ) self.assertEqual(0, status, message) # Files beyond limit size display_name = "Unittest over limit size" - status, message = self.archive.archive_files_to_s3(display_name, temp_file, self.bucket, - s3_path) + status, message = self.archive.archive_files_to_s3( + display_name, temp_file, self.bucket, s3_path + ) self.assertEqual(1, status, message) @@ -190,51 +206,57 @@ class ArchivalLimitFileTests(ArchivalTestCase): return archival.Archival(cls.logger, limit_files=3, s3_client=cls.s3_client) def test_limit_file(self): - # Files within limit number display_name = "Unittest under limit number" temp_file = tempfile.mkstemp(dir=self.temp_dir)[1] s3_path = self.s3_path("unittest/valid_limit_number.tgz") - status, message = self.archive.archive_files_to_s3(display_name, temp_file, self.bucket, - s3_path) + status, message = self.archive.archive_files_to_s3( + display_name, temp_file, self.bucket, s3_path + ) self.assertEqual(0, status, message) - status, message = self.archive.archive_files_to_s3(display_name, temp_file, self.bucket, - s3_path) + status, message = self.archive.archive_files_to_s3( + display_name, temp_file, self.bucket, s3_path + ) self.assertEqual(0, status, message) - status, message = self.archive.archive_files_to_s3(display_name, temp_file, self.bucket, - s3_path) + status, message = self.archive.archive_files_to_s3( + display_name, temp_file, self.bucket, s3_path + ) self.assertEqual(0, status, message) # Files beyond limit number display_name = "Unittest over limit number" - status, message = self.archive.archive_files_to_s3(display_name, temp_file, self.bucket, - s3_path) + status, message = self.archive.archive_files_to_s3( + display_name, temp_file, self.bucket, s3_path + ) self.assertEqual(1, status, message) class ArchivalLimitTests(ArchivalTestCase): @classmethod def create_archival(cls): - return archival.Archival(cls.logger, limit_size_mb=3, limit_files=3, - s3_client=cls.s3_client) + return archival.Archival( + cls.logger, limit_size_mb=3, limit_files=3, s3_client=cls.s3_client + ) @unittest.skip("Known broken. SERVER-48969 tracks re-enabling.") def test_limits(self): - # Files within limits display_name = "Unittest under limits" temp_file = tempfile.mkstemp(dir=self.temp_dir)[1] create_random_file(temp_file, 1) s3_path = self.s3_path("unittest/valid_limits.tgz") - status, message = self.archive.archive_files_to_s3(display_name, temp_file, self.bucket, - s3_path) + status, message = self.archive.archive_files_to_s3( + display_name, temp_file, self.bucket, s3_path + ) self.assertEqual(0, status, message) - status, message = self.archive.archive_files_to_s3(display_name, temp_file, self.bucket, - s3_path) + status, message = self.archive.archive_files_to_s3( + display_name, temp_file, self.bucket, s3_path + ) self.assertEqual(0, status, message) # Files beyond limits display_name = "Unittest over limits" - status, message = self.archive.archive_files_to_s3(display_name, temp_file, self.bucket, - s3_path) + status, message = self.archive.archive_files_to_s3( + display_name, temp_file, self.bucket, s3_path + ) self.assertEqual(1, status, message) diff --git a/buildscripts/tests/resmokelib/utils/test_evergreen_conn.py b/buildscripts/tests/resmokelib/utils/test_evergreen_conn.py index f5816d786c5..a9ad666ac06 100644 --- a/buildscripts/tests/resmokelib/utils/test_evergreen_conn.py +++ b/buildscripts/tests/resmokelib/utils/test_evergreen_conn.py @@ -1,4 +1,5 @@ """Unit tests for buildscripts/resmokelib/utils/evergreen_conn.py.""" + import unittest from mock import patch @@ -48,8 +49,12 @@ class TestGetBuildvariantName(unittest.TestCase): major_minor_version = "4.0" buildvariant_name = evergreen_conn.get_buildvariant_name( - config=self.config, edition=edition, platform=platform, architecture=architecture, - major_minor_version=major_minor_version) + config=self.config, + edition=edition, + platform=platform, + architecture=architecture, + major_minor_version=major_minor_version, + ) self.assertEqual(buildvariant_name, "macos-4.0") def test_any_version(self): @@ -59,8 +64,12 @@ class TestGetBuildvariantName(unittest.TestCase): major_minor_version = "any" buildvariant_name = evergreen_conn.get_buildvariant_name( - config=self.config, edition=edition, platform=platform, architecture=architecture, - major_minor_version=major_minor_version) + config=self.config, + edition=edition, + platform=platform, + architecture=architecture, + major_minor_version=major_minor_version, + ) self.assertEqual(buildvariant_name, "macos-any") def test_buildvariant_not_found(self): @@ -70,35 +79,45 @@ class TestGetBuildvariantName(unittest.TestCase): major_minor_version = "any" buildvariant_name = evergreen_conn.get_buildvariant_name( - config=self.config, edition=edition, platform=platform, architecture=architecture, - major_minor_version=major_minor_version) + config=self.config, + edition=edition, + platform=platform, + architecture=architecture, + major_minor_version=major_minor_version, + ) self.assertEqual(buildvariant_name, "") class TestGetGenericBuildvariantName(unittest.TestCase): def setUp(self): raw_yaml = { - "evergreen_buildvariants": [{ - "name": "generic-buildvariant-name", - "edition": evergreen_conn.GENERIC_EDITION, - "platform": evergreen_conn.GENERIC_PLATFORM, - "architecture": evergreen_conn.GENERIC_ARCHITECTURE, - "versions": ["3.4", "3.6", "4.0"], - }, ] + "evergreen_buildvariants": [ + { + "name": "generic-buildvariant-name", + "edition": evergreen_conn.GENERIC_EDITION, + "platform": evergreen_conn.GENERIC_PLATFORM, + "architecture": evergreen_conn.GENERIC_ARCHITECTURE, + "versions": ["3.4", "3.6", "4.0"], + }, + ] } self.config = SetupMultiversionConfig(raw_yaml) def test_buildvariant_found(self): major_minor_version = "4.0" generic_buildvariant_name = evergreen_conn.get_generic_buildvariant_name( - config=self.config, major_minor_version=major_minor_version) + config=self.config, major_minor_version=major_minor_version + ) self.assertEqual(generic_buildvariant_name, "generic-buildvariant-name") def test_buildvarinat_not_found(self): major_minor_version = "4.2" - self.assertRaises(evergreen_conn.EvergreenConnError, - evergreen_conn.get_generic_buildvariant_name, self.config, - major_minor_version) + self.assertRaises( + evergreen_conn.EvergreenConnError, + evergreen_conn.get_generic_buildvariant_name, + self.config, + major_minor_version, + ) class TestGetEvergreenProjectAndVersion(unittest.TestCase): @@ -131,7 +150,9 @@ class TestGetEvergreenProjectAndVersion(unittest.TestCase): raise HTTPError() mock_evg_api.version_by_id.side_effect = version_by_id_side_effect - evg_version = evergreen_conn.get_evergreen_version(mock_evg_api, evergreen_version_id) + evg_version = evergreen_conn.get_evergreen_version( + mock_evg_api, evergreen_version_id + ) self.assertEqual(mock_version, evg_version) self.assertEqual(mock_version.version_id, evergreen_version_id) @@ -148,9 +169,13 @@ class TestGetCompileArtifactUrls(unittest.TestCase): def test_buildvariant_not_found(self, mock_evg_api, mock_version): buildvariant_name = "test" mock_version.build_variants_map = {"not-test": "build_id"} - self.assertRaises(evergreen_conn.EvergreenConnError, - evergreen_conn.get_compile_artifact_urls, mock_evg_api, mock_version, - buildvariant_name) + self.assertRaises( + evergreen_conn.EvergreenConnError, + evergreen_conn.get_compile_artifact_urls, + mock_evg_api, + mock_version, + buildvariant_name, + ) @patch("evergreen.task.Artifact") @patch("evergreen.task.Task") @@ -158,16 +183,20 @@ class TestGetCompileArtifactUrls(unittest.TestCase): @patch("evergreen.build.Build") @patch("evergreen.version.Version") @patch("evergreen.api.EvergreenApi") - def test_urls_found(self, mock_evg_api, mock_version, mock_build, mock_compile_task, - mock_push_task, mock_artifact): - + def test_urls_found( + self, + mock_evg_api, + mock_version, + mock_build, + mock_compile_task, + mock_push_task, + mock_artifact, + ): mock_compile_task.project_identifier = "dummy project id" expected_urls = { - "Binaries": - "https://mciuploads.s3.amazonaws.com/mongodb-mongo-master/ubuntu1804/90f767adbb1901d007ee4dd8714f53402d893669/binaries/mongo-mongodb_mongo_master_ubuntu1804_90f767adbb1901d007ee4dd8714f53402d893669_20_11_30_03_14_30.tgz", - "project_identifier": - "dummy project id" + "Binaries": "https://mciuploads.s3.amazonaws.com/mongodb-mongo-master/ubuntu1804/90f767adbb1901d007ee4dd8714f53402d893669/binaries/mongo-mongodb_mongo_master_ubuntu1804_90f767adbb1901d007ee4dd8714f53402d893669_20_11_30_03_14_30.tgz", + "project_identifier": "dummy project id", } mock_evg_api.build_by_id.return_value = mock_build mock_artifact.name = "Binaries" @@ -181,7 +210,9 @@ class TestGetCompileArtifactUrls(unittest.TestCase): mock_push_task.get_execution_or_self.return_value = mock_push_task mock_build.get_tasks.return_value = [mock_compile_task, mock_push_task] - urls = evergreen_conn.get_compile_artifact_urls(mock_evg_api, mock_version, "test") + urls = evergreen_conn.get_compile_artifact_urls( + mock_evg_api, mock_version, "test" + ) self.assertEqual(urls, expected_urls) @patch("evergreen.task.Artifact") @@ -192,19 +223,23 @@ class TestGetCompileArtifactUrls(unittest.TestCase): @patch("evergreen.build.Build") @patch("evergreen.version.Version") @patch("evergreen.api.EvergreenApi") - def test_child_urls_found(self, mock_evg_api, mock_version, mock_build, mock_compile_task, - mock_push_task, mock_child_task, mock_compile_artifact, - mock_child_task_artifact): - + def test_child_urls_found( + self, + mock_evg_api, + mock_version, + mock_build, + mock_compile_task, + mock_push_task, + mock_child_task, + mock_compile_artifact, + mock_child_task_artifact, + ): mock_compile_task.project_identifier = "dummy project id" expected_urls = { - "Binaries": - "https://mciuploads.s3.amazonaws.com/mongodb-mongo-master/ubuntu1804/90f767adbb1901d007ee4dd8714f53402d893669/binaries/mongo-mongodb_mongo_master_ubuntu1804_90f767adbb1901d007ee4dd8714f53402d893669_20_11_30_03_14_30.tgz", - "Symbols": - "yeet_skeert", - "project_identifier": - "dummy project id", + "Binaries": "https://mciuploads.s3.amazonaws.com/mongodb-mongo-master/ubuntu1804/90f767adbb1901d007ee4dd8714f53402d893669/binaries/mongo-mongodb_mongo_master_ubuntu1804_90f767adbb1901d007ee4dd8714f53402d893669_20_11_30_03_14_30.tgz", + "Symbols": "yeet_skeert", + "project_identifier": "dummy project id", } mock_evg_api.build_by_id.return_value = mock_build mock_evg_api.task_by_id.return_value = mock_child_task @@ -227,7 +262,9 @@ class TestGetCompileArtifactUrls(unittest.TestCase): mock_build.get_tasks.return_value = [mock_compile_task, mock_push_task] - urls = evergreen_conn.get_compile_artifact_urls(mock_evg_api, mock_version, "test") + urls = evergreen_conn.get_compile_artifact_urls( + mock_evg_api, mock_version, "test" + ) self.assertEqual(urls, expected_urls) @patch("evergreen.task.Task") @@ -235,8 +272,9 @@ class TestGetCompileArtifactUrls(unittest.TestCase): @patch("evergreen.build.Build") @patch("evergreen.version.Version") @patch("evergreen.api.EvergreenApi") - def test_push_task_failed(self, mock_evg_api, mock_version, mock_build, mock_compile_task, - mock_push_task): + def test_push_task_failed( + self, mock_evg_api, mock_version, mock_build, mock_compile_task, mock_push_task + ): mock_evg_api.build_by_id.return_value = mock_build mock_compile_task.display_name = "compile" mock_compile_task.status = "success" @@ -246,20 +284,26 @@ class TestGetCompileArtifactUrls(unittest.TestCase): mock_push_task.get_execution_or_self.return_value = mock_push_task mock_build.get_tasks.return_value = [mock_compile_task, mock_push_task] - urls = evergreen_conn.get_compile_artifact_urls(mock_evg_api, mock_version, "test") + urls = evergreen_conn.get_compile_artifact_urls( + mock_evg_api, mock_version, "test" + ) self.assertEqual(urls, {}) @patch("evergreen.task.Task") @patch("evergreen.build.Build") @patch("evergreen.version.Version") @patch("evergreen.api.EvergreenApi") - def test_no_push_task(self, mock_evg_api, mock_version, mock_build, mock_compile_task): + def test_no_push_task( + self, mock_evg_api, mock_version, mock_build, mock_compile_task + ): mock_evg_api.build_by_id.return_value = mock_build mock_compile_task.display_name = "compile" mock_compile_task.status = "success" mock_build.get_tasks.return_value = [mock_compile_task] - urls = evergreen_conn.get_compile_artifact_urls(mock_evg_api, mock_version, "test") + urls = evergreen_conn.get_compile_artifact_urls( + mock_evg_api, mock_version, "test" + ) self.assertEqual(urls, {}) @patch("evergreen.build.Build") @@ -269,5 +313,7 @@ class TestGetCompileArtifactUrls(unittest.TestCase): mock_evg_api.build_by_id.return_value = mock_build mock_build.get_tasks.return_value = [] - urls = evergreen_conn.get_compile_artifact_urls(mock_evg_api, mock_version, "test") + urls = evergreen_conn.get_compile_artifact_urls( + mock_evg_api, mock_version, "test" + ) self.assertEqual(urls, {}) diff --git a/buildscripts/tests/sbom_linter/test_sbom.py b/buildscripts/tests/sbom_linter/test_sbom.py index 373dccf1996..b12aa724be5 100644 --- a/buildscripts/tests/sbom_linter/test_sbom.py +++ b/buildscripts/tests/sbom_linter/test_sbom.py @@ -26,7 +26,9 @@ class TestSbom(unittest.TestCase): def tearDown(self): shutil.rmtree(self.output_dir) - def assert_message_in_errors(self, error_manager: sbom_linter.ErrorManager, message: str): + def assert_message_in_errors( + self, error_manager: sbom_linter.ErrorManager, message: str + ): if not error_manager.find_message_in_errors(message): error_manager.print_errors() self.fail(f"Could not find error message matching: {message}") @@ -34,7 +36,9 @@ class TestSbom(unittest.TestCase): def test_valid_sbom(self): test_file = os.path.join(self.input_dir, "valid_sbom.json") third_party_libs = {"librdkafka", "protobuf"} - error_manager = sbom_linter.lint_sbom(test_file, test_file, third_party_libs, True) + error_manager = sbom_linter.lint_sbom( + test_file, test_file, third_party_libs, True + ) if not error_manager.zero_error(): error_manager.print_errors() self.assertTrue(error_manager.zero_error()) @@ -42,41 +46,57 @@ class TestSbom(unittest.TestCase): def test_undefined_dep(self): test_file = os.path.join(self.input_dir, "valid_sbom.json") third_party_libs = {"librdkafka", "protobuf", "extra_dep"} - error_manager = sbom_linter.lint_sbom(test_file, test_file, third_party_libs, False) - self.assert_message_in_errors(error_manager, sbom_linter.UNDEFINED_THIRD_PARTY_ERROR) + error_manager = sbom_linter.lint_sbom( + test_file, test_file, third_party_libs, False + ) + self.assert_message_in_errors( + error_manager, sbom_linter.UNDEFINED_THIRD_PARTY_ERROR + ) def test_missing_purl_or_cpe(self): test_file = os.path.join(self.input_dir, "sbom_missing_purl.json") third_party_libs = {"librdkafka", "protobuf"} - error_manager = sbom_linter.lint_sbom(test_file, test_file, third_party_libs, False) + error_manager = sbom_linter.lint_sbom( + test_file, test_file, third_party_libs, False + ) self.assert_message_in_errors(error_manager, sbom_linter.MISSING_PURL_CPE_ERROR) def test_missing_evidence(self): test_file = os.path.join(self.input_dir, "sbom_missing_evidence.json") third_party_libs = {"librdkafka", "protobuf"} - error_manager = sbom_linter.lint_sbom(test_file, test_file, third_party_libs, False) + error_manager = sbom_linter.lint_sbom( + test_file, test_file, third_party_libs, False + ) self.assert_message_in_errors(error_manager, sbom_linter.MISSING_EVIDENCE_ERROR) def test_missing_team_responsible(self): test_file = os.path.join(self.input_dir, "sbom_missing_team.json") third_party_libs = {"librdkafka", "protobuf"} - error_manager = sbom_linter.lint_sbom(test_file, test_file, third_party_libs, False) + error_manager = sbom_linter.lint_sbom( + test_file, test_file, third_party_libs, False + ) self.assert_message_in_errors(error_manager, sbom_linter.MISSING_TEAM_ERROR) def test_format(self): test_file = os.path.join(self.input_dir, "sbom_invalid_format.json") output_file = os.path.join(self.output_dir, "new_valid_sbom1.json") third_party_libs = {"librdkafka", "protobuf"} - error_manager = sbom_linter.lint_sbom(test_file, output_file, third_party_libs, True) + error_manager = sbom_linter.lint_sbom( + test_file, output_file, third_party_libs, True + ) self.assert_message_in_errors(error_manager, sbom_linter.FORMATTING_ERROR) - error_manager = sbom_linter.lint_sbom(output_file, output_file, third_party_libs, False) + error_manager = sbom_linter.lint_sbom( + output_file, output_file, third_party_libs, False + ) self.assertTrue(error_manager.zero_error()) def test_missing_version(self): test_file = os.path.join(self.input_dir, "sbom_missing_version.json") third_party_libs = {"librdkafka"} - error_manager = sbom_linter.lint_sbom(test_file, test_file, third_party_libs, False) + error_manager = sbom_linter.lint_sbom( + test_file, test_file, third_party_libs, False + ) self.assert_message_in_errors( error_manager, sbom_linter.MISSING_VERSION_IN_SBOM_COMPONENT_ERROR ) @@ -84,7 +104,9 @@ class TestSbom(unittest.TestCase): def test_missing_version_in_import_file(self): test_file = os.path.join(self.input_dir, "sbom_script_missing_version.json") third_party_libs = {"librdkafka"} - error_manager = sbom_linter.lint_sbom(test_file, test_file, third_party_libs, False) + error_manager = sbom_linter.lint_sbom( + test_file, test_file, third_party_libs, False + ) self.assert_message_in_errors( error_manager, sbom_linter.MISSING_VERSION_IN_IMPORT_FILE_ERROR ) @@ -92,7 +114,9 @@ class TestSbom(unittest.TestCase): def test_missing_import_file(self): test_file = os.path.join(self.input_dir, "sbom_script_file_missing.json") third_party_libs = {"librdkafka"} - error_manager = sbom_linter.lint_sbom(test_file, test_file, third_party_libs, False) + error_manager = sbom_linter.lint_sbom( + test_file, test_file, third_party_libs, False + ) self.assert_message_in_errors( error_manager, sbom_linter.COULD_NOT_FIND_OR_READ_SCRIPT_FILE_ERROR ) @@ -100,7 +124,9 @@ class TestSbom(unittest.TestCase): def test_pedigree_version_match(self): test_file = os.path.join(self.input_dir, "sbom_pedigree_version_match.json") third_party_libs = {"kafka"} - error_manager = sbom_linter.lint_sbom(test_file, test_file, third_party_libs, False) + error_manager = sbom_linter.lint_sbom( + test_file, test_file, third_party_libs, False + ) if not error_manager.zero_error(): error_manager.print_errors() self.assertTrue(error_manager.zero_error()) @@ -108,13 +134,17 @@ class TestSbom(unittest.TestCase): def test_schema_match_failure(self): test_file = os.path.join(self.input_dir, "sbom_component_name_missing.json") third_party_libs = {"librdkafka"} - error_manager = sbom_linter.lint_sbom(test_file, test_file, third_party_libs, False) + error_manager = sbom_linter.lint_sbom( + test_file, test_file, third_party_libs, False + ) self.assert_message_in_errors(error_manager, sbom_linter.SCHEMA_MATCH_FAILURE) def test_component_empty_version(self): test_file = os.path.join(self.input_dir, "sbom_component_empty_version.json") third_party_libs = {"librdkafka"} - error_manager = sbom_linter.lint_sbom(test_file, test_file, third_party_libs, False) + error_manager = sbom_linter.lint_sbom( + test_file, test_file, third_party_libs, False + ) self.assert_message_in_errors( error_manager, sbom_linter.MISSING_VERSION_IN_SBOM_COMPONENT_ERROR ) @@ -122,7 +152,9 @@ class TestSbom(unittest.TestCase): def test_missing_license(self): test_file = os.path.join(self.input_dir, "sbom_missing_license.json") third_party_libs = {"librdkafka"} - error_manager = sbom_linter.lint_sbom(test_file, test_file, third_party_libs, False) + error_manager = sbom_linter.lint_sbom( + test_file, test_file, third_party_libs, False + ) self.assert_message_in_errors( error_manager, sbom_linter.MISSING_LICENSE_IN_SBOM_COMPONENT_ERROR ) @@ -130,14 +162,18 @@ class TestSbom(unittest.TestCase): def test_invalid_license_expression(self): test_file = os.path.join(self.input_dir, "sbom_invalid_license_expression.json") third_party_libs = {"librdkafka"} - error_manager = sbom_linter.lint_sbom(test_file, test_file, third_party_libs, False) + error_manager = sbom_linter.lint_sbom( + test_file, test_file, third_party_libs, False + ) # print(error_manager.errors) self.assert_message_in_errors(error_manager, "ExpressionInfo") def test_named_license(self): test_file = os.path.join(self.input_dir, "sbom_named_license.json") third_party_libs = {"murmurhash3"} - error_manager = sbom_linter.lint_sbom(test_file, test_file, third_party_libs, False) + error_manager = sbom_linter.lint_sbom( + test_file, test_file, third_party_libs, False + ) if not error_manager.zero_error(): error_manager.print_errors() self.assertTrue(error_manager.zero_error()) diff --git a/buildscripts/tests/test_burn_in_tests.py b/buildscripts/tests/test_burn_in_tests.py index 6caceff17de..356e8954392 100644 --- a/buildscripts/tests/test_burn_in_tests.py +++ b/buildscripts/tests/test_burn_in_tests.py @@ -81,7 +81,9 @@ class TestRepeatConfig(unittest.TestCase): self.assertEqual(repeat_config, repeat_config.validate()) def test_validate_with_both_repeat_options_specified(self): - repeat_config = under_test.RepeatConfig(repeat_tests_secs=10, repeat_tests_num=5) + repeat_config = under_test.RepeatConfig( + repeat_tests_secs=10, repeat_tests_num=5 + ) with self.assertRaises(ValueError): repeat_config.validate() @@ -93,8 +95,9 @@ class TestRepeatConfig(unittest.TestCase): repeat_config.validate() def test_validate_with_repeat_min_greater_than_max(self): - repeat_config = under_test.RepeatConfig(repeat_tests_max=10, repeat_tests_min=100, - repeat_tests_secs=15) + repeat_config = under_test.RepeatConfig( + repeat_tests_max=10, repeat_tests_min=100, repeat_tests_secs=15 + ) with self.assertRaises(ValueError): repeat_config.validate() @@ -136,8 +139,9 @@ class TestRepeatConfig(unittest.TestCase): self.assertNotIn("--repeatSuites", repeat_options) def test_get_resmoke_repeat_options_secs_min_max(self): - repeat_config = under_test.RepeatConfig(repeat_tests_secs=5, repeat_tests_min=2, - repeat_tests_max=2) + repeat_config = under_test.RepeatConfig( + repeat_tests_secs=5, repeat_tests_min=2, repeat_tests_max=2 + ) repeat_options = repeat_config.generate_resmoke_options() self.assertIn("--repeatTestsSecs=5", repeat_options) @@ -177,16 +181,21 @@ class TestSetResmokeCmd(unittest.TestCase): repeat_config = under_test.RepeatConfig() resmoke_cmds = under_test._set_resmoke_cmd(repeat_config, []) - self.assertListEqual(resmoke_cmds, - [sys.executable, "buildscripts/resmoke.py", "run", '--repeatSuites=2']) + self.assertListEqual( + resmoke_cmds, + [sys.executable, "buildscripts/resmoke.py", "run", "--repeatSuites=2"], + ) def test__set_resmoke_cmd_no_opts(self): repeat_config = under_test.RepeatConfig() resmoke_args = ["arg1", "arg2"] resmoke_cmd = under_test._set_resmoke_cmd(repeat_config, resmoke_args) - expected_resmoke_cmd = [sys.executable, 'buildscripts/resmoke.py', 'run' - ] + resmoke_args + ['--repeatSuites=2'] + expected_resmoke_cmd = ( + [sys.executable, "buildscripts/resmoke.py", "run"] + + resmoke_args + + ["--repeatSuites=2"] + ) self.assertListEqual(expected_resmoke_cmd, resmoke_cmd) @@ -195,41 +204,59 @@ class TestSetResmokeCmd(unittest.TestCase): resmoke_args = ["arg1", "arg2"] resmoke_cmd = under_test._set_resmoke_cmd(repeat_config, resmoke_args) - expected_resmoke_cmd = [sys.executable, 'buildscripts/resmoke.py', 'run' - ] + resmoke_args + ['--repeatSuites=3'] + expected_resmoke_cmd = ( + [sys.executable, "buildscripts/resmoke.py", "run"] + + resmoke_args + + ["--repeatSuites=3"] + ) self.assertListEqual(expected_resmoke_cmd, resmoke_cmd) class RunTests(unittest.TestCase): - @patch(ns('subprocess.check_call')) + @patch(ns("subprocess.check_call")) def test_run_tests_no_tests(self, check_call_mock): tests_by_task = {} - resmoke_cmd = ["python", "buildscripts/resmoke.py", "run", "--continueOnFailure"] + resmoke_cmd = [ + "python", + "buildscripts/resmoke.py", + "run", + "--continueOnFailure", + ] under_test.run_tests(tests_by_task, resmoke_cmd) check_call_mock.assert_not_called() - @patch(ns('subprocess.check_call')) + @patch(ns("subprocess.check_call")) def test_run_tests_some_test(self, check_call_mock): n_tasks = 3 tests_by_task = create_tests_by_task_mock(n_tasks, 5) - resmoke_cmd = ["python", "buildscripts/resmoke.py", "run", "--continueOnFailure"] + resmoke_cmd = [ + "python", + "buildscripts/resmoke.py", + "run", + "--continueOnFailure", + ] under_test.run_tests(tests_by_task, resmoke_cmd) self.assertEqual(n_tasks, check_call_mock.call_count) - @patch(ns('sys.exit')) - @patch(ns('subprocess.check_call')) + @patch(ns("sys.exit")) + @patch(ns("subprocess.check_call")) def test_run_tests_tests_resmoke_failure(self, check_call_mock, exit_mock): error_code = 42 n_tasks = 3 tests_by_task = create_tests_by_task_mock(n_tasks, 5) - resmoke_cmd = ["python", "buildscripts/resmoke.py", "run", "--continueOnFailure"] + resmoke_cmd = [ + "python", + "buildscripts/resmoke.py", + "run", + "--continueOnFailure", + ] check_call_mock.side_effect = subprocess.CalledProcessError(error_code, "err1") - exit_mock.side_effect = ValueError('exiting') + exit_mock.side_effect = ValueError("exiting") with self.assertRaises(ValueError): under_test.run_tests(tests_by_task, resmoke_cmd) @@ -239,8 +266,11 @@ class RunTests(unittest.TestCase): MEMBERS_MAP = { - "test1.js": ["suite1", "suite2"], "test2.js": ["suite1", "suite3"], "test3.js": [], - "test4.js": ["suite1", "suite2", "suite3"], "test5.js": ["suite2"] + "test1.js": ["suite1", "suite2"], + "test2.js": ["suite1", "suite3"], + "test3.js": [], + "test4.js": ["suite1", "suite2", "suite3"], + "test5.js": ["suite2"], } SUITE1 = Mock() @@ -279,7 +309,9 @@ class CreateExecutorList(unittest.TestCase): @patch(RESMOKELIB + ".testing.suite.Suite") @patch(RESMOKELIB + ".suitesconfig.get_named_suites") - def test_create_executor_list_runs_core_suite(self, mock_get_named_suites, mock_suite_class): + def test_create_executor_list_runs_core_suite( + self, mock_get_named_suites, mock_suite_class + ): mock_get_named_suites.return_value = ["core"] under_test.create_executor_list([], []) @@ -287,8 +319,9 @@ class CreateExecutorList(unittest.TestCase): @patch(RESMOKELIB + ".testing.suite.Suite") @patch(RESMOKELIB + ".suitesconfig.get_named_suites") - def test_create_executor_list_ignores_dbtest_suite(self, mock_get_named_suites, - mock_suite_class): + def test_create_executor_list_ignores_dbtest_suite( + self, mock_get_named_suites, mock_suite_class + ): mock_get_named_suites.return_value = ["dbtest"] under_test.create_executor_list([], []) @@ -300,7 +333,9 @@ def create_variant_task_mock(task_name, suite_name, distro="distro"): variant_task.name = task_name variant_task.generated_task_name = task_name variant_task.get_suite_names.return_value = [suite_name] - variant_task.combined_suite_to_resmoke_args_map = {suite_name: f"--suites={suite_name}"} + variant_task.combined_suite_to_resmoke_args_map = { + suite_name: f"--suites={suite_name}" + } variant_task.run_on = [distro] return variant_task @@ -341,8 +376,9 @@ class TestCreateTaskList(unittest.TestCase): } exclude_tasks = [] - task_list = under_test.create_task_list(evg_conf_mock, variant, tests_by_suite, - exclude_tasks) + task_list = under_test.create_task_list( + evg_conf_mock, variant, tests_by_suite, exclude_tasks + ) self.assertIn("task 1", task_list) self.assertIn("task 2", task_list) @@ -361,8 +397,9 @@ class TestCreateTaskList(unittest.TestCase): } exclude_tasks = [] - task_list = under_test.create_task_list(evg_conf_mock, variant, tests_by_suite, - exclude_tasks) + task_list = under_test.create_task_list( + evg_conf_mock, variant, tests_by_suite, exclude_tasks + ) self.assertIn("task 1", task_list) task_info = task_list["task 1"] @@ -385,8 +422,9 @@ class TestCreateTaskList(unittest.TestCase): } exclude_tasks = ["task 2"] - task_list = under_test.create_task_list(evg_conf_mock, variant, tests_by_suite, - exclude_tasks) + task_list = under_test.create_task_list( + evg_conf_mock, variant, tests_by_suite, exclude_tasks + ) self.assertIn("task 1", task_list) self.assertNotIn("task 2", task_list) @@ -419,7 +457,9 @@ class TestCreateTestsByTask(unittest.TestCase): evg_conf_mock.get_variant.return_value = None with self.assertRaises(ValueError): - under_test.create_tests_by_task(variant, evg_conf_mock, set(), "install-dir/bin") + under_test.create_tests_by_task( + variant, evg_conf_mock, set(), "install-dir/bin" + ) class TestLocalFileChangeDetector(unittest.TestCase): @@ -481,7 +521,7 @@ class TestLocalFileChangeDetector(unittest.TestCase): class TestYamlBurnInExecutor(unittest.TestCase): - @patch('sys.stdout', new_callable=StringIO) + @patch("sys.stdout", new_callable=StringIO) def test_found_tasks_should_be_reported_as_yaml(self, stdout): n_tasks = 5 n_tests = 3 @@ -493,4 +533,6 @@ class TestYamlBurnInExecutor(unittest.TestCase): yaml_raw = stdout.getvalue() results = yaml.safe_load(yaml_raw) self.assertEqual(n_tasks, len(results["discovered_tasks"])) - self.assertEqual(n_tests, len(results["discovered_tasks"][0]["suites"][0]["test_list"])) + self.assertEqual( + n_tests, len(results["discovered_tasks"][0]["suites"][0]["test_list"]) + ) diff --git a/buildscripts/tests/test_debugsymb_mapper.py b/buildscripts/tests/test_debugsymb_mapper.py index bb794c18fe9..df2ab481f19 100644 --- a/buildscripts/tests/test_debugsymb_mapper.py +++ b/buildscripts/tests/test_debugsymb_mapper.py @@ -30,11 +30,14 @@ class TestGetBuildId(TestCmdOutputExtractor): " GA$3h864 0x00000010\tOPEN\n" " Applies to region from 0xb71 to 0xb71 (.annobin_init.c.hot)\n" " GA$3h864 0x00000010\tOPEN\n" - " Applies to region from 0xb71 to 0xb71 (.annobin_init.c.hot)") + " Applies to region from 0xb71 to 0xb71 (.annobin_init.c.hot)" + ) self.cmd_client_mock.run.return_value = readelf_output build_id_output = self.cmd_output_extractor.get_build_id("path/to/bin") - self.assertEqual(build_id_output.build_id, "74c2322104428836f3d94af6cd7471ee7cb5c4ee") + self.assertEqual( + build_id_output.build_id, "74c2322104428836f3d94af6cd7471ee7cb5c4ee" + ) self.assertEqual(build_id_output.cmd_output, readelf_output) def test_get_build_id_raises_error(self): @@ -48,16 +51,20 @@ class TestGetBuildId(TestCmdOutputExtractor): " GNU 0x00000014\tNT_GNU_BUILD_ID (unique build ID bitstring)\n" " Build ID: 74c2322104428836f3d94af6cd7471ee7cb5c4ee\n" "\n" - "Displaying notes found in: .gnu.build.attributes.hot") + "Displaying notes found in: .gnu.build.attributes.hot" + ) self.cmd_client_mock.run.return_value = readelf_output - self.assertRaises(ValueError, self.cmd_output_extractor.get_build_id, "path/to/bin") + self.assertRaises( + ValueError, self.cmd_output_extractor.get_build_id, "path/to/bin" + ) def test_get_build_id_returns_none(self): readelf_output = ( "Displaying notes found in: .note.gnu.build-id\n" " Owner Data size\tDescription\n" - " GNU 0x00000014\tNT_GNU_BUILD_ID (unique build ID bitstring)") + " GNU 0x00000014\tNT_GNU_BUILD_ID (unique build ID bitstring)" + ) self.cmd_client_mock.run.return_value = readelf_output build_id_output = self.cmd_output_extractor.get_build_id("path/to/bin") @@ -68,21 +75,23 @@ class TestGetBuildId(TestCmdOutputExtractor): class TestGetBinVersion(TestCmdOutputExtractor): def test_get_bin_version_returns_version(self): # Newer versions command output - version_cmd_output = ('db version v4.4.14-25-gb0475e2\n' - 'Build Info: {\n' - ' "version": "4.4.14-25-gb0475e2",\n' - ' "gitVersion": "b0475e2657c3351b25499971d3340f054ea85b98",\n' - ' "openSSLVersion": "OpenSSL 1.1.1 11 Sep 2018",\n' - ' "modules": [\n' - ' "enterprise"\n' - ' ],\n' - ' "allocator": "tcmalloc",\n' - ' "environment": {\n' - ' "distmod": "ubuntu1804",\n' - ' "distarch": "x86_64",\n' - ' "target_arch": "x86_64"\n' - ' }\n' - '}') + version_cmd_output = ( + "db version v4.4.14-25-gb0475e2\n" + "Build Info: {\n" + ' "version": "4.4.14-25-gb0475e2",\n' + ' "gitVersion": "b0475e2657c3351b25499971d3340f054ea85b98",\n' + ' "openSSLVersion": "OpenSSL 1.1.1 11 Sep 2018",\n' + ' "modules": [\n' + ' "enterprise"\n' + " ],\n" + ' "allocator": "tcmalloc",\n' + ' "environment": {\n' + ' "distmod": "ubuntu1804",\n' + ' "distarch": "x86_64",\n' + ' "target_arch": "x86_64"\n' + " }\n" + "}" + ) self.cmd_client_mock.run.return_value = version_cmd_output bin_version_output = self.cmd_output_extractor.get_bin_version("path/to/bin") @@ -91,15 +100,17 @@ class TestGetBinVersion(TestCmdOutputExtractor): def test_get_bin_version_unsupported_output(self): # Versions prior to 5.0 are not supported - version_cmd_output = ('db version v4.2.20-7-g5a81409\n' - 'git version: 5a81409faf16f30f1189af6367eb3ceee50a02b5\n' - 'OpenSSL version: OpenSSL 1.1.1 11 Sep 2018\n' - 'allocator: tcmalloc\n' - 'modules: enterprise \n' - 'build environment:\n' - ' distmod: ubuntu1804\n' - ' distarch: x86_64\n' - ' target_arch: x86_64') + version_cmd_output = ( + "db version v4.2.20-7-g5a81409\n" + "git version: 5a81409faf16f30f1189af6367eb3ceee50a02b5\n" + "OpenSSL version: OpenSSL 1.1.1 11 Sep 2018\n" + "allocator: tcmalloc\n" + "modules: enterprise \n" + "build environment:\n" + " distmod: ubuntu1804\n" + " distarch: x86_64\n" + " target_arch: x86_64" + ) self.cmd_client_mock.run.return_value = version_cmd_output bin_version_output = self.cmd_output_extractor.get_bin_version("path/to/bin") diff --git a/buildscripts/tests/test_errorcodes.py b/buildscripts/tests/test_errorcodes.py index d13e0036cc1..75ee3b1096f 100644 --- a/buildscripts/tests/test_errorcodes.py +++ b/buildscripts/tests/test_errorcodes.py @@ -1,4 +1,5 @@ """Unit tests for the selected_tests script.""" + import unittest from buildscripts import errorcodes @@ -6,7 +7,7 @@ from buildscripts import errorcodes # Debugging errorcodes.list_files = True -TESTDATA_DIR = './buildscripts/tests/data/errorcodes/' +TESTDATA_DIR = "./buildscripts/tests/data/errorcodes/" class TestErrorcodes(unittest.TestCase): @@ -23,12 +24,16 @@ class TestErrorcodes(unittest.TestCase): def accumulate_files(code): captured_error_codes.append(code) - errorcodes.parse_source_files(accumulate_files, TESTDATA_DIR + 'regex_matching/') + errorcodes.parse_source_files( + accumulate_files, TESTDATA_DIR + "regex_matching/" + ) self.assertEqual(32, len(captured_error_codes)) def test_dup_checking(self): """Test dup checking.""" - assertions, errors, _ = errorcodes.read_error_codes(TESTDATA_DIR + 'dup_checking/') + assertions, errors, _ = errorcodes.read_error_codes( + TESTDATA_DIR + "dup_checking/" + ) # `assertions` is every use of an error code. Duplicates are included. self.assertEqual(4, len(assertions)) self.assertEqual([1, 2, 3, 2], list(map(lambda x: int(x.code), assertions))) @@ -40,7 +45,7 @@ class TestErrorcodes(unittest.TestCase): def test_generate_next_code(self): """Test `get_next_code`.""" - _, _, seen = errorcodes.read_error_codes(TESTDATA_DIR + 'generate_next_code/') + _, _, seen = errorcodes.read_error_codes(TESTDATA_DIR + "generate_next_code/") generator = errorcodes.get_next_code(seen) self.assertEqual(21, next(generator)) self.assertEqual(22, next(generator)) @@ -53,7 +58,9 @@ class TestErrorcodes(unittest.TestCase): `server_ticket` is passed in. But it maybe makes sense for the test to do so in case a future patch changes that relationship. """ - _, _, seen = errorcodes.read_error_codes(TESTDATA_DIR + 'generate_next_server_code/') + _, _, seen = errorcodes.read_error_codes( + TESTDATA_DIR + "generate_next_server_code/" + ) print("Seen: " + str(seen)) generator = errorcodes.get_next_code(seen, server_ticket=12301) self.assertEqual(1230101, next(generator)) @@ -62,7 +69,7 @@ class TestErrorcodes(unittest.TestCase): def test_ticket_coersion(self): """Test `coerce_to_number`.""" self.assertEqual(0, errorcodes.coerce_to_number(0)) - self.assertEqual(1234, errorcodes.coerce_to_number('1234')) - self.assertEqual(1234, errorcodes.coerce_to_number('server-1234')) - self.assertEqual(1234, errorcodes.coerce_to_number('SERVER-1234')) - self.assertEqual(-1, errorcodes.coerce_to_number('not a ticket')) + self.assertEqual(1234, errorcodes.coerce_to_number("1234")) + self.assertEqual(1234, errorcodes.coerce_to_number("server-1234")) + self.assertEqual(1234, errorcodes.coerce_to_number("SERVER-1234")) + self.assertEqual(-1, errorcodes.coerce_to_number("not a ticket")) diff --git a/buildscripts/tests/test_evergreen_activate_gen_tasks.py b/buildscripts/tests/test_evergreen_activate_gen_tasks.py index bf14d682b41..31252ee176c 100644 --- a/buildscripts/tests/test_evergreen_activate_gen_tasks.py +++ b/buildscripts/tests/test_evergreen_activate_gen_tasks.py @@ -1,4 +1,5 @@ """Unit tests for the generate_resmoke_suite script.""" + # pylint: disable=invalid-name import unittest @@ -17,7 +18,7 @@ def build_mock_task_list(num_tasks): return [build_mock_task(f"task_{i}", f"id_{i}") for i in range(num_tasks)] -class MockVariantData(): +class MockVariantData: """An object to help create a mock evg api.""" def __init__(self, build_id, variant_name, task_list): @@ -55,70 +56,96 @@ def build_mock_evg_api(variant_data_list): class TestActivateTask(unittest.TestCase): def test_task_with_display_name_is_activated(self): - expansions = under_test.EvgExpansions(**{ - "build_id": "build_id", - "version_id": "version_id", - "task_name": "task_3_gen", - }) + expansions = under_test.EvgExpansions( + **{ + "build_id": "build_id", + "version_id": "version_id", + "task_name": "task_3_gen", + } + ) mock_task_list = build_mock_task_list(5) mock_evg_api = build_mock_evg_api( - [MockVariantData("build_id", "non-burn-in-bv", mock_task_list)]) + [MockVariantData("build_id", "non-burn-in-bv", mock_task_list)] + ) under_test.activate_task(expansions, mock_evg_api) mock_evg_api.configure_task.assert_called_with("id_3", activated=True) def test_task_with_no_matching_name(self): - expansions = under_test.EvgExpansions(**{ - "build_id": "build_id", - "version_id": "version_id", - "task_name": "not_an_existing_task", - }) + expansions = under_test.EvgExpansions( + **{ + "build_id": "build_id", + "version_id": "version_id", + "task_name": "not_an_existing_task", + } + ) mock_task_list = build_mock_task_list(5) mock_evg_api = build_mock_evg_api( - [MockVariantData("build_id", "non-burn-in-bv", mock_task_list)]) + [MockVariantData("build_id", "non-burn-in-bv", mock_task_list)] + ) under_test.activate_task(expansions, mock_evg_api) mock_evg_api.configure_task.assert_not_called() def test_burn_in_tags_tasks_are_activated(self): - expansions = under_test.EvgExpansions(**{ - "build_id": "build_id", - "version_id": "version_id", - "task_name": "burn_in_tags_gen", - }) + expansions = under_test.EvgExpansions( + **{ + "build_id": "build_id", + "version_id": "version_id", + "task_name": "burn_in_tags_gen", + } + ) mock_task_list_2 = build_mock_task_list(5) mock_task_list_2.append(build_mock_task("burn_in_tests", "burn_in_tests_id_2")) mock_task_list_3 = build_mock_task_list(5) mock_task_list_3.append(build_mock_task("burn_in_tests", "burn_in_tests_id_3")) - mock_evg_api = build_mock_evg_api([ - MockVariantData("1", "variant1-generated-by-burn-in-tags", mock_task_list_2), - MockVariantData("2", "variant2-generated-by-burn-in-tags", mock_task_list_3) - ]) + mock_evg_api = build_mock_evg_api( + [ + MockVariantData( + "1", "variant1-generated-by-burn-in-tags", mock_task_list_2 + ), + MockVariantData( + "2", "variant2-generated-by-burn-in-tags", mock_task_list_3 + ), + ] + ) under_test.activate_task(expansions, mock_evg_api) - mock_evg_api.configure_task.assert_has_calls([ - mock.call("burn_in_tests_id_2", activated=True), - mock.call("burn_in_tests_id_3", activated=True) - ]) + mock_evg_api.configure_task.assert_has_calls( + [ + mock.call("burn_in_tests_id_2", activated=True), + mock.call("burn_in_tests_id_3", activated=True), + ] + ) def test_burn_in_tags_task_skips_non_existing_build_variant(self): - expansions = under_test.EvgExpansions(**{ - "build_id": "build_id", - "version_id": "version_id", - "task_name": "burn_in_tags_gen", - }) + expansions = under_test.EvgExpansions( + **{ + "build_id": "build_id", + "version_id": "version_id", + "task_name": "burn_in_tags_gen", + } + ) mock_task_list_1 = build_mock_task_list(5) - mock_task_list_1.append(build_mock_task("burn_in_tags_gen", "burn_in_tags_gen_id_1")) + mock_task_list_1.append( + build_mock_task("burn_in_tags_gen", "burn_in_tags_gen_id_1") + ) mock_task_list_2 = build_mock_task_list(5) mock_task_list_2.append(build_mock_task("burn_in_tests", "burn_in_tests_id_2")) - mock_evg_api = build_mock_evg_api([ - MockVariantData("1", "variant1-non-burn-in", mock_task_list_1), - MockVariantData("2", "variant2-generated-by-burn-in-tags", mock_task_list_2) - ]) + mock_evg_api = build_mock_evg_api( + [ + MockVariantData("1", "variant1-non-burn-in", mock_task_list_1), + MockVariantData( + "2", "variant2-generated-by-burn-in-tags", mock_task_list_2 + ), + ] + ) under_test.activate_task(expansions, mock_evg_api) - mock_evg_api.configure_task.assert_called_once_with("burn_in_tests_id_2", activated=True) + mock_evg_api.configure_task.assert_called_once_with( + "burn_in_tests_id_2", activated=True + ) diff --git a/buildscripts/tests/test_evergreen_resmoke_job_count.py b/buildscripts/tests/test_evergreen_resmoke_job_count.py index 06815340c6c..b6a32e31661 100644 --- a/buildscripts/tests/test_evergreen_resmoke_job_count.py +++ b/buildscripts/tests/test_evergreen_resmoke_job_count.py @@ -13,11 +13,15 @@ class DetermineJobsTest(unittest.TestCase): regex = "regexthatmatches" mytask_factor = 0.5 regex_factor = 0.25 - task_factors = [{"task": mytask, "factor": mytask_factor}, - {"task": "regex.*", "factor": regex_factor}] + task_factors = [ + {"task": mytask, "factor": mytask_factor}, + {"task": "regex.*", "factor": regex_factor}, + ] def test_determine_jobs_no_matching_task(self): - jobs = under_test.determine_jobs("_no_match_", "_no_variant_", "_no_distro_", 0, 1) + jobs = under_test.determine_jobs( + "_no_match_", "_no_variant_", "_no_distro_", 0, 1 + ) self.assertEqual(self.cpu_count, jobs) def test_determine_jobs_matching_variant(self): @@ -83,12 +87,18 @@ class DetermineJobsTest(unittest.TestCase): under_test.SYS_PLATFORM = "myplatform" mytask_factor_min = 0.5 regex_factor_min = 0.25 - task_factors1 = [{"task": "mytask", "factor": mytask_factor_min + .5}, - {"task": "regex.*", "factor": regex_factor_min + .5}] - task_factors2 = [{"task": "mytask", "factor": mytask_factor_min + .25}, - {"task": "regex.*", "factor": regex_factor_min + .25}] - task_factors3 = [{"task": "mytask", "factor": mytask_factor_min}, - {"task": "regex.*", "factor": regex_factor_min}] + task_factors1 = [ + {"task": "mytask", "factor": mytask_factor_min + 0.5}, + {"task": "regex.*", "factor": regex_factor_min + 0.5}, + ] + task_factors2 = [ + {"task": "mytask", "factor": mytask_factor_min + 0.25}, + {"task": "regex.*", "factor": regex_factor_min + 0.25}, + ] + task_factors3 = [ + {"task": "mytask", "factor": mytask_factor_min}, + {"task": "regex.*", "factor": regex_factor_min}, + ] under_test.VARIANT_TASK_FACTOR_OVERRIDES = {"myvariant": task_factors1} under_test.MACHINE_TASK_FACTOR_OVERRIDES = {"mymachine": task_factors2} under_test.PLATFORM_TASK_FACTOR_OVERRIDES = {"myplatform": task_factors3} @@ -99,15 +109,21 @@ class DetermineJobsTest(unittest.TestCase): def test_determine_jobs_factor(self): factor = 0.4 - jobs = under_test.determine_jobs("_no_match_", "_no_variant_", "_no_distro_", 0, factor) + jobs = under_test.determine_jobs( + "_no_match_", "_no_variant_", "_no_distro_", 0, factor + ) self.assertEqual(int(round(self.cpu_count * factor)), jobs) def test_determine_jobs_jobs_max(self): jobs_max = 3 - jobs = under_test.determine_jobs("_no_match_", "_no_variant_", "_no_distro_", jobs_max, 1) + jobs = under_test.determine_jobs( + "_no_match_", "_no_variant_", "_no_distro_", jobs_max, 1 + ) self.assertEqual(min(jobs_max, jobs), jobs) jobs_max = 30 - jobs = under_test.determine_jobs("_no_match_", "_no_variant_", "_no_distro_", jobs_max, 1) + jobs = under_test.determine_jobs( + "_no_match_", "_no_variant_", "_no_distro_", jobs_max, 1 + ) self.assertEqual(min(jobs_max, jobs), jobs) def test_determine_jobs_with_global_specification(self): @@ -120,7 +136,9 @@ class DetermineJobsTest(unittest.TestCase): } variant = "a_build_variant" distro = "a_distro" - job_count_matching = under_test.determine_jobs(task, variant, distro, jobs_max=jobs_default) + job_count_matching = under_test.determine_jobs( + task, variant, distro, jobs_max=jobs_default + ) self.assertEqual(jobs_default * target_factor, job_count_matching) def test_determine_jobs_without_global_specification(self): @@ -134,5 +152,7 @@ class DetermineJobsTest(unittest.TestCase): variant = "a_build_variant" distro = "a_distro" - job_count_matching = under_test.determine_jobs(task, variant, distro, jobs_max=jobs_default) + job_count_matching = under_test.determine_jobs( + task, variant, distro, jobs_max=jobs_default + ) self.assertEqual(jobs_default, job_count_matching) diff --git a/buildscripts/tests/test_evergreen_task_tags.py b/buildscripts/tests/test_evergreen_task_tags.py index f9da42e80da..9c92f6aa2b6 100644 --- a/buildscripts/tests/test_evergreen_task_tags.py +++ b/buildscripts/tests/test_evergreen_task_tags.py @@ -27,7 +27,9 @@ class TestGetAllTaskTags(unittest.TestCase): def test_with_some_tags(self): task_prefixes = ["b", "a", "q", "v"] n_tags = 3 - task_list_mock = [MagicMock(tags=gen_tag_set(prefix, n_tags)) for prefix in task_prefixes] + task_list_mock = [ + MagicMock(tags=gen_tag_set(prefix, n_tags)) for prefix in task_prefixes + ] evg_config_mock = MagicMock(tasks=task_list_mock) tag_list = ett.get_all_task_tags(evg_config_mock) @@ -49,7 +51,9 @@ class TestGetTasksWithTag(unittest.TestCase): def test_with_one_tag_each(self): task_prefixes = ["b", "a", "b", "v"] n_tags = 3 - task_list_mock = [MagicMock(tags=gen_tag_set(prefix, n_tags)) for prefix in task_prefixes] + task_list_mock = [ + MagicMock(tags=gen_tag_set(prefix, n_tags)) for prefix in task_prefixes + ] for index, task in enumerate(task_list_mock): task.name = "task " + str(index) evg_config_mock = MagicMock(tasks=task_list_mock) @@ -63,7 +67,9 @@ class TestGetTasksWithTag(unittest.TestCase): def test_with_two_tags(self): task_prefixes = ["b", "a", "b", "v"] n_tags = 3 - task_list_mock = [MagicMock(tags=gen_tag_set(prefix, n_tags)) for prefix in task_prefixes] + task_list_mock = [ + MagicMock(tags=gen_tag_set(prefix, n_tags)) for prefix in task_prefixes + ] for index, task in enumerate(task_list_mock): task.name = "task " + str(index) evg_config_mock = MagicMock(tasks=task_list_mock) @@ -77,7 +83,9 @@ class TestGetTasksWithTag(unittest.TestCase): def test_with_two_tags_no_results(self): task_prefixes = ["b", "a", "b", "v"] n_tags = 3 - task_list_mock = [MagicMock(tags=gen_tag_set(prefix, n_tags)) for prefix in task_prefixes] + task_list_mock = [ + MagicMock(tags=gen_tag_set(prefix, n_tags)) for prefix in task_prefixes + ] for index, task in enumerate(task_list_mock): task.name = "task " + str(index) evg_config_mock = MagicMock(tasks=task_list_mock) @@ -88,7 +96,9 @@ class TestGetTasksWithTag(unittest.TestCase): def test_with_one_filter(self): task_prefixes = ["b", "a", "b", "v"] n_tags = 3 - task_list_mock = [MagicMock(tags=gen_tag_set(prefix, n_tags)) for prefix in task_prefixes] + task_list_mock = [ + MagicMock(tags=gen_tag_set(prefix, n_tags)) for prefix in task_prefixes + ] for index, task in enumerate(task_list_mock): task.name = "task " + str(index) task_list_mock[0].tags = ["b 0"] @@ -101,7 +111,9 @@ class TestGetTasksWithTag(unittest.TestCase): def test_with_two_filter(self): task_prefixes = ["b", "a", "b", "v"] n_tags = 3 - task_list_mock = [MagicMock(tags=gen_tag_set(prefix, n_tags)) for prefix in task_prefixes] + task_list_mock = [ + MagicMock(tags=gen_tag_set(prefix, n_tags)) for prefix in task_prefixes + ] for index, task in enumerate(task_list_mock): task.name = "task " + str(index) task_list_mock[0].tags = ["b 0"] diff --git a/buildscripts/tests/test_evergreen_task_timeout.py b/buildscripts/tests/test_evergreen_task_timeout.py index 72c6c5f8cc1..dbbfe164596 100644 --- a/buildscripts/tests/test_evergreen_task_timeout.py +++ b/buildscripts/tests/test_evergreen_task_timeout.py @@ -1,4 +1,5 @@ """Unit tests for the evergreen_task_timeout script.""" + import unittest from datetime import timedelta from unittest.mock import MagicMock @@ -52,16 +53,20 @@ class TestTimeoutOverrides(unittest.TestCase): def test_looking_up_a_duplicate_override_should_raise_error(self): timeout_overrides = under_test.TimeoutOverrides( overrides={ - "bv": [{ - "task": "task_name", - "exec_timeout": 42, - "idle_timeout": 10, - }, { - "task": "task_name", - "exec_timeout": 314, - "idle_timeout": 20, - }] - }) + "bv": [ + { + "task": "task_name", + "exec_timeout": 42, + "idle_timeout": 10, + }, + { + "task": "task_name", + "exec_timeout": 314, + "idle_timeout": 20, + }, + ] + } + ) with self.assertRaises(ValueError): self.assertIsNone(timeout_overrides.lookup_exec_override("bv", "task_name")) @@ -83,10 +88,13 @@ class TestTimeoutOverrides(unittest.TestCase): "exec_timeout": 42, }, ] - }) + } + ) - self.assertEqual(42 * 60, - timeout_overrides.lookup_exec_override("bv", "task_name").total_seconds()) + self.assertEqual( + 42 * 60, + timeout_overrides.lookup_exec_override("bv", "task_name").total_seconds(), + ) def test_looking_up_an_idle_override_should_work(self): timeout_overrides = under_test.TimeoutOverrides( @@ -102,15 +110,27 @@ class TestTimeoutOverrides(unittest.TestCase): "idle_timeout": 10, }, ] - }) + } + ) - self.assertEqual(10 * 60, - timeout_overrides.lookup_idle_override("bv", "task_name").total_seconds()) + self.assertEqual( + 10 * 60, + timeout_overrides.lookup_idle_override("bv", "task_name").total_seconds(), + ) class TestDetermineExecTimeout(unittest.TestCase): - def _validate_exec_timeout(self, idle_timeout, exec_timeout, historic_timeout, evg_alias, - build_variant, display_name, timeout_override, expected_timeout): + def _validate_exec_timeout( + self, + idle_timeout, + exec_timeout, + historic_timeout, + evg_alias, + build_variant, + display_name, + timeout_override, + expected_timeout, + ): task_name = "task_name" variant = build_variant overrides = {} @@ -121,114 +141,206 @@ class TestDetermineExecTimeout(unittest.TestCase): orchestrator = under_test.TaskTimeoutOrchestrator( timeout_service=MagicMock(spec_set=TimeoutService), - timeout_overrides=mock_timeout_overrides, evg_project_config=MagicMock( + timeout_overrides=mock_timeout_overrides, + evg_project_config=MagicMock( spec_set=EvergreenProjectConfig, - get_variant=MagicMock(return_value=MagicMock(display_name=display_name)))) + get_variant=MagicMock( + return_value=MagicMock(display_name=display_name) + ), + ), + ) actual_timeout = orchestrator.determine_exec_timeout( - task_name, variant, idle_timeout, exec_timeout, evg_alias, historic_timeout) + task_name, variant, idle_timeout, exec_timeout, evg_alias, historic_timeout + ) self.assertEqual(actual_timeout, expected_timeout) def test_timeout_used_if_specified(self): - self._validate_exec_timeout(idle_timeout=None, exec_timeout=timedelta(seconds=42), - historic_timeout=None, evg_alias=None, build_variant="variant", - display_name="not required", timeout_override=None, - expected_timeout=timedelta(seconds=42)) + self._validate_exec_timeout( + idle_timeout=None, + exec_timeout=timedelta(seconds=42), + historic_timeout=None, + evg_alias=None, + build_variant="variant", + display_name="not required", + timeout_override=None, + expected_timeout=timedelta(seconds=42), + ) def test_default_is_returned_with_no_timeout(self): - self._validate_exec_timeout(idle_timeout=None, exec_timeout=None, historic_timeout=None, - evg_alias=None, build_variant="variant", - display_name="not required", timeout_override=None, - expected_timeout=under_test.DEFAULT_NON_REQUIRED_BUILD_TIMEOUT) + self._validate_exec_timeout( + idle_timeout=None, + exec_timeout=None, + historic_timeout=None, + evg_alias=None, + build_variant="variant", + display_name="not required", + timeout_override=None, + expected_timeout=under_test.DEFAULT_NON_REQUIRED_BUILD_TIMEOUT, + ) def test_default_is_returned_with_timeout_at_zero(self): - self._validate_exec_timeout(idle_timeout=None, exec_timeout=timedelta(seconds=0), - historic_timeout=None, evg_alias=None, build_variant="variant", - display_name="not required", timeout_override=None, - expected_timeout=under_test.DEFAULT_NON_REQUIRED_BUILD_TIMEOUT) + self._validate_exec_timeout( + idle_timeout=None, + exec_timeout=timedelta(seconds=0), + historic_timeout=None, + evg_alias=None, + build_variant="variant", + display_name="not required", + timeout_override=None, + expected_timeout=under_test.DEFAULT_NON_REQUIRED_BUILD_TIMEOUT, + ) def test_default_required_returned_on_required_variants(self): - self._validate_exec_timeout(idle_timeout=None, exec_timeout=None, historic_timeout=None, - evg_alias=None, build_variant="variant-required", - display_name="! required", timeout_override=None, - expected_timeout=under_test.DEFAULT_REQUIRED_BUILD_TIMEOUT) + self._validate_exec_timeout( + idle_timeout=None, + exec_timeout=None, + historic_timeout=None, + evg_alias=None, + build_variant="variant-required", + display_name="! required", + timeout_override=None, + expected_timeout=under_test.DEFAULT_REQUIRED_BUILD_TIMEOUT, + ) def test_override_on_required_should_use_override(self): - self._validate_exec_timeout(idle_timeout=None, exec_timeout=None, historic_timeout=None, - evg_alias=None, build_variant="variant-required", - display_name="! required", timeout_override=3 * 60, - expected_timeout=timedelta(minutes=3 * 60)) + self._validate_exec_timeout( + idle_timeout=None, + exec_timeout=None, + historic_timeout=None, + evg_alias=None, + build_variant="variant-required", + display_name="! required", + timeout_override=3 * 60, + expected_timeout=timedelta(minutes=3 * 60), + ) def test_task_specific_timeout(self): - self._validate_exec_timeout(idle_timeout=None, exec_timeout=timedelta(seconds=0), - historic_timeout=None, evg_alias=None, build_variant="variant", - display_name="not required", timeout_override=60, - expected_timeout=timedelta(minutes=60)) + self._validate_exec_timeout( + idle_timeout=None, + exec_timeout=timedelta(seconds=0), + historic_timeout=None, + evg_alias=None, + build_variant="variant", + display_name="not required", + timeout_override=60, + expected_timeout=timedelta(minutes=60), + ) def test_commit_queue_items_use_commit_queue_timeout(self): - self._validate_exec_timeout(idle_timeout=None, exec_timeout=None, historic_timeout=None, - evg_alias=under_test.COMMIT_QUEUE_ALIAS, - build_variant="variant", display_name="not required", - timeout_override=None, - expected_timeout=under_test.COMMIT_QUEUE_TIMEOUT) + self._validate_exec_timeout( + idle_timeout=None, + exec_timeout=None, + historic_timeout=None, + evg_alias=under_test.COMMIT_QUEUE_ALIAS, + build_variant="variant", + display_name="not required", + timeout_override=None, + expected_timeout=under_test.COMMIT_QUEUE_TIMEOUT, + ) def test_use_idle_timeout_if_greater_than_exec_timeout(self): self._validate_exec_timeout( - idle_timeout=timedelta(hours=2), exec_timeout=timedelta(minutes=10), - historic_timeout=None, evg_alias=None, build_variant="variant", - display_name="not required", timeout_override=None, expected_timeout=timedelta(hours=2)) + idle_timeout=timedelta(hours=2), + exec_timeout=timedelta(minutes=10), + historic_timeout=None, + evg_alias=None, + build_variant="variant", + display_name="not required", + timeout_override=None, + expected_timeout=timedelta(hours=2), + ) def test_historic_timeout_should_be_used_if_given(self): - self._validate_exec_timeout(idle_timeout=None, exec_timeout=None, - historic_timeout=timedelta(minutes=15), evg_alias=None, - build_variant="variant", display_name="not required", - timeout_override=None, expected_timeout=timedelta(minutes=15)) + self._validate_exec_timeout( + idle_timeout=None, + exec_timeout=None, + historic_timeout=timedelta(minutes=15), + evg_alias=None, + build_variant="variant", + display_name="not required", + timeout_override=None, + expected_timeout=timedelta(minutes=15), + ) def test_commit_queue_should_override_historic_timeouts(self): self._validate_exec_timeout( - idle_timeout=None, exec_timeout=None, historic_timeout=timedelta(minutes=15), - evg_alias=under_test.COMMIT_QUEUE_ALIAS, build_variant="variant", - display_name="not required", timeout_override=None, - expected_timeout=under_test.COMMIT_QUEUE_TIMEOUT) + idle_timeout=None, + exec_timeout=None, + historic_timeout=timedelta(minutes=15), + evg_alias=under_test.COMMIT_QUEUE_ALIAS, + build_variant="variant", + display_name="not required", + timeout_override=None, + expected_timeout=under_test.COMMIT_QUEUE_TIMEOUT, + ) def test_override_should_override_historic_timeouts(self): - self._validate_exec_timeout(idle_timeout=None, exec_timeout=None, - historic_timeout=timedelta(minutes=15), evg_alias=None, - build_variant="variant", display_name="not required", - timeout_override=33, expected_timeout=timedelta(minutes=33)) + self._validate_exec_timeout( + idle_timeout=None, + exec_timeout=None, + historic_timeout=timedelta(minutes=15), + evg_alias=None, + build_variant="variant", + display_name="not required", + timeout_override=33, + expected_timeout=timedelta(minutes=33), + ) def test_historic_timeout_should_not_be_overridden_by_required_bv(self): - self._validate_exec_timeout(idle_timeout=None, exec_timeout=None, - historic_timeout=timedelta(minutes=15), evg_alias=None, - build_variant="variant-required", display_name="! required", - timeout_override=None, expected_timeout=timedelta(minutes=15)) + self._validate_exec_timeout( + idle_timeout=None, + exec_timeout=None, + historic_timeout=timedelta(minutes=15), + evg_alias=None, + build_variant="variant-required", + display_name="! required", + timeout_override=None, + expected_timeout=timedelta(minutes=15), + ) def test_historic_timeout_should_not_be_increase_required_bv_timeout(self): self._validate_exec_timeout( - idle_timeout=None, exec_timeout=None, - historic_timeout=under_test.DEFAULT_REQUIRED_BUILD_TIMEOUT + timedelta(minutes=30), - evg_alias=None, build_variant="variant-required", display_name="! required", - timeout_override=None, expected_timeout=under_test.DEFAULT_REQUIRED_BUILD_TIMEOUT) + idle_timeout=None, + exec_timeout=None, + historic_timeout=under_test.DEFAULT_REQUIRED_BUILD_TIMEOUT + + timedelta(minutes=30), + evg_alias=None, + build_variant="variant-required", + display_name="! required", + timeout_override=None, + expected_timeout=under_test.DEFAULT_REQUIRED_BUILD_TIMEOUT, + ) class TestDetermineIdleTimeout(unittest.TestCase): - def _validate_idle_timeout(self, idle_timeout, historic_timeout, build_variant, - timeout_override, expected_timeout): + def _validate_idle_timeout( + self, + idle_timeout, + historic_timeout, + build_variant, + timeout_override, + expected_timeout, + ): task_name = "task_name" overrides = {} if timeout_override is not None: - overrides[build_variant] = [{"task": task_name, "idle_timeout": timeout_override}] + overrides[build_variant] = [ + {"task": task_name, "idle_timeout": timeout_override} + ] mock_timeout_overrides = under_test.TimeoutOverrides(overrides=overrides) orchestrator = under_test.TaskTimeoutOrchestrator( timeout_service=MagicMock(spec_set=TimeoutService), timeout_overrides=mock_timeout_overrides, - evg_project_config=MagicMock(spec_set=EvergreenProjectConfig)) + evg_project_config=MagicMock(spec_set=EvergreenProjectConfig), + ) - actual_timeout = orchestrator.determine_idle_timeout(task_name, build_variant, idle_timeout, - historic_timeout) + actual_timeout = orchestrator.determine_idle_timeout( + task_name, build_variant, idle_timeout, historic_timeout + ) self.assertEqual(actual_timeout, expected_timeout) @@ -260,11 +372,19 @@ class TestDetermineIdleTimeout(unittest.TestCase): ) def test_historic_timeout_should_be_used_if_given(self): - self._validate_idle_timeout(idle_timeout=None, historic_timeout=timedelta(minutes=15), - build_variant="variant", timeout_override=None, - expected_timeout=timedelta(minutes=15)) + self._validate_idle_timeout( + idle_timeout=None, + historic_timeout=timedelta(minutes=15), + build_variant="variant", + timeout_override=None, + expected_timeout=timedelta(minutes=15), + ) def test_override_should_override_historic_timeout(self): - self._validate_idle_timeout(idle_timeout=None, historic_timeout=timedelta(minutes=15), - build_variant="variant", timeout_override=30, - expected_timeout=timedelta(minutes=30)) + self._validate_idle_timeout( + idle_timeout=None, + historic_timeout=timedelta(minutes=15), + build_variant="variant", + timeout_override=30, + expected_timeout=timedelta(minutes=30), + ) diff --git a/buildscripts/tests/test_exception_exctractor.py b/buildscripts/tests/test_exception_exctractor.py index bb9639dd9c9..006502b76aa 100644 --- a/buildscripts/tests/test_exception_exctractor.py +++ b/buildscripts/tests/test_exception_exctractor.py @@ -62,10 +62,17 @@ class TestExceptionExtractor(unittest.TestCase): assert not exception_extractor.get_exception() def test_successful_extraction_truncate_first(self): - logs = ["START"] + ["not captured" - ] + ["captured"] * (under_test.MAX_EXCEPTION_LENGTH - 1) + ["END"] - expected_exception = ["[LAST Part of Exception]" - ] + ["captured"] * (under_test.MAX_EXCEPTION_LENGTH - 1) + ["END"] + logs = ( + ["START"] + + ["not captured"] + + ["captured"] * (under_test.MAX_EXCEPTION_LENGTH - 1) + + ["END"] + ) + expected_exception = ( + ["[LAST Part of Exception]"] + + ["captured"] * (under_test.MAX_EXCEPTION_LENGTH - 1) + + ["END"] + ) exception_extractor = self.get_exception_extractor() for log in logs: exception_extractor.process_log_line(log) @@ -74,10 +81,17 @@ class TestExceptionExtractor(unittest.TestCase): assert exception_extractor.get_exception() == expected_exception def test_successful_extraction_truncate_last(self): - logs = ["START"] + ["captured"] * (under_test.MAX_EXCEPTION_LENGTH - 1) + ["not captured" - ] + ["END"] - expected_exception = ["[FIRST Part of Exception]" - ] + ["START"] + ["captured"] * (under_test.MAX_EXCEPTION_LENGTH - 1) + logs = ( + ["START"] + + ["captured"] * (under_test.MAX_EXCEPTION_LENGTH - 1) + + ["not captured"] + + ["END"] + ) + expected_exception = ( + ["[FIRST Part of Exception]"] + + ["START"] + + ["captured"] * (under_test.MAX_EXCEPTION_LENGTH - 1) + ) exception_extractor = self.get_exception_extractor(under_test.Truncate.LAST) for log in logs: exception_extractor.process_log_line(log) diff --git a/buildscripts/tests/test_feature_flag_tags_check.py b/buildscripts/tests/test_feature_flag_tags_check.py index dc3b296daf4..ec49ef85f62 100644 --- a/buildscripts/tests/test_feature_flag_tags_check.py +++ b/buildscripts/tests/test_feature_flag_tags_check.py @@ -19,14 +19,19 @@ class TestFindTestsInGitDiff(unittest.TestCase): def test_get_tests_missing_fcv_tag_no_tag(self): tests = ["dummy_jstest_file.js"] - with patch.object(feature_flag_tags_check.jscomment, "get_tags", return_value=[]): + with patch.object( + feature_flag_tags_check.jscomment, "get_tags", return_value=[] + ): result = feature_flag_tags_check.get_tests_missing_fcv_tag(tests) self.assertCountEqual(tests, result) def test_get_tests_missing_fcv_tag_have_tag(self): tests = ["dummy_jstest_file.js"] - with patch.object(feature_flag_tags_check.jscomment, "get_tags", - return_value=[self.requires_fcv_tag]): + with patch.object( + feature_flag_tags_check.jscomment, + "get_tags", + return_value=[self.requires_fcv_tag], + ): result = feature_flag_tags_check.get_tests_missing_fcv_tag(tests) self.assertCountEqual([], result) diff --git a/buildscripts/tests/test_generate_sbom.py b/buildscripts/tests/test_generate_sbom.py index d6d17e4f687..d116aeaf457 100644 --- a/buildscripts/tests/test_generate_sbom.py +++ b/buildscripts/tests/test_generate_sbom.py @@ -31,7 +31,9 @@ class TestEndorctl(unittest.TestCase): logger = logging.getLogger("generate_sbom") logger.setLevel(logging.INFO) - e = EndorCtl(namespace="mongodb.10gen", endorctl_path="this_path_does_not_exist") + e = EndorCtl( + namespace="mongodb.10gen", endorctl_path="this_path_does_not_exist" + ) result = e.get_sbom_for_project("https://github.com/10gen/mongo.git") self.assertRaises(FileNotFoundError) self.assertIsNone(result, None) @@ -177,7 +179,9 @@ class TestConfigRegex(unittest.TestCase): print("\nTesting Invalid PURLs:") for purl in invalid_purls: with self.subTest(purl=purl): - self.assertFalse(is_valid_purl(purl), f"Expected '{purl}' to be invalid") + self.assertFalse( + is_valid_purl(purl), f"Expected '{purl}' to be invalid" + ) __unittest = True diff --git a/buildscripts/tests/test_jepsen_report.py b/buildscripts/tests/test_jepsen_report.py index 00ef273dfc7..d328846a10a 100644 --- a/buildscripts/tests/test_jepsen_report.py +++ b/buildscripts/tests/test_jepsen_report.py @@ -1,4 +1,5 @@ """Tests for jepsen report generator.""" + import os import random import textwrap @@ -116,34 +117,36 @@ Everything looks good! ヽ(‘ー`)ノ """) return { - 'expected': - ParserOutput({ - 'success': successful_tests, - 'unknown': indeterminate_tests, - 'crashed': crashed_tests, - 'failed': failed_tests, - }), 'corpus': - corpus + "expected": ParserOutput( + { + "success": successful_tests, + "unknown": indeterminate_tests, + "crashed": crashed_tests, + "failed": failed_tests, + } + ), + "corpus": corpus, } def test_parser(self): """Test with embedded corpus.""" out = parse(_CORPUS) - self.assertEqual(len(out['success']), 28) - self.assertEqual(len(out['unknown']), 0) - self.assertEqual(len(out['crashed']), 2) - self.assertEqual(len(out['failed']), 0) + self.assertEqual(len(out["success"]), 28) + self.assertEqual(len(out["unknown"]), 0) + self.assertEqual(len(out["crashed"]), 2) + self.assertEqual(len(out["failed"]), 0) def test_parser2(self): """Test with jepsen.log file.""" - with open(os.path.join(os.path.dirname(__file__), - "test_jepsen_report_corpus.log.txt")) as fh: + with open( + os.path.join(os.path.dirname(__file__), "test_jepsen_report_corpus.log.txt") + ) as fh: corpus = fh.read().splitlines() out = parse(corpus) - self.assertEqual(len(out['success']), 29) - self.assertEqual(len(out['unknown']), 0) - self.assertEqual(len(out['crashed']), 1) - self.assertEqual(len(out['failed']), 0) + self.assertEqual(len(out["success"]), 29) + self.assertEqual(len(out["unknown"]), 0) + self.assertEqual(len(out["crashed"]), 1) + self.assertEqual(len(out["failed"]), 0) def test_generated_corpus(self): """Generate 100 corpuses and test them.""" @@ -152,25 +155,27 @@ Everything looks good! ヽ(‘ー`)ノ def _test_generated_corpus(self): gen = self._corpus_generator() - corpus = gen['corpus'].splitlines() + corpus = gen["corpus"].splitlines() out = parse(corpus) - self.assertDictEqual(out, gen['expected']) + self.assertDictEqual(out, gen["expected"]) - @patch('buildscripts.jepsen_report._try_find_log_file') - @patch('buildscripts.jepsen_report._get_log_lines') - @patch('buildscripts.jepsen_report._put_report') + @patch("buildscripts.jepsen_report._try_find_log_file") + @patch("buildscripts.jepsen_report._get_log_lines") + @patch("buildscripts.jepsen_report._put_report") def test_main(self, mock_put_report, mock_get_log_lines, mock_try_find_log_file): """Test main function.""" gen = self._corpus_generator() - corpus = gen['corpus'].splitlines() + corpus = gen["corpus"].splitlines() mock_get_log_lines.return_value = corpus def _try_find_log_file(_store, _test): if _try_find_log_file.counter == 0: _try_find_log_file.counter += 1 with open( - os.path.join( - os.path.dirname(__file__), "test_jepsen_report_corpus.log.txt")) as fh: + os.path.join( + os.path.dirname(__file__), "test_jepsen_report_corpus.log.txt" + ) + ) as fh: return fh.read() return "" @@ -178,19 +183,24 @@ Everything looks good! ヽ(‘ー`)ノ mock_try_find_log_file.side_effect = _try_find_log_file runner = CliRunner() - result = runner.invoke(main, - ["--start_time=0", "--end_time=10", "--elapsed=10", "test.log"]) - num_tests = len(gen['expected']['success']) + len(gen['expected']['unknown']) + len( - gen['expected']['crashed']) + len(gen['expected']['failed']) - num_fails = num_tests - len(gen['expected']['success']) + result = runner.invoke( + main, ["--start_time=0", "--end_time=10", "--elapsed=10", "test.log"] + ) + num_tests = ( + len(gen["expected"]["success"]) + + len(gen["expected"]["unknown"]) + + len(gen["expected"]["crashed"]) + + len(gen["expected"]["failed"]) + ) + num_fails = num_tests - len(gen["expected"]["success"]) callee_dict = mock_put_report.call_args[0][0] - self.assertEqual(callee_dict['failures'], num_fails) - self.assertEqual(len(callee_dict['results']), num_tests) - mock_get_log_lines.assert_called_once_with('test.log') - if gen['expected']['crashed']: + self.assertEqual(callee_dict["failures"], num_fails) + self.assertEqual(len(callee_dict["results"]), num_tests) + mock_get_log_lines.assert_called_once_with("test.log") + if gen["expected"]["crashed"]: self.assertEqual(result.exit_code, 2) - elif gen['expected']['unknown'] or gen['expected']['failure']: + elif gen["expected"]["unknown"] or gen["expected"]["failure"]: self.assertEqual(result.exit_code, 1) else: self.assertEqual(result.exit_code, 0) diff --git a/buildscripts/tests/test_mongosymb.py b/buildscripts/tests/test_mongosymb.py index 5e06313cbd9..c8335737928 100644 --- a/buildscripts/tests/test_mongosymb.py +++ b/buildscripts/tests/test_mongosymb.py @@ -1,4 +1,5 @@ """Unit tests for buildscripts/mongosymb.py.""" + import unittest from buildscripts import mongosymb as under_test @@ -12,7 +13,9 @@ class TestGetVersion(unittest.TestCase): } } version = under_test.get_version(trace_doc) - self.assertEqual(version, "6.0.0-alpha0-37-ge1d28c1-patch-6257e60a32f417196bc25169") + self.assertEqual( + version, "6.0.0-alpha0-37-ge1d28c1-patch-6257e60a32f417196bc25169" + ) def test_get_version_without_patch(self): trace_doc = {"processInfo": {"mongodbVersion": "6.1.0-alpha-504-g0c8a142"}} diff --git a/buildscripts/tests/test_packager.py b/buildscripts/tests/test_packager.py index 975a85dee3d..fb9f27a5103 100644 --- a/buildscripts/tests/test_packager.py +++ b/buildscripts/tests/test_packager.py @@ -1,4 +1,5 @@ """Unit tests for the packager script.""" + from dataclasses import dataclass from unittest import TestCase diff --git a/buildscripts/tests/test_powercycle_sentinel.py b/buildscripts/tests/test_powercycle_sentinel.py index 0ef0588ebc9..8e27ad4c2e0 100644 --- a/buildscripts/tests/test_powercycle_sentinel.py +++ b/buildscripts/tests/test_powercycle_sentinel.py @@ -1,4 +1,5 @@ """Unit tests for powercycle_sentinel.py.""" + import unittest from datetime import datetime, timedelta, timezone from unittest.mock import Mock @@ -11,11 +12,14 @@ from evergreen import EvergreenApi, Task def make_task_mock(evg_api, task_id, start_time, finish_time): - return Task({ - "task_id": task_id, - "start_time": start_time, - "finish_time": finish_time, - }, evg_api) + return Task( + { + "task_id": task_id, + "start_time": start_time, + "finish_time": finish_time, + }, + evg_api, + ) class TestWatchTasks(unittest.TestCase): @@ -28,25 +32,33 @@ class TestWatchTasks(unittest.TestCase): task_1 = make_task_mock(evg_api, task_ids[0], now, now) task_2 = make_task_mock(evg_api, task_ids[1], now, now) evg_api.task_by_id = Mock( - side_effect=(lambda task_id: { - "1": task_1, - "2": task_2, - }[task_id])) + side_effect=( + lambda task_id: { + "1": task_1, + "2": task_2, + }[task_id] + ) + ) long_running_task_ids = watch_tasks(task_ids, evg_api, 0) self.assertEqual([], long_running_task_ids) def test_found_long_running_tasks(self): evg_api = EvergreenApi() task_ids = ["1", "2"] - exec_timeout_seconds_ago = (datetime.now(timezone.utc) - - timedelta(hours=POWERCYCLE_TASK_EXEC_TIMEOUT_SECS)).isoformat() + exec_timeout_seconds_ago = ( + datetime.now(timezone.utc) + - timedelta(hours=POWERCYCLE_TASK_EXEC_TIMEOUT_SECS) + ).isoformat() now = datetime.now(timezone.utc).isoformat() task_1 = make_task_mock(evg_api, task_ids[0], exec_timeout_seconds_ago, now) task_2 = make_task_mock(evg_api, task_ids[1], exec_timeout_seconds_ago, None) evg_api.task_by_id = Mock( - side_effect=(lambda task_id: { - "1": task_1, - "2": task_2, - }[task_id])) + side_effect=( + lambda task_id: { + "1": task_1, + "2": task_2, + }[task_id] + ) + ) long_running_task_ids = watch_tasks(task_ids, evg_api, 0) self.assertEqual([task_2.task_id], long_running_task_ids) diff --git a/buildscripts/tests/test_simple_report.py b/buildscripts/tests/test_simple_report.py index 8c3c323a351..b258f91fe10 100644 --- a/buildscripts/tests/test_simple_report.py +++ b/buildscripts/tests/test_simple_report.py @@ -36,14 +36,23 @@ class TestSimpleReport(unittest.TestCase): @patch(ns("try_combine_reports")) @patch(ns("_clean_log_file")) @patch(ns("put_report")) - def _test_trivial_report(self, mock_put_report, mock_clean_log_file, _mock_try_combine_reports): + def _test_trivial_report( + self, mock_put_report, mock_clean_log_file, _mock_try_combine_reports + ): exit_code = self.rng.randint(0, 254) print(f"Trying exit code: {exit_code}") mock_clean_log_file.return_value = "I'm a little test log, short and stdout." runner = CliRunner() result = runner.invoke( buildscripts.simple_report.main, - ["--test-name", "potato", "--log-file", "test.log", "--exit-code", str(exit_code)], + [ + "--test-name", + "potato", + "--log-file", + "test.log", + "--exit-code", + str(exit_code), + ], ) report = mock_put_report.call_args[0][0] results = mock_put_report.call_args[0][0]["results"] diff --git a/buildscripts/tests/test_sync_repo_with_copybara.py b/buildscripts/tests/test_sync_repo_with_copybara.py index 030c418787b..4d28283d406 100644 --- a/buildscripts/tests/test_sync_repo_with_copybara.py +++ b/buildscripts/tests/test_sync_repo_with_copybara.py @@ -8,7 +8,7 @@ from buildscripts import sync_repo_with_copybara @unittest.skipIf( - sys.platform in ('win32', 'darwin'), + sys.platform in ("win32", "darwin"), reason="No need to run this unittest on windows or macos", ) class TestBranchFunctions(unittest.TestCase): @@ -23,17 +23,19 @@ class TestBranchFunctions(unittest.TestCase): os.makedirs(mongodb_mongo_dir, exist_ok=True) # Create .git directory - git_dir = os.path.join(mongodb_mongo_dir, '.git') + git_dir = os.path.join(mongodb_mongo_dir, ".git") os.makedirs(git_dir, exist_ok=True) # Write contents to .git/config - config_path = os.path.join(git_dir, 'config') - with open(config_path, 'w') as f: + config_path = os.path.join(git_dir, "config") + with open(config_path, "w") as f: # Write contents to .git/config f.write(config_content) @staticmethod - def create_mock_repo_commits(repo_directory, num_commits, private_commit_hashes=None): + def create_mock_repo_commits( + repo_directory, num_commits, private_commit_hashes=None + ): """ Create mock commits in a Git repository. @@ -43,16 +45,18 @@ class TestBranchFunctions(unittest.TestCase): :return: A list of commit hashes generated for the new commits. """ os.chdir(repo_directory) - sync_repo_with_copybara.run_command('git init') - sync_repo_with_copybara.run_command('git config --local user.email "test@example.com"') + sync_repo_with_copybara.run_command("git init") + sync_repo_with_copybara.run_command( + 'git config --local user.email "test@example.com"' + ) sync_repo_with_copybara.run_command('git config --local user.name "Test User"') # Used to store commit hashes commit_hashes = [] for i in range(num_commits): - with open("test.txt", 'a') as f: + with open("test.txt", "a") as f: f.write(str(i)) - sync_repo_with_copybara.run_command('git add test.txt') + sync_repo_with_copybara.run_command("git add test.txt") commit_message = f"test commit {i}" # If there are private commit hashes need to be added in public repo commits, include them in the commit message if private_commit_hashes: @@ -61,7 +65,8 @@ class TestBranchFunctions(unittest.TestCase): # Get the current commit hash sync_repo_with_copybara.run_command(f'git commit -m "{commit_message}"') commit_hashes.append( - sync_repo_with_copybara.run_command('git log --pretty=format:\"%H\" -1')) + sync_repo_with_copybara.run_command('git log --pretty=format:"%H" -1') + ) return commit_hashes @staticmethod @@ -85,19 +90,23 @@ class TestBranchFunctions(unittest.TestCase): os.chdir(mock_10gen_dir) # Create a mock private repository and get all commit hashes private_hashes = TestBranchFunctions.create_mock_repo_commits( - mock_10gen_dir, num_commits) + mock_10gen_dir, num_commits + ) # Create a mock public repository and pass the list of private commit hashes if matched_public_commits != 0: public_hashes = TestBranchFunctions.create_mock_repo_commits( - mock_mongodb_dir, matched_public_commits, private_hashes) + mock_mongodb_dir, matched_public_commits, private_hashes + ) else: public_hashes = TestBranchFunctions.create_mock_repo_commits( - mock_mongodb_dir, num_commits) + mock_mongodb_dir, num_commits + ) os.chdir(tmpdir) - result = sync_repo_with_copybara.find_matching_commit(mock_10gen_dir, - mock_mongodb_dir) + result = sync_repo_with_copybara.find_matching_commit( + mock_10gen_dir, mock_mongodb_dir + ) # Check if the commit in the search result matches the last commit in the public repository if result == public_hashes[-1]: @@ -105,7 +114,9 @@ class TestBranchFunctions(unittest.TestCase): else: assert result is None except Exception as err: - print(f"{test_name}: FAIL!\n Exception occurred: {err}\n {traceback.format_exc()}") + print( + f"{test_name}: FAIL!\n Exception occurred: {err}\n {traceback.format_exc()}" + ) return False def test_no_search(self): @@ -136,7 +147,9 @@ class TestBranchFunctions(unittest.TestCase): branch="v7.3", ), ) - result = sync_repo_with_copybara.check_destination_branch_exists(copybara_config) + result = sync_repo_with_copybara.check_destination_branch_exists( + copybara_config + ) self.assertTrue(result, f"{test_name}: SUCCESS!") def test_branch_not_exists(self): @@ -149,7 +162,9 @@ class TestBranchFunctions(unittest.TestCase): branch="..invalid-therefore-impossible-to-create-branch-name", ), ) - result = sync_repo_with_copybara.check_destination_branch_exists(copybara_config) + result = sync_repo_with_copybara.check_destination_branch_exists( + copybara_config + ) self.assertFalse(result, f"{test_name}: SUCCESS!") def test_only_mongodb_mongo_repo(self): @@ -168,9 +183,13 @@ class TestBranchFunctions(unittest.TestCase): try: # Check if the repository is only the MongoDB official repository - result = sync_repo_with_copybara.has_only_destination_repo_remote("mongodb/mongo") + result = sync_repo_with_copybara.has_only_destination_repo_remote( + "mongodb/mongo" + ) except Exception as err: - print(f"{test_name}: FAIL!\n Exception occurred: {err}\n {traceback.format_exc()}") + print( + f"{test_name}: FAIL!\n Exception occurred: {err}\n {traceback.format_exc()}" + ) self.fail(f"{test_name}: FAIL!") return @@ -210,8 +229,10 @@ class TestBranchFunctions(unittest.TestCase): branching_off_commit="", ) except Exception as err: - if (str(err) == - f"{mongodb_mongo_dir} git repo has not only the destination repo remote"): + if ( + str(err) + == f"{mongodb_mongo_dir} git repo has not only the destination repo remote" + ): return self.fail(f"{test_name}: FAIL!") @@ -255,9 +276,10 @@ class TestBranchFunctions(unittest.TestCase): invalid_branching_off_commit, ) except Exception as err: - if str( - err - ) == "The new branch top commit does not match the branching_off_commit. Aborting push.": + if ( + str(err) + == "The new branch top commit does not match the branching_off_commit. Aborting push." + ): return self.fail(f"{test_name}: FAIL!") diff --git a/buildscripts/tests/test_todo_check.py b/buildscripts/tests/test_todo_check.py index ad7c83ebb12..6ad138460ce 100644 --- a/buildscripts/tests/test_todo_check.py +++ b/buildscripts/tests/test_todo_check.py @@ -252,7 +252,9 @@ class TestWalkFs(unittest.TestCase): write_file(os.path.join(tmpdir, "file1.txt"), expected_files["file1.txt"]) os.makedirs(os.path.join(tmpdir, "dir0", "dir1")) write_file( - os.path.join(tmpdir, "dir0", "dir1", "file2.txt"), expected_files["file2.txt"]) + os.path.join(tmpdir, "dir0", "dir1", "file2.txt"), + expected_files["file2.txt"], + ) seen_files = {} diff --git a/buildscripts/tests/test_validate_commit_message.py b/buildscripts/tests/test_validate_commit_message.py index fc6fce554a6..93c14abff82 100644 --- a/buildscripts/tests/test_validate_commit_message.py +++ b/buildscripts/tests/test_validate_commit_message.py @@ -1,4 +1,5 @@ """Unit tests for the evergreen_task_timeout script.""" + import unittest from buildscripts.validate_commit_message import STATUS_ERROR, STATUS_OK, main @@ -8,7 +9,7 @@ class ValidateCommitMessageTest(unittest.TestCase): def test_valid(self): messages = [ "SERVER-44338", - "Revert \"SERVER-60", + 'Revert "SERVER-60', "Import wiredtiger: 58115abb6fbb3c1cc7bfd087d41a47347bce9a69 from branch mongodb-4.4", 'Revert "Import wiredtiger: 58115abb6fbb3c1cc7bfd087d41a47347bce9a69 from branch mongodb-4.4"', ] diff --git a/buildscripts/tests/test_validate_mongocryptd.py b/buildscripts/tests/test_validate_mongocryptd.py index 3d82a25c890..bb6f6a3ee11 100644 --- a/buildscripts/tests/test_validate_mongocryptd.py +++ b/buildscripts/tests/test_validate_mongocryptd.py @@ -20,17 +20,27 @@ class TestCanValidationBeSkipped(unittest.TestCase): def test_non_existing_variant_can_be_skipped(self): mock_evg_config = MagicMock() mock_evg_config.get_variant.return_value = None - self.assertTrue(under_test.can_validation_be_skipped(mock_evg_config, "variant")) + self.assertTrue( + under_test.can_validation_be_skipped(mock_evg_config, "variant") + ) def test_variant_with_no_push_task_can_be_skipped(self): mock_evg_config = MagicMock() mock_evg_config.get_variant.return_value.task_names = ["task 1", "task 2"] - self.assertTrue(under_test.can_validation_be_skipped(mock_evg_config, "variant")) + self.assertTrue( + under_test.can_validation_be_skipped(mock_evg_config, "variant") + ) def test_variant_with_push_task_cannot_be_skipped(self): mock_evg_config = MagicMock() - mock_evg_config.get_variant.return_value.task_names = ["task 1", "push", "task 2"] - self.assertFalse(under_test.can_validation_be_skipped(mock_evg_config, "variant")) + mock_evg_config.get_variant.return_value.task_names = [ + "task 1", + "push", + "task 2", + ] + self.assertFalse( + under_test.can_validation_be_skipped(mock_evg_config, "variant") + ) class TestReadVariableFromYml(unittest.TestCase): @@ -58,4 +68,6 @@ class TestReadVariableFromYml(unittest.TestCase): } yaml_mock.safe_load.return_value = mock_nodes - self.assertEqual(expected_value, under_test.read_variable_from_yml("filename", search_key)) + self.assertEqual( + expected_value, under_test.read_variable_from_yml("filename", search_key) + ) diff --git a/buildscripts/tests/timeouts/test_timeout.py b/buildscripts/tests/timeouts/test_timeout.py index 45332543084..9c41b1ca2a0 100644 --- a/buildscripts/tests/timeouts/test_timeout.py +++ b/buildscripts/tests/timeouts/test_timeout.py @@ -1,4 +1,5 @@ """Unit tests for timeout.py.""" + import unittest from buildscripts.timeouts import timeout as under_test @@ -8,7 +9,9 @@ from buildscripts.timeouts import timeout as under_test class CalculateTimeoutTest(unittest.TestCase): def test_min_timeout(self): - self.assertEqual(under_test.MIN_TIMEOUT_SECONDS, under_test.calculate_timeout(15, 1)) + self.assertEqual( + under_test.MIN_TIMEOUT_SECONDS, under_test.calculate_timeout(15, 1) + ) def test_over_timeout_by_one_minute(self): self.assertEqual(360, under_test.calculate_timeout(301, 1)) @@ -19,38 +22,49 @@ class CalculateTimeoutTest(unittest.TestCase): def test_scaling_factor(self): avg_runtime = 30 scaling_factor = 10 - self.assertEqual(avg_runtime * scaling_factor + 60, - under_test.calculate_timeout(avg_runtime, scaling_factor)) + self.assertEqual( + avg_runtime * scaling_factor + 60, + under_test.calculate_timeout(avg_runtime, scaling_factor), + ) class TimeoutEstimateTest(unittest.TestCase): def test_too_high_a_timeout_raises_errors(self): timeout_est = under_test.TimeoutEstimate( - max_test_runtime=5, expected_task_runtime=under_test.MAX_EXPECTED_TIMEOUT) + max_test_runtime=5, expected_task_runtime=under_test.MAX_EXPECTED_TIMEOUT + ) with self.assertRaises(ValueError): timeout_est.generate_timeout_cmd(is_patch=True, repeat_factor=1) def test_is_specified_should_return_true_when_a_test_runtime_is_specified(self): - timeout_est = under_test.TimeoutEstimate(max_test_runtime=3.14, expected_task_runtime=None) + timeout_est = under_test.TimeoutEstimate( + max_test_runtime=3.14, expected_task_runtime=None + ) self.assertTrue(timeout_est.is_specified()) def test_is_specified_should_return_true_when_a_task_runtime_is_specified(self): - timeout_est = under_test.TimeoutEstimate(max_test_runtime=None, expected_task_runtime=3.14) + timeout_est = under_test.TimeoutEstimate( + max_test_runtime=None, expected_task_runtime=3.14 + ) self.assertTrue(timeout_est.is_specified()) def test_is_specified_should_return_false_when_no_data_is_specified(self): - timeout_est = under_test.TimeoutEstimate(max_test_runtime=None, expected_task_runtime=None) + timeout_est = under_test.TimeoutEstimate( + max_test_runtime=None, expected_task_runtime=None + ) self.assertFalse(timeout_est.is_specified()) class TestGenerateTimeoutCmd(unittest.TestCase): def test_evg_config_does_not_fails_if_test_timeout_too_high_on_mainline(self): - timeout = under_test.TimeoutEstimate(max_test_runtime=under_test.MAX_EXPECTED_TIMEOUT + 1, - expected_task_runtime=None) + timeout = under_test.TimeoutEstimate( + max_test_runtime=under_test.MAX_EXPECTED_TIMEOUT + 1, + expected_task_runtime=None, + ) time_cmd = timeout.generate_timeout_cmd(is_patch=False, repeat_factor=1) @@ -58,7 +72,9 @@ class TestGenerateTimeoutCmd(unittest.TestCase): def test_evg_config_does_not_fails_if_task_timeout_too_high_on_mainline(self): timeout = under_test.TimeoutEstimate( - expected_task_runtime=under_test.MAX_EXPECTED_TIMEOUT + 1, max_test_runtime=None) + expected_task_runtime=under_test.MAX_EXPECTED_TIMEOUT + 1, + max_test_runtime=None, + ) time_cmd = timeout.generate_timeout_cmd(is_patch=False, repeat_factor=1) @@ -94,7 +110,9 @@ class TestTimeoutInfo(unittest.TestCase): def test_both_timeouts_set(self): timeout = 3 exec_timeout = 5 - timeout_info = under_test.TimeoutInfo.overridden(exec_timeout=exec_timeout, timeout=timeout) + timeout_info = under_test.TimeoutInfo.overridden( + exec_timeout=exec_timeout, timeout=timeout + ) cmd = timeout_info.cmd.as_dict() diff --git a/buildscripts/tests/timeouts/test_timeout_service.py b/buildscripts/tests/timeouts/test_timeout_service.py index db737356df9..91906dcb93f 100644 --- a/buildscripts/tests/timeouts/test_timeout_service.py +++ b/buildscripts/tests/timeouts/test_timeout_service.py @@ -1,4 +1,5 @@ """Unit tests for timeout_service.py.""" + import random import unittest from unittest.mock import MagicMock, patch @@ -21,11 +22,16 @@ def ns(relative_name): # pylint: disable=invalid-name def build_mock_service(resmoke_proxy=None): return under_test.TimeoutService( - resmoke_proxy=resmoke_proxy if resmoke_proxy else MagicMock(spec_set=ResmokeProxyService)) + resmoke_proxy=resmoke_proxy + if resmoke_proxy + else MagicMock(spec_set=ResmokeProxyService) + ) def tst_stat_mock(file, duration, pass_count): - return MagicMock(test_name=file, avg_duration_pass=duration, num_pass=pass_count, hooks=[]) + return MagicMock( + test_name=file, avg_duration_pass=duration, num_pass=pass_count, hooks=[] + ) def tst_runtime_mock(file, duration, pass_count): @@ -51,9 +57,12 @@ class TestGetTimeoutEstimate(unittest.TestCase): @patch(ns("HistoricTaskData.from_s3")) def test_too_many_tests_missing_history_should_cause_a_default_timeout( - self, from_s3_mock: MagicMock): + self, from_s3_mock: MagicMock + ): test_stats = [ - HistoricTestInfo(test_name=f"test_{i}.js", avg_duration=600.0, num_pass=1, hooks=[]) + HistoricTestInfo( + test_name=f"test_{i}.js", avg_duration=600.0, num_pass=1, hooks=[] + ) for i in range(23) ] from_s3_mock.return_value = HistoricTaskData(test_stats) @@ -76,15 +85,22 @@ class TestGetTimeoutEstimate(unittest.TestCase): @patch(ns("HistoricTaskData.from_s3")) def test_too_many_tests_with_zero_runtime_history_should_cause_a_default_timeout( - self, from_s3_mock: MagicMock): + self, from_s3_mock: MagicMock + ): test_stats = [ - HistoricTestInfo(test_name=f"test_{i}.js", avg_duration=600.0, num_pass=1, hooks=[]) + HistoricTestInfo( + test_name=f"test_{i}.js", avg_duration=600.0, num_pass=1, hooks=[] + ) for i in range(23) ] - test_stats.extend([ - HistoricTestInfo(test_name=f"zero_{i}.js", avg_duration=0.0, num_pass=1, hooks=[]) - for i in range(7) - ]) + test_stats.extend( + [ + HistoricTestInfo( + test_name=f"zero_{i}.js", avg_duration=0.0, num_pass=1, hooks=[] + ) + for i in range(7) + ] + ) from_s3_mock.return_value = HistoricTaskData(test_stats) mock_resmoke_proxy = MagicMock(spec_set=ResmokeProxyService) mock_resmoke_proxy.list_tests.return_value = [ts.test_name for ts in test_stats] @@ -103,9 +119,12 @@ class TestGetTimeoutEstimate(unittest.TestCase): @patch(ns("HistoricTaskData.from_s3")) def test_enough_history_but_some_tests_missing_history_should_cause_custom_task_and_default_test_timeout( - self, from_s3_mock: MagicMock): + self, from_s3_mock: MagicMock + ): test_stats = [ - HistoricTestInfo(test_name=f"test_{i}.js", avg_duration=600.0, num_pass=1, hooks=[]) + HistoricTestInfo( + test_name=f"test_{i}.js", avg_duration=600.0, num_pass=1, hooks=[] + ) for i in range(25) ] from_s3_mock.return_value = HistoricTaskData(test_stats) @@ -130,15 +149,22 @@ class TestGetTimeoutEstimate(unittest.TestCase): @patch(ns("HistoricTaskData.from_s3")) def test_enough_history_but_some_tests_with_zero_runtime_should_cause_custom_task_and_default_test_timeout( - self, from_s3_mock: MagicMock): + self, from_s3_mock: MagicMock + ): test_stats = [ - HistoricTestInfo(test_name=f"test_{i}.js", avg_duration=600.0, num_pass=1, hooks=[]) + HistoricTestInfo( + test_name=f"test_{i}.js", avg_duration=600.0, num_pass=1, hooks=[] + ) for i in range(25) ] - test_stats.extend([ - HistoricTestInfo(test_name=f"zero_{i}.js", avg_duration=0.0, num_pass=1, hooks=[]) - for i in range(5) - ]) + test_stats.extend( + [ + HistoricTestInfo( + test_name=f"zero_{i}.js", avg_duration=0.0, num_pass=1, hooks=[] + ) + for i in range(5) + ] + ) from_s3_mock.return_value = HistoricTaskData(test_stats) mock_resmoke_proxy = MagicMock(spec_set=ResmokeProxyService) mock_resmoke_proxy.list_tests.return_value = [ts.test_name for ts in test_stats] @@ -158,10 +184,13 @@ class TestGetTimeoutEstimate(unittest.TestCase): self.assertEqual(54360, timeout.calculate_task_timeout(1)) @patch(ns("HistoricTaskData.from_s3")) - def test_all_tests_with_runtime_history_should_use_custom_timeout(self, - from_s3_mock: MagicMock): + def test_all_tests_with_runtime_history_should_use_custom_timeout( + self, from_s3_mock: MagicMock + ): test_stats = [ - HistoricTestInfo(test_name=f"test_{i}.js", avg_duration=600.0, num_pass=1, hooks=[]) + HistoricTestInfo( + test_name=f"test_{i}.js", avg_duration=600.0, num_pass=1, hooks=[] + ) for i in range(30) ] from_s3_mock.return_value = HistoricTaskData(test_stats) @@ -187,18 +216,21 @@ class TestGetTaskHookOverhead(unittest.TestCase): def test_no_stats_should_return_zero(self): timeout_service = build_mock_service() - overhead = timeout_service.get_task_hook_overhead("suite", is_asan=False, test_count=30, - historic_stats=None) + overhead = timeout_service.get_task_hook_overhead( + "suite", is_asan=False, test_count=30, historic_stats=None + ) self.assertEqual(0.0, overhead) def test_stats_with_no_clean_every_n_should_return_zero(self): timeout_service = build_mock_service() test_stats = HistoricTaskData.from_stats_list( - [tst_stat_mock(f"test_{i}.js", 60, 1) for i in range(30)]) + [tst_stat_mock(f"test_{i}.js", 60, 1) for i in range(30)] + ) - overhead = timeout_service.get_task_hook_overhead("suite", is_asan=False, test_count=30, - historic_stats=test_stats) + overhead = timeout_service.get_task_hook_overhead( + "suite", is_asan=False, test_count=30, historic_stats=test_stats + ) self.assertEqual(0.0, overhead) @@ -206,16 +238,21 @@ class TestGetTaskHookOverhead(unittest.TestCase): test_count = 30 runtime = 25 timeout_service = build_mock_service() - test_stat_list = [tst_stat_mock(f"test_{i}.js", 60, 1) for i in range(test_count)] - test_stat_list.extend([ - tst_stat_mock(f"test_{i}:{under_test.CLEAN_EVERY_N_HOOK}", runtime, 1) - for i in range(10) - ]) + test_stat_list = [ + tst_stat_mock(f"test_{i}.js", 60, 1) for i in range(test_count) + ] + test_stat_list.extend( + [ + tst_stat_mock(f"test_{i}:{under_test.CLEAN_EVERY_N_HOOK}", runtime, 1) + for i in range(10) + ] + ) random.shuffle(test_stat_list) test_stats = HistoricTaskData.from_stats_list(test_stat_list) overhead = timeout_service.get_task_hook_overhead( - "suite", is_asan=True, test_count=test_count, historic_stats=test_stats) + "suite", is_asan=True, test_count=test_count, historic_stats=test_stats + ) self.assertEqual(runtime * test_count, overhead) @@ -254,7 +291,9 @@ class TestLookupHistoricStats(unittest.TestCase): self.assertIsNone(stats) @patch(ns("HistoricTaskData.from_s3")) - def test_stats_from_evergreen_should_return_the_stats(self, from_s3_mock: MagicMock): + def test_stats_from_evergreen_should_return_the_stats( + self, from_s3_mock: MagicMock + ): test_stats = [tst_stat_mock(f"test_{i}.js", 60, 1) for i in range(100)] from_s3_mock.return_value = HistoricTaskData(test_stats) timeout_service = build_mock_service() @@ -285,12 +324,15 @@ class TestGetCleanEveryNCadence(unittest.TestCase): mock_resmoke_proxy = MagicMock() mock_resmoke_proxy.read_suite_config.return_value = { "executor": { - "hooks": [{ - "class": "hook1", - }, { - "class": under_test.CLEAN_EVERY_N_HOOK, - "n": expected_n, - }] + "hooks": [ + { + "class": "hook1", + }, + { + "class": under_test.CLEAN_EVERY_N_HOOK, + "n": expected_n, + }, + ] } } timeout_service = build_mock_service(resmoke_proxy=mock_resmoke_proxy) @@ -303,11 +345,14 @@ class TestGetCleanEveryNCadence(unittest.TestCase): mock_resmoke_proxy = MagicMock() mock_resmoke_proxy.read_suite_config.return_value = { "executor": { - "hooks": [{ - "class": "hook1", - }, { - "class": under_test.CLEAN_EVERY_N_HOOK, - }] + "hooks": [ + { + "class": "hook1", + }, + { + "class": under_test.CLEAN_EVERY_N_HOOK, + }, + ] } } timeout_service = build_mock_service(resmoke_proxy=mock_resmoke_proxy) @@ -319,9 +364,13 @@ class TestGetCleanEveryNCadence(unittest.TestCase): def test_clean_every_n_cadence_no_hook_config(self): mock_resmoke_proxy = MagicMock() mock_resmoke_proxy.read_suite_config.return_value = { - "executor": {"hooks": [{ - "class": "hook1", - }, ]} + "executor": { + "hooks": [ + { + "class": "hook1", + }, + ] + } } timeout_service = build_mock_service(resmoke_proxy=mock_resmoke_proxy) @@ -334,28 +383,53 @@ class TestHaveEnoughHistoricStats(unittest.TestCase): def test_should_return_true_when_number_of_tests_equals_zero(self): timeout_service = build_mock_service() self.assertTrue( - timeout_service._have_enough_historic_stats(num_tests=0, num_tests_missing_data=0)) + timeout_service._have_enough_historic_stats( + num_tests=0, num_tests_missing_data=0 + ) + ) - def test_should_return_true_when_number_of_tests_with_historic_data_more_than_threshold(self): + def test_should_return_true_when_number_of_tests_with_historic_data_more_than_threshold( + self, + ): timeout_service = build_mock_service() self.assertTrue( - timeout_service._have_enough_historic_stats(num_tests=100, num_tests_missing_data=19)) + timeout_service._have_enough_historic_stats( + num_tests=100, num_tests_missing_data=19 + ) + ) self.assertTrue( - timeout_service._have_enough_historic_stats(num_tests=100, num_tests_missing_data=0)) + timeout_service._have_enough_historic_stats( + num_tests=100, num_tests_missing_data=0 + ) + ) def test_should_return_false_when_number_of_tests_with_historic_data_less_or_equal_to_threshold( - self): + self, + ): timeout_service = build_mock_service() self.assertFalse( - timeout_service._have_enough_historic_stats(num_tests=100, num_tests_missing_data=20)) + timeout_service._have_enough_historic_stats( + num_tests=100, num_tests_missing_data=20 + ) + ) self.assertFalse( - timeout_service._have_enough_historic_stats(num_tests=100, num_tests_missing_data=21)) + timeout_service._have_enough_historic_stats( + num_tests=100, num_tests_missing_data=21 + ) + ) self.assertFalse( - timeout_service._have_enough_historic_stats(num_tests=100, num_tests_missing_data=100)) + timeout_service._have_enough_historic_stats( + num_tests=100, num_tests_missing_data=100 + ) + ) def test_exception_raised_when_number_of_tests_less_than_zero(self): timeout_service = build_mock_service() with self.assertRaises(ValueError): - timeout_service._have_enough_historic_stats(num_tests=-1, num_tests_missing_data=0) + timeout_service._have_enough_historic_stats( + num_tests=-1, num_tests_missing_data=0 + ) with self.assertRaises(ValueError): - timeout_service._have_enough_historic_stats(num_tests=-100, num_tests_missing_data=0) + timeout_service._have_enough_historic_stats( + num_tests=-100, num_tests_missing_data=0 + ) diff --git a/buildscripts/tests/util/test_fileops.py b/buildscripts/tests/util/test_fileops.py index cfd7f9dd498..743ae8e2359 100644 --- a/buildscripts/tests/util/test_fileops.py +++ b/buildscripts/tests/util/test_fileops.py @@ -1,4 +1,5 @@ """Unit tests for fileops.py.""" + import unittest from unittest.mock import patch diff --git a/buildscripts/tests/util/test_read_config.py b/buildscripts/tests/util/test_read_config.py index a9c51764e41..6e345920826 100644 --- a/buildscripts/tests/util/test_read_config.py +++ b/buildscripts/tests/util/test_read_config.py @@ -19,17 +19,25 @@ class TestGetConfigValue(unittest.TestCase): self.assertEqual("default", value) def test_exception_throw_for_missing_required(self): - self.assertRaises(KeyError, read_config.get_config_value, "missing", {}, {}, required=True) + self.assertRaises( + KeyError, read_config.get_config_value, "missing", {}, {}, required=True + ) def test_config_file_value_is_used(self): - value = read_config.get_config_value("option", {}, {"option": "value 0"}, default="default", - required=True) + value = read_config.get_config_value( + "option", {}, {"option": "value 0"}, default="default", required=True + ) self.assertEqual("value 0", value) def test_cmdline_value_is_used(self): cmdline_mock = mock.Mock cmdline_mock.option = "cmdline value" - value = read_config.get_config_value("option", cmdline_mock, {"option": "value 0"}, - default="default", required=True) + value = read_config.get_config_value( + "option", + cmdline_mock, + {"option": "value 0"}, + default="default", + required=True, + ) self.assertEqual("cmdline value", value) diff --git a/buildscripts/tests/util/test_taskname.py b/buildscripts/tests/util/test_taskname.py index d5b6d264bf5..b5571d94789 100644 --- a/buildscripts/tests/util/test_taskname.py +++ b/buildscripts/tests/util/test_taskname.py @@ -9,21 +9,29 @@ import buildscripts.util.taskname as under_test class TestNameTask(unittest.TestCase): def test_name_task_with_width_one(self): - self.assertEqual("name_3_var", under_test.name_generated_task("name", 3, 10, "var")) + self.assertEqual( + "name_3_var", under_test.name_generated_task("name", 3, 10, "var") + ) def test_name_task_with_width_four(self): - self.assertEqual("task_3141_var", under_test.name_generated_task("task", 3141, 5000, "var")) + self.assertEqual( + "task_3141_var", under_test.name_generated_task("task", 3141, 5000, "var") + ) class TestRemoveGenSuffix(unittest.TestCase): def test_removes_gen_suffix(self): input_task_name = "sharding_auth_auditg_gen" - self.assertEqual("sharding_auth_auditg", under_test.remove_gen_suffix(input_task_name)) + self.assertEqual( + "sharding_auth_auditg", under_test.remove_gen_suffix(input_task_name) + ) def test_doesnt_remove_non_gen_suffix(self): input_task_name = "sharded_multi_stmt_txn_jscore_passthroug" - self.assertEqual("sharded_multi_stmt_txn_jscore_passthroug", - under_test.remove_gen_suffix(input_task_name)) + self.assertEqual( + "sharded_multi_stmt_txn_jscore_passthroug", + under_test.remove_gen_suffix(input_task_name), + ) class TestDetermineTaskBaseName(unittest.TestCase): diff --git a/buildscripts/tests/util/test_testname.py b/buildscripts/tests/util/test_testname.py index d4d7facb642..28e2e2242eb 100644 --- a/buildscripts/tests/util/test_testname.py +++ b/buildscripts/tests/util/test_testname.py @@ -46,12 +46,15 @@ class NormalizeTestFileTest(unittest.TestCase): def test_windows_file_is_normalized(self): windows_file = "test\\found\\under\\windows.exe" self.assertEqual( - testname_utils.normalize_test_file(windows_file), "test/found/under/windows") + testname_utils.normalize_test_file(windows_file), "test/found/under/windows" + ) def test_windows_file_with_non_exe_ext(self): windows_file = "test\\found\\under\\windows.sh" self.assertEqual( - testname_utils.normalize_test_file(windows_file), "test/found/under/windows.sh") + testname_utils.normalize_test_file(windows_file), + "test/found/under/windows.sh", + ) def test_unix_files_are_not_changed(self): unix_file = "test/found/under/unix" diff --git a/buildscripts/tests/util/test_teststats.py b/buildscripts/tests/util/test_teststats.py index 191e4b0b020..bb6ec8392ac 100644 --- a/buildscripts/tests/util/test_teststats.py +++ b/buildscripts/tests/util/test_teststats.py @@ -16,17 +16,21 @@ _DATE = datetime.datetime(2018, 7, 15) class NormalizeTestNameTest(unittest.TestCase): def test_unix_names(self): - self.assertEqual("/home/user/test.js", under_test.normalize_test_name("/home/user/test.js")) + self.assertEqual( + "/home/user/test.js", under_test.normalize_test_name("/home/user/test.js") + ) def test_windows_names(self): - self.assertEqual("/home/user/test.js", - under_test.normalize_test_name("\\home\\user\\test.js")) + self.assertEqual( + "/home/user/test.js", + under_test.normalize_test_name("\\home\\user\\test.js"), + ) class TestHistoricTestInfo(unittest.TestCase): def test_total_test_runtime_not_passing_test_no_hooks(self): test_info = under_test.HistoricTestInfo( - test_name='jstests/test.js', + test_name="jstests/test.js", num_pass=0, avg_duration=0.0, hooks=[], @@ -36,12 +40,12 @@ class TestHistoricTestInfo(unittest.TestCase): def test_total_test_runtime_not_passing_test_with_hooks(self): test_info = under_test.HistoricTestInfo( - test_name='jstests/test.js', + test_name="jstests/test.js", num_pass=0, avg_duration=0.0, hooks=[ under_test.HistoricHookInfo( - hook_id='test:hook', + hook_id="test:hook", num_pass=10, avg_duration=5.0, ), @@ -52,7 +56,7 @@ class TestHistoricTestInfo(unittest.TestCase): def test_total_test_runtime_passing_test_no_hooks(self): test_info = under_test.HistoricTestInfo( - test_name='jstests/test.js', + test_name="jstests/test.js", num_pass=10, avg_duration=23.0, hooks=[], @@ -62,12 +66,12 @@ class TestHistoricTestInfo(unittest.TestCase): def test_total_test_runtime_passing_test_with_hooks(self): test_info = under_test.HistoricTestInfo( - test_name='jstests/test.js', + test_name="jstests/test.js", num_pass=10, avg_duration=23.0, hooks=[ under_test.HistoricHookInfo( - hook_id='test:hook', + hook_id="test:hook", num_pass=10, avg_duration=5.0, ), @@ -146,7 +150,7 @@ class TestHistoricTaskData(unittest.TestCase): avg_duration_pass=duration, ) - @patch.object(Session, 'get') + @patch.object(Session, "get") def test_get_stats_from_s3_returns_data(self, mock_get): mock_response = MagicMock() mock_response.json.return_value = [ @@ -167,31 +171,38 @@ class TestHistoricTaskData(unittest.TestCase): ] mock_get.return_value = mock_response - result = under_test.HistoricTaskData.get_stats_from_s3("project", "task", "variant") + result = under_test.HistoricTaskData.get_stats_from_s3( + "project", "task", "variant" + ) - self.assertEqual(result, [ - under_test.HistoricalTestInformation( - test_name="jstests/noPassthroughWithMongod/geo_near_random1.js", - num_pass=74, - num_fail=0, - avg_duration_pass=23.16216216216216, - max_duration_pass=27.123, - ), - under_test.HistoricalTestInformation( - test_name="shell_advance_cluster_time:ValidateCollections", - num_pass=74, - num_fail=0, - avg_duration_pass=1.662162162162162, - max_duration_pass=100.0987, - ), - ]) + self.assertEqual( + result, + [ + under_test.HistoricalTestInformation( + test_name="jstests/noPassthroughWithMongod/geo_near_random1.js", + num_pass=74, + num_fail=0, + avg_duration_pass=23.16216216216216, + max_duration_pass=27.123, + ), + under_test.HistoricalTestInformation( + test_name="shell_advance_cluster_time:ValidateCollections", + num_pass=74, + num_fail=0, + avg_duration_pass=1.662162162162162, + max_duration_pass=100.0987, + ), + ], + ) - @patch.object(Session, 'get') + @patch.object(Session, "get") def test_get_stats_from_s3_json_decode_error(self, mock_get): mock_response = MagicMock() mock_response.json.side_effect = JSONDecodeError("msg", "doc", 0) mock_get.return_value = mock_response - result = under_test.HistoricTaskData.get_stats_from_s3("project", "task", "variant") + result = under_test.HistoricTaskData.get_stats_from_s3( + "project", "task", "variant" + ) self.assertEqual(result, []) diff --git a/buildscripts/timeouts/timeout.py b/buildscripts/timeouts/timeout.py index 9030e9546b4..335803bcb1d 100644 --- a/buildscripts/timeouts/timeout.py +++ b/buildscripts/timeouts/timeout.py @@ -1,4 +1,5 @@ """Timeout information for generating tasks.""" + from __future__ import annotations import math @@ -17,7 +18,9 @@ MAX_EXPECTED_TIMEOUT = int(timedelta(hours=48).total_seconds()) DEFAULT_SCALING_FACTOR = 3.0 -def calculate_timeout(avg_runtime: float, scaling_factor: Optional[float] = None) -> int: +def calculate_timeout( + avg_runtime: float, scaling_factor: Optional[float] = None +) -> int: """ Determine how long a runtime to set based on average runtime and a scaling factor. @@ -54,10 +57,13 @@ class TimeoutEstimate(NamedTuple): def is_specified(self) -> bool: """Determine if any specific timeout value has been specified.""" - return self.max_test_runtime is not None or self.expected_task_runtime is not None + return ( + self.max_test_runtime is not None or self.expected_task_runtime is not None + ) - def calculate_test_timeout(self, repeat_factor: int, - scaling_factor: Optional[float] = None) -> Optional[int]: + def calculate_test_timeout( + self, repeat_factor: int, scaling_factor: Optional[float] = None + ) -> Optional[int]: """ Calculate the timeout to use for tests. @@ -68,14 +74,21 @@ class TimeoutEstimate(NamedTuple): if self.max_test_runtime is None: return None - timeout = calculate_timeout(self.max_test_runtime, scaling_factor) * repeat_factor - LOGGER.debug("Setting timeout", timeout=timeout, max_runtime=self.max_test_runtime, - repeat_factor=repeat_factor, scaling_factor=(scaling_factor - or DEFAULT_SCALING_FACTOR)) + timeout = ( + calculate_timeout(self.max_test_runtime, scaling_factor) * repeat_factor + ) + LOGGER.debug( + "Setting timeout", + timeout=timeout, + max_runtime=self.max_test_runtime, + repeat_factor=repeat_factor, + scaling_factor=(scaling_factor or DEFAULT_SCALING_FACTOR), + ) return timeout - def calculate_task_timeout(self, repeat_factor: int, - scaling_factor: Optional[float] = None) -> Optional[int]: + def calculate_task_timeout( + self, repeat_factor: int, scaling_factor: Optional[float] = None + ) -> Optional[int]: """ Calculate the timeout to use for tasks. @@ -86,16 +99,28 @@ class TimeoutEstimate(NamedTuple): if self.expected_task_runtime is None: return None - exec_timeout = calculate_timeout(self.expected_task_runtime, - scaling_factor) * repeat_factor + AVG_TASK_SETUP_TIME - LOGGER.debug("Setting exec_timeout", exec_timeout=exec_timeout, - suite_runtime=self.expected_task_runtime, repeat_factor=repeat_factor, - scaling_factor=(scaling_factor or DEFAULT_SCALING_FACTOR)) + exec_timeout = ( + calculate_timeout(self.expected_task_runtime, scaling_factor) + * repeat_factor + + AVG_TASK_SETUP_TIME + ) + LOGGER.debug( + "Setting exec_timeout", + exec_timeout=exec_timeout, + suite_runtime=self.expected_task_runtime, + repeat_factor=repeat_factor, + scaling_factor=(scaling_factor or DEFAULT_SCALING_FACTOR), + ) return exec_timeout def generate_timeout_cmd( - self, is_patch: bool, repeat_factor: int, test_timeout_factor: Optional[float] = None, - task_timeout_factor: Optional[float] = None, use_default: bool = False) -> TimeoutInfo: + self, + is_patch: bool, + repeat_factor: int, + test_timeout_factor: Optional[float] = None, + task_timeout_factor: Optional[float] = None, + use_default: bool = False, + ) -> TimeoutInfo: """ Create the timeout info to use to create a timeout shrub command. @@ -113,16 +138,22 @@ class TimeoutEstimate(NamedTuple): test_timeout = self.calculate_test_timeout(repeat_factor, test_timeout_factor) task_timeout = self.calculate_task_timeout(repeat_factor, task_timeout_factor) - if is_patch and (test_timeout > MAX_EXPECTED_TIMEOUT - or task_timeout > MAX_EXPECTED_TIMEOUT): + if is_patch and ( + test_timeout > MAX_EXPECTED_TIMEOUT or task_timeout > MAX_EXPECTED_TIMEOUT + ): frameinfo = getframeinfo(currentframe()) LOGGER.error( "This task looks like it is expected to run far longer than normal. This is " "likely due to setting the suite 'repeat' value very high. If you are sure " "this is something you want to do, comment this check out in your patch build " - "and resubmit", repeat_value=repeat_factor, timeout=test_timeout, - exec_timeout=task_timeout, code_file=frameinfo.filename, code_line=frameinfo.lineno, - max_timeout=MAX_EXPECTED_TIMEOUT) + "and resubmit", + repeat_value=repeat_factor, + timeout=test_timeout, + exec_timeout=task_timeout, + code_file=frameinfo.filename, + code_line=frameinfo.lineno, + max_timeout=MAX_EXPECTED_TIMEOUT, + ) raise ValueError("Failing due to expected runtime.") return TimeoutInfo.overridden(timeout=test_timeout, exec_timeout=task_timeout) @@ -165,7 +196,9 @@ class TimeoutInfo(object): def cmd(self): """Create a command that sets timeouts as specified.""" if not self.use_defaults: - return timeout_update(exec_timeout_secs=self.exec_timeout, timeout_secs=self.timeout) + return timeout_update( + exec_timeout_secs=self.exec_timeout, timeout_secs=self.timeout + ) return None diff --git a/buildscripts/timeouts/timeout_service.py b/buildscripts/timeouts/timeout_service.py index 3ef6afe3259..5b63c1de849 100644 --- a/buildscripts/timeouts/timeout_service.py +++ b/buildscripts/timeouts/timeout_service.py @@ -1,4 +1,5 @@ """Service for determining task timeouts.""" + from typing import Any, Dict, NamedTuple, Optional import inject @@ -52,7 +53,9 @@ class TimeoutService: """ historic_stats = self.lookup_historic_stats(timeout_params) if not historic_stats: - LOGGER.warning("Missing historic runtime information, using default timeout") + LOGGER.warning( + "Missing historic runtime information, using default timeout" + ) return TimeoutEstimate.no_timeouts() test_set = { @@ -60,13 +63,17 @@ class TimeoutService: for test in self.resmoke_proxy.list_tests(timeout_params.suite_name) } test_runtimes = [ - stat for stat in historic_stats.get_tests_runtimes() if stat.test_name in test_set + stat + for stat in historic_stats.get_tests_runtimes() + if stat.test_name in test_set ] test_runtime_set = {test.test_name for test in test_runtimes} num_tests_missing_historic_data = 0 for test in test_set: if test not in test_runtime_set: - LOGGER.warning("Could not find historic runtime information for test", test=test) + LOGGER.warning( + "Could not find historic runtime information for test", test=test + ) num_tests_missing_historic_data += 1 total_runtime = 0.0 @@ -81,7 +88,9 @@ class TimeoutService: num_tests_missing_historic_data += 1 total_num_tests = len(test_set) - if not self._have_enough_historic_stats(total_num_tests, num_tests_missing_historic_data): + if not self._have_enough_historic_stats( + total_num_tests, num_tests_missing_historic_data + ): LOGGER.warning( "Not enough historic runtime information, using default timeout", total_num_tests=total_num_tests, @@ -91,7 +100,11 @@ class TimeoutService: return TimeoutEstimate.no_timeouts() hook_overhead = self.get_task_hook_overhead( - timeout_params.suite_name, timeout_params.is_asan, total_num_tests, historic_stats) + timeout_params.suite_name, + timeout_params.is_asan, + total_num_tests, + historic_stats, + ) total_runtime += hook_overhead if num_tests_missing_historic_data > 0: @@ -100,12 +113,21 @@ class TimeoutService: "At least one test misses historic runtime information, using default idle timeout", num_tests_missing_historic_data=num_tests_missing_historic_data, ) - return TimeoutEstimate.only_task_timeout(expected_task_runtime=total_runtime) + return TimeoutEstimate.only_task_timeout( + expected_task_runtime=total_runtime + ) - return TimeoutEstimate(max_test_runtime=max_runtime, expected_task_runtime=total_runtime) + return TimeoutEstimate( + max_test_runtime=max_runtime, expected_task_runtime=total_runtime + ) - def get_task_hook_overhead(self, suite_name: str, is_asan: bool, test_count: int, - historic_stats: Optional[HistoricTaskData]) -> float: + def get_task_hook_overhead( + self, + suite_name: str, + is_asan: bool, + test_count: int, + historic_stats: Optional[HistoricTaskData], + ) -> float: """ Add how much overhead task-level hooks each suite should account for. @@ -125,15 +147,23 @@ class TimeoutService: return 0.0 clean_every_n_cadence = self._get_clean_every_n_cadence(suite_name, is_asan) - avg_clean_every_n_runtime = historic_stats.get_avg_hook_runtime(CLEAN_EVERY_N_HOOK) - LOGGER.debug("task hook overhead", cadence=clean_every_n_cadence, - runtime=avg_clean_every_n_runtime, is_asan=is_asan) + avg_clean_every_n_runtime = historic_stats.get_avg_hook_runtime( + CLEAN_EVERY_N_HOOK + ) + LOGGER.debug( + "task hook overhead", + cadence=clean_every_n_cadence, + runtime=avg_clean_every_n_runtime, + is_asan=is_asan, + ) if avg_clean_every_n_runtime != 0: n_expected_runs = test_count / clean_every_n_cadence return n_expected_runs * avg_clean_every_n_runtime return 0.0 - def lookup_historic_stats(self, timeout_params: TimeoutParams) -> Optional[HistoricTaskData]: + def lookup_historic_stats( + self, timeout_params: TimeoutParams + ) -> Optional[HistoricTaskData]: """ Lookup historic test results stats for the given task. @@ -142,24 +172,36 @@ class TimeoutService: """ try: LOGGER.info( - "Getting historic runtime information", evg_project=timeout_params.evg_project, - build_variant=timeout_params.build_variant, task_name=timeout_params.task_name) + "Getting historic runtime information", + evg_project=timeout_params.evg_project, + build_variant=timeout_params.build_variant, + task_name=timeout_params.task_name, + ) evg_stats = HistoricTaskData.from_s3( - timeout_params.evg_project, timeout_params.task_name, timeout_params.build_variant) + timeout_params.evg_project, + timeout_params.task_name, + timeout_params.build_variant, + ) if not evg_stats: LOGGER.warning("No historic runtime information available") return None - LOGGER.info("Found historic runtime information", - evg_stats=evg_stats.historic_test_results) + LOGGER.info( + "Found historic runtime information", + evg_stats=evg_stats.historic_test_results, + ) return evg_stats except Exception as err: # If we have any trouble getting the historic runtime information, log the issue, but # don't fall back to default timeouts instead of failing. - LOGGER.warning("Error querying history runtime information from evergreen: %s", err) + LOGGER.warning( + "Error querying history runtime information from evergreen: %s", err + ) return None @staticmethod - def _have_enough_historic_stats(num_tests: int, num_tests_missing_data: int) -> bool: + def _have_enough_historic_stats( + num_tests: int, num_tests_missing_data: int + ) -> bool: """ Check whether the required number of stats threshold is met. @@ -171,7 +213,9 @@ class TimeoutService: raise ValueError("Number of tests cannot be less than 0") if num_tests == 0: return True - return (num_tests - num_tests_missing_data) / num_tests > REQUIRED_STATS_THRESHOLD + return ( + num_tests - num_tests_missing_data + ) / num_tests > REQUIRED_STATS_THRESHOLD def _get_clean_every_n_cadence(self, suite_name: str, is_asan: bool) -> int: """ @@ -193,15 +237,20 @@ class TimeoutService: return clean_every_n_cadence - def _get_hook_config(self, suite_name: str, hook_name: str) -> Optional[Dict[str, Any]]: + def _get_hook_config( + self, suite_name: str, hook_name: str + ) -> Optional[Dict[str, Any]]: """ Get the configuration for the given hook. :param hook_name: Name of hook to query. :return: Configuration for hook, if it exists. """ - hooks_config = self.resmoke_proxy.read_suite_config(suite_name).get("executor", - {}).get("hooks") + hooks_config = ( + self.resmoke_proxy.read_suite_config(suite_name) + .get("executor", {}) + .get("hooks") + ) if hooks_config: for hook in hooks_config: if hook.get("class") == hook_name: diff --git a/buildscripts/todo_check.py b/buildscripts/todo_check.py index a07acf31cdd..0f150863a22 100755 --- a/buildscripts/todo_check.py +++ b/buildscripts/todo_check.py @@ -1,5 +1,6 @@ #!/usr/bin/env python3 """Check for TODOs in the source code.""" + import os import re import sys @@ -14,7 +15,7 @@ from evergreen import RetryingEvergreenApi EVG_CONFIG_FILE = "./.evergreen.yml" BASE_SEARCH_DIR = "." IGNORED_PATHS = [".git"] -ISSUE_RE = re.compile('(BUILD|SERVER|WT|PM|TOOLS|TIG|PERF|BF)-[0-9]+') +ISSUE_RE = re.compile("(BUILD|SERVER|WT|PM|TOOLS|TIG|PERF|BF)-[0-9]+") class Todo(NamedTuple): @@ -79,8 +80,9 @@ class TodoChecker: def __init__(self) -> None: """Initialize a new TODO checker.""" - self.found_todos = FoundTodos(no_tickets=[], with_tickets=defaultdict(list), - by_file=defaultdict(list)) + self.found_todos = FoundTodos( + no_tickets=[], with_tickets=defaultdict(list), by_file=defaultdict(list) + ) def check_file(self, file_name: str, file_contents: Iterable[str]) -> None: """ @@ -168,7 +170,7 @@ class TodoChecker: while ticket: found_any = self.report_on_ticket(ticket) or found_any rest_index = commit_message.find(ticket) - commit_message = commit_message[rest_index + len(ticket):] + commit_message = commit_message[rest_index + len(ticket) :] ticket = Todo.get_issue_key_from_line(commit_message) print(f"Checking complete - todos found: {found_any}") @@ -209,14 +211,27 @@ def get_summary_for_patch(version_id: str) -> str: @click.command() -@click.option("--ticket", help="Only report on TODOs associated with given Jira ticket.") -@click.option("--base-dir", default=BASE_SEARCH_DIR, help="Base directory to search in.") -@click.option("--commit-message", - help="For commit-queue execution only, ensure no TODOs for this commit") -@click.option("--patch-build", type=str, - help="For patch build execution only, check for any TODOs from patch description") -def main(ticket: Optional[str], base_dir: str, commit_message: Optional[str], - patch_build: Optional[str]): +@click.option( + "--ticket", help="Only report on TODOs associated with given Jira ticket." +) +@click.option( + "--base-dir", default=BASE_SEARCH_DIR, help="Base directory to search in." +) +@click.option( + "--commit-message", + help="For commit-queue execution only, ensure no TODOs for this commit", +) +@click.option( + "--patch-build", + type=str, + help="For patch build execution only, check for any TODOs from patch description", +) +def main( + ticket: Optional[str], + base_dir: str, + commit_message: Optional[str], + patch_build: Optional[str], +): """ Search for and report on TODO comments in the code base. @@ -282,11 +297,15 @@ def main(ticket: Optional[str], base_dir: str, commit_message: Optional[str], found_todos = todo_checker.report_on_all_tickets() if found_todos: - print("TODOs that reference a Jira ticket associated with the current commit should not " - "remain in the code after the commit is merged. A TODO referencing a ticket that has " - "been closed and that solved the TODO's purpose can be confusing.") - print("To fix this error resolve any TODOs that reference the Jira ticket this commit is" - "associated with.") + print( + "TODOs that reference a Jira ticket associated with the current commit should not " + "remain in the code after the commit is merged. A TODO referencing a ticket that has " + "been closed and that solved the TODO's purpose can be confusing." + ) + print( + "To fix this error resolve any TODOs that reference the Jira ticket this commit is" + "associated with." + ) sys.exit(1) sys.exit(0) diff --git a/buildscripts/toolchains.py b/buildscripts/toolchains.py index a49e829576d..c8dae1db60f 100755 --- a/buildscripts/toolchains.py +++ b/buildscripts/toolchains.py @@ -26,23 +26,25 @@ from typing import ( import yaml __all__ = [ - 'DEFAULT_DATA_FILE', - 'Toolchain', - 'ToolchainConfig', - 'ToolchainDataException', - 'ToolchainReleaseName', - 'ToolchainVersionName', - 'Toolchains', + "DEFAULT_DATA_FILE", + "Toolchain", + "ToolchainConfig", + "ToolchainDataException", + "ToolchainReleaseName", + "ToolchainVersionName", + "Toolchains", ] -DEFAULT_DATA_FILE: pathlib.Path = pathlib.Path(__file__).parent / '../etc/toolchains.yaml' +DEFAULT_DATA_FILE: pathlib.Path = ( + pathlib.Path(__file__).parent / "../etc/toolchains.yaml" +) class ToolchainVersionName(str, enum.Enum): """Represents a "named" toolchain version, such as "stable" or "testing".""" - STABLE = 'stable' - TESTING = 'testing' + STABLE = "stable" + TESTING = "testing" # pylint: disable=invalid-str-returned def __str__(self) -> str: @@ -52,9 +54,9 @@ class ToolchainVersionName(str, enum.Enum): class ToolchainReleaseName(str, enum.Enum): """Represents a "named" toolchain release, such as "rollback" or "current".""" - ROLLBACK = 'rollback' - CURRENT = 'current' - LATEST = 'latest' + ROLLBACK = "rollback" + CURRENT = "current" + LATEST = "latest" # pylint: disable=invalid-str-returned def __str__(self) -> str: @@ -65,34 +67,43 @@ class ToolchainDistroName(Tuple[str, ...], enum.Enum): """Represents a distribution for which the toolchain is built.""" AMAZON1_2012 = ( - 'amazon1-2012', - 'linux-64-amzn', + "amazon1-2012", + "linux-64-amzn", ) - AMAZON1_2018 = ('amazon1-2018', ) - AMAZON2 = ('amazon2', ) - ARCHLINUX = ('archlinux', ) - CENTOS6 = ('centos6', ) - DEBIAN8 = ('debian81', ) - DEBIAN9 = ('debian92', ) - DEBIAN10 = ('debian10', ) - DEBIAN11 = ('debian11', ) - DEBIAN12 = ('debian12', ) - MACOS1012 = ('macos-1012', ) - MACOS1014 = ('macos-1014', ) - MACOS1100 = ('macos-1100', ) - RHEL6 = ('rhel6', 'rhel62', 'rhel67') - RHEL7 = ('rhel7', 'rhel70', 'rhel71', 'rhel72', 'rhel76', 'ubi7') - RHEL8 = ('rhel8', 'rhel80', 'rhel81', 'rhel82', 'rhel83', 'rhel84', 'rhel88', 'ubi8') - SUSE12 = ('suse12', 'suse12-sp5') - SUSE15 = ('suse15', 'suse15-sp0', 'suse15-sp2') - UBUNTU1404 = ('ubuntu1404', ) - UBUNTU1604 = ('ubuntu1604', ) - UBUNTU1804 = ('ubuntu1804', ) - UBUNTU2004 = ('ubuntu2004', ) - DEFAULT = ('default', ) + AMAZON1_2018 = ("amazon1-2018",) + AMAZON2 = ("amazon2",) + ARCHLINUX = ("archlinux",) + CENTOS6 = ("centos6",) + DEBIAN8 = ("debian81",) + DEBIAN9 = ("debian92",) + DEBIAN10 = ("debian10",) + DEBIAN11 = ("debian11",) + DEBIAN12 = ("debian12",) + MACOS1012 = ("macos-1012",) + MACOS1014 = ("macos-1014",) + MACOS1100 = ("macos-1100",) + RHEL6 = ("rhel6", "rhel62", "rhel67") + RHEL7 = ("rhel7", "rhel70", "rhel71", "rhel72", "rhel76", "ubi7") + RHEL8 = ( + "rhel8", + "rhel80", + "rhel81", + "rhel82", + "rhel83", + "rhel84", + "rhel88", + "ubi8", + ) + SUSE12 = ("suse12", "suse12-sp5") + SUSE15 = ("suse15", "suse15-sp0", "suse15-sp2") + UBUNTU1404 = ("ubuntu1404",) + UBUNTU1604 = ("ubuntu1604",) + UBUNTU1804 = ("ubuntu1804",) + UBUNTU2004 = ("ubuntu2004",) + DEFAULT = ("default",) @classmethod - def from_str(cls, text: str) -> 'ToolchainDistroName': + def from_str(cls, text: str) -> "ToolchainDistroName": """Return the enumeration object matching a given string.""" for distro in cls: @@ -109,14 +120,14 @@ class ToolchainDistroName(Tuple[str, ...], enum.Enum): class ToolchainArchName(Tuple[str, ...], enum.Enum): """Represents an architecture for which the toolchain is built.""" - ARM64 = ('arm64', 'aarch64') - PPC64LE = ('ppc64le', 'power8') - S390X = ('s390x', 'zSeries') - X86_64 = ('x86_64', ) - DEFAULT = ('', ) + ARM64 = ("arm64", "aarch64") + PPC64LE = ("ppc64le", "power8") + S390X = ("s390x", "zSeries") + X86_64 = ("x86_64",) + DEFAULT = ("",) @classmethod - def from_str(cls, text: str) -> 'ToolchainArchName': + def from_str(cls, text: str) -> "ToolchainArchName": """Return the enumeratrion object matching a given string.""" for arch in cls: @@ -137,7 +148,9 @@ class ToolchainDataException(Exception): class ToolchainPlatform: """Represents a platform for which the toolchain is built.""" - def __init__(self, distro_id: str, arch: Optional[ToolchainArchName] = None) -> None: + def __init__( + self, distro_id: str, arch: Optional[ToolchainArchName] = None + ) -> None: """Parse a distro_id into a full toolchain platform.""" self._distro_id: str = distro_id @@ -151,12 +164,13 @@ class ToolchainPlatform: self._arch_span: Tuple[int, int] = self._find_arch_span() def _split_distro_id(self, start: int = 0) -> Tuple[str, str]: - return self._distro_id[start:].split('-', 1)[0], self._distro_id[start:].split('.')[0] + return self._distro_id[start:].split("-", 1)[0], self._distro_id[start:].split( + "." + )[0] def _find_distro_length(self) -> int: for distro in ToolchainDistroName: for name in distro.value: - if not name: continue @@ -170,19 +184,18 @@ class ToolchainPlatform: for arch in ToolchainArchName: for name in arch.value: - if not name: continue iter_start = self._distro_length + 1 while iter_start < len(self._distro_id): - if name.lower() in self._split_distro_id(self._distro_length): arch_span = (iter_start, len(name)) iter_start += len(name) - if iter_start < len(self._distro_id) and self._distro_id[iter_start] in ('-', - '.'): + if iter_start < len(self._distro_id) and self._distro_id[ + iter_start + ] in ("-", "."): iter_start += 1 if arch_span is None: @@ -221,7 +234,8 @@ class ToolchainPlatform: self._arch = ToolchainArchName.X86_64 else: self._arch = ToolchainArchName.from_str( - self.distro_id[arch_span[0]:arch_span[0] + arch_span[1]]) + self.distro_id[arch_span[0] : arch_span[0] + arch_span[1]] + ) return self._arch @@ -232,9 +246,9 @@ class ToolchainPlatform: if self._tag is None: arch_span: Tuple[int, int] = self._arch_span if arch_span[0] + arch_span[1] + 1 < len(self._distro_id): - self._tag = self._distro_id[arch_span[0] + arch_span[1] + 1:] + self._tag = self._distro_id[arch_span[0] + arch_span[1] + 1 :] else: - self._tag = '' + self._tag = "" if self._tag: return self._tag @@ -260,11 +274,12 @@ class ToolchainConfig: """Construct a toolchain configuration from a data file.""" try: - with open(data_file.absolute(), 'r', encoding='utf-8') as yaml_stream: + with open(data_file.absolute(), "r", encoding="utf-8") as yaml_stream: self._data = yaml.safe_load(yaml_stream) except yaml.YAMLError as parent_exc: raise ToolchainDataException( - f"Could not read toolchain data file: `{data_file}'") from parent_exc + f"Could not read toolchain data file: `{data_file}'" + ) from parent_exc self._platform: ToolchainPlatform = platform @@ -272,14 +287,14 @@ class ToolchainConfig: def base_path(self) -> pathlib.Path: """Return the base (installed) path for toolchain releases.""" - return pathlib.Path(self._data['toolchains']['base_path']) + return pathlib.Path(self._data["toolchains"]["base_path"]) @property def all_releases(self) -> Dict[str, Dict[str, str]]: """Return all known releases in the data file.""" try: - return self._data['toolchains']['releases'] + return self._data["toolchains"]["releases"] except (KeyError, TypeError): return {} @@ -299,16 +314,18 @@ class ToolchainConfig: platform_section = self.all_releases[str(self._platform)] elif f"{self._platform.distro}.{self._platform.arch}" in self.all_releases: platform_section = self.all_releases[ - f"{self._platform.distro}.{self._platform.arch}"] + f"{self._platform.distro}.{self._platform.arch}" + ] elif f"{self._platform.distro}.{self._platform.tag}" in self.all_releases: platform_section = self.all_releases[ - f"{self._platform.distro}.{self._platform.tag}"] + f"{self._platform.distro}.{self._platform.tag}" + ] elif f"{self._platform.distro}" in self.all_releases: platform_section = self.all_releases[f"{self._platform.distro}"] if not platform_section: try: - platform_section = self.all_releases['default'] + platform_section = self.all_releases["default"] except KeyError: return {} @@ -318,29 +335,34 @@ class ToolchainConfig: def versions(self) -> List[str]: """Return all known versions in the data file.""" - return self._data['toolchains']['versions'] + return self._data["toolchains"]["versions"] @property def aliases(self) -> Dict[str, str]: """Return all known version aliases in the data file.""" - return self._data['toolchains']['version_aliases'] + return self._data["toolchains"]["version_aliases"] @property def revisions_dir(self) -> Optional[pathlib.Path]: """Return the legacy revisions directory for toolchain releases.""" - warnings.warn(("This is legacy toolchain usage. " - f"Call {self.__class__.__name__}.releases_dir() instead."), - DeprecationWarning, stacklevel=2) + warnings.warn( + ( + "This is legacy toolchain usage. " + f"Call {self.__class__.__name__}.releases_dir() instead." + ), + DeprecationWarning, + stacklevel=2, + ) - return self.base_path.joinpath('revisions') + return self.base_path.joinpath("revisions") @property def releases_dir(self) -> pathlib.Path: """Return the directory where toolchain releases are installed.""" - return self.base_path.joinpath('releases') + return self.base_path.joinpath("releases") def search_releases(self, release_name: str) -> Optional[str]: """Search configured releases for a given release name.""" @@ -383,7 +405,9 @@ class Toolchain: return path - def exec_path(self, version: Union[ToolchainVersionName, str]) -> Optional[pathlib.Path]: + def exec_path( + self, version: Union[ToolchainVersionName, str] + ) -> Optional[pathlib.Path]: """Return a path to a specific toolchain version.""" install_path = self.install_path @@ -396,7 +420,8 @@ class Toolchain: version = self._config.aliases[version] except KeyError: raise ValueError( - f"Toolchain version `{version}' not defined in data file") from None + f"Toolchain version `{version}' not defined in data file" + ) from None return install_path / version @@ -415,19 +440,25 @@ class Toolchains(Mapping[Union[ToolchainReleaseName, str], Toolchain]): releases_dir: Optional[pathlib.Path] = self._config.releases_dir if releases_dir and releases_dir.exists(): - release_dirs.extend([path for path in releases_dir.iterdir() if path.is_dir()]) + release_dirs.extend( + [path for path in releases_dir.iterdir() if path.is_dir()] + ) with warnings.catch_warnings(): warnings.simplefilter("ignore") revisions_dir: Optional[pathlib.Path] = self._config.revisions_dir if revisions_dir and revisions_dir.exists(): - release_dirs.extend([path for path in revisions_dir.iterdir() if path.is_dir()]) + release_dirs.extend( + [path for path in revisions_dir.iterdir() if path.is_dir()] + ) if release_dirs: return [ - path.name for path in sorted(release_dirs, key=lambda path: path.stat().st_mtime, - reverse=True) + path.name + for path in sorted( + release_dirs, key=lambda path: path.stat().st_mtime, reverse=True + ) ] return [] @@ -446,11 +477,10 @@ class Toolchains(Mapping[Union[ToolchainReleaseName, str], Toolchain]): """Return a list of all configured toolchain releases.""" configured: Set[Union[str, None]] = { - self._config.search_releases(name.value) - for name in ToolchainReleaseName + self._config.search_releases(name.value) for name in ToolchainReleaseName } - configured.add(self._config.search_releases('default')) + configured.add(self._config.search_releases("default")) return [release for release in configured if release is not None] @@ -466,7 +496,7 @@ class Toolchains(Mapping[Union[ToolchainReleaseName, str], Toolchain]): latest_symlink: Optional[pathlib.Path] = None try: - latest_symlink = self._config.releases_dir.joinpath('latest') + latest_symlink = self._config.releases_dir.joinpath("latest") except AttributeError: latest_symlink = None @@ -534,18 +564,26 @@ class _FormatterClass: here for NicerHelpFormatter to inherit from it and prevent the error. """ - def __call__(self, _: str) -> argparse.HelpFormatter: - ... + def __call__(self, _: str) -> argparse.HelpFormatter: ... # pylint: disable=protected-access class NicerHelpFormatter(argparse.HelpFormatter, _FormatterClass): """A HelpFormatter with nicer output than the default.""" - def __init__(self, prog: str, indent_increment: int = 2, max_help_position: int = 32, - width=None) -> None: - super().__init__(prog=prog, indent_increment=indent_increment, - max_help_position=max_help_position, width=width) + def __init__( + self, + prog: str, + indent_increment: int = 2, + max_help_position: int = 32, + width=None, + ) -> None: + super().__init__( + prog=prog, + indent_increment=indent_increment, + max_help_position=max_help_position, + width=width, + ) def __call__(self, prog: str) -> argparse.HelpFormatter: return NicerHelpFormatter(prog) @@ -561,7 +599,7 @@ class NicerHelpFormatter(argparse.HelpFormatter, _FormatterClass): return "" if not action.option_strings: default = self._get_default_metavar_for_optional(action) - metavar, = self._metavar_formatter(action, default)(1) + (metavar,) = self._metavar_formatter(action, default)(1) return metavar parts: List[str] = [] @@ -577,10 +615,11 @@ class NicerHelpFormatter(argparse.HelpFormatter, _FormatterClass): return f"{' '.join(parts)} {args_string}" - return ' '.join(parts) + return " ".join(parts) def _iter_indented_subactions( - self, action: argparse.Action) -> Generator[argparse.Action, None, None]: + self, action: argparse.Action + ) -> Generator[argparse.Action, None, None]: if isinstance(action, (argparse._SubParsersAction, DictChoiceAction)): try: get_subactions = action._get_subactions @@ -593,8 +632,9 @@ class NicerHelpFormatter(argparse.HelpFormatter, _FormatterClass): for subaction in super()._iter_indented_subactions(action): yield subaction - def _metavar_formatter(self, action: argparse.Action, - default_metavar: str) -> Callable[[int], Tuple[str, ...]]: + def _metavar_formatter( + self, action: argparse.Action, default_metavar: str + ) -> Callable[[int], Tuple[str, ...]]: if action.metavar is not None: result = action.metavar elif action.choices is not None: @@ -610,7 +650,7 @@ class NicerHelpFormatter(argparse.HelpFormatter, _FormatterClass): def _format(tuple_size: int) -> Tuple[str, ...]: if isinstance(result, tuple): return result - return (result, ) * tuple_size + return (result,) * tuple_size return _format @@ -620,8 +660,9 @@ class DictChoiceAction(argparse._StoreAction): """An action with nicer per-choice formatting.""" class _ChoicesPseudoAction(argparse.Action): - def __init__(self, name: str, aliases: List[str], help: Optional[str] = None) -> None: - + def __init__( + self, name: str, aliases: List[str], help: Optional[str] = None + ) -> None: metavar = dest = name if aliases: metavar += f" {' | '.join(aliases)}" @@ -632,22 +673,41 @@ class DictChoiceAction(argparse._StoreAction): super().__init__(option_strings=[], dest=dest, help=help, metavar=metavar) - def __call__(self, parser: argparse.ArgumentParser, namespace: argparse.Namespace, - values: Union[str, Sequence[Any], None], - option_string: Optional[str] = None) -> None: - + def __call__( + self, + parser: argparse.ArgumentParser, + namespace: argparse.Namespace, + values: Union[str, Sequence[Any], None], + option_string: Optional[str] = None, + ) -> None: parser.print_help() parser.exit() - def __init__(self, option_strings: List[str], dest: str, nargs: Optional[int] = None, - const: Optional[Any] = None, default: Optional[Any] = None, - type: Optional[type] = None, choices: Optional[Dict[str, str]] = None, - required: bool = False, help: Optional[str] = None, - metavar: Optional[str] = None) -> None: - - super().__init__(option_strings=option_strings, dest=dest, nargs=nargs, const=const, - default=default, type=type, choices=choices, required=required, help=help, - metavar=metavar) + def __init__( + self, + option_strings: List[str], + dest: str, + nargs: Optional[int] = None, + const: Optional[Any] = None, + default: Optional[Any] = None, + type: Optional[type] = None, + choices: Optional[Dict[str, str]] = None, + required: bool = False, + help: Optional[str] = None, + metavar: Optional[str] = None, + ) -> None: + super().__init__( + option_strings=option_strings, + dest=dest, + nargs=nargs, + const=const, + default=default, + type=type, + choices=choices, + required=required, + help=help, + metavar=metavar, + ) self.choices: Dict[str, str] = {} if choices: self.choices = choices @@ -663,27 +723,48 @@ class DictChoiceAction(argparse._StoreAction): class NicerArgumentParser(argparse.ArgumentParser): """An argument parser with nicer help output.""" - def __init__(self, prog: Optional[str] = None, usage: Optional[str] = None, - description: Optional[str] = None, epilog: Optional[str] = None, - parents: Optional[List[argparse.ArgumentParser]] = None, prefix_chars: str = '-', - fromfile_prefix_chars: Optional[str] = None, - argument_default: Optional[Any] = None, conflict_handler: str = 'error', - add_help: bool = True, allow_abbrev: bool = True) -> None: + def __init__( + self, + prog: Optional[str] = None, + usage: Optional[str] = None, + description: Optional[str] = None, + epilog: Optional[str] = None, + parents: Optional[List[argparse.ArgumentParser]] = None, + prefix_chars: str = "-", + fromfile_prefix_chars: Optional[str] = None, + argument_default: Optional[Any] = None, + conflict_handler: str = "error", + add_help: bool = True, + allow_abbrev: bool = True, + ) -> None: """Initialize a NicerParser.""" - super().__init__(prog=prog, usage=usage, description=description, epilog=epilog, - parents=parents if parents else [], formatter_class=NicerHelpFormatter, - prefix_chars=prefix_chars, fromfile_prefix_chars=fromfile_prefix_chars, - argument_default=argument_default, conflict_handler=conflict_handler, - add_help=add_help, allow_abbrev=allow_abbrev) - self._optionals.title = 'Options' - self._positionals.title = 'Queries' + super().__init__( + prog=prog, + usage=usage, + description=description, + epilog=epilog, + parents=parents if parents else [], + formatter_class=NicerHelpFormatter, + prefix_chars=prefix_chars, + fromfile_prefix_chars=fromfile_prefix_chars, + argument_default=argument_default, + conflict_handler=conflict_handler, + add_help=add_help, + allow_abbrev=allow_abbrev, + ) + self._optionals.title = "Options" + self._positionals.title = "Queries" def format_help(self) -> str: formatter = self._get_formatter() formatter.add_text(self.description) - formatter.add_usage(self.usage, self._actions, self._mutually_exclusive_groups, - prefix='Usage:\n ') + formatter.add_usage( + self.usage, + self._actions, + self._mutually_exclusive_groups, + prefix="Usage:\n ", + ) for action_group in self._action_groups: formatter.start_section(action_group.title) @@ -696,94 +777,170 @@ class NicerArgumentParser(argparse.ArgumentParser): return formatter.format_help() -if __name__ == '__main__': +if __name__ == "__main__": parser = NicerArgumentParser( - description='Tool for querying information about mongodbtoolchain.', add_help=False) + description="Tool for querying information about mongodbtoolchain.", + add_help=False, + ) - parser.add_argument('-h', '--help', action='help', default=argparse.SUPPRESS, - help='Show this help message and exit') - parser.add_argument('-f', '--from-file', help='Specify a toolchain data file', metavar='FILE', - type=str, default=str(DEFAULT_DATA_FILE)) - parser.add_argument('-d', '--distro-id', help='Evergreen distro_id', type=str, required=True) - parser.add_argument('-a', '--arch', help='Host architecture', type=str) + parser.add_argument( + "-h", + "--help", + action="help", + default=argparse.SUPPRESS, + help="Show this help message and exit", + ) + parser.add_argument( + "-f", + "--from-file", + help="Specify a toolchain data file", + metavar="FILE", + type=str, + default=str(DEFAULT_DATA_FILE), + ) + parser.add_argument( + "-d", "--distro-id", help="Evergreen distro_id", type=str, required=True + ) + parser.add_argument("-a", "--arch", help="Host architecture", type=str) - subparsers = parser.add_subparsers(title='Commands', dest='command', required=True) + subparsers = parser.add_subparsers(title="Commands", dest="command", required=True) show_parser = subparsers.add_parser( - 'show', description='Shows general toolchain collection info.', add_help=False, - help='Show general toolchain collection info') - config_parser = subparsers.add_parser('config', - description='Shows toolchain configuration info.', - add_help=False, help='Show toolchain configuration info') - platform_parser = subparsers.add_parser('platform', - description='Shows component parts of a distro_id.', - add_help=False, help='Show parts of a distro_id') - toolchain_parser = subparsers.add_parser('toolchain', - description='Shows specific toolchain info.', - add_help=False, help='Show specific toolchain info') + "show", + description="Shows general toolchain collection info.", + add_help=False, + help="Show general toolchain collection info", + ) + config_parser = subparsers.add_parser( + "config", + description="Shows toolchain configuration info.", + add_help=False, + help="Show toolchain configuration info", + ) + platform_parser = subparsers.add_parser( + "platform", + description="Shows component parts of a distro_id.", + add_help=False, + help="Show parts of a distro_id", + ) + toolchain_parser = subparsers.add_parser( + "toolchain", + description="Shows specific toolchain info.", + add_help=False, + help="Show specific toolchain info", + ) - show_parser.add_argument('-h', '--help', action='help', default=argparse.SUPPRESS, - help='Show this help message and exit') show_parser.add_argument( - 'query', action=DictChoiceAction, type=str, choices={ - 'available': 'All installed toolchains', - 'configured': 'Toolchains configured for the distro_id', - 'latest': 'The most recent installed toolchain', - }) + "-h", + "--help", + action="help", + default=argparse.SUPPRESS, + help="Show this help message and exit", + ) + show_parser.add_argument( + "query", + action=DictChoiceAction, + type=str, + choices={ + "available": "All installed toolchains", + "configured": "Toolchains configured for the distro_id", + "latest": "The most recent installed toolchain", + }, + ) - config_parser.add_argument('-h', '--help', action='help', default=argparse.SUPPRESS, - help='Show this help message and exit') config_parser.add_argument( - 'query', action=DictChoiceAction, type=str, choices={ - 'base_path': 'Toolchain base execution path', - 'releases': 'All defined release names', - 'versions': 'All defined version names', - 'aliases': 'All defined aliases for version names', - }) + "-h", + "--help", + action="help", + default=argparse.SUPPRESS, + help="Show this help message and exit", + ) + config_parser.add_argument( + "query", + action=DictChoiceAction, + type=str, + choices={ + "base_path": "Toolchain base execution path", + "releases": "All defined release names", + "versions": "All defined version names", + "aliases": "All defined aliases for version names", + }, + ) - platform_parser.add_argument('-h', '--help', action='help', default=argparse.SUPPRESS, - help='Show this help message and exit') platform_parser.add_argument( - 'query', action=DictChoiceAction, type=str, choices={ - 'distro': 'Show the "distro" component of the distro_id', - 'arch': 'Show the "arch" component of the distro_id', - 'tag': 'Show the information tag component of the distro_id', - }) + "-h", + "--help", + action="help", + default=argparse.SUPPRESS, + help="Show this help message and exit", + ) + platform_parser.add_argument( + "query", + action=DictChoiceAction, + type=str, + choices={ + "distro": 'Show the "distro" component of the distro_id', + "arch": 'Show the "arch" component of the distro_id', + "tag": "Show the information tag component of the distro_id", + }, + ) - toolchain_parser.add_argument('-h', '--help', action='help', default=argparse.SUPPRESS, - help='Show this help message and exit') - toolchain_parser.add_argument('-v', '--toolchain-version', help='Toolchain version', type=str, - default=str(ToolchainVersionName.STABLE)) - toolchain_parser.add_argument('-r', '--release', help="Toolchain release", type=str, - default=str(ToolchainReleaseName.CURRENT)) toolchain_parser.add_argument( - 'query', action=DictChoiceAction, type=str, choices={ - 'install_path': 'Toolchain installation path', - 'exec_path': 'Toolchain execution path', - }) + "-h", + "--help", + action="help", + default=argparse.SUPPRESS, + help="Show this help message and exit", + ) + toolchain_parser.add_argument( + "-v", + "--toolchain-version", + help="Toolchain version", + type=str, + default=str(ToolchainVersionName.STABLE), + ) + toolchain_parser.add_argument( + "-r", + "--release", + help="Toolchain release", + type=str, + default=str(ToolchainReleaseName.CURRENT), + ) + toolchain_parser.add_argument( + "query", + action=DictChoiceAction, + type=str, + choices={ + "install_path": "Toolchain installation path", + "exec_path": "Toolchain execution path", + }, + ) parsed_args = parser.parse_args() obj: Optional[object] = None # Set up the objects required for each command - toolchain_platform = ToolchainPlatform(distro_id=parsed_args.distro_id, arch=parsed_args.arch) - if parsed_args.command == 'platform': + toolchain_platform = ToolchainPlatform( + distro_id=parsed_args.distro_id, arch=parsed_args.arch + ) + if parsed_args.command == "platform": obj = toolchain_platform - elif parsed_args.command in ('show', 'config', 'toolchain'): + elif parsed_args.command in ("show", "config", "toolchain"): try: toolchain_config = ToolchainConfig( - pathlib.Path(parsed_args.from_file), platform=toolchain_platform) + pathlib.Path(parsed_args.from_file), platform=toolchain_platform + ) except ToolchainDataException as exc: print(exc, file=sys.stderr) sys.exit(1) toolchains = Toolchains(config=toolchain_config) - if parsed_args.command == 'show': + if parsed_args.command == "show": obj = toolchains - elif parsed_args.command == 'config': + elif parsed_args.command == "config": obj = toolchain_config - elif parsed_args.command == 'toolchain': + elif parsed_args.command == "toolchain": obj = toolchains[parsed_args.release] else: print(f"Unknown command: {parsed_args.command}", file=sys.stderr) @@ -793,7 +950,7 @@ if __name__ == '__main__': output: Any attribute = getattr(obj, parsed_args.query) if callable(attribute): - if attribute.__name__ == 'exec_path': + if attribute.__name__ == "exec_path": output = attribute(parsed_args.toolchain_version) else: output = attribute() @@ -805,7 +962,7 @@ if __name__ == '__main__': if isinstance(output, (tuple, list)): output = str.join(" ", output) elif isinstance(output, dict): - output = '\n'.join([f"{k}: {v}" for k, v in output.items()]) + output = "\n".join([f"{k}: {v}" for k, v in output.items()]) elif not isinstance(output, str): output = str(output) diff --git a/buildscripts/unittest_grouper.py b/buildscripts/unittest_grouper.py index 4d803f24990..67c1ea8906e 100644 --- a/buildscripts/unittest_grouper.py +++ b/buildscripts/unittest_grouper.py @@ -67,7 +67,9 @@ def download_buildozer(download_location: str = "./"): operating_system = determine_platform() architechture = determine_architecture() if operating_system == "windows" and architechture == "arm64": - raise RuntimeError("There are no published arm windows releases for buildifier.") + raise RuntimeError( + "There are no published arm windows releases for buildifier." + ) extension = ".exe" if operating_system == "windows" else "" binary_name = f"buildozer-{operating_system}-{architechture}{extension}" @@ -167,7 +169,8 @@ def validate_bazel_groups(generate_report, fix): bazel_bin, "query", 'kind(extract_debug, attr(tags, "[\[ ]mongo_unittest[,\]]", //src/...))', - ] + query_opts, + ] + + query_opts, capture_output=True, text=True, check=True, @@ -194,7 +197,8 @@ def validate_bazel_groups(generate_report, fix): bazel_bin, "query", f'kind(extract_debug, attr(tags, "[\[ ]mongo_unittest_{group}_group[,\]]", //src/...))', - ] + query_opts, + ] + + query_opts, capture_output=True, text=True, check=True, @@ -210,35 +214,45 @@ def validate_bazel_groups(generate_report, fix): if groups[group] != group_tests: for test in group_tests: if test not in bazel_unittests: - failures.append([ - test + " tag", - f"{test} not a 'mongo_unittest' but has 'mongo_unittest_{group}_group' tag.", - ]) + failures.append( + [ + test + " tag", + f"{test} not a 'mongo_unittest' but has 'mongo_unittest_{group}_group' tag.", + ] + ) print(failures[-1][1]) if fix: - buildozer_update_cmds += [[ - f"remove tags mongo_unittest_{group}_group", test - ]] + buildozer_update_cmds += [ + [f"remove tags mongo_unittest_{group}_group", test] + ] for test in groups[group]: if test not in group_tests: failures.append( - [test + " tag", f"{test} missing 'mongo_unittest_{group}_group'"]) + [ + test + " tag", + f"{test} missing 'mongo_unittest_{group}_group'", + ] + ) print(failures[-1][1]) if fix: - buildozer_update_cmds += [[f"add tags mongo_unittest_{group}_group", test]] + buildozer_update_cmds += [ + [f"add tags mongo_unittest_{group}_group", test] + ] for test in group_tests: if test not in groups[group]: - failures.append([ - test + " tag", - f"{test} is tagged in the wrong group.", - ]) + failures.append( + [ + test + " tag", + f"{test} is tagged in the wrong group.", + ] + ) print(failures[-1][1]) if fix: - buildozer_update_cmds += [[ - f"remove tags mongo_unittest_{group}_group", test - ]] + buildozer_update_cmds += [ + [f"remove tags mongo_unittest_{group}_group", test] + ] if fix: for cmd in buildozer_update_cmds: diff --git a/buildscripts/util/buildozer_utils.py b/buildscripts/util/buildozer_utils.py index 0f01d35fc10..3efa25a8e22 100644 --- a/buildscripts/util/buildozer_utils.py +++ b/buildscripts/util/buildozer_utils.py @@ -3,7 +3,9 @@ from typing import List def _bd_command(cmd: str, labels: List[str]): - print(f"buildozer '{cmd}' " + " ".join(labels), ) + print( + f"buildozer '{cmd}' " + " ".join(labels), + ) p = subprocess.run( f"buildozer '{cmd}' " + " ".join(labels), capture_output=True, @@ -33,7 +35,9 @@ def bd_new(package: str, rule_kind: str, rule_name: str) -> None: _bd_command(f"new {rule_kind} {rule_name}", [package]) -def bd_comment(labels: List[str], comment: str, attr: str = "", value: str = "") -> None: +def bd_comment( + labels: List[str], comment: str, attr: str = "", value: str = "" +) -> None: _bd_command(f"comment {attr} {value} {comment}", labels) diff --git a/buildscripts/util/cedar_report.py b/buildscripts/util/cedar_report.py index a0c9ae782a4..548aea9579f 100644 --- a/buildscripts/util/cedar_report.py +++ b/buildscripts/util/cedar_report.py @@ -36,7 +36,9 @@ class CedarTestReport: return { "info": { "test_name": self.test_name, - "args": {"thread_level": self.thread_level, }, + "args": { + "thread_level": self.thread_level, + }, }, "metrics": [metric.as_dict() for metric in self.metrics], } diff --git a/buildscripts/util/codeowners_utils.py b/buildscripts/util/codeowners_utils.py index d9cf8f406da..4f17c1cfcec 100644 --- a/buildscripts/util/codeowners_utils.py +++ b/buildscripts/util/codeowners_utils.py @@ -26,7 +26,9 @@ def process_owners(cur_dir: str) -> Tuple[Dict[re.Pattern, List[str]], bool]: no_parent_owners = False if "options" in contents: options = contents["options"] - no_parent_owners = "no_parent_owners" in options and options["no_parent_owners"] + no_parent_owners = ( + "no_parent_owners" in options and options["no_parent_owners"] + ) filters = {} for file_filter in contents["filters"]: @@ -50,7 +52,9 @@ def process_owners(cur_dir: str) -> Tuple[Dict[re.Pattern, List[str]], bool]: class Owners: def __init__(self): - self.co_jira_map = yaml.safe_load(open("buildscripts/util/co_jira_map.yml", "r")) + self.co_jira_map = yaml.safe_load( + open("buildscripts/util/co_jira_map.yml", "r") + ) def get_codeowners(self, file_path: str) -> List[str]: cur_dir = os.path.dirname(file_path) @@ -74,6 +78,7 @@ class Owners: def get_jira_team_owner(self, file_path: str) -> List[str]: return [ - jira_team for codeowner in self.get_codeowners(file_path) + jira_team + for codeowner in self.get_codeowners(file_path) for jira_team in self.get_jira_team_from_codeowner(codeowner) ] diff --git a/buildscripts/util/fileops.py b/buildscripts/util/fileops.py index 0c765adac8b..3e29fd9b723 100644 --- a/buildscripts/util/fileops.py +++ b/buildscripts/util/fileops.py @@ -41,7 +41,9 @@ def write_file(path: str, contents: str) -> None: file_handle.write(contents) -def write_file_to_dir(directory: str, file: str, contents: str, overwrite: bool = True) -> None: +def write_file_to_dir( + directory: str, file: str, contents: str, overwrite: bool = True +) -> None: """ Write the contents provided to the file in the given directory. diff --git a/buildscripts/util/oauth.py b/buildscripts/util/oauth.py index c91d05e4f68..7e307de627a 100644 --- a/buildscripts/util/oauth.py +++ b/buildscripts/util/oauth.py @@ -45,13 +45,13 @@ class Configs: SCOPE = "kanopy+openid+profile" def __init__( - self, - client_credentials_scope: str = None, - client_credentials_user_name: str = None, - auth_domain: str = None, - client_id: str = None, - redirect_port: int = None, - scope: str = None, + self, + client_credentials_scope: str = None, + client_credentials_user_name: str = None, + auth_domain: str = None, + client_id: str = None, + redirect_port: int = None, + scope: str = None, ): """Initialize configs instance.""" @@ -80,7 +80,9 @@ class OAuthCredentials(BaseModel): return self.created_time + timedelta(seconds=self.expires_in) < datetime.now() @classmethod - def get_existing_credentials_from_file(cls, file_path: str) -> Optional[OAuthCredentials]: + def get_existing_credentials_from_file( + cls, file_path: str + ) -> Optional[OAuthCredentials]: """ Try to get OAuth credentials from a file location. @@ -90,8 +92,13 @@ class OAuthCredentials(BaseModel): """ try: creds = OAuthCredentials(**read_yaml_file(file_path)) - if (creds.access_token and creds.created_time and creds.expires_in and creds.user_name - and not creds.are_expired()): + if ( + creds.access_token + and creds.created_time + and creds.expires_in + and creds.user_name + and not creds.are_expired() + ): return creds else: return None @@ -111,13 +118,13 @@ class _RedirectServer(HTTPServer): code_verifier: str def __init__( - self, - server_address: Tuple[str, int], - handler: Callable[..., BaseHTTPRequestHandler], - redirect_uri: str, - auth_domain: str, - client_id: str, - code_verifier: str, + self, + server_address: Tuple[str, int], + handler: Callable[..., BaseHTTPRequestHandler], + redirect_uri: str, + auth_domain: str, + client_id: str, + code_verifier: str, ): self.redirect_uri = redirect_uri self.auth_domain = auth_domain @@ -169,11 +176,14 @@ class _Handler(BaseHTTPRequestHandler): expires_in = resp.get("expires_in") if not access_token or not expires_in: - raise ValueError("Could not get access token or expires_in data about access token") + raise ValueError( + "Could not get access token or expires_in data about access token" + ) headers = {"Authorization": f"Bearer {access_token}"} - resp = requests.get(f"https://{self.server.auth_domain}/v1/userinfo", - headers=headers).json() + resp = requests.get( + f"https://{self.server.auth_domain}/v1/userinfo", headers=headers + ).json() split_username = resp["preferred_username"].split("@") @@ -199,7 +209,9 @@ class PKCEOauthTools: redirect_uri: str scope: str - def __init__(self, auth_domain: str, client_id: str, redirect_port: int, scope: str): + def __init__( + self, auth_domain: str, client_id: str, redirect_port: int, scope: str + ): """ Create a new PKCEOauth tools instance. @@ -225,15 +237,17 @@ class PKCEOauthTools: state = "".join(choice(ascii_lowercase) for i in range(10)) - authorization_url = (f"https://{self.auth_domain}/v1/authorize?" - f"scope={self.scope}&" - f"response_type=code&" - f"response_mode=query&" - f"client_id={self.client_id}&" - f"code_challenge={code_challenge}&" - f"state={state}&" - f"code_challenge_method=S256&" - f"redirect_uri={self.redirect_uri}") + authorization_url = ( + f"https://{self.auth_domain}/v1/authorize?" + f"scope={self.scope}&" + f"response_type=code&" + f"response_mode=query&" + f"client_id={self.client_id}&" + f"code_challenge={code_challenge}&" + f"state={state}&" + f"code_challenge_method=S256&" + f"redirect_uri={self.redirect_uri}" + ) httpd = _RedirectServer( ("", self.redirect_port), @@ -253,12 +267,15 @@ class PKCEOauthTools: if not httpd.pkce_credentials: raise ValueError( "Could not retrieve Okta credentials to talk to Kanopy with. " - "Please sign out of Okta in your browser and try runnning this script again") + "Please sign out of Okta in your browser and try runnning this script again" + ) return httpd.pkce_credentials -def get_oauth_credentials(configs: Configs, print_auth_url: bool = False) -> OAuthCredentials: +def get_oauth_credentials( + configs: Configs, print_auth_url: bool = False +) -> OAuthCredentials: """ Run the OAuth workflow to get credentials for a human user. @@ -276,8 +293,9 @@ def get_oauth_credentials(configs: Configs, print_auth_url: bool = False) -> OAu return credentials -def get_client_cred_oauth_credentials(client_id: str, client_secret: str, - configs: Configs) -> OAuthCredentials: +def get_client_cred_oauth_credentials( + client_id: str, client_secret: str, configs: Configs +) -> OAuthCredentials: """ Run the OAuth workflow to get credentials for a machine user. @@ -298,7 +316,9 @@ def get_client_cred_oauth_credentials(client_id: str, client_secret: str, expires_in = token.get("expires_in") if not access_token or not expires_in: - raise ValueError("Could not get access token or expires_in data about access token") + raise ValueError( + "Could not get access token or expires_in data about access token" + ) return OAuthCredentials( access_token=access_token, diff --git a/buildscripts/util/read_config.py b/buildscripts/util/read_config.py index 26fc19cecdc..9f234933489 100644 --- a/buildscripts/util/read_config.py +++ b/buildscripts/util/read_config.py @@ -3,7 +3,9 @@ import yaml -def get_config_value(attrib, cmd_line_options, config_file_data, required=False, default=None): +def get_config_value( + attrib, cmd_line_options, config_file_data, required=False, default=None +): """ Get the configuration value to use. diff --git a/buildscripts/util/runcommand.py b/buildscripts/util/runcommand.py index ec6d157d54d..a9aa2b13fe7 100644 --- a/buildscripts/util/runcommand.py +++ b/buildscripts/util/runcommand.py @@ -11,7 +11,9 @@ from . import fileops class RunCommand(object): """Class to abstract executing a subprocess.""" - def __init__(self, string=None, output_file=None, append_file=False, propagate_signals=True): + def __init__( + self, string=None, output_file=None, append_file=False, propagate_signals=True + ): """Initialize the RunCommand object.""" self._command = string if string else "" self.output_file = output_file diff --git a/buildscripts/util/teststats.py b/buildscripts/util/teststats.py index a13af5d42c0..f1cb5cbbd0a 100644 --- a/buildscripts/util/teststats.py +++ b/buildscripts/util/teststats.py @@ -99,7 +99,9 @@ class HistoricHookInfo(NamedTuple): avg_duration: float @classmethod - def from_test_stats(cls, test_stats: HistoricalTestInformation) -> "HistoricHookInfo": + def from_test_stats( + cls, test_stats: HistoricalTestInformation + ) -> "HistoricHookInfo": """Create an instance from a test_stats object.""" return cls( hook_id=test_stats.test_name, @@ -129,8 +131,9 @@ class HistoricTestInfo(NamedTuple): hooks: List[HistoricHookInfo] @classmethod - def from_test_stats(cls, test_stats: HistoricalTestInformation, - hooks: List[HistoricHookInfo]) -> "HistoricTestInfo": + def from_test_stats( + cls, test_stats: HistoricalTestInformation, hooks: List[HistoricHookInfo] + ) -> "HistoricTestInfo": """Create an instance from a test_stats object.""" return cls( test_name=test_stats.test_name, @@ -143,8 +146,9 @@ class HistoricTestInfo(NamedTuple): """Get the normalized version of the test name.""" return normalize_test_name(self.test_name) - def total_hook_runtime(self, - predicate: Optional[Callable[[HistoricHookInfo], bool]] = None) -> float: + def total_hook_runtime( + self, predicate: Optional[Callable[[HistoricHookInfo], bool]] = None + ) -> float: """Get the average runtime of all the hooks associated with this test.""" def default_predicate(_) -> bool: @@ -152,15 +156,21 @@ class HistoricTestInfo(NamedTuple): if not predicate: predicate = default_predicate - return sum([ - hook.avg_duration * (hook.num_pass // self.num_pass if self.num_pass else 1) - for hook in self.hooks if predicate(hook) - ]) + return sum( + [ + hook.avg_duration + * (hook.num_pass // self.num_pass if self.num_pass else 1) + for hook in self.hooks + if predicate(hook) + ] + ) def total_test_runtime(self) -> float: """Get the average runtime of this test and it's non-task level hooks.""" if self.num_pass > 0: - return self.avg_duration + self.total_hook_runtime(lambda h: not h.is_task_level_hook()) + return self.avg_duration + self.total_hook_runtime( + lambda h: not h.is_task_level_hook() + ) return 0.0 def get_hook_overhead(self) -> float: @@ -176,7 +186,9 @@ class HistoricTaskData(object): self.historic_test_results = historic_test_results @staticmethod - def get_stats_from_s3(project: str, task: str, variant: str) -> List[HistoricalTestInformation]: + def get_stats_from_s3( + project: str, task: str, variant: str + ) -> List[HistoricalTestInformation]: """ Retrieve test stats from s3 for a given task. @@ -212,7 +224,8 @@ class HistoricTaskData(object): @classmethod def from_stats_list( - cls, historical_test_data: List[HistoricalTestInformation]) -> "HistoricTaskData": + cls, historical_test_data: List[HistoricalTestInformation] + ) -> "HistoricTaskData": """ Build historic task data from a list of historic stats. @@ -220,21 +233,29 @@ class HistoricTaskData(object): :return: Historic task data from the list of stats. """ hooks = defaultdict(list) - for hook in [stat for stat in historical_test_data if is_resmoke_hook(stat.test_name)]: + for hook in [ + stat for stat in historical_test_data if is_resmoke_hook(stat.test_name) + ]: historical_hook = HistoricHookInfo.from_test_stats(hook) hooks[historical_hook.test_name()].append(historical_hook) - return cls([ - HistoricTestInfo.from_test_stats(stat, - hooks[get_short_name_from_test_file(stat.test_name)]) - for stat in historical_test_data if not is_resmoke_hook(stat.test_name) - ]) + return cls( + [ + HistoricTestInfo.from_test_stats( + stat, hooks[get_short_name_from_test_file(stat.test_name)] + ) + for stat in historical_test_data + if not is_resmoke_hook(stat.test_name) + ] + ) def get_tests_runtimes(self) -> List[TestRuntime]: """Return the list of (test_file, runtime_in_secs) tuples ordered by decreasing runtime.""" tests = [ - TestRuntime(test_name=test_stats.normalized_test_name(), - runtime=test_stats.total_test_runtime()) + TestRuntime( + test_name=test_stats.normalized_test_name(), + runtime=test_stats.total_test_runtime(), + ) for test_stats in self.historic_test_results ] return sorted(tests, key=lambda x: x.runtime, reverse=True) @@ -242,8 +263,13 @@ class HistoricTaskData(object): def get_avg_hook_runtime(self, hook_name: str) -> float: """Get the average runtime for the specified hook.""" hook_instances = list( - chain.from_iterable([[hook for hook in test.hooks if hook.hook_name() == hook_name] - for test in self.historic_test_results])) + chain.from_iterable( + [ + [hook for hook in test.hooks if hook.hook_name() == hook_name] + for test in self.historic_test_results + ] + ) + ) if not hook_instances: return 0 diff --git a/buildscripts/utils.py b/buildscripts/utils.py index 815ccdc1c87..e0284a6f6cd 100644 --- a/buildscripts/utils.py +++ b/buildscripts/utils.py @@ -56,9 +56,14 @@ def get_git_version(): def get_git_describe(): """Return 'git describe --abbrev=7'.""" with open(os.devnull, "r+") as devnull: - proc = subprocess.Popen("git describe --abbrev=7", stdout=subprocess.PIPE, stderr=devnull, - stdin=devnull, shell=True) - return proc.communicate()[0].strip().decode('utf-8') + proc = subprocess.Popen( + "git describe --abbrev=7", + stdout=subprocess.PIPE, + stderr=devnull, + stdin=devnull, + shell=True, + ) + return proc.communicate()[0].strip().decode("utf-8") def execsys(args): @@ -96,7 +101,7 @@ def replace_with_repr(unicode_error): # fashion. This codec error handler will substitute the # repr() of the offending bytes into the decoded string # at the position they occurred - offender = unicode_error.object[unicode_error.start:unicode_error.end] + offender = unicode_error.object[unicode_error.start : unicode_error.end] return (str(repr(offender).strip("'").strip('"')), unicode_error.end) diff --git a/buildscripts/validate_commit_message.py b/buildscripts/validate_commit_message.py index 9c3d7fcaa27..39ac4a951c4 100755 --- a/buildscripts/validate_commit_message.py +++ b/buildscripts/validate_commit_message.py @@ -27,6 +27,7 @@ # it in the license file. # """Validate that the commit message is ok.""" + import argparse import logging import re @@ -42,7 +43,8 @@ def main(argv=None): """Execute Main function to validate commit messages.""" parser = argparse.ArgumentParser( usage="Validate the commit message. " - "It validates the latest message when no arguments are provided.") + "It validates the latest message when no arguments are provided." + ) parser.add_argument( "message", metavar="commit message", diff --git a/buildscripts/validate_file_size.py b/buildscripts/validate_file_size.py index 9b9a0e3b5c3..b64d65465ca 100644 --- a/buildscripts/validate_file_size.py +++ b/buildscripts/validate_file_size.py @@ -16,7 +16,8 @@ def main(): print( f"WARNING! {file_name} is {file_size_in_bytes} bytes, exceeding threshold" f" {FILE_SIZE_THRESHOLD_IN_BYTES} bytes, file upload may fail due to network issues, or Evergreen" - f" may reject very large yaml sizes") + f" may reject very large yaml sizes" + ) else: print( f"{file_name} is {file_size_in_bytes} bytes, below threshold {FILE_SIZE_THRESHOLD_IN_BYTES} bytes" @@ -25,5 +26,5 @@ def main(): print(f"{file_path} does not exist") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/buildscripts/validate_mongocryptd.py b/buildscripts/validate_mongocryptd.py index d30fa003887..b9015591f74 100644 --- a/buildscripts/validate_mongocryptd.py +++ b/buildscripts/validate_mongocryptd.py @@ -26,6 +26,7 @@ # it in the license file. # """Validate that mongocryptd push tasks are correct in etc/evergreen.yml.""" + from __future__ import absolute_import, print_function, unicode_literals import argparse @@ -77,7 +78,7 @@ def read_variable_from_yml(filename, variable_name): :param variable_name: Variable to read from file. :return: Value of variable or None. """ - with open(filename, 'r') as fh: + with open(filename, "r") as fh: nodes = yaml.safe_load(fh) variables = nodes["variables"] @@ -92,17 +93,20 @@ def main(): # type: () -> None """Execute Main Entry point.""" - parser = argparse.ArgumentParser(description='MongoDB CryptD Check Tool.') + parser = argparse.ArgumentParser(description="MongoDB CryptD Check Tool.") - parser.add_argument('file', type=str, help="etc/evergreen.yml file") - parser.add_argument('--variant', type=str, help="Build variant to check for") + parser.add_argument("file", type=str, help="etc/evergreen.yml file") + parser.add_argument("--variant", type=str, help="Build variant to check for") args = parser.parse_args() expected_variants = read_variable_from_yml(args.file, MONGOCRYPTD_VARIANTS) if not expected_variants: - print("ERROR: Could not find node %s in file '%s'" % (MONGOCRYPTD_VARIANTS, args.file), - file=sys.stderr) + print( + "ERROR: Could not find node %s in file '%s'" + % (MONGOCRYPTD_VARIANTS, args.file), + file=sys.stderr, + ) sys.exit(1) evg_config = parse_evergreen_file(args.file) @@ -111,15 +115,20 @@ def main(): sys.exit(0) if args.variant not in expected_variants: - print("ERROR: Expected to find variant %s in list %s" % (args.variant, expected_variants), - file=sys.stderr) print( - "ERROR: Please add the build variant %s to the %s list in '%s'" % - (args.variant, MONGOCRYPTD_VARIANTS, args.file), file=sys.stderr) + "ERROR: Expected to find variant %s in list %s" + % (args.variant, expected_variants), + file=sys.stderr, + ) + print( + "ERROR: Please add the build variant %s to the %s list in '%s'" + % (args.variant, MONGOCRYPTD_VARIANTS, args.file), + file=sys.stderr, + ) sys.exit(1) sys.exit(0) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/buildscripts/yaml_key_value.py b/buildscripts/yaml_key_value.py index 2a03de84328..2e2f170e844 100755 --- a/buildscripts/yaml_key_value.py +++ b/buildscripts/yaml_key_value.py @@ -17,9 +17,15 @@ def main(): """Execute Main program.""" parser = optparse.OptionParser(description=__doc__) - parser.add_option("--yamlFile", dest="yaml_file", default=None, help="YAML file to read") - parser.add_option("--yamlKey", dest="yaml_key", default=None, - help="Top level YAML key to provide the value") + parser.add_option( + "--yamlFile", dest="yaml_file", default=None, help="YAML file to read" + ) + parser.add_option( + "--yamlKey", + dest="yaml_key", + default=None, + help="Top level YAML key to provide the value", + ) (options, _) = parser.parse_args() if not options.yaml_file: diff --git a/evergreen/download_db_contrib_tool.py b/evergreen/download_db_contrib_tool.py index e502cc13216..d7be8e4c210 100644 --- a/evergreen/download_db_contrib_tool.py +++ b/evergreen/download_db_contrib_tool.py @@ -14,9 +14,7 @@ sys.path.append(str(mongo_path)) from buildscripts.util.expansions import get_expansion DB_CONTRIB_TOOL_VERSION = "v2.1.0" -RELEASE_URL = ( - f"https://mdb-build-public.s3.amazonaws.com/db-contrib-tool-binaries/{DB_CONTRIB_TOOL_VERSION}/" -) +RELEASE_URL = f"https://mdb-build-public.s3.amazonaws.com/db-contrib-tool-binaries/{DB_CONTRIB_TOOL_VERSION}/" def get_binary_name() -> str: @@ -43,7 +41,9 @@ def get_binary_name() -> str: operating_system = f"rhel{major_version}" - binary_name = f"db-contrib-tool_{DB_CONTRIB_TOOL_VERSION}_{operating_system}_{machine}" + binary_name = ( + f"db-contrib-tool_{DB_CONTRIB_TOOL_VERSION}_{operating_system}_{machine}" + ) if operating_system == "windows": binary_name = f"{binary_name}.exe" diff --git a/evergreen/functions/binaries_extract.py b/evergreen/functions/binaries_extract.py index b52c2526883..69664a93c88 100644 --- a/evergreen/functions/binaries_extract.py +++ b/evergreen/functions/binaries_extract.py @@ -32,15 +32,29 @@ import sys parser = argparse.ArgumentParser() -parser.add_argument('--change-dir', type=str, action='store', - help="The directory to change into to perform the extraction.") -parser.add_argument('--extraction-command', type=str, action='store', - help="The command to use for the extraction.") -parser.add_argument('--tarball', type=str, action='store', - help="The tarball to perform the extraction on.") parser.add_argument( - '--move-output', type=str, action='append', help= - "Move an extracted entry to a new location after extraction. Format is colon separated, e.g. '--move-output=file/to/move:path/to/destination'. Can accept glob like wildcards." + "--change-dir", + type=str, + action="store", + help="The directory to change into to perform the extraction.", +) +parser.add_argument( + "--extraction-command", + type=str, + action="store", + help="The command to use for the extraction.", +) +parser.add_argument( + "--tarball", + type=str, + action="store", + help="The tarball to perform the extraction on.", +) +parser.add_argument( + "--move-output", + type=str, + action="append", + help="Move an extracted entry to a new location after extraction. Format is colon separated, e.g. '--move-output=file/to/move:path/to/destination'. Can accept glob like wildcards.", ) args = parser.parse_args() @@ -53,33 +67,37 @@ else: working_dir = None tarball = pathlib.Path(args.tarball).as_posix() -shell = os.environ.get('SHELL', '/bin/bash') +shell = os.environ.get("SHELL", "/bin/bash") -if sys.platform == 'win32': - proc = subprocess.run(['C:/cygwin/bin/cygpath.exe', '-w', shell], text=True, - capture_output=True) +if sys.platform == "win32": + proc = subprocess.run( + ["C:/cygwin/bin/cygpath.exe", "-w", shell], text=True, capture_output=True + ) bash = pathlib.Path(proc.stdout.strip()) - cmd = [bash.as_posix(), '-c', f"{args.extraction_command} {tarball}"] + cmd = [bash.as_posix(), "-c", f"{args.extraction_command} {tarball}"] else: - cmd = [shell, '-c', f"{args.extraction_command} {tarball}"] + cmd = [shell, "-c", f"{args.extraction_command} {tarball}"] print(f"Extracting: {' '.join(cmd)}") -proc = subprocess.run(cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, - cwd=working_dir) +proc = subprocess.run( + cmd, text=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, cwd=working_dir +) print(proc.stdout) if args.move_output: for arg in args.move_output: try: - src, dst = arg.split(':') + src, dst = arg.split(":") print(f"Moving {src} to {dst}...") files_to_move = glob.glob(src, recursive=True) for file in files_to_move: result_dst = shutil.move(file, dst) print(f"Moved {file} to {result_dst}") except ValueError as exc: - print(f"Bad format, needs to be glob like paths in the from 'src:dst', got: {arg}") + print( + f"Bad format, needs to be glob like paths in the from 'src:dst', got: {arg}" + ) raise exc sys.exit(proc.returncode) diff --git a/evergreen/functions/poetry_lock_check.py b/evergreen/functions/poetry_lock_check.py index d779180dad8..c50098a3778 100644 --- a/evergreen/functions/poetry_lock_check.py +++ b/evergreen/functions/poetry_lock_check.py @@ -7,9 +7,7 @@ It will return non zero if poetry.lock and pyproject.toml are not synced import platform import subprocess -POETRY_LOCK_V200 = ( - """# This file is automatically @generated by Poetry 2.0.0 and should not be changed by hand.""" -) +POETRY_LOCK_V200 = """# This file is automatically @generated by Poetry 2.0.0 and should not be changed by hand.""" extras = [] if platform.machine() in set(["s390x", "ppc64le"]) and ".el9" not in platform.release(): diff --git a/evergreen/functions/security_reporting_scripts/upload_to_google_drive.py b/evergreen/functions/security_reporting_scripts/upload_to_google_drive.py index b0967284039..e7de9d12ac6 100644 --- a/evergreen/functions/security_reporting_scripts/upload_to_google_drive.py +++ b/evergreen/functions/security_reporting_scripts/upload_to_google_drive.py @@ -82,7 +82,9 @@ def upload( print(f"Failed to authenticate with Google API: {e}") sys.exit(1) - folder_id = releases_folder_id if triggered_by_tag.lower() == "true" else test_folder_id + folder_id = ( + releases_folder_id if triggered_by_tag.lower() == "true" else test_folder_id + ) if upload_file_name is None: input_file_name_str = str(input_file.resolve().name) diff --git a/evergreen/functions/upload_sbom_via_silkbomb.py b/evergreen/functions/upload_sbom_via_silkbomb.py index 868d9f5d98a..2b0bd0355cd 100644 --- a/evergreen/functions/upload_sbom_via_silkbomb.py +++ b/evergreen/functions/upload_sbom_via_silkbomb.py @@ -12,12 +12,16 @@ app = typer.Typer( ) -def get_changed_files_from_latest_commit(local_repo_path: str, branch_name: str = "master") -> dict: +def get_changed_files_from_latest_commit( + local_repo_path: str, branch_name: str = "master" +) -> dict: try: repo = Repo(local_repo_path) if branch_name not in repo.heads: - raise ValueError(f"Branch '{branch_name}' does not exist in the repository.") + raise ValueError( + f"Branch '{branch_name}' does not exist in the repository." + ) last_commit = repo.heads[branch_name].commit title = last_commit.summary @@ -104,7 +108,9 @@ def upload_sbom_via_silkbomb( try: print(f"Running command: {' '.join(command)}") - subprocess.run(command, check=True, text=True, capture_output=True, timeout=timeout_seconds) + subprocess.run( + command, check=True, text=True, capture_output=True, timeout=timeout_seconds + ) print("Updated sbom.json file upload via Silkbomb successful!") except FileNotFoundError as e: print(f"Error: '{container_command}' command not found.") @@ -126,18 +132,31 @@ def upload_sbom_via_silkbomb( def run( github_org: Annotated[ str, - typer.Option(..., envvar="GITHUB_ORG", help="Name of the github organization (e.g. 10gen)"), + typer.Option( + ..., + envvar="GITHUB_ORG", + help="Name of the github organization (e.g. 10gen)", + ), ], github_repo: Annotated[ - str, typer.Option(..., envvar="GITHUB_REPO", help="Repo name in 'owner/repo' format.") + str, + typer.Option( + ..., envvar="GITHUB_REPO", help="Repo name in 'owner/repo' format." + ), ], local_repo_path: Annotated[ str, - typer.Option(..., envvar="LOCAL_REPO_PATH", help="Path to the local git repository."), + typer.Option( + ..., envvar="LOCAL_REPO_PATH", help="Path to the local git repository." + ), ], branch_name: Annotated[ str, - typer.Option(..., envvar="BRANCH_NAME", help="The head branch (e.g., the PR branch name)."), + typer.Option( + ..., + envvar="BRANCH_NAME", + help="The head branch (e.g., the PR branch name).", + ), ], sbom_repo_path: Annotated[ str, @@ -159,33 +178,50 @@ def run( container_command: Annotated[ str, typer.Option( - ..., envvar="CONTAINER_COMMAND", help="Container engine to use ('podman' or 'docker')." + ..., + envvar="CONTAINER_COMMAND", + help="Container engine to use ('podman' or 'docker').", ), ] = "podman", container_image: Annotated[ - str, typer.Option(..., envvar="CONTAINER_IMAGE", help="Silkbomb container image.") + str, + typer.Option(..., envvar="CONTAINER_IMAGE", help="Silkbomb container image."), ] = "901841024863.dkr.ecr.us-east-1.amazonaws.com/release-infrastructure/silkbomb:2.0", creds_file: Annotated[ pathlib.Path, typer.Option( - ..., envvar="CONTAINER_ENV_FILES", help="Path for the temporary credentials file." + ..., + envvar="CONTAINER_ENV_FILES", + help="Path for the temporary credentials file.", ), ] = pathlib.Path("kondukto_credentials.env"), workdir: Annotated[ - str, typer.Option(..., envvar="WORKING_DIR", help="Path for the container volumes.") + str, + typer.Option(..., envvar="WORKING_DIR", help="Path for the container volumes."), ] = "/workdir", dry_run: Annotated[ - bool, typer.Option("--dry-run/--run", help="Check for changes without uploading.") + bool, + typer.Option("--dry-run/--run", help="Check for changes without uploading."), ] = True, check_sbom_file_change: Annotated[ - bool, typer.Option("--check-sbom-file-change", help="Check for changes to the SBOM file.") + bool, + typer.Option( + "--check-sbom-file-change", help="Check for changes to the SBOM file." + ), ] = False, ): if requester != "commit" and not dry_run: - print(f"Skipping: Run can only be triggered for 'commit', but requester was '{requester}'.") + print( + f"Skipping: Run can only be triggered for 'commit', but requester was '{requester}'." + ) sys.exit(0) - major_branches = ["v7.0", "v8.0", "v8.1", "master"] # Only major branches that MongoDB supports + major_branches = [ + "v7.0", + "v8.0", + "v8.1", + "master", + ] # Only major branches that MongoDB supports if False and branch_name not in major_branches: print(f"Skipping: Branch '{branch_name}' is not a major branch. Exiting.") sys.exit(0) @@ -199,7 +235,9 @@ def run( try: sbom_file_changed = True if check_sbom_file_change: - commit_changed_files = get_changed_files_from_latest_commit(repo_path, branch_name) + commit_changed_files = get_changed_files_from_latest_commit( + repo_path, branch_name + ) if commit_changed_files: print( f"Latest commit '{commit_changed_files['title']}' ({commit_changed_files['hash']}) in branch '{branch_name}' has the following changed files:" diff --git a/evergreen/generate_clang_tidy_report.py b/evergreen/generate_clang_tidy_report.py index 5fa3897164b..6db8cd50eb7 100644 --- a/evergreen/generate_clang_tidy_report.py +++ b/evergreen/generate_clang_tidy_report.py @@ -12,7 +12,10 @@ if clang_tidy: failures = [] for root, _, files in os.walk("bazel-bin"): for name in files: - if name.endswith(".clang-tidy.status") and "mongo_tidy_checks/tests/" not in root: + if ( + name.endswith(".clang-tidy.status") + and "mongo_tidy_checks/tests/" not in root + ): with open(os.path.join(root, name)) as f: if f.read().strip() == "1": tokens = name.split(".") @@ -29,7 +32,9 @@ if clang_tidy: filename = os.path.basename(log_file) parts = filename.split(".") if len(parts) < 5: - raise ValueError(f"Unexpected status file format: {filename}") + raise ValueError( + f"Unexpected status file format: {filename}" + ) source_file = ".".join(parts[:2]) target_name = parts[2] target_dir = re.search( @@ -41,7 +46,9 @@ if clang_tidy: failures.append( [ os.path.join( - re.sub("^.*/bazel_clang_tidy_src/", "src/", root, 1), + re.sub( + "^.*/bazel_clang_tidy_src/", "src/", root, 1 + ), source_file, ), content, @@ -60,4 +67,4 @@ if clang_tidy: else: report = make_report("bazel build --config=clang-tidy //src/mongo/...", "", 0) try_combine_reports(report) - put_report(report) \ No newline at end of file + put_report(report) diff --git a/evergreen/lint_fuzzer_sanity_patch.py b/evergreen/lint_fuzzer_sanity_patch.py index 4c0bf00c2e7..5af4963283c 100644 --- a/evergreen/lint_fuzzer_sanity_patch.py +++ b/evergreen/lint_fuzzer_sanity_patch.py @@ -9,7 +9,9 @@ from pathlib import Path # Get relative imports to work when the package is not installed on the PYTHONPATH. if __name__ == "__main__" and __package__ is None: - sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(os.path.realpath(__file__))))) + sys.path.append( + os.path.dirname(os.path.dirname(os.path.abspath(os.path.realpath(__file__)))) + ) # pylint: disable=wrong-import-position from buildscripts import simple_report @@ -20,7 +22,10 @@ from buildscripts.linter.filediff import gather_changed_files_for_lint def is_js_file(filename: str) -> bool: # return True - return (filename.startswith("jstests") or filename.startswith("src/mongo/db/modules/enterprise/jstests")) and filename.endswith(".js") + return ( + filename.startswith("jstests") + or filename.startswith("src/mongo/db/modules/enterprise/jstests") + ) and filename.endswith(".js") diffed_files = [Path(f) for f in gather_changed_files_for_lint(is_js_file)] @@ -43,11 +48,23 @@ for file in diffed_files: OUTPUT_FULL_DIR = Path(os.getcwd()) / OUTPUT_DIR INPUT_FULL_DIR = Path(os.getcwd()) / INPUT_DIR -subprocess.run([ - "./src/scripts/npm_run.sh", "jstestfuzz", "--", "--jsTestsDir", INPUT_FULL_DIR, "--out", - OUTPUT_FULL_DIR, "--numSourceFiles", - str(min(num_changed_files, 250)), "--numGeneratedFiles", "250" -], check=True, cwd="jstestfuzz") +subprocess.run( + [ + "./src/scripts/npm_run.sh", + "jstestfuzz", + "--", + "--jsTestsDir", + INPUT_FULL_DIR, + "--out", + OUTPUT_FULL_DIR, + "--numSourceFiles", + str(min(num_changed_files, 250)), + "--numGeneratedFiles", + "250", + ], + check=True, + cwd="jstestfuzz", +) def _parse_jsfile(jsfile: Path) -> simple_report.Result: @@ -57,9 +74,12 @@ def _parse_jsfile(jsfile: Path) -> simple_report.Result: """ print(f"Trying to parse jsfile {jsfile}") start_time = time.time() - proc = subprocess.run(["./src/scripts/npm_run.sh", "parse-jsfiles", "--", - str(jsfile)], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, - cwd="jstestfuzz") + proc = subprocess.run( + ["./src/scripts/npm_run.sh", "parse-jsfiles", "--", str(jsfile)], + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + cwd="jstestfuzz", + ) end_time = time.time() status = "pass" if proc.returncode == 0 else "fail" npm_run_output = proc.stdout.decode("UTF-8") @@ -68,8 +88,14 @@ def _parse_jsfile(jsfile: Path) -> simple_report.Result: else: print(f"Failed to parsed jsfile {jsfile}") print(npm_run_output) - return simple_report.Result(status=status, exit_code=proc.returncode, start=start_time, - end=end_time, test_file=jsfile.name, log_raw=npm_run_output) + return simple_report.Result( + status=status, + exit_code=proc.returncode, + start=start_time, + end=end_time, + test_file=jsfile.name, + log_raw=npm_run_output, + ) report = simple_report.Report(failures=0, results=[]) diff --git a/evergreen/macos_notary.py b/evergreen/macos_notary.py index eef7663e026..842f98e7e24 100644 --- a/evergreen/macos_notary.py +++ b/evergreen/macos_notary.py @@ -7,7 +7,7 @@ import sys import urllib.request import zipfile -if platform.system().lower() != 'darwin': +if platform.system().lower() != "darwin": print("Not a macos system, skipping macos signing.") sys.exit(0) @@ -15,62 +15,73 @@ if len(sys.argv) < 2: print("Must provide at least 1 archive to sign.") sys.exit(1) -supported_archs = { - 'arm64': 'arm64', - 'x86_64': 'amd64' -} +supported_archs = {"arm64": "arm64", "x86_64": "amd64"} arch = platform.uname().machine.lower() if arch not in supported_archs: print(f"Unsupported platform uname arch: {arch}, must be {supported_archs.keys()}") sys.exit(1) -macnotary_name = f'darwin_{supported_archs[arch]}' +macnotary_name = f"darwin_{supported_archs[arch]}" -if os.environ['project'] in ['mongodb-mongo-master-nightly', 'mongo-release']: - signing_type = 'notarizeAndSign' +if os.environ["project"] in ["mongodb-mongo-master-nightly", "mongo-release"]: + signing_type = "notarizeAndSign" else: - signing_type = 'sign' + signing_type = "sign" -macnotary_url = f'https://macos-notary-1628249594.s3.amazonaws.com/releases/client/latest/{macnotary_name}.zip' -print(f'Fetching macnotary tool from: {macnotary_url}') -local_filename, headers = urllib.request.urlretrieve(macnotary_url, f'{macnotary_name}.zip') -with zipfile.ZipFile(f'{macnotary_name}.zip') as zipf: +macnotary_url = f"https://macos-notary-1628249594.s3.amazonaws.com/releases/client/latest/{macnotary_name}.zip" +print(f"Fetching macnotary tool from: {macnotary_url}") +local_filename, headers = urllib.request.urlretrieve( + macnotary_url, f"{macnotary_name}.zip" +) +with zipfile.ZipFile(f"{macnotary_name}.zip") as zipf: zipf.extractall() -st = os.stat(f'{macnotary_name}/macnotary') -os.chmod(f'{macnotary_name}/macnotary', st.st_mode | stat.S_IEXEC) +st = os.stat(f"{macnotary_name}/macnotary") +os.chmod(f"{macnotary_name}/macnotary", st.st_mode | stat.S_IEXEC) failed = False archives = sys.argv[1:] for archive in archives: archive_base, archive_ext = os.path.splitext(archive) - unsigned_archive = f'{archive_base}_unsigned{archive_ext}' + unsigned_archive = f"{archive_base}_unsigned{archive_ext}" shutil.move(archive, unsigned_archive) signing_cmd = [ - f'./{macnotary_name}/macnotary', - '-f', f'{unsigned_archive}', - '-m', f'{signing_type}', - '-u', 'https://dev.macos-notary.build.10gen.cc/api', - '-k', 'server', - '--entitlements', 'etc/macos_entitlements.xml', - '--verify', - "--timeout", "30", - '-b', 'server.mongodb.com', - '-i', f'{os.environ["task_id"]}', - '-c', f'{os.environ["project"]}', - '-o', f'{archive}' + f"./{macnotary_name}/macnotary", + "-f", + f"{unsigned_archive}", + "-m", + f"{signing_type}", + "-u", + "https://dev.macos-notary.build.10gen.cc/api", + "-k", + "server", + "--entitlements", + "etc/macos_entitlements.xml", + "--verify", + "--timeout", + "30", + "-b", + "server.mongodb.com", + "-i", + f'{os.environ["task_id"]}', + "-c", + f'{os.environ["project"]}', + "-o", + f"{archive}", ] signing_env = os.environ.copy() - signing_env['MACOS_NOTARY_SECRET'] = os.environ["macos_notarization_secret"] - print(' '.join(signing_cmd)) - p = subprocess.Popen(signing_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=signing_env) + signing_env["MACOS_NOTARY_SECRET"] = os.environ["macos_notarization_secret"] + print(" ".join(signing_cmd)) + p = subprocess.Popen( + signing_cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, env=signing_env + ) print(f"Signing tool completed with exitcode: {p.returncode}") - for line in iter(p.stdout.readline, b''): + for line in iter(p.stdout.readline, b""): print(f'macnotary: {line.decode("utf-8").strip()}') p.wait() @@ -82,4 +93,3 @@ for archive in archives: if failed: exit(1) - diff --git a/evergreen/spawnhost/download_archive_dist_test_debug.py b/evergreen/spawnhost/download_archive_dist_test_debug.py index f0b6e09fe9d..6c73811b8ab 100644 --- a/evergreen/spawnhost/download_archive_dist_test_debug.py +++ b/evergreen/spawnhost/download_archive_dist_test_debug.py @@ -34,7 +34,9 @@ def main(): host = evg_api.host_by_id(instance_id) task_id = host.json["provision_options"]["task_id"] - compile_tasks = evergreen_conn._filter_successful_tasks(evg_api, collections.deque([task_id])) + compile_tasks = evergreen_conn._filter_successful_tasks( + evg_api, collections.deque([task_id]) + ) debugsymbols_task = compile_tasks.symbols_task if debugsymbols_task is None: raise RuntimeError("Could not find debugsymbols task") @@ -58,7 +60,9 @@ def main(): continue ext = artifact.name.split(".")[-1] - urlretrieve(artifact.url, f"{output_dir}/debugsymbols-manually-downloaded.{ext}") + urlretrieve( + artifact.url, f"{output_dir}/debugsymbols-manually-downloaded.{ext}" + ) return raise RuntimeError("Error occured while trying to download debugsymbols.") diff --git a/jstests/auth/lib/automated_idp_authn_simulator_azure.py b/jstests/auth/lib/automated_idp_authn_simulator_azure.py index 9758855d678..9edad37d17b 100644 --- a/jstests/auth/lib/automated_idp_authn_simulator_azure.py +++ b/jstests/auth/lib/automated_idp_authn_simulator_azure.py @@ -7,6 +7,7 @@ Given a device authorization endpoint, a username, a user code and a file with n will simulate automatically logging in as a human would. """ + import argparse import os import json @@ -21,13 +22,14 @@ from selenium.webdriver.firefox.options import Options from selenium.webdriver.support import expected_conditions as EC from selenium.webdriver.support.ui import WebDriverWait + def authenticate_azure(activation_endpoint, userCode, username, test_credentials): # Install GeckoDriver if needed. geckodriver_autoinstaller.install() # Launch headless Firefox to the device authorization endpoint. firefox_options = Options() - firefox_options.add_argument('-headless') + firefox_options.add_argument("-headless") driver = webdriver.Firefox(options=firefox_options) driver.get(activation_endpoint) @@ -38,7 +40,9 @@ def authenticate_azure(activation_endpoint, userCode, username, test_credentials EC.presence_of_element_located((By.XPATH, "//input[@name='otc']")) ) next_button = WebDriverWait(driver, 30).until( - EC.presence_of_element_located((By.XPATH, "//input[@type='submit'][@value='Next']")) + EC.presence_of_element_located( + (By.XPATH, "//input[@type='submit'][@value='Next']") + ) ) # Enter usercode. @@ -50,9 +54,11 @@ def authenticate_azure(activation_endpoint, userCode, username, test_credentials EC.presence_of_element_located((By.XPATH, "//input[@name='loginfmt']")) ) next_button = WebDriverWait(driver, 30).until( - EC.presence_of_element_located((By.XPATH, "//input[@type='submit'][@value='Next']")) + EC.presence_of_element_located( + (By.XPATH, "//input[@type='submit'][@value='Next']") + ) ) - + # Enter username. username_input_box.send_keys(username) next_button.click() @@ -74,18 +80,23 @@ def authenticate_azure(activation_endpoint, userCode, username, test_credentials verify_button = None try: password_input_box = WebDriverWait(driver, 30).until( - EC.presence_of_element_located((By.ID, "passwordEntry"))) + EC.presence_of_element_located((By.ID, "passwordEntry")) + ) except: password_input_box = WebDriverWait(driver, 30).until( - EC.presence_of_element_located((By.ID, "i0118"))) + EC.presence_of_element_located((By.ID, "i0118")) + ) try: verify_button = WebDriverWait(driver, 30).until( - EC.presence_of_element_located((By.XPATH, - "//button[@data-testid='primaryButton']"))) + EC.presence_of_element_located( + (By.XPATH, "//button[@data-testid='primaryButton']") + ) + ) except: verify_button = WebDriverWait(driver, 30).until( - EC.presence_of_element_located((By.ID, "idSIButton9"))) + EC.presence_of_element_located((By.ID, "idSIButton9")) + ) # Enter password. password_input_box.send_keys(test_credentials[username]) @@ -93,34 +104,57 @@ def authenticate_azure(activation_endpoint, userCode, username, test_credentials # Assert 'Are you trying to sign in to OIDC_EVG_TESTING?' message. continue_button = WebDriverWait(driver, 30).until( - EC.presence_of_element_located((By.XPATH, "//input[@type='submit'][@value='Continue']")) + EC.presence_of_element_located( + (By.XPATH, "//input[@type='submit'][@value='Continue']") + ) ) continue_button.click() # Assert that the landing page contains the "You have signed in to the OIDC_EVG_TESTING application on your device" text, indicating successful auth. landing_header = WebDriverWait(driver, 30).until( - EC.presence_of_element_located((By.XPATH, "//p[@id='message'][@class='text-block-body no-margin-top']")) + EC.presence_of_element_located( + (By.XPATH, "//p[@id='message'][@class='text-block-body no-margin-top']") + ) ) - assert landing_header is not None and "You have signed in" in landing_header.text - - except Exception as e: + assert ( + landing_header is not None and "You have signed in" in landing_header.text + ) + + except Exception as e: print("Error: ", e) print("Traceback: ", traceback.format_exc()) print("HTML Source: ", driver.page_source) raise else: - print('Success') + print("Success") finally: driver.quit() -def main(): - parser = argparse.ArgumentParser(description='Azure Automated Authentication Simulator') - parser.add_argument('-e', '--activationEndpoint', type=str, help="Endpoint to start activation at") - parser.add_argument('-c', '--userCode', type=str, help="Code to be added in the endpoint to authenticate") - parser.add_argument('-u', '--username', type=str, help="Username to authenticate as") - parser.add_argument('-s', '--setupFile', type=str, help="File containing information generated during test setup, relative to home directory") +def main(): + parser = argparse.ArgumentParser( + description="Azure Automated Authentication Simulator" + ) + + parser.add_argument( + "-e", "--activationEndpoint", type=str, help="Endpoint to start activation at" + ) + parser.add_argument( + "-c", + "--userCode", + type=str, + help="Code to be added in the endpoint to authenticate", + ) + parser.add_argument( + "-u", "--username", type=str, help="Username to authenticate as" + ) + parser.add_argument( + "-s", + "--setupFile", + type=str, + help="File containing information generated during test setup, relative to home directory", + ) args = parser.parse_args() @@ -134,8 +168,12 @@ def main(): for i in range(num_retries): try: - authenticate_azure(args.activationEndpoint, args.userCode, args.username, - setup_information) + authenticate_azure( + args.activationEndpoint, + args.userCode, + args.username, + setup_information, + ) success = True break except Exception as e: @@ -148,8 +186,10 @@ def main(): else: print(f"Authentication with Azure failed after {num_retries} attempts") + authenticate_azure( + args.activationEndpoint, args.userCode, args.username, setup_information + ) - authenticate_azure(args.activationEndpoint, args.userCode, args.username, setup_information) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/jstests/auth/lib/automated_idp_authn_simulator_okta.py b/jstests/auth/lib/automated_idp_authn_simulator_okta.py index 324e1dfd6d0..fd76a7d90a1 100644 --- a/jstests/auth/lib/automated_idp_authn_simulator_okta.py +++ b/jstests/auth/lib/automated_idp_authn_simulator_okta.py @@ -7,6 +7,7 @@ Given a device authorization endpoint, a username, and a file with necessary set will simulate automatically logging in as a human would. """ + import argparse import os import json @@ -23,9 +24,7 @@ from selenium.webdriver.support.ui import WebDriverWait def get_input_box_with_label(driver, label_to_match, timeout): caps = "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - label_xpath = ( - f"//label[contains(translate(., '{caps}', '{caps.lower()}'), '{label_to_match.lower()}')]" - ) + label_xpath = f"//label[contains(translate(., '{caps}', '{caps.lower()}'), '{label_to_match.lower()}')]" label = WebDriverWait(driver, timeout).until( EC.presence_of_element_located((By.XPATH, label_xpath)) ) @@ -43,17 +42,21 @@ def authenticate_okta(activation_endpoint, userCode, username, test_credentials) # Launch headless Firefox to the device authorization endpoint. firefox_options = Options() - firefox_options.add_argument('-headless') + firefox_options.add_argument("-headless") driver = webdriver.Firefox(options=firefox_options) driver.get(activation_endpoint) try: # Wait for activation code input box and next button to load and click. - activationCode_input_box = get_input_box_with_label(driver, "Activation Code", 30) - next_button = WebDriverWait(driver, 30).until( - EC.element_to_be_clickable((By.XPATH, "//input[@class='button button-primary'][@value='Next']")) + activationCode_input_box = get_input_box_with_label( + driver, "Activation Code", 30 ) - + next_button = WebDriverWait(driver, 30).until( + EC.element_to_be_clickable( + (By.XPATH, "//input[@class='button button-primary'][@value='Next']") + ) + ) + # Enter user activation code. activationCode_input_box.send_keys(userCode) next_button.click() @@ -65,7 +68,7 @@ def authenticate_okta(activation_endpoint, userCode, username, test_credentials) (By.XPATH, "//input[@class='button button-primary'][@value='Next']") ) ) - + # Enter username. username_input_box.send_keys(username) next_button_username.click() @@ -84,26 +87,48 @@ def authenticate_okta(activation_endpoint, userCode, username, test_credentials) # Assert that the landing page contains the "Device activated" text, indicating successful auth. landing_header = WebDriverWait(driver, 30).until( - EC.presence_of_element_located((By.XPATH, "//h2[@class='okta-form-title o-form-head'][contains(text(), 'Device activated')]")) + EC.presence_of_element_located( + ( + By.XPATH, + "//h2[@class='okta-form-title o-form-head'][contains(text(), 'Device activated')]", + ) + ) ) assert landing_header is not None - + except Exception as e: print("Error: ", e) print("Traceback: ", traceback.format_exc()) print("HTML Source: ", driver.page_source) else: - print('Success') + print("Success") finally: driver.quit() -def main(): - parser = argparse.ArgumentParser(description='Okta Automated Authentication Simulator') - parser.add_argument('-e', '--activationEndpoint', type=str, help="Endpoint to start activation at") - parser.add_argument('-c', '--userCode', type=str, help="Code to be added in the endpoint to authenticate") - parser.add_argument('-u', '--username', type=str, help="Username to authenticate as") - parser.add_argument('-s', '--setupFile', type=str, help="File containing information generated during test setup, relative to home directory") +def main(): + parser = argparse.ArgumentParser( + description="Okta Automated Authentication Simulator" + ) + + parser.add_argument( + "-e", "--activationEndpoint", type=str, help="Endpoint to start activation at" + ) + parser.add_argument( + "-c", + "--userCode", + type=str, + help="Code to be added in the endpoint to authenticate", + ) + parser.add_argument( + "-u", "--username", type=str, help="Username to authenticate as" + ) + parser.add_argument( + "-s", + "--setupFile", + type=str, + help="File containing information generated during test setup, relative to home directory", + ) args = parser.parse_args() @@ -112,7 +137,10 @@ def main(): assert args.username in setup_information assert setup_information[args.username] - authenticate_okta(args.activationEndpoint, args.userCode, args.username, setup_information) + authenticate_okta( + args.activationEndpoint, args.userCode, args.username, setup_information + ) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/jstests/noPassthrough/libs/configExpand/reflect.py b/jstests/noPassthrough/libs/configExpand/reflect.py index 2852e3d1355..e4fc3d2eebf 100644 --- a/jstests/noPassthrough/libs/configExpand/reflect.py +++ b/jstests/noPassthrough/libs/configExpand/reflect.py @@ -1,7 +1,7 @@ #! /usr/bin/env python3 """Simple reflection script. - Sends argument back as provided. - Optionally sleeps for `--sleep` seconds.""" +Sends argument back as provided. +Optionally sleeps for `--sleep` seconds.""" import argparse import sys @@ -11,10 +11,17 @@ import time def main(): """Main Method.""" - parser = argparse.ArgumentParser(description='MongoDB Mock Config Expandsion EXEC Endpoint.') - parser.add_argument('-s', '--sleep', type=int, default=0, - help="Add artificial delay for timeout testing") - parser.add_argument('value', type=str, help="Content to reflect to stdout") + parser = argparse.ArgumentParser( + description="MongoDB Mock Config Expandsion EXEC Endpoint." + ) + parser.add_argument( + "-s", + "--sleep", + type=int, + default=0, + help="Add artificial delay for timeout testing", + ) + parser.add_argument("value", type=str, help="Content to reflect to stdout") args = parser.parse_args() @@ -28,5 +35,5 @@ def main(): sys.stdout.write(args.value) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/jstests/noPassthrough/libs/configExpand/rest_server.py b/jstests/noPassthrough/libs/configExpand/rest_server.py index f1704e8fe77..f1631a0f1d2 100644 --- a/jstests/noPassthrough/libs/configExpand/rest_server.py +++ b/jstests/noPassthrough/libs/configExpand/rest_server.py @@ -16,7 +16,7 @@ class ConfigExpandRestHandler(http.server.BaseHTTPRequestHandler): Handle requests from mongod during config expansion. """ - protocol_version = 'HTTP/1.1' + protocol_version = "HTTP/1.1" def handle(self): global connect_count @@ -29,37 +29,37 @@ class ConfigExpandRestHandler(http.server.BaseHTTPRequestHandler): path = parts.path query = urllib.parse.parse_qs(parts.query) - code = int(query.get('code', [http.HTTPStatus.OK])[0]) - sleep = int(query.get('sleep', [0])[0]) + code = int(query.get("code", [http.HTTPStatus.OK])[0]) + sleep = int(query.get("sleep", [0])[0]) if sleep > 0: time.sleep(sleep) try: - response = b'' - content_type = 'text/plain' - connection = 'keep-alive' + response = b"" + content_type = "text/plain" + connection = "keep-alive" - if path == '/reflect/string': + if path == "/reflect/string": # Parses 'string' value from query string and echoes it back. - response = ','.join(query['string']).encode() - elif path == '/reflect/yaml': + response = ",".join(query["string"]).encode() + elif path == "/reflect/yaml": # Parses 'json' value from query string as JSON and reencodes as YAML. - response = query['yaml'][0].encode() - content_type = 'text/yaml' - elif path == '/connect_count': + response = query["yaml"][0].encode() + content_type = "text/yaml" + elif path == "/connect_count": global connect_count response = str(connect_count).encode() - elif path == '/connection_close': - connection = 'close' - response = b'closed' + elif path == "/connection_close": + connection = "close" + response = b"closed" else: code = http.HTTPStatus.NOT_FOUND - response = b'Unknown URL' + response = b"Unknown URL" self.send_response(code) - self.send_header('content-type', content_type) - self.send_header('content-length', len(response)) - self.send_header('connection', connection) + self.send_header("content-type", content_type) + self.send_header("content-length", len(response)) + self.send_header("connection", connection) self.end_headers() self.wfile.write(response) @@ -70,16 +70,16 @@ class ConfigExpandRestHandler(http.server.BaseHTTPRequestHandler): def do_POST(self): self.send_response(http.HTTPStatus.NOT_FOUND) - self.send_header('content-type', 'text/plain') + self.send_header("content-type", "text/plain") self.end_headers() - self.wfile.write('POST not supported') + self.wfile.write("POST not supported") def run(port): """Run web server.""" http.server.HTTPServer.protocol_version = "HTTP/1.1" - server_address = ('', port) + server_address = ("", port) httpd = http.server.HTTPServer(server_address, ConfigExpandRestHandler) print("Mock Web Server Listening on %s" % (str(server_address))) @@ -89,11 +89,17 @@ def run(port): def main(): """Main Method.""" - parser = argparse.ArgumentParser(description='MongoDB Mock Config Expandsion REST Endpoint.') + parser = argparse.ArgumentParser( + description="MongoDB Mock Config Expandsion REST Endpoint." + ) - parser.add_argument('-p', '--port', type=int, default=8000, help="Port to listen on") + parser.add_argument( + "-p", "--port", type=int, default=8000, help="Port to listen on" + ) - parser.add_argument('-v', '--verbose', action='count', help="Enable verbose tracing") + parser.add_argument( + "-v", "--verbose", action="count", help="Enable verbose tracing" + ) args = parser.parse_args() if args.verbose: @@ -102,6 +108,5 @@ def main(): run(args.port) -if __name__ == '__main__': - +if __name__ == "__main__": main() diff --git a/jstests/ocsp/lib/ocsp_mock.py b/jstests/ocsp/lib/ocsp_mock.py index d41bfc90ea9..94a564ac760 100644 --- a/jstests/ocsp/lib/ocsp_mock.py +++ b/jstests/ocsp/lib/ocsp_mock.py @@ -10,56 +10,111 @@ import time import sys import os -sys.path.append(os.path.join(os.getcwd() ,'src', 'third_party', 'mock_ocsp_responder')) +sys.path.append(os.path.join(os.getcwd(), "src", "third_party", "mock_ocsp_responder")) import mock_ocsp_responder logger = logging.getLogger(__name__) + @atexit.register def on_exit(): - logger.debug('Mock OCSP Responder is exiting') + logger.debug("Mock OCSP Responder is exiting") + def main(): """Main entry point""" parser = argparse.ArgumentParser(description="MongoDB Mock OCSP Responder.") - parser.add_argument('-p', '--port', type=int, default=8080, help="Port to listen on") + parser.add_argument( + "-p", "--port", type=int, default=8080, help="Port to listen on" + ) - parser.add_argument('-b', '--bind_ip', type=str, default=None, help="IP to listen on") + parser.add_argument( + "-b", "--bind_ip", type=str, default=None, help="IP to listen on" + ) - parser.add_argument('--ca_file', type=str, required=True, help="CA file for OCSP responder") + parser.add_argument( + "--ca_file", type=str, required=True, help="CA file for OCSP responder" + ) - parser.add_argument('-v', '--verbose', action='count', help="Enable verbose tracing") + parser.add_argument( + "-v", "--verbose", action="count", help="Enable verbose tracing" + ) - parser.add_argument('--ocsp_responder_cert', type=str, required=True, help="OCSP Responder Certificate") + parser.add_argument( + "--ocsp_responder_cert", + type=str, + required=True, + help="OCSP Responder Certificate", + ) - parser.add_argument('--ocsp_responder_key', type=str, required=True, help="OCSP Responder Keyfile") + parser.add_argument( + "--ocsp_responder_key", type=str, required=True, help="OCSP Responder Keyfile" + ) - parser.add_argument('--fault', choices=[mock_ocsp_responder.FAULT_REVOKED, mock_ocsp_responder.FAULT_UNKNOWN, None], default=None, type=str, help="Specify a specific fault to test") + parser.add_argument( + "--fault", + choices=[ + mock_ocsp_responder.FAULT_REVOKED, + mock_ocsp_responder.FAULT_UNKNOWN, + None, + ], + default=None, + type=str, + help="Specify a specific fault to test", + ) - parser.add_argument('--next_update_seconds', type=int, default=32400, help="Specify how long the OCSP response should be valid for") + parser.add_argument( + "--next_update_seconds", + type=int, + default=32400, + help="Specify how long the OCSP response should be valid for", + ) - parser.add_argument('--response_delay_seconds', type=int, default=0, help="Delays the response by this number of seconds") + parser.add_argument( + "--response_delay_seconds", + type=int, + default=0, + help="Delays the response by this number of seconds", + ) - parser.add_argument('--include_extraneous_status', action='store_true', help="Include status of extraneous certificates in the response") + parser.add_argument( + "--include_extraneous_status", + action="store_true", + help="Include status of extraneous certificates in the response", + ) - parser.add_argument('--issuer_hash_algorithm', type=str, default='sha1', help="Algorithm to use when hashing issuer name and key") + parser.add_argument( + "--issuer_hash_algorithm", + type=str, + default="sha1", + help="Algorithm to use when hashing issuer name and key", + ) args = parser.parse_args() - level=logging.DEBUG if args.verbose else logging.INFO - logging.basicConfig(level=level, format="%(asctime)s %(levelname)s %(module)s: %(message)s") + level = logging.DEBUG if args.verbose else logging.INFO + logging.basicConfig( + level=level, format="%(asctime)s %(levelname)s %(module)s: %(message)s" + ) logging.Formatter.converter = time.gmtime - logger.info('Initializing OCSP Responder') - mock_ocsp_responder.init_responder(issuer_cert=args.ca_file, responder_cert=args.ocsp_responder_cert, - responder_key=args.ocsp_responder_key, fault=args.fault, next_update_seconds=args.next_update_seconds, - response_delay_seconds=args.response_delay_seconds, include_extraneous_status=args.include_extraneous_status, - issuer_hash_algorithm=args.issuer_hash_algorithm) + logger.info("Initializing OCSP Responder") + mock_ocsp_responder.init_responder( + issuer_cert=args.ca_file, + responder_cert=args.ocsp_responder_cert, + responder_key=args.ocsp_responder_key, + fault=args.fault, + next_update_seconds=args.next_update_seconds, + response_delay_seconds=args.response_delay_seconds, + include_extraneous_status=args.include_extraneous_status, + issuer_hash_algorithm=args.issuer_hash_algorithm, + ) - logger.debug('Mock OCSP Responder will be started on port %s' % (str(args.port))) + logger.debug("Mock OCSP Responder will be started on port %s" % (str(args.port))) mock_ocsp_responder.init(port=args.port, debug=args.verbose, host=args.bind_ip) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/jstests/sharding/libs/proxy_protocol_server.py b/jstests/sharding/libs/proxy_protocol_server.py index d6bf876bad1..9218e29977c 100644 --- a/jstests/sharding/libs/proxy_protocol_server.py +++ b/jstests/sharding/libs/proxy_protocol_server.py @@ -6,6 +6,6 @@ Python script to interact with proxy protocol server. from proxyprotocol.server.main import * import sys -if __name__ == '__main__': +if __name__ == "__main__": print("Starting proxy protocol server...") sys.exit(main()) diff --git a/jstests/ssl/tls_enumerator.py b/jstests/ssl/tls_enumerator.py index 6fa428f5636..c12a58813c0 100644 --- a/jstests/ssl/tls_enumerator.py +++ b/jstests/ssl/tls_enumerator.py @@ -5,12 +5,13 @@ import argparse exception_ciphers = {} + def enumerate_tls_ciphers(protocol_options, host, port, cert, cafile): root_context = ssl.SSLContext(ssl.PROTOCOL_TLS) root_context.options |= protocol_options - root_context.set_ciphers('ALL:COMPLEMENTOFALL:-PSK:-SRP') + root_context.set_ciphers("ALL:COMPLEMENTOFALL:-PSK:-SRP") - ciphers = {cipher['name'] for cipher in root_context.get_ciphers()} + ciphers = {cipher["name"] for cipher in root_context.get_ciphers()} accepted_ciphers = [] @@ -35,18 +36,25 @@ def enumerate_tls_ciphers(protocol_options, host, port, cert, cafile): return sorted(accepted_ciphers) -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='MongoDB TLS Cipher Suite Enumerator') - parser.add_argument('--port', type=int, default=27017, help='Port to connect to') - parser.add_argument('-o', '--outfile', type=str, default='ciphers.json', - help='file to write the output to') - parser.add_argument('--host', type=str, default='localhost', help='host to connect to') - parser.add_argument('--cafile', type=str, help='Path to CA certificate') - parser.add_argument('--cert', type=str, help='Path to client certificate') +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="MongoDB TLS Cipher Suite Enumerator") + parser.add_argument("--port", type=int, default=27017, help="Port to connect to") + parser.add_argument( + "-o", + "--outfile", + type=str, + default="ciphers.json", + help="file to write the output to", + ) + parser.add_argument( + "--host", type=str, default="localhost", help="host to connect to" + ) + parser.add_argument("--cafile", type=str, help="Path to CA certificate") + parser.add_argument("--cert", type=str, help="Path to client certificate") args = parser.parse_args() # MacOS version of the toolchain does not have python linked with OpenSSL 1.1.1 yet, so we monkey patch this in here - if not hasattr(ssl, 'OP_NO_TLSv1_3'): + if not hasattr(ssl, "OP_NO_TLSv1_3"): ssl.OP_NO_TLSv1_3 = 0 exclude_ops = { @@ -65,23 +73,28 @@ if __name__ == '__main__': return option suites = { - 'sslv2': exclude_except(ssl.OP_NO_SSLv2), - 'sslv3': exclude_except(ssl.OP_NO_SSLv3), - 'tls1': exclude_except(ssl.OP_NO_TLSv1), - 'tls1_1': exclude_except(ssl.OP_NO_TLSv1_1), - 'tls1_2': exclude_except(ssl.OP_NO_TLSv1_2), + "sslv2": exclude_except(ssl.OP_NO_SSLv2), + "sslv3": exclude_except(ssl.OP_NO_SSLv3), + "tls1": exclude_except(ssl.OP_NO_TLSv1), + "tls1_1": exclude_except(ssl.OP_NO_TLSv1_1), + "tls1_2": exclude_except(ssl.OP_NO_TLSv1_2), } results = { - key: enumerate_tls_ciphers(protocol_options=proto, host=args.host, port=args.port, - cafile=args.cafile, cert=args.cert) + key: enumerate_tls_ciphers( + protocol_options=proto, + host=args.host, + port=args.port, + cafile=args.cafile, + cert=args.cert, + ) for key, proto in suites.items() } if exception_ciphers: print("System could not process the following ciphers") for cipher, error in exception_ciphers.items(): - print(cipher + '\tError: ' + error) + print(cipher + "\tError: " + error) - with open(args.outfile, 'w+') as outfile: + with open(args.outfile, "w+") as outfile: json.dump(results, outfile) diff --git a/jstests/ssl/x509/mkcert.py b/jstests/ssl/x509/mkcert.py index 8f51ef466e2..a2af2ff17f9 100755 --- a/jstests/ssl/x509/mkcert.py +++ b/jstests/ssl/x509/mkcert.py @@ -5,6 +5,7 @@ Invoke as `python3 jstests/ssl/x509/mkcert.py --config jstests/ssl/x509/certs.ym and `python3 jstests/ssl/x509/mkcert.py --config src/mongo/db/modules/enterprise/jstests/libs/certs.yml` Optionally providing a cert ID to only regenerate a single cert. """ + import argparse import binascii import datetime @@ -24,94 +25,109 @@ import mkdigest try: # Newer versions of PyOpenSSL hide OBJ_create, but also seem okay without it. OBJ_create = OpenSSL._util.lib.OBJ_create - OBJ_create(b'1.2.3.45', b'DummyOID45', b'Dummy OID 45') - OBJ_create(b'1.2.3.56', b'DummyOID56', b'Dummy OID 56') - OBJ_create(b'1.3.6.1.4.1.34601.2.1.1', b'mongoRoles', - b'Sequence of MongoDB Database Roles') - OBJ_create(b'1.3.6.1.4.1.34601.2.1.2', b'mongoClusterMembership', - b'Name of MongoDB cluster this cert is a member of') + OBJ_create(b"1.2.3.45", b"DummyOID45", b"Dummy OID 45") + OBJ_create(b"1.2.3.56", b"DummyOID56", b"Dummy OID 56") + OBJ_create( + b"1.3.6.1.4.1.34601.2.1.1", b"mongoRoles", b"Sequence of MongoDB Database Roles" + ) + OBJ_create( + b"1.3.6.1.4.1.34601.2.1.2", + b"mongoClusterMembership", + b"Name of MongoDB cluster this cert is a member of", + ) except: pass # pylint: enable=protected-access -CONFIGFILE = 'jstests/ssl/x509/certs.yml' +CONFIGFILE = "jstests/ssl/x509/certs.yml" CONFIG = Dict[str, Any] # tlsfeature = status_request isn't supported by older versions of OpenSSL so we manually define this below # 1.3.6.1.5.5.7.1.24: "tls_feature" extension as defined in https://tools.ietf.org/html/rfc7633#section-6 -MUST_STAPLE_KEY_STR = '1.3.6.1.5.5.7.1.24' +MUST_STAPLE_KEY_STR = "1.3.6.1.5.5.7.1.24" MUST_STAPLE_KEY = bytes(MUST_STAPLE_KEY_STR, "utf-8") # status_request extension as defined in https://tools.ietf.org/html/rfc4366#section-2.3 -MUST_STAPLE_VALUE_STR = 'DER:30:03:02:01:05' # ASN.1 value: SEQUENCE { INTEGER 0x05 (5 decimal) } -MUST_STAPLE_VALUE = str(MUST_STAPLE_VALUE_STR).encode('utf-8') +MUST_STAPLE_VALUE_STR = ( + "DER:30:03:02:01:05" # ASN.1 value: SEQUENCE { INTEGER 0x05 (5 decimal) } +) +MUST_STAPLE_VALUE = str(MUST_STAPLE_VALUE_STR).encode("utf-8") -# <= 825 in order to abide by https://support.apple.com/en-us/HT210176. +# <= 825 in order to abide by https://support.apple.com/en-us/HT210176. MAX_VALIDITY_PERIOD_DAYS = 824 + def glbl(key, default=None): """Fetch a key from the global dict.""" - return CONFIG.get('global', {}).get(key, default) + return CONFIG.get("global", {}).get(key, default) + def idx(cert, key, default=None): """Fetch a key from the cert dict, falling back through global dict.""" return cert.get(key, None) or glbl(key, default) + def make_key(cert): """Generate an RSA or DSA private key.""" # Note that ECDSA keys are generated in the # process_csdsa_*() functions below. - type_str = idx(cert, 'key_type', 'RSA') - if type_str == 'RSA': + type_str = idx(cert, "key_type", "RSA") + if type_str == "RSA": key_type = OpenSSL.crypto.TYPE_RSA - elif type_str == 'DSA': + elif type_str == "DSA": key_type = OpenSSL.crypto.TYPE_DSA else: - raise ValueError('Unknown key_type: ' + type_str) + raise ValueError("Unknown key_type: " + type_str) - key_size = int(idx(cert, 'key_size', '2048')) + key_size = int(idx(cert, "key_size", "2048")) if key_size < 1024: - raise ValueError('Invalid key_size: ' + key_size) + raise ValueError("Invalid key_size: " + key_size) key = OpenSSL.crypto.PKey() key.generate_key(key_type, key_size) return key + def make_filename(cert): """Form a pathname from a certificate name.""" - return idx(cert, 'output_path') + '/' + cert['name'] + return idx(cert, "output_path") + "/" + cert["name"] + def find_certificate_definition(name): """Locate a definition by name.""" - for ca_cert in CONFIG['certs']: - if ca_cert['name'] == name: + for ca_cert in CONFIG["certs"]: + if ca_cert["name"] == name: return ca_cert return None + def get_cert_path(name): """Determine certificate path by name.""" entry = find_certificate_definition(name) return make_filename(entry) if entry else name + def load_authority_file(issuer): """Locate the cert/key file for a given ID and load their parts.""" ca_cert = find_certificate_definition(issuer) if ca_cert: - pem = open(make_filename(ca_cert), 'rt').read() + pem = open(make_filename(ca_cert), "rt").read() certificate = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, pem) - passphrase = ca_cert.get('passphrase', None) + passphrase = ca_cert.get("passphrase", None) if passphrase: - passphrase = passphrase.encode('utf-8') + passphrase = passphrase.encode("utf-8") - signing_key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, pem, passphrase=passphrase) + signing_key = OpenSSL.crypto.load_privatekey( + OpenSSL.crypto.FILETYPE_PEM, pem, passphrase=passphrase + ) return (certificate, signing_key) # Externally sourced certifiate, try by path. Hopefully unencrypted. # pylint: disable=bare-except try: - pem = open(issuer, 'rt').read() + pem = open(issuer, "rt").read() certificate = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, pem) signing_key = OpenSSL.crypto.load_privatekey(OpenSSL.crypto.FILETYPE_PEM, pem) return (certificate, signing_key) @@ -120,24 +136,26 @@ def load_authority_file(issuer): return (None, None) + def set_subject(x509, cert): """Translate a subject dict to X509Name elements.""" - if not cert.get('Subject'): - if cert.get('explicit_subject', False): + if not cert.get("Subject"): + if cert.get("explicit_subject", False): # do nothing if an empty subject is explicitly provided return - raise ValueError(cert['name'] + ' requires a Subject') + raise ValueError(cert["name"] + " requires a Subject") - if not cert.get('explicit_subject', False): - for key, val in glbl('Subject', {}).items(): + if not cert.get("explicit_subject", False): + for key, val in glbl("Subject", {}).items(): setattr(x509.get_subject(), key, val) - for key, val in cert['Subject'].items(): + for key, val in cert["Subject"].items(): setattr(x509.get_subject(), key, val) + def set_validity(x509, cert): """Set validity range for the certificate.""" - not_before = idx(cert, 'not_before', None) + not_before = idx(cert, "not_before", None) if not_before: # TODO: Parse human readable dates and/or datedeltas not_before = int(not_before) @@ -146,7 +164,7 @@ def set_validity(x509, cert): not_before = -7 * 24 * 60 * 60 x509.gmtime_adj_notBefore(not_before) - not_after = idx(cert, 'not_after', None) + not_after = idx(cert, "not_after", None) if not_after: # TODO: Parse human readable dates and/or datedeltas not_after = int(not_after) @@ -154,38 +172,47 @@ def set_validity(x509, cert): not_after = not_before + MAX_VALIDITY_PERIOD_DAYS * 24 * 60 * 60 x509.gmtime_adj_notAfter(not_after) + def set_general_dict_extension(x509, exts, cert, name, typed_values): """Set dict key/value pairs for an extension.""" - tags = cert.get('extensions', {}).get(name, {}) + tags = cert.get("extensions", {}).get(name, {}) if not tags: return critical = False value = [] for key, val in tags.items(): - if key == 'critical': + if key == "critical": if not val is True: - raise ValueError('critical must be precisely equal to TRUE') + raise ValueError("critical must be precisely equal to TRUE") critical = True continue if not key in typed_values: - raise ValueError('Unknown key for extensions. ' + name + ': ' + key) + raise ValueError("Unknown key for extensions. " + name + ": " + key) if not isinstance(val, type(typed_values[key])): - raise ValueError('Type mismatch for extensions. ' + name + '.' + key) + raise ValueError("Type mismatch for extensions. " + name + "." + key) if isinstance(val, bool): - value.append(key + ':' + ('TRUE' if val else 'FALSE')) + value.append(key + ":" + ("TRUE" if val else "FALSE")) else: - value.append(key + ':' + val) + value.append(key + ":" + val) + + exts.append( + OpenSSL.crypto.X509Extension( + bytes(name, "utf-8"), + critical, + ",".join(value).encode("utf-8"), + subject=x509, + ) + ) - exts.append(OpenSSL.crypto.X509Extension(bytes(name, 'utf-8'), critical, ','.join(value).encode('utf-8'), subject=x509)) def set_general_list_extension(x509, exts, cert, name, values): """Set value elements for a given extension.""" - tags = cert.get('extensions', {}).get(name, ()) + tags = cert.get("extensions", {}).get(name, ()) if not tags: return @@ -193,167 +220,258 @@ def set_general_list_extension(x509, exts, cert, name, values): tags = [tags] critical = False - if 'critical' in tags: + if "critical" in tags: critical = True - tags.remove('critical') + tags.remove("critical") for key in tags: if key not in values: - raise ValueError('Illegal tag: ' + key) + raise ValueError("Illegal tag: " + key) + + exts.append( + OpenSSL.crypto.X509Extension( + name.encode("utf-8"), critical, ",".join(tags).encode("utf-8"), subject=x509 + ) + ) - exts.append(OpenSSL.crypto.X509Extension(name.encode('utf-8'), critical, ','.join(tags).encode('utf-8'), subject=x509)) def set_ocsp_extension(x509, exts, cert): """Set the OCSP extension""" - ocsp = cert.get('extensions', {}).get('authorityInfoAccess') + ocsp = cert.get("extensions", {}).get("authorityInfoAccess") if not ocsp: return - exts.append(OpenSSL.crypto.X509Extension(b'authorityInfoAccess', False, ocsp.encode('utf-8'), subject=x509)) + exts.append( + OpenSSL.crypto.X509Extension( + b"authorityInfoAccess", False, ocsp.encode("utf-8"), subject=x509 + ) + ) + def set_no_check_extension(x509, exts, cert): """Set the OCSP No Check extension""" - noCheck = cert.get('extensions', {}).get('noCheck') + noCheck = cert.get("extensions", {}).get("noCheck") if not noCheck: return # "The OCSP No Check extension is a string extension but its value is ignored." https://www.openssl.org/docs/man1.1.1/man5/x509v3_config.html - exts.append(OpenSSL.crypto.X509Extension(b'noCheck', False, "this-value-ignored".encode('utf8'), subject=x509)) + exts.append( + OpenSSL.crypto.X509Extension( + b"noCheck", False, "this-value-ignored".encode("utf8"), subject=x509 + ) + ) + def set_tls_feature_extension(x509, exts, cert): """Set the OCSP Must Staple extension""" - mustStaple = cert.get('extensions', {}).get('mustStaple') + mustStaple = cert.get("extensions", {}).get("mustStaple") if not mustStaple: return - exts.append(OpenSSL.crypto.X509Extension(MUST_STAPLE_KEY, False, MUST_STAPLE_VALUE, subject=x509)) + exts.append( + OpenSSL.crypto.X509Extension( + MUST_STAPLE_KEY, False, MUST_STAPLE_VALUE, subject=x509 + ) + ) + def set_san_extension(x509, exts, cert): """Set the Subject Alternate Name extension.""" - san = cert.get('extensions', {}).get('subjectAltName') + san = cert.get("extensions", {}).get("subjectAltName") if not san: return critical = False sans = [] for typ, vals in san.items(): - if typ == 'critical': + if typ == "critical": if not vals is True: - raise ValueError('critical must be precisely equal to TRUE') + raise ValueError("critical must be precisely equal to TRUE") critical = True continue - if not typ in ['IP', 'DNS']: # Other things can live here, but this is all we use. - raise ValueError('Fix me? Only IP and DNS SANs are handled') + if not typ in [ + "IP", + "DNS", + ]: # Other things can live here, but this is all we use. + raise ValueError("Fix me? Only IP and DNS SANs are handled") if not isinstance(vals, list): vals = [vals] for val in vals: - sans.append(typ + ':' + val) + sans.append(typ + ":" + val) if not sans: return - exts.append(OpenSSL.crypto.X509Extension(b'subjectAltName', critical, ','.join(sans).encode('utf-8'), subject=x509)) + exts.append( + OpenSSL.crypto.X509Extension( + b"subjectAltName", critical, ",".join(sans).encode("utf-8"), subject=x509 + ) + ) + def enable_subject_key_identifier_extension(x509, exts, cert): """Enable the subject key identifier extension.""" - ident = cert.get('extensions', {}).get('subjectKeyIdentifier', False) + ident = cert.get("extensions", {}).get("subjectKeyIdentifier", False) if not ident: return - if ident not in ['hash', 'hash-critical']: + if ident not in ["hash", "hash-critical"]: raise ValueError("Only the value 'hash' is accepted for subejctKeyIdentifier") - exts.append(OpenSSL.crypto.X509Extension(b'subjectKeyIdentifier', ident == 'hash-critical', b'hash', subject=x509)) + exts.append( + OpenSSL.crypto.X509Extension( + b"subjectKeyIdentifier", ident == "hash-critical", b"hash", subject=x509 + ) + ) + def enable_authority_key_identifier_extension(x509, exts, cert): """Enable the authority key identifier extension.""" - ident = cert.get('extensions', {}).get('authorityKeyIdentifier', False) + ident = cert.get("extensions", {}).get("authorityKeyIdentifier", False) if not ident: return - if ident not in ['keyid', 'issuer']: - raise ValueError("Only the 'keyid' or 'issuer' values are accepted for authorityKeyIdentifier") - issuer = cert.get('Issuer', 'ca.pem') - if issuer == 'self': + if ident not in ["keyid", "issuer"]: + raise ValueError( + "Only the 'keyid' or 'issuer' values are accepted for authorityKeyIdentifier" + ) + issuer = cert.get("Issuer", "ca.pem") + if issuer == "self": issuer_cert = x509 else: issuer_cert = load_authority_file(issuer)[0] - exts.append(OpenSSL.crypto.X509Extension(b'authorityKeyIdentifier', False, ident.encode('utf-8'), subject=x509, issuer=issuer_cert)) + exts.append( + OpenSSL.crypto.X509Extension( + b"authorityKeyIdentifier", + False, + ident.encode("utf-8"), + subject=x509, + issuer=issuer_cert, + ) + ) + def to_der_varint(val): """Translate a native int to a variable length ASN.1 encoded integer.""" if val < 0: - raise ValueError('Negative values nor permitted in DER payload') + raise ValueError("Negative values nor permitted in DER payload") if val < 0x80: - return chr(val).encode('ascii') + return chr(val).encode("ascii") - ret = bytearray(b'') + ret = bytearray(b"") while (val > 0) and (len(ret) < 8): ret.insert(0, val & 0xFF) val = val >> 8 if val > 0: - raise ValueError('Length is too large to represent in 64bits') + raise ValueError("Length is too large to represent in 64bits") ret.insert(0, 0x80 + len(ret)) return ret + def to_der_utf8_string(val): """Encode a unicode string as a ASN.1 UTF8 String.""" - utf8_val = str(val).encode('utf-8') - return b'\x0C' + to_der_varint(len(utf8_val)) + utf8_val + utf8_val = str(val).encode("utf-8") + return b"\x0c" + to_der_varint(len(utf8_val)) + utf8_val + def to_der_sequence_pair(name, value): """Encode a pair of ASN.1 values as a sequence pair.""" # Simplified sequence which always expects two string, a key and a value. bin_name = to_der_utf8_string(name) bin_value = to_der_utf8_string(value) - return b'\x30' + to_der_varint(len(bin_name) + len(bin_value)) + bin_name + bin_value + return ( + b"\x30" + to_der_varint(len(bin_name) + len(bin_value)) + bin_name + bin_value + ) + def set_mongo_roles_extension(exts, cert): """Encode a set of role/db pairs into a MongoDB DER packet.""" - roles = cert.get('extensions', {}).get('mongoRoles') + roles = cert.get("extensions", {}).get("mongoRoles") if not roles: return - pair = b'' + pair = b"" for role in roles: - if (len(role) != 2) or ('role' not in role) or ('db' not in role): - raise ValueError('mongoRoles must consist of a series of role/db pairs') - pair = pair + to_der_sequence_pair(role['role'], role['db']) + if (len(role) != 2) or ("role" not in role) or ("db" not in role): + raise ValueError("mongoRoles must consist of a series of role/db pairs") + pair = pair + to_der_sequence_pair(role["role"], role["db"]) - value = b'DER:31' + binascii.hexlify(to_der_varint(len(pair))) + binascii.hexlify(pair) + value = ( + b"DER:31" + binascii.hexlify(to_der_varint(len(pair))) + binascii.hexlify(pair) + ) + + exts.append(OpenSSL.crypto.X509Extension(b"1.3.6.1.4.1.34601.2.1.1", False, value)) - exts.append(OpenSSL.crypto.X509Extension(b'1.3.6.1.4.1.34601.2.1.1', False, value)) def set_mongo_cluster_membership_extension(exts, cert): """Encode a symbolic name to a mongodbClusterMembership extension.""" - name = cert.get('extensions', {}).get('mongoClusterMembership') + name = cert.get("extensions", {}).get("mongoClusterMembership") if not name: return - value = b'DER:' + binascii.hexlify(to_der_utf8_string(name)) - exts.append(OpenSSL.crypto.X509Extension(b'1.3.6.1.4.1.34601.2.1.2', False, value)) + value = b"DER:" + binascii.hexlify(to_der_utf8_string(name)) + exts.append(OpenSSL.crypto.X509Extension(b"1.3.6.1.4.1.34601.2.1.2", False, value)) + def set_crl_distribution_point_extension(exts, cert): """Specify URI(s) for CRL distribution point(s).""" - uris = cert.get('extensions', {}).get('crlDistributionPoints') + uris = cert.get("extensions", {}).get("crlDistributionPoints") if not uris: return - exts.append(OpenSSL.crypto.X509Extension(b'crlDistributionPoints', False, (','.join(uris)).encode('utf-8'))) + exts.append( + OpenSSL.crypto.X509Extension( + b"crlDistributionPoints", False, (",".join(uris)).encode("utf-8") + ) + ) + def set_extensions(x509, cert): """Setup X509 extensions.""" exts = [] - set_general_dict_extension(x509, exts, cert, 'basicConstraints', {'CA': False, 'pathlen': 0}) - set_general_list_extension(x509, exts, cert, 'keyUsage', [ - 'digitalSignature', 'nonRepudiation', 'keyEncipherment', 'dataEncipherment', - 'keyAgreement', 'keyCertSign', 'cRLSign', 'encipherOnly', 'decipherOnly']) - set_general_list_extension(x509, exts, cert, 'extendedKeyUsage', [ - 'serverAuth', 'clientAuth', 'codeSigning', 'emailProtection', 'timeStamping', - 'msCodeInd', 'msCodeCom', 'msCTLSign', 'msSGC', 'msEFS', 'nsSGC', 'OCSPSigning']) + set_general_dict_extension( + x509, exts, cert, "basicConstraints", {"CA": False, "pathlen": 0} + ) + set_general_list_extension( + x509, + exts, + cert, + "keyUsage", + [ + "digitalSignature", + "nonRepudiation", + "keyEncipherment", + "dataEncipherment", + "keyAgreement", + "keyCertSign", + "cRLSign", + "encipherOnly", + "decipherOnly", + ], + ) + set_general_list_extension( + x509, + exts, + cert, + "extendedKeyUsage", + [ + "serverAuth", + "clientAuth", + "codeSigning", + "emailProtection", + "timeStamping", + "msCodeInd", + "msCodeCom", + "msCTLSign", + "msSGC", + "msEFS", + "nsSGC", + "OCSPSigning", + ], + ) enable_subject_key_identifier_extension(x509, exts, cert) enable_authority_key_identifier_extension(x509, exts, cert) set_ocsp_extension(x509, exts, cert) @@ -364,19 +482,24 @@ def set_extensions(x509, cert): set_mongo_roles_extension(exts, cert) set_mongo_cluster_membership_extension(exts, cert) - ns_comment = cert.get('extensions', {}).get('nsComment') + ns_comment = cert.get("extensions", {}).get("nsComment") if ns_comment: - exts.append(OpenSSL.crypto.X509Extension(b'nsComment', False, ns_comment.encode('utf-8'))) + exts.append( + OpenSSL.crypto.X509Extension( + b"nsComment", False, ns_comment.encode("utf-8") + ) + ) if exts: x509.add_extensions(exts) + def sign_cert(x509, cert, key): """Sign the new certificate.""" - sig = idx(cert, 'hash', 'sha256') + sig = idx(cert, "hash", "sha256") - issuer = cert.get('Issuer', 'ca.pem') - if issuer == 'self': + issuer = cert.get("Issuer", "ca.pem") + if issuer == "self": x509.set_issuer(x509.get_subject()) x509.sign(key, sig) return @@ -385,45 +508,69 @@ def sign_cert(x509, cert, key): (signing_cert, signing_key) = load_authority_file(issuer) if not signing_key: - raise ValueError('No issuer available to sign with') + raise ValueError("No issuer available to sign with") x509.set_issuer(signing_cert.get_subject()) x509.sign(signing_key, sig) + def get_header_comment(cert): - if not cert.get('include_header', True): - return '' + if not cert.get("include_header", True): + return "" """Header comment for every generated file.""" comment = "# Autogenerated file, do not edit.\n" - comment = comment + '# Generate using jstests/ssl/x509/mkcert.py --config ' + CONFIGFILE - comment = comment + ' ' + cert['name'] + "\n#\n" - comment = comment + "# " + cert.get('description', '').replace("\n", "\n# ") + comment = ( + comment + "# Generate using jstests/ssl/x509/mkcert.py --config " + CONFIGFILE + ) + comment = comment + " " + cert["name"] + "\n#\n" + comment = comment + "# " + cert.get("description", "").replace("\n", "\n# ") comment = comment + "\n" return comment + def convert_cert_to_pkcs1(cert): """Reencodes the main certificate to use PKCS#1 private key encryption.""" src = make_filename(cert) - pswd = 'pass:' + cert['passphrase'] + pswd = "pass:" + cert["passphrase"] tmpcert = tempfile.mkstemp()[1] tmpkey = tempfile.mkstemp()[1] - subprocess.check_call(['openssl', 'x509', '-in', src, '-out', tmpcert]) - subprocess.check_call(['openssl', 'rsa', '-in', src, '-passin', pswd, '-out', tmpkey, '-aes256', '-passout', pswd]) - open(src, 'wt').write(get_header_comment(cert) + "\n" + open(tmpcert, 'rt').read() + open(tmpkey, 'rt').read()) + subprocess.check_call(["openssl", "x509", "-in", src, "-out", tmpcert]) + subprocess.check_call( + [ + "openssl", + "rsa", + "-in", + src, + "-passin", + pswd, + "-out", + tmpkey, + "-aes256", + "-passout", + pswd, + ] + ) + open(src, "wt").write( + get_header_comment(cert) + + "\n" + + open(tmpcert, "rt").read() + + open(tmpkey, "rt").read() + ) os.remove(tmpcert) os.remove(tmpkey) + def convert_cert_to_pkcs12(cert): """Makes a new copy of the cert/key pair using PKCS#12 encoding.""" - pkcs12 = cert.get('pkcs12') - if not pkcs12.get('passphrase'): - raise ValueError('PKCS#12 requires a passphrase') + pkcs12 = cert.get("pkcs12") + if not pkcs12.get("passphrase"): + raise ValueError("PKCS#12 requires a passphrase") src = make_filename(cert) - dest = idx(cert, 'output_path') + '/' + pkcs12.get('name', cert['name']) - ca = get_cert_path(cert['Issuer']) - passout = 'pass:' + pkcs12['passphrase'] + dest = idx(cert, "output_path") + "/" + pkcs12.get("name", cert["name"]) + ca = get_cert_path(cert["Issuer"]) + passout = "pass:" + pkcs12["passphrase"] args = [ "openssl", @@ -447,87 +594,107 @@ def convert_cert_to_pkcs12(cert): subprocess.check_call(args) + def create_cert(cert): """Create a new X509 certificate.""" x509 = OpenSSL.crypto.X509() key = make_key(cert) x509.set_pubkey(key) - x509.set_version(int(cert.get('version', 3)) - 1) + x509.set_version(int(cert.get("version", 3)) - 1) set_subject(x509, cert) set_validity(x509, cert) set_extensions(x509, cert) # Serial numbers 0..999 are reserved for fixed serial numbers. # Other will be assigned randomly. - x509.set_serial_number(cert.get('serial', random.randint(1000, 0x7FFFFFFF))) + x509.set_serial_number(cert.get("serial", random.randint(1000, 0x7FFFFFFF))) sign_cert(x509, cert, key) - passphrase = cert.get('passphrase', None) + passphrase = cert.get("passphrase", None) cipher = None if passphrase: - passphrase = passphrase.encode('utf-8') - cipher = 'aes256' + passphrase = passphrase.encode("utf-8") + cipher = "aes256" header = get_header_comment(cert) - if bool(cert.get('keyfile', False)) != bool(cert.get('crtfile', False)): + if bool(cert.get("keyfile", False)) != bool(cert.get("crtfile", False)): raise ValueError("Either include both keyfile and crtfile or neither") # The OCSP responder certificate needs to have the key and the pem file separated. # Since there are only a few cases where we need split key and crt files, and since we # sometimes need the unified pem file as well, we can always generate the pem file. - if cert.get('keyfile', False) and cert.get('crtfile', False): - keyfile = cert['keyfile'] - crtfile = cert['crtfile'] + if cert.get("keyfile", False) and cert.get("crtfile", False): + keyfile = cert["keyfile"] + crtfile = cert["crtfile"] - key_path_dict = {'output_path': cert['output_path'], 'name': keyfile} - crt_path_dict = {'output_path': cert['output_path'], 'name': crtfile} + key_path_dict = {"output_path": cert["output_path"], "name": keyfile} + crt_path_dict = {"output_path": cert["output_path"], "name": crtfile} - open(make_filename(crt_path_dict), 'wt').write( - header + - OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, x509).decode('ascii')) + open(make_filename(crt_path_dict), "wt").write( + header + + OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, x509).decode( + "ascii" + ) + ) - open(make_filename(key_path_dict), 'wt').write( - header + - OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key, cipher=cipher, passphrase=passphrase).decode('ascii')) + open(make_filename(key_path_dict), "wt").write( + header + + OpenSSL.crypto.dump_privatekey( + OpenSSL.crypto.FILETYPE_PEM, key, cipher=cipher, passphrase=passphrase + ).decode("ascii") + ) - open(make_filename(cert), 'wt').write( - header + - OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, x509).decode('ascii') + - OpenSSL.crypto.dump_privatekey(OpenSSL.crypto.FILETYPE_PEM, key, cipher=cipher, passphrase=passphrase).decode('ascii')) + open(make_filename(cert), "wt").write( + header + + OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, x509).decode( + "ascii" + ) + + OpenSSL.crypto.dump_privatekey( + OpenSSL.crypto.FILETYPE_PEM, key, cipher=cipher, passphrase=passphrase + ).decode("ascii") + ) - if cert.get('pkcs1'): + if cert.get("pkcs1"): convert_cert_to_pkcs1(cert) - if cert.get('pkcs12'): + if cert.get("pkcs12"): convert_cert_to_pkcs12(cert) + def check_special_case_keys(cert): """All special cases must contain three keys with an optional tags key""" keys = set(cert.keys()) - required_keys = {'name', 'description', 'Issuer'} - optional_keys = {'tags'} - allowed_tags = {'ecdsa', 'ocsp', 'responder', 'must-staple'} + required_keys = {"name", "description", "Issuer"} + optional_keys = {"tags"} + allowed_tags = {"ecdsa", "ocsp", "responder", "must-staple"} allowed_keys = required_keys.union(optional_keys) if not keys.issubset(allowed_keys): unexpected_keys = keys - allowed_keys - raise ValueError('Unexpected fields in special entry: ' + ", ".join(map(str, unexpected_keys))) + raise ValueError( + "Unexpected fields in special entry: " + + ", ".join(map(str, unexpected_keys)) + ) if not required_keys.issubset(keys): missing_keys = required_keys - keys - raise ValueError('Missing fields in special entry: ' + ", ".join(map(str, missing_keys))) + raise ValueError( + "Missing fields in special entry: " + ", ".join(map(str, missing_keys)) + ) if "tags" in keys: tags = set(cert["tags"]) if not tags.issubset(allowed_tags): unexpected_tags = tags - allowed_tags - raise ValueError('Unexpected tags: ' + ", ".join(map(str, unexpected_tags))) + raise ValueError("Unexpected tags: " + ", ".join(map(str, unexpected_tags))) + def check_for_ecdsa_in_tags(cert): - if not cert.get('tags') or not 'ecdsa' in cert['tags']: - raise ValueError('ECDSA special case certs must contain an ECDSA tag') + if not cert.get("tags") or not "ecdsa" in cert["tags"]: + raise ValueError("ECDSA special case certs must contain an ECDSA tag") + def process_client_multivalue_rdn(cert): """Special handling for client-multivalue-rdn.pem""" @@ -539,19 +706,58 @@ def process_client_multivalue_rdn(cert): pem = tempfile.mkstemp()[1] dest = make_filename(cert) - ca = get_cert_path(cert['Issuer']) + ca = get_cert_path(cert["Issuer"]) serial = str(random.randint(1000, 0x7FFFFFFF)) - subject = '/CN=client+OU=KernelUser+O=MongoDB/L=New York City+ST=New York+C=US' - subprocess.check_call(['openssl', 'req', '-new', '-nodes', '-multivalue-rdn', '-subj', subject, '-keyout', key, '-out', csr]) - subprocess.check_call(['openssl', 'rsa', '-in', key, '-out', rsa]) - subprocess.check_call(['openssl', 'x509', '-in', csr, '-out', pem, '-req', '-CA', ca, '-CAkey', ca, '-days', str(MAX_VALIDITY_PERIOD_DAYS), '-sha256', '-set_serial', serial]) + subject = "/CN=client+OU=KernelUser+O=MongoDB/L=New York City+ST=New York+C=US" + subprocess.check_call( + [ + "openssl", + "req", + "-new", + "-nodes", + "-multivalue-rdn", + "-subj", + subject, + "-keyout", + key, + "-out", + csr, + ] + ) + subprocess.check_call(["openssl", "rsa", "-in", key, "-out", rsa]) + subprocess.check_call( + [ + "openssl", + "x509", + "-in", + csr, + "-out", + pem, + "-req", + "-CA", + ca, + "-CAkey", + ca, + "-days", + str(MAX_VALIDITY_PERIOD_DAYS), + "-sha256", + "-set_serial", + serial, + ] + ) - open(dest, 'wt').write(get_header_comment(cert) + "\n" + open(pem, 'rt').read() + open(rsa, 'rt').read()) + open(dest, "wt").write( + get_header_comment(cert) + + "\n" + + open(pem, "rt").read() + + open(rsa, "rt").read() + ) os.remove(key) os.remove(csr) os.remove(rsa) os.remove(pem) + def convert_ecdsa_key_to_pkcs8(ec_key_file, pkcs8_key_file): """ Convert ECDSA key with explicit text header into a PEM-encoded PKCS#8 object @@ -559,40 +765,53 @@ def convert_ecdsa_key_to_pkcs8(ec_key_file, pkcs8_key_file): :param pkcs8_key_file: name of file to contain PEM-encoded PKCS#8 object describing the ECDSA key """ # wrap the PEM-encoded EC key with explicit text header into a PEM-encoded PKCS#8 object - pkcs8args = ['openssl', 'pkcs8', '-topk8', '-nocrypt', '-in', ec_key_file, '-out', pkcs8_key_file] + pkcs8args = [ + "openssl", + "pkcs8", + "-topk8", + "-nocrypt", + "-in", + ec_key_file, + "-out", + pkcs8_key_file, + ] subprocess.check_call(pkcs8args) + def process_ecdsa_cert(cert, pem, key, dest, filename, split_pem=True): """Convert the ECDSA key and write the public/private key pair key into a .pem file, optionally writing the - public key to a .crt file and private key to a .key file""" + public key to a .crt file and private key to a .key file""" # copy the public key to a temp crt file temp_cert_filename = tempfile.mkstemp()[1] - cert_filename = make_filename({'name': f"{filename}.crt"}) - shutil.copy(src = pem, dst = temp_cert_filename) + cert_filename = make_filename({"name": f"{filename}.crt"}) + shutil.copy(src=pem, dst=temp_cert_filename) # convert and create the temp key file temp_key_filename = tempfile.mkstemp()[1] convert_ecdsa_key_to_pkcs8(ec_key_file=key, pkcs8_key_file=temp_key_filename) # combine public and private key into a .pem file with no comments - open(dest, 'wt').write(open(pem, 'rt').read() + open(temp_key_filename, 'rt').read()) + open(dest, "wt").write( + open(pem, "rt").read() + open(temp_key_filename, "rt").read() + ) - if split_pem: # copy the temp files into .crt + .key files - cert_filename = make_filename({'name': f"{filename}.crt"}) - shutil.copy(src = temp_cert_filename, dst = cert_filename) - key_filename = make_filename({'name': f"{filename}.key"}) - shutil.copy(src = temp_key_filename, dst = key_filename) + if split_pem: # copy the temp files into .crt + .key files + cert_filename = make_filename({"name": f"{filename}.crt"}) + shutil.copy(src=temp_cert_filename, dst=cert_filename) + key_filename = make_filename({"name": f"{filename}.key"}) + shutil.copy(src=temp_key_filename, dst=key_filename) os.remove(temp_key_filename) os.remove(temp_cert_filename) + def process_ecdsa_ca(cert): """Create CA for ECDSA tree.""" check_special_case_keys(cert) check_for_ecdsa_in_tags(cert) - if cert['Issuer'] != 'self': - raise ValueError('ECDSA-CA should be self-signed') + if cert["Issuer"] != "self": + raise ValueError("ECDSA-CA should be self-signed") key = tempfile.mkstemp()[1] csr = tempfile.mkstemp()[1] @@ -601,35 +820,60 @@ def process_ecdsa_ca(cert): dest = make_filename(cert) serial = str(random.randint(1000, 0x7FFFFFFF)) - subject = '/C=US/ST=New York/L=New York City/O=MongoDB/OU=Kernel/CN=Kernel Test ESCDA CA/' + subject = ( + "/C=US/ST=New York/L=New York City/O=MongoDB/OU=Kernel/CN=Kernel Test ESCDA CA/" + ) - reqargs = ['openssl', 'req', '-new', '-key', key, '-out', csr, '-subj', subject] - x509args = ['openssl', 'x509', '-in', csr, '-out', pem, '-req', '-signkey', key, '-days', str(MAX_VALIDITY_PERIOD_DAYS), '-sha256', '-set_serial', serial] - ecparamargs = (['openssl', 'ecparam', '-name', 'prime256v1', '-genkey', '-out', key, '-noout'] - if "ocsp" in cert.get('tags', []) - else ['openssl', 'ecparam', '-name', 'prime256v1', '-genkey', '-out', key]) + reqargs = ["openssl", "req", "-new", "-key", key, "-out", csr, "-subj", subject] + x509args = [ + "openssl", + "x509", + "-in", + csr, + "-out", + pem, + "-req", + "-signkey", + key, + "-days", + str(MAX_VALIDITY_PERIOD_DAYS), + "-sha256", + "-set_serial", + serial, + ] + ecparamargs = ( + ["openssl", "ecparam", "-name", "prime256v1", "-genkey", "-out", key, "-noout"] + if "ocsp" in cert.get("tags", []) + else ["openssl", "ecparam", "-name", "prime256v1", "-genkey", "-out", key] + ) - reqargs = reqargs + ['-reqexts', 'v3_req'] + reqargs = reqargs + ["-reqexts", "v3_req"] extfile = tempfile.mkstemp()[1] - open(extfile, 'wt').write('basicConstraints=CA:TRUE\n') - x509args.append('-extfile') + open(extfile, "wt").write("basicConstraints=CA:TRUE\n") + x509args.append("-extfile") x509args.append(extfile) subprocess.check_call(ecparamargs) subprocess.check_call(reqargs) subprocess.check_call(x509args) - if "ocsp" in cert['name']: + if "ocsp" in cert["name"]: # given foo.pem, we'll generate foo.crt and foo.key as well - filename = re.search('(.*)\.pem', cert['name']).group(1) + filename = re.search("(.*)\.pem", cert["name"]).group(1) process_ecdsa_cert(cert, pem, key, dest, filename) else: - open(dest, 'wt').write(get_header_comment(cert) + "\n" + open(pem, 'rt').read() + open(key, 'rt').read()) + open(dest, "wt").write( + get_header_comment(cert) + + "\n" + + open(pem, "rt").read() + + open(key, "rt").read() + ) os.remove(key) os.remove(csr) os.remove(pem) + def process_ecdsa_leaf(cert): """Create leaf certificates for ECDSA tree.""" check_special_case_keys(cert) @@ -641,52 +885,82 @@ def process_ecdsa_leaf(cert): extfile = None dest = make_filename(cert) - ca = get_cert_path(cert['Issuer']) + ca = get_cert_path(cert["Issuer"]) serial = str(random.randint(1000, 0x7FFFFFFF)) - mode = 'client' if cert['name'] == 'ecdsa-client.pem' else 'server' - ou = 'Kernel' if mode == 'server' else 'KernelUser' - subject = '/C=US/ST=New York/L=New York City/O=MongoDB/OU=' + ou + '/CN=' + mode + mode = "client" if cert["name"] == "ecdsa-client.pem" else "server" + ou = "Kernel" if mode == "server" else "KernelUser" + subject = "/C=US/ST=New York/L=New York City/O=MongoDB/OU=" + ou + "/CN=" + mode - reqargs = ['openssl', 'req', '-new', '-key', key, '-out', csr, '-subj', subject] - x509args = ['openssl', 'x509', '-in', csr, '-out', pem, '-req', '-CA', ca, '-CAkey', ca, '-days', str(MAX_VALIDITY_PERIOD_DAYS), '-sha256', '-set_serial', serial] - if mode == 'server': - reqargs = reqargs + ['-reqexts', 'v3_req'] + reqargs = ["openssl", "req", "-new", "-key", key, "-out", csr, "-subj", subject] + x509args = [ + "openssl", + "x509", + "-in", + csr, + "-out", + pem, + "-req", + "-CA", + ca, + "-CAkey", + ca, + "-days", + str(MAX_VALIDITY_PERIOD_DAYS), + "-sha256", + "-set_serial", + serial, + ] + if mode == "server": + reqargs = reqargs + ["-reqexts", "v3_req"] extfile = tempfile.mkstemp()[1] - with open(extfile, 'wt') as f: - f.write('basicConstraints=CA:FALSE\n') - f.write('subjectAltName=DNS:localhost,IP:127.0.0.1\n') - f.write('subjectKeyIdentifier=hash\n') - key_usage = ('keyUsage=nonRepudiation,digitalSignature,keyEncipherment\n' - if "responder" in cert.get('tags', []) - else 'keyUsage=digitalSignature,keyEncipherment\n') + with open(extfile, "wt") as f: + f.write("basicConstraints=CA:FALSE\n") + f.write("subjectAltName=DNS:localhost,IP:127.0.0.1\n") + f.write("subjectKeyIdentifier=hash\n") + key_usage = ( + "keyUsage=nonRepudiation,digitalSignature,keyEncipherment\n" + if "responder" in cert.get("tags", []) + else "keyUsage=digitalSignature,keyEncipherment\n" + ) f.write(key_usage) - extended_key_usage = ('extendedKeyUsage=serverAuth,clientAuth,OCSPSigning\n' - if "responder" in cert.get('tags', []) - else 'extendedKeyUsage=serverAuth,clientAuth\n') + extended_key_usage = ( + "extendedKeyUsage=serverAuth,clientAuth,OCSPSigning\n" + if "responder" in cert.get("tags", []) + else "extendedKeyUsage=serverAuth,clientAuth\n" + ) f.write(extended_key_usage) - if cert.get('tags'): - if "ocsp" in cert['tags']: - if not "responder" in cert['tags']: - f.write('authorityInfoAccess = OCSP;URI:http://localhost:9001/power/level,OCSP;URI:http://localhost:8100/status\n') - if "must-staple" in cert['tags']: - f.write(f'{MUST_STAPLE_KEY_STR}={MUST_STAPLE_VALUE_STR}\n') - x509args.append('-extfile') + if cert.get("tags"): + if "ocsp" in cert["tags"]: + if not "responder" in cert["tags"]: + f.write( + "authorityInfoAccess = OCSP;URI:http://localhost:9001/power/level,OCSP;URI:http://localhost:8100/status\n" + ) + if "must-staple" in cert["tags"]: + f.write(f"{MUST_STAPLE_KEY_STR}={MUST_STAPLE_VALUE_STR}\n") + x509args.append("-extfile") x509args.append(extfile) - subprocess.check_call(['openssl', 'ecparam', '-name', 'prime256v1', '-genkey', '-out', key]) + subprocess.check_call( + ["openssl", "ecparam", "-name", "prime256v1", "-genkey", "-out", key] + ) subprocess.check_call(reqargs) subprocess.check_call(x509args) - if 'responder' in cert.get('tags', []): - # given foo.crt, we'll generate foo.pem and foo.key - filename = re.search('(.*)\.crt', cert['name']).group(1) - process_ecdsa_cert(cert, pem, key, dest, filename) - elif "ocsp" in cert.get('tags', []): - # given foo.pem, we'll regenerate foo.pem and delete foo.crt and foo.key - filename = re.search('(.*)\.pem', cert['name']).group(1) - process_ecdsa_cert(cert, pem, key, dest, filename, split_pem=False) + if "responder" in cert.get("tags", []): + # given foo.crt, we'll generate foo.pem and foo.key + filename = re.search("(.*)\.crt", cert["name"]).group(1) + process_ecdsa_cert(cert, pem, key, dest, filename) + elif "ocsp" in cert.get("tags", []): + # given foo.pem, we'll regenerate foo.pem and delete foo.crt and foo.key + filename = re.search("(.*)\.pem", cert["name"]).group(1) + process_ecdsa_cert(cert, pem, key, dest, filename, split_pem=False) else: - open(dest, 'wt').write(get_header_comment(cert) + "\n" + open(pem, 'rt').read() + open(key, 'rt').read()) + open(dest, "wt").write( + get_header_comment(cert) + + "\n" + + open(pem, "rt").read() + + open(key, "rt").read() + ) os.remove(key) os.remove(csr) @@ -694,82 +968,131 @@ def process_ecdsa_leaf(cert): if extfile: os.remove(extfile) + def process_cert(cert): """Process a certificate.""" - print('Processing certificate: ' + cert['name']) + print("Processing certificate: " + cert["name"]) - if cert['name'] == 'client-multivalue-rdn.pem': + if cert["name"] == "client-multivalue-rdn.pem": process_client_multivalue_rdn(cert) return - if cert['name'] in ['ecdsa-ca.pem', 'ecdsa-ca-ocsp.pem']: + if cert["name"] in ["ecdsa-ca.pem", "ecdsa-ca-ocsp.pem"]: process_ecdsa_ca(cert) return - if cert['name'] in ['ecdsa-client.pem', 'ecdsa-server.pem', 'ecdsa-server-ocsp.pem', - 'ecdsa-server-ocsp-mustStaple.pem', 'ecdsa-ocsp-responder.crt']: + if cert["name"] in [ + "ecdsa-client.pem", + "ecdsa-server.pem", + "ecdsa-server-ocsp.pem", + "ecdsa-server-ocsp-mustStaple.pem", + "ecdsa-ocsp-responder.crt", + ]: process_ecdsa_leaf(cert) return - append_certs = cert.get('append_cert', []) + append_certs = cert.get("append_cert", []) if isinstance(append_certs, str): append_certs = [append_certs] - subject = cert.get('Subject'); - explicit_empty_subject = cert.get('explicit_subject', False) and not subject; - + subject = cert.get("Subject") + explicit_empty_subject = cert.get("explicit_subject", False) and not subject if subject or explicit_empty_subject: create_cert(cert) elif append_certs: # Pure composing certificate. Start with a basic preamble. - open(make_filename(cert), 'wt').write(get_header_comment(cert) + "\n") + open(make_filename(cert), "wt").write(get_header_comment(cert) + "\n") else: - raise ValueError("Certificate definitions must have at least one of 'Subject' and/or 'append_cert'") + raise ValueError( + "Certificate definitions must have at least one of 'Subject' and/or 'append_cert'" + ) for append_cert in append_certs: x509 = load_authority_file(append_cert)[0] if not x509: - raise ValueError("Unable to find certificate '" + append_cert + "' to append") - header = "# Certificate from " + append_cert + "\n" if cert.get('include_header', True) else "" - open(make_filename(cert), 'at').write( - header + - OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, x509).decode('ascii')) + raise ValueError( + "Unable to find certificate '" + append_cert + "' to append" + ) + header = ( + "# Certificate from " + append_cert + "\n" + if cert.get("include_header", True) + else "" + ) + open(make_filename(cert), "at").write( + header + + OpenSSL.crypto.dump_certificate(OpenSSL.crypto.FILETYPE_PEM, x509).decode( + "ascii" + ) + ) + def parse_command_line(): """Accept a named config file.""" # pylint: disable=global-statement global CONFIGFILE - parser = argparse.ArgumentParser(description='X509 Test Certificate Generator') - parser.add_argument('--config', help='Certificate definition file', type=str, default=CONFIGFILE) - parser.add_argument('cert', nargs='*', help='Certificate to generate (blank for all)') + parser = argparse.ArgumentParser(description="X509 Test Certificate Generator") + parser.add_argument( + "--config", help="Certificate definition file", type=str, default=CONFIGFILE + ) + parser.add_argument( + "cert", nargs="*", help="Certificate to generate (blank for all)" + ) args = parser.parse_args() CONFIGFILE = args.config return args.cert or [] + def validate_config(): """Perform basic start up time validation of config file.""" - if not glbl('output_path'): - raise ValueError('global.output_path required') + if not glbl("output_path"): + raise ValueError("global.output_path required") - if not CONFIG.get('certs'): - raise ValueError('No certificates defined') + if not CONFIG.get("certs"): + raise ValueError("No certificates defined") - permissible = ['name', 'description', 'Subject', 'Issuer', 'append_cert', 'extensions', 'passphrase', 'output_path', 'hash', 'include_header', 'key_type', 'keyfile', 'crtfile', 'explicit_subject', 'serial', 'not_before', 'not_after', 'pkcs1', 'pkcs12', 'version', 'tags'] - for cert in CONFIG.get('certs', []): + permissible = [ + "name", + "description", + "Subject", + "Issuer", + "append_cert", + "extensions", + "passphrase", + "output_path", + "hash", + "include_header", + "key_type", + "keyfile", + "crtfile", + "explicit_subject", + "serial", + "not_before", + "not_after", + "pkcs1", + "pkcs12", + "version", + "tags", + ] + for cert in CONFIG.get("certs", []): keys = cert.keys() - if not 'name' in keys: - raise ValueError('Name field required for all certificate definitions') - if not 'description' in keys: - raise ValueError('description field required for all certificate definitions') + if not "name" in keys: + raise ValueError("Name field required for all certificate definitions") + if not "description" in keys: + raise ValueError( + "description field required for all certificate definitions" + ) for key in keys: if not key in permissible: - raise ValueError("Unknown element '" + key + "' in certificate: " + cert['name']) + raise ValueError( + "Unknown element '" + key + "' in certificate: " + cert["name"] + ) + def select_items(names): """Select all certificates requested and their leaf nodes.""" if not names: - return CONFIG['certs'] + return CONFIG["certs"] # Temporarily treat like dictionary for easy de-duping. ret = {} @@ -777,24 +1100,37 @@ def select_items(names): for name in names: cert = find_certificate_definition(name) if not cert: - raise ValueError('Unknown certificate: ' + name) + raise ValueError("Unknown certificate: " + name) ret[name] = cert last_count = -1 while last_count != len(ret): last_count = len(ret) # Add any certs who use our current set as an issuer. - ret.update({cert['name']: cert for cert in CONFIG['certs'] if cert.get('Issuer') in names}) + ret.update( + { + cert["name"]: cert + for cert in CONFIG["certs"] + if cert.get("Issuer") in names + } + ) # Add any certs who are composed of our current set. - ret.update({cert['name']: cert for cert in CONFIG['certs'] if [True for append in cert.get('append_cert', []) if append in names]}) + ret.update( + { + cert["name"]: cert + for cert in CONFIG["certs"] + if [True for append in cert.get("append_cert", []) if append in names] + } + ) # Repeat until no new names are added. names = ret.keys() return ret.values() + def sort_items(items): """Ensure that leaves are produced after roots (as much as possible within one file).""" - all_names = [cert['name'] for cert in items] + all_names = [cert["name"] for cert in items] all_names.sort() processed_names = [] @@ -802,33 +1138,43 @@ def sort_items(items): while len(ret) != len(items): for cert in items: # only concern ourselves with prependents in this config file. - unmet_prependents = [name for name in cert.get('append_certs', []) if (name in all_names) and (not name in processed_names)] + unmet_prependents = [ + name + for name in cert.get("append_certs", []) + if (name in all_names) and (not name in processed_names) + ] # Self-signed, signed by someone in ret already, or signed externally - issuer = cert.get('Issuer') - has_issuer = (issuer == 'self') or (issuer in processed_names) or (issuer not in all_names) + issuer = cert.get("Issuer") + has_issuer = ( + (issuer == "self") + or (issuer in processed_names) + or (issuer not in all_names) + ) if has_issuer and not unmet_prependents: ret.append(cert) - processed_names.append(cert['name']) + processed_names.append(cert["name"]) return ret + def main(): """Go go go.""" # pylint: disable=global-statement global CONFIG items_to_process = parse_command_line() - CONFIG = yaml.load(open(CONFIGFILE, 'r'), Loader=yaml.FullLoader) + CONFIG = yaml.load(open(CONFIGFILE, "r"), Loader=yaml.FullLoader) validate_config() items = select_items(items_to_process) items = sort_items(items) for item in items: process_cert(item) filename = make_filename(item) - mkdigest.make_digest(filename, 'cert', 'sha256') - mkdigest.make_digest(filename, 'cert', 'sha1') + mkdigest.make_digest(filename, "cert", "sha256") + mkdigest.make_digest(filename, "cert", "sha1") -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/jstests/ssl/x509/mkdigest.py b/jstests/ssl/x509/mkdigest.py index a7b0bdb448d..d52d061eeeb 100755 --- a/jstests/ssl/x509/mkdigest.py +++ b/jstests/ssl/x509/mkdigest.py @@ -3,40 +3,52 @@ This script calculates and writes out digests for x509 certificates/CRLs. Invoke as `mkdigest.py [filename2 ...]` """ + import argparse import OpenSSL import cryptography.hazmat.primitives.hashes as hashes -DIGEST_NAME_TO_HASH = {'sha256': hashes.SHA256(), 'sha1': hashes.SHA1()} +DIGEST_NAME_TO_HASH = {"sha256": hashes.SHA256(), "sha1": hashes.SHA1()} + def make_digest(filename, item_type, digest_type): """Calculate the given digest of the certificate/CRL passed in and write it out to .digest.""" assert item_type in {"cert", "crl"} assert digest_type in {"sha256", "sha1"} - with open(filename, 'r') as f: + with open(filename, "r") as f: data = f.read() - if item_type == 'cert': + if item_type == "cert": cert = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, data) rawdigest = cert.digest(digest_type) - digest = rawdigest.decode('utf8').replace(':', '') - elif item_type == 'crl': + digest = rawdigest.decode("utf8").replace(":", "") + elif item_type == "crl": crl = OpenSSL.crypto.load_crl(OpenSSL.crypto.FILETYPE_PEM, data) rawdigest = crl.to_cryptography().fingerprint(DIGEST_NAME_TO_HASH[digest_type]) digest = rawdigest.hex().upper() - with open(filename + '.digest.' + digest_type, 'w') as f: + with open(filename + ".digest." + digest_type, "w") as f: f.write(digest) + def main(): - parser = argparse.ArgumentParser(description='X509 Digest Generator') - parser.add_argument('type', choices={"cert", "crl"}, help='Type of X509 object to generate digest for') - parser.add_argument('digest', choices={"sha1", "sha256"}, help='Algorithm for digest') - parser.add_argument('filename', nargs='+', help='Path of X509 file to generate digest for') + parser = argparse.ArgumentParser(description="X509 Digest Generator") + parser.add_argument( + "type", + choices={"cert", "crl"}, + help="Type of X509 object to generate digest for", + ) + parser.add_argument( + "digest", choices={"sha1", "sha256"}, help="Algorithm for digest" + ) + parser.add_argument( + "filename", nargs="+", help="Path of X509 file to generate digest for" + ) args = parser.parse_args() for fname in args.filename: make_digest(fname, args.type, args.digest) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/jstests/ssl_linear/windows_castore_cleanup.py b/jstests/ssl_linear/windows_castore_cleanup.py index ff9de13219c..342ad03a12d 100644 --- a/jstests/ssl_linear/windows_castore_cleanup.py +++ b/jstests/ssl_linear/windows_castore_cleanup.py @@ -2,6 +2,7 @@ import subprocess import sys import re + def findMongoCertsFromStore(store): command = ["certutil", "-store", store] subject_pattern = re.compile(r"Subject:.*O=MongoDB") @@ -21,21 +22,27 @@ def findMongoCertsFromStore(store): cns.append(cn_match.group(1)) return cns + def deleteCertsByCNFromStore(store, cns): command = ["certutil", "-delstore", "-f", store, "cn"] for cn in cns: command[4] = cn try: - print(f"Deleting 'CN={cn}' from the '{store}' certificate store:\n\t{' ' .join(command)}") + print( + f"Deleting 'CN={cn}' from the '{store}' certificate store:\n\t{' ' .join(command)}" + ) subprocess.check_call(command, shell=True) except subprocess.CalledProcessError as e: print(f"Command {command} failed with error: {e}", file=sys.stderr) sys.exit(1) + my_cns = findMongoCertsFromStore("My") root_cns = findMongoCertsFromStore("Root") if my_cns + root_cns: - print(f"Unexpected MongoDB certs found on host. Clearing them from the system cert stores.") + print( + f"Unexpected MongoDB certs found on host. Clearing them from the system cert stores." + ) deleteCertsByCNFromStore("My", my_cns) -deleteCertsByCNFromStore("Root", root_cns) \ No newline at end of file +deleteCertsByCNFromStore("Root", root_cns) diff --git a/src/mongo/base/generate_error_codes.py b/src/mongo/base/generate_error_codes.py index 21da251a314..29ceb705184 100755 --- a/src/mongo/base/generate_error_codes.py +++ b/src/mongo/base/generate_error_codes.py @@ -33,7 +33,7 @@ import argparse import sys import yaml -help_epilog=""" +help_epilog = """ The error_codes_spec YAML document is a mapping containing two toplevel fields: `error_categories`: sequence of string - The error category names @@ -47,32 +47,39 @@ The error_codes_spec YAML document is a mapping containing two toplevel fields: the ErrorCode. """ + def init_parser(): global parser parser = argparse.ArgumentParser( - formatter_class=argparse.RawDescriptionHelpFormatter, - description=__doc__, - epilog=help_epilog) - parser.add_argument('--verbose', - action='store_true', - help='extra debug logging to stderr') - parser.add_argument('error_codes_spec', - help='YAML file describing error codes and categories') - parser.add_argument('template_file', - help='Cheetah template file') - parser.add_argument('output_file') + formatter_class=argparse.RawDescriptionHelpFormatter, + description=__doc__, + epilog=help_epilog, + ) + parser.add_argument( + "--verbose", action="store_true", help="extra debug logging to stderr" + ) + parser.add_argument( + "error_codes_spec", help="YAML file describing error codes and categories" + ) + parser.add_argument("template_file", help="Cheetah template file") + parser.add_argument("output_file") + verbose = False def render_template(template_path, **kw): - '''Renders the template file located at template_path, using the variables defined by kw, and - returns the result as a string''' + """Renders the template file located at template_path, using the variables defined by kw, and + returns the result as a string""" template = Template.compile( file=template_path, - compilerSettings=dict(directiveStartToken="//#", directiveEndToken="//#", - commentStartToken="//##"), baseclass=dict, useCache=False) + compilerSettings=dict( + directiveStartToken="//#", directiveEndToken="//#", commentStartToken="//##" + ), + baseclass=dict, + useCache=False, + ) return str(template(**kw)) @@ -83,13 +90,17 @@ class ErrorCode: self.extra = extra self.extraIsOptional = extraIsOptional if extra: - split = extra.split('::') + split = extra.split("::") if not split[0]: - die("Error for %s with extra info %s: fully qualified namespaces aren't supported" % - (name, extra)) + die( + "Error for %s with extra info %s: fully qualified namespaces aren't supported" + % (name, extra) + ) if split[0] == "mongo": - die("Error for %s with extra info %s: don't include the mongo namespace" % (name, - extra)) + die( + "Error for %s with extra info %s: don't include the mongo namespace" + % (name, extra) + ) if len(split) > 1: self.extra_class = split.pop() self.extra_ns = "::".join(split) @@ -104,6 +115,7 @@ class ErrorClass: self.name = name self.codes = codes + def main(): init_parser() parsed = parser.parse_args() @@ -119,51 +131,55 @@ def main(): # Render the templates to the output files. if verbose: - print(f'rendering {template_file} => {output_file}') - text = render_template(template_file, - codes=error_codes, - categories=error_classes, - ) - with open(output_file, 'w') as outfile: + print(f"rendering {template_file} => {output_file}") + text = render_template( + template_file, + codes=error_codes, + categories=error_classes, + ) + with open(output_file, "w") as outfile: outfile.write(text) + def die(message=None): sys.stderr.write(message or "Fatal error\n") sys.exit(1) + def usage(message=None): parser.error(message) # writes a usage message and exits the program dies + def parse_error_definitions_from_file(errors_filename): error_codes = [] error_classes = [] - with open(errors_filename, 'r') as errors_file: + with open(errors_filename, "r") as errors_file: doc = yaml.safe_load(errors_file) if verbose: yaml.dump(doc, sys.stderr) cats = {} - for v in doc['error_categories']: + for v in doc["error_categories"]: cats[v] = [] - for v in doc['error_codes']: + for v in doc["error_codes"]: assert type(v) is dict - name, code = v['name'], v['code'] + name, code = v["name"], v["code"] extraIsOptional = False - if 'extraIsOptional' in v: - extraIsOptional = v['extraIsOptional'] + if "extraIsOptional" in v: + extraIsOptional = v["extraIsOptional"] - if 'categories' in v: - for cat in v['categories']: - assert cat in cats, f'invalid category {cat} for code {name}' + if "categories" in v: + for cat in v["categories"]: + assert cat in cats, f"invalid category {cat} for code {name}" cats[cat].append(name) kw = {} - if 'extra' in v: - kw['extra'] = v['extra'] + if "extra" in v: + kw["extra"] = v["extra"] error_codes.append(ErrorCode(name, code, **kw, extraIsOptional=extraIsOptional)) @@ -174,6 +190,7 @@ def parse_error_definitions_from_file(errors_filename): return error_codes, error_classes + def check_for_conflicts(error_codes, error_classes): failed = has_duplicate_error_codes(error_codes) if has_duplicate_error_classes(error_classes): @@ -193,7 +210,9 @@ def has_duplicate_error_codes(error_codes): for curr in sorted_by_name[1:]: if curr.name == prev.name: sys.stdout.write( - 'Duplicate name %s with codes %s and %s\n' % (curr.name, curr.code, prev.code)) + "Duplicate name %s with codes %s and %s\n" + % (curr.name, curr.code, prev.code) + ) failed = True prev = curr @@ -201,7 +220,9 @@ def has_duplicate_error_codes(error_codes): for curr in sorted_by_code[1:]: if curr.code == prev.code: sys.stdout.write( - 'Duplicate code %s with names %s and %s\n' % (curr.code, curr.name, prev.name)) + "Duplicate code %s with names %s and %s\n" + % (curr.code, curr.name, prev.name) + ) failed = True prev = curr @@ -215,7 +236,7 @@ def has_duplicate_error_classes(error_classes): prev_name = names[0] for name in names[1:]: if prev_name == name: - sys.stdout.write('Duplicate error class name %s\n' % name) + sys.stdout.write("Duplicate error class name %s\n" % name) failed = True prev_name = name return failed @@ -229,11 +250,13 @@ def has_missing_error_codes(error_codes, error_classes): try: code_names[name].categories.append(category.name) except KeyError: - sys.stdout.write('Undeclared error code %s in class %s\n' % (name, category.name)) + sys.stdout.write( + "Undeclared error code %s in class %s\n" % (name, category.name) + ) failed = True return failed -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/src/mongo/db/auth/builtin_roles_gen.py b/src/mongo/db/auth/builtin_roles_gen.py index 1d51d995943..25c5cd79edd 100755 --- a/src/mongo/db/auth/builtin_roles_gen.py +++ b/src/mongo/db/auth/builtin_roles_gen.py @@ -33,7 +33,7 @@ import argparse import sys import yaml -help_epilog=""" +help_epilog = """ The builtin_roles_spec YAML document is a mapping containing one toplevel field: `roles` is a mapping for role name to role definitions. @@ -70,48 +70,64 @@ The builtin_roles_spec YAML document is a mapping containing one toplevel field: 'tenant': Only applies to non-system tenants in multi-tenancy. """ + def init_parser(): parser = argparse.ArgumentParser( - formatter_class=argparse.RawDescriptionHelpFormatter, - description=__doc__, - epilog=help_epilog) - parser.add_argument('--verbose', - action='store_true', - help='extra debug logging to stderr') - parser.add_argument('builtin_roles_spec', - help='YAML file describing builtin roles') - parser.add_argument('template_file', - help='Cheetah template file') - parser.add_argument('output_file') + formatter_class=argparse.RawDescriptionHelpFormatter, + description=__doc__, + epilog=help_epilog, + ) + parser.add_argument( + "--verbose", action="store_true", help="extra debug logging to stderr" + ) + parser.add_argument("builtin_roles_spec", help="YAML file describing builtin roles") + parser.add_argument("template_file", help="Cheetah template file") + parser.add_argument("output_file") return parser + def render_template(template_path, **kw): - '''Renders the template file located at template_path, using the variables defined by kw, and - returns the result as a string''' + """Renders the template file located at template_path, using the variables defined by kw, and + returns the result as a string""" template = Template.compile( file=template_path, - compilerSettings=dict(directiveStartToken="//#", directiveEndToken="//#", - commentStartToken="//##"), baseclass=dict, useCache=False) + compilerSettings=dict( + directiveStartToken="//#", directiveEndToken="//#", commentStartToken="//##" + ), + baseclass=dict, + useCache=False, + ) return str(template(**kw)) + def check_allowed_fields(mapping, allowed): for field in mapping: if field not in allowed: raise Exception("Unknown field '%s' in %r" % (field, mapping)) + def check_required_fields(mapping, required): for field in required: if field not in mapping: raise Exception("Missing required field '%s' in %r" % (field, mapping)) + def assert_str_field(name, value): if type(value) is not str: - raise Exception("Invalid type for string field '%s', got '%s' ': %r" % (name, type(value), value)) + raise Exception( + "Invalid type for string field '%s', got '%s' ': %r" + % (name, type(value), value) + ) + def assert_list_field(name, value): if type(value) is not list: - raise Exception("Invalid type for list field '%s', got '%s' ': %r" % (name, type(value), value)) + raise Exception( + "Invalid type for list field '%s', got '%s' ': %r" + % (name, type(value), value) + ) + def get_nonempty_str_field(mapping, fieldName): assert_str_field(fieldName, mapping[fieldName]) @@ -119,6 +135,7 @@ def get_nonempty_str_field(mapping, fieldName): raise Exception("Field '%s' value must be a non-empty string" % (fieldName)) return mapping[fieldName] + class InheritedRole: def __init__(self, spec): self.db = None @@ -127,63 +144,84 @@ class InheritedRole: self.role = spec else: if type(spec) is not dict: - raise Exception('Inherited role must be either a simple name, or a role/db tuple, got: %r' % (spec)) - check_allowed_fields(spec, ['role', 'db']) - check_required_fields(spec, ['role']) - self.role = get_nonempty_str_field(spec, 'role') - if 'db' in spec: - self.db = get_nonempty_str_field(spec, 'db') + raise Exception( + "Inherited role must be either a simple name, or a role/db tuple, got: %r" + % (spec) + ) + check_allowed_fields(spec, ["role", "db"]) + check_required_fields(spec, ["role"]) + self.role = get_nonempty_str_field(spec, "role") + if "db" in spec: + self.db = get_nonempty_str_field(spec, "db") + class Privilege: def __init__(self, spec): + check_allowed_fields( + spec, + ["matchType", "db", "collection", "system_buckets", "actions", "tenancy"], + ) + check_required_fields(spec, ["matchType", "actions"]) - check_allowed_fields(spec, ['matchType', 'db', 'collection', 'system_buckets', 'actions', 'tenancy']) - check_required_fields(spec, ['matchType', 'actions']) - - self.matchType = get_nonempty_str_field(spec, 'matchType') + self.matchType = get_nonempty_str_field(spec, "matchType") self.db = None self.collection = None self.system_buckets = None self.actions = [] - self.tenancy = 'any' + self.tenancy = "any" - db_valid_types = ['database', 'exact_namespace', 'system_buckets', 'any_system_buckets_in_db'] - if 'db' in spec: + db_valid_types = [ + "database", + "exact_namespace", + "system_buckets", + "any_system_buckets_in_db", + ] + if "db" in spec: if self.matchType in db_valid_types: - self.db = get_nonempty_str_field(spec, 'db') + self.db = get_nonempty_str_field(spec, "db") else: - raise Exception("db field is not valid for matchType: %s" % (self.matchType)) + raise Exception( + "db field is not valid for matchType: %s" % (self.matchType) + ) - coll_valid_types = ['collection', 'exact_namespace'] + coll_valid_types = ["collection", "exact_namespace"] if self.matchType in coll_valid_types: - check_required_fields(spec, ['collection']) - self.collection = get_nonempty_str_field(spec, 'collection') - elif 'collection' in spec: - raise Exception("collection field is not valid for matchType: %s" % (self.matchType)) + check_required_fields(spec, ["collection"]) + self.collection = get_nonempty_str_field(spec, "collection") + elif "collection" in spec: + raise Exception( + "collection field is not valid for matchType: %s" % (self.matchType) + ) - buckets_valid_types = ['system_buckets', 'system_buckets_in_any_db'] + buckets_valid_types = ["system_buckets", "system_buckets_in_any_db"] if self.matchType in buckets_valid_types: - check_required_fields(spec, ['system_buckets']) - db.system_buckets = get_nonempty_str_field(spec, 'system_buckets') - elif 'system_buckets' in spec: - raise Exception("system_buckets field is not valid for matchType: %s" % (self.matchType)) + check_required_fields(spec, ["system_buckets"]) + db.system_buckets = get_nonempty_str_field(spec, "system_buckets") + elif "system_buckets" in spec: + raise Exception( + "system_buckets field is not valid for matchType: %s" % (self.matchType) + ) - assert_list_field('actions', spec['actions']) - for action in spec['actions']: + assert_list_field("actions", spec["actions"]) + for action in spec["actions"]: if type(action) is list: for subaction in action: - assert_str_field('actions', subaction) + assert_str_field("actions", subaction) self.actions.append(subaction) else: - assert_str_field('actions', action) + assert_str_field("actions", action) self.actions.append(action) - if 'tenancy' in spec: - assert_str_field('tenancy', spec['tenancy']) - tenancy_options = ['any', 'single', 'multi', 'system', 'tenant'] - if spec['tenancy'] not in tenancy_options: - raise Exception("Invalid value for enum field 'tenancy', got '%s', expeted one of %r" % (spec['tenancy'], tenancy_options)) - self.tenancy = spec['tenancy'] + if "tenancy" in spec: + assert_str_field("tenancy", spec["tenancy"]) + tenancy_options = ["any", "single", "multi", "system", "tenant"] + if spec["tenancy"] not in tenancy_options: + raise Exception( + "Invalid value for enum field 'tenancy', got '%s', expeted one of %r" + % (spec["tenancy"], tenancy_options) + ) + self.tenancy = spec["tenancy"] + class BuiltinRole: def __init__(self, name, spec): @@ -192,36 +230,40 @@ class BuiltinRole: self.roles = [] self.privileges = [] - check_allowed_fields(spec, ['adminOnly', 'roles', 'privileges']) + check_allowed_fields(spec, ["adminOnly", "roles", "privileges"]) - if 'adminOnly' in spec: - if type(spec['adminOnly']) is not bool: - raise Exception('adminOnly must be a bool, got: %r' % (spec['adminOnly'])) - self.adminOnly = spec['adminOnly'] + if "adminOnly" in spec: + if type(spec["adminOnly"]) is not bool: + raise Exception( + "adminOnly must be a bool, got: %r" % (spec["adminOnly"]) + ) + self.adminOnly = spec["adminOnly"] - if 'roles' in spec: - assert_list_field('roles', spec['roles']) - for role in spec['roles']: + if "roles" in spec: + assert_list_field("roles", spec["roles"]) + for role in spec["roles"]: self.roles.append(InheritedRole(role)) - if 'privileges' in spec: - assert_list_field('privileges', spec['privileges']) - for priv in spec['privileges']: + if "privileges" in spec: + assert_list_field("privileges", spec["privileges"]) + for priv in spec["privileges"]: self.privileges.append(Privilege(priv)) + def parse_builtin_role_definitions_from_file(roles_filename, verbose=False): roles = [] - with open(roles_filename, 'r') as roles_file: + with open(roles_filename, "r") as roles_file: doc = yaml.safe_load(roles_file) if verbose: yaml.dump(doc, sys.stderr) - for roleName in doc['roles']: - roles.append(BuiltinRole(roleName, doc['roles'][roleName])) + for roleName in doc["roles"]: + roles.append(BuiltinRole(roleName, doc["roles"][roleName])) return roles + def main(): parsed = init_parser().parse_args() verbose = parsed.verbose @@ -229,14 +271,17 @@ def main(): output_file = parsed.output_file # Parse and validate builtin_roles.yml - builtin_roles = parse_builtin_role_definitions_from_file(parsed.builtin_roles_spec, verbose) + builtin_roles = parse_builtin_role_definitions_from_file( + parsed.builtin_roles_spec, verbose + ) # Render the templates to the output files. if verbose: - print(f'rendering {template_file} => {output_file}') + print(f"rendering {template_file} => {output_file}") text = render_template(template_file, roles=builtin_roles) - with open(output_file, 'w') as outfile: + with open(output_file, "w") as outfile: outfile.write(text) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/src/mongo/db/concurrency/lock_gdb_test.py b/src/mongo/db/concurrency/lock_gdb_test.py index 7e1ef5b63b9..8099d2555c8 100644 --- a/src/mongo/db/concurrency/lock_gdb_test.py +++ b/src/mongo/db/concurrency/lock_gdb_test.py @@ -1,16 +1,15 @@ -"""Script to be invoked by GDB for testing lock manager pretty printer. -""" +"""Script to be invoked by GDB for testing lock manager pretty printer.""" import traceback import gdb try: - gdb.execute('break main') - gdb.execute('run') - gdb_type = lookup_type('mongo::LockManager') - assert gdb_type is not None, 'Failed to lookup type mongo::LockManager' - gdb.write('TEST PASSED\n') + gdb.execute("break main") + gdb.execute("run") + gdb_type = lookup_type("mongo::LockManager") + assert gdb_type is not None, "Failed to lookup type mongo::LockManager" + gdb.write("TEST PASSED\n") except Exception: - gdb.write('TEST FAILED -- {!s}\n'.format(traceback.format_exc())) - gdb.execute('quit 1', to_string=True) + gdb.write("TEST FAILED -- {!s}\n".format(traceback.format_exc())) + gdb.execute("quit 1", to_string=True) diff --git a/src/mongo/db/fts/generate_stop_words.py b/src/mongo/db/fts/generate_stop_words.py index 6f4504933e5..c6c24b88d19 100644 --- a/src/mongo/db/fts/generate_stop_words.py +++ b/src/mongo/db/fts/generate_stop_words.py @@ -7,7 +7,11 @@ def read_stop_words(file_path): print(f"Warning: File {file_path} does not exist. Skipping.") return [] with open(file_path, "r", encoding="utf-8") as f: - return [line.strip() for line in f if line.strip() and not line.strip().startswith("#")] + return [ + line.strip() + for line in f + if line.strip() and not line.strip().startswith("#") + ] def generate(source, language_files): @@ -44,4 +48,4 @@ if __name__ == "__main__": source_file = sys.argv[-1] language_files = sys.argv[1:-1] - generate(source_file, language_files) \ No newline at end of file + generate(source_file, language_files) diff --git a/src/mongo/db/fts/unicode/gen_casefold_map.py b/src/mongo/db/fts/unicode/gen_casefold_map.py index 740a08d471d..d0f6bc0b571 100644 --- a/src/mongo/db/fts/unicode/gen_casefold_map.py +++ b/src/mongo/db/fts/unicode/gen_casefold_map.py @@ -3,8 +3,7 @@ import os import sys -from gen_helper import getCopyrightNotice, openNamespaces, closeNamespaces, \ - include +from gen_helper import getCopyrightNotice, openNamespaces, closeNamespaces, include def generate(unicode_casefold_file, target): @@ -14,7 +13,7 @@ def generate(unicode_casefold_file, target): The case folding function contains a switch statement with cases for every Unicode codepoint that has a case folding mapping. """ - out = open(target, "w", encoding='utf-8') + out = open(target, "w", encoding="utf-8") out.write(getCopyrightNotice()) out.write(include("mongo/db/fts/unicode/codepoints.h")) @@ -23,30 +22,30 @@ def generate(unicode_casefold_file, target): case_mappings = {} - cf_file = open(unicode_casefold_file, 'r', encoding='utf-8') + cf_file = open(unicode_casefold_file, "r", encoding="utf-8") for line in cf_file: # Filter out blank lines and lines that start with # - data = line[:line.find('#')] - if(data == ""): + data = line[: line.find("#")] + if data == "": continue # Parse the data on the line values = data.split("; ") - assert(len(values) == 4) + assert len(values) == 4 status = values[1] - if status == 'C' or status == 'S': + if status == "C" or status == "S": # We only include the "Common" and "Simple" mappings. "Full" case # folding mappings expand certain letters to multiple codepoints, # which we currently do not support. original_codepoint = int(values[0], 16) - codepoint_mapping = int(values[2], 16) + codepoint_mapping = int(values[2], 16) case_mappings[original_codepoint] = codepoint_mapping turkishMapping = { 0x49: 0x131, # I -> ı - 0x130: 0x069, # İ -> i + 0x130: 0x069, # İ -> i } out.write( @@ -60,7 +59,8 @@ def generate(unicode_casefold_file, target): return codepoint; } - switch (codepoint) {\n""") + switch (codepoint) {\n""" + ) mappings_list = [] @@ -76,17 +76,21 @@ def generate(unicode_casefold_file, target): sorted_mappings = sorted(mappings_list, key=lambda mapping: mapping[0]) for mapping in sorted_mappings: - if mapping[0] <= 0x7f: + if mapping[0] <= 0x7F: continue # ascii is special cased above. if mapping[0] in turkishMapping: - out.write("case 0x%x: return mode == CaseFoldMode::kTurkish ? 0x%x : 0x%x;\n" % - (mapping[0], turkishMapping[mapping[0]], mapping[1])) + out.write( + "case 0x%x: return mode == CaseFoldMode::kTurkish ? 0x%x : 0x%x;\n" + % (mapping[0], turkishMapping[mapping[0]], mapping[1]) + ) else: out.write("case 0x%x: return 0x%x;\n" % mapping) - out.write("\ - default: return codepoint;\n }\n}") + out.write( + "\ + default: return codepoint;\n }\n}" + ) out.write(closeNamespaces()) diff --git a/src/mongo/db/fts/unicode/gen_delimiter_list.py b/src/mongo/db/fts/unicode/gen_delimiter_list.py index 4ca79bcdc28..617980d322c 100644 --- a/src/mongo/db/fts/unicode/gen_delimiter_list.py +++ b/src/mongo/db/fts/unicode/gen_delimiter_list.py @@ -2,8 +2,7 @@ # -*- coding: utf-8 -*- import sys -from gen_helper import getCopyrightNotice, openNamespaces, closeNamespaces, \ - include +from gen_helper import getCopyrightNotice, openNamespaces, closeNamespaces, include def generate(unicode_proplist_file, target): @@ -22,27 +21,32 @@ def generate(unicode_proplist_file, target): delim_codepoints = set() - proplist_file = open(unicode_proplist_file, 'r') + proplist_file = open(unicode_proplist_file, "r") delim_properties = [ - "White_Space", "Dash", "Hyphen", "Quotation_Mark", "Terminal_Punctuation", "Pattern_Syntax", - "STerm" + "White_Space", + "Dash", + "Hyphen", + "Quotation_Mark", + "Terminal_Punctuation", + "Pattern_Syntax", + "STerm", ] for line in proplist_file: # Filter out blank lines and lines that start with # - data = line[:line.find('#')] - if (data == ""): + data = line[: line.find("#")] + if data == "": continue # Parse the data on the line values = data.split("; ") - assert (len(values) == 2) + assert len(values) == 2 uproperty = values[1].strip() if uproperty in delim_properties: - if len(values[0].split('..')) == 2: - codepoint_range = values[0].split('..') + if len(values[0].split("..")) == 2: + codepoint_range = values[0].split("..") start = int(codepoint_range[0], 16) end = int(codepoint_range[1], 16) + 1 @@ -80,13 +84,19 @@ def generate(unicode_proplist_file, target): switch (codepoint) {\n""") for delim in sorted(delim_codepoints): - if delim <= 0x7f: # ascii codepoints handled in lists above. + if delim <= 0x7F: # ascii codepoints handled in lists above. continue - out.write("\ - case " + str(hex(delim)) + ": return true;\n") + out.write( + "\ + case " + + str(hex(delim)) + + ": return true;\n" + ) - out.write("\ - default: return false;\n }\n}") + out.write( + "\ + default: return false;\n }\n}" + ) out.write(closeNamespaces()) diff --git a/src/mongo/db/fts/unicode/gen_diacritic_list.py b/src/mongo/db/fts/unicode/gen_diacritic_list.py index 649525c52da..ace1623c2c3 100644 --- a/src/mongo/db/fts/unicode/gen_diacritic_list.py +++ b/src/mongo/db/fts/unicode/gen_diacritic_list.py @@ -2,8 +2,7 @@ # -*- coding: utf-8 -*- import sys -from gen_helper import getCopyrightNotice, openNamespaces, closeNamespaces, \ - include +from gen_helper import getCopyrightNotice, openNamespaces, closeNamespaces, include def generate(unicode_proplist_file, target): @@ -21,22 +20,22 @@ def generate(unicode_proplist_file, target): diacritics = set() - proplist_file = open(unicode_proplist_file, 'r') + proplist_file = open(unicode_proplist_file, "r") for line in proplist_file: # Filter out blank lines and lines that start with # - data = line[:line.find('#')] - if (data == ""): + data = line[: line.find("#")] + if data == "": continue # Parse the data on the line values = data.split("; ") - assert (len(values) == 2) + assert len(values) == 2 uproperty = values[1].strip() if uproperty in "Diacritic": - if len(values[0].split('..')) == 2: - codepoint_range = values[0].split('..') + if len(values[0].split("..")) == 2: + codepoint_range = values[0].split("..") start = int(codepoint_range[0], 16) end = int(codepoint_range[1], 16) + 1 @@ -52,11 +51,17 @@ def generate(unicode_proplist_file, target): switch (codepoint) {\n""") for diacritic in sorted(diacritics): - out.write("\ - case " + str(hex(diacritic)) + ": return true;\n") + out.write( + "\ + case " + + str(hex(diacritic)) + + ": return true;\n" + ) - out.write("\ - default: return false;\n }\n}") + out.write( + "\ + default: return false;\n }\n}" + ) out.write(closeNamespaces()) diff --git a/src/mongo/db/fts/unicode/gen_diacritic_map.py b/src/mongo/db/fts/unicode/gen_diacritic_map.py index cabb33c9419..0f73be92b7c 100644 --- a/src/mongo/db/fts/unicode/gen_diacritic_map.py +++ b/src/mongo/db/fts/unicode/gen_diacritic_map.py @@ -9,22 +9,22 @@ diacritics = set() def load_diacritics(unicode_proplist_file): - proplist_file = open(unicode_proplist_file, 'r') + proplist_file = open(unicode_proplist_file, "r") for line in proplist_file: # Filter out blank lines and lines that start with # - data = line[:line.find('#')] - if (data == ""): + data = line[: line.find("#")] + if data == "": continue # Parse the data on the line values = data.split("; ") - assert (len(values) == 2) + assert len(values) == 2 uproperty = values[1].strip() if uproperty == "Diacritic": - if len(values[0].split('..')) == 2: - codepoint_range = values[0].split('..') + if len(values[0].split("..")) == 2: + codepoint_range = values[0].split("..") start = int(codepoint_range[0], 16) end = int(codepoint_range[1], 16) + 1 @@ -46,18 +46,18 @@ def add_diacritic_mapping(codepoint): # r : decomposed unicode character with diacritics removed # c : recomposed unicode character with diacritics removed a = chr(codepoint) - d = normalize('NFD', a) - r = '' + d = normalize("NFD", a) + r = "" for i in range(len(d)): if ord(d[i]) not in diacritics: r += d[i] - c = normalize('NFC', r) + c = normalize("NFC", r) # Only use mappings where the final recomposed form is a single codepoint - if (a != c and len(c) == 1): - assert c != '\0' # This is used to indicate the codepoint is a pure diacritic. + if a != c and len(c) == 1: + assert c != "\0" # This is used to indicate the codepoint is a pure diacritic. assert ord(c) not in diacritics diacritic_mappings[codepoint] = ord(c[0]) @@ -98,8 +98,13 @@ def generate(target): sorted_mappings = sorted(mappings_list, key=lambda mapping: mapping[0]) for mapping in sorted_mappings: - out.write(" case " + str(hex(mapping[0])) + ": return " + \ - str(hex(mapping[1])) +";\n") + out.write( + " case " + + str(hex(mapping[0])) + + ": return " + + str(hex(mapping[1])) + + ";\n" + ) out.write(" default: return codepoint;\n }\n}") @@ -107,7 +112,7 @@ def generate(target): if __name__ == "__main__": - if (unidata_version != '8.0.0'): + if unidata_version != "8.0.0": print("""ERROR: This script must be run with a version of Python that \ contains the Unicode 8.0.0 Character Database.""") sys.exit(1) diff --git a/src/mongo/db/query/optimizer/optimizer_gdb_test.py b/src/mongo/db/query/optimizer/optimizer_gdb_test.py index 21eb9a0d09f..50c22ea84f5 100644 --- a/src/mongo/db/query/optimizer/optimizer_gdb_test.py +++ b/src/mongo/db/query/optimizer/optimizer_gdb_test.py @@ -1,5 +1,4 @@ -"""Script to be invoked by GDB for testing optimizer pretty printers. -""" +"""Script to be invoked by GDB for testing optimizer pretty printers.""" import difflib import string @@ -10,90 +9,105 @@ import gdb def output_diff(actual, expected): str = "" for text in difflib.unified_diff(expected.split("\n"), actual.split("\n")): - if text[:3] not in ('+++', '---', '@@ '): + if text[:3] not in ("+++", "---", "@@ "): str += text + "\n" return str + def remove_whitespace(str): remove = string.whitespace mapping = {ord(c): None for c in remove} return str.translate(mapping) -# Asserts on the pretty printed string of the local 'variable' by comparing to 'expected'. + +# Asserts on the pretty printed string of the local 'variable' by comparing to 'expected'. def assertPrintedOutput(variable, expected): - actual = gdb.execute('print ' + variable, to_string=True).split(" = ", 1)[1] - assert remove_whitespace(actual) == remove_whitespace(expected), \ - '[case: \'' + variable + '\'] Diff:\n' + output_diff(actual, expected) + actual = gdb.execute("print " + variable, to_string=True).split(" = ", 1)[1] + assert remove_whitespace(actual) == remove_whitespace(expected), ( + "[case: '" + variable + "'] Diff:\n" + output_diff(actual, expected) + ) print("TEST PASSED - " + variable) -if __name__ == '__main__': - try: - gdb.execute('run') - gdb.execute('frame function main') - # These tests work in tandem with the test binary 'optimizer_gdb_test_program'. Each test - # case inspects a local variable by invoking the appropriate pretty printer and comparing - # to the expected output. +if __name__ == "__main__": + try: + gdb.execute("run") + gdb.execute("frame function main") + + # These tests work in tandem with the test binary 'optimizer_gdb_test_program'. Each test + # case inspects a local variable by invoking the appropriate pretty printer and comparing + # to the expected output. # Test ABT containing a conjunction over the fields 'a' and 'b'. - assertPrintedOutput('testABT', - "\nRoot[\"root\"]\n" + - "Filter\n" + - "| EvalFilter\n" + - "| | Variable[\"root\"], \n" + - "| PathComposeM\n" + - "| | PathGet[\"b\"]\n" + - "| | PathCompare[Eq]\n" + - "| | Constant[\"1\"], \n" + - "| PathGet[\"a\"]\n" + - "| PathCompare[Eq]\n" + - "| Constant[\"1\"], \n" + - "Scan[\"coll\", \"root\"]\n") - - # After exploration phase, the optimized ABT should demonstrate the conversion to + assertPrintedOutput( + "testABT", + '\nRoot["root"]\n' + + "Filter\n" + + "| EvalFilter\n" + + '| | Variable["root"], \n' + + "| PathComposeM\n" + + '| | PathGet["b"]\n' + + "| | PathCompare[Eq]\n" + + '| | Constant["1"], \n' + + '| PathGet["a"]\n' + + "| PathCompare[Eq]\n" + + '| Constant["1"], \n' + + 'Scan["coll", "root"]\n', + ) + + # After exploration phase, the optimized ABT should demonstrate the conversion to # SargableNodes. Since the test program mimics an index over 'a', only one of the # predicates can use the index with the other remaining as residual. - assertPrintedOutput('optimized', - "\nRoot[\"root\"]\n" + - "RIDIntersect[\"root\"]\n" + - "| Sargable [Seek]\n" + - "| | | | requirements: root, 'PathGet [b] PathIdentity []', =Const [1], \n" + - "| | | candidateIndexes: [], \n" + - "| | scan_params: (fields: std::map with 1 element : ([\"b\"] : \"evalTemp_4\"), " \ - "residual: ), \n" + - "| Scan[\"coll\", \"root\"], \n" + - "Sargable [Index]\n" + - "| | requirements: root, 'PathGet [a] PathIdentity []', =Const [1], \n" + - "| candidateIndexes: [], \n" + - "Scan[\"coll\", \"root\"]\n") + assertPrintedOutput( + "optimized", + '\nRoot["root"]\n' + + 'RIDIntersect["root"]\n' + + "| Sargable [Seek]\n" + + "| | | | requirements: root, 'PathGet [b] PathIdentity []', =Const [1], \n" + + "| | | candidateIndexes: [], \n" + + '| | scan_params: (fields: std::map with 1 element : (["b"] : "evalTemp_4"), ' + "residual: ), \n" + + '| Scan["coll", "root"], \n' + + "Sargable [Index]\n" + + "| | requirements: root, 'PathGet [a] PathIdentity []', =Const [1], \n" + + "| candidateIndexes: [], \n" + + 'Scan["coll", "root"]\n', + ) # Verify interesting pieces of the indexed SargableNode. - assertPrintedOutput('indexSargable.getCandidateIndexes()', - "std::vector of length 1, capacity 1 = {index1, {}, {SimpleEquality}, {{{=Const [1]}}}\n}") - + assertPrintedOutput( + "indexSargable.getCandidateIndexes()", + "std::vector of length 1, capacity 1 = {index1, {}, {SimpleEquality}, {{{=Const [1]}}}\n}", + ) + # Verify interesting pieces of the seek SargableNode, including the residual predicates in the scan params. - assertPrintedOutput('residualSargable', - "Sargable [Seek] = {\n" + - "| | | = requirements: {{{root, 'PathGet [b] PathIdentity []', {{{=Const [1]}}}}}}, \n" + - "| | = candidateIndexes: [], \n" + - "| = scan_params: (fields: {std::map with 1 element : ([\"b\"] : \"evalTemp_4\")}, " \ - "residual: ), \n" + - " = Scan[\"coll\", \"root\"]}\n") - - assertPrintedOutput('testInterval', - "{\n" + - " {[Const [1], Const [3]]}\n" + - " U \n" + - " {[Const [4], Const [5]]}\n" + - "}\n") + assertPrintedOutput( + "residualSargable", + "Sargable [Seek] = {\n" + + "| | | = requirements: {{{root, 'PathGet [b] PathIdentity []', {{{=Const [1]}}}}}}, \n" + + "| | = candidateIndexes: [], \n" + + '| = scan_params: (fields: {std::map with 1 element : (["b"] : "evalTemp_4")}, ' + "residual: ), \n" + + ' = Scan["coll", "root"]}\n', + ) - assertPrintedOutput('emptyProjectionMap', - "{std::map with 0 elements}\n") + assertPrintedOutput( + "testInterval", + "{\n" + + " {[Const [1], Const [3]]}\n" + + " U \n" + + " {[Const [4], Const [5]]}\n" + + "}\n", + ) - assertPrintedOutput('testProjectionMap', - "{: \"test\", std::map with 2 elements : ([\"a\"] : \"b\", [\"c\"] : \"d\")}") + assertPrintedOutput("emptyProjectionMap", "{std::map with 0 elements}\n") - gdb.write('TEST PASSED\n') + assertPrintedOutput( + "testProjectionMap", + '{: "test", std::map with 2 elements : (["a"] : "b", ["c"] : "d")}', + ) + + gdb.write("TEST PASSED\n") except Exception as err: - gdb.write('TEST FAILED -- {!s}\n'.format(err)) - gdb.execute('quit 1', to_string=True) + gdb.write("TEST FAILED -- {!s}\n".format(err)) + gdb.execute("quit 1", to_string=True) diff --git a/src/mongo/mongo_config_header.py b/src/mongo/mongo_config_header.py index 5b98803cf34..39b686748b5 100644 --- a/src/mongo/mongo_config_header.py +++ b/src/mongo/mongo_config_header.py @@ -145,7 +145,9 @@ def explicit_bzero_present_flag() -> list[HeaderDefinition]: def pthread_setname_np_present_flag() -> list[HeaderDefinition]: - log_check("[MONGO_CONFIG_HAVE_PTHREAD_SETNAME_NP] Checking for pthread_setname_np...") + log_check( + "[MONGO_CONFIG_HAVE_PTHREAD_SETNAME_NP] Checking for pthread_setname_np..." + ) if compile_check(""" #ifndef _GNU_SOURCE @@ -377,7 +379,9 @@ def altivec_vbpermq_output_flag() -> list[HeaderDefinition]: for index in [0, 1]: if check_altivec_vbpermq_output(index): - return [HeaderDefinition("MONGO_CONFIG_ALTIVEC_VEC_VBPERMQ_OUTPUT_INDEX", index)] + return [ + HeaderDefinition("MONGO_CONFIG_ALTIVEC_VEC_VBPERMQ_OUTPUT_INDEX", index) + ] return [] @@ -385,7 +389,9 @@ def usdt_provider_flags() -> list[HeaderDefinition]: if platform.system() == "Darwin": return [] - log_check("[MONGO_CONFIG_USDT_PROVIDER] Checking if SDT usdt provider is available...") + log_check( + "[MONGO_CONFIG_USDT_PROVIDER] Checking if SDT usdt provider is available..." + ) if compile_check(""" #include int main(void) { return 0; } @@ -399,39 +405,64 @@ def usdt_provider_flags() -> list[HeaderDefinition]: def get_config_header_substs(): config_header_substs = ( - ('@mongo_config_altivec_vec_vbpermq_output_index@', - 'MONGO_CONFIG_ALTIVEC_VEC_VBPERMQ_OUTPUT_INDEX'), - ('@mongo_config_debug_build@', 'MONGO_CONFIG_DEBUG_BUILD'), - ('@mongo_config_have_execinfo_backtrace@', 'MONGO_CONFIG_HAVE_EXECINFO_BACKTRACE'), - ('@mongo_config_have_explicit_bzero@', 'MONGO_CONFIG_HAVE_EXPLICIT_BZERO'), - ('@mongo_config_have_fips_mode_set@', 'MONGO_CONFIG_HAVE_FIPS_MODE_SET'), - ('@mongo_config_have_header_unistd_h@', 'MONGO_CONFIG_HAVE_HEADER_UNISTD_H'), - ('@mongo_config_have_memset_s@', 'MONGO_CONFIG_HAVE_MEMSET_S'), - ('@mongo_config_have_posix_monotonic_clock@', 'MONGO_CONFIG_HAVE_POSIX_MONOTONIC_CLOCK'), - ('@mongo_config_have_pthread_setname_np@', 'MONGO_CONFIG_HAVE_PTHREAD_SETNAME_NP'), - ('@mongo_config_have_ssl_ec_key_new@', 'MONGO_CONFIG_HAVE_SSL_EC_KEY_NEW'), - ('@mongo_config_have_ssl_set_ecdh_auto@', 'MONGO_CONFIG_HAVE_SSL_SET_ECDH_AUTO'), - ('@mongo_config_have_strnlen@', 'MONGO_CONFIG_HAVE_STRNLEN'), - ('@mongo_config_max_extended_alignment@', 'MONGO_CONFIG_MAX_EXTENDED_ALIGNMENT'), - ('@mongo_config_ocsp_stapling_enabled@', 'MONGO_CONFIG_OCSP_STAPLING_ENABLED'), - ('@mongo_config_optimized_build@', 'MONGO_CONFIG_OPTIMIZED_BUILD'), - ('@mongo_config_ssl_has_asn1_any_definitions@', 'MONGO_CONFIG_HAVE_ASN1_ANY_DEFINITIONS'), - ('@mongo_config_ssl_provider@', 'MONGO_CONFIG_SSL_PROVIDER'), - ('@mongo_config_ssl@', 'MONGO_CONFIG_SSL'), - ('@mongo_config_usdt_enabled@', 'MONGO_CONFIG_USDT_ENABLED'), - ('@mongo_config_usdt_provider@', 'MONGO_CONFIG_USDT_PROVIDER'), - ('@mongo_config_use_libunwind@', 'MONGO_CONFIG_USE_LIBUNWIND'), - ('@mongo_config_use_raw_latches@', 'MONGO_CONFIG_USE_RAW_LATCHES'), - ('@mongo_config_wiredtiger_enabled@', 'MONGO_CONFIG_WIREDTIGER_ENABLED'), - ('@mongo_config_glibc_rseq@', 'MONGO_CONFIG_GLIBC_RSEQ'), - ('@mongo_config_tcmalloc_google@', 'MONGO_CONFIG_TCMALLOC_GOOGLE'), - ('@mongo_config_tcmalloc_gperf@', 'MONGO_CONFIG_TCMALLOC_GPERF'), + ( + "@mongo_config_altivec_vec_vbpermq_output_index@", + "MONGO_CONFIG_ALTIVEC_VEC_VBPERMQ_OUTPUT_INDEX", + ), + ("@mongo_config_debug_build@", "MONGO_CONFIG_DEBUG_BUILD"), + ( + "@mongo_config_have_execinfo_backtrace@", + "MONGO_CONFIG_HAVE_EXECINFO_BACKTRACE", + ), + ("@mongo_config_have_explicit_bzero@", "MONGO_CONFIG_HAVE_EXPLICIT_BZERO"), + ("@mongo_config_have_fips_mode_set@", "MONGO_CONFIG_HAVE_FIPS_MODE_SET"), + ("@mongo_config_have_header_unistd_h@", "MONGO_CONFIG_HAVE_HEADER_UNISTD_H"), + ("@mongo_config_have_memset_s@", "MONGO_CONFIG_HAVE_MEMSET_S"), + ( + "@mongo_config_have_posix_monotonic_clock@", + "MONGO_CONFIG_HAVE_POSIX_MONOTONIC_CLOCK", + ), + ( + "@mongo_config_have_pthread_setname_np@", + "MONGO_CONFIG_HAVE_PTHREAD_SETNAME_NP", + ), + ("@mongo_config_have_ssl_ec_key_new@", "MONGO_CONFIG_HAVE_SSL_EC_KEY_NEW"), + ( + "@mongo_config_have_ssl_set_ecdh_auto@", + "MONGO_CONFIG_HAVE_SSL_SET_ECDH_AUTO", + ), + ("@mongo_config_have_strnlen@", "MONGO_CONFIG_HAVE_STRNLEN"), + ( + "@mongo_config_max_extended_alignment@", + "MONGO_CONFIG_MAX_EXTENDED_ALIGNMENT", + ), + ("@mongo_config_ocsp_stapling_enabled@", "MONGO_CONFIG_OCSP_STAPLING_ENABLED"), + ("@mongo_config_optimized_build@", "MONGO_CONFIG_OPTIMIZED_BUILD"), + ( + "@mongo_config_ssl_has_asn1_any_definitions@", + "MONGO_CONFIG_HAVE_ASN1_ANY_DEFINITIONS", + ), + ("@mongo_config_ssl_provider@", "MONGO_CONFIG_SSL_PROVIDER"), + ("@mongo_config_ssl@", "MONGO_CONFIG_SSL"), + ("@mongo_config_usdt_enabled@", "MONGO_CONFIG_USDT_ENABLED"), + ("@mongo_config_usdt_provider@", "MONGO_CONFIG_USDT_PROVIDER"), + ("@mongo_config_use_libunwind@", "MONGO_CONFIG_USE_LIBUNWIND"), + ("@mongo_config_use_raw_latches@", "MONGO_CONFIG_USE_RAW_LATCHES"), + ("@mongo_config_wiredtiger_enabled@", "MONGO_CONFIG_WIREDTIGER_ENABLED"), + ("@mongo_config_glibc_rseq@", "MONGO_CONFIG_GLIBC_RSEQ"), + ("@mongo_config_tcmalloc_google@", "MONGO_CONFIG_TCMALLOC_GOOGLE"), + ("@mongo_config_tcmalloc_gperf@", "MONGO_CONFIG_TCMALLOC_GPERF"), ) return config_header_substs def generate_config_header( - compiler_path, compiler_args, env_vars, logpath, additional_inputs=[], extra_definitions={} + compiler_path, + compiler_args, + env_vars, + logpath, + additional_inputs=[], + extra_definitions={}, ) -> Dict[str, str]: global logfile_path CompilerSettings.compiler_path = compiler_path diff --git a/src/mongo/tools/mongo_tidy_checks/tests/MongoTidyCheck_unittest.py b/src/mongo/tools/mongo_tidy_checks/tests/MongoTidyCheck_unittest.py index facd531d5ca..7d8ad13146a 100644 --- a/src/mongo/tools/mongo_tidy_checks/tests/MongoTidyCheck_unittest.py +++ b/src/mongo/tools/mongo_tidy_checks/tests/MongoTidyCheck_unittest.py @@ -15,17 +15,29 @@ class MongoTidyTests(unittest.TestCase): "--@bazel_clang_tidy//:clang_tidy_config=//src/mongo/tools/mongo_tidy_checks/tests:" + self._testMethodName + "_tidy_config", - "//src/mongo/tools/mongo_tidy_checks/tests:" + self._testMethodName + "_with_debug", + "//src/mongo/tools/mongo_tidy_checks/tests:" + + self._testMethodName + + "_with_debug", ] p = subprocess.run( - cmd, cwd=os.environ.get("BUILD_WORKSPACE_DIRECTORY"), capture_output=True, text=True + cmd, + cwd=os.environ.get("BUILD_WORKSPACE_DIRECTORY"), + capture_output=True, + text=True, ) if isinstance(self.expected_output, list): - passed = all([expected_output in p.stdout for expected_output in self.expected_output]) + passed = all( + [ + expected_output in p.stdout + for expected_output in self.expected_output + ] + ) print_expected_output = "\n".join(self.expected_output) else: - passed = self.expected_output is not None and self.expected_output in p.stdout + passed = ( + self.expected_output is not None and self.expected_output in p.stdout + ) print_expected_output = self.expected_output msg = "\n".join( @@ -59,7 +71,6 @@ class MongoTidyTests(unittest.TestCase): self.fail() def test_MongoHeaderBracketCheck(self): - self.expected_output = [ "error: non-mongo include 'cctype' should use angle brackets", "error: mongo include 'test_MongoHeaderBracketCheck.h' should use double quotes", @@ -69,7 +80,6 @@ class MongoTidyTests(unittest.TestCase): self.run_clang_tidy() def test_MongoUninterruptibleLockGuardCheck(self): - self.expected_output = ( "Potentially incorrect use of UninterruptibleLockGuard, " "the programming model inside MongoDB requires that all operations be interruptible. " @@ -79,7 +89,6 @@ class MongoTidyTests(unittest.TestCase): self.run_clang_tidy() def test_MongoUninterruptibleLockGuardCheckForOpCtxMember(self): - self.expected_output = ( "Potentially incorrect use of " "OperationContext::uninterruptibleLocksRequested_DO_NOT_USE, this is a legacy " @@ -92,16 +101,14 @@ class MongoTidyTests(unittest.TestCase): self.run_clang_tidy() def test_MongoCctypeCheck(self): - self.expected_output = [ - "Use of prohibited \"cctype\" header, use \"mongo/util/ctype.h\"", - "Use of prohibited header, use \"mongo/util/ctype.h\"", + 'Use of prohibited "cctype" header, use "mongo/util/ctype.h"', + 'Use of prohibited header, use "mongo/util/ctype.h"', ] self.run_clang_tidy() def test_MongoCxx20BannedIncludesCheck(self): - self.expected_output = [ "Use of prohibited header.", "Use of prohibited header.", @@ -113,15 +120,23 @@ class MongoTidyTests(unittest.TestCase): self.run_clang_tidy() def test_MongoCxx20StdChronoCheck(self): - - prohibited_types = ["day", "day", "month", "year", "month_day", "month", "day", "day"] + prohibited_types = [ + "day", + "day", + "month", + "year", + "month_day", + "month", + "day", + "day", + ] self.expected_output = [ - f"Illegal use of prohibited type 'std::chrono::{t}'." for t in prohibited_types + f"Illegal use of prohibited type 'std::chrono::{t}'." + for t in prohibited_types ] self.run_clang_tidy() def test_MongoStdOptionalCheck(self): - self.expected_output = [ "Use of std::optional, use boost::optional instead. [mongo-std-optional-check,-warnings-as-errors]\nvoid f(std::optional parameterDeclTest) {", "Use of std::optional, use boost::optional instead. [mongo-std-optional-check,-warnings-as-errors]\n std::optional variableDeclTest;", @@ -135,17 +150,15 @@ class MongoTidyTests(unittest.TestCase): self.run_clang_tidy() def test_MongoVolatileCheck(self): - self.expected_output = [ - "Illegal use of the volatile storage keyword, use AtomicWord instead from \"mongo/platform/atomic_word.h\" [mongo-volatile-check,-warnings-as-errors]\nvolatile int varVolatileTest;", - "Illegal use of the volatile storage keyword, use AtomicWord instead from \"mongo/platform/atomic_word.h\" [mongo-volatile-check,-warnings-as-errors]\n volatile int fieldVolatileTest;", - "Illegal use of the volatile storage keyword, use AtomicWord instead from \"mongo/platform/atomic_word.h\" [mongo-volatile-check,-warnings-as-errors]\nvoid functionName(volatile int varVolatileTest) {}", + 'Illegal use of the volatile storage keyword, use AtomicWord instead from "mongo/platform/atomic_word.h" [mongo-volatile-check,-warnings-as-errors]\nvolatile int varVolatileTest;', + 'Illegal use of the volatile storage keyword, use AtomicWord instead from "mongo/platform/atomic_word.h" [mongo-volatile-check,-warnings-as-errors]\n volatile int fieldVolatileTest;', + 'Illegal use of the volatile storage keyword, use AtomicWord instead from "mongo/platform/atomic_word.h" [mongo-volatile-check,-warnings-as-errors]\nvoid functionName(volatile int varVolatileTest) {}', ] self.run_clang_tidy() def test_MongoTraceCheck(self): - self.expected_output = [ "Illegal use of prohibited tracing support, this is only for local development use and should not be committed. [mongo-trace-check,-warnings-as-errors]\n TracerProvider::initialize();", "Illegal use of prohibited tracing support, this is only for local development use and should not be committed. [mongo-trace-check,-warnings-as-errors]\n TracerProvider provider = TracerProvider::get();", @@ -154,16 +167,14 @@ class MongoTidyTests(unittest.TestCase): self.run_clang_tidy() def test_MongoStdAtomicCheck(self): - self.expected_output = [ - "Illegal use of prohibited std::atomic, use AtomicWord or other types from \"mongo/platform/atomic_word.h\" [mongo-std-atomic-check,-warnings-as-errors]\nstd::atomic atomic_var;", - "Illegal use of prohibited std::atomic, use AtomicWord or other types from \"mongo/platform/atomic_word.h\" [mongo-std-atomic-check,-warnings-as-errors]\n std::atomic field_decl;", + 'Illegal use of prohibited std::atomic, use AtomicWord or other types from "mongo/platform/atomic_word.h" [mongo-std-atomic-check,-warnings-as-errors]\nstd::atomic atomic_var;', + 'Illegal use of prohibited std::atomic, use AtomicWord or other types from "mongo/platform/atomic_word.h" [mongo-std-atomic-check,-warnings-as-errors]\n std::atomic field_decl;', ] self.run_clang_tidy() def test_MongoMutexCheck(self): - self.expected_output = [ "Illegal use of prohibited stdx::mutex, use mongo::Mutex from mongo/platform/mutex.h instead. [mongo-mutex-check,-warnings-as-errors]\nstdx::mutex stdxmutex_vardecl;", "Illegal use of prohibited stdx::mutex, use mongo::Mutex from mongo/platform/mutex.h instead. [mongo-mutex-check,-warnings-as-errors]\nstd::mutex stdmutex_vardecl;", @@ -174,7 +185,6 @@ class MongoTidyTests(unittest.TestCase): self.run_clang_tidy() def test_MongoAssertCheck(self): - self.expected_output = [ "error: Illegal use of the bare assert function, use a function from assert_util.h instead", ] @@ -182,7 +192,6 @@ class MongoTidyTests(unittest.TestCase): self.run_clang_tidy() def test_MongoFCVConstantCheck(self): - self.expected_output = [ "error: Illegal use of FCV constant in FCV comparison check functions. FCV gating should be done through feature flags instead.", ] @@ -190,7 +199,6 @@ class MongoTidyTests(unittest.TestCase): self.run_clang_tidy() def test_MongoUnstructuredLogCheck(self): - self.expected_output = [ "error: Illegal use of unstructured logging, this is only for local development use and should not be committed [mongo-unstructured-log-check,-warnings-as-errors]\n logd();", "error: Illegal use of unstructured logging, this is only for local development use and should not be committed [mongo-unstructured-log-check,-warnings-as-errors]\n doUnstructuredLogImpl();", @@ -199,7 +207,6 @@ class MongoTidyTests(unittest.TestCase): self.run_clang_tidy() def test_MongoConfigHeaderCheck(self): - self.expected_output = [ "error: MONGO_CONFIG define used without prior inclusion of config.h [mongo-config-header-check,-warnings-as-errors]\n#define MONGO_CONFIG_TEST1 1", "error: MONGO_CONFIG define used without prior inclusion of config.h [mongo-config-header-check,-warnings-as-errors]\n#ifdef MONGO_CONFIG_TEST1", @@ -210,16 +217,14 @@ class MongoTidyTests(unittest.TestCase): self.run_clang_tidy() def test_MongoCollectionShardingRuntimeCheck(self): - self.expected_output = [ - "error: Illegal use of CollectionShardingRuntime outside of mongo/db/s/; use CollectionShardingState instead; see src/mongo/db/s/collection_sharding_state.h for details. [mongo-collection-sharding-runtime-check,-warnings-as-errors]\n CollectionShardingRuntime csr(5, \"Test\");", - "error: Illegal use of CollectionShardingRuntime outside of mongo/db/s/; use CollectionShardingState instead; see src/mongo/db/s/collection_sharding_state.h for details. [mongo-collection-sharding-runtime-check,-warnings-as-errors]\n int result = CollectionShardingRuntime::functionTest(7, \"Test\");", + 'error: Illegal use of CollectionShardingRuntime outside of mongo/db/s/; use CollectionShardingState instead; see src/mongo/db/s/collection_sharding_state.h for details. [mongo-collection-sharding-runtime-check,-warnings-as-errors]\n CollectionShardingRuntime csr(5, "Test");', + 'error: Illegal use of CollectionShardingRuntime outside of mongo/db/s/; use CollectionShardingState instead; see src/mongo/db/s/collection_sharding_state.h for details. [mongo-collection-sharding-runtime-check,-warnings-as-errors]\n int result = CollectionShardingRuntime::functionTest(7, "Test");', ] self.run_clang_tidy() def test_MongoMacroDefinitionLeaksCheck(self): - self.expected_output = [ "Missing #undef 'MONGO_LOGV2_DEFAULT_COMPONENT'", ] @@ -227,7 +232,6 @@ class MongoTidyTests(unittest.TestCase): self.run_clang_tidy() def test_MongoNoUniqueAddressCheck(self): - self.expected_output = [ "Illegal use of [[no_unique_address]]", ] @@ -235,7 +239,6 @@ class MongoTidyTests(unittest.TestCase): self.run_clang_tidy() def test_MongoPolyFillCheck(self): - self.expected_output = [ "error: Illegal use of banned name from std::/boost:: for std::mutex, use mongo::stdx:: variant instead", "error: Illegal use of banned name from std::/boost:: for std::future, use mongo::stdx:: variant instead", @@ -247,7 +250,6 @@ class MongoTidyTests(unittest.TestCase): self.run_clang_tidy() def test_MongoRandCheck(self): - self.expected_output = [ "error: Use of rand or srand, use or PseudoRandom instead. [mongo-rand-check,-warnings-as-errors]\n srand(time(0));", "error: Use of rand or srand, use or PseudoRandom instead. [mongo-rand-check,-warnings-as-errors]\n int random_number = rand();", @@ -256,7 +258,6 @@ class MongoTidyTests(unittest.TestCase): self.run_clang_tidy() def test_MongoStringDataConstRefCheck1(self): - self.expected_output = [ "Prefer passing StringData by value.", ] @@ -264,7 +265,6 @@ class MongoTidyTests(unittest.TestCase): self.run_clang_tidy() def test_MongoStringDataConstRefCheck2(self): - self.expected_output = [ "Prefer passing StringData by value.", ] @@ -272,12 +272,12 @@ class MongoTidyTests(unittest.TestCase): self.run_clang_tidy() def test_MongoStringDataConstRefCheck3(self): - self.expected_output = [ "", ] self.run_clang_tidy() + if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main() diff --git a/src/mongo/tools/workload_simulation/process_logs.py b/src/mongo/tools/workload_simulation/process_logs.py index 93e5ad37631..4941936c873 100755 --- a/src/mongo/tools/workload_simulation/process_logs.py +++ b/src/mongo/tools/workload_simulation/process_logs.py @@ -8,10 +8,15 @@ import matplotlib.pyplot as pyplot LOG_ID_WORKLOAD_NAME = 7782100 LOG_ID_METRICS = 7782101 -parser = argparse.ArgumentParser(description='Process simulator log output piped to stdin') -parser.add_argument("-o", "--outputDirectory", - help="Path to an existing directory where output should be stored", - metavar="~/path/for/output") +parser = argparse.ArgumentParser( + description="Process simulator log output piped to stdin" +) +parser.add_argument( + "-o", + "--outputDirectory", + help="Path to an existing directory where output should be stored", + metavar="~/path/for/output", +) args = parser.parse_args() directory = args.outputDirectory @@ -22,26 +27,28 @@ currentWorkload = "" firstTime = 0 for line in sys.stdin: parsed = json.loads(line) - if (parsed["id"] == LOG_ID_WORKLOAD_NAME): + if parsed["id"] == LOG_ID_WORKLOAD_NAME: # Starting output for a new workload. currentWorkload = parsed["attr"]["workload"] # Can't reuse the workload name - assert (currentWorkload not in workloads) + assert currentWorkload not in workloads # Initialize the data structure for the output. workloads[currentWorkload] = {"time": [], "metrics": {}} - elif (parsed["id"] == LOG_ID_METRICS): + elif parsed["id"] == LOG_ID_METRICS: # Check that data structure for current workload has been initialized properly. - assert (isinstance(workloads[currentWorkload], dict)) - assert (isinstance(workloads[currentWorkload]["time"], list)) - assert (isinstance(workloads[currentWorkload]["metrics"], dict)) + assert isinstance(workloads[currentWorkload], dict) + assert isinstance(workloads[currentWorkload]["time"], list) + assert isinstance(workloads[currentWorkload]["metrics"], dict) # Parsing output for the current workload. # Normalize time values so that first time value in series is '0', displayed in seconds. - if (len(workloads[currentWorkload]["time"]) == 0): + if len(workloads[currentWorkload]["time"]) == 0: firstTime = parsed["attr"]["time"] - workloads[currentWorkload]["time"].append((parsed["attr"]["time"] - firstTime) / 1E9) + workloads[currentWorkload]["time"].append( + (parsed["attr"]["time"] - firstTime) / 1e9 + ) # Process the metrics, initializing structures as necessary metrics = parsed["attr"]["metrics"] @@ -59,8 +66,13 @@ for workload, data in workloads.items(): width = min(numPlots, 3) height = int(numPlots / width) + (numPlots % width > 0) - fig, ax = pyplot.subplots(nrows=height, ncols=width, figsize=(7.5 * width, 3.5 * height), - sharex=True, layout="constrained") + fig, ax = pyplot.subplots( + nrows=height, + ncols=width, + figsize=(7.5 * width, 3.5 * height), + sharex=True, + layout="constrained", + ) fig.suptitle("Workload: " + workload) i = 0 diff --git a/src/mongo/util/generate_icu_init_cpp.py b/src/mongo/util/generate_icu_init_cpp.py index 3358fad25f7..f9c4a2f8521 100755 --- a/src/mongo/util/generate_icu_init_cpp.py +++ b/src/mongo/util/generate_icu_init_cpp.py @@ -34,10 +34,20 @@ import sys def main(argv): parser = optparse.OptionParser() - parser.add_option('-o', '--output', action='store', dest='output_cpp_file', - help='path to output cpp file') - parser.add_option('-i', '--input', action='store', dest='input_data_file', - help='input ICU data file, in common format (.dat)') + parser.add_option( + "-o", + "--output", + action="store", + dest="output_cpp_file", + help="path to output cpp file", + ) + parser.add_option( + "-i", + "--input", + action="store", + dest="input_data_file", + help="input ICU data file, in common format (.dat)", + ) (options, args) = parser.parse_args(argv) if len(args) > 1: parser.error("too many arguments") @@ -49,7 +59,7 @@ def main(argv): def generate_cpp_file(data_file_path, cpp_file_path): - source_template = '''// AUTO-GENERATED FILE DO NOT EDIT + source_template = """// AUTO-GENERATED FILE DO NOT EDIT // See generate_icu_init_cpp.py. /** * Copyright (C) 2018-present MongoDB, Inc. @@ -111,12 +121,15 @@ MONGO_INITIALIZER_GENERAL(LoadICUData, (), ("BeginStartupOptionHandling"))( } } // namespace mongo -''' - decimal_encoded_data = '' - with open(data_file_path, 'rb') as data_file: - decimal_encoded_data = ','.join([str(byte) for byte in data_file.read()]) - with open(cpp_file_path, 'w') as cpp_file: - cpp_file.write(source_template % dict(decimal_encoded_data=decimal_encoded_data)) +""" + decimal_encoded_data = "" + with open(data_file_path, "rb") as data_file: + decimal_encoded_data = ",".join([str(byte) for byte in data_file.read()]) + with open(cpp_file_path, "w") as cpp_file: + cpp_file.write( + source_template % dict(decimal_encoded_data=decimal_encoded_data) + ) -if __name__ == '__main__': + +if __name__ == "__main__": main(sys.argv) diff --git a/src/mongo/util/pretty_printer_test.py b/src/mongo/util/pretty_printer_test.py index 14c6409f23c..746a8220ea4 100644 --- a/src/mongo/util/pretty_printer_test.py +++ b/src/mongo/util/pretty_printer_test.py @@ -1,78 +1,95 @@ -"""Script to be invoked by GDB for testing decorable pretty printing. -""" +"""Script to be invoked by GDB for testing decorable pretty printing.""" import re import gdb expected_patterns = [ - r'Decorable with 3 elems', - r'vector of length 3.*\{ *123, *213, *312 *\}', + r"Decorable with 3 elems", + r"vector of length 3.*\{ *123, *213, *312 *\}", r'basic_string.* \= *"hello"', r'basic_string.* \= *"world"', ] -up_pattern = r'std::unique_ptr = \{get\(\) \= 0x[0-9a-fA-F]+\}' -set_pattern = r'std::[__debug::]*set with 4 elements' -static_member_pattern = '128' +up_pattern = r"std::unique_ptr = \{get\(\) \= 0x[0-9a-fA-F]+\}" +set_pattern = r"std::[__debug::]*set with 4 elements" +static_member_pattern = "128" def search(pattern, s): match = re.search(pattern, s) - assert match is not None, 'Did not find {!s} in {!s}'.format(pattern, s) + assert match is not None, "Did not find {!s} in {!s}".format(pattern, s) return match def test_decorable(): - d1_str = gdb.execute('print d1', to_string=True) + d1_str = gdb.execute("print d1", to_string=True) for pattern in expected_patterns: search(pattern, d1_str) - search(up_pattern, gdb.execute('print up', to_string=True)) - search(set_pattern, gdb.execute('print set_type', to_string=True)) - search(static_member_pattern, gdb.execute('print testClass::static_member', to_string=True)) + search(up_pattern, gdb.execute("print up", to_string=True)) + search(set_pattern, gdb.execute("print set_type", to_string=True)) + search( + static_member_pattern, + gdb.execute("print testClass::static_member", to_string=True), + ) def test_dbname_nss(): - dbname_str = gdb.execute('print dbName', to_string=True) + dbname_str = gdb.execute("print dbName", to_string=True) search("foo", dbname_str) - dbname_tid_str = gdb.execute('print dbNameWithTenantId', to_string=True) + dbname_tid_str = gdb.execute("print dbNameWithTenantId", to_string=True) search("6491a2112ef5c818703bf2a7_foo", dbname_tid_str) - nss_str = gdb.execute('print nss', to_string=True) + nss_str = gdb.execute("print nss", to_string=True) search("foo.ba", nss_str) - nss_tid_str = gdb.execute('print nssWithTenantId', to_string=True) + nss_tid_str = gdb.execute("print nssWithTenantId", to_string=True) search("6491a2112ef5c818703bf2a7_foo.barbaz", nss_tid_str) - long_nss_str = gdb.execute('print longNss', to_string=True) + long_nss_str = gdb.execute("print longNss", to_string=True) search("longdatabasenamewithoutsmallstring.longcollection", long_nss_str) - constexpr_str = gdb.execute('print kConstNs', to_string=True) + constexpr_str = gdb.execute("print kConstNs", to_string=True) search("constexpr.name", constexpr_str) - constexpr_str = gdb.execute('print constCopy', to_string=True) + constexpr_str = gdb.execute("print constCopy", to_string=True) search("constexpr.name", constexpr_str) def test_string_map(): - search(r'absl::flat_hash_map.*0 elems', gdb.execute('print emptyMap', to_string=True)) - int_map_results = gdb.execute('print intMap', to_string=True) - search(r'absl::flat_hash_map.*2 elems', int_map_results) + search( + r"absl::flat_hash_map.*0 elems", gdb.execute("print emptyMap", to_string=True) + ) + int_map_results = gdb.execute("print intMap", to_string=True) + search(r"absl::flat_hash_map.*2 elems", int_map_results) search(r'\["a"\] = 1', int_map_results) search(r'\["b"\] = 1', int_map_results) search( r'absl::flat_hash_map.*1 elems.*\{\["a"\] = "a_value"\}', gdb.execute("print strMap", to_string=True), ) - search(r'absl::flat_hash_set.*1 elems.*\{"a"\}', gdb.execute("print strSet", to_string=True)) + search( + r'absl::flat_hash_set.*1 elems.*\{"a"\}', + gdb.execute("print strSet", to_string=True), + ) # Non empty Hash, Eq, or Alloc functors should pretty print without issues. - search(r"absl::flat_hash_set.*0 elems", gdb.execute("print checkNonEmptyHash", to_string=True)) - search(r"absl::flat_hash_set.*0 elems", gdb.execute("print checkNonEmptyEq", to_string=True)) - search(r"absl::flat_hash_set.*0 elems", gdb.execute("print checkNonEmptyAlloc", to_string=True)) + search( + r"absl::flat_hash_set.*0 elems", + gdb.execute("print checkNonEmptyHash", to_string=True), + ) + search( + r"absl::flat_hash_set.*0 elems", + gdb.execute("print checkNonEmptyEq", to_string=True), + ) + search( + r"absl::flat_hash_set.*0 elems", + gdb.execute("print checkNonEmptyAlloc", to_string=True), + ) -if __name__ == '__main__': + +if __name__ == "__main__": try: - gdb.execute('run') - gdb.execute('frame function main') + gdb.execute("run") + gdb.execute("frame function main") test_decorable() test_dbname_nss() test_string_map() - gdb.write('TEST PASSED\n') + gdb.write("TEST PASSED\n") except Exception as err: - gdb.write('TEST FAILED -- {!s}\n'.format(err)) - gdb.execute('quit 1', to_string=True) + gdb.write("TEST FAILED -- {!s}\n".format(err)) + gdb.execute("quit 1", to_string=True) diff --git a/src/mongo/util/version_constants_gen.py b/src/mongo/util/version_constants_gen.py index abaa0cfa5eb..b83243f2132 100644 --- a/src/mongo/util/version_constants_gen.py +++ b/src/mongo/util/version_constants_gen.py @@ -159,7 +159,8 @@ def default_buildinfo_environment_data(compiler_path, extra_definitions, env_var ), ) return { - k: {"key": k, "value": v, "inBuildInfo": ibi, "inVersion": iv} for k, v, ibi, iv in data + k: {"key": k, "value": v, "inBuildInfo": ibi, "inVersion": iv} + for k, v, ibi, iv in data } @@ -195,7 +196,12 @@ def log_check(message): def generate_config_header( - compiler_path, compiler_args, env_vars, logpath, additional_inputs, extra_definitions={} + compiler_path, + compiler_args, + env_vars, + logpath, + additional_inputs, + extra_definitions={}, ) -> Dict[str, str]: global logfile_path logfile_path = logpath @@ -242,7 +248,9 @@ def generate_config_header( version_parts[3] = 0 version_parts = [int(x) for x in version_parts[:4]] - modules = ["enterprise"] if "build_enterprise_enabled" in extra_definitions_dict else [] + modules = ( + ["enterprise"] if "build_enterprise_enabled" in extra_definitions_dict else [] + ) module_list = ",\n".join(['"{0}"_sd'.format(x) for x in modules]) replacements = { @@ -259,4 +267,4 @@ def generate_config_header( "@buildinfo_environment_data@": buildInfoInitializer, } - return replacements \ No newline at end of file + return replacements