SERVER-90571: Enable python formatting checks for buildscripts/idl directory (#22293)
GitOrigin-RevId: a2fbc8ed83f576703cce96ebb5e680cc70aac4d8
This commit is contained in:
parent
83fa212b68
commit
bd2955c297
@ -40,7 +40,7 @@ from typing import Any, Dict, List, Mapping, Set
|
||||
from pymongo import MongoClient
|
||||
|
||||
# Permit imports from "buildscripts".
|
||||
sys.path.append(os.path.normpath(os.path.join(os.path.abspath(__file__), '../../..')))
|
||||
sys.path.append(os.path.normpath(os.path.join(os.path.abspath(__file__), "../../..")))
|
||||
|
||||
# pylint: disable=wrong-import-position
|
||||
from idl import syntax
|
||||
@ -53,7 +53,7 @@ from buildscripts.resmokelib.testing.fixtures.shardedcluster import ShardedClust
|
||||
from buildscripts.resmokelib.testing.fixtures.standalone import MongoDFixture
|
||||
# pylint: enable=wrong-import-position
|
||||
|
||||
LOGGER_NAME = 'check-idl-definitions'
|
||||
LOGGER_NAME = "check-idl-definitions"
|
||||
LOGGER = logging.getLogger(LOGGER_NAME)
|
||||
|
||||
|
||||
@ -68,8 +68,9 @@ def is_test_or_third_party_idl(idl_path: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def get_command_definitions(api_version: str, directory: str,
|
||||
import_directories: List[str]) -> Dict[str, syntax.Command]:
|
||||
def get_command_definitions(
|
||||
api_version: str, directory: str, import_directories: List[str]
|
||||
) -> Dict[str, syntax.Command]:
|
||||
"""Get parsed IDL definitions of commands in a given API version."""
|
||||
|
||||
LOGGER.info("Searching for command definitions in %s", directory)
|
||||
@ -109,22 +110,30 @@ def list_commands_for_api(api_version: str, mongod_or_mongos: str, install_dir:
|
||||
logger = loggers.new_fixture_logger("ShardedClusterFixture", 0)
|
||||
logger.parent = LOGGER
|
||||
fixture = fixturelib.make_fixture(
|
||||
"ShardedClusterFixture", logger, 0, dbpath_prefix=dbpath.name,
|
||||
mongos_executable=mongos_executable, mongod_executable=mongod_executable,
|
||||
mongod_options={"set_parameters": {}})
|
||||
"ShardedClusterFixture",
|
||||
logger,
|
||||
0,
|
||||
dbpath_prefix=dbpath.name,
|
||||
mongos_executable=mongos_executable,
|
||||
mongod_executable=mongod_executable,
|
||||
mongod_options={"set_parameters": {}},
|
||||
)
|
||||
|
||||
fixture.setup()
|
||||
fixture.await_ready()
|
||||
|
||||
try:
|
||||
client = MongoClient(fixture.get_driver_connection_url()) # type: MongoClient
|
||||
reply = client.admin.command('listCommands') # type: Mapping[str, Any]
|
||||
reply = client.admin.command("listCommands") # type: Mapping[str, Any]
|
||||
commands = {
|
||||
name
|
||||
for name, info in reply['commands'].items() if api_version in info['apiVersions']
|
||||
name for name, info in reply["commands"].items() if api_version in info["apiVersions"]
|
||||
}
|
||||
logging.info("Found %s commands in API Version %s on %s", len(commands), api_version,
|
||||
mongod_or_mongos)
|
||||
logging.info(
|
||||
"Found %s commands in API Version %s on %s",
|
||||
len(commands),
|
||||
api_version,
|
||||
mongod_or_mongos,
|
||||
)
|
||||
return commands
|
||||
finally:
|
||||
fixture.teardown()
|
||||
@ -144,13 +153,16 @@ def assert_command_sets_equal(api_version: str, command_sets: Dict[str, Set[str]
|
||||
for other_name, other_commands in it:
|
||||
if commands != other_commands:
|
||||
if commands - other_commands:
|
||||
LOGGER.error("%s has commands not in %s: %s", name, other_name,
|
||||
commands - other_commands)
|
||||
LOGGER.error(
|
||||
"%s has commands not in %s: %s", name, other_name, commands - other_commands
|
||||
)
|
||||
if other_commands - commands:
|
||||
LOGGER.error("%s has commands not in %s: %s", other_name, name,
|
||||
other_commands - commands)
|
||||
LOGGER.error(
|
||||
"%s has commands not in %s: %s", other_name, name, other_commands - commands
|
||||
)
|
||||
raise AssertionError(
|
||||
f"{name} and {other_name} have different commands in API Version {api_version}")
|
||||
f"{name} and {other_name} have different commands in API Version {api_version}"
|
||||
)
|
||||
|
||||
|
||||
def remove_skipped_commands(command_sets: Dict[str, Set[str]]):
|
||||
@ -173,10 +185,15 @@ def remove_skipped_commands(command_sets: Dict[str, Set[str]]):
|
||||
def main():
|
||||
"""Run the script."""
|
||||
arg_parser = argparse.ArgumentParser(description=__doc__)
|
||||
arg_parser.add_argument("--include", type=str, action="append",
|
||||
help="Directory to search for IDL import files")
|
||||
arg_parser.add_argument("--install-dir", dest="install_dir", required=True,
|
||||
help="Directory to search for MongoDB binaries")
|
||||
arg_parser.add_argument(
|
||||
"--include", type=str, action="append", help="Directory to search for IDL import files"
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"--install-dir",
|
||||
dest="install_dir",
|
||||
required=True,
|
||||
help="Directory to search for MongoDB binaries",
|
||||
)
|
||||
arg_parser.add_argument("-v", "--verbose", action="count", help="Enable verbose logging")
|
||||
arg_parser.add_argument("api_version", metavar="API_VERSION", help="API Version to check")
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
@ -41,10 +41,14 @@ if __name__ == "__main__" and __package__ is None:
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
|
||||
# pylint: disable=wrong-import-position
|
||||
from buildscripts.resmokelib.multiversionconstants import LAST_LTS_FCV, LAST_CONTINUOUS_FCV, LATEST_FCV
|
||||
from buildscripts.resmokelib.multiversionconstants import (
|
||||
LAST_LTS_FCV,
|
||||
LAST_CONTINUOUS_FCV,
|
||||
LATEST_FCV,
|
||||
)
|
||||
# pylint: enable=wrong-import-position
|
||||
|
||||
LOGGER_NAME = 'checkout-idl'
|
||||
LOGGER_NAME = "checkout-idl"
|
||||
LOGGER = logging.getLogger(LOGGER_NAME)
|
||||
|
||||
|
||||
@ -52,9 +56,9 @@ def get_tags() -> List[str]:
|
||||
"""Get a list of git tags that the IDL compatibility script should check against."""
|
||||
|
||||
def gen_versions_and_tags():
|
||||
for tag in check_output(['git', 'tag']).decode().split():
|
||||
for tag in check_output(["git", "tag"]).decode().split():
|
||||
# Releases are like "r5.6.7". Older ones aren't r-prefixed but we don't care about them.
|
||||
if not tag.startswith('r'):
|
||||
if not tag.startswith("r"):
|
||||
continue
|
||||
|
||||
try:
|
||||
@ -110,14 +114,14 @@ def make_idl_directories(tags: List[str], destination: str) -> None:
|
||||
for tag in tags:
|
||||
LOGGER.info("Checking out IDL files in %s", tag)
|
||||
directory = os.path.join(destination, tag)
|
||||
for path in check_output(['git', 'ls-tree', '--name-only', '-r', tag]).decode().split():
|
||||
if not path.endswith('.idl'):
|
||||
for path in check_output(["git", "ls-tree", "--name-only", "-r", tag]).decode().split():
|
||||
if not path.endswith(".idl"):
|
||||
continue
|
||||
|
||||
contents = check_output(['git', 'show', f'{tag}:{path}']).decode()
|
||||
contents = check_output(["git", "show", f"{tag}:{path}"]).decode()
|
||||
output_path = os.path.join(directory, path)
|
||||
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
||||
with open(output_path, 'w+') as fd:
|
||||
with open(output_path, "w+") as fd:
|
||||
fd.write(contents)
|
||||
|
||||
|
||||
@ -125,8 +129,9 @@ def main():
|
||||
"""Run the script."""
|
||||
arg_parser = argparse.ArgumentParser(description=__doc__)
|
||||
arg_parser.add_argument("-v", "--verbose", action="count", help="Enable verbose logging")
|
||||
arg_parser.add_argument("destination", metavar="DESTINATION",
|
||||
help="Directory to check out past IDL file versions")
|
||||
arg_parser.add_argument(
|
||||
"destination", metavar="DESTINATION", help="Directory to check out past IDL file versions"
|
||||
)
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
|
||||
@ -38,7 +38,7 @@ from typing import List
|
||||
import yaml
|
||||
|
||||
# Permit imports from "buildscripts".
|
||||
sys.path.append(os.path.normpath(os.path.join(os.path.abspath(__file__), '../../..')))
|
||||
sys.path.append(os.path.normpath(os.path.join(os.path.abspath(__file__), "../../..")))
|
||||
|
||||
# pylint: disable=wrong-import-position
|
||||
from buildscripts.idl import lib
|
||||
@ -60,7 +60,7 @@ def get_all_feature_flags(idl_dirs: List[str] = None):
|
||||
# Most IDL files do not contain feature flags.
|
||||
# We can discard these quickly without expensive YAML parsing.
|
||||
with open(idl_path) as idl_file:
|
||||
if 'feature_flags' not in idl_file.read():
|
||||
if "feature_flags" not in idl_file.read():
|
||||
continue
|
||||
with open(idl_path) as idl_file:
|
||||
doc = parser.parse_file(idl_file, idl_path)
|
||||
@ -100,5 +100,5 @@ def main():
|
||||
gen_all_feature_flags_file()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -36,7 +36,7 @@ import sys
|
||||
from typing import List
|
||||
|
||||
# Permit imports from "buildscripts".
|
||||
sys.path.append(os.path.normpath(os.path.join(os.path.abspath(__file__), '../../..')))
|
||||
sys.path.append(os.path.normpath(os.path.join(os.path.abspath(__file__), "../../..")))
|
||||
|
||||
# pylint: disable=wrong-import-position
|
||||
from buildscripts.idl import lib
|
||||
@ -58,7 +58,7 @@ def gen_all_server_params(idl_dirs: List[str] = None):
|
||||
# Most IDL files do not contain server parameters.
|
||||
# We can discard these quickly without expensive YAML parsing.
|
||||
with open(idl_path) as idl_file:
|
||||
if 'server_parameters' not in idl_file.read():
|
||||
if "server_parameters" not in idl_file.read():
|
||||
continue
|
||||
with open(idl_path) as idl_file:
|
||||
doc = parser.parse_file(idl_file, idl_path)
|
||||
@ -80,5 +80,5 @@ def main():
|
||||
gen_all_server_params_file()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -33,6 +33,7 @@ Represents the derived IDL specification after type resolution in the binding pa
|
||||
This is a lossy translation from the IDL Syntax tree as the IDL AST only contains information about
|
||||
the enums and structs that need code generated for them, and just enough information to do that.
|
||||
"""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
import enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
@ -46,8 +47,9 @@ class IDLBoundSpec(object):
|
||||
def __init__(self, spec, error_collection):
|
||||
# type: (IDLAST, errors.ParserErrorCollection) -> None
|
||||
"""Must specify either an IDL document or errors, not both."""
|
||||
assert (spec is None and error_collection is not None) or (spec is not None
|
||||
and error_collection is None)
|
||||
assert (spec is None and error_collection is not None) or (
|
||||
spec is not None and error_collection is None
|
||||
)
|
||||
self.spec = spec
|
||||
self.errors = error_collection
|
||||
|
||||
@ -292,7 +294,8 @@ class Field(common.SourceLocation):
|
||||
# type: () -> bool
|
||||
"""Returns true if the IDL compiler should add a call to serialization options for this field."""
|
||||
return self.query_shape is not None and self.query_shape in [
|
||||
QueryShapeFieldType.LITERAL, QueryShapeFieldType.ANONYMIZE
|
||||
QueryShapeFieldType.LITERAL,
|
||||
QueryShapeFieldType.ANONYMIZE,
|
||||
]
|
||||
|
||||
@property
|
||||
|
||||
@ -61,8 +61,9 @@ def _validate_single_bson_type(ctxt, idl_type, syntax_type):
|
||||
subtype = "<unknown>"
|
||||
|
||||
if not bson.is_valid_bindata_subtype(subtype):
|
||||
ctxt.add_bad_bson_bindata_subtype_value_error(idl_type, syntax_type, idl_type.name,
|
||||
subtype)
|
||||
ctxt.add_bad_bson_bindata_subtype_value_error(
|
||||
idl_type, syntax_type, idl_type.name, subtype
|
||||
)
|
||||
elif idl_type.bindata_subtype is not None:
|
||||
ctxt.add_bad_bson_bindata_subtype_error(idl_type, syntax_type, idl_type.name, bson_type)
|
||||
|
||||
@ -107,7 +108,7 @@ def _validate_type(ctxt, idl_type):
|
||||
if idl_type.name.startswith("array<"):
|
||||
ctxt.add_array_not_valid_error(idl_type, "type", idl_type.name)
|
||||
|
||||
_validate_type_properties(ctxt, idl_type, 'type')
|
||||
_validate_type_properties(ctxt, idl_type, "type")
|
||||
|
||||
|
||||
def _validate_cpp_type(ctxt, idl_type, syntax_type):
|
||||
@ -120,15 +121,17 @@ def _validate_cpp_type(ctxt, idl_type, syntax_type):
|
||||
ctxt.add_no_string_data_error(idl_type, syntax_type, idl_type.name)
|
||||
|
||||
# We do not support C++ char and float types for style reasons
|
||||
if idl_type.cpp_type in ['char', 'wchar_t', 'char16_t', 'char32_t', 'float']:
|
||||
ctxt.add_bad_cpp_numeric_type_use_error(idl_type, syntax_type, idl_type.name,
|
||||
idl_type.cpp_type)
|
||||
if idl_type.cpp_type in ["char", "wchar_t", "char16_t", "char32_t", "float"]:
|
||||
ctxt.add_bad_cpp_numeric_type_use_error(
|
||||
idl_type, syntax_type, idl_type.name, idl_type.cpp_type
|
||||
)
|
||||
|
||||
# We do not support C++ builtin integer for style reasons
|
||||
for numeric_word in ['signed', "unsigned", "int", "long", "short"]:
|
||||
if re.search(r'\b%s\b' % (numeric_word), idl_type.cpp_type):
|
||||
ctxt.add_bad_cpp_numeric_type_use_error(idl_type, syntax_type, idl_type.name,
|
||||
idl_type.cpp_type)
|
||||
for numeric_word in ["signed", "unsigned", "int", "long", "short"]:
|
||||
if re.search(r"\b%s\b" % (numeric_word), idl_type.cpp_type):
|
||||
ctxt.add_bad_cpp_numeric_type_use_error(
|
||||
idl_type, syntax_type, idl_type.name, idl_type.cpp_type
|
||||
)
|
||||
# Return early so we only throw one error for types like "signed short int"
|
||||
return
|
||||
|
||||
@ -151,27 +154,39 @@ def _validate_cpp_type(ctxt, idl_type, syntax_type):
|
||||
# Check for std fixed integer types which are not allowed. These are not allowed even if they
|
||||
# have the "std::" prefix.
|
||||
for std_numeric_type in [
|
||||
"int8_t", "int16_t", "int32_t", "int64_t", "uint8_t", "uint16_t", "uint32_t", "uint64_t"
|
||||
"int8_t",
|
||||
"int16_t",
|
||||
"int32_t",
|
||||
"int64_t",
|
||||
"uint8_t",
|
||||
"uint16_t",
|
||||
"uint32_t",
|
||||
"uint64_t",
|
||||
]:
|
||||
if std_numeric_type in idl_type.cpp_type:
|
||||
ctxt.add_bad_cpp_numeric_type_use_error(idl_type, syntax_type, idl_type.name,
|
||||
idl_type.cpp_type)
|
||||
ctxt.add_bad_cpp_numeric_type_use_error(
|
||||
idl_type, syntax_type, idl_type.name, idl_type.cpp_type
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def _validate_chain_type_properties(ctxt, idl_type, syntax_type):
|
||||
# type: (errors.ParserContext, Union[syntax.Type, ast.Type], str) -> None
|
||||
"""Validate a chained type has both a deserializer and serializer."""
|
||||
assert len(
|
||||
idl_type.bson_serialization_type) == 1 and idl_type.bson_serialization_type[0] == 'chain'
|
||||
assert (
|
||||
len(idl_type.bson_serialization_type) == 1
|
||||
and idl_type.bson_serialization_type[0] == "chain"
|
||||
)
|
||||
|
||||
if idl_type.deserializer is None:
|
||||
ctxt.add_missing_ast_required_field_error(idl_type, syntax_type, idl_type.name,
|
||||
"deserializer")
|
||||
ctxt.add_missing_ast_required_field_error(
|
||||
idl_type, syntax_type, idl_type.name, "deserializer"
|
||||
)
|
||||
|
||||
if idl_type.serializer is None:
|
||||
ctxt.add_missing_ast_required_field_error(idl_type, syntax_type, idl_type.name,
|
||||
"serializer")
|
||||
ctxt.add_missing_ast_required_field_error(
|
||||
idl_type, syntax_type, idl_type.name, "serializer"
|
||||
)
|
||||
|
||||
|
||||
def _validate_type_properties(ctxt, idl_type, syntax_type):
|
||||
@ -189,29 +204,34 @@ def _validate_type_properties(ctxt, idl_type, syntax_type):
|
||||
# serialization for their C++ type. An internal_only type is not associated with BSON
|
||||
# and thus should not have a deserializer defined.
|
||||
if idl_type.deserializer is None and not idl_type.internal_only:
|
||||
ctxt.add_missing_ast_required_field_error(idl_type, syntax_type, idl_type.name,
|
||||
"deserializer")
|
||||
ctxt.add_missing_ast_required_field_error(
|
||||
idl_type, syntax_type, idl_type.name, "deserializer"
|
||||
)
|
||||
elif bson_type == "chain":
|
||||
_validate_chain_type_properties(ctxt, idl_type, syntax_type)
|
||||
|
||||
elif bson_type == "string":
|
||||
# Strings support custom serialization unlike other non-object scalar types
|
||||
if idl_type.deserializer is None:
|
||||
ctxt.add_missing_ast_required_field_error(idl_type, syntax_type, idl_type.name,
|
||||
"deserializer")
|
||||
ctxt.add_missing_ast_required_field_error(
|
||||
idl_type, syntax_type, idl_type.name, "deserializer"
|
||||
)
|
||||
|
||||
elif not bson_type in ["array", "object", "bindata"]:
|
||||
if idl_type.deserializer is None:
|
||||
ctxt.add_missing_ast_required_field_error(idl_type, syntax_type, idl_type.name,
|
||||
"deserializer")
|
||||
ctxt.add_missing_ast_required_field_error(
|
||||
idl_type, syntax_type, idl_type.name, "deserializer"
|
||||
)
|
||||
|
||||
if idl_type.deserializer is not None and "BSONElement" not in idl_type.deserializer:
|
||||
ctxt.add_not_custom_scalar_serialization_not_supported_error(
|
||||
idl_type, syntax_type, idl_type.name, bson_type)
|
||||
idl_type, syntax_type, idl_type.name, bson_type
|
||||
)
|
||||
|
||||
if idl_type.serializer is not None:
|
||||
ctxt.add_not_custom_scalar_serialization_not_supported_error(
|
||||
idl_type, syntax_type, idl_type.name, bson_type)
|
||||
idl_type, syntax_type, idl_type.name, bson_type
|
||||
)
|
||||
|
||||
if bson_type == "bindata" and isinstance(idl_type, syntax.Type) and idl_type.default:
|
||||
ctxt.add_bindata_no_default(idl_type, syntax_type, idl_type.name)
|
||||
@ -219,8 +239,9 @@ def _validate_type_properties(ctxt, idl_type, syntax_type):
|
||||
else:
|
||||
# Now, this is a list of scalar types
|
||||
if idl_type.deserializer is None:
|
||||
ctxt.add_missing_ast_required_field_error(idl_type, syntax_type, idl_type.name,
|
||||
"deserializer")
|
||||
ctxt.add_missing_ast_required_field_error(
|
||||
idl_type, syntax_type, idl_type.name, "deserializer"
|
||||
)
|
||||
|
||||
_validate_cpp_type(ctxt, idl_type, syntax_type)
|
||||
|
||||
@ -251,26 +272,27 @@ def _is_duplicate_field(ctxt, field_container, fields, ast_field):
|
||||
|
||||
def _get_struct_qualified_cpp_name(struct):
|
||||
# type: (syntax.Struct) -> str
|
||||
return common.qualify_cpp_name(struct.cpp_namespace,
|
||||
common.title_case(struct.cpp_name or struct.name))
|
||||
return common.qualify_cpp_name(
|
||||
struct.cpp_namespace, common.title_case(struct.cpp_name or struct.name)
|
||||
)
|
||||
|
||||
|
||||
def _compute_field_is_view(resolved_field, ctxt, symbols):
|
||||
# type: (Union[syntax.Type, syntax.Enum, syntax.Struct], errors.ParserContext, syntax.SymbolTable) -> bool
|
||||
"""Compute is_view for a symbol referenced by a field."""
|
||||
# Resolved field is an array.
|
||||
if (isinstance(resolved_field, syntax.ArrayType)):
|
||||
if isinstance(resolved_field, syntax.ArrayType):
|
||||
# Inner type needs to be resolved.
|
||||
return _compute_field_is_view(resolved_field.element_type, ctxt, symbols)
|
||||
|
||||
# Resolved field is a variant.
|
||||
elif (isinstance(resolved_field, syntax.VariantType)):
|
||||
elif isinstance(resolved_field, syntax.VariantType):
|
||||
for variant_type in resolved_field.variant_types:
|
||||
# Inner type needs to be resolved.
|
||||
if (_compute_field_is_view(variant_type, ctxt, symbols)):
|
||||
if _compute_field_is_view(variant_type, ctxt, symbols):
|
||||
return True
|
||||
for variant_struct_type in resolved_field.variant_struct_types:
|
||||
if (_compute_is_view(variant_struct_type, ctxt, symbols)):
|
||||
if _compute_is_view(variant_struct_type, ctxt, symbols):
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -282,12 +304,13 @@ def _compute_field_is_view(resolved_field, ctxt, symbols):
|
||||
def _compute_chained_item_is_view(struct, ctxt, symbols, chained_item):
|
||||
# type: (syntax.Struct, errors.ParserContext, syntax.SymbolTable, Union[syntax.ChainedType, syntax.ChainedStruct]) -> bool
|
||||
"""Helper to compute is_view of chained types or structs."""
|
||||
resolved_chained_item = symbols.resolve_type_from_name(ctxt, struct, chained_item.name,
|
||||
chained_item.name)
|
||||
resolved_chained_item = symbols.resolve_type_from_name(
|
||||
ctxt, struct, chained_item.name, chained_item.name
|
||||
)
|
||||
# If symbols.resolve_field_type returns None, we can assume an error occured during the function.
|
||||
# We can rely on symbols.resolve_field_type to add errors.
|
||||
if (resolved_chained_item is None):
|
||||
assert (ctxt.errors.has_errors())
|
||||
if resolved_chained_item is None:
|
||||
assert ctxt.errors.has_errors()
|
||||
return True
|
||||
return _compute_is_view(resolved_chained_item, ctxt, symbols)
|
||||
|
||||
@ -296,25 +319,25 @@ def _compute_command_type_is_view(struct, ctxt, symbols, field_type):
|
||||
# type: (syntax.Struct, errors.ParserContext, syntax.SymbolTable, syntax.FieldType) -> bool
|
||||
"""
|
||||
Compute is_view for the command parameter type.
|
||||
|
||||
|
||||
This function is similar to _compute_field_is_view, but because command parameter types are
|
||||
syntax.FieldType instead of syntax.Type, separate logic must exist to resolve the command
|
||||
parameter types.
|
||||
"""
|
||||
if (isinstance(field_type, syntax.FieldTypeVariant)):
|
||||
if isinstance(field_type, syntax.FieldTypeVariant):
|
||||
for variant_type in field_type.variant:
|
||||
if (_compute_command_type_is_view(struct, ctxt, symbols, variant_type)):
|
||||
if _compute_command_type_is_view(struct, ctxt, symbols, variant_type):
|
||||
return True
|
||||
elif (isinstance(field_type, syntax.FieldTypeArray)):
|
||||
elif isinstance(field_type, syntax.FieldTypeArray):
|
||||
return _compute_command_type_is_view(struct, ctxt, symbols, field_type.element_type)
|
||||
elif (isinstance(field_type, syntax.FieldTypeSingle)):
|
||||
elif isinstance(field_type, syntax.FieldTypeSingle):
|
||||
resolved_type = symbols.resolve_field_type(ctxt, struct, field_type.type_name, field_type)
|
||||
# If symbols.resolve_field_type returns None, we can assume an error occured during the function.
|
||||
# We can rely on symbols.resolve_field_type to add errors.
|
||||
if (resolved_type is None):
|
||||
assert (ctxt.errors.has_errors())
|
||||
if resolved_type is None:
|
||||
assert ctxt.errors.has_errors()
|
||||
return True
|
||||
if (_compute_field_is_view(resolved_type, ctxt, symbols)):
|
||||
if _compute_field_is_view(resolved_type, ctxt, symbols):
|
||||
return True
|
||||
else:
|
||||
ctxt.add_unknown_command_type_error(struct, struct.name)
|
||||
@ -325,37 +348,38 @@ def _compute_struct_is_view(struct, ctxt, symbols):
|
||||
# type: (syntax.Struct, errors.ParserContext, syntax.SymbolTable) -> bool
|
||||
"""Compute is_view for structs. A struct is a view if any of its fields are views."""
|
||||
# Empty structs are non view types.
|
||||
if (not struct.fields):
|
||||
if not struct.fields:
|
||||
return False
|
||||
|
||||
for field in struct.fields:
|
||||
if (field.ignore):
|
||||
if field.ignore:
|
||||
continue
|
||||
# Get the resolved field from the global symbol table.
|
||||
resolved_field = symbols.resolve_field_type(ctxt, field, field.name, field.type)
|
||||
# If symbols.resolve_field_type returns None, we can assume an error occured during the function.
|
||||
# We can rely on symbols.resolve_field_type to add errors.
|
||||
if (resolved_field is None):
|
||||
assert (ctxt.errors.has_errors())
|
||||
if resolved_field is None:
|
||||
assert ctxt.errors.has_errors()
|
||||
return True
|
||||
# If any field is a view type, then the struct is also a view type.
|
||||
if (_compute_field_is_view(resolved_field, ctxt, symbols)):
|
||||
if _compute_field_is_view(resolved_field, ctxt, symbols):
|
||||
return True
|
||||
|
||||
if (struct.chained_types):
|
||||
if struct.chained_types:
|
||||
for chained_type in struct.chained_types:
|
||||
if (_compute_chained_item_is_view(struct, ctxt, symbols, chained_type)):
|
||||
if _compute_chained_item_is_view(struct, ctxt, symbols, chained_type):
|
||||
return True
|
||||
|
||||
if (struct.chained_structs):
|
||||
if struct.chained_structs:
|
||||
for chained_struct in struct.chained_structs:
|
||||
if (_compute_chained_item_is_view(struct, ctxt, symbols, chained_struct)):
|
||||
if _compute_chained_item_is_view(struct, ctxt, symbols, chained_struct):
|
||||
return True
|
||||
|
||||
# Check command parameter.
|
||||
if (isinstance(struct, syntax.Command)):
|
||||
if (struct.type is not None
|
||||
and _compute_command_type_is_view(struct, ctxt, symbols, struct.type)):
|
||||
if isinstance(struct, syntax.Command):
|
||||
if struct.type is not None and _compute_command_type_is_view(
|
||||
struct, ctxt, symbols, struct.type
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
@ -364,11 +388,11 @@ def _compute_struct_is_view(struct, ctxt, symbols):
|
||||
def _compute_is_view(symbol, ctxt, symbols):
|
||||
# type: (Union[syntax.Type, syntax.Enum, syntax.Struct], errors.ParserContext, syntax.SymbolTable) -> bool
|
||||
"""Compute is_view for any symbol."""
|
||||
if (isinstance(symbol, syntax.Type)):
|
||||
if isinstance(symbol, syntax.Type):
|
||||
return symbol.is_view
|
||||
elif (isinstance(symbol, syntax.Enum)):
|
||||
elif isinstance(symbol, syntax.Enum):
|
||||
return False
|
||||
elif (isinstance(symbol, syntax.Struct)):
|
||||
elif isinstance(symbol, syntax.Struct):
|
||||
return _compute_struct_is_view(symbol, ctxt, symbols)
|
||||
else:
|
||||
ctxt.add_unknown_symbol_error(symbol, symbol.name)
|
||||
@ -394,11 +418,16 @@ def _bind_struct_common(ctxt, parsed_spec, struct, ast_struct):
|
||||
ast_struct.is_command_reply = struct.is_command_reply
|
||||
ast_struct.is_catalog_ctxt = struct.is_catalog_ctxt
|
||||
ast_struct.query_shape_component = struct.query_shape_component
|
||||
ast_struct.unsafe_dangerous_disable_extra_field_duplicate_checks = struct.unsafe_dangerous_disable_extra_field_duplicate_checks
|
||||
ast_struct.unsafe_dangerous_disable_extra_field_duplicate_checks = (
|
||||
struct.unsafe_dangerous_disable_extra_field_duplicate_checks
|
||||
)
|
||||
ast_struct.is_view = _compute_is_view(struct, ctxt, parsed_spec.symbols)
|
||||
|
||||
# Check that unsafe_dangerous_disable_extra_field_duplicate_checks is used correctly
|
||||
if ast_struct.unsafe_dangerous_disable_extra_field_duplicate_checks and ast_struct.strict is True:
|
||||
if (
|
||||
ast_struct.unsafe_dangerous_disable_extra_field_duplicate_checks
|
||||
and ast_struct.strict is True
|
||||
):
|
||||
ctxt.add_strict_and_disable_check_not_allowed(ast_struct)
|
||||
|
||||
if struct.is_generic_cmd_list:
|
||||
@ -419,8 +448,9 @@ def _bind_struct_common(ctxt, parsed_spec, struct, ast_struct):
|
||||
|
||||
for chained_type in struct.chained_types:
|
||||
ast_field = _bind_chained_type(ctxt, parsed_spec, ast_struct, chained_type)
|
||||
if ast_field and not _is_duplicate_field(ctxt, chained_type.name, ast_struct.fields,
|
||||
ast_field):
|
||||
if ast_field and not _is_duplicate_field(
|
||||
ctxt, chained_type.name, ast_struct.fields, ast_field
|
||||
):
|
||||
ast_struct.fields.append(ast_field)
|
||||
|
||||
# Merge chained structs as a chained struct and ignored fields
|
||||
@ -443,12 +473,14 @@ def _bind_struct_common(ctxt, parsed_spec, struct, ast_struct):
|
||||
ast_field.generic_field_info = gen_field_info
|
||||
if ast_field.supports_doc_sequence and not isinstance(ast_struct, ast.Command):
|
||||
# Doc sequences are only supported in commands at the moment
|
||||
ctxt.add_bad_struct_field_as_doc_sequence_error(ast_struct, ast_struct.name,
|
||||
ast_field.name)
|
||||
ctxt.add_bad_struct_field_as_doc_sequence_error(
|
||||
ast_struct, ast_struct.name, ast_field.name
|
||||
)
|
||||
|
||||
if ast_field.non_const_getter and struct.immutable:
|
||||
ctxt.add_bad_field_non_const_getter_in_immutable_struct_error(
|
||||
ast_struct, ast_struct.name, ast_field.name)
|
||||
ast_struct, ast_struct.name, ast_field.name
|
||||
)
|
||||
|
||||
if not _is_duplicate_field(ctxt, ast_struct.name, ast_struct.fields, ast_field):
|
||||
ast_struct.fields.append(ast_field)
|
||||
@ -462,10 +494,12 @@ def _bind_struct_common(ctxt, parsed_spec, struct, ast_struct):
|
||||
ctxt.add_must_be_query_shape_component(ast_field, ast_struct.name, ast_field.name)
|
||||
|
||||
if ast_field.query_shape == ast.QueryShapeFieldType.ANONYMIZE and not (
|
||||
ast_field.type.cpp_type in ["std::string", "std::vector<std::string>"]
|
||||
or 'string' in ast_field.type.bson_serialization_type):
|
||||
ctxt.add_query_shape_anonymize_must_be_string(ast_field, ast_field.name,
|
||||
ast_field.type.cpp_type)
|
||||
ast_field.type.cpp_type in ["std::string", "std::vector<std::string>"]
|
||||
or "string" in ast_field.type.bson_serialization_type
|
||||
):
|
||||
ctxt.add_query_shape_anonymize_must_be_string(
|
||||
ast_field, ast_field.name, ast_field.type.cpp_type
|
||||
)
|
||||
|
||||
# Fill out the field comparison_order property as needed
|
||||
if ast_struct.generate_comparison_operators and ast_struct.fields:
|
||||
@ -478,8 +512,9 @@ def _bind_struct_common(ctxt, parsed_spec, struct, ast_struct):
|
||||
if not ast_field.comparison_order == -1:
|
||||
use_default_order = False
|
||||
if ast_field.comparison_order in comparison_orders:
|
||||
ctxt.add_duplicate_comparison_order_field_error(ast_struct, ast_struct.name,
|
||||
ast_field.comparison_order)
|
||||
ctxt.add_duplicate_comparison_order_field_error(
|
||||
ast_struct, ast_struct.name, ast_field.comparison_order
|
||||
)
|
||||
|
||||
comparison_orders.add(ast_field.comparison_order)
|
||||
|
||||
@ -500,8 +535,9 @@ def _inject_hidden_fields(struct):
|
||||
|
||||
serialization_context_field = syntax.Field(struct.file_name, struct.line, struct.column)
|
||||
serialization_context_field.name = "serialization_context" # This comes from basic_types.idl
|
||||
serialization_context_field.type = syntax.FieldTypeSingle(struct.file_name, struct.line,
|
||||
struct.column)
|
||||
serialization_context_field.type = syntax.FieldTypeSingle(
|
||||
struct.file_name, struct.line, struct.column
|
||||
)
|
||||
serialization_context_field.type.type_name = "serialization_context"
|
||||
serialization_context_field.cpp_name = "serializationContext"
|
||||
serialization_context_field.optional = False
|
||||
@ -606,7 +642,7 @@ def _bind_variant_field(ctxt, ast_field, idl_type):
|
||||
def gen_cpp_types():
|
||||
for alternative in ast_field.type.variant_types:
|
||||
if alternative.is_array:
|
||||
yield f'std::vector<{alternative.cpp_type}>'
|
||||
yield f"std::vector<{alternative.cpp_type}>"
|
||||
else:
|
||||
yield alternative.cpp_type
|
||||
|
||||
@ -638,8 +674,9 @@ def _bind_command_type(ctxt, parsed_spec, command):
|
||||
ctxt.add_array_not_valid_error(ast_field, "field", ast_field.name)
|
||||
|
||||
# Resolve the command type as a field
|
||||
syntax_symbol = parsed_spec.symbols.resolve_field_type(ctxt, command, command.name,
|
||||
command.type)
|
||||
syntax_symbol = parsed_spec.symbols.resolve_field_type(
|
||||
ctxt, command, command.name, command.type
|
||||
)
|
||||
if syntax_symbol is None:
|
||||
return None
|
||||
|
||||
@ -649,8 +686,9 @@ def _bind_command_type(ctxt, parsed_spec, command):
|
||||
|
||||
assert not isinstance(syntax_symbol, syntax.Enum)
|
||||
|
||||
base_type = (syntax_symbol.element_type
|
||||
if isinstance(syntax_symbol, syntax.ArrayType) else syntax_symbol)
|
||||
base_type = (
|
||||
syntax_symbol.element_type if isinstance(syntax_symbol, syntax.ArrayType) else syntax_symbol
|
||||
)
|
||||
|
||||
# Copy over only the needed information if this is a struct or a type.
|
||||
if isinstance(base_type, syntax.Struct):
|
||||
@ -683,8 +721,9 @@ def _bind_command_reply_type(ctxt, parsed_spec, command):
|
||||
ast_field.description = f"{command.name} reply type"
|
||||
|
||||
# Resolve the command type as a field
|
||||
syntax_symbol = parsed_spec.symbols.resolve_type_from_name(ctxt, command, command.name,
|
||||
command.reply_type)
|
||||
syntax_symbol = parsed_spec.symbols.resolve_type_from_name(
|
||||
ctxt, command, command.name, command.reply_type
|
||||
)
|
||||
|
||||
if syntax_symbol is None:
|
||||
# Resolution failed, we've recorded an error.
|
||||
@ -714,8 +753,9 @@ def _bind_enum_value(ctxt, parsed_spec, location, enum_name, enum_value):
|
||||
# type: (errors.ParserContext, syntax.IDLSpec, common.SourceLocation, str, str) -> str
|
||||
|
||||
# Look up the enum for "enum_name" in the symbol table
|
||||
access_check_enum = parsed_spec.symbols.resolve_type_from_name(ctxt, location, "access_check",
|
||||
enum_name)
|
||||
access_check_enum = parsed_spec.symbols.resolve_type_from_name(
|
||||
ctxt, location, "access_check", enum_name
|
||||
)
|
||||
|
||||
if access_check_enum is None:
|
||||
# Resolution failed, we've recorded an error.
|
||||
@ -725,8 +765,9 @@ def _bind_enum_value(ctxt, parsed_spec, location, enum_name, enum_value):
|
||||
ctxt.add_unknown_type_error(location, enum_name, "enum")
|
||||
return None
|
||||
|
||||
syntax_enum = resolve_enum_value(ctxt, location, cast(syntax.Enum, access_check_enum),
|
||||
enum_value)
|
||||
syntax_enum = resolve_enum_value(
|
||||
ctxt, location, cast(syntax.Enum, access_check_enum), enum_value
|
||||
)
|
||||
if not syntax_enum:
|
||||
return None
|
||||
|
||||
@ -737,22 +778,25 @@ def _bind_single_check(ctxt, parsed_spec, access_check):
|
||||
# type: (errors.ParserContext, syntax.IDLSpec, syntax.AccessCheck) -> ast.AccessCheck
|
||||
"""Bind a single access_check."""
|
||||
|
||||
ast_access_check = ast.AccessCheck(access_check.file_name, access_check.line,
|
||||
access_check.column)
|
||||
ast_access_check = ast.AccessCheck(
|
||||
access_check.file_name, access_check.line, access_check.column
|
||||
)
|
||||
|
||||
assert bool(access_check.check) != bool(access_check.privilege)
|
||||
|
||||
if access_check.check:
|
||||
ast_access_check.check = _bind_enum_value(ctxt, parsed_spec, access_check, "AccessCheck",
|
||||
access_check.check)
|
||||
ast_access_check.check = _bind_enum_value(
|
||||
ctxt, parsed_spec, access_check, "AccessCheck", access_check.check
|
||||
)
|
||||
if not ast_access_check.check:
|
||||
return None
|
||||
else:
|
||||
privilege = access_check.privilege
|
||||
ast_privilege = ast.Privilege(privilege.file_name, privilege.line, privilege.column)
|
||||
|
||||
ast_privilege.resource_pattern = _bind_enum_value(ctxt, parsed_spec, privilege, "MatchType",
|
||||
privilege.resource_pattern)
|
||||
ast_privilege.resource_pattern = _bind_enum_value(
|
||||
ctxt, parsed_spec, privilege, "MatchType", privilege.resource_pattern
|
||||
)
|
||||
if not ast_privilege.resource_pattern:
|
||||
return None
|
||||
|
||||
@ -935,7 +979,7 @@ def _validate_variant_type(ctxt, syntax_symbol, field):
|
||||
|
||||
for type_name, count in array_type_count.items():
|
||||
if count > 1:
|
||||
ctxt.add_variant_duplicate_types_error(syntax_symbol, field.name, f'array<{type_name}>')
|
||||
ctxt.add_variant_duplicate_types_error(syntax_symbol, field.name, f"array<{type_name}>")
|
||||
|
||||
types = len(syntax_symbol.variant_types) + len(syntax_symbol.variant_struct_types)
|
||||
if types < 2:
|
||||
@ -958,14 +1002,14 @@ def _validate_field_properties(ctxt, ast_field):
|
||||
if ast_field.optional:
|
||||
ctxt.add_bad_field_default_and_optional(ast_field, ast_field.name)
|
||||
|
||||
if ast_field.type.bson_serialization_type == ['bindata']:
|
||||
if ast_field.type.bson_serialization_type == ["bindata"]:
|
||||
ctxt.add_bindata_no_default(ast_field, ast_field.type.name, ast_field.name)
|
||||
|
||||
if ast_field.always_serialize and not ast_field.optional:
|
||||
ctxt.add_bad_field_always_serialize_not_optional(ast_field, ast_field.name)
|
||||
|
||||
# A "chain" type should never appear as a field.
|
||||
if ast_field.type.bson_serialization_type == ['chain']:
|
||||
if ast_field.type.bson_serialization_type == ["chain"]:
|
||||
ctxt.add_bad_array_of_chain(ast_field, ast_field.name)
|
||||
|
||||
|
||||
@ -979,7 +1023,7 @@ def _validate_doc_sequence_field(ctxt, ast_field):
|
||||
|
||||
# The only allowed BSON type for a doc_sequence field is "object"
|
||||
for serialization_type in ast_field.type.bson_serialization_type:
|
||||
if serialization_type != 'object':
|
||||
if serialization_type != "object":
|
||||
ctxt.add_bad_non_object_as_doc_sequence_error(ast_field, ast_field.name)
|
||||
|
||||
|
||||
@ -991,7 +1035,7 @@ def _normalize_method_name(cpp_type_name, cpp_method_name):
|
||||
return cpp_method_name
|
||||
|
||||
# Global function
|
||||
if cpp_method_name.startswith('::'):
|
||||
if cpp_method_name.startswith("::"):
|
||||
return cpp_method_name
|
||||
|
||||
# Method is full qualified already
|
||||
@ -1003,7 +1047,7 @@ def _normalize_method_name(cpp_type_name, cpp_method_name):
|
||||
|
||||
# Method is prefixed with just the type name
|
||||
if cpp_method_name.startswith(type_name):
|
||||
return '::'.join(cpp_type_name.split('::')[0:-1]) + "::" + cpp_method_name
|
||||
return "::".join(cpp_type_name.split("::")[0:-1]) + "::" + cpp_method_name
|
||||
|
||||
return cpp_method_name
|
||||
|
||||
@ -1046,9 +1090,9 @@ def _bind_expression(expr, allow_literal_string=True):
|
||||
# std::string
|
||||
if allow_literal_string:
|
||||
strval = expr.literal
|
||||
for i in ['\\', '"', "'"]:
|
||||
for i in ["\\", '"', "'"]:
|
||||
if i in strval:
|
||||
strval = strval.replace(i, '\\' + i)
|
||||
strval = strval.replace(i, "\\" + i)
|
||||
node.expr = '"' + strval + '"'
|
||||
return node
|
||||
|
||||
@ -1093,11 +1137,11 @@ def _bind_condition(condition, condition_for):
|
||||
ast_condition.preprocessor = condition.preprocessor
|
||||
|
||||
if condition.feature_flag:
|
||||
assert condition_for == 'server_parameter'
|
||||
assert condition_for == "server_parameter"
|
||||
ast_condition.feature_flag = condition.feature_flag
|
||||
|
||||
if condition.min_fcv:
|
||||
assert condition_for == 'server_parameter'
|
||||
assert condition_for == "server_parameter"
|
||||
ast_condition.min_fcv = condition.min_fcv
|
||||
|
||||
return ast_condition
|
||||
@ -1186,12 +1230,14 @@ def _bind_field(ctxt, parsed_spec, field):
|
||||
_validate_array_type(ctxt, cast(syntax.ArrayType, syntax_symbol), field)
|
||||
elif field.supports_doc_sequence:
|
||||
# Doc sequences are only supported for arrays
|
||||
ctxt.add_bad_non_array_as_doc_sequence_error(syntax_symbol, syntax_symbol.name,
|
||||
ast_field.name)
|
||||
ctxt.add_bad_non_array_as_doc_sequence_error(
|
||||
syntax_symbol, syntax_symbol.name, ast_field.name
|
||||
)
|
||||
return None
|
||||
|
||||
base_type = (syntax_symbol.element_type
|
||||
if isinstance(syntax_symbol, syntax.ArrayType) else syntax_symbol)
|
||||
base_type = (
|
||||
syntax_symbol.element_type if isinstance(syntax_symbol, syntax.ArrayType) else syntax_symbol
|
||||
)
|
||||
|
||||
# Copy over only the needed information if this is a struct or a type.
|
||||
|
||||
@ -1244,8 +1290,9 @@ def _bind_field(ctxt, parsed_spec, field):
|
||||
def _bind_chained_type(ctxt, parsed_spec, location, chained_type):
|
||||
# type: (errors.ParserContext, syntax.IDLSpec, common.SourceLocation, syntax.ChainedType) -> ast.Field
|
||||
"""Bind the specified chained type."""
|
||||
syntax_symbol = parsed_spec.symbols.resolve_type_from_name(ctxt, location, chained_type.name,
|
||||
chained_type.name)
|
||||
syntax_symbol = parsed_spec.symbols.resolve_type_from_name(
|
||||
ctxt, location, chained_type.name, chained_type.name
|
||||
)
|
||||
if not syntax_symbol:
|
||||
return None
|
||||
|
||||
@ -1255,9 +1302,10 @@ def _bind_chained_type(ctxt, parsed_spec, location, chained_type):
|
||||
|
||||
idltype = cast(syntax.Type, syntax_symbol)
|
||||
|
||||
if len(idltype.bson_serialization_type) != 1 or idltype.bson_serialization_type[0] != 'chain':
|
||||
ctxt.add_chained_type_wrong_type_error(location, chained_type.name,
|
||||
idltype.bson_serialization_type[0])
|
||||
if len(idltype.bson_serialization_type) != 1 or idltype.bson_serialization_type[0] != "chain":
|
||||
ctxt.add_chained_type_wrong_type_error(
|
||||
location, chained_type.name, idltype.bson_serialization_type[0]
|
||||
)
|
||||
return None
|
||||
|
||||
ast_field = ast.Field(location.file_name, location.line, location.column)
|
||||
@ -1274,7 +1322,8 @@ def _bind_chained_struct(ctxt, parsed_spec, ast_struct, chained_struct):
|
||||
# type: (errors.ParserContext, syntax.IDLSpec, ast.Struct, syntax.ChainedStruct) -> None
|
||||
"""Bind the specified chained struct."""
|
||||
syntax_symbol = parsed_spec.symbols.resolve_type_from_name(
|
||||
ctxt, ast_struct, chained_struct.name, chained_struct.name)
|
||||
ctxt, ast_struct, chained_struct.name, chained_struct.name
|
||||
)
|
||||
|
||||
if not syntax_symbol:
|
||||
return
|
||||
@ -1287,12 +1336,14 @@ def _bind_chained_struct(ctxt, parsed_spec, ast_struct, chained_struct):
|
||||
|
||||
# chained struct cannot be strict unless it is inlined
|
||||
if struct.strict and not ast_struct.inline_chained_structs:
|
||||
ctxt.add_chained_nested_struct_no_strict_error(ast_struct, ast_struct.name,
|
||||
chained_struct.name)
|
||||
ctxt.add_chained_nested_struct_no_strict_error(
|
||||
ast_struct, ast_struct.name, chained_struct.name
|
||||
)
|
||||
|
||||
if struct.chained_types or struct.chained_structs:
|
||||
ctxt.add_chained_nested_struct_no_nested_error(ast_struct, ast_struct.name,
|
||||
chained_struct.name)
|
||||
ctxt.add_chained_nested_struct_no_nested_error(
|
||||
ast_struct, ast_struct.name, chained_struct.name
|
||||
)
|
||||
|
||||
# Configure a field for the chained struct.
|
||||
ast_chained_field = ast.Field(ast_struct.file_name, ast_struct.line, ast_struct.column)
|
||||
@ -1313,9 +1364,9 @@ def _bind_chained_struct(ctxt, parsed_spec, ast_struct, chained_struct):
|
||||
# Don't use internal fields in chained types, stick to local access only
|
||||
if ast_field.type.internal_only:
|
||||
continue
|
||||
if ast_field and not _is_duplicate_field(ctxt, chained_struct.name, ast_struct.fields,
|
||||
ast_field):
|
||||
|
||||
if ast_field and not _is_duplicate_field(
|
||||
ctxt, chained_struct.name, ast_struct.fields, ast_field
|
||||
):
|
||||
if ast_struct.inline_chained_structs:
|
||||
ast_field.chained_struct_field = ast_chained_field
|
||||
else:
|
||||
@ -1329,8 +1380,9 @@ def _bind_globals(ctxt, parsed_spec):
|
||||
# type: (errors.ParserContext, syntax.IDLSpec) -> ast.Global
|
||||
"""Bind the globals object from the idl.syntax tree into the idl.ast tree by doing a deep copy."""
|
||||
if parsed_spec.globals:
|
||||
ast_global = ast.Global(parsed_spec.globals.file_name, parsed_spec.globals.line,
|
||||
parsed_spec.globals.column)
|
||||
ast_global = ast.Global(
|
||||
parsed_spec.globals.file_name, parsed_spec.globals.line, parsed_spec.globals.column
|
||||
)
|
||||
ast_global.cpp_namespace = parsed_spec.globals.cpp_namespace
|
||||
ast_global.cpp_includes = parsed_spec.globals.cpp_includes
|
||||
|
||||
@ -1345,7 +1397,8 @@ def _bind_globals(ctxt, parsed_spec):
|
||||
init = configs.initializer
|
||||
|
||||
ast_global.configs.initializer = ast.GlobalInitializer(
|
||||
init.file_name, init.line, init.column)
|
||||
init.file_name, init.line, init.column
|
||||
)
|
||||
# Parser rule makes it impossible to have both name and register/store.
|
||||
ast_global.configs.initializer.name = init.name
|
||||
ast_global.configs.initializer.register = init.register
|
||||
@ -1371,8 +1424,9 @@ def _validate_enum_int(ctxt, idl_enum):
|
||||
try:
|
||||
int_values_set.add(int(enum_value.value))
|
||||
except ValueError as value_error:
|
||||
ctxt.add_enum_value_not_int_error(idl_enum, idl_enum.name, enum_value.value,
|
||||
str(value_error))
|
||||
ctxt.add_enum_value_not_int_error(
|
||||
idl_enum, idl_enum.name, enum_value.value, str(value_error)
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@ -1412,7 +1466,7 @@ def _bind_enum(ctxt, idl_enum):
|
||||
if len(idl_enum.values) != len(values_set):
|
||||
ctxt.add_enum_value_not_unique_error(idl_enum, idl_enum.name)
|
||||
|
||||
if ast_enum.type == 'int':
|
||||
if ast_enum.type == "int":
|
||||
_validate_enum_int(ctxt, idl_enum)
|
||||
|
||||
return ast_enum
|
||||
@ -1423,9 +1477,9 @@ def _bind_server_parameter_class(ctxt, ast_param, param):
|
||||
"""Bind and validate ServerParameter attributes specific to specialized ServerParameters."""
|
||||
|
||||
# Fields specific to bound and unbound standard params.
|
||||
for field in ['cpp_vartype', 'cpp_varname', 'on_update', 'validator']:
|
||||
for field in ["cpp_vartype", "cpp_varname", "on_update", "validator"]:
|
||||
if getattr(param, field) is not None:
|
||||
ctxt.add_server_parameter_invalid_attr(param, field, 'specialized')
|
||||
ctxt.add_server_parameter_invalid_attr(param, field, "specialized")
|
||||
return None
|
||||
|
||||
# Fields specific to specialized stroage.
|
||||
@ -1433,8 +1487,9 @@ def _bind_server_parameter_class(ctxt, ast_param, param):
|
||||
|
||||
if param.default is not None:
|
||||
if not param.default.is_constexpr:
|
||||
ctxt.add_server_parameter_invalid_attr(param, 'default.is_constexpr=false',
|
||||
'specialized')
|
||||
ctxt.add_server_parameter_invalid_attr(
|
||||
param, "default.is_constexpr=false", "specialized"
|
||||
)
|
||||
return None
|
||||
|
||||
ast_param.default = _bind_expression(param.default)
|
||||
@ -1448,7 +1503,7 @@ def _bind_server_parameter_class(ctxt, ast_param, param):
|
||||
ast_param.cpp_class.override_validate = cls.override_validate
|
||||
|
||||
# If set_at is cluster, then set must be overridden. Otherwise, use the parsed value.
|
||||
ast_param.cpp_class.override_set = True if param.set_at == ['cluster'] else cls.override_set
|
||||
ast_param.cpp_class.override_set = True if param.set_at == ["cluster"] else cls.override_set
|
||||
|
||||
return ast_param
|
||||
|
||||
@ -1458,13 +1513,13 @@ def _bind_server_parameter_with_storage(ctxt, ast_param, param):
|
||||
"""Bind and validate ServerParameter attributes specific to bound ServerParameters."""
|
||||
|
||||
# Fields specific to specialized and unbound standard params.
|
||||
for field in ['cpp_class']:
|
||||
for field in ["cpp_class"]:
|
||||
if getattr(param, field) is not None:
|
||||
ctxt.add_server_parameter_invalid_attr(param, field, 'bound')
|
||||
ctxt.add_server_parameter_invalid_attr(param, field, "bound")
|
||||
return None
|
||||
|
||||
if param.set_at == ['cluster']:
|
||||
ast_param.cpp_vartype = f'TenantIdMap<{param.cpp_vartype}>'
|
||||
if param.set_at == ["cluster"]:
|
||||
ast_param.cpp_vartype = f"TenantIdMap<{param.cpp_vartype}>"
|
||||
else:
|
||||
ast_param.cpp_vartype = param.cpp_vartype
|
||||
ast_param.cpp_varname = param.cpp_varname
|
||||
@ -1487,20 +1542,20 @@ def _bind_server_parameter_set_at(ctxt, param):
|
||||
# type: (errors.ParserContext, syntax.ServerParameter) -> str
|
||||
"""Translate set_at options to C++ enum value."""
|
||||
|
||||
if param.set_at == ['readonly']:
|
||||
if param.set_at == ["readonly"]:
|
||||
# Readonly may not be mixed with startup or runtime
|
||||
return "ServerParameterType::kReadOnly"
|
||||
|
||||
if param.set_at == ['cluster']:
|
||||
if param.set_at == ["cluster"]:
|
||||
# Cluster-wide parameters may not be mixed with startup or runtime.
|
||||
# They are implicitly runtime-only.
|
||||
return "ServerParameterType::kClusterWide"
|
||||
|
||||
set_at = 0
|
||||
for psa in param.set_at:
|
||||
if psa.lower() == 'startup':
|
||||
if psa.lower() == "startup":
|
||||
set_at |= 1
|
||||
elif psa.lower() == 'runtime':
|
||||
elif psa.lower() == "runtime":
|
||||
set_at |= 2
|
||||
else:
|
||||
ctxt.add_bad_setat_specifier(param, psa)
|
||||
@ -1516,7 +1571,7 @@ def _bind_server_parameter_set_at(ctxt, param):
|
||||
return mask_to_text[set_at]
|
||||
|
||||
# Can't happen based on above logic.
|
||||
ctxt.add_bad_setat_specifier(param, ','.join(param.set_at))
|
||||
ctxt.add_bad_setat_specifier(param, ",".join(param.set_at))
|
||||
return None
|
||||
|
||||
|
||||
@ -1526,19 +1581,19 @@ def _bind_server_parameter(ctxt, param):
|
||||
ast_param = ast.ServerParameter(param.file_name, param.line, param.column)
|
||||
ast_param.name = param.name
|
||||
ast_param.description = param.description
|
||||
ast_param.condition = _bind_condition(param.condition, condition_for='server_parameter')
|
||||
ast_param.condition = _bind_condition(param.condition, condition_for="server_parameter")
|
||||
ast_param.redact = param.redact
|
||||
ast_param.test_only = param.test_only
|
||||
ast_param.deprecated_name = param.deprecated_name
|
||||
|
||||
# The omit_in_ftdc flag can only be enabled for cluster parameters.
|
||||
if param.omit_in_ftdc is not None and param.set_at != ['cluster']:
|
||||
ctxt.add_server_parameter_invalid_attr(param, 'omit_in_ftdc=True', ''.join(param.set_at))
|
||||
if param.omit_in_ftdc is not None and param.set_at != ["cluster"]:
|
||||
ctxt.add_server_parameter_invalid_attr(param, "omit_in_ftdc=True", "".join(param.set_at))
|
||||
return None
|
||||
|
||||
# If omit_in_ftdc is None (it has not been set) for a cluster parameter, then emit an error.
|
||||
if param.omit_in_ftdc is None and param.set_at == ['cluster']:
|
||||
ctxt.add_server_parameter_required_attr(param, 'omit_in_ftdc', 'cluster')
|
||||
if param.omit_in_ftdc is None and param.set_at == ["cluster"]:
|
||||
ctxt.add_server_parameter_required_attr(param, "omit_in_ftdc", "cluster")
|
||||
|
||||
ast_param.omit_in_ftdc = param.omit_in_ftdc
|
||||
|
||||
@ -1551,7 +1606,7 @@ def _bind_server_parameter(ctxt, param):
|
||||
elif param.cpp_varname:
|
||||
return _bind_server_parameter_with_storage(ctxt, ast_param, param)
|
||||
else:
|
||||
ctxt.add_server_parameter_required_attr(param, 'cpp_varname', 'server_parameter')
|
||||
ctxt.add_server_parameter_required_attr(param, "cpp_varname", "server_parameter")
|
||||
return None
|
||||
|
||||
|
||||
@ -1572,7 +1627,11 @@ def _bind_feature_flags(ctxt, param):
|
||||
return None
|
||||
|
||||
# Feature flags that default to true and should be FCV gated are required to have a version
|
||||
if param.default.literal == "true" and param.shouldBeFCVGated.literal == "true" and not param.version:
|
||||
if (
|
||||
param.default.literal == "true"
|
||||
and param.shouldBeFCVGated.literal == "true"
|
||||
and not param.version
|
||||
):
|
||||
ctxt.add_feature_flag_default_true_missing_version(param)
|
||||
return None
|
||||
|
||||
@ -1582,9 +1641,11 @@ def _bind_feature_flags(ctxt, param):
|
||||
return None
|
||||
|
||||
expr = syntax.Expression(param.default.file_name, param.default.line, param.default.column)
|
||||
expr.expr = '%s, "%s"_sd, %s' % (param.default.literal, param.version if
|
||||
(param.shouldBeFCVGated.literal == "true"
|
||||
and param.version) else '', param.shouldBeFCVGated.literal)
|
||||
expr.expr = '%s, "%s"_sd, %s' % (
|
||||
param.default.literal,
|
||||
param.version if (param.shouldBeFCVGated.literal == "true" and param.version) else "",
|
||||
param.shouldBeFCVGated.literal,
|
||||
)
|
||||
|
||||
ast_param.default = _bind_expression(expr)
|
||||
ast_param.default.export = False
|
||||
@ -1597,7 +1658,7 @@ def _bind_feature_flags(ctxt, param):
|
||||
def _is_invalid_config_short_name(name):
|
||||
# type: (str) -> bool
|
||||
"""Check if a given name is valid as a short name."""
|
||||
return ('.' in name) or (',' in name)
|
||||
return ("." in name) or ("," in name)
|
||||
|
||||
|
||||
def _parse_config_option_sources(source_list):
|
||||
@ -1635,7 +1696,7 @@ def _bind_config_option(ctxt, globals_spec, option):
|
||||
|
||||
node = ast.ConfigOption(option.file_name, option.line, option.column)
|
||||
|
||||
if _is_invalid_config_short_name(option.short_name or ''):
|
||||
if _is_invalid_config_short_name(option.short_name or ""):
|
||||
ctxt.add_invalid_short_name(option, option.short_name)
|
||||
return None
|
||||
|
||||
@ -1664,13 +1725,13 @@ def _bind_config_option(ctxt, globals_spec, option):
|
||||
ctxt.add_missing_short_name_with_single_name(option, option.single_name)
|
||||
return None
|
||||
|
||||
node.short_name = node.short_name + ',' + option.single_name
|
||||
node.short_name = node.short_name + "," + option.single_name
|
||||
|
||||
node.description = _bind_expression(option.description)
|
||||
node.arg_vartype = option.arg_vartype
|
||||
node.cpp_vartype = option.cpp_vartype
|
||||
node.cpp_varname = option.cpp_varname
|
||||
node.condition = _bind_condition(option.condition, condition_for='config')
|
||||
node.condition = _bind_condition(option.condition, condition_for="config")
|
||||
|
||||
node.requires = option.requires
|
||||
node.conflicts = option.conflicts
|
||||
@ -1694,7 +1755,7 @@ def _bind_config_option(ctxt, globals_spec, option):
|
||||
|
||||
node.source = _parse_config_option_sources(source_list)
|
||||
if node.source is None:
|
||||
ctxt.add_bad_source_specifier(option, ', '.join(source_list))
|
||||
ctxt.add_bad_source_specifier(option, ", ".join(source_list))
|
||||
return None
|
||||
|
||||
if option.duplicate_behavior:
|
||||
@ -1710,17 +1771,17 @@ def _bind_config_option(ctxt, globals_spec, option):
|
||||
return None
|
||||
|
||||
# Parse single digit, closed range, or open range of digits.
|
||||
spread = option.positional.split('-')
|
||||
spread = option.positional.split("-")
|
||||
if len(spread) == 1:
|
||||
# Make a single number behave like a range of that number, (e.g. "2" -> "2-2").
|
||||
spread.append(spread[0])
|
||||
if (len(spread) != 2) or ((spread[0] == "") and (spread[1] == "")):
|
||||
ctxt.add_bad_numeric_range(option, 'positional', option.positional)
|
||||
ctxt.add_bad_numeric_range(option, "positional", option.positional)
|
||||
try:
|
||||
node.positional_start = int(spread[0] or "-1")
|
||||
node.positional_end = int(spread[1] or "-1")
|
||||
except ValueError:
|
||||
ctxt.add_bad_numeric_range(option, 'positional', option.positional)
|
||||
ctxt.add_bad_numeric_range(option, "positional", option.positional)
|
||||
return None
|
||||
|
||||
if option.validator is not None:
|
||||
|
||||
@ -37,43 +37,43 @@ from typing import Dict, List
|
||||
# scalar: True if the type is not an array or object
|
||||
# bson_type_enum: The BSONType enum value for the given type
|
||||
_BSON_TYPE_INFORMATION = {
|
||||
"double": {'scalar': True, 'bson_type_enum': 'NumberDouble'},
|
||||
"string": {'scalar': True, 'bson_type_enum': 'String'},
|
||||
"object": {'scalar': False, 'bson_type_enum': 'Object'},
|
||||
"array": {'scalar': False, 'bson_type_enum': 'Array'},
|
||||
"bindata": {'scalar': True, 'bson_type_enum': 'BinData'},
|
||||
"undefined": {'scalar': True, 'bson_type_enum': 'Undefined'},
|
||||
"objectid": {'scalar': True, 'bson_type_enum': 'jstOID'},
|
||||
"bool": {'scalar': True, 'bson_type_enum': 'Bool'},
|
||||
"date": {'scalar': True, 'bson_type_enum': 'Date'},
|
||||
"null": {'scalar': True, 'bson_type_enum': 'jstNULL'},
|
||||
"regex": {'scalar': True, 'bson_type_enum': 'RegEx'},
|
||||
"int": {'scalar': True, 'bson_type_enum': 'NumberInt'},
|
||||
"timestamp": {'scalar': True, 'bson_type_enum': 'bsonTimestamp'},
|
||||
"long": {'scalar': True, 'bson_type_enum': 'NumberLong'},
|
||||
"decimal": {'scalar': True, 'bson_type_enum': 'NumberDecimal'},
|
||||
"double": {"scalar": True, "bson_type_enum": "NumberDouble"},
|
||||
"string": {"scalar": True, "bson_type_enum": "String"},
|
||||
"object": {"scalar": False, "bson_type_enum": "Object"},
|
||||
"array": {"scalar": False, "bson_type_enum": "Array"},
|
||||
"bindata": {"scalar": True, "bson_type_enum": "BinData"},
|
||||
"undefined": {"scalar": True, "bson_type_enum": "Undefined"},
|
||||
"objectid": {"scalar": True, "bson_type_enum": "jstOID"},
|
||||
"bool": {"scalar": True, "bson_type_enum": "Bool"},
|
||||
"date": {"scalar": True, "bson_type_enum": "Date"},
|
||||
"null": {"scalar": True, "bson_type_enum": "jstNULL"},
|
||||
"regex": {"scalar": True, "bson_type_enum": "RegEx"},
|
||||
"int": {"scalar": True, "bson_type_enum": "NumberInt"},
|
||||
"timestamp": {"scalar": True, "bson_type_enum": "bsonTimestamp"},
|
||||
"long": {"scalar": True, "bson_type_enum": "NumberLong"},
|
||||
"decimal": {"scalar": True, "bson_type_enum": "NumberDecimal"},
|
||||
}
|
||||
|
||||
# Dictionary of BinData subtype type Information
|
||||
# scalar: True if the type is not an array or object
|
||||
# bindata_enum: The BinDataType enum value for the given type
|
||||
_BINDATA_SUBTYPE = {
|
||||
"generic": {'scalar': True, 'bindata_enum': 'BinDataGeneral'},
|
||||
"function": {'scalar': True, 'bindata_enum': 'Function'},
|
||||
"generic": {"scalar": True, "bindata_enum": "BinDataGeneral"},
|
||||
"function": {"scalar": True, "bindata_enum": "Function"},
|
||||
# Also simply known as type 2, deprecated, and requires special handling
|
||||
#"binary": {
|
||||
# "binary": {
|
||||
# 'scalar': False,
|
||||
# 'bindata_enum': 'ByteArrayDeprecated'
|
||||
#},
|
||||
# },
|
||||
# Deprecated
|
||||
# "uuid_old": {
|
||||
# 'scalar': False,
|
||||
# 'bindata_enum': 'bdtUUID'
|
||||
# },
|
||||
"uuid": {'scalar': True, 'bindata_enum': 'newUUID'},
|
||||
"md5": {'scalar': True, 'bindata_enum': 'MD5Type'},
|
||||
"encrypt": {'scalar': True, 'bindata_enum': 'Encrypt'},
|
||||
"sensitive": {'scalar': True, 'bindata_enum': 'Sensitive'},
|
||||
"uuid": {"scalar": True, "bindata_enum": "newUUID"},
|
||||
"md5": {"scalar": True, "bindata_enum": "MD5Type"},
|
||||
"encrypt": {"scalar": True, "bindata_enum": "Encrypt"},
|
||||
"sensitive": {"scalar": True, "bindata_enum": "Sensitive"},
|
||||
}
|
||||
|
||||
|
||||
@ -87,14 +87,14 @@ def is_scalar_bson_type(name):
|
||||
# type: (str) -> bool
|
||||
"""Return True if this bson type is a scalar."""
|
||||
assert is_valid_bson_type(name)
|
||||
return _BSON_TYPE_INFORMATION[name]['scalar'] # type: ignore
|
||||
return _BSON_TYPE_INFORMATION[name]["scalar"] # type: ignore
|
||||
|
||||
|
||||
def cpp_bson_type_name(name):
|
||||
# type: (str) -> str
|
||||
"""Return the C++ type name for a bson type."""
|
||||
assert is_valid_bson_type(name)
|
||||
return _BSON_TYPE_INFORMATION[name]['bson_type_enum'] # type: ignore
|
||||
return _BSON_TYPE_INFORMATION[name]["bson_type_enum"] # type: ignore
|
||||
|
||||
|
||||
def list_valid_types():
|
||||
@ -113,4 +113,4 @@ def cpp_bindata_subtype_type_name(name):
|
||||
# type: (str) -> str
|
||||
"""Return the C++ type name for a bindata subtype."""
|
||||
assert is_valid_bindata_subtype(name)
|
||||
return _BINDATA_SUBTYPE[name]['bindata_enum'] # type: ignore
|
||||
return _BINDATA_SUBTYPE[name]["bindata_enum"] # type: ignore
|
||||
|
||||
@ -48,7 +48,7 @@ def title_case(name):
|
||||
# Only capitalize the last part of a fully-qualified name
|
||||
pos = name.rfind("::")
|
||||
if pos > -1:
|
||||
return name[:pos + 2] + name[pos + 2:pos + 3].upper() + name[pos + 3:]
|
||||
return name[: pos + 2] + name[pos + 2 : pos + 3].upper() + name[pos + 3 :]
|
||||
|
||||
return name[0:1].upper() + name[1:]
|
||||
|
||||
@ -72,9 +72,9 @@ def _escape_template_string(template):
|
||||
# type: (str) -> str
|
||||
"""Escape the '$' in template strings unless followed by '{'."""
|
||||
# See https://docs.python.org/2/library/string.html#template-strings
|
||||
template = template.replace('${', '#{')
|
||||
template = template.replace('$', '$$')
|
||||
return template.replace('#{', '${')
|
||||
template = template.replace("${", "#{")
|
||||
template = template.replace("$", "$$")
|
||||
return template.replace("#{", "${")
|
||||
|
||||
|
||||
def template_format(template, template_params=None):
|
||||
|
||||
@ -80,26 +80,42 @@ class CompilerImportResolver(parser.ImportResolverBase):
|
||||
logging.debug("Resolving imported file '%s' for file '%s'", imported_file_name, base_file)
|
||||
|
||||
# Check for fully-qualified paths
|
||||
logging.debug("Checking for imported file '%s' for file '%s' at '%s'", imported_file_name,
|
||||
base_file, imported_file_name)
|
||||
logging.debug(
|
||||
"Checking for imported file '%s' for file '%s' at '%s'",
|
||||
imported_file_name,
|
||||
base_file,
|
||||
imported_file_name,
|
||||
)
|
||||
if os.path.isabs(imported_file_name) and os.path.exists(imported_file_name):
|
||||
logging.debug("Found imported file '%s' for file '%s' at '%s'", imported_file_name,
|
||||
base_file, imported_file_name)
|
||||
logging.debug(
|
||||
"Found imported file '%s' for file '%s' at '%s'",
|
||||
imported_file_name,
|
||||
base_file,
|
||||
imported_file_name,
|
||||
)
|
||||
return imported_file_name
|
||||
|
||||
for candidate_dir in self._import_directories or []:
|
||||
base_dir = os.path.abspath(candidate_dir)
|
||||
resolved_file_name = os.path.normpath(os.path.join(base_dir, imported_file_name))
|
||||
|
||||
logging.debug("Checking for imported file '%s' for file '%s' at '%s'",
|
||||
imported_file_name, base_file, resolved_file_name)
|
||||
logging.debug(
|
||||
"Checking for imported file '%s' for file '%s' at '%s'",
|
||||
imported_file_name,
|
||||
base_file,
|
||||
resolved_file_name,
|
||||
)
|
||||
|
||||
if os.path.exists(resolved_file_name):
|
||||
logging.debug("Found imported file '%s' for file '%s' at '%s'", imported_file_name,
|
||||
base_file, resolved_file_name)
|
||||
logging.debug(
|
||||
"Found imported file '%s' for file '%s' at '%s'",
|
||||
imported_file_name,
|
||||
base_file,
|
||||
resolved_file_name,
|
||||
)
|
||||
return resolved_file_name
|
||||
|
||||
msg = ("Cannot find imported file '%s' for file '%s'" % (imported_file_name, base_file))
|
||||
msg = "Cannot find imported file '%s' for file '%s'" % (imported_file_name, base_file)
|
||||
logging.error(msg)
|
||||
|
||||
raise errors.IDLError(msg)
|
||||
@ -107,7 +123,7 @@ class CompilerImportResolver(parser.ImportResolverBase):
|
||||
def open(self, resolved_file_name):
|
||||
# type: (str) -> Any
|
||||
"""Return an io.Stream for the requested file."""
|
||||
return io.open(resolved_file_name, encoding='utf-8')
|
||||
return io.open(resolved_file_name, encoding="utf-8")
|
||||
|
||||
|
||||
def _write_dependencies(spec, write_dependencies_inline):
|
||||
@ -138,7 +154,8 @@ def _update_import_includes(args, spec, header_file_name):
|
||||
|
||||
if args.output_base_dir:
|
||||
base_include_h_file_name = os.path.relpath(
|
||||
os.path.normpath(header_file_name), os.path.normpath(args.output_base_dir))
|
||||
os.path.normpath(header_file_name), os.path.normpath(args.output_base_dir)
|
||||
)
|
||||
else:
|
||||
base_include_h_file_name = os.path.abspath(header_file_name)
|
||||
|
||||
@ -149,7 +166,7 @@ def _update_import_includes(args, spec, header_file_name):
|
||||
if not spec.globals:
|
||||
spec.globals = syntax.Global(args.input_file, -1, -1)
|
||||
|
||||
first_dir = base_include_h_file_name.split('/')[0]
|
||||
first_dir = base_include_h_file_name.split("/")[0]
|
||||
|
||||
for resolved_file_name in spec.imports.resolved_imports:
|
||||
# Guess: the file naming rules are consistent across IDL invocations
|
||||
@ -161,9 +178,10 @@ def _update_import_includes(args, spec, header_file_name):
|
||||
|
||||
if os.path.isabs(base_dir):
|
||||
include_h_file_name = os.path.join(
|
||||
base_dir, include_h_file_name[include_h_file_name.rfind(first_dir):])
|
||||
base_dir, include_h_file_name[include_h_file_name.rfind(first_dir) :]
|
||||
)
|
||||
else:
|
||||
include_h_file_name = include_h_file_name[include_h_file_name.find(first_dir):]
|
||||
include_h_file_name = include_h_file_name[include_h_file_name.find(first_dir) :]
|
||||
else:
|
||||
include_h_file_name = os.path.abspath(include_h_file_name)
|
||||
|
||||
@ -181,9 +199,12 @@ def compile_idl(args):
|
||||
logging.error("File '%s' not found", args.input_file)
|
||||
|
||||
if args.output_source is None:
|
||||
if not '.' in args.input_file:
|
||||
logging.error("File name '%s' must be end with a filename extension, such as '%s.idl'",
|
||||
args.input_file, args.input_file)
|
||||
if not "." in args.input_file:
|
||||
logging.error(
|
||||
"File name '%s' must be end with a filename extension, such as '%s.idl'",
|
||||
args.input_file,
|
||||
args.input_file,
|
||||
)
|
||||
return False
|
||||
|
||||
file_name_prefix = os.path.splitext(args.input_file)[0]
|
||||
@ -199,9 +220,10 @@ def compile_idl(args):
|
||||
args.target_arch = platform.machine()
|
||||
|
||||
# Compile the IDL through the 3 passes
|
||||
with io.open(args.input_file, encoding='utf-8') as file_stream:
|
||||
parsed_doc = parser.parse(file_stream, args.input_file,
|
||||
CompilerImportResolver(args.import_directories))
|
||||
with io.open(args.input_file, encoding="utf-8") as file_stream:
|
||||
parsed_doc = parser.parse(
|
||||
file_stream, args.input_file, CompilerImportResolver(args.import_directories)
|
||||
)
|
||||
|
||||
if not parsed_doc.errors:
|
||||
if args.write_dependencies or args.write_dependencies_inline:
|
||||
@ -215,8 +237,13 @@ def compile_idl(args):
|
||||
|
||||
bound_doc = binder.bind(parsed_doc.spec)
|
||||
if not bound_doc.errors:
|
||||
generator.generate_code(bound_doc.spec, args.target_arch, args.output_base_dir,
|
||||
header_file_name, source_file_name)
|
||||
generator.generate_code(
|
||||
bound_doc.spec,
|
||||
args.target_arch,
|
||||
args.output_base_dir,
|
||||
header_file_name,
|
||||
source_file_name,
|
||||
)
|
||||
|
||||
return True
|
||||
else:
|
||||
|
||||
@ -37,7 +37,7 @@ from . import bson
|
||||
from . import common
|
||||
from . import writer
|
||||
|
||||
_STD_ARRAY_UINT8_16 = 'std::array<std::uint8_t,16>'
|
||||
_STD_ARRAY_UINT8_16 = "std::array<std::uint8_t,16>"
|
||||
|
||||
|
||||
def is_primitive_scalar_type(cpp_type):
|
||||
@ -47,26 +47,31 @@ def is_primitive_scalar_type(cpp_type):
|
||||
|
||||
Primitive scalar types need to have a default value to prevent warnings from Coverity.
|
||||
"""
|
||||
cpp_type = cpp_type.replace(' ', '')
|
||||
cpp_type = cpp_type.replace(" ", "")
|
||||
# TODO (SERVER-50101): Remove 'multiversion::FeatureCompatibilityVersion' once IDL supports
|
||||
# a commmand cpp_type of C++ enum.
|
||||
return cpp_type in [
|
||||
'bool', 'double', 'std::int32_t', 'std::uint32_t', 'std::uint64_t', 'std::int64_t',
|
||||
'multiversion::FeatureCompatibilityVersion'
|
||||
"bool",
|
||||
"double",
|
||||
"std::int32_t",
|
||||
"std::uint32_t",
|
||||
"std::uint64_t",
|
||||
"std::int64_t",
|
||||
"multiversion::FeatureCompatibilityVersion",
|
||||
]
|
||||
|
||||
|
||||
def is_primitive_type(cpp_type):
|
||||
# type: (str) -> bool
|
||||
"""Return True if a cpp_type is a primitive type and should not be returned as reference."""
|
||||
cpp_type = cpp_type.replace(' ', '')
|
||||
cpp_type = cpp_type.replace(" ", "")
|
||||
return is_primitive_scalar_type(cpp_type) or cpp_type == _STD_ARRAY_UINT8_16
|
||||
|
||||
|
||||
def _qualify_optional_type(cpp_type):
|
||||
# type: (str) -> str
|
||||
"""Qualify the type as optional."""
|
||||
return 'boost::optional<%s>' % (cpp_type)
|
||||
return "boost::optional<%s>" % (cpp_type)
|
||||
|
||||
|
||||
def _qualify_array_type(cpp_type):
|
||||
@ -79,7 +84,7 @@ def _optionally_make_call(method_name, param):
|
||||
# type: (str, str) -> str
|
||||
"""Return a call to method_name if it is not None, otherwise return an empty string."""
|
||||
if not method_name:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
return "%s(%s);" % (method_name, param)
|
||||
|
||||
@ -173,14 +178,15 @@ class _CppTypeBasic(CppTypeBase):
|
||||
|
||||
def get_getter_body(self, member_name):
|
||||
# type: (str) -> str
|
||||
return common.template_args('return ${member_name};', member_name=member_name)
|
||||
return common.template_args("return ${member_name};", member_name=member_name)
|
||||
|
||||
def get_setter_body(self, member_name, validator_method_name):
|
||||
# type: (str, str) -> str
|
||||
return common.template_args(
|
||||
'${optionally_call_validator} ${member_name} = std::move(value);',
|
||||
optionally_call_validator=_optionally_make_call(validator_method_name,
|
||||
'value'), member_name=member_name)
|
||||
"${optionally_call_validator} ${member_name} = std::move(value);",
|
||||
optionally_call_validator=_optionally_make_call(validator_method_name, "value"),
|
||||
member_name=member_name,
|
||||
)
|
||||
|
||||
def get_transform_to_getter_type(self, expression):
|
||||
# type: (str) -> Optional[str]
|
||||
@ -222,15 +228,16 @@ class _CppTypeView(CppTypeBase):
|
||||
|
||||
def get_getter_body(self, member_name):
|
||||
# type: (str) -> str
|
||||
return common.template_args('return ${member_name};', member_name=member_name)
|
||||
return common.template_args("return ${member_name};", member_name=member_name)
|
||||
|
||||
def get_setter_body(self, member_name, validator_method_name):
|
||||
# type: (str, str) -> str
|
||||
return common.template_args(
|
||||
'auto _tmpValue = ${value}; ${optionally_call_validator} ${member_name} = std::move(_tmpValue);',
|
||||
member_name=member_name, optionally_call_validator=_optionally_make_call(
|
||||
validator_method_name,
|
||||
'_tmpValue'), value=self.get_transform_to_storage_type("value"))
|
||||
"auto _tmpValue = ${value}; ${optionally_call_validator} ${member_name} = std::move(_tmpValue);",
|
||||
member_name=member_name,
|
||||
optionally_call_validator=_optionally_make_call(validator_method_name, "_tmpValue"),
|
||||
value=self.get_transform_to_storage_type("value"),
|
||||
)
|
||||
|
||||
def get_transform_to_getter_type(self, expression):
|
||||
# type: (str) -> Optional[str]
|
||||
@ -239,7 +246,7 @@ class _CppTypeView(CppTypeBase):
|
||||
def get_transform_to_storage_type(self, expression):
|
||||
# type: (str) -> Optional[str]
|
||||
return common.template_args(
|
||||
'${expression}.toString()',
|
||||
"${expression}.toString()",
|
||||
expression=expression,
|
||||
)
|
||||
|
||||
@ -249,7 +256,7 @@ class _CppTypeVector(CppTypeBase):
|
||||
|
||||
def __init__(self, field):
|
||||
# type: (ast.Field) -> None
|
||||
super(_CppTypeVector, self).__init__(field, 'std::vector<std::uint8_t>')
|
||||
super(_CppTypeVector, self).__init__(field, "std::vector<std::uint8_t>")
|
||||
|
||||
def get_type_name(self):
|
||||
# type: () -> str
|
||||
@ -261,7 +268,7 @@ class _CppTypeVector(CppTypeBase):
|
||||
|
||||
def get_getter_setter_type(self):
|
||||
# type: () -> str
|
||||
return 'ConstDataRange'
|
||||
return "ConstDataRange"
|
||||
|
||||
def return_by_reference(self):
|
||||
# type: () -> bool
|
||||
@ -273,27 +280,30 @@ class _CppTypeVector(CppTypeBase):
|
||||
|
||||
def get_getter_body(self, member_name):
|
||||
# type: (str) -> str
|
||||
return common.template_args('return ConstDataRange(${member_name});',
|
||||
member_name=member_name)
|
||||
return common.template_args(
|
||||
"return ConstDataRange(${member_name});", member_name=member_name
|
||||
)
|
||||
|
||||
def get_setter_body(self, member_name, validator_method_name):
|
||||
# type: (str, str) -> str
|
||||
return common.template_args(
|
||||
'auto _tmpValue = ${value}; ${optionally_call_validator} ${member_name} = std::move(_tmpValue);',
|
||||
member_name=member_name, optionally_call_validator=_optionally_make_call(
|
||||
validator_method_name,
|
||||
'_tmpValue'), value=self.get_transform_to_storage_type("value"))
|
||||
"auto _tmpValue = ${value}; ${optionally_call_validator} ${member_name} = std::move(_tmpValue);",
|
||||
member_name=member_name,
|
||||
optionally_call_validator=_optionally_make_call(validator_method_name, "_tmpValue"),
|
||||
value=self.get_transform_to_storage_type("value"),
|
||||
)
|
||||
|
||||
def get_transform_to_getter_type(self, expression):
|
||||
# type: (str) -> Optional[str]
|
||||
return common.template_args('ConstDataRange(${expression});', expression=expression)
|
||||
return common.template_args("ConstDataRange(${expression});", expression=expression)
|
||||
|
||||
def get_transform_to_storage_type(self, expression):
|
||||
# type: (str) -> Optional[str]
|
||||
return common.template_args(
|
||||
'std::vector<std::uint8_t>(reinterpret_cast<const uint8_t*>(${expression}.data()), ' +
|
||||
'reinterpret_cast<const uint8_t*>(${expression}.data()) + ${expression}.length())',
|
||||
expression=expression)
|
||||
"std::vector<std::uint8_t>(reinterpret_cast<const uint8_t*>(${expression}.data()), "
|
||||
+ "reinterpret_cast<const uint8_t*>(${expression}.data()) + ${expression}.length())",
|
||||
expression=expression,
|
||||
)
|
||||
|
||||
|
||||
class _CppTypeDelegating(CppTypeBase):
|
||||
@ -362,7 +372,7 @@ class _CppTypeArray(_CppTypeDelegating):
|
||||
# type: (str) -> str
|
||||
convert = self.get_transform_to_getter_type(member_name)
|
||||
if convert:
|
||||
return common.template_args('return ${convert};', convert=convert)
|
||||
return common.template_args("return ${convert};", convert=convert)
|
||||
return self._base.get_getter_body(member_name)
|
||||
|
||||
def get_setter_body(self, member_name, validator_method_name):
|
||||
@ -370,16 +380,18 @@ class _CppTypeArray(_CppTypeDelegating):
|
||||
convert = self.get_transform_to_storage_type("value")
|
||||
if convert:
|
||||
return common.template_args(
|
||||
'auto _tmpValue = ${convert}; ${optionally_call_validator} ${member_name} = std::move(_tmpValue);',
|
||||
member_name=member_name, optionally_call_validator=_optionally_make_call(
|
||||
validator_method_name, '_tmpValue'), convert=convert)
|
||||
"auto _tmpValue = ${convert}; ${optionally_call_validator} ${member_name} = std::move(_tmpValue);",
|
||||
member_name=member_name,
|
||||
optionally_call_validator=_optionally_make_call(validator_method_name, "_tmpValue"),
|
||||
convert=convert,
|
||||
)
|
||||
return self._base.get_setter_body(member_name, validator_method_name)
|
||||
|
||||
def get_transform_to_getter_type(self, expression):
|
||||
# type: (str) -> Optional[str]
|
||||
if self._base.get_storage_type() != self._base.get_getter_setter_type():
|
||||
return common.template_args(
|
||||
'transformVector(${expression})',
|
||||
"transformVector(${expression})",
|
||||
expression=expression,
|
||||
)
|
||||
return None
|
||||
@ -388,7 +400,7 @@ class _CppTypeArray(_CppTypeDelegating):
|
||||
# type: (str) -> Optional[str]
|
||||
if self._base.get_storage_type() != self._base.get_getter_setter_type():
|
||||
return common.template_args(
|
||||
'transformVector(${expression})',
|
||||
"transformVector(${expression})",
|
||||
expression=expression,
|
||||
)
|
||||
return None
|
||||
@ -427,13 +439,18 @@ class _CppTypeOptional(_CppTypeDelegating):
|
||||
} else {
|
||||
return boost::none;
|
||||
}
|
||||
"""), member_name=member_name, convert=convert)
|
||||
"""),
|
||||
member_name=member_name,
|
||||
convert=convert,
|
||||
)
|
||||
elif self.is_view_type():
|
||||
# For optionals around view types, do an explicit construction
|
||||
return common.template_args('return ${param_type}{${member_name}};',
|
||||
param_type=self.get_getter_setter_type(),
|
||||
member_name=member_name)
|
||||
return common.template_args('return ${member_name};', member_name=member_name)
|
||||
return common.template_args(
|
||||
"return ${param_type}{${member_name}};",
|
||||
param_type=self.get_getter_setter_type(),
|
||||
member_name=member_name,
|
||||
)
|
||||
return common.template_args("return ${member_name};", member_name=member_name)
|
||||
|
||||
def get_setter_body(self, member_name, validator_method_name):
|
||||
# type: (str, str) -> str
|
||||
@ -450,8 +467,11 @@ class _CppTypeOptional(_CppTypeDelegating):
|
||||
} else {
|
||||
${member_name} = boost::none;
|
||||
}
|
||||
"""), member_name=member_name, convert=convert,
|
||||
optionally_call_validator=_optionally_make_call(validator_method_name, '_tmpValue'))
|
||||
"""),
|
||||
member_name=member_name,
|
||||
convert=convert,
|
||||
optionally_call_validator=_optionally_make_call(validator_method_name, "_tmpValue"),
|
||||
)
|
||||
return self._base.get_setter_body(member_name, validator_method_name)
|
||||
|
||||
|
||||
@ -459,9 +479,9 @@ def get_cpp_type_from_cpp_type_name(field, cpp_type_name, array):
|
||||
# type: (ast.Field, str, bool) -> CppTypeBase
|
||||
"""Get the C++ Type information for the given C++ type name, e.g. std::string."""
|
||||
cpp_type_info: CppTypeBase
|
||||
if cpp_type_name == 'std::string':
|
||||
cpp_type_info = _CppTypeView(field, 'std::string', 'std::string', 'StringData')
|
||||
elif cpp_type_name == 'std::vector<std::uint8_t>':
|
||||
if cpp_type_name == "std::string":
|
||||
cpp_type_info = _CppTypeView(field, "std::string", "std::string", "StringData")
|
||||
elif cpp_type_name == "std::vector<std::uint8_t>":
|
||||
cpp_type_info = _CppTypeVector(field)
|
||||
else:
|
||||
cpp_type_info = _CppTypeBasic(field, cpp_type_name)
|
||||
@ -511,15 +531,17 @@ class BsonCppTypeBase(object, metaclass=ABCMeta):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def gen_serializer_expression(self, indented_writer, expression, should_shapify=False,
|
||||
is_catalog_ctxt=False):
|
||||
def gen_serializer_expression(
|
||||
self, indented_writer, expression, should_shapify=False, is_catalog_ctxt=False
|
||||
):
|
||||
# type: (writer.IndentedTextWriter, str, bool, bool) -> str
|
||||
"""Generate code with the text writer and return an expression to serialize the type."""
|
||||
pass
|
||||
|
||||
|
||||
def _call_method_or_global_function(expression, ast_type, should_shapify=False,
|
||||
is_catalog_ctxt=False):
|
||||
def _call_method_or_global_function(
|
||||
expression, ast_type, should_shapify=False, is_catalog_ctxt=False
|
||||
):
|
||||
# type: (str, ast.Type, bool, bool) -> str
|
||||
"""
|
||||
Given a fully-qualified method name, call it correctly.
|
||||
@ -529,25 +551,25 @@ def _call_method_or_global_function(expression, ast_type, should_shapify=False,
|
||||
enum deserializers/serializers which are not methods.
|
||||
"""
|
||||
method_name = ast_type.serializer
|
||||
serialization_context = 'getSerializationContext()' if ast_type.deserialize_with_tenant else ''
|
||||
shape_options = ''
|
||||
serialization_context = "getSerializationContext()" if ast_type.deserialize_with_tenant else ""
|
||||
shape_options = ""
|
||||
if should_shapify:
|
||||
shape_options = 'options'
|
||||
shape_options = "options"
|
||||
|
||||
short_method_name = writer.get_method_name(method_name)
|
||||
if writer.is_function(method_name):
|
||||
if ast_type.deserialize_with_tenant:
|
||||
if is_catalog_ctxt:
|
||||
# serializeForCatalog doesn't need a serializationContext
|
||||
serialization_context = ''
|
||||
serialization_context = ""
|
||||
method_name = method_name.replace("serialize", "serializeForCatalog")
|
||||
else:
|
||||
serialization_context = ', ' + serialization_context
|
||||
serialization_context = ", " + serialization_context
|
||||
if should_shapify:
|
||||
shape_options = ', ' + shape_options
|
||||
shape_options = ", " + shape_options
|
||||
|
||||
return common.template_args(
|
||||
'${method_name}(${expression}${shape_options}${serialization_context})',
|
||||
"${method_name}(${expression}${shape_options}${serialization_context})",
|
||||
expression=expression,
|
||||
method_name=method_name,
|
||||
shape_options=shape_options,
|
||||
@ -555,7 +577,7 @@ def _call_method_or_global_function(expression, ast_type, should_shapify=False,
|
||||
)
|
||||
|
||||
return common.template_args(
|
||||
'${expression}.${method_name}(${shape_options}${serialization_context})',
|
||||
"${expression}.${method_name}(${shape_options}${serialization_context})",
|
||||
expression=expression,
|
||||
method_name=short_method_name,
|
||||
shape_options=shape_options,
|
||||
@ -573,19 +595,23 @@ class _CommonBsonCppTypeBase(BsonCppTypeBase):
|
||||
|
||||
def gen_deserializer_expression(self, indented_writer, object_instance):
|
||||
# type: (writer.IndentedTextWriter, str) -> str
|
||||
return common.template_args('${object_instance}.${method_name}()',
|
||||
object_instance=object_instance,
|
||||
method_name=self._deserialize_method_name)
|
||||
return common.template_args(
|
||||
"${object_instance}.${method_name}()",
|
||||
object_instance=object_instance,
|
||||
method_name=self._deserialize_method_name,
|
||||
)
|
||||
|
||||
def has_serializer(self):
|
||||
# type: () -> bool
|
||||
return self._ast_type.serializer is not None
|
||||
|
||||
def gen_serializer_expression(self, indented_writer, expression, should_shapify=False,
|
||||
is_catalog_ctxt=False):
|
||||
def gen_serializer_expression(
|
||||
self, indented_writer, expression, should_shapify=False, is_catalog_ctxt=False
|
||||
):
|
||||
# type: (writer.IndentedTextWriter, str, bool, bool) -> str
|
||||
return _call_method_or_global_function(expression, self._ast_type, should_shapify,
|
||||
is_catalog_ctxt)
|
||||
return _call_method_or_global_function(
|
||||
expression, self._ast_type, should_shapify, is_catalog_ctxt
|
||||
)
|
||||
|
||||
|
||||
class _ObjectBsonCppTypeBase(BsonCppTypeBase):
|
||||
@ -596,34 +622,41 @@ class _ObjectBsonCppTypeBase(BsonCppTypeBase):
|
||||
if self._ast_type.deserializer:
|
||||
# Call a method like: Class::method(const BSONObj& value)
|
||||
indented_writer.write_line(
|
||||
common.template_args('const BSONObj localObject = ${object_instance}.Obj();',
|
||||
object_instance=object_instance))
|
||||
common.template_args(
|
||||
"const BSONObj localObject = ${object_instance}.Obj();",
|
||||
object_instance=object_instance,
|
||||
)
|
||||
)
|
||||
return "localObject"
|
||||
|
||||
# Just pass the BSONObj through without trying to parse it.
|
||||
return common.template_args('${object_instance}.Obj()', object_instance=object_instance)
|
||||
return common.template_args("${object_instance}.Obj()", object_instance=object_instance)
|
||||
|
||||
def has_serializer(self):
|
||||
# type: () -> bool
|
||||
return self._ast_type.serializer is not None
|
||||
|
||||
def gen_serializer_expression(self, indented_writer, expression, should_shapify=False,
|
||||
is_catalog_ctxt=False):
|
||||
def gen_serializer_expression(
|
||||
self, indented_writer, expression, should_shapify=False, is_catalog_ctxt=False
|
||||
):
|
||||
# type: (writer.IndentedTextWriter, str, bool, bool) -> str
|
||||
method_name = writer.get_method_name(self._ast_type.serializer)
|
||||
function_arguments = []
|
||||
# SerializationContext is tied to tenant deserialization
|
||||
if self._ast_type.deserialize_with_tenant:
|
||||
function_arguments.append('getSerializationContext()')
|
||||
function_arguments.append("getSerializationContext()")
|
||||
# Provide options if custom shapification required.
|
||||
if should_shapify:
|
||||
function_arguments.append('options')
|
||||
function_arguments.append("options")
|
||||
|
||||
indented_writer.write_line(
|
||||
common.template_args(
|
||||
'const BSONObj localObject = ${expression}.${method_name}(${function_arguments});',
|
||||
expression=expression, method_name=method_name,
|
||||
function_arguments=', '.join(function_arguments)))
|
||||
"const BSONObj localObject = ${expression}.${method_name}(${function_arguments});",
|
||||
expression=expression,
|
||||
method_name=method_name,
|
||||
function_arguments=", ".join(function_arguments),
|
||||
)
|
||||
)
|
||||
return "localObject"
|
||||
|
||||
|
||||
@ -634,25 +667,34 @@ class _ArrayBsonCppTypeBase(BsonCppTypeBase):
|
||||
# type: (writer.IndentedTextWriter, str) -> str
|
||||
if self._ast_type.deserializer:
|
||||
indented_writer.write_line(
|
||||
common.template_args('BSONArray localArray(${object_instance}.Obj());',
|
||||
object_instance=object_instance))
|
||||
common.template_args(
|
||||
"BSONArray localArray(${object_instance}.Obj());",
|
||||
object_instance=object_instance,
|
||||
)
|
||||
)
|
||||
return "localArray"
|
||||
|
||||
# Just pass the BSONObj through without trying to parse it.
|
||||
return common.template_args('BSONArray(${object_instance}.Obj())',
|
||||
object_instance=object_instance)
|
||||
return common.template_args(
|
||||
"BSONArray(${object_instance}.Obj())", object_instance=object_instance
|
||||
)
|
||||
|
||||
def has_serializer(self):
|
||||
# type: () -> bool
|
||||
return self._ast_type.serializer is not None
|
||||
|
||||
def gen_serializer_expression(self, indented_writer, expression, should_shapify=False,
|
||||
is_catalog_ctxt=False):
|
||||
def gen_serializer_expression(
|
||||
self, indented_writer, expression, should_shapify=False, is_catalog_ctxt=False
|
||||
):
|
||||
# type: (writer.IndentedTextWriter, str, bool, bool) -> str
|
||||
method_name = writer.get_method_name(self._ast_type.serializer)
|
||||
indented_writer.write_line(
|
||||
common.template_args('BSONArray localArray(${expression}.${method_name}());',
|
||||
expression=expression, method_name=method_name))
|
||||
common.template_args(
|
||||
"BSONArray localArray(${expression}.${method_name}());",
|
||||
expression=expression,
|
||||
method_name=method_name,
|
||||
)
|
||||
)
|
||||
return "localArray"
|
||||
|
||||
|
||||
@ -661,32 +703,42 @@ class _BinDataBsonCppTypeBase(BsonCppTypeBase):
|
||||
|
||||
def gen_deserializer_expression(self, indented_writer, object_instance):
|
||||
# type: (writer.IndentedTextWriter, str) -> str
|
||||
if self._ast_type.bindata_subtype == 'uuid':
|
||||
return common.template_args('uassertStatusOK(UUID::parse(${object_instance}))',
|
||||
object_instance=object_instance)
|
||||
return common.template_args('${object_instance}._binDataVector()',
|
||||
object_instance=object_instance)
|
||||
if self._ast_type.bindata_subtype == "uuid":
|
||||
return common.template_args(
|
||||
"uassertStatusOK(UUID::parse(${object_instance}))", object_instance=object_instance
|
||||
)
|
||||
return common.template_args(
|
||||
"${object_instance}._binDataVector()", object_instance=object_instance
|
||||
)
|
||||
|
||||
def has_serializer(self):
|
||||
# type: () -> bool
|
||||
return True
|
||||
|
||||
def gen_serializer_expression(self, indented_writer, expression, should_shapify=False,
|
||||
is_catalog_ctxt=False):
|
||||
def gen_serializer_expression(
|
||||
self, indented_writer, expression, should_shapify=False, is_catalog_ctxt=False
|
||||
):
|
||||
# type: (writer.IndentedTextWriter, str, bool, bool) -> str
|
||||
if self._ast_type.serializer:
|
||||
method_name = writer.get_method_name(self._ast_type.serializer)
|
||||
indented_writer.write_line(
|
||||
common.template_args('ConstDataRange tempCDR = ${expression}.${method_name}();',
|
||||
expression=expression, method_name=method_name))
|
||||
common.template_args(
|
||||
"ConstDataRange tempCDR = ${expression}.${method_name}();",
|
||||
expression=expression,
|
||||
method_name=method_name,
|
||||
)
|
||||
)
|
||||
else:
|
||||
indented_writer.write_line(
|
||||
common.template_args('ConstDataRange tempCDR(${expression});',
|
||||
expression=expression))
|
||||
common.template_args(
|
||||
"ConstDataRange tempCDR(${expression});", expression=expression
|
||||
)
|
||||
)
|
||||
|
||||
return common.template_args(
|
||||
'BSONBinData(tempCDR.data(), tempCDR.length(), ${bindata_subtype})',
|
||||
bindata_subtype=bson.cpp_bindata_subtype_type_name(self._ast_type.bindata_subtype))
|
||||
"BSONBinData(tempCDR.data(), tempCDR.length(), ${bindata_subtype})",
|
||||
bindata_subtype=bson.cpp_bindata_subtype_type_name(self._ast_type.bindata_subtype),
|
||||
)
|
||||
|
||||
|
||||
# For some types, we want to support custom serialization but defer most of the serialization to
|
||||
@ -700,19 +752,19 @@ def get_bson_cpp_type(ast_type):
|
||||
if len(ast_type.bson_serialization_type) > 1:
|
||||
return None
|
||||
|
||||
if ast_type.bson_serialization_type[0] == 'string':
|
||||
if ast_type.bson_serialization_type[0] == "string":
|
||||
return _CommonBsonCppTypeBase(ast_type, "valueStringData")
|
||||
|
||||
if ast_type.bson_serialization_type[0] == 'object':
|
||||
if ast_type.bson_serialization_type[0] == "object":
|
||||
return _ObjectBsonCppTypeBase(ast_type)
|
||||
|
||||
if ast_type.bson_serialization_type[0] == 'array':
|
||||
if ast_type.bson_serialization_type[0] == "array":
|
||||
return _ArrayBsonCppTypeBase(ast_type)
|
||||
|
||||
if ast_type.bson_serialization_type[0] == 'bindata':
|
||||
if ast_type.bson_serialization_type[0] == "bindata":
|
||||
return _BinDataBsonCppTypeBase(ast_type)
|
||||
|
||||
if ast_type.bson_serialization_type[0] == 'int':
|
||||
if ast_type.bson_serialization_type[0] == "int":
|
||||
return _CommonBsonCppTypeBase(ast_type, "_numberInt")
|
||||
|
||||
# Unsupported type
|
||||
|
||||
@ -71,32 +71,37 @@ class EnumTypeInfoBase(object, metaclass=ABCMeta):
|
||||
def _get_enum_deserializer_name(self):
|
||||
# type: () -> str
|
||||
"""Return the name of deserializer function without prefix."""
|
||||
return common.template_args("${enum_name}_parse",
|
||||
enum_name=common.title_case(self._enum.name))
|
||||
return common.template_args(
|
||||
"${enum_name}_parse", enum_name=common.title_case(self._enum.name)
|
||||
)
|
||||
|
||||
def get_enum_deserializer_name(self):
|
||||
# type: () -> str
|
||||
"""Return the name of deserializer function with non-method prefix."""
|
||||
return "::" + common.qualify_cpp_name(self._enum.cpp_namespace,
|
||||
self._get_enum_deserializer_name())
|
||||
return "::" + common.qualify_cpp_name(
|
||||
self._enum.cpp_namespace, self._get_enum_deserializer_name()
|
||||
)
|
||||
|
||||
def _get_enum_serializer_name(self):
|
||||
# type: () -> str
|
||||
"""Return the name of serializer function without prefix."""
|
||||
return common.template_args("${enum_name}_serializer",
|
||||
enum_name=common.title_case(self._enum.name))
|
||||
return common.template_args(
|
||||
"${enum_name}_serializer", enum_name=common.title_case(self._enum.name)
|
||||
)
|
||||
|
||||
def get_enum_serializer_name(self):
|
||||
# type: () -> str
|
||||
"""Return the name of serializer function with non-method prefix."""
|
||||
return "::" + common.qualify_cpp_name(self._enum.cpp_namespace,
|
||||
self._get_enum_serializer_name())
|
||||
return "::" + common.qualify_cpp_name(
|
||||
self._enum.cpp_namespace, self._get_enum_serializer_name()
|
||||
)
|
||||
|
||||
def _get_enum_extra_data_name(self):
|
||||
# type: () -> str
|
||||
"""Return the name of the get_extra_data function without prefix."""
|
||||
return common.template_args("${enum_name}_get_extra_data",
|
||||
enum_name=common.title_case(self._enum.name))
|
||||
return common.template_args(
|
||||
"${enum_name}_get_extra_data", enum_name=common.title_case(self._enum.name)
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_cpp_value_assignment(self, enum_value):
|
||||
@ -139,9 +144,11 @@ class EnumTypeInfoBase(object, metaclass=ABCMeta):
|
||||
if len(self._get_populated_extra_values()) == 0:
|
||||
return None
|
||||
|
||||
return common.template_args("BSONObj ${function_name}(${enum_name} value)",
|
||||
enum_name=self.get_cpp_type_name(),
|
||||
function_name=self._get_enum_extra_data_name())
|
||||
return common.template_args(
|
||||
"BSONObj ${function_name}(${enum_name} value)",
|
||||
enum_name=self.get_cpp_type_name(),
|
||||
function_name=self._get_enum_extra_data_name(),
|
||||
)
|
||||
|
||||
def gen_extra_data_definition(self, indented_writer):
|
||||
# type: (writer.IndentedTextWriter) -> None
|
||||
@ -154,26 +161,30 @@ class EnumTypeInfoBase(object, metaclass=ABCMeta):
|
||||
|
||||
# Generate an anonymous namespace full of BSON constants.
|
||||
#
|
||||
with writer.NamespaceScopeBlock(indented_writer, ['']):
|
||||
with writer.NamespaceScopeBlock(indented_writer, [""]):
|
||||
for enum_value in extra_values:
|
||||
indented_writer.write_line(
|
||||
common.template_args('// %s' % json.dumps(enum_value.extra_data)))
|
||||
common.template_args("// %s" % json.dumps(enum_value.extra_data))
|
||||
)
|
||||
|
||||
bson_value = ''.join(
|
||||
[('\\x%02x' % (b)) for b in bson.BSON.encode(enum_value.extra_data)])
|
||||
bson_value = "".join(
|
||||
[("\\x%02x" % (b)) for b in bson.BSON.encode(enum_value.extra_data)]
|
||||
)
|
||||
indented_writer.write_line(
|
||||
common.template_args(
|
||||
'const BSONObj ${const_name}("${bson_value}");',
|
||||
const_name=_get_constant_enum_extra_data_name(self._enum, enum_value),
|
||||
bson_value=bson_value))
|
||||
bson_value=bson_value,
|
||||
)
|
||||
)
|
||||
|
||||
indented_writer.write_empty_line()
|
||||
|
||||
# Generate implementation of get_extra_data function.
|
||||
#
|
||||
template_params = {
|
||||
'enum_name': self.get_cpp_type_name(),
|
||||
'function_name': self.get_extra_data_declaration(),
|
||||
"enum_name": self.get_cpp_type_name(),
|
||||
"function_name": self.get_extra_data_declaration(),
|
||||
}
|
||||
|
||||
with writer.TemplateContext(indented_writer, template_params):
|
||||
@ -181,16 +192,19 @@ class EnumTypeInfoBase(object, metaclass=ABCMeta):
|
||||
with writer.IndentedScopedBlock(indented_writer, "switch (value) {", "}"):
|
||||
for enum_value in extra_values:
|
||||
indented_writer.write_template(
|
||||
'case ${enum_name}::%s: return %s;' %
|
||||
(enum_value.name,
|
||||
_get_constant_enum_extra_data_name(self._enum, enum_value)))
|
||||
"case ${enum_name}::%s: return %s;"
|
||||
% (
|
||||
enum_value.name,
|
||||
_get_constant_enum_extra_data_name(self._enum, enum_value),
|
||||
)
|
||||
)
|
||||
if len(extra_values) != len(self._enum.values):
|
||||
# One or more enums does not have associated extra data.
|
||||
indented_writer.write_line('default: return BSONObj();')
|
||||
indented_writer.write_line("default: return BSONObj();")
|
||||
|
||||
if len(extra_values) == len(self._enum.values):
|
||||
# All enum cases handled, the compiler should know this.
|
||||
indented_writer.write_line('MONGO_UNREACHABLE;')
|
||||
indented_writer.write_line("MONGO_UNREACHABLE;")
|
||||
|
||||
|
||||
class _EnumTypeInt(EnumTypeInfoBase, metaclass=ABCMeta):
|
||||
@ -212,17 +226,19 @@ class _EnumTypeInt(EnumTypeInfoBase, metaclass=ABCMeta):
|
||||
# type: () -> str
|
||||
return common.template_args(
|
||||
"${enum_name} ${function_name}(const IDLParserContext& ctxt, std::int32_t value)",
|
||||
enum_name=self.get_cpp_type_name(), function_name=self._get_enum_deserializer_name())
|
||||
enum_name=self.get_cpp_type_name(),
|
||||
function_name=self._get_enum_deserializer_name(),
|
||||
)
|
||||
|
||||
def gen_deserializer_definition(self, indented_writer):
|
||||
# type: (writer.IndentedTextWriter) -> None
|
||||
enum_values = sorted(cast(ast.Enum, self._enum).values, key=lambda ev: int(ev.value))
|
||||
|
||||
template_params = {
|
||||
'enum_name': self.get_cpp_type_name(),
|
||||
'function_name': self.get_deserializer_declaration(),
|
||||
'min_value': enum_values[0].name,
|
||||
'max_value': enum_values[-1].name,
|
||||
"enum_name": self.get_cpp_type_name(),
|
||||
"function_name": self.get_deserializer_declaration(),
|
||||
"min_value": enum_values[0].name,
|
||||
"max_value": enum_values[-1].name,
|
||||
}
|
||||
|
||||
with writer.TemplateContext(indented_writer, template_params):
|
||||
@ -237,33 +253,39 @@ class _EnumTypeInt(EnumTypeInfoBase, metaclass=ABCMeta):
|
||||
} else {
|
||||
return static_cast<${enum_name}>(value);
|
||||
}
|
||||
"""))
|
||||
""")
|
||||
)
|
||||
|
||||
def get_serializer_declaration(self):
|
||||
# type: () -> str
|
||||
"""Get the serializer function declaration minus trailing semicolon."""
|
||||
return common.template_args("std::int32_t ${function_name}(${enum_name} value)",
|
||||
enum_name=self.get_cpp_type_name(),
|
||||
function_name=self._get_enum_serializer_name())
|
||||
return common.template_args(
|
||||
"std::int32_t ${function_name}(${enum_name} value)",
|
||||
enum_name=self.get_cpp_type_name(),
|
||||
function_name=self._get_enum_serializer_name(),
|
||||
)
|
||||
|
||||
def gen_serializer_definition(self, indented_writer):
|
||||
# type: (writer.IndentedTextWriter) -> None
|
||||
"""Generate the serializer function definition."""
|
||||
template_params = {
|
||||
'enum_name': self.get_cpp_type_name(),
|
||||
'function_name': self.get_serializer_declaration(),
|
||||
"enum_name": self.get_cpp_type_name(),
|
||||
"function_name": self.get_serializer_declaration(),
|
||||
}
|
||||
|
||||
with writer.TemplateContext(indented_writer, template_params):
|
||||
with writer.IndentedScopedBlock(indented_writer, "${function_name} {", "}"):
|
||||
indented_writer.write_template('return static_cast<std::int32_t>(value);')
|
||||
indented_writer.write_template("return static_cast<std::int32_t>(value);")
|
||||
|
||||
|
||||
def _get_constant_enum_extra_data_name(idl_enum, enum_value):
|
||||
# type: (Union[syntax.Enum,ast.Enum], Union[syntax.EnumValue,ast.EnumValue]) -> str
|
||||
"""Return the C++ name for a string constant of enum extra data value."""
|
||||
return common.template_args('k${enum_name}_${name}_extra_data',
|
||||
enum_name=common.title_case(idl_enum.name), name=enum_value.name)
|
||||
return common.template_args(
|
||||
"k${enum_name}_${name}_extra_data",
|
||||
enum_name=common.title_case(idl_enum.name),
|
||||
name=enum_value.name,
|
||||
)
|
||||
|
||||
|
||||
class _EnumTypeString(EnumTypeInfoBase, metaclass=ABCMeta):
|
||||
@ -271,8 +293,9 @@ class _EnumTypeString(EnumTypeInfoBase, metaclass=ABCMeta):
|
||||
|
||||
def get_cpp_type_name(self):
|
||||
# type: () -> str
|
||||
return common.template_args("${enum_name}Enum",
|
||||
enum_name=common.title_case(self._enum.name))
|
||||
return common.template_args(
|
||||
"${enum_name}Enum", enum_name=common.title_case(self._enum.name)
|
||||
)
|
||||
|
||||
def get_bson_types(self):
|
||||
# type: () -> List[str]
|
||||
@ -280,65 +303,76 @@ class _EnumTypeString(EnumTypeInfoBase, metaclass=ABCMeta):
|
||||
|
||||
def get_cpp_value_assignment(self, enum_value):
|
||||
# type: (ast.EnumValue) -> str
|
||||
return ''
|
||||
return ""
|
||||
|
||||
def get_deserializer_declaration(self):
|
||||
# type: () -> str
|
||||
cpp_type = self.get_cpp_type_name()
|
||||
func = self._get_enum_deserializer_name()
|
||||
return f'{cpp_type} {func}(const IDLParserContext& ctxt, StringData value)'
|
||||
return f"{cpp_type} {func}(const IDLParserContext& ctxt, StringData value)"
|
||||
|
||||
def gen_deserializer_definition(self, indented_writer):
|
||||
# type: (writer.IndentedTextWriter) -> None
|
||||
cpp_type = self.get_cpp_type_name()
|
||||
func = self._get_enum_deserializer_name()
|
||||
with writer.NamespaceScopeBlock(indented_writer, ['']):
|
||||
with writer.IndentedScopedBlock(indented_writer,
|
||||
f'constexpr std::array {cpp_type}_values{{', '};'):
|
||||
with writer.NamespaceScopeBlock(indented_writer, [""]):
|
||||
with writer.IndentedScopedBlock(
|
||||
indented_writer, f"constexpr std::array {cpp_type}_values{{", "};"
|
||||
):
|
||||
for e in self._enum.values:
|
||||
indented_writer.write_line(f'{cpp_type}::{e.name},')
|
||||
with writer.IndentedScopedBlock(indented_writer,
|
||||
f'constexpr std::array {cpp_type}_names{{', '};'):
|
||||
indented_writer.write_line(f"{cpp_type}::{e.name},")
|
||||
with writer.IndentedScopedBlock(
|
||||
indented_writer, f"constexpr std::array {cpp_type}_names{{", "};"
|
||||
):
|
||||
for e in self._enum.values:
|
||||
indented_writer.write_line(f'"{e.value}"_sd,')
|
||||
indented_writer.write_empty_line()
|
||||
|
||||
with writer.IndentedScopedBlock(
|
||||
indented_writer,
|
||||
f"{cpp_type} {func}(const IDLParserContext& ctxt, StringData value) {{",
|
||||
"}",
|
||||
):
|
||||
indented_writer.write_line(
|
||||
f"static constexpr auto onMatch = [](int i) {{ return {cpp_type}_values[i]; }};"
|
||||
)
|
||||
indented_writer.write_line(
|
||||
f"auto onFail = [&] {{ ctxt.throwBadEnumValue(value); return {cpp_type}{{}}; }};"
|
||||
)
|
||||
writer.gen_string_table_find_function_block(
|
||||
indented_writer,
|
||||
f'{cpp_type} {func}(const IDLParserContext& ctxt, StringData value) {{', '}'):
|
||||
indented_writer.write_line(
|
||||
f'static constexpr auto onMatch = [](int i) {{ return {cpp_type}_values[i]; }};')
|
||||
indented_writer.write_line(
|
||||
f'auto onFail = [&] {{ ctxt.throwBadEnumValue(value); return {cpp_type}{{}}; }};')
|
||||
writer.gen_string_table_find_function_block(indented_writer, 'value', 'onMatch({})',
|
||||
'onFail()',
|
||||
[e.value for e in self._enum.values])
|
||||
"value",
|
||||
"onMatch({})",
|
||||
"onFail()",
|
||||
[e.value for e in self._enum.values],
|
||||
)
|
||||
|
||||
def get_serializer_declaration(self):
|
||||
# type: () -> str
|
||||
"""Get the serializer function declaration minus trailing semicolon."""
|
||||
cpp_type = self.get_cpp_type_name()
|
||||
func = self._get_enum_serializer_name()
|
||||
return f'StringData {func}({cpp_type} value)'
|
||||
return f"StringData {func}({cpp_type} value)"
|
||||
|
||||
def gen_serializer_definition(self, indented_writer):
|
||||
# type: (writer.IndentedTextWriter) -> None
|
||||
"""Generate the serializer function definition."""
|
||||
func = self._get_enum_serializer_name()
|
||||
cpp_type = self.get_cpp_type_name()
|
||||
with writer.IndentedScopedBlock(indented_writer, f'StringData {func}({cpp_type} value) {{',
|
||||
'}'):
|
||||
indented_writer.write_line('auto idx = static_cast<size_t>(value);')
|
||||
indented_writer.write_line(f'invariant(idx < {cpp_type}_names.size());')
|
||||
indented_writer.write_line(f'return {cpp_type}_names[idx];')
|
||||
with writer.IndentedScopedBlock(
|
||||
indented_writer, f"StringData {func}({cpp_type} value) {{", "}"
|
||||
):
|
||||
indented_writer.write_line("auto idx = static_cast<size_t>(value);")
|
||||
indented_writer.write_line(f"invariant(idx < {cpp_type}_names.size());")
|
||||
indented_writer.write_line(f"return {cpp_type}_names[idx];")
|
||||
|
||||
|
||||
def get_type_info(idl_enum):
|
||||
# type: (Union[syntax.Enum,ast.Enum]) -> Optional[EnumTypeInfoBase]
|
||||
"""Get the type information for a given enumeration type, return None if not supported."""
|
||||
if idl_enum.type == 'int':
|
||||
if idl_enum.type == "int":
|
||||
return _EnumTypeInt(idl_enum)
|
||||
elif idl_enum.type == 'string':
|
||||
elif idl_enum.type == "string":
|
||||
return _EnumTypeString(idl_enum)
|
||||
|
||||
return None
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -43,7 +43,7 @@ class FieldListInfo:
|
||||
# type: () -> MethodInfo
|
||||
"""Get the hasField method for a generic argument or generic reply field list."""
|
||||
class_name = common.title_case(self.struct.cpp_name)
|
||||
return MethodInfo(class_name, 'hasField', ['StringData fieldName'], 'bool', static=True)
|
||||
return MethodInfo(class_name, "hasField", ["StringData fieldName"], "bool", static=True)
|
||||
|
||||
def get_should_forward_name(self):
|
||||
"""Get the name of the shard-forwarding rule for a generic argument or reply field."""
|
||||
@ -62,8 +62,13 @@ class FieldListInfo:
|
||||
# type: () -> MethodInfo
|
||||
"""Get the method for checking the shard-forwarding rule of an argument or reply field."""
|
||||
class_name = common.title_case(self.struct.cpp_name)
|
||||
return MethodInfo(class_name, self.get_should_forward_name(), ['StringData fieldName'],
|
||||
'bool', static=True)
|
||||
return MethodInfo(
|
||||
class_name,
|
||||
self.get_should_forward_name(),
|
||||
["StringData fieldName"],
|
||||
"bool",
|
||||
static=True,
|
||||
)
|
||||
|
||||
|
||||
def get_field_list_info(struct):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -40,8 +40,14 @@ from . import writer
|
||||
def _is_required_constructor_arg(field):
|
||||
# type: (ast.Field) -> bool
|
||||
"""Get whether we require this field to have a value set for constructor purposes."""
|
||||
return not field.ignore and not field.optional and not field.default and not field.chained \
|
||||
and not field.chained_struct_field and not field.serialize_op_msg_request_only
|
||||
return (
|
||||
not field.ignore
|
||||
and not field.optional
|
||||
and not field.default
|
||||
and not field.chained
|
||||
and not field.chained_struct_field
|
||||
and not field.serialize_op_msg_request_only
|
||||
)
|
||||
|
||||
|
||||
def _get_arg_for_field(field):
|
||||
@ -66,7 +72,7 @@ def _get_required_parameters(struct):
|
||||
|
||||
|
||||
def _get_serialization_ctx_arg():
|
||||
return 'boost::optional<SerializationContext> serializationContext = boost::none'
|
||||
return "boost::optional<SerializationContext> serializationContext = boost::none"
|
||||
|
||||
|
||||
class ArgumentInfo(object):
|
||||
@ -76,12 +82,12 @@ class ArgumentInfo(object):
|
||||
# type: (str) -> None
|
||||
"""Create a instance of the ArgumentInfo class by parsing the argument string."""
|
||||
self.defaults = None
|
||||
equal_tokens = arg.split('=')
|
||||
equal_tokens = arg.split("=")
|
||||
if len(equal_tokens) > 1:
|
||||
self.defaults = equal_tokens[-1].strip()
|
||||
|
||||
space_tokens = equal_tokens[0].strip().split(' ')
|
||||
self.type = ' '.join(space_tokens[0:-1])
|
||||
space_tokens = equal_tokens[0].strip().split(" ")
|
||||
self.type = " ".join(space_tokens[0:-1])
|
||||
self.name = space_tokens[-1]
|
||||
|
||||
def get_string(self, get_defaults):
|
||||
@ -97,8 +103,17 @@ class MethodInfo(object):
|
||||
|
||||
# pylint: disable=too-many-instance-attributes
|
||||
|
||||
def __init__(self, class_name, method_name, args, return_type=None, static=False, const=False,
|
||||
explicit=False, desc_for_comment=None):
|
||||
def __init__(
|
||||
self,
|
||||
class_name,
|
||||
method_name,
|
||||
args,
|
||||
return_type=None,
|
||||
static=False,
|
||||
const=False,
|
||||
explicit=False,
|
||||
desc_for_comment=None,
|
||||
):
|
||||
# type: (str, str, List[str], str, bool, bool, bool, Optional[str]) -> None
|
||||
# pylint: disable=too-many-arguments
|
||||
"""Create a MethodInfo instance."""
|
||||
@ -114,59 +129,68 @@ class MethodInfo(object):
|
||||
def get_declaration(self):
|
||||
# type: () -> str
|
||||
"""Get a declaration for a method."""
|
||||
pre_modifiers = ''
|
||||
post_modifiers = ''
|
||||
return_type_str = ''
|
||||
pre_modifiers = ""
|
||||
post_modifiers = ""
|
||||
return_type_str = ""
|
||||
|
||||
if self.static:
|
||||
pre_modifiers = 'static '
|
||||
pre_modifiers = "static "
|
||||
|
||||
if self.const:
|
||||
post_modifiers = ' const'
|
||||
post_modifiers = " const"
|
||||
|
||||
if self.explicit:
|
||||
pre_modifiers += 'explicit '
|
||||
pre_modifiers += "explicit "
|
||||
|
||||
if self.return_type:
|
||||
return_type_str = self.return_type + ' '
|
||||
return_type_str = self.return_type + " "
|
||||
|
||||
return common.template_args(
|
||||
"${pre_modifiers}${return_type}${method_name}(${args})${post_modifiers};",
|
||||
pre_modifiers=pre_modifiers, return_type=return_type_str, method_name=self.method_name,
|
||||
args=', '.join(
|
||||
[arg.get_string(True) for arg in self.args]), post_modifiers=post_modifiers)
|
||||
pre_modifiers=pre_modifiers,
|
||||
return_type=return_type_str,
|
||||
method_name=self.method_name,
|
||||
args=", ".join([arg.get_string(True) for arg in self.args]),
|
||||
post_modifiers=post_modifiers,
|
||||
)
|
||||
|
||||
def get_definition(self):
|
||||
# type: () -> str
|
||||
"""Get a definition for a method."""
|
||||
pre_modifiers = ''
|
||||
post_modifiers = ''
|
||||
return_type_str = ''
|
||||
pre_modifiers = ""
|
||||
post_modifiers = ""
|
||||
return_type_str = ""
|
||||
|
||||
if self.const:
|
||||
post_modifiers = ' const'
|
||||
post_modifiers = " const"
|
||||
|
||||
if self.return_type:
|
||||
return_type_str = self.return_type + ' '
|
||||
return_type_str = self.return_type + " "
|
||||
|
||||
return common.template_args(
|
||||
"${pre_modifiers}${return_type}${class_name}::${method_name}(${args})${post_modifiers}",
|
||||
pre_modifiers=pre_modifiers, return_type=return_type_str, class_name=self.class_name,
|
||||
method_name=self.method_name, args=', '.join(
|
||||
[arg.get_string(False) for arg in self.args]), post_modifiers=post_modifiers)
|
||||
pre_modifiers=pre_modifiers,
|
||||
return_type=return_type_str,
|
||||
class_name=self.class_name,
|
||||
method_name=self.method_name,
|
||||
args=", ".join([arg.get_string(False) for arg in self.args]),
|
||||
post_modifiers=post_modifiers,
|
||||
)
|
||||
|
||||
def get_call(self, obj):
|
||||
# type: (Optional[str]) -> str
|
||||
"""Generate a simple call to the method using the defined args list."""
|
||||
|
||||
args = ', '.join([arg.name for arg in self.args])
|
||||
args = ", ".join([arg.name for arg in self.args])
|
||||
|
||||
if obj:
|
||||
return common.template_args("${obj}.${method_name}(${args});", obj=obj,
|
||||
method_name=self.method_name, args=args)
|
||||
return common.template_args(
|
||||
"${obj}.${method_name}(${args});", obj=obj, method_name=self.method_name, args=args
|
||||
)
|
||||
|
||||
return common.template_args("${method_name}(${args});", method_name=self.method_name,
|
||||
args=args)
|
||||
return common.template_args(
|
||||
"${method_name}(${args});", method_name=self.method_name, args=args
|
||||
)
|
||||
|
||||
def get_desc_for_comment(self):
|
||||
# type: () -> Optional[str]
|
||||
@ -292,9 +316,14 @@ class _StructTypeInfo(StructTypeInfoBase):
|
||||
comment = textwrap.dedent(f"""\
|
||||
Factory function that parses a {class_name} from a BSONObj. A {class_name} parsed
|
||||
this way participates in ownership of the data underlying the BSONObj.""")
|
||||
return MethodInfo(class_name, 'parseSharingOwnership',
|
||||
['const IDLParserContext& ctxt', 'const BSONObj& bsonObject'], class_name,
|
||||
static=True, desc_for_comment=comment)
|
||||
return MethodInfo(
|
||||
class_name,
|
||||
"parseSharingOwnership",
|
||||
["const IDLParserContext& ctxt", "const BSONObj& bsonObject"],
|
||||
class_name,
|
||||
static=True,
|
||||
desc_for_comment=comment,
|
||||
)
|
||||
|
||||
def get_owned_deserializer_static_method(self):
|
||||
# type: () -> MethodInfo
|
||||
@ -302,9 +331,14 @@ class _StructTypeInfo(StructTypeInfoBase):
|
||||
comment = textwrap.dedent(f"""\
|
||||
Factory function that parses a {class_name} from a BSONObj. A {class_name} parsed
|
||||
this way takes ownership of the data underlying the BSONObj.""")
|
||||
return MethodInfo(class_name, 'parseOwned',
|
||||
['const IDLParserContext& ctxt', 'BSONObj&& bsonObject'], class_name,
|
||||
static=True, desc_for_comment=comment)
|
||||
return MethodInfo(
|
||||
class_name,
|
||||
"parseOwned",
|
||||
["const IDLParserContext& ctxt", "BSONObj&& bsonObject"],
|
||||
class_name,
|
||||
static=True,
|
||||
desc_for_comment=comment,
|
||||
)
|
||||
|
||||
def get_deserializer_static_method(self):
|
||||
# type: () -> MethodInfo
|
||||
@ -315,23 +349,32 @@ class _StructTypeInfo(StructTypeInfoBase):
|
||||
ensure the validity any members of this struct that point-into the BSONObj (i.e.
|
||||
unowned
|
||||
objects).""")
|
||||
return MethodInfo(class_name, 'parse',
|
||||
['const IDLParserContext& ctxt', 'const BSONObj& bsonObject'], class_name,
|
||||
static=True, desc_for_comment=comment)
|
||||
return MethodInfo(
|
||||
class_name,
|
||||
"parse",
|
||||
["const IDLParserContext& ctxt", "const BSONObj& bsonObject"],
|
||||
class_name,
|
||||
static=True,
|
||||
desc_for_comment=comment,
|
||||
)
|
||||
|
||||
def get_deserializer_method(self):
|
||||
# type: () -> MethodInfo
|
||||
return MethodInfo(
|
||||
common.title_case(self._struct.cpp_name), 'parseProtected',
|
||||
['const IDLParserContext& ctxt', 'const BSONObj& bsonObject'], 'void')
|
||||
common.title_case(self._struct.cpp_name),
|
||||
"parseProtected",
|
||||
["const IDLParserContext& ctxt", "const BSONObj& bsonObject"],
|
||||
"void",
|
||||
)
|
||||
|
||||
def get_serializer_method(self):
|
||||
# type: () -> MethodInfo
|
||||
args = ['BSONObjBuilder* builder']
|
||||
args = ["BSONObjBuilder* builder"]
|
||||
if self._struct.query_shape_component:
|
||||
args.append("const SerializationOptions& options = {}")
|
||||
return MethodInfo(
|
||||
common.title_case(self._struct.cpp_name), 'serialize', args, 'void', const=True)
|
||||
common.title_case(self._struct.cpp_name), "serialize", args, "void", const=True
|
||||
)
|
||||
|
||||
def get_to_bson_method(self):
|
||||
# type: () -> MethodInfo
|
||||
@ -339,7 +382,8 @@ class _StructTypeInfo(StructTypeInfoBase):
|
||||
if self._struct.query_shape_component:
|
||||
args.append("const SerializationOptions& options = {}")
|
||||
return MethodInfo(
|
||||
common.title_case(self._struct.cpp_name), 'toBSON', args, 'BSONObj', const=True)
|
||||
common.title_case(self._struct.cpp_name), "toBSON", args, "BSONObj", const=True
|
||||
)
|
||||
|
||||
def get_op_msg_request_serializer_method(self):
|
||||
# type: () -> Optional[MethodInfo]
|
||||
@ -383,21 +427,32 @@ class _CommandBaseTypeInfo(_StructTypeInfo):
|
||||
def get_op_msg_request_serializer_method(self):
|
||||
# type: () -> Optional[MethodInfo]
|
||||
return MethodInfo(
|
||||
common.title_case(self._struct.cpp_name), 'serialize',
|
||||
['const BSONObj& commandPassthroughFields = {}'], 'OpMsgRequest', const=True)
|
||||
common.title_case(self._struct.cpp_name),
|
||||
"serialize",
|
||||
["const BSONObj& commandPassthroughFields = {}"],
|
||||
"OpMsgRequest",
|
||||
const=True,
|
||||
)
|
||||
|
||||
def get_op_msg_request_deserializer_static_method(self):
|
||||
# type: () -> Optional[MethodInfo]
|
||||
class_name = common.title_case(self._struct.cpp_name)
|
||||
return MethodInfo(class_name, 'parse',
|
||||
['const IDLParserContext& ctxt', 'const OpMsgRequest& request'],
|
||||
class_name, static=True)
|
||||
return MethodInfo(
|
||||
class_name,
|
||||
"parse",
|
||||
["const IDLParserContext& ctxt", "const OpMsgRequest& request"],
|
||||
class_name,
|
||||
static=True,
|
||||
)
|
||||
|
||||
def get_op_msg_request_deserializer_method(self):
|
||||
# type: () -> Optional[MethodInfo]
|
||||
return MethodInfo(
|
||||
common.title_case(self._struct.cpp_name), 'parseProtected',
|
||||
['const IDLParserContext& ctxt', 'const OpMsgRequest& request'], 'void')
|
||||
common.title_case(self._struct.cpp_name),
|
||||
"parseProtected",
|
||||
["const IDLParserContext& ctxt", "const OpMsgRequest& request"],
|
||||
"void",
|
||||
)
|
||||
|
||||
|
||||
class _IgnoredCommandTypeInfo(_CommandBaseTypeInfo):
|
||||
@ -413,16 +468,23 @@ class _IgnoredCommandTypeInfo(_CommandBaseTypeInfo):
|
||||
def get_serializer_method(self):
|
||||
# type: () -> MethodInfo
|
||||
return MethodInfo(
|
||||
common.title_case(self._struct.cpp_name), 'serialize',
|
||||
['BSONObjBuilder* builder', 'const BSONObj& commandPassthroughFields = {}'], 'void',
|
||||
const=True)
|
||||
common.title_case(self._struct.cpp_name),
|
||||
"serialize",
|
||||
["BSONObjBuilder* builder", "const BSONObj& commandPassthroughFields = {}"],
|
||||
"void",
|
||||
const=True,
|
||||
)
|
||||
|
||||
def get_to_bson_method(self):
|
||||
# type: () -> MethodInfo
|
||||
# Commands that require namespaces require it as a parameter to serialize()
|
||||
return MethodInfo(
|
||||
common.title_case(self._struct.cpp_name), 'toBSON',
|
||||
['const BSONObj& commandPassthroughFields = {}'], 'BSONObj', const=True)
|
||||
common.title_case(self._struct.cpp_name),
|
||||
"toBSON",
|
||||
["const BSONObj& commandPassthroughFields = {}"],
|
||||
"BSONObj",
|
||||
const=True,
|
||||
)
|
||||
|
||||
def gen_serializer(self, indented_writer):
|
||||
# type: (writer.IndentedTextWriter) -> None
|
||||
@ -440,8 +502,8 @@ def _get_command_type_parameter(command, gen_header=False):
|
||||
# Use the storage type for the constructor argument since the generated code will use std::move.
|
||||
member_type = cpp_type_info.get_storage_type()
|
||||
result = f"{member_type} {common.camel_case(command.command_field.cpp_name)}"
|
||||
if not gen_header or '&' in result:
|
||||
result = 'const ' + result
|
||||
if not gen_header or "&" in result:
|
||||
result = "const " + result
|
||||
return result
|
||||
|
||||
|
||||
@ -468,27 +530,38 @@ class _CommandFromType(_CommandBaseTypeInfo):
|
||||
class_name = common.title_case(self._struct.cpp_name)
|
||||
|
||||
arg = _get_command_type_parameter(self._command, gen_header)
|
||||
return MethodInfo(class_name, class_name, [arg] + _get_required_parameters(self._struct),
|
||||
explicit=True)
|
||||
return MethodInfo(
|
||||
class_name, class_name, [arg] + _get_required_parameters(self._struct), explicit=True
|
||||
)
|
||||
|
||||
def get_serializer_method(self):
|
||||
# type: () -> MethodInfo
|
||||
return MethodInfo(
|
||||
common.title_case(self._struct.cpp_name), 'serialize',
|
||||
['BSONObjBuilder* builder', 'const BSONObj& commandPassthroughFields = {}'], 'void',
|
||||
const=True)
|
||||
common.title_case(self._struct.cpp_name),
|
||||
"serialize",
|
||||
["BSONObjBuilder* builder", "const BSONObj& commandPassthroughFields = {}"],
|
||||
"void",
|
||||
const=True,
|
||||
)
|
||||
|
||||
def get_to_bson_method(self):
|
||||
# type: () -> MethodInfo
|
||||
return MethodInfo(
|
||||
common.title_case(self._struct.cpp_name), 'toBSON',
|
||||
['const BSONObj& commandPassthroughFields = {}'], 'BSONObj', const=True)
|
||||
common.title_case(self._struct.cpp_name),
|
||||
"toBSON",
|
||||
["const BSONObj& commandPassthroughFields = {}"],
|
||||
"BSONObj",
|
||||
const=True,
|
||||
)
|
||||
|
||||
def get_deserializer_method(self):
|
||||
# type: () -> MethodInfo
|
||||
return MethodInfo(
|
||||
common.title_case(self._struct.cpp_name), 'parseProtected',
|
||||
['const IDLParserContext& ctxt', 'const BSONObj& bsonObject'], 'void')
|
||||
common.title_case(self._struct.cpp_name),
|
||||
"parseProtected",
|
||||
["const IDLParserContext& ctxt", "const BSONObj& bsonObject"],
|
||||
"void",
|
||||
)
|
||||
|
||||
def gen_getter_method(self, indented_writer):
|
||||
# type: (writer.IndentedTextWriter) -> None
|
||||
@ -520,78 +593,96 @@ class _CommandWithNamespaceTypeInfo(_CommandBaseTypeInfo):
|
||||
|
||||
@staticmethod
|
||||
def _get_nss_param(gen_header):
|
||||
nss_param = 'NamespaceString nss'
|
||||
nss_param = "NamespaceString nss"
|
||||
if not gen_header:
|
||||
nss_param = 'const ' + nss_param
|
||||
nss_param = "const " + nss_param
|
||||
return nss_param
|
||||
|
||||
def get_constructor_method(self, gen_header=False):
|
||||
# type: (bool) -> MethodInfo
|
||||
class_name = common.title_case(self._struct.cpp_name)
|
||||
sc_arg = _get_serialization_ctx_arg()
|
||||
return MethodInfo(class_name, class_name, [self._get_nss_param(gen_header), sc_arg],
|
||||
explicit=True)
|
||||
return MethodInfo(
|
||||
class_name, class_name, [self._get_nss_param(gen_header), sc_arg], explicit=True
|
||||
)
|
||||
|
||||
def get_required_constructor_method(self, gen_header=False):
|
||||
# type: (bool) -> MethodInfo
|
||||
class_name = common.title_case(self._struct.cpp_name)
|
||||
return MethodInfo(
|
||||
class_name, class_name,
|
||||
[self._get_nss_param(gen_header)] + _get_required_parameters(self._struct))
|
||||
class_name,
|
||||
class_name,
|
||||
[self._get_nss_param(gen_header)] + _get_required_parameters(self._struct),
|
||||
)
|
||||
|
||||
def get_serializer_method(self):
|
||||
# type: () -> MethodInfo
|
||||
return MethodInfo(
|
||||
common.title_case(self._struct.cpp_name), 'serialize',
|
||||
['BSONObjBuilder* builder', 'const BSONObj& commandPassthroughFields = {}'], 'void',
|
||||
const=True)
|
||||
common.title_case(self._struct.cpp_name),
|
||||
"serialize",
|
||||
["BSONObjBuilder* builder", "const BSONObj& commandPassthroughFields = {}"],
|
||||
"void",
|
||||
const=True,
|
||||
)
|
||||
|
||||
def get_to_bson_method(self):
|
||||
# type: () -> MethodInfo
|
||||
return MethodInfo(
|
||||
common.title_case(self._struct.cpp_name), 'toBSON',
|
||||
['const BSONObj& commandPassthroughFields = {}'], 'BSONObj', const=True)
|
||||
common.title_case(self._struct.cpp_name),
|
||||
"toBSON",
|
||||
["const BSONObj& commandPassthroughFields = {}"],
|
||||
"BSONObj",
|
||||
const=True,
|
||||
)
|
||||
|
||||
def get_deserializer_method(self):
|
||||
# type: () -> MethodInfo
|
||||
return MethodInfo(
|
||||
common.title_case(self._struct.cpp_name), 'parseProtected',
|
||||
['const IDLParserContext& ctxt', 'const BSONObj& bsonObject'], 'void')
|
||||
common.title_case(self._struct.cpp_name),
|
||||
"parseProtected",
|
||||
["const IDLParserContext& ctxt", "const BSONObj& bsonObject"],
|
||||
"void",
|
||||
)
|
||||
|
||||
def gen_getter_method(self, indented_writer):
|
||||
# type: (writer.IndentedTextWriter) -> None
|
||||
indented_writer.write_line('const NamespaceString& getNamespace() const { return _nss; }')
|
||||
indented_writer.write_line("const NamespaceString& getNamespace() const { return _nss; }")
|
||||
if self._struct.non_const_getter:
|
||||
indented_writer.write_line('NamespaceString& getNamespace() { return _nss; }')
|
||||
indented_writer.write_line("NamespaceString& getNamespace() { return _nss; }")
|
||||
|
||||
def gen_member(self, indented_writer):
|
||||
# type: (writer.IndentedTextWriter) -> None
|
||||
indented_writer.write_line('NamespaceString _nss;')
|
||||
indented_writer.write_line("NamespaceString _nss;")
|
||||
|
||||
def gen_serializer(self, indented_writer):
|
||||
# type: (writer.IndentedTextWriter) -> None
|
||||
if self._struct.allow_global_collection_name:
|
||||
indented_writer.write_line(
|
||||
'_nss.serializeCollectionName(builder, "%s"_sd);' % (self._command.name))
|
||||
'_nss.serializeCollectionName(builder, "%s"_sd);' % (self._command.name)
|
||||
)
|
||||
else:
|
||||
indented_writer.write_line('invariant(!_nss.isEmpty());')
|
||||
indented_writer.write_line("invariant(!_nss.isEmpty());")
|
||||
indented_writer.write_line(
|
||||
'builder->append("%s"_sd, _nss.coll());' % (self._command.name))
|
||||
'builder->append("%s"_sd, _nss.coll());' % (self._command.name)
|
||||
)
|
||||
indented_writer.write_empty_line()
|
||||
|
||||
def gen_namespace_check(self, indented_writer, db_name, element):
|
||||
# type: (writer.IndentedTextWriter, str, str) -> None
|
||||
# TODO: should the name of the first element be validated??
|
||||
indented_writer.write_line('invariant(_nss.isEmpty());')
|
||||
allow_global = 'true' if self._struct.allow_global_collection_name else 'false'
|
||||
indented_writer.write_line("invariant(_nss.isEmpty());")
|
||||
allow_global = "true" if self._struct.allow_global_collection_name else "false"
|
||||
indented_writer.write_line(
|
||||
'auto collectionName = ctxt.checkAndAssertCollectionName(%s, %s);' % (element,
|
||||
allow_global))
|
||||
"auto collectionName = ctxt.checkAndAssertCollectionName(%s, %s);"
|
||||
% (element, allow_global)
|
||||
)
|
||||
indented_writer.write_line(
|
||||
'_nss = NamespaceStringUtil::deserialize(%s, collectionName);' % (db_name))
|
||||
"_nss = NamespaceStringUtil::deserialize(%s, collectionName);" % (db_name)
|
||||
)
|
||||
indented_writer.write_line(
|
||||
'uassert(ErrorCodes::InvalidNamespace, str::stream() << "Invalid namespace specified: "'
|
||||
' << _nss.toStringForErrorMsg(), _nss.isValid());')
|
||||
" << _nss.toStringForErrorMsg(), _nss.isValid());"
|
||||
)
|
||||
|
||||
|
||||
class _CommandWithUUIDNamespaceTypeInfo(_CommandBaseTypeInfo):
|
||||
@ -606,55 +697,70 @@ class _CommandWithUUIDNamespaceTypeInfo(_CommandBaseTypeInfo):
|
||||
|
||||
@staticmethod
|
||||
def _get_nss_param(gen_header):
|
||||
nss_param = 'NamespaceStringOrUUID nssOrUUID'
|
||||
nss_param = "NamespaceStringOrUUID nssOrUUID"
|
||||
if not gen_header:
|
||||
nss_param = 'const ' + nss_param
|
||||
nss_param = "const " + nss_param
|
||||
return nss_param
|
||||
|
||||
def get_constructor_method(self, gen_header=False):
|
||||
# type: (bool) -> MethodInfo
|
||||
class_name = common.title_case(self._struct.cpp_name)
|
||||
sc_arg = _get_serialization_ctx_arg()
|
||||
return MethodInfo(class_name, class_name, [self._get_nss_param(gen_header), sc_arg],
|
||||
explicit=True)
|
||||
return MethodInfo(
|
||||
class_name, class_name, [self._get_nss_param(gen_header), sc_arg], explicit=True
|
||||
)
|
||||
|
||||
def get_required_constructor_method(self, gen_header=False):
|
||||
# type: (bool) -> MethodInfo
|
||||
class_name = common.title_case(self._struct.cpp_name)
|
||||
return MethodInfo(
|
||||
class_name, class_name,
|
||||
[self._get_nss_param(gen_header)] + _get_required_parameters(self._struct))
|
||||
class_name,
|
||||
class_name,
|
||||
[self._get_nss_param(gen_header)] + _get_required_parameters(self._struct),
|
||||
)
|
||||
|
||||
def get_serializer_method(self):
|
||||
# type: () -> MethodInfo
|
||||
return MethodInfo(
|
||||
common.title_case(self._struct.cpp_name), 'serialize',
|
||||
['BSONObjBuilder* builder', 'const BSONObj& commandPassthroughFields = {}'], 'void',
|
||||
const=True)
|
||||
common.title_case(self._struct.cpp_name),
|
||||
"serialize",
|
||||
["BSONObjBuilder* builder", "const BSONObj& commandPassthroughFields = {}"],
|
||||
"void",
|
||||
const=True,
|
||||
)
|
||||
|
||||
def get_to_bson_method(self):
|
||||
# type: () -> MethodInfo
|
||||
return MethodInfo(
|
||||
common.title_case(self._struct.cpp_name), 'toBSON',
|
||||
['const BSONObj& commandPassthroughFields = {}'], 'BSONObj', const=True)
|
||||
common.title_case(self._struct.cpp_name),
|
||||
"toBSON",
|
||||
["const BSONObj& commandPassthroughFields = {}"],
|
||||
"BSONObj",
|
||||
const=True,
|
||||
)
|
||||
|
||||
def get_deserializer_method(self):
|
||||
# type: () -> MethodInfo
|
||||
return MethodInfo(
|
||||
common.title_case(self._struct.cpp_name), 'parseProtected',
|
||||
['const IDLParserContext& ctxt', 'const BSONObj& bsonObject'], 'void')
|
||||
common.title_case(self._struct.cpp_name),
|
||||
"parseProtected",
|
||||
["const IDLParserContext& ctxt", "const BSONObj& bsonObject"],
|
||||
"void",
|
||||
)
|
||||
|
||||
def gen_getter_method(self, indented_writer):
|
||||
# type: (writer.IndentedTextWriter) -> None
|
||||
indented_writer.write_line(
|
||||
'const NamespaceStringOrUUID& getNamespaceOrUUID() const { return _nssOrUUID; }')
|
||||
"const NamespaceStringOrUUID& getNamespaceOrUUID() const { return _nssOrUUID; }"
|
||||
)
|
||||
if self._struct.non_const_getter:
|
||||
indented_writer.write_line(
|
||||
'NamespaceStringOrUUID& getNamespaceOrUUID() { return _nssOrUUID; }')
|
||||
"NamespaceStringOrUUID& getNamespaceOrUUID() { return _nssOrUUID; }"
|
||||
)
|
||||
|
||||
def gen_member(self, indented_writer):
|
||||
# type: (writer.IndentedTextWriter) -> None
|
||||
indented_writer.write_line('NamespaceStringOrUUID _nssOrUUID;')
|
||||
indented_writer.write_line("NamespaceStringOrUUID _nssOrUUID;")
|
||||
|
||||
def gen_serializer(self, indented_writer):
|
||||
# type: (writer.IndentedTextWriter) -> None
|
||||
@ -664,14 +770,17 @@ class _CommandWithUUIDNamespaceTypeInfo(_CommandBaseTypeInfo):
|
||||
def gen_namespace_check(self, indented_writer, db_name, element):
|
||||
# type: (writer.IndentedTextWriter, str, str) -> None
|
||||
indented_writer.write_line(
|
||||
'auto collOrUUID = ctxt.checkAndAssertCollectionNameOrUUID(%s);' % (element))
|
||||
"auto collOrUUID = ctxt.checkAndAssertCollectionNameOrUUID(%s);" % (element)
|
||||
)
|
||||
indented_writer.write_line(
|
||||
'_nssOrUUID = std::holds_alternative<StringData>(collOrUUID) ? NamespaceStringUtil::deserialize(%s, get<StringData>(collOrUUID)) : NamespaceStringOrUUID(%s, get<UUID>(collOrUUID));'
|
||||
% (db_name, db_name))
|
||||
"_nssOrUUID = std::holds_alternative<StringData>(collOrUUID) ? NamespaceStringUtil::deserialize(%s, get<StringData>(collOrUUID)) : NamespaceStringOrUUID(%s, get<UUID>(collOrUUID));"
|
||||
% (db_name, db_name)
|
||||
)
|
||||
indented_writer.write_line(
|
||||
'uassert(ErrorCodes::InvalidNamespace, str::stream() << "Invalid namespace specified: "'
|
||||
' << _nssOrUUID.toStringForErrorMsg()'
|
||||
', !_nssOrUUID.isNamespaceString() || _nssOrUUID.nss().isValid());')
|
||||
" << _nssOrUUID.toStringForErrorMsg()"
|
||||
", !_nssOrUUID.isNamespaceString() || _nssOrUUID.nss().isValid());"
|
||||
)
|
||||
|
||||
|
||||
def get_struct_info(struct):
|
||||
|
||||
@ -46,8 +46,9 @@ class IDLParsedSpec(object):
|
||||
def __init__(self, spec, error_collection):
|
||||
# type: (IDLSpec, errors.ParserErrorCollection) -> None
|
||||
"""Must specify either an IDL document or errors, not both."""
|
||||
assert (spec is None and error_collection is not None) or (spec is not None
|
||||
and error_collection is None)
|
||||
assert (spec is None and error_collection is not None) or (
|
||||
spec is not None and error_collection is None
|
||||
)
|
||||
self.spec = spec
|
||||
self.errors = error_collection
|
||||
|
||||
@ -76,11 +77,11 @@ def parse_array_variant_types(name):
|
||||
if not name.startswith("array<variant<") and not name.endswith(">>"):
|
||||
return None
|
||||
|
||||
name = name[len("array<variant<"):]
|
||||
name = name[len("array<variant<") :]
|
||||
name = name[:-2]
|
||||
|
||||
variant_types = []
|
||||
for variant_type in name.split(','):
|
||||
for variant_type in name.split(","):
|
||||
variant_type = variant_type.strip()
|
||||
# Ban array<variant<..., array<...>, ...>> types.
|
||||
if variant_type.startswith("array<") and variant_type.endswith(">"):
|
||||
@ -96,7 +97,7 @@ def parse_array_type(name):
|
||||
if not name.startswith("array<") and not name.endswith(">"):
|
||||
return None
|
||||
|
||||
name = name[len("array<"):]
|
||||
name = name[len("array<") :]
|
||||
name = name[:-1]
|
||||
|
||||
# V1 restriction, ban nested array types to reduce scope.
|
||||
@ -139,19 +140,22 @@ class SymbolTable(object):
|
||||
def _is_duplicate(self, ctxt, location, name, duplicate_class_name):
|
||||
# type: (errors.ParserContext, common.SourceLocation, str, str) -> bool
|
||||
"""Return true if the given item already exist in the symbol table."""
|
||||
for (item, entity_type) in _item_and_type({
|
||||
for item, entity_type in _item_and_type(
|
||||
{
|
||||
"command": self.commands,
|
||||
"enum": self.enums,
|
||||
"struct": self.structs,
|
||||
"type": self.types,
|
||||
}):
|
||||
}
|
||||
):
|
||||
if item.name == name:
|
||||
ctxt.add_duplicate_symbol_error(location, name, duplicate_class_name, entity_type)
|
||||
return True
|
||||
if entity_type == "command":
|
||||
if name in [item.command_name, item.command_alias if item.command_alias else '']:
|
||||
ctxt.add_duplicate_symbol_error(location, name, duplicate_class_name,
|
||||
entity_type)
|
||||
if name in [item.command_name, item.command_alias if item.command_alias else ""]:
|
||||
ctxt.add_duplicate_symbol_error(
|
||||
location, name, duplicate_class_name, entity_type
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
@ -177,8 +181,9 @@ class SymbolTable(object):
|
||||
def add_command(self, ctxt, command):
|
||||
# type: (errors.ParserContext, Command) -> None
|
||||
"""Add an IDL command to the symbol table and check for duplicates."""
|
||||
if (not self._is_duplicate(ctxt, command, command.name, "command")
|
||||
and not self._is_duplicate(ctxt, command, command.command_alias, "command")):
|
||||
if not self._is_duplicate(
|
||||
ctxt, command, command.name, "command"
|
||||
) and not self._is_duplicate(ctxt, command, command.command_alias, "command"):
|
||||
self.commands.append(command)
|
||||
|
||||
def add_generic_argument_list(self, field_list):
|
||||
@ -270,7 +275,7 @@ class SymbolTable(object):
|
||||
# cause parsing ambiguity.
|
||||
first_element = alternative_type.fields[0].name
|
||||
if first_element in [
|
||||
elem.fields[0].name for elem in variant.variant_struct_types
|
||||
elem.fields[0].name for elem in variant.variant_struct_types
|
||||
]:
|
||||
ctxt.add_variant_structs_error(location, field_name)
|
||||
continue
|
||||
@ -293,11 +298,13 @@ class SymbolTable(object):
|
||||
return variant
|
||||
|
||||
if isinstance(field_type, FieldTypeArray):
|
||||
element_type = self.resolve_field_type(ctxt, location, field_name,
|
||||
field_type.element_type)
|
||||
element_type = self.resolve_field_type(
|
||||
ctxt, location, field_name, field_type.element_type
|
||||
)
|
||||
if not element_type:
|
||||
ctxt.add_unknown_type_error(location, field_name,
|
||||
field_type.element_type.debug_string())
|
||||
ctxt.add_unknown_type_error(
|
||||
location, field_name, field_type.element_type.debug_string()
|
||||
)
|
||||
return None
|
||||
|
||||
if isinstance(element_type, Enum):
|
||||
@ -308,7 +315,7 @@ class SymbolTable(object):
|
||||
|
||||
assert isinstance(field_type, FieldTypeSingle)
|
||||
type_name = field_type.type_name
|
||||
if type_name.startswith('array<'):
|
||||
if type_name.startswith("array<"):
|
||||
# The caller should've already stripped "array<...>" from type_name, this may be an
|
||||
# illegal nested array like "array<array<...>>".
|
||||
ctxt.add_bad_array_type_name_error(location, field_name, type_name)
|
||||
@ -405,13 +412,14 @@ class ArrayType(Type):
|
||||
def __init__(self, element_type):
|
||||
# type: (Union[Struct, Type]) -> None
|
||||
"""Construct an ArrayType."""
|
||||
super(ArrayType, self).__init__(element_type.file_name, element_type.line,
|
||||
element_type.column)
|
||||
self.name = f'array<{element_type.name}>'
|
||||
super(ArrayType, self).__init__(
|
||||
element_type.file_name, element_type.line, element_type.column
|
||||
)
|
||||
self.name = f"array<{element_type.name}>"
|
||||
self.element_type = element_type
|
||||
if isinstance(element_type, Type):
|
||||
if element_type.cpp_type:
|
||||
self.cpp_type = f'std::vector<{element_type.cpp_type}>'
|
||||
self.cpp_type = f"std::vector<{element_type.cpp_type}>"
|
||||
else:
|
||||
assert isinstance(element_type, VariantType)
|
||||
# cpp_type can't be set here for array of variants as element_type.cpp_type is not set yet.
|
||||
@ -424,7 +432,7 @@ class VariantType(Type):
|
||||
# type: (str, int, int) -> None
|
||||
"""Construct a VariantType."""
|
||||
super(VariantType, self).__init__(file_name, line, column)
|
||||
self.name = 'variant'
|
||||
self.name = "variant"
|
||||
self.variant_types = [] # type: List[Type]
|
||||
self.variant_struct_types = [] # type: List[Struct]
|
||||
|
||||
@ -451,9 +459,14 @@ class Validator(common.SourceLocation):
|
||||
super(Validator, self).__init__(file_name, line, column)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, Validator) and self.gt == other.gt and self.lt == other.lt
|
||||
and self.gte == other.gte and self.lte == other.lte
|
||||
and self.callback == other.callback)
|
||||
return (
|
||||
isinstance(other, Validator)
|
||||
and self.gt == other.gt
|
||||
and self.lt == other.lt
|
||||
and self.gte == other.gte
|
||||
and self.lte == other.lte
|
||||
and self.callback == other.callback
|
||||
)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self == other
|
||||
@ -602,7 +615,11 @@ class Privilege(common.SourceLocation):
|
||||
"""
|
||||
location = super(Privilege, self).__str__()
|
||||
msg = "location: %s, resource_pattern: %s, action_type: %s, agg_stage: %s" % (
|
||||
location, self.resource_pattern, self.action_type, self.agg_stage)
|
||||
location,
|
||||
self.resource_pattern,
|
||||
self.action_type,
|
||||
self.agg_stage,
|
||||
)
|
||||
return msg # type: ignore
|
||||
|
||||
|
||||
@ -709,8 +726,9 @@ class EnumValue(common.SourceLocation):
|
||||
super(EnumValue, self).__init__(file_name, line, column)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, EnumValue) and self.name == other.name
|
||||
and self.value == other.value)
|
||||
return (
|
||||
isinstance(other, EnumValue) and self.name == other.name and self.value == other.value
|
||||
)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self == other
|
||||
@ -790,12 +808,13 @@ class FieldTypeArray(FieldType):
|
||||
"""Construct a FieldTypeArray."""
|
||||
self.element_type = element_type # type: Union[FieldTypeSingle, FieldTypeVariant]
|
||||
|
||||
super(FieldTypeArray, self).__init__(element_type.file_name, element_type.line,
|
||||
element_type.column)
|
||||
super(FieldTypeArray, self).__init__(
|
||||
element_type.file_name, element_type.line, element_type.column
|
||||
)
|
||||
|
||||
def debug_string(self):
|
||||
"""Display this field type in error messages."""
|
||||
return f'array<{self.element_type.type_name}>'
|
||||
return f"array<{self.element_type.type_name}>"
|
||||
|
||||
|
||||
class FieldTypeVariant(FieldType):
|
||||
@ -810,7 +829,7 @@ class FieldTypeVariant(FieldType):
|
||||
|
||||
def debug_string(self):
|
||||
"""Display this field type in error messages."""
|
||||
return 'variant<%s>' % (', '.join(v.debug_string() for v in self.variant))
|
||||
return "variant<%s>" % (", ".join(v.debug_string() for v in self.variant))
|
||||
|
||||
|
||||
class Expression(common.SourceLocation):
|
||||
@ -827,8 +846,12 @@ class Expression(common.SourceLocation):
|
||||
super(Expression, self).__init__(file_name, line, column)
|
||||
|
||||
def __eq__(self, other):
|
||||
return (isinstance(other, Expression) and self.literal == other.literal
|
||||
and self.expr == other.expr and self.is_constexpr == other.is_constexpr)
|
||||
return (
|
||||
isinstance(other, Expression)
|
||||
and self.literal == other.literal
|
||||
and self.expr == other.expr
|
||||
and self.is_constexpr == other.is_constexpr
|
||||
)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not self == other
|
||||
|
||||
@ -39,9 +39,9 @@ _INDENT_SPACE_COUNT = 4
|
||||
def _fill_spaces(count):
|
||||
# type: (int) -> str
|
||||
"""Fill a string full of spaces."""
|
||||
fill = ''
|
||||
fill = ""
|
||||
for _ in range(count * _INDENT_SPACE_COUNT):
|
||||
fill += ' '
|
||||
fill += " "
|
||||
|
||||
return fill
|
||||
|
||||
@ -51,7 +51,7 @@ def _indent_text(count, unindented_text):
|
||||
"""Indent each line of a multi-line string."""
|
||||
lines = unindented_text.splitlines()
|
||||
fill = _fill_spaces(count)
|
||||
return '\n'.join(fill + line for line in lines)
|
||||
return "\n".join(fill + line for line in lines)
|
||||
|
||||
|
||||
def is_function(name):
|
||||
@ -68,10 +68,10 @@ def is_function(name):
|
||||
def get_method_name(name):
|
||||
# type: (str) -> str
|
||||
"""Get a method name from a fully qualified method name."""
|
||||
pos = name.rfind('::')
|
||||
pos = name.rfind("::")
|
||||
if pos == -1:
|
||||
return name
|
||||
return name[pos + 2:]
|
||||
return name[pos + 2 :]
|
||||
|
||||
|
||||
def get_method_name_from_qualified_method_name(name):
|
||||
@ -82,12 +82,12 @@ def get_method_name_from_qualified_method_name(name):
|
||||
if name.startswith("::"):
|
||||
name = name[2:]
|
||||
|
||||
prefix = 'mongo::'
|
||||
prefix = "mongo::"
|
||||
pos = name.find(prefix)
|
||||
if pos == -1:
|
||||
return name
|
||||
|
||||
return name[len(prefix):]
|
||||
return name[len(prefix) :]
|
||||
|
||||
|
||||
class IndentedTextWriter(object):
|
||||
@ -243,7 +243,7 @@ class NamespaceScopeBlock(WriterBlock):
|
||||
# type: () -> None
|
||||
"""Write the beginning of the block and do not indent."""
|
||||
for namespace in self._namespaces:
|
||||
self._writer.write_unindented_line('namespace %s {' % (namespace))
|
||||
self._writer.write_unindented_line("namespace %s {" % (namespace))
|
||||
|
||||
def __exit__(self, *args):
|
||||
# type: (*str) -> None
|
||||
@ -251,7 +251,7 @@ class NamespaceScopeBlock(WriterBlock):
|
||||
self._namespaces.reverse()
|
||||
|
||||
for namespace in self._namespaces:
|
||||
self._writer.write_unindented_line('} // namespace %s' % (namespace))
|
||||
self._writer.write_unindented_line("} // namespace %s" % (namespace))
|
||||
|
||||
|
||||
class UnindentedBlock(WriterBlock):
|
||||
@ -304,7 +304,7 @@ def _get_common_prefix(words):
|
||||
"""
|
||||
empty_words = [lw for lw in words if len(lw) == 0]
|
||||
if empty_words:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
first_letters = {w[0] for w in words}
|
||||
|
||||
@ -317,7 +317,7 @@ def _get_common_prefix(words):
|
||||
|
||||
return words[0][0] + _get_common_prefix(suffix_words)
|
||||
else:
|
||||
return ''
|
||||
return ""
|
||||
|
||||
|
||||
def gen_trie(words, writer, callback):
|
||||
@ -329,7 +329,7 @@ def gen_trie(words, writer, callback):
|
||||
i.e. for ["abc", "def"], then callback() will be called twice, once for each string.
|
||||
"""
|
||||
|
||||
_gen_trie('', words, writer, callback)
|
||||
_gen_trie("", words, writer, callback)
|
||||
|
||||
|
||||
def _gen_trie(prefix, words, writer, callback):
|
||||
@ -352,12 +352,14 @@ def _gen_trie(prefix, words, writer, callback):
|
||||
suffix = words[0]
|
||||
suffix_len = len(suffix)
|
||||
|
||||
predicate = f'fieldName.size() == {len(word_to_check)} && ' \
|
||||
predicate = (
|
||||
f"fieldName.size() == {len(word_to_check)} && "
|
||||
+ f'std::char_traits<char>::compare(fieldName.rawData() + {prefix_len}, "{suffix}", {suffix_len}) == 0'
|
||||
)
|
||||
|
||||
# If there is no trailing text, we just need to check length to validate we matched
|
||||
if suffix_len == 0:
|
||||
predicate = f'fieldName.size() == {len(word_to_check)}'
|
||||
predicate = f"fieldName.size() == {len(word_to_check)}"
|
||||
|
||||
# Optimization:
|
||||
# Checking strings of length 1 or even length is efficient. Strings of 3 byte length are
|
||||
@ -365,10 +367,12 @@ def _gen_trie(prefix, words, writer, callback):
|
||||
# length strings require just 1. Since we know the field name is zero terminated, we can
|
||||
# just use memcmp and compare with the trailing null byte.
|
||||
elif suffix_len % 4 == 3:
|
||||
predicate = f'fieldName.size() == {len(word_to_check)} && ' \
|
||||
predicate = (
|
||||
f"fieldName.size() == {len(word_to_check)} && "
|
||||
+ f' memcmp(fieldName.rawData() + {prefix_len}, "{suffix}\\0", {suffix_len + 1}) == 0'
|
||||
)
|
||||
|
||||
with IndentedScopedBlock(writer, f'if ({predicate}) {{', '}'):
|
||||
with IndentedScopedBlock(writer, f"if ({predicate}) {{", "}"):
|
||||
callback(word_to_check)
|
||||
|
||||
return
|
||||
@ -379,7 +383,7 @@ def _gen_trie(prefix, words, writer, callback):
|
||||
empty_words = [lw for lw in words if len(lw) == 0]
|
||||
if empty_words:
|
||||
word_to_check = prefix
|
||||
with IndentedScopedBlock(writer, f'if (fieldName.size() == {len(word_to_check)}) {{', "}"):
|
||||
with IndentedScopedBlock(writer, f"if (fieldName.size() == {len(word_to_check)}) {{", "}"):
|
||||
callback(word_to_check)
|
||||
|
||||
# Filter out empty words
|
||||
@ -395,10 +399,11 @@ def _gen_trie(prefix, words, writer, callback):
|
||||
suffix_words = [flw[gcp_len:] for flw in words]
|
||||
|
||||
with IndentedScopedBlock(
|
||||
writer,
|
||||
f'if (fieldName.size() >= {gcp_len} && '\
|
||||
+ f'std::char_traits<char>::compare(fieldName.rawData() + {prefix_len}, "{gcp}", {gcp_len}) == 0) {{',
|
||||
"}"):
|
||||
writer,
|
||||
f"if (fieldName.size() >= {gcp_len} && "
|
||||
+ f'std::char_traits<char>::compare(fieldName.rawData() + {prefix_len}, "{gcp}", {gcp_len}) == 0) {{',
|
||||
"}",
|
||||
):
|
||||
_gen_trie(prefix + gcp, suffix_words, writer, callback)
|
||||
|
||||
return
|
||||
@ -410,16 +415,16 @@ def _gen_trie(prefix, words, writer, callback):
|
||||
first_letters = {w[0] for w in sorted_words}
|
||||
min_len = len(prefix) + min([len(w) for w in sorted_words])
|
||||
|
||||
with IndentedScopedBlock(writer, f'if (fieldName.size() >= {min_len}) {{', "}"):
|
||||
with IndentedScopedBlock(writer, f"if (fieldName.size() >= {min_len}) {{", "}"):
|
||||
first_if = True
|
||||
|
||||
for first_letter in first_letters:
|
||||
|
||||
fl_words = [flw[1:] for flw in words if flw[0] == first_letter]
|
||||
|
||||
ei = "else " if not first_if else ''
|
||||
ei = "else " if not first_if else ""
|
||||
with IndentedScopedBlock(
|
||||
writer, f"{ei}if (fieldName[{len(prefix)}] == '{first_letter}') {{", "}"):
|
||||
writer, f"{ei}if (fieldName[{len(prefix)}] == '{first_letter}') {{", "}"
|
||||
):
|
||||
_gen_trie(prefix + first_letter, fl_words, writer, callback)
|
||||
|
||||
first_if = False
|
||||
@ -429,6 +434,6 @@ def gen_string_table_find_function_block(out, in_str, on_match, on_fail, words):
|
||||
# type: (IndentedTextWriter, str, str, str, list[str]) -> None
|
||||
"""Wrap a gen_trie generated block as a function."""
|
||||
index = {word: i for i, word in enumerate(words)}
|
||||
out.write_line(f'StringData fieldName{{{in_str}}};')
|
||||
gen_trie(words, out, lambda w: out.write_line(f'return {on_match.format(index[w])};'))
|
||||
out.write_line(f'return {on_fail};')
|
||||
out.write_line(f"StringData fieldName{{{in_str}}};")
|
||||
gen_trie(words, out, lambda w: out.write_line(f"return {on_match.format(index[w])};"))
|
||||
out.write_line(f"return {on_fail};")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -39,29 +39,43 @@ import idl.compiler
|
||||
def main():
|
||||
# type: () -> None
|
||||
"""Execute Main Entry point."""
|
||||
parser = argparse.ArgumentParser(description='MongoDB IDL Compiler.')
|
||||
parser = argparse.ArgumentParser(description="MongoDB IDL Compiler.")
|
||||
|
||||
parser.add_argument('file', type=str, help="IDL input file")
|
||||
parser.add_argument("file", type=str, help="IDL input file")
|
||||
|
||||
parser.add_argument('-o', '--output', type=str, help="IDL output source file")
|
||||
parser.add_argument("-o", "--output", type=str, help="IDL output source file")
|
||||
|
||||
parser.add_argument('--header', type=str, help="IDL output header file")
|
||||
parser.add_argument("--header", type=str, help="IDL output header file")
|
||||
|
||||
parser.add_argument('-i', '--include', type=str, action="append",
|
||||
help="Directory to search for IDL import files")
|
||||
parser.add_argument(
|
||||
"-i",
|
||||
"--include",
|
||||
type=str,
|
||||
action="append",
|
||||
help="Directory to search for IDL import files",
|
||||
)
|
||||
|
||||
parser.add_argument('-v', '--verbose', action='count', help="Enable verbose tracing")
|
||||
parser.add_argument("-v", "--verbose", action="count", help="Enable verbose tracing")
|
||||
|
||||
parser.add_argument('--base_dir', type=str, help="IDL output relative base directory")
|
||||
parser.add_argument("--base_dir", type=str, help="IDL output relative base directory")
|
||||
|
||||
parser.add_argument('--write-dependencies', action='store_true',
|
||||
help='only print out a list of dependent imports')
|
||||
parser.add_argument(
|
||||
"--write-dependencies",
|
||||
action="store_true",
|
||||
help="only print out a list of dependent imports",
|
||||
)
|
||||
|
||||
parser.add_argument('--write-dependencies-inline', action='store_true',
|
||||
help='print out a list of dependent imports during file generation')
|
||||
parser.add_argument(
|
||||
"--write-dependencies-inline",
|
||||
action="store_true",
|
||||
help="print out a list of dependent imports during file generation",
|
||||
)
|
||||
|
||||
parser.add_argument('--target_arch', type=str,
|
||||
help="IDL target archiecture (amd64, s390x). defaults to current machine")
|
||||
parser.add_argument(
|
||||
"--target_arch",
|
||||
type=str,
|
||||
help="IDL target archiecture (amd64, s390x). defaults to current machine",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@ -81,8 +95,9 @@ def main():
|
||||
compiler_args.write_dependencies = args.write_dependencies
|
||||
compiler_args.write_dependencies_inline = args.write_dependencies_inline
|
||||
|
||||
if (args.output is not None and args.header is None) or \
|
||||
(args.output is None and args.header is not None):
|
||||
if (args.output is not None and args.header is None) or (
|
||||
args.output is None and args.header is not None
|
||||
):
|
||||
print("ERROR: Either both --header and --output must be specified or neither.")
|
||||
sys.exit(1)
|
||||
|
||||
@ -93,5 +108,5 @@ def main():
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@ -38,8 +38,9 @@ def list_idls(directory: str) -> Set[str]:
|
||||
"""Find all IDL files in the current directory."""
|
||||
return {
|
||||
os.path.join(dirpath, filename)
|
||||
for dirpath, dirnames, filenames in os.walk(directory) for filename in filenames
|
||||
if filename.endswith('.idl')
|
||||
for dirpath, dirnames, filenames in os.walk(directory)
|
||||
for filename in filenames
|
||||
if filename.endswith(".idl")
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -46,11 +46,11 @@ def run_tests():
|
||||
# my-py type information.
|
||||
all_tests = unittest.defaultTestLoader.discover(start_dir="tests") # type: ignore
|
||||
|
||||
runner = XMLTestRunner(verbosity=2, failfast=False, output='results')
|
||||
runner = XMLTestRunner(verbosity=2, failfast=False, output="results")
|
||||
result = runner.run(all_tests)
|
||||
|
||||
sys.exit(not result.wasSuccessful())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
||||
@ -30,7 +30,7 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
|
||||
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||
|
||||
import idl.ast # pylint: disable=wrong-import-position
|
||||
import idl.binder # pylint: disable=wrong-import-position
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -43,6 +43,7 @@ from textwrap import dedent
|
||||
# import package so that it works regardless of whether we run as a module or file
|
||||
if __package__ is None:
|
||||
import sys
|
||||
|
||||
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
||||
from context import idl
|
||||
import testcase
|
||||
@ -61,16 +62,17 @@ class TestGenerator(testcase.IDLTestcase):
|
||||
def _src_dir(self):
|
||||
"""Get the directory of the src folder."""
|
||||
base_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
)
|
||||
return os.path.join(
|
||||
base_dir,
|
||||
'src',
|
||||
"src",
|
||||
)
|
||||
|
||||
@property
|
||||
def _idl_dir(self):
|
||||
"""Get the directory of the idl folder."""
|
||||
return os.path.join(self._src_dir, 'mongo', 'idl')
|
||||
return os.path.join(self._src_dir, "mongo", "idl")
|
||||
|
||||
def tearDown(self) -> None:
|
||||
"""Cleanup resources created by tests."""
|
||||
@ -87,14 +89,15 @@ class TestGenerator(testcase.IDLTestcase):
|
||||
args.output_suffix = self.output_suffix
|
||||
args.import_directories = [self._src_dir]
|
||||
|
||||
unittest_idl_file = os.path.join(self._idl_dir, f'{self.idl_files_to_test[0]}.idl')
|
||||
unittest_idl_file = os.path.join(self._idl_dir, f"{self.idl_files_to_test[0]}.idl")
|
||||
if not os.path.exists(unittest_idl_file):
|
||||
unittest.skip(
|
||||
"Skipping IDL Generator testing since %s could not be found." % (unittest_idl_file))
|
||||
"Skipping IDL Generator testing since %s could not be found." % (unittest_idl_file)
|
||||
)
|
||||
return
|
||||
|
||||
for idl_file in self.idl_files_to_test:
|
||||
args.input_file = os.path.join(self._idl_dir, f'{idl_file}.idl')
|
||||
args.input_file = os.path.join(self._idl_dir, f"{idl_file}.idl")
|
||||
self.assertTrue(idl.compiler.compile_idl(args))
|
||||
|
||||
def test_enum_non_const(self):
|
||||
@ -130,13 +133,15 @@ class TestGenerator(testcase.IDLTestcase):
|
||||
# Make sure the getter is marked as const.
|
||||
# Make sure the return type is not marked as const by validating the getter marked as const
|
||||
# is the only occurrence of the word "const".
|
||||
header_lines = header.split('\n')
|
||||
header_lines = header.split("\n")
|
||||
|
||||
found = False
|
||||
for header_line in header_lines:
|
||||
if header_line.find("getValue") > 0 \
|
||||
and header_line.find("const {") > 0 \
|
||||
and header_line.find("const {") == header_line.find("const"):
|
||||
if (
|
||||
header_line.find("getValue") > 0
|
||||
and header_line.find("const {") > 0
|
||||
and header_line.find("const {") == header_line.find("const")
|
||||
):
|
||||
found = True
|
||||
|
||||
self.assertTrue(found, "Bad Header: " + header)
|
||||
@ -267,7 +272,8 @@ class TestGenerator(testcase.IDLTestcase):
|
||||
self.assertIn(expected, source)
|
||||
|
||||
def test_array_of_object_type_with_custom_serializer_and_query_shape_specification_custom(
|
||||
self) -> None:
|
||||
self,
|
||||
) -> None:
|
||||
"""Serialization with custom query_shape used, array use case."""
|
||||
_, source = self.assert_generate("""
|
||||
types:
|
||||
@ -373,7 +379,9 @@ class TestGenerator(testcase.IDLTestcase):
|
||||
|
||||
def test_view_struct_generates_anchor(self) -> None:
|
||||
"""Test anchor generation on view struct."""
|
||||
header, _ = self.assert_generate(self.view_test_common_types + dedent("""
|
||||
header, _ = self.assert_generate(
|
||||
self.view_test_common_types
|
||||
+ dedent("""
|
||||
structs:
|
||||
ViewStruct:
|
||||
description: ViewStruct
|
||||
@ -382,14 +390,17 @@ class TestGenerator(testcase.IDLTestcase):
|
||||
value2: random_type_not_view
|
||||
value3: random_type_not_view
|
||||
value4: random_type_not_view
|
||||
"""))
|
||||
""")
|
||||
)
|
||||
|
||||
expected = dedent("BSONObj _anchorObj;")
|
||||
self.assertIn(expected, header)
|
||||
|
||||
def test_non_view_struct_does_not_generate_anchor(self) -> None:
|
||||
"""Test anchor is not generated on non view struct."""
|
||||
header, _ = self.assert_generate(self.view_test_common_types + dedent("""
|
||||
header, _ = self.assert_generate(
|
||||
self.view_test_common_types
|
||||
+ dedent("""
|
||||
structs:
|
||||
NonViewStruct:
|
||||
description: NonViewStruct
|
||||
@ -398,14 +409,17 @@ class TestGenerator(testcase.IDLTestcase):
|
||||
value2: random_type_not_view
|
||||
value3: random_type_not_view
|
||||
value4: random_type_not_view
|
||||
"""))
|
||||
""")
|
||||
)
|
||||
|
||||
expected = dedent("BSONObj _anchorObj;")
|
||||
self.assertNotIn(expected, header)
|
||||
|
||||
def test_compound_view_struct_generates_anchor(self) -> None:
|
||||
"""Test anchor generation on view struct with compound type."""
|
||||
header, _ = self.assert_generate(self.view_test_common_types + dedent("""
|
||||
header, _ = self.assert_generate(
|
||||
self.view_test_common_types
|
||||
+ dedent("""
|
||||
structs:
|
||||
ViewStruct:
|
||||
description: ViewStruct
|
||||
@ -414,14 +428,17 @@ class TestGenerator(testcase.IDLTestcase):
|
||||
value2: random_type_not_view
|
||||
value3: random_type_not_view
|
||||
value4: random_type_not_view
|
||||
"""))
|
||||
""")
|
||||
)
|
||||
|
||||
expected = dedent("BSONObj _anchorObj;")
|
||||
self.assertIn(expected, header)
|
||||
|
||||
def test_compound_non_view_struct_does_not_generate_anchor(self) -> None:
|
||||
"""Test anchor is not generated on non view struct with compound type."""
|
||||
header, _ = self.assert_generate(self.view_test_common_types + dedent("""
|
||||
header, _ = self.assert_generate(
|
||||
self.view_test_common_types
|
||||
+ dedent("""
|
||||
structs:
|
||||
NonViewStruct:
|
||||
description: NonViewStruct
|
||||
@ -430,14 +447,17 @@ class TestGenerator(testcase.IDLTestcase):
|
||||
value2: random_type_not_view
|
||||
value3: random_type_not_view
|
||||
value4: random_type_not_view
|
||||
"""))
|
||||
""")
|
||||
)
|
||||
|
||||
expected = dedent("BSONObj _anchorObj;")
|
||||
self.assertNotIn(expected, header)
|
||||
|
||||
def test_command_view_type_generates_anchor(self) -> None:
|
||||
"""Test anchor generation on command with view parameter."""
|
||||
header, _ = self.assert_generate(self.view_test_common_types + dedent("""
|
||||
header, _ = self.assert_generate(
|
||||
self.view_test_common_types
|
||||
+ dedent("""
|
||||
commands:
|
||||
CommandTypeArrayObjectCommand:
|
||||
description: CommandTypeArrayObjectCommand
|
||||
@ -445,14 +465,17 @@ class TestGenerator(testcase.IDLTestcase):
|
||||
namespace: type
|
||||
api_version: ""
|
||||
type: array<object_is_view>
|
||||
"""))
|
||||
""")
|
||||
)
|
||||
|
||||
expected = dedent("BSONObj _anchorObj;")
|
||||
self.assertIn(expected, header)
|
||||
|
||||
def test_command_non_view_type_does_not_generate_anchor(self) -> None:
|
||||
"""Test anchor is not generated on command with nont view parameter."""
|
||||
header, _ = self.assert_generate(self.view_test_common_types + dedent("""
|
||||
header, _ = self.assert_generate(
|
||||
self.view_test_common_types
|
||||
+ dedent("""
|
||||
commands:
|
||||
CommandTypeArrayObjectCommand:
|
||||
description: CommandTypeArrayObjectCommand
|
||||
@ -460,14 +483,17 @@ class TestGenerator(testcase.IDLTestcase):
|
||||
namespace: type
|
||||
api_version: ""
|
||||
type: array<object_is_not_view>
|
||||
"""))
|
||||
""")
|
||||
)
|
||||
|
||||
expected = dedent("BSONObj _anchorObj;")
|
||||
self.assertNotIn(expected, header)
|
||||
|
||||
def test_chained_view_struct_generates_anchor(self) -> None:
|
||||
"""Test anchor generation on struct chained with view struct."""
|
||||
header, _ = self.assert_generate(self.view_test_common_types + dedent("""
|
||||
header, _ = self.assert_generate(
|
||||
self.view_test_common_types
|
||||
+ dedent("""
|
||||
structs:
|
||||
ViewStruct:
|
||||
description: ViewStruct
|
||||
@ -480,14 +506,17 @@ class TestGenerator(testcase.IDLTestcase):
|
||||
description: ViewStructChainedStruct
|
||||
chained_structs:
|
||||
ViewStruct: ViewStruct
|
||||
"""))
|
||||
""")
|
||||
)
|
||||
|
||||
expected = dedent("BSONObj _anchorObj;")
|
||||
self.assertIn(expected, header)
|
||||
|
||||
def test_chained_non_view_struct_does_not_generate_anchor(self) -> None:
|
||||
"""Test anchor not generated on struct chained with non view struct."""
|
||||
header, _ = self.assert_generate(self.view_test_common_types + dedent("""
|
||||
header, _ = self.assert_generate(
|
||||
self.view_test_common_types
|
||||
+ dedent("""
|
||||
structs:
|
||||
NonViewStruct:
|
||||
description: NonViewStruct
|
||||
@ -500,12 +529,12 @@ class TestGenerator(testcase.IDLTestcase):
|
||||
description: NonViewStructChainedStruct
|
||||
chained_structs:
|
||||
NonViewStruct: NonViewStruct
|
||||
"""))
|
||||
""")
|
||||
)
|
||||
|
||||
expected = dedent("BSONObj _anchorObj;")
|
||||
self.assertNotIn(expected, header)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@ -38,6 +38,7 @@ from typing import Any, Dict
|
||||
if __package__ is None:
|
||||
import sys
|
||||
from os import path
|
||||
|
||||
sys.path.append(path.dirname(path.abspath(__file__)))
|
||||
from context import idl
|
||||
import testcase
|
||||
@ -88,27 +89,32 @@ class TestImport(testcase.IDLTestcase):
|
||||
|
||||
imports:
|
||||
- "b.idl"
|
||||
"""), idl.errors.ERROR_ID_DUPLICATE_NODE)
|
||||
"""),
|
||||
idl.errors.ERROR_ID_DUPLICATE_NODE,
|
||||
)
|
||||
|
||||
self.assert_parse_fail(
|
||||
textwrap.dedent("""
|
||||
imports: "basetypes.idl"
|
||||
"""), idl.errors.ERROR_ID_IS_NODE_TYPE)
|
||||
"""),
|
||||
idl.errors.ERROR_ID_IS_NODE_TYPE,
|
||||
)
|
||||
|
||||
self.assert_parse_fail(
|
||||
textwrap.dedent("""
|
||||
imports:
|
||||
a: "a.idl"
|
||||
b: "b.idl"
|
||||
"""), idl.errors.ERROR_ID_IS_NODE_TYPE)
|
||||
"""),
|
||||
idl.errors.ERROR_ID_IS_NODE_TYPE,
|
||||
)
|
||||
|
||||
def test_import_positive(self):
|
||||
# type: () -> None
|
||||
"""Postive import tests."""
|
||||
|
||||
import_dict = {
|
||||
"basetypes.idl":
|
||||
textwrap.dedent("""
|
||||
"basetypes.idl": textwrap.dedent("""
|
||||
global:
|
||||
cpp_namespace: 'mongo'
|
||||
|
||||
@ -135,8 +141,7 @@ class TestImport(testcase.IDLTestcase):
|
||||
fields:
|
||||
foo: string
|
||||
"""),
|
||||
"recurse1.idl":
|
||||
textwrap.dedent("""
|
||||
"recurse1.idl": textwrap.dedent("""
|
||||
imports:
|
||||
- "basetypes.idl"
|
||||
|
||||
@ -149,8 +154,7 @@ class TestImport(testcase.IDLTestcase):
|
||||
is_view: false
|
||||
|
||||
"""),
|
||||
"recurse2.idl":
|
||||
textwrap.dedent("""
|
||||
"recurse2.idl": textwrap.dedent("""
|
||||
imports:
|
||||
- "recurse1.idl"
|
||||
|
||||
@ -163,8 +167,7 @@ class TestImport(testcase.IDLTestcase):
|
||||
is_view: false
|
||||
|
||||
"""),
|
||||
"recurse1b.idl":
|
||||
textwrap.dedent("""
|
||||
"recurse1b.idl": textwrap.dedent("""
|
||||
imports:
|
||||
- "basetypes.idl"
|
||||
|
||||
@ -176,8 +179,7 @@ class TestImport(testcase.IDLTestcase):
|
||||
deserializer: BSONElement::fake
|
||||
is_view: false
|
||||
"""),
|
||||
"cycle1a.idl":
|
||||
textwrap.dedent("""
|
||||
"cycle1a.idl": textwrap.dedent("""
|
||||
global:
|
||||
cpp_namespace: 'mongo'
|
||||
|
||||
@ -208,8 +210,7 @@ class TestImport(testcase.IDLTestcase):
|
||||
foo: string
|
||||
foo1: bool
|
||||
"""),
|
||||
"cycle1b.idl":
|
||||
textwrap.dedent("""
|
||||
"cycle1b.idl": textwrap.dedent("""
|
||||
global:
|
||||
cpp_namespace: 'mongo'
|
||||
|
||||
@ -232,8 +233,7 @@ class TestImport(testcase.IDLTestcase):
|
||||
foo: string
|
||||
foo1: bool
|
||||
"""),
|
||||
"cycle2.idl":
|
||||
textwrap.dedent("""
|
||||
"cycle2.idl": textwrap.dedent("""
|
||||
global:
|
||||
cpp_namespace: 'mongo'
|
||||
|
||||
@ -275,7 +275,9 @@ class TestImport(testcase.IDLTestcase):
|
||||
strict: false
|
||||
fields:
|
||||
foo: string
|
||||
"""), resolver=resolver)
|
||||
"""),
|
||||
resolver=resolver,
|
||||
)
|
||||
|
||||
# Test nested import
|
||||
self.assert_bind(
|
||||
@ -294,7 +296,9 @@ class TestImport(testcase.IDLTestcase):
|
||||
foo: string
|
||||
foo1: int
|
||||
foo2: double
|
||||
"""), resolver=resolver)
|
||||
"""),
|
||||
resolver=resolver,
|
||||
)
|
||||
|
||||
# Test diamond import
|
||||
self.assert_bind(
|
||||
@ -315,7 +319,9 @@ class TestImport(testcase.IDLTestcase):
|
||||
foo1: int
|
||||
foo2: double
|
||||
foo3: bool
|
||||
"""), resolver=resolver)
|
||||
"""),
|
||||
resolver=resolver,
|
||||
)
|
||||
|
||||
# Test cycle import
|
||||
self.assert_bind(
|
||||
@ -333,7 +339,9 @@ class TestImport(testcase.IDLTestcase):
|
||||
fields:
|
||||
foo: string
|
||||
foo1: bool
|
||||
"""), resolver=resolver)
|
||||
"""),
|
||||
resolver=resolver,
|
||||
)
|
||||
|
||||
# Test self cycle import
|
||||
self.assert_bind(
|
||||
@ -350,15 +358,16 @@ class TestImport(testcase.IDLTestcase):
|
||||
strict: false
|
||||
fields:
|
||||
foo: string
|
||||
"""), resolver=resolver)
|
||||
"""),
|
||||
resolver=resolver,
|
||||
)
|
||||
|
||||
def test_import_negative(self):
|
||||
# type: () -> None
|
||||
"""Negative import tests."""
|
||||
|
||||
import_dict = {
|
||||
"basetypes.idl":
|
||||
textwrap.dedent("""
|
||||
"basetypes.idl": textwrap.dedent("""
|
||||
global:
|
||||
cpp_namespace: 'mongo'
|
||||
|
||||
@ -388,8 +397,7 @@ class TestImport(testcase.IDLTestcase):
|
||||
b1: 1
|
||||
|
||||
"""),
|
||||
"bug.idl":
|
||||
textwrap.dedent("""
|
||||
"bug.idl": textwrap.dedent("""
|
||||
global:
|
||||
cpp_namespace: 'mongo'
|
||||
|
||||
@ -409,7 +417,10 @@ class TestImport(testcase.IDLTestcase):
|
||||
textwrap.dedent("""
|
||||
imports:
|
||||
- "notfound.idl"
|
||||
"""), idl.errors.ERROR_ID_BAD_IMPORT, resolver=resolver)
|
||||
"""),
|
||||
idl.errors.ERROR_ID_BAD_IMPORT,
|
||||
resolver=resolver,
|
||||
)
|
||||
|
||||
# Duplicate types
|
||||
self.assert_parse_fail(
|
||||
@ -423,7 +434,10 @@ class TestImport(testcase.IDLTestcase):
|
||||
cpp_type: foo
|
||||
bson_serialization_type: string
|
||||
is_view: false
|
||||
"""), idl.errors.ERROR_ID_DUPLICATE_SYMBOL, resolver=resolver)
|
||||
"""),
|
||||
idl.errors.ERROR_ID_DUPLICATE_SYMBOL,
|
||||
resolver=resolver,
|
||||
)
|
||||
|
||||
# Duplicate structs
|
||||
self.assert_parse_fail(
|
||||
@ -436,7 +450,10 @@ class TestImport(testcase.IDLTestcase):
|
||||
description: foo
|
||||
fields:
|
||||
foo1: string
|
||||
"""), idl.errors.ERROR_ID_DUPLICATE_SYMBOL, resolver=resolver)
|
||||
"""),
|
||||
idl.errors.ERROR_ID_DUPLICATE_SYMBOL,
|
||||
resolver=resolver,
|
||||
)
|
||||
|
||||
# Duplicate struct and type
|
||||
self.assert_parse_fail(
|
||||
@ -449,7 +466,10 @@ class TestImport(testcase.IDLTestcase):
|
||||
description: foo
|
||||
fields:
|
||||
foo1: string
|
||||
"""), idl.errors.ERROR_ID_DUPLICATE_SYMBOL, resolver=resolver)
|
||||
"""),
|
||||
idl.errors.ERROR_ID_DUPLICATE_SYMBOL,
|
||||
resolver=resolver,
|
||||
)
|
||||
|
||||
# Duplicate type and struct
|
||||
self.assert_parse_fail(
|
||||
@ -463,7 +483,10 @@ class TestImport(testcase.IDLTestcase):
|
||||
cpp_type: foo
|
||||
bson_serialization_type: string
|
||||
is_view: false
|
||||
"""), idl.errors.ERROR_ID_DUPLICATE_SYMBOL, resolver=resolver)
|
||||
"""),
|
||||
idl.errors.ERROR_ID_DUPLICATE_SYMBOL,
|
||||
resolver=resolver,
|
||||
)
|
||||
|
||||
# Duplicate enums
|
||||
self.assert_parse_fail(
|
||||
@ -478,7 +501,10 @@ class TestImport(testcase.IDLTestcase):
|
||||
values:
|
||||
a0: 0
|
||||
b1: 1
|
||||
"""), idl.errors.ERROR_ID_DUPLICATE_SYMBOL, resolver=resolver)
|
||||
"""),
|
||||
idl.errors.ERROR_ID_DUPLICATE_SYMBOL,
|
||||
resolver=resolver,
|
||||
)
|
||||
|
||||
# Import a file with errors
|
||||
self.assert_parse_fail(
|
||||
@ -493,9 +519,11 @@ class TestImport(testcase.IDLTestcase):
|
||||
cpp_type: foo
|
||||
bson_serialization_type: string
|
||||
is_view: false
|
||||
"""), idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD, resolver=resolver)
|
||||
"""),
|
||||
idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD,
|
||||
resolver=resolver,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -30,9 +30,10 @@
|
||||
import unittest
|
||||
from typing import Any, Tuple
|
||||
|
||||
if __name__ == 'testcase':
|
||||
if __name__ == "testcase":
|
||||
import sys
|
||||
from os import path
|
||||
|
||||
sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
|
||||
from context import idl
|
||||
else:
|
||||
@ -79,8 +80,9 @@ class IDLTestcase(unittest.TestCase):
|
||||
"""Assert a document parsed correctly by the IDL compiler and returned no errors."""
|
||||
self.assertIsNone(
|
||||
parsed_doc.errors,
|
||||
"Expected no parser errors\nFor document:\n%s\nReceived errors:\n\n%s" %
|
||||
(doc_str, errors_to_str(parsed_doc.errors)))
|
||||
"Expected no parser errors\nFor document:\n%s\nReceived errors:\n\n%s"
|
||||
% (doc_str, errors_to_str(parsed_doc.errors)),
|
||||
)
|
||||
self.assertIsNotNone(parsed_doc.spec, "Expected a parsed doc")
|
||||
|
||||
def assert_parse(self, doc_str, resolver=NothingImportResolver()):
|
||||
@ -89,8 +91,9 @@ class IDLTestcase(unittest.TestCase):
|
||||
parsed_doc = self._parse(doc_str, resolver)
|
||||
self._assert_parse(doc_str, parsed_doc)
|
||||
|
||||
def assert_parse_fail(self, doc_str, error_id, multiple=False,
|
||||
resolver=NothingImportResolver()):
|
||||
def assert_parse_fail(
|
||||
self, doc_str, error_id, multiple=False, resolver=NothingImportResolver()
|
||||
):
|
||||
# type: (str, str, bool, idl.parser.ImportResolverBase) -> None
|
||||
"""
|
||||
Assert a document parsed correctly by the YAML parser, but not the by the IDL compiler.
|
||||
@ -107,12 +110,14 @@ class IDLTestcase(unittest.TestCase):
|
||||
self.assertTrue(
|
||||
multiple or parsed_doc.errors.count() == 1,
|
||||
"For document:\n%s\nExpected only error message '%s' but received multiple errors:\n\n%s"
|
||||
% (doc_str, error_id, errors_to_str(parsed_doc.errors)))
|
||||
% (doc_str, error_id, errors_to_str(parsed_doc.errors)),
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
parsed_doc.errors.contains(error_id),
|
||||
"For document:\n%s\nExpected error message '%s' but received only errors:\n %s" %
|
||||
(doc_str, error_id, errors_to_str(parsed_doc.errors)))
|
||||
"For document:\n%s\nExpected error message '%s' but received only errors:\n %s"
|
||||
% (doc_str, error_id, errors_to_str(parsed_doc.errors)),
|
||||
)
|
||||
|
||||
def assert_bind(self, doc_str, resolver=NothingImportResolver()):
|
||||
# type: (str, idl.parser.ImportResolverBase) -> idl.ast.IDLBoundSpec
|
||||
@ -123,8 +128,10 @@ class IDLTestcase(unittest.TestCase):
|
||||
bound_doc = idl.binder.bind(parsed_doc.spec)
|
||||
|
||||
self.assertIsNone(
|
||||
bound_doc.errors, "Expected no binder errors\nFor document:\n%s\nReceived errors:\n\n%s"
|
||||
% (doc_str, errors_to_str(bound_doc.errors)))
|
||||
bound_doc.errors,
|
||||
"Expected no binder errors\nFor document:\n%s\nReceived errors:\n\n%s"
|
||||
% (doc_str, errors_to_str(bound_doc.errors)),
|
||||
)
|
||||
self.assertIsNotNone(bound_doc.spec, "Expected a bound doc")
|
||||
|
||||
return bound_doc.spec
|
||||
@ -149,12 +156,14 @@ class IDLTestcase(unittest.TestCase):
|
||||
self.assertTrue(
|
||||
(multiple and bound_doc.errors.count() >= 1) or bound_doc.errors.count() == 1,
|
||||
"For document:\n%s\nExpected only error message '%s' but received multiple errors:\n\n%s"
|
||||
% (doc_str, error_id, errors_to_str(bound_doc.errors)))
|
||||
% (doc_str, error_id, errors_to_str(bound_doc.errors)),
|
||||
)
|
||||
|
||||
self.assertTrue(
|
||||
bound_doc.errors.contains(error_id),
|
||||
"For document:\n%s\nExpected error message '%s' but received only errors:\n %s" %
|
||||
(doc_str, error_id, errors_to_str(bound_doc.errors)))
|
||||
"For document:\n%s\nExpected error message '%s' but received only errors:\n %s"
|
||||
% (doc_str, error_id, errors_to_str(bound_doc.errors)),
|
||||
)
|
||||
|
||||
def assert_generate(self, doc_str, resolver=NothingImportResolver()):
|
||||
# type: (str, idl.parser.ImportResolverBase) -> Tuple[str,str]
|
||||
|
||||
@ -212,7 +212,6 @@ ignore = [
|
||||
[tool.ruff.format]
|
||||
exclude = [
|
||||
# allow-list: incrementally remove as each is formatted and locked down
|
||||
"buildscripts/idl/*",
|
||||
"jstests/*",
|
||||
"site_scons/*",
|
||||
"src/*",
|
||||
|
||||
Loading…
Reference in New Issue
Block a user