SERVER-90571: Enable python formatting checks for buildscripts/idl directory (#22293)

GitOrigin-RevId: a2fbc8ed83f576703cce96ebb5e680cc70aac4d8
This commit is contained in:
Steve McClure 2024-05-17 15:17:45 -04:00 committed by MongoDB Bot
parent 83fa212b68
commit bd2955c297
31 changed files with 8383 additions and 4455 deletions

View File

@ -40,7 +40,7 @@ from typing import Any, Dict, List, Mapping, Set
from pymongo import MongoClient
# Permit imports from "buildscripts".
sys.path.append(os.path.normpath(os.path.join(os.path.abspath(__file__), '../../..')))
sys.path.append(os.path.normpath(os.path.join(os.path.abspath(__file__), "../../..")))
# pylint: disable=wrong-import-position
from idl import syntax
@ -53,7 +53,7 @@ from buildscripts.resmokelib.testing.fixtures.shardedcluster import ShardedClust
from buildscripts.resmokelib.testing.fixtures.standalone import MongoDFixture
# pylint: enable=wrong-import-position
LOGGER_NAME = 'check-idl-definitions'
LOGGER_NAME = "check-idl-definitions"
LOGGER = logging.getLogger(LOGGER_NAME)
@ -68,8 +68,9 @@ def is_test_or_third_party_idl(idl_path: str) -> bool:
return False
def get_command_definitions(api_version: str, directory: str,
import_directories: List[str]) -> Dict[str, syntax.Command]:
def get_command_definitions(
api_version: str, directory: str, import_directories: List[str]
) -> Dict[str, syntax.Command]:
"""Get parsed IDL definitions of commands in a given API version."""
LOGGER.info("Searching for command definitions in %s", directory)
@ -109,22 +110,30 @@ def list_commands_for_api(api_version: str, mongod_or_mongos: str, install_dir:
logger = loggers.new_fixture_logger("ShardedClusterFixture", 0)
logger.parent = LOGGER
fixture = fixturelib.make_fixture(
"ShardedClusterFixture", logger, 0, dbpath_prefix=dbpath.name,
mongos_executable=mongos_executable, mongod_executable=mongod_executable,
mongod_options={"set_parameters": {}})
"ShardedClusterFixture",
logger,
0,
dbpath_prefix=dbpath.name,
mongos_executable=mongos_executable,
mongod_executable=mongod_executable,
mongod_options={"set_parameters": {}},
)
fixture.setup()
fixture.await_ready()
try:
client = MongoClient(fixture.get_driver_connection_url()) # type: MongoClient
reply = client.admin.command('listCommands') # type: Mapping[str, Any]
reply = client.admin.command("listCommands") # type: Mapping[str, Any]
commands = {
name
for name, info in reply['commands'].items() if api_version in info['apiVersions']
name for name, info in reply["commands"].items() if api_version in info["apiVersions"]
}
logging.info("Found %s commands in API Version %s on %s", len(commands), api_version,
mongod_or_mongos)
logging.info(
"Found %s commands in API Version %s on %s",
len(commands),
api_version,
mongod_or_mongos,
)
return commands
finally:
fixture.teardown()
@ -144,13 +153,16 @@ def assert_command_sets_equal(api_version: str, command_sets: Dict[str, Set[str]
for other_name, other_commands in it:
if commands != other_commands:
if commands - other_commands:
LOGGER.error("%s has commands not in %s: %s", name, other_name,
commands - other_commands)
LOGGER.error(
"%s has commands not in %s: %s", name, other_name, commands - other_commands
)
if other_commands - commands:
LOGGER.error("%s has commands not in %s: %s", other_name, name,
other_commands - commands)
LOGGER.error(
"%s has commands not in %s: %s", other_name, name, other_commands - commands
)
raise AssertionError(
f"{name} and {other_name} have different commands in API Version {api_version}")
f"{name} and {other_name} have different commands in API Version {api_version}"
)
def remove_skipped_commands(command_sets: Dict[str, Set[str]]):
@ -173,10 +185,15 @@ def remove_skipped_commands(command_sets: Dict[str, Set[str]]):
def main():
"""Run the script."""
arg_parser = argparse.ArgumentParser(description=__doc__)
arg_parser.add_argument("--include", type=str, action="append",
help="Directory to search for IDL import files")
arg_parser.add_argument("--install-dir", dest="install_dir", required=True,
help="Directory to search for MongoDB binaries")
arg_parser.add_argument(
"--include", type=str, action="append", help="Directory to search for IDL import files"
)
arg_parser.add_argument(
"--install-dir",
dest="install_dir",
required=True,
help="Directory to search for MongoDB binaries",
)
arg_parser.add_argument("-v", "--verbose", action="count", help="Enable verbose logging")
arg_parser.add_argument("api_version", metavar="API_VERSION", help="API Version to check")
args = arg_parser.parse_args()

View File

@ -41,10 +41,14 @@ if __name__ == "__main__" and __package__ is None:
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
# pylint: disable=wrong-import-position
from buildscripts.resmokelib.multiversionconstants import LAST_LTS_FCV, LAST_CONTINUOUS_FCV, LATEST_FCV
from buildscripts.resmokelib.multiversionconstants import (
LAST_LTS_FCV,
LAST_CONTINUOUS_FCV,
LATEST_FCV,
)
# pylint: enable=wrong-import-position
LOGGER_NAME = 'checkout-idl'
LOGGER_NAME = "checkout-idl"
LOGGER = logging.getLogger(LOGGER_NAME)
@ -52,9 +56,9 @@ def get_tags() -> List[str]:
"""Get a list of git tags that the IDL compatibility script should check against."""
def gen_versions_and_tags():
for tag in check_output(['git', 'tag']).decode().split():
for tag in check_output(["git", "tag"]).decode().split():
# Releases are like "r5.6.7". Older ones aren't r-prefixed but we don't care about them.
if not tag.startswith('r'):
if not tag.startswith("r"):
continue
try:
@ -110,14 +114,14 @@ def make_idl_directories(tags: List[str], destination: str) -> None:
for tag in tags:
LOGGER.info("Checking out IDL files in %s", tag)
directory = os.path.join(destination, tag)
for path in check_output(['git', 'ls-tree', '--name-only', '-r', tag]).decode().split():
if not path.endswith('.idl'):
for path in check_output(["git", "ls-tree", "--name-only", "-r", tag]).decode().split():
if not path.endswith(".idl"):
continue
contents = check_output(['git', 'show', f'{tag}:{path}']).decode()
contents = check_output(["git", "show", f"{tag}:{path}"]).decode()
output_path = os.path.join(directory, path)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, 'w+') as fd:
with open(output_path, "w+") as fd:
fd.write(contents)
@ -125,8 +129,9 @@ def main():
"""Run the script."""
arg_parser = argparse.ArgumentParser(description=__doc__)
arg_parser.add_argument("-v", "--verbose", action="count", help="Enable verbose logging")
arg_parser.add_argument("destination", metavar="DESTINATION",
help="Directory to check out past IDL file versions")
arg_parser.add_argument(
"destination", metavar="DESTINATION", help="Directory to check out past IDL file versions"
)
args = arg_parser.parse_args()
logging.basicConfig(level=logging.WARNING)

View File

@ -38,7 +38,7 @@ from typing import List
import yaml
# Permit imports from "buildscripts".
sys.path.append(os.path.normpath(os.path.join(os.path.abspath(__file__), '../../..')))
sys.path.append(os.path.normpath(os.path.join(os.path.abspath(__file__), "../../..")))
# pylint: disable=wrong-import-position
from buildscripts.idl import lib
@ -60,7 +60,7 @@ def get_all_feature_flags(idl_dirs: List[str] = None):
# Most IDL files do not contain feature flags.
# We can discard these quickly without expensive YAML parsing.
with open(idl_path) as idl_file:
if 'feature_flags' not in idl_file.read():
if "feature_flags" not in idl_file.read():
continue
with open(idl_path) as idl_file:
doc = parser.parse_file(idl_file, idl_path)
@ -100,5 +100,5 @@ def main():
gen_all_feature_flags_file()
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -36,7 +36,7 @@ import sys
from typing import List
# Permit imports from "buildscripts".
sys.path.append(os.path.normpath(os.path.join(os.path.abspath(__file__), '../../..')))
sys.path.append(os.path.normpath(os.path.join(os.path.abspath(__file__), "../../..")))
# pylint: disable=wrong-import-position
from buildscripts.idl import lib
@ -58,7 +58,7 @@ def gen_all_server_params(idl_dirs: List[str] = None):
# Most IDL files do not contain server parameters.
# We can discard these quickly without expensive YAML parsing.
with open(idl_path) as idl_file:
if 'server_parameters' not in idl_file.read():
if "server_parameters" not in idl_file.read():
continue
with open(idl_path) as idl_file:
doc = parser.parse_file(idl_file, idl_path)
@ -80,5 +80,5 @@ def main():
gen_all_server_params_file()
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -33,6 +33,7 @@ Represents the derived IDL specification after type resolution in the binding pa
This is a lossy translation from the IDL Syntax tree as the IDL AST only contains information about
the enums and structs that need code generated for them, and just enough information to do that.
"""
from abc import ABCMeta, abstractmethod
import enum
from typing import Any, Dict, List, Optional
@ -46,8 +47,9 @@ class IDLBoundSpec(object):
def __init__(self, spec, error_collection):
# type: (IDLAST, errors.ParserErrorCollection) -> None
"""Must specify either an IDL document or errors, not both."""
assert (spec is None and error_collection is not None) or (spec is not None
and error_collection is None)
assert (spec is None and error_collection is not None) or (
spec is not None and error_collection is None
)
self.spec = spec
self.errors = error_collection
@ -292,7 +294,8 @@ class Field(common.SourceLocation):
# type: () -> bool
"""Returns true if the IDL compiler should add a call to serialization options for this field."""
return self.query_shape is not None and self.query_shape in [
QueryShapeFieldType.LITERAL, QueryShapeFieldType.ANONYMIZE
QueryShapeFieldType.LITERAL,
QueryShapeFieldType.ANONYMIZE,
]
@property

View File

@ -61,8 +61,9 @@ def _validate_single_bson_type(ctxt, idl_type, syntax_type):
subtype = "<unknown>"
if not bson.is_valid_bindata_subtype(subtype):
ctxt.add_bad_bson_bindata_subtype_value_error(idl_type, syntax_type, idl_type.name,
subtype)
ctxt.add_bad_bson_bindata_subtype_value_error(
idl_type, syntax_type, idl_type.name, subtype
)
elif idl_type.bindata_subtype is not None:
ctxt.add_bad_bson_bindata_subtype_error(idl_type, syntax_type, idl_type.name, bson_type)
@ -107,7 +108,7 @@ def _validate_type(ctxt, idl_type):
if idl_type.name.startswith("array<"):
ctxt.add_array_not_valid_error(idl_type, "type", idl_type.name)
_validate_type_properties(ctxt, idl_type, 'type')
_validate_type_properties(ctxt, idl_type, "type")
def _validate_cpp_type(ctxt, idl_type, syntax_type):
@ -120,15 +121,17 @@ def _validate_cpp_type(ctxt, idl_type, syntax_type):
ctxt.add_no_string_data_error(idl_type, syntax_type, idl_type.name)
# We do not support C++ char and float types for style reasons
if idl_type.cpp_type in ['char', 'wchar_t', 'char16_t', 'char32_t', 'float']:
ctxt.add_bad_cpp_numeric_type_use_error(idl_type, syntax_type, idl_type.name,
idl_type.cpp_type)
if idl_type.cpp_type in ["char", "wchar_t", "char16_t", "char32_t", "float"]:
ctxt.add_bad_cpp_numeric_type_use_error(
idl_type, syntax_type, idl_type.name, idl_type.cpp_type
)
# We do not support C++ builtin integer for style reasons
for numeric_word in ['signed', "unsigned", "int", "long", "short"]:
if re.search(r'\b%s\b' % (numeric_word), idl_type.cpp_type):
ctxt.add_bad_cpp_numeric_type_use_error(idl_type, syntax_type, idl_type.name,
idl_type.cpp_type)
for numeric_word in ["signed", "unsigned", "int", "long", "short"]:
if re.search(r"\b%s\b" % (numeric_word), idl_type.cpp_type):
ctxt.add_bad_cpp_numeric_type_use_error(
idl_type, syntax_type, idl_type.name, idl_type.cpp_type
)
# Return early so we only throw one error for types like "signed short int"
return
@ -151,27 +154,39 @@ def _validate_cpp_type(ctxt, idl_type, syntax_type):
# Check for std fixed integer types which are not allowed. These are not allowed even if they
# have the "std::" prefix.
for std_numeric_type in [
"int8_t", "int16_t", "int32_t", "int64_t", "uint8_t", "uint16_t", "uint32_t", "uint64_t"
"int8_t",
"int16_t",
"int32_t",
"int64_t",
"uint8_t",
"uint16_t",
"uint32_t",
"uint64_t",
]:
if std_numeric_type in idl_type.cpp_type:
ctxt.add_bad_cpp_numeric_type_use_error(idl_type, syntax_type, idl_type.name,
idl_type.cpp_type)
ctxt.add_bad_cpp_numeric_type_use_error(
idl_type, syntax_type, idl_type.name, idl_type.cpp_type
)
return
def _validate_chain_type_properties(ctxt, idl_type, syntax_type):
# type: (errors.ParserContext, Union[syntax.Type, ast.Type], str) -> None
"""Validate a chained type has both a deserializer and serializer."""
assert len(
idl_type.bson_serialization_type) == 1 and idl_type.bson_serialization_type[0] == 'chain'
assert (
len(idl_type.bson_serialization_type) == 1
and idl_type.bson_serialization_type[0] == "chain"
)
if idl_type.deserializer is None:
ctxt.add_missing_ast_required_field_error(idl_type, syntax_type, idl_type.name,
"deserializer")
ctxt.add_missing_ast_required_field_error(
idl_type, syntax_type, idl_type.name, "deserializer"
)
if idl_type.serializer is None:
ctxt.add_missing_ast_required_field_error(idl_type, syntax_type, idl_type.name,
"serializer")
ctxt.add_missing_ast_required_field_error(
idl_type, syntax_type, idl_type.name, "serializer"
)
def _validate_type_properties(ctxt, idl_type, syntax_type):
@ -189,29 +204,34 @@ def _validate_type_properties(ctxt, idl_type, syntax_type):
# serialization for their C++ type. An internal_only type is not associated with BSON
# and thus should not have a deserializer defined.
if idl_type.deserializer is None and not idl_type.internal_only:
ctxt.add_missing_ast_required_field_error(idl_type, syntax_type, idl_type.name,
"deserializer")
ctxt.add_missing_ast_required_field_error(
idl_type, syntax_type, idl_type.name, "deserializer"
)
elif bson_type == "chain":
_validate_chain_type_properties(ctxt, idl_type, syntax_type)
elif bson_type == "string":
# Strings support custom serialization unlike other non-object scalar types
if idl_type.deserializer is None:
ctxt.add_missing_ast_required_field_error(idl_type, syntax_type, idl_type.name,
"deserializer")
ctxt.add_missing_ast_required_field_error(
idl_type, syntax_type, idl_type.name, "deserializer"
)
elif not bson_type in ["array", "object", "bindata"]:
if idl_type.deserializer is None:
ctxt.add_missing_ast_required_field_error(idl_type, syntax_type, idl_type.name,
"deserializer")
ctxt.add_missing_ast_required_field_error(
idl_type, syntax_type, idl_type.name, "deserializer"
)
if idl_type.deserializer is not None and "BSONElement" not in idl_type.deserializer:
ctxt.add_not_custom_scalar_serialization_not_supported_error(
idl_type, syntax_type, idl_type.name, bson_type)
idl_type, syntax_type, idl_type.name, bson_type
)
if idl_type.serializer is not None:
ctxt.add_not_custom_scalar_serialization_not_supported_error(
idl_type, syntax_type, idl_type.name, bson_type)
idl_type, syntax_type, idl_type.name, bson_type
)
if bson_type == "bindata" and isinstance(idl_type, syntax.Type) and idl_type.default:
ctxt.add_bindata_no_default(idl_type, syntax_type, idl_type.name)
@ -219,8 +239,9 @@ def _validate_type_properties(ctxt, idl_type, syntax_type):
else:
# Now, this is a list of scalar types
if idl_type.deserializer is None:
ctxt.add_missing_ast_required_field_error(idl_type, syntax_type, idl_type.name,
"deserializer")
ctxt.add_missing_ast_required_field_error(
idl_type, syntax_type, idl_type.name, "deserializer"
)
_validate_cpp_type(ctxt, idl_type, syntax_type)
@ -251,26 +272,27 @@ def _is_duplicate_field(ctxt, field_container, fields, ast_field):
def _get_struct_qualified_cpp_name(struct):
# type: (syntax.Struct) -> str
return common.qualify_cpp_name(struct.cpp_namespace,
common.title_case(struct.cpp_name or struct.name))
return common.qualify_cpp_name(
struct.cpp_namespace, common.title_case(struct.cpp_name or struct.name)
)
def _compute_field_is_view(resolved_field, ctxt, symbols):
# type: (Union[syntax.Type, syntax.Enum, syntax.Struct], errors.ParserContext, syntax.SymbolTable) -> bool
"""Compute is_view for a symbol referenced by a field."""
# Resolved field is an array.
if (isinstance(resolved_field, syntax.ArrayType)):
if isinstance(resolved_field, syntax.ArrayType):
# Inner type needs to be resolved.
return _compute_field_is_view(resolved_field.element_type, ctxt, symbols)
# Resolved field is a variant.
elif (isinstance(resolved_field, syntax.VariantType)):
elif isinstance(resolved_field, syntax.VariantType):
for variant_type in resolved_field.variant_types:
# Inner type needs to be resolved.
if (_compute_field_is_view(variant_type, ctxt, symbols)):
if _compute_field_is_view(variant_type, ctxt, symbols):
return True
for variant_struct_type in resolved_field.variant_struct_types:
if (_compute_is_view(variant_struct_type, ctxt, symbols)):
if _compute_is_view(variant_struct_type, ctxt, symbols):
return True
return False
@ -282,12 +304,13 @@ def _compute_field_is_view(resolved_field, ctxt, symbols):
def _compute_chained_item_is_view(struct, ctxt, symbols, chained_item):
# type: (syntax.Struct, errors.ParserContext, syntax.SymbolTable, Union[syntax.ChainedType, syntax.ChainedStruct]) -> bool
"""Helper to compute is_view of chained types or structs."""
resolved_chained_item = symbols.resolve_type_from_name(ctxt, struct, chained_item.name,
chained_item.name)
resolved_chained_item = symbols.resolve_type_from_name(
ctxt, struct, chained_item.name, chained_item.name
)
# If symbols.resolve_field_type returns None, we can assume an error occured during the function.
# We can rely on symbols.resolve_field_type to add errors.
if (resolved_chained_item is None):
assert (ctxt.errors.has_errors())
if resolved_chained_item is None:
assert ctxt.errors.has_errors()
return True
return _compute_is_view(resolved_chained_item, ctxt, symbols)
@ -296,25 +319,25 @@ def _compute_command_type_is_view(struct, ctxt, symbols, field_type):
# type: (syntax.Struct, errors.ParserContext, syntax.SymbolTable, syntax.FieldType) -> bool
"""
Compute is_view for the command parameter type.
This function is similar to _compute_field_is_view, but because command parameter types are
syntax.FieldType instead of syntax.Type, separate logic must exist to resolve the command
parameter types.
"""
if (isinstance(field_type, syntax.FieldTypeVariant)):
if isinstance(field_type, syntax.FieldTypeVariant):
for variant_type in field_type.variant:
if (_compute_command_type_is_view(struct, ctxt, symbols, variant_type)):
if _compute_command_type_is_view(struct, ctxt, symbols, variant_type):
return True
elif (isinstance(field_type, syntax.FieldTypeArray)):
elif isinstance(field_type, syntax.FieldTypeArray):
return _compute_command_type_is_view(struct, ctxt, symbols, field_type.element_type)
elif (isinstance(field_type, syntax.FieldTypeSingle)):
elif isinstance(field_type, syntax.FieldTypeSingle):
resolved_type = symbols.resolve_field_type(ctxt, struct, field_type.type_name, field_type)
# If symbols.resolve_field_type returns None, we can assume an error occured during the function.
# We can rely on symbols.resolve_field_type to add errors.
if (resolved_type is None):
assert (ctxt.errors.has_errors())
if resolved_type is None:
assert ctxt.errors.has_errors()
return True
if (_compute_field_is_view(resolved_type, ctxt, symbols)):
if _compute_field_is_view(resolved_type, ctxt, symbols):
return True
else:
ctxt.add_unknown_command_type_error(struct, struct.name)
@ -325,37 +348,38 @@ def _compute_struct_is_view(struct, ctxt, symbols):
# type: (syntax.Struct, errors.ParserContext, syntax.SymbolTable) -> bool
"""Compute is_view for structs. A struct is a view if any of its fields are views."""
# Empty structs are non view types.
if (not struct.fields):
if not struct.fields:
return False
for field in struct.fields:
if (field.ignore):
if field.ignore:
continue
# Get the resolved field from the global symbol table.
resolved_field = symbols.resolve_field_type(ctxt, field, field.name, field.type)
# If symbols.resolve_field_type returns None, we can assume an error occured during the function.
# We can rely on symbols.resolve_field_type to add errors.
if (resolved_field is None):
assert (ctxt.errors.has_errors())
if resolved_field is None:
assert ctxt.errors.has_errors()
return True
# If any field is a view type, then the struct is also a view type.
if (_compute_field_is_view(resolved_field, ctxt, symbols)):
if _compute_field_is_view(resolved_field, ctxt, symbols):
return True
if (struct.chained_types):
if struct.chained_types:
for chained_type in struct.chained_types:
if (_compute_chained_item_is_view(struct, ctxt, symbols, chained_type)):
if _compute_chained_item_is_view(struct, ctxt, symbols, chained_type):
return True
if (struct.chained_structs):
if struct.chained_structs:
for chained_struct in struct.chained_structs:
if (_compute_chained_item_is_view(struct, ctxt, symbols, chained_struct)):
if _compute_chained_item_is_view(struct, ctxt, symbols, chained_struct):
return True
# Check command parameter.
if (isinstance(struct, syntax.Command)):
if (struct.type is not None
and _compute_command_type_is_view(struct, ctxt, symbols, struct.type)):
if isinstance(struct, syntax.Command):
if struct.type is not None and _compute_command_type_is_view(
struct, ctxt, symbols, struct.type
):
return True
return False
@ -364,11 +388,11 @@ def _compute_struct_is_view(struct, ctxt, symbols):
def _compute_is_view(symbol, ctxt, symbols):
# type: (Union[syntax.Type, syntax.Enum, syntax.Struct], errors.ParserContext, syntax.SymbolTable) -> bool
"""Compute is_view for any symbol."""
if (isinstance(symbol, syntax.Type)):
if isinstance(symbol, syntax.Type):
return symbol.is_view
elif (isinstance(symbol, syntax.Enum)):
elif isinstance(symbol, syntax.Enum):
return False
elif (isinstance(symbol, syntax.Struct)):
elif isinstance(symbol, syntax.Struct):
return _compute_struct_is_view(symbol, ctxt, symbols)
else:
ctxt.add_unknown_symbol_error(symbol, symbol.name)
@ -394,11 +418,16 @@ def _bind_struct_common(ctxt, parsed_spec, struct, ast_struct):
ast_struct.is_command_reply = struct.is_command_reply
ast_struct.is_catalog_ctxt = struct.is_catalog_ctxt
ast_struct.query_shape_component = struct.query_shape_component
ast_struct.unsafe_dangerous_disable_extra_field_duplicate_checks = struct.unsafe_dangerous_disable_extra_field_duplicate_checks
ast_struct.unsafe_dangerous_disable_extra_field_duplicate_checks = (
struct.unsafe_dangerous_disable_extra_field_duplicate_checks
)
ast_struct.is_view = _compute_is_view(struct, ctxt, parsed_spec.symbols)
# Check that unsafe_dangerous_disable_extra_field_duplicate_checks is used correctly
if ast_struct.unsafe_dangerous_disable_extra_field_duplicate_checks and ast_struct.strict is True:
if (
ast_struct.unsafe_dangerous_disable_extra_field_duplicate_checks
and ast_struct.strict is True
):
ctxt.add_strict_and_disable_check_not_allowed(ast_struct)
if struct.is_generic_cmd_list:
@ -419,8 +448,9 @@ def _bind_struct_common(ctxt, parsed_spec, struct, ast_struct):
for chained_type in struct.chained_types:
ast_field = _bind_chained_type(ctxt, parsed_spec, ast_struct, chained_type)
if ast_field and not _is_duplicate_field(ctxt, chained_type.name, ast_struct.fields,
ast_field):
if ast_field and not _is_duplicate_field(
ctxt, chained_type.name, ast_struct.fields, ast_field
):
ast_struct.fields.append(ast_field)
# Merge chained structs as a chained struct and ignored fields
@ -443,12 +473,14 @@ def _bind_struct_common(ctxt, parsed_spec, struct, ast_struct):
ast_field.generic_field_info = gen_field_info
if ast_field.supports_doc_sequence and not isinstance(ast_struct, ast.Command):
# Doc sequences are only supported in commands at the moment
ctxt.add_bad_struct_field_as_doc_sequence_error(ast_struct, ast_struct.name,
ast_field.name)
ctxt.add_bad_struct_field_as_doc_sequence_error(
ast_struct, ast_struct.name, ast_field.name
)
if ast_field.non_const_getter and struct.immutable:
ctxt.add_bad_field_non_const_getter_in_immutable_struct_error(
ast_struct, ast_struct.name, ast_field.name)
ast_struct, ast_struct.name, ast_field.name
)
if not _is_duplicate_field(ctxt, ast_struct.name, ast_struct.fields, ast_field):
ast_struct.fields.append(ast_field)
@ -462,10 +494,12 @@ def _bind_struct_common(ctxt, parsed_spec, struct, ast_struct):
ctxt.add_must_be_query_shape_component(ast_field, ast_struct.name, ast_field.name)
if ast_field.query_shape == ast.QueryShapeFieldType.ANONYMIZE and not (
ast_field.type.cpp_type in ["std::string", "std::vector<std::string>"]
or 'string' in ast_field.type.bson_serialization_type):
ctxt.add_query_shape_anonymize_must_be_string(ast_field, ast_field.name,
ast_field.type.cpp_type)
ast_field.type.cpp_type in ["std::string", "std::vector<std::string>"]
or "string" in ast_field.type.bson_serialization_type
):
ctxt.add_query_shape_anonymize_must_be_string(
ast_field, ast_field.name, ast_field.type.cpp_type
)
# Fill out the field comparison_order property as needed
if ast_struct.generate_comparison_operators and ast_struct.fields:
@ -478,8 +512,9 @@ def _bind_struct_common(ctxt, parsed_spec, struct, ast_struct):
if not ast_field.comparison_order == -1:
use_default_order = False
if ast_field.comparison_order in comparison_orders:
ctxt.add_duplicate_comparison_order_field_error(ast_struct, ast_struct.name,
ast_field.comparison_order)
ctxt.add_duplicate_comparison_order_field_error(
ast_struct, ast_struct.name, ast_field.comparison_order
)
comparison_orders.add(ast_field.comparison_order)
@ -500,8 +535,9 @@ def _inject_hidden_fields(struct):
serialization_context_field = syntax.Field(struct.file_name, struct.line, struct.column)
serialization_context_field.name = "serialization_context" # This comes from basic_types.idl
serialization_context_field.type = syntax.FieldTypeSingle(struct.file_name, struct.line,
struct.column)
serialization_context_field.type = syntax.FieldTypeSingle(
struct.file_name, struct.line, struct.column
)
serialization_context_field.type.type_name = "serialization_context"
serialization_context_field.cpp_name = "serializationContext"
serialization_context_field.optional = False
@ -606,7 +642,7 @@ def _bind_variant_field(ctxt, ast_field, idl_type):
def gen_cpp_types():
for alternative in ast_field.type.variant_types:
if alternative.is_array:
yield f'std::vector<{alternative.cpp_type}>'
yield f"std::vector<{alternative.cpp_type}>"
else:
yield alternative.cpp_type
@ -638,8 +674,9 @@ def _bind_command_type(ctxt, parsed_spec, command):
ctxt.add_array_not_valid_error(ast_field, "field", ast_field.name)
# Resolve the command type as a field
syntax_symbol = parsed_spec.symbols.resolve_field_type(ctxt, command, command.name,
command.type)
syntax_symbol = parsed_spec.symbols.resolve_field_type(
ctxt, command, command.name, command.type
)
if syntax_symbol is None:
return None
@ -649,8 +686,9 @@ def _bind_command_type(ctxt, parsed_spec, command):
assert not isinstance(syntax_symbol, syntax.Enum)
base_type = (syntax_symbol.element_type
if isinstance(syntax_symbol, syntax.ArrayType) else syntax_symbol)
base_type = (
syntax_symbol.element_type if isinstance(syntax_symbol, syntax.ArrayType) else syntax_symbol
)
# Copy over only the needed information if this is a struct or a type.
if isinstance(base_type, syntax.Struct):
@ -683,8 +721,9 @@ def _bind_command_reply_type(ctxt, parsed_spec, command):
ast_field.description = f"{command.name} reply type"
# Resolve the command type as a field
syntax_symbol = parsed_spec.symbols.resolve_type_from_name(ctxt, command, command.name,
command.reply_type)
syntax_symbol = parsed_spec.symbols.resolve_type_from_name(
ctxt, command, command.name, command.reply_type
)
if syntax_symbol is None:
# Resolution failed, we've recorded an error.
@ -714,8 +753,9 @@ def _bind_enum_value(ctxt, parsed_spec, location, enum_name, enum_value):
# type: (errors.ParserContext, syntax.IDLSpec, common.SourceLocation, str, str) -> str
# Look up the enum for "enum_name" in the symbol table
access_check_enum = parsed_spec.symbols.resolve_type_from_name(ctxt, location, "access_check",
enum_name)
access_check_enum = parsed_spec.symbols.resolve_type_from_name(
ctxt, location, "access_check", enum_name
)
if access_check_enum is None:
# Resolution failed, we've recorded an error.
@ -725,8 +765,9 @@ def _bind_enum_value(ctxt, parsed_spec, location, enum_name, enum_value):
ctxt.add_unknown_type_error(location, enum_name, "enum")
return None
syntax_enum = resolve_enum_value(ctxt, location, cast(syntax.Enum, access_check_enum),
enum_value)
syntax_enum = resolve_enum_value(
ctxt, location, cast(syntax.Enum, access_check_enum), enum_value
)
if not syntax_enum:
return None
@ -737,22 +778,25 @@ def _bind_single_check(ctxt, parsed_spec, access_check):
# type: (errors.ParserContext, syntax.IDLSpec, syntax.AccessCheck) -> ast.AccessCheck
"""Bind a single access_check."""
ast_access_check = ast.AccessCheck(access_check.file_name, access_check.line,
access_check.column)
ast_access_check = ast.AccessCheck(
access_check.file_name, access_check.line, access_check.column
)
assert bool(access_check.check) != bool(access_check.privilege)
if access_check.check:
ast_access_check.check = _bind_enum_value(ctxt, parsed_spec, access_check, "AccessCheck",
access_check.check)
ast_access_check.check = _bind_enum_value(
ctxt, parsed_spec, access_check, "AccessCheck", access_check.check
)
if not ast_access_check.check:
return None
else:
privilege = access_check.privilege
ast_privilege = ast.Privilege(privilege.file_name, privilege.line, privilege.column)
ast_privilege.resource_pattern = _bind_enum_value(ctxt, parsed_spec, privilege, "MatchType",
privilege.resource_pattern)
ast_privilege.resource_pattern = _bind_enum_value(
ctxt, parsed_spec, privilege, "MatchType", privilege.resource_pattern
)
if not ast_privilege.resource_pattern:
return None
@ -935,7 +979,7 @@ def _validate_variant_type(ctxt, syntax_symbol, field):
for type_name, count in array_type_count.items():
if count > 1:
ctxt.add_variant_duplicate_types_error(syntax_symbol, field.name, f'array<{type_name}>')
ctxt.add_variant_duplicate_types_error(syntax_symbol, field.name, f"array<{type_name}>")
types = len(syntax_symbol.variant_types) + len(syntax_symbol.variant_struct_types)
if types < 2:
@ -958,14 +1002,14 @@ def _validate_field_properties(ctxt, ast_field):
if ast_field.optional:
ctxt.add_bad_field_default_and_optional(ast_field, ast_field.name)
if ast_field.type.bson_serialization_type == ['bindata']:
if ast_field.type.bson_serialization_type == ["bindata"]:
ctxt.add_bindata_no_default(ast_field, ast_field.type.name, ast_field.name)
if ast_field.always_serialize and not ast_field.optional:
ctxt.add_bad_field_always_serialize_not_optional(ast_field, ast_field.name)
# A "chain" type should never appear as a field.
if ast_field.type.bson_serialization_type == ['chain']:
if ast_field.type.bson_serialization_type == ["chain"]:
ctxt.add_bad_array_of_chain(ast_field, ast_field.name)
@ -979,7 +1023,7 @@ def _validate_doc_sequence_field(ctxt, ast_field):
# The only allowed BSON type for a doc_sequence field is "object"
for serialization_type in ast_field.type.bson_serialization_type:
if serialization_type != 'object':
if serialization_type != "object":
ctxt.add_bad_non_object_as_doc_sequence_error(ast_field, ast_field.name)
@ -991,7 +1035,7 @@ def _normalize_method_name(cpp_type_name, cpp_method_name):
return cpp_method_name
# Global function
if cpp_method_name.startswith('::'):
if cpp_method_name.startswith("::"):
return cpp_method_name
# Method is full qualified already
@ -1003,7 +1047,7 @@ def _normalize_method_name(cpp_type_name, cpp_method_name):
# Method is prefixed with just the type name
if cpp_method_name.startswith(type_name):
return '::'.join(cpp_type_name.split('::')[0:-1]) + "::" + cpp_method_name
return "::".join(cpp_type_name.split("::")[0:-1]) + "::" + cpp_method_name
return cpp_method_name
@ -1046,9 +1090,9 @@ def _bind_expression(expr, allow_literal_string=True):
# std::string
if allow_literal_string:
strval = expr.literal
for i in ['\\', '"', "'"]:
for i in ["\\", '"', "'"]:
if i in strval:
strval = strval.replace(i, '\\' + i)
strval = strval.replace(i, "\\" + i)
node.expr = '"' + strval + '"'
return node
@ -1093,11 +1137,11 @@ def _bind_condition(condition, condition_for):
ast_condition.preprocessor = condition.preprocessor
if condition.feature_flag:
assert condition_for == 'server_parameter'
assert condition_for == "server_parameter"
ast_condition.feature_flag = condition.feature_flag
if condition.min_fcv:
assert condition_for == 'server_parameter'
assert condition_for == "server_parameter"
ast_condition.min_fcv = condition.min_fcv
return ast_condition
@ -1186,12 +1230,14 @@ def _bind_field(ctxt, parsed_spec, field):
_validate_array_type(ctxt, cast(syntax.ArrayType, syntax_symbol), field)
elif field.supports_doc_sequence:
# Doc sequences are only supported for arrays
ctxt.add_bad_non_array_as_doc_sequence_error(syntax_symbol, syntax_symbol.name,
ast_field.name)
ctxt.add_bad_non_array_as_doc_sequence_error(
syntax_symbol, syntax_symbol.name, ast_field.name
)
return None
base_type = (syntax_symbol.element_type
if isinstance(syntax_symbol, syntax.ArrayType) else syntax_symbol)
base_type = (
syntax_symbol.element_type if isinstance(syntax_symbol, syntax.ArrayType) else syntax_symbol
)
# Copy over only the needed information if this is a struct or a type.
@ -1244,8 +1290,9 @@ def _bind_field(ctxt, parsed_spec, field):
def _bind_chained_type(ctxt, parsed_spec, location, chained_type):
# type: (errors.ParserContext, syntax.IDLSpec, common.SourceLocation, syntax.ChainedType) -> ast.Field
"""Bind the specified chained type."""
syntax_symbol = parsed_spec.symbols.resolve_type_from_name(ctxt, location, chained_type.name,
chained_type.name)
syntax_symbol = parsed_spec.symbols.resolve_type_from_name(
ctxt, location, chained_type.name, chained_type.name
)
if not syntax_symbol:
return None
@ -1255,9 +1302,10 @@ def _bind_chained_type(ctxt, parsed_spec, location, chained_type):
idltype = cast(syntax.Type, syntax_symbol)
if len(idltype.bson_serialization_type) != 1 or idltype.bson_serialization_type[0] != 'chain':
ctxt.add_chained_type_wrong_type_error(location, chained_type.name,
idltype.bson_serialization_type[0])
if len(idltype.bson_serialization_type) != 1 or idltype.bson_serialization_type[0] != "chain":
ctxt.add_chained_type_wrong_type_error(
location, chained_type.name, idltype.bson_serialization_type[0]
)
return None
ast_field = ast.Field(location.file_name, location.line, location.column)
@ -1274,7 +1322,8 @@ def _bind_chained_struct(ctxt, parsed_spec, ast_struct, chained_struct):
# type: (errors.ParserContext, syntax.IDLSpec, ast.Struct, syntax.ChainedStruct) -> None
"""Bind the specified chained struct."""
syntax_symbol = parsed_spec.symbols.resolve_type_from_name(
ctxt, ast_struct, chained_struct.name, chained_struct.name)
ctxt, ast_struct, chained_struct.name, chained_struct.name
)
if not syntax_symbol:
return
@ -1287,12 +1336,14 @@ def _bind_chained_struct(ctxt, parsed_spec, ast_struct, chained_struct):
# chained struct cannot be strict unless it is inlined
if struct.strict and not ast_struct.inline_chained_structs:
ctxt.add_chained_nested_struct_no_strict_error(ast_struct, ast_struct.name,
chained_struct.name)
ctxt.add_chained_nested_struct_no_strict_error(
ast_struct, ast_struct.name, chained_struct.name
)
if struct.chained_types or struct.chained_structs:
ctxt.add_chained_nested_struct_no_nested_error(ast_struct, ast_struct.name,
chained_struct.name)
ctxt.add_chained_nested_struct_no_nested_error(
ast_struct, ast_struct.name, chained_struct.name
)
# Configure a field for the chained struct.
ast_chained_field = ast.Field(ast_struct.file_name, ast_struct.line, ast_struct.column)
@ -1313,9 +1364,9 @@ def _bind_chained_struct(ctxt, parsed_spec, ast_struct, chained_struct):
# Don't use internal fields in chained types, stick to local access only
if ast_field.type.internal_only:
continue
if ast_field and not _is_duplicate_field(ctxt, chained_struct.name, ast_struct.fields,
ast_field):
if ast_field and not _is_duplicate_field(
ctxt, chained_struct.name, ast_struct.fields, ast_field
):
if ast_struct.inline_chained_structs:
ast_field.chained_struct_field = ast_chained_field
else:
@ -1329,8 +1380,9 @@ def _bind_globals(ctxt, parsed_spec):
# type: (errors.ParserContext, syntax.IDLSpec) -> ast.Global
"""Bind the globals object from the idl.syntax tree into the idl.ast tree by doing a deep copy."""
if parsed_spec.globals:
ast_global = ast.Global(parsed_spec.globals.file_name, parsed_spec.globals.line,
parsed_spec.globals.column)
ast_global = ast.Global(
parsed_spec.globals.file_name, parsed_spec.globals.line, parsed_spec.globals.column
)
ast_global.cpp_namespace = parsed_spec.globals.cpp_namespace
ast_global.cpp_includes = parsed_spec.globals.cpp_includes
@ -1345,7 +1397,8 @@ def _bind_globals(ctxt, parsed_spec):
init = configs.initializer
ast_global.configs.initializer = ast.GlobalInitializer(
init.file_name, init.line, init.column)
init.file_name, init.line, init.column
)
# Parser rule makes it impossible to have both name and register/store.
ast_global.configs.initializer.name = init.name
ast_global.configs.initializer.register = init.register
@ -1371,8 +1424,9 @@ def _validate_enum_int(ctxt, idl_enum):
try:
int_values_set.add(int(enum_value.value))
except ValueError as value_error:
ctxt.add_enum_value_not_int_error(idl_enum, idl_enum.name, enum_value.value,
str(value_error))
ctxt.add_enum_value_not_int_error(
idl_enum, idl_enum.name, enum_value.value, str(value_error)
)
return
@ -1412,7 +1466,7 @@ def _bind_enum(ctxt, idl_enum):
if len(idl_enum.values) != len(values_set):
ctxt.add_enum_value_not_unique_error(idl_enum, idl_enum.name)
if ast_enum.type == 'int':
if ast_enum.type == "int":
_validate_enum_int(ctxt, idl_enum)
return ast_enum
@ -1423,9 +1477,9 @@ def _bind_server_parameter_class(ctxt, ast_param, param):
"""Bind and validate ServerParameter attributes specific to specialized ServerParameters."""
# Fields specific to bound and unbound standard params.
for field in ['cpp_vartype', 'cpp_varname', 'on_update', 'validator']:
for field in ["cpp_vartype", "cpp_varname", "on_update", "validator"]:
if getattr(param, field) is not None:
ctxt.add_server_parameter_invalid_attr(param, field, 'specialized')
ctxt.add_server_parameter_invalid_attr(param, field, "specialized")
return None
# Fields specific to specialized stroage.
@ -1433,8 +1487,9 @@ def _bind_server_parameter_class(ctxt, ast_param, param):
if param.default is not None:
if not param.default.is_constexpr:
ctxt.add_server_parameter_invalid_attr(param, 'default.is_constexpr=false',
'specialized')
ctxt.add_server_parameter_invalid_attr(
param, "default.is_constexpr=false", "specialized"
)
return None
ast_param.default = _bind_expression(param.default)
@ -1448,7 +1503,7 @@ def _bind_server_parameter_class(ctxt, ast_param, param):
ast_param.cpp_class.override_validate = cls.override_validate
# If set_at is cluster, then set must be overridden. Otherwise, use the parsed value.
ast_param.cpp_class.override_set = True if param.set_at == ['cluster'] else cls.override_set
ast_param.cpp_class.override_set = True if param.set_at == ["cluster"] else cls.override_set
return ast_param
@ -1458,13 +1513,13 @@ def _bind_server_parameter_with_storage(ctxt, ast_param, param):
"""Bind and validate ServerParameter attributes specific to bound ServerParameters."""
# Fields specific to specialized and unbound standard params.
for field in ['cpp_class']:
for field in ["cpp_class"]:
if getattr(param, field) is not None:
ctxt.add_server_parameter_invalid_attr(param, field, 'bound')
ctxt.add_server_parameter_invalid_attr(param, field, "bound")
return None
if param.set_at == ['cluster']:
ast_param.cpp_vartype = f'TenantIdMap<{param.cpp_vartype}>'
if param.set_at == ["cluster"]:
ast_param.cpp_vartype = f"TenantIdMap<{param.cpp_vartype}>"
else:
ast_param.cpp_vartype = param.cpp_vartype
ast_param.cpp_varname = param.cpp_varname
@ -1487,20 +1542,20 @@ def _bind_server_parameter_set_at(ctxt, param):
# type: (errors.ParserContext, syntax.ServerParameter) -> str
"""Translate set_at options to C++ enum value."""
if param.set_at == ['readonly']:
if param.set_at == ["readonly"]:
# Readonly may not be mixed with startup or runtime
return "ServerParameterType::kReadOnly"
if param.set_at == ['cluster']:
if param.set_at == ["cluster"]:
# Cluster-wide parameters may not be mixed with startup or runtime.
# They are implicitly runtime-only.
return "ServerParameterType::kClusterWide"
set_at = 0
for psa in param.set_at:
if psa.lower() == 'startup':
if psa.lower() == "startup":
set_at |= 1
elif psa.lower() == 'runtime':
elif psa.lower() == "runtime":
set_at |= 2
else:
ctxt.add_bad_setat_specifier(param, psa)
@ -1516,7 +1571,7 @@ def _bind_server_parameter_set_at(ctxt, param):
return mask_to_text[set_at]
# Can't happen based on above logic.
ctxt.add_bad_setat_specifier(param, ','.join(param.set_at))
ctxt.add_bad_setat_specifier(param, ",".join(param.set_at))
return None
@ -1526,19 +1581,19 @@ def _bind_server_parameter(ctxt, param):
ast_param = ast.ServerParameter(param.file_name, param.line, param.column)
ast_param.name = param.name
ast_param.description = param.description
ast_param.condition = _bind_condition(param.condition, condition_for='server_parameter')
ast_param.condition = _bind_condition(param.condition, condition_for="server_parameter")
ast_param.redact = param.redact
ast_param.test_only = param.test_only
ast_param.deprecated_name = param.deprecated_name
# The omit_in_ftdc flag can only be enabled for cluster parameters.
if param.omit_in_ftdc is not None and param.set_at != ['cluster']:
ctxt.add_server_parameter_invalid_attr(param, 'omit_in_ftdc=True', ''.join(param.set_at))
if param.omit_in_ftdc is not None and param.set_at != ["cluster"]:
ctxt.add_server_parameter_invalid_attr(param, "omit_in_ftdc=True", "".join(param.set_at))
return None
# If omit_in_ftdc is None (it has not been set) for a cluster parameter, then emit an error.
if param.omit_in_ftdc is None and param.set_at == ['cluster']:
ctxt.add_server_parameter_required_attr(param, 'omit_in_ftdc', 'cluster')
if param.omit_in_ftdc is None and param.set_at == ["cluster"]:
ctxt.add_server_parameter_required_attr(param, "omit_in_ftdc", "cluster")
ast_param.omit_in_ftdc = param.omit_in_ftdc
@ -1551,7 +1606,7 @@ def _bind_server_parameter(ctxt, param):
elif param.cpp_varname:
return _bind_server_parameter_with_storage(ctxt, ast_param, param)
else:
ctxt.add_server_parameter_required_attr(param, 'cpp_varname', 'server_parameter')
ctxt.add_server_parameter_required_attr(param, "cpp_varname", "server_parameter")
return None
@ -1572,7 +1627,11 @@ def _bind_feature_flags(ctxt, param):
return None
# Feature flags that default to true and should be FCV gated are required to have a version
if param.default.literal == "true" and param.shouldBeFCVGated.literal == "true" and not param.version:
if (
param.default.literal == "true"
and param.shouldBeFCVGated.literal == "true"
and not param.version
):
ctxt.add_feature_flag_default_true_missing_version(param)
return None
@ -1582,9 +1641,11 @@ def _bind_feature_flags(ctxt, param):
return None
expr = syntax.Expression(param.default.file_name, param.default.line, param.default.column)
expr.expr = '%s, "%s"_sd, %s' % (param.default.literal, param.version if
(param.shouldBeFCVGated.literal == "true"
and param.version) else '', param.shouldBeFCVGated.literal)
expr.expr = '%s, "%s"_sd, %s' % (
param.default.literal,
param.version if (param.shouldBeFCVGated.literal == "true" and param.version) else "",
param.shouldBeFCVGated.literal,
)
ast_param.default = _bind_expression(expr)
ast_param.default.export = False
@ -1597,7 +1658,7 @@ def _bind_feature_flags(ctxt, param):
def _is_invalid_config_short_name(name):
# type: (str) -> bool
"""Check if a given name is valid as a short name."""
return ('.' in name) or (',' in name)
return ("." in name) or ("," in name)
def _parse_config_option_sources(source_list):
@ -1635,7 +1696,7 @@ def _bind_config_option(ctxt, globals_spec, option):
node = ast.ConfigOption(option.file_name, option.line, option.column)
if _is_invalid_config_short_name(option.short_name or ''):
if _is_invalid_config_short_name(option.short_name or ""):
ctxt.add_invalid_short_name(option, option.short_name)
return None
@ -1664,13 +1725,13 @@ def _bind_config_option(ctxt, globals_spec, option):
ctxt.add_missing_short_name_with_single_name(option, option.single_name)
return None
node.short_name = node.short_name + ',' + option.single_name
node.short_name = node.short_name + "," + option.single_name
node.description = _bind_expression(option.description)
node.arg_vartype = option.arg_vartype
node.cpp_vartype = option.cpp_vartype
node.cpp_varname = option.cpp_varname
node.condition = _bind_condition(option.condition, condition_for='config')
node.condition = _bind_condition(option.condition, condition_for="config")
node.requires = option.requires
node.conflicts = option.conflicts
@ -1694,7 +1755,7 @@ def _bind_config_option(ctxt, globals_spec, option):
node.source = _parse_config_option_sources(source_list)
if node.source is None:
ctxt.add_bad_source_specifier(option, ', '.join(source_list))
ctxt.add_bad_source_specifier(option, ", ".join(source_list))
return None
if option.duplicate_behavior:
@ -1710,17 +1771,17 @@ def _bind_config_option(ctxt, globals_spec, option):
return None
# Parse single digit, closed range, or open range of digits.
spread = option.positional.split('-')
spread = option.positional.split("-")
if len(spread) == 1:
# Make a single number behave like a range of that number, (e.g. "2" -> "2-2").
spread.append(spread[0])
if (len(spread) != 2) or ((spread[0] == "") and (spread[1] == "")):
ctxt.add_bad_numeric_range(option, 'positional', option.positional)
ctxt.add_bad_numeric_range(option, "positional", option.positional)
try:
node.positional_start = int(spread[0] or "-1")
node.positional_end = int(spread[1] or "-1")
except ValueError:
ctxt.add_bad_numeric_range(option, 'positional', option.positional)
ctxt.add_bad_numeric_range(option, "positional", option.positional)
return None
if option.validator is not None:

View File

@ -37,43 +37,43 @@ from typing import Dict, List
# scalar: True if the type is not an array or object
# bson_type_enum: The BSONType enum value for the given type
_BSON_TYPE_INFORMATION = {
"double": {'scalar': True, 'bson_type_enum': 'NumberDouble'},
"string": {'scalar': True, 'bson_type_enum': 'String'},
"object": {'scalar': False, 'bson_type_enum': 'Object'},
"array": {'scalar': False, 'bson_type_enum': 'Array'},
"bindata": {'scalar': True, 'bson_type_enum': 'BinData'},
"undefined": {'scalar': True, 'bson_type_enum': 'Undefined'},
"objectid": {'scalar': True, 'bson_type_enum': 'jstOID'},
"bool": {'scalar': True, 'bson_type_enum': 'Bool'},
"date": {'scalar': True, 'bson_type_enum': 'Date'},
"null": {'scalar': True, 'bson_type_enum': 'jstNULL'},
"regex": {'scalar': True, 'bson_type_enum': 'RegEx'},
"int": {'scalar': True, 'bson_type_enum': 'NumberInt'},
"timestamp": {'scalar': True, 'bson_type_enum': 'bsonTimestamp'},
"long": {'scalar': True, 'bson_type_enum': 'NumberLong'},
"decimal": {'scalar': True, 'bson_type_enum': 'NumberDecimal'},
"double": {"scalar": True, "bson_type_enum": "NumberDouble"},
"string": {"scalar": True, "bson_type_enum": "String"},
"object": {"scalar": False, "bson_type_enum": "Object"},
"array": {"scalar": False, "bson_type_enum": "Array"},
"bindata": {"scalar": True, "bson_type_enum": "BinData"},
"undefined": {"scalar": True, "bson_type_enum": "Undefined"},
"objectid": {"scalar": True, "bson_type_enum": "jstOID"},
"bool": {"scalar": True, "bson_type_enum": "Bool"},
"date": {"scalar": True, "bson_type_enum": "Date"},
"null": {"scalar": True, "bson_type_enum": "jstNULL"},
"regex": {"scalar": True, "bson_type_enum": "RegEx"},
"int": {"scalar": True, "bson_type_enum": "NumberInt"},
"timestamp": {"scalar": True, "bson_type_enum": "bsonTimestamp"},
"long": {"scalar": True, "bson_type_enum": "NumberLong"},
"decimal": {"scalar": True, "bson_type_enum": "NumberDecimal"},
}
# Dictionary of BinData subtype type Information
# scalar: True if the type is not an array or object
# bindata_enum: The BinDataType enum value for the given type
_BINDATA_SUBTYPE = {
"generic": {'scalar': True, 'bindata_enum': 'BinDataGeneral'},
"function": {'scalar': True, 'bindata_enum': 'Function'},
"generic": {"scalar": True, "bindata_enum": "BinDataGeneral"},
"function": {"scalar": True, "bindata_enum": "Function"},
# Also simply known as type 2, deprecated, and requires special handling
#"binary": {
# "binary": {
# 'scalar': False,
# 'bindata_enum': 'ByteArrayDeprecated'
#},
# },
# Deprecated
# "uuid_old": {
# 'scalar': False,
# 'bindata_enum': 'bdtUUID'
# },
"uuid": {'scalar': True, 'bindata_enum': 'newUUID'},
"md5": {'scalar': True, 'bindata_enum': 'MD5Type'},
"encrypt": {'scalar': True, 'bindata_enum': 'Encrypt'},
"sensitive": {'scalar': True, 'bindata_enum': 'Sensitive'},
"uuid": {"scalar": True, "bindata_enum": "newUUID"},
"md5": {"scalar": True, "bindata_enum": "MD5Type"},
"encrypt": {"scalar": True, "bindata_enum": "Encrypt"},
"sensitive": {"scalar": True, "bindata_enum": "Sensitive"},
}
@ -87,14 +87,14 @@ def is_scalar_bson_type(name):
# type: (str) -> bool
"""Return True if this bson type is a scalar."""
assert is_valid_bson_type(name)
return _BSON_TYPE_INFORMATION[name]['scalar'] # type: ignore
return _BSON_TYPE_INFORMATION[name]["scalar"] # type: ignore
def cpp_bson_type_name(name):
# type: (str) -> str
"""Return the C++ type name for a bson type."""
assert is_valid_bson_type(name)
return _BSON_TYPE_INFORMATION[name]['bson_type_enum'] # type: ignore
return _BSON_TYPE_INFORMATION[name]["bson_type_enum"] # type: ignore
def list_valid_types():
@ -113,4 +113,4 @@ def cpp_bindata_subtype_type_name(name):
# type: (str) -> str
"""Return the C++ type name for a bindata subtype."""
assert is_valid_bindata_subtype(name)
return _BINDATA_SUBTYPE[name]['bindata_enum'] # type: ignore
return _BINDATA_SUBTYPE[name]["bindata_enum"] # type: ignore

View File

@ -48,7 +48,7 @@ def title_case(name):
# Only capitalize the last part of a fully-qualified name
pos = name.rfind("::")
if pos > -1:
return name[:pos + 2] + name[pos + 2:pos + 3].upper() + name[pos + 3:]
return name[: pos + 2] + name[pos + 2 : pos + 3].upper() + name[pos + 3 :]
return name[0:1].upper() + name[1:]
@ -72,9 +72,9 @@ def _escape_template_string(template):
# type: (str) -> str
"""Escape the '$' in template strings unless followed by '{'."""
# See https://docs.python.org/2/library/string.html#template-strings
template = template.replace('${', '#{')
template = template.replace('$', '$$')
return template.replace('#{', '${')
template = template.replace("${", "#{")
template = template.replace("$", "$$")
return template.replace("#{", "${")
def template_format(template, template_params=None):

View File

@ -80,26 +80,42 @@ class CompilerImportResolver(parser.ImportResolverBase):
logging.debug("Resolving imported file '%s' for file '%s'", imported_file_name, base_file)
# Check for fully-qualified paths
logging.debug("Checking for imported file '%s' for file '%s' at '%s'", imported_file_name,
base_file, imported_file_name)
logging.debug(
"Checking for imported file '%s' for file '%s' at '%s'",
imported_file_name,
base_file,
imported_file_name,
)
if os.path.isabs(imported_file_name) and os.path.exists(imported_file_name):
logging.debug("Found imported file '%s' for file '%s' at '%s'", imported_file_name,
base_file, imported_file_name)
logging.debug(
"Found imported file '%s' for file '%s' at '%s'",
imported_file_name,
base_file,
imported_file_name,
)
return imported_file_name
for candidate_dir in self._import_directories or []:
base_dir = os.path.abspath(candidate_dir)
resolved_file_name = os.path.normpath(os.path.join(base_dir, imported_file_name))
logging.debug("Checking for imported file '%s' for file '%s' at '%s'",
imported_file_name, base_file, resolved_file_name)
logging.debug(
"Checking for imported file '%s' for file '%s' at '%s'",
imported_file_name,
base_file,
resolved_file_name,
)
if os.path.exists(resolved_file_name):
logging.debug("Found imported file '%s' for file '%s' at '%s'", imported_file_name,
base_file, resolved_file_name)
logging.debug(
"Found imported file '%s' for file '%s' at '%s'",
imported_file_name,
base_file,
resolved_file_name,
)
return resolved_file_name
msg = ("Cannot find imported file '%s' for file '%s'" % (imported_file_name, base_file))
msg = "Cannot find imported file '%s' for file '%s'" % (imported_file_name, base_file)
logging.error(msg)
raise errors.IDLError(msg)
@ -107,7 +123,7 @@ class CompilerImportResolver(parser.ImportResolverBase):
def open(self, resolved_file_name):
# type: (str) -> Any
"""Return an io.Stream for the requested file."""
return io.open(resolved_file_name, encoding='utf-8')
return io.open(resolved_file_name, encoding="utf-8")
def _write_dependencies(spec, write_dependencies_inline):
@ -138,7 +154,8 @@ def _update_import_includes(args, spec, header_file_name):
if args.output_base_dir:
base_include_h_file_name = os.path.relpath(
os.path.normpath(header_file_name), os.path.normpath(args.output_base_dir))
os.path.normpath(header_file_name), os.path.normpath(args.output_base_dir)
)
else:
base_include_h_file_name = os.path.abspath(header_file_name)
@ -149,7 +166,7 @@ def _update_import_includes(args, spec, header_file_name):
if not spec.globals:
spec.globals = syntax.Global(args.input_file, -1, -1)
first_dir = base_include_h_file_name.split('/')[0]
first_dir = base_include_h_file_name.split("/")[0]
for resolved_file_name in spec.imports.resolved_imports:
# Guess: the file naming rules are consistent across IDL invocations
@ -161,9 +178,10 @@ def _update_import_includes(args, spec, header_file_name):
if os.path.isabs(base_dir):
include_h_file_name = os.path.join(
base_dir, include_h_file_name[include_h_file_name.rfind(first_dir):])
base_dir, include_h_file_name[include_h_file_name.rfind(first_dir) :]
)
else:
include_h_file_name = include_h_file_name[include_h_file_name.find(first_dir):]
include_h_file_name = include_h_file_name[include_h_file_name.find(first_dir) :]
else:
include_h_file_name = os.path.abspath(include_h_file_name)
@ -181,9 +199,12 @@ def compile_idl(args):
logging.error("File '%s' not found", args.input_file)
if args.output_source is None:
if not '.' in args.input_file:
logging.error("File name '%s' must be end with a filename extension, such as '%s.idl'",
args.input_file, args.input_file)
if not "." in args.input_file:
logging.error(
"File name '%s' must be end with a filename extension, such as '%s.idl'",
args.input_file,
args.input_file,
)
return False
file_name_prefix = os.path.splitext(args.input_file)[0]
@ -199,9 +220,10 @@ def compile_idl(args):
args.target_arch = platform.machine()
# Compile the IDL through the 3 passes
with io.open(args.input_file, encoding='utf-8') as file_stream:
parsed_doc = parser.parse(file_stream, args.input_file,
CompilerImportResolver(args.import_directories))
with io.open(args.input_file, encoding="utf-8") as file_stream:
parsed_doc = parser.parse(
file_stream, args.input_file, CompilerImportResolver(args.import_directories)
)
if not parsed_doc.errors:
if args.write_dependencies or args.write_dependencies_inline:
@ -215,8 +237,13 @@ def compile_idl(args):
bound_doc = binder.bind(parsed_doc.spec)
if not bound_doc.errors:
generator.generate_code(bound_doc.spec, args.target_arch, args.output_base_dir,
header_file_name, source_file_name)
generator.generate_code(
bound_doc.spec,
args.target_arch,
args.output_base_dir,
header_file_name,
source_file_name,
)
return True
else:

View File

@ -37,7 +37,7 @@ from . import bson
from . import common
from . import writer
_STD_ARRAY_UINT8_16 = 'std::array<std::uint8_t,16>'
_STD_ARRAY_UINT8_16 = "std::array<std::uint8_t,16>"
def is_primitive_scalar_type(cpp_type):
@ -47,26 +47,31 @@ def is_primitive_scalar_type(cpp_type):
Primitive scalar types need to have a default value to prevent warnings from Coverity.
"""
cpp_type = cpp_type.replace(' ', '')
cpp_type = cpp_type.replace(" ", "")
# TODO (SERVER-50101): Remove 'multiversion::FeatureCompatibilityVersion' once IDL supports
# a commmand cpp_type of C++ enum.
return cpp_type in [
'bool', 'double', 'std::int32_t', 'std::uint32_t', 'std::uint64_t', 'std::int64_t',
'multiversion::FeatureCompatibilityVersion'
"bool",
"double",
"std::int32_t",
"std::uint32_t",
"std::uint64_t",
"std::int64_t",
"multiversion::FeatureCompatibilityVersion",
]
def is_primitive_type(cpp_type):
# type: (str) -> bool
"""Return True if a cpp_type is a primitive type and should not be returned as reference."""
cpp_type = cpp_type.replace(' ', '')
cpp_type = cpp_type.replace(" ", "")
return is_primitive_scalar_type(cpp_type) or cpp_type == _STD_ARRAY_UINT8_16
def _qualify_optional_type(cpp_type):
# type: (str) -> str
"""Qualify the type as optional."""
return 'boost::optional<%s>' % (cpp_type)
return "boost::optional<%s>" % (cpp_type)
def _qualify_array_type(cpp_type):
@ -79,7 +84,7 @@ def _optionally_make_call(method_name, param):
# type: (str, str) -> str
"""Return a call to method_name if it is not None, otherwise return an empty string."""
if not method_name:
return ''
return ""
return "%s(%s);" % (method_name, param)
@ -173,14 +178,15 @@ class _CppTypeBasic(CppTypeBase):
def get_getter_body(self, member_name):
# type: (str) -> str
return common.template_args('return ${member_name};', member_name=member_name)
return common.template_args("return ${member_name};", member_name=member_name)
def get_setter_body(self, member_name, validator_method_name):
# type: (str, str) -> str
return common.template_args(
'${optionally_call_validator} ${member_name} = std::move(value);',
optionally_call_validator=_optionally_make_call(validator_method_name,
'value'), member_name=member_name)
"${optionally_call_validator} ${member_name} = std::move(value);",
optionally_call_validator=_optionally_make_call(validator_method_name, "value"),
member_name=member_name,
)
def get_transform_to_getter_type(self, expression):
# type: (str) -> Optional[str]
@ -222,15 +228,16 @@ class _CppTypeView(CppTypeBase):
def get_getter_body(self, member_name):
# type: (str) -> str
return common.template_args('return ${member_name};', member_name=member_name)
return common.template_args("return ${member_name};", member_name=member_name)
def get_setter_body(self, member_name, validator_method_name):
# type: (str, str) -> str
return common.template_args(
'auto _tmpValue = ${value}; ${optionally_call_validator} ${member_name} = std::move(_tmpValue);',
member_name=member_name, optionally_call_validator=_optionally_make_call(
validator_method_name,
'_tmpValue'), value=self.get_transform_to_storage_type("value"))
"auto _tmpValue = ${value}; ${optionally_call_validator} ${member_name} = std::move(_tmpValue);",
member_name=member_name,
optionally_call_validator=_optionally_make_call(validator_method_name, "_tmpValue"),
value=self.get_transform_to_storage_type("value"),
)
def get_transform_to_getter_type(self, expression):
# type: (str) -> Optional[str]
@ -239,7 +246,7 @@ class _CppTypeView(CppTypeBase):
def get_transform_to_storage_type(self, expression):
# type: (str) -> Optional[str]
return common.template_args(
'${expression}.toString()',
"${expression}.toString()",
expression=expression,
)
@ -249,7 +256,7 @@ class _CppTypeVector(CppTypeBase):
def __init__(self, field):
# type: (ast.Field) -> None
super(_CppTypeVector, self).__init__(field, 'std::vector<std::uint8_t>')
super(_CppTypeVector, self).__init__(field, "std::vector<std::uint8_t>")
def get_type_name(self):
# type: () -> str
@ -261,7 +268,7 @@ class _CppTypeVector(CppTypeBase):
def get_getter_setter_type(self):
# type: () -> str
return 'ConstDataRange'
return "ConstDataRange"
def return_by_reference(self):
# type: () -> bool
@ -273,27 +280,30 @@ class _CppTypeVector(CppTypeBase):
def get_getter_body(self, member_name):
# type: (str) -> str
return common.template_args('return ConstDataRange(${member_name});',
member_name=member_name)
return common.template_args(
"return ConstDataRange(${member_name});", member_name=member_name
)
def get_setter_body(self, member_name, validator_method_name):
# type: (str, str) -> str
return common.template_args(
'auto _tmpValue = ${value}; ${optionally_call_validator} ${member_name} = std::move(_tmpValue);',
member_name=member_name, optionally_call_validator=_optionally_make_call(
validator_method_name,
'_tmpValue'), value=self.get_transform_to_storage_type("value"))
"auto _tmpValue = ${value}; ${optionally_call_validator} ${member_name} = std::move(_tmpValue);",
member_name=member_name,
optionally_call_validator=_optionally_make_call(validator_method_name, "_tmpValue"),
value=self.get_transform_to_storage_type("value"),
)
def get_transform_to_getter_type(self, expression):
# type: (str) -> Optional[str]
return common.template_args('ConstDataRange(${expression});', expression=expression)
return common.template_args("ConstDataRange(${expression});", expression=expression)
def get_transform_to_storage_type(self, expression):
# type: (str) -> Optional[str]
return common.template_args(
'std::vector<std::uint8_t>(reinterpret_cast<const uint8_t*>(${expression}.data()), ' +
'reinterpret_cast<const uint8_t*>(${expression}.data()) + ${expression}.length())',
expression=expression)
"std::vector<std::uint8_t>(reinterpret_cast<const uint8_t*>(${expression}.data()), "
+ "reinterpret_cast<const uint8_t*>(${expression}.data()) + ${expression}.length())",
expression=expression,
)
class _CppTypeDelegating(CppTypeBase):
@ -362,7 +372,7 @@ class _CppTypeArray(_CppTypeDelegating):
# type: (str) -> str
convert = self.get_transform_to_getter_type(member_name)
if convert:
return common.template_args('return ${convert};', convert=convert)
return common.template_args("return ${convert};", convert=convert)
return self._base.get_getter_body(member_name)
def get_setter_body(self, member_name, validator_method_name):
@ -370,16 +380,18 @@ class _CppTypeArray(_CppTypeDelegating):
convert = self.get_transform_to_storage_type("value")
if convert:
return common.template_args(
'auto _tmpValue = ${convert}; ${optionally_call_validator} ${member_name} = std::move(_tmpValue);',
member_name=member_name, optionally_call_validator=_optionally_make_call(
validator_method_name, '_tmpValue'), convert=convert)
"auto _tmpValue = ${convert}; ${optionally_call_validator} ${member_name} = std::move(_tmpValue);",
member_name=member_name,
optionally_call_validator=_optionally_make_call(validator_method_name, "_tmpValue"),
convert=convert,
)
return self._base.get_setter_body(member_name, validator_method_name)
def get_transform_to_getter_type(self, expression):
# type: (str) -> Optional[str]
if self._base.get_storage_type() != self._base.get_getter_setter_type():
return common.template_args(
'transformVector(${expression})',
"transformVector(${expression})",
expression=expression,
)
return None
@ -388,7 +400,7 @@ class _CppTypeArray(_CppTypeDelegating):
# type: (str) -> Optional[str]
if self._base.get_storage_type() != self._base.get_getter_setter_type():
return common.template_args(
'transformVector(${expression})',
"transformVector(${expression})",
expression=expression,
)
return None
@ -427,13 +439,18 @@ class _CppTypeOptional(_CppTypeDelegating):
} else {
return boost::none;
}
"""), member_name=member_name, convert=convert)
"""),
member_name=member_name,
convert=convert,
)
elif self.is_view_type():
# For optionals around view types, do an explicit construction
return common.template_args('return ${param_type}{${member_name}};',
param_type=self.get_getter_setter_type(),
member_name=member_name)
return common.template_args('return ${member_name};', member_name=member_name)
return common.template_args(
"return ${param_type}{${member_name}};",
param_type=self.get_getter_setter_type(),
member_name=member_name,
)
return common.template_args("return ${member_name};", member_name=member_name)
def get_setter_body(self, member_name, validator_method_name):
# type: (str, str) -> str
@ -450,8 +467,11 @@ class _CppTypeOptional(_CppTypeDelegating):
} else {
${member_name} = boost::none;
}
"""), member_name=member_name, convert=convert,
optionally_call_validator=_optionally_make_call(validator_method_name, '_tmpValue'))
"""),
member_name=member_name,
convert=convert,
optionally_call_validator=_optionally_make_call(validator_method_name, "_tmpValue"),
)
return self._base.get_setter_body(member_name, validator_method_name)
@ -459,9 +479,9 @@ def get_cpp_type_from_cpp_type_name(field, cpp_type_name, array):
# type: (ast.Field, str, bool) -> CppTypeBase
"""Get the C++ Type information for the given C++ type name, e.g. std::string."""
cpp_type_info: CppTypeBase
if cpp_type_name == 'std::string':
cpp_type_info = _CppTypeView(field, 'std::string', 'std::string', 'StringData')
elif cpp_type_name == 'std::vector<std::uint8_t>':
if cpp_type_name == "std::string":
cpp_type_info = _CppTypeView(field, "std::string", "std::string", "StringData")
elif cpp_type_name == "std::vector<std::uint8_t>":
cpp_type_info = _CppTypeVector(field)
else:
cpp_type_info = _CppTypeBasic(field, cpp_type_name)
@ -511,15 +531,17 @@ class BsonCppTypeBase(object, metaclass=ABCMeta):
pass
@abstractmethod
def gen_serializer_expression(self, indented_writer, expression, should_shapify=False,
is_catalog_ctxt=False):
def gen_serializer_expression(
self, indented_writer, expression, should_shapify=False, is_catalog_ctxt=False
):
# type: (writer.IndentedTextWriter, str, bool, bool) -> str
"""Generate code with the text writer and return an expression to serialize the type."""
pass
def _call_method_or_global_function(expression, ast_type, should_shapify=False,
is_catalog_ctxt=False):
def _call_method_or_global_function(
expression, ast_type, should_shapify=False, is_catalog_ctxt=False
):
# type: (str, ast.Type, bool, bool) -> str
"""
Given a fully-qualified method name, call it correctly.
@ -529,25 +551,25 @@ def _call_method_or_global_function(expression, ast_type, should_shapify=False,
enum deserializers/serializers which are not methods.
"""
method_name = ast_type.serializer
serialization_context = 'getSerializationContext()' if ast_type.deserialize_with_tenant else ''
shape_options = ''
serialization_context = "getSerializationContext()" if ast_type.deserialize_with_tenant else ""
shape_options = ""
if should_shapify:
shape_options = 'options'
shape_options = "options"
short_method_name = writer.get_method_name(method_name)
if writer.is_function(method_name):
if ast_type.deserialize_with_tenant:
if is_catalog_ctxt:
# serializeForCatalog doesn't need a serializationContext
serialization_context = ''
serialization_context = ""
method_name = method_name.replace("serialize", "serializeForCatalog")
else:
serialization_context = ', ' + serialization_context
serialization_context = ", " + serialization_context
if should_shapify:
shape_options = ', ' + shape_options
shape_options = ", " + shape_options
return common.template_args(
'${method_name}(${expression}${shape_options}${serialization_context})',
"${method_name}(${expression}${shape_options}${serialization_context})",
expression=expression,
method_name=method_name,
shape_options=shape_options,
@ -555,7 +577,7 @@ def _call_method_or_global_function(expression, ast_type, should_shapify=False,
)
return common.template_args(
'${expression}.${method_name}(${shape_options}${serialization_context})',
"${expression}.${method_name}(${shape_options}${serialization_context})",
expression=expression,
method_name=short_method_name,
shape_options=shape_options,
@ -573,19 +595,23 @@ class _CommonBsonCppTypeBase(BsonCppTypeBase):
def gen_deserializer_expression(self, indented_writer, object_instance):
# type: (writer.IndentedTextWriter, str) -> str
return common.template_args('${object_instance}.${method_name}()',
object_instance=object_instance,
method_name=self._deserialize_method_name)
return common.template_args(
"${object_instance}.${method_name}()",
object_instance=object_instance,
method_name=self._deserialize_method_name,
)
def has_serializer(self):
# type: () -> bool
return self._ast_type.serializer is not None
def gen_serializer_expression(self, indented_writer, expression, should_shapify=False,
is_catalog_ctxt=False):
def gen_serializer_expression(
self, indented_writer, expression, should_shapify=False, is_catalog_ctxt=False
):
# type: (writer.IndentedTextWriter, str, bool, bool) -> str
return _call_method_or_global_function(expression, self._ast_type, should_shapify,
is_catalog_ctxt)
return _call_method_or_global_function(
expression, self._ast_type, should_shapify, is_catalog_ctxt
)
class _ObjectBsonCppTypeBase(BsonCppTypeBase):
@ -596,34 +622,41 @@ class _ObjectBsonCppTypeBase(BsonCppTypeBase):
if self._ast_type.deserializer:
# Call a method like: Class::method(const BSONObj& value)
indented_writer.write_line(
common.template_args('const BSONObj localObject = ${object_instance}.Obj();',
object_instance=object_instance))
common.template_args(
"const BSONObj localObject = ${object_instance}.Obj();",
object_instance=object_instance,
)
)
return "localObject"
# Just pass the BSONObj through without trying to parse it.
return common.template_args('${object_instance}.Obj()', object_instance=object_instance)
return common.template_args("${object_instance}.Obj()", object_instance=object_instance)
def has_serializer(self):
# type: () -> bool
return self._ast_type.serializer is not None
def gen_serializer_expression(self, indented_writer, expression, should_shapify=False,
is_catalog_ctxt=False):
def gen_serializer_expression(
self, indented_writer, expression, should_shapify=False, is_catalog_ctxt=False
):
# type: (writer.IndentedTextWriter, str, bool, bool) -> str
method_name = writer.get_method_name(self._ast_type.serializer)
function_arguments = []
# SerializationContext is tied to tenant deserialization
if self._ast_type.deserialize_with_tenant:
function_arguments.append('getSerializationContext()')
function_arguments.append("getSerializationContext()")
# Provide options if custom shapification required.
if should_shapify:
function_arguments.append('options')
function_arguments.append("options")
indented_writer.write_line(
common.template_args(
'const BSONObj localObject = ${expression}.${method_name}(${function_arguments});',
expression=expression, method_name=method_name,
function_arguments=', '.join(function_arguments)))
"const BSONObj localObject = ${expression}.${method_name}(${function_arguments});",
expression=expression,
method_name=method_name,
function_arguments=", ".join(function_arguments),
)
)
return "localObject"
@ -634,25 +667,34 @@ class _ArrayBsonCppTypeBase(BsonCppTypeBase):
# type: (writer.IndentedTextWriter, str) -> str
if self._ast_type.deserializer:
indented_writer.write_line(
common.template_args('BSONArray localArray(${object_instance}.Obj());',
object_instance=object_instance))
common.template_args(
"BSONArray localArray(${object_instance}.Obj());",
object_instance=object_instance,
)
)
return "localArray"
# Just pass the BSONObj through without trying to parse it.
return common.template_args('BSONArray(${object_instance}.Obj())',
object_instance=object_instance)
return common.template_args(
"BSONArray(${object_instance}.Obj())", object_instance=object_instance
)
def has_serializer(self):
# type: () -> bool
return self._ast_type.serializer is not None
def gen_serializer_expression(self, indented_writer, expression, should_shapify=False,
is_catalog_ctxt=False):
def gen_serializer_expression(
self, indented_writer, expression, should_shapify=False, is_catalog_ctxt=False
):
# type: (writer.IndentedTextWriter, str, bool, bool) -> str
method_name = writer.get_method_name(self._ast_type.serializer)
indented_writer.write_line(
common.template_args('BSONArray localArray(${expression}.${method_name}());',
expression=expression, method_name=method_name))
common.template_args(
"BSONArray localArray(${expression}.${method_name}());",
expression=expression,
method_name=method_name,
)
)
return "localArray"
@ -661,32 +703,42 @@ class _BinDataBsonCppTypeBase(BsonCppTypeBase):
def gen_deserializer_expression(self, indented_writer, object_instance):
# type: (writer.IndentedTextWriter, str) -> str
if self._ast_type.bindata_subtype == 'uuid':
return common.template_args('uassertStatusOK(UUID::parse(${object_instance}))',
object_instance=object_instance)
return common.template_args('${object_instance}._binDataVector()',
object_instance=object_instance)
if self._ast_type.bindata_subtype == "uuid":
return common.template_args(
"uassertStatusOK(UUID::parse(${object_instance}))", object_instance=object_instance
)
return common.template_args(
"${object_instance}._binDataVector()", object_instance=object_instance
)
def has_serializer(self):
# type: () -> bool
return True
def gen_serializer_expression(self, indented_writer, expression, should_shapify=False,
is_catalog_ctxt=False):
def gen_serializer_expression(
self, indented_writer, expression, should_shapify=False, is_catalog_ctxt=False
):
# type: (writer.IndentedTextWriter, str, bool, bool) -> str
if self._ast_type.serializer:
method_name = writer.get_method_name(self._ast_type.serializer)
indented_writer.write_line(
common.template_args('ConstDataRange tempCDR = ${expression}.${method_name}();',
expression=expression, method_name=method_name))
common.template_args(
"ConstDataRange tempCDR = ${expression}.${method_name}();",
expression=expression,
method_name=method_name,
)
)
else:
indented_writer.write_line(
common.template_args('ConstDataRange tempCDR(${expression});',
expression=expression))
common.template_args(
"ConstDataRange tempCDR(${expression});", expression=expression
)
)
return common.template_args(
'BSONBinData(tempCDR.data(), tempCDR.length(), ${bindata_subtype})',
bindata_subtype=bson.cpp_bindata_subtype_type_name(self._ast_type.bindata_subtype))
"BSONBinData(tempCDR.data(), tempCDR.length(), ${bindata_subtype})",
bindata_subtype=bson.cpp_bindata_subtype_type_name(self._ast_type.bindata_subtype),
)
# For some types, we want to support custom serialization but defer most of the serialization to
@ -700,19 +752,19 @@ def get_bson_cpp_type(ast_type):
if len(ast_type.bson_serialization_type) > 1:
return None
if ast_type.bson_serialization_type[0] == 'string':
if ast_type.bson_serialization_type[0] == "string":
return _CommonBsonCppTypeBase(ast_type, "valueStringData")
if ast_type.bson_serialization_type[0] == 'object':
if ast_type.bson_serialization_type[0] == "object":
return _ObjectBsonCppTypeBase(ast_type)
if ast_type.bson_serialization_type[0] == 'array':
if ast_type.bson_serialization_type[0] == "array":
return _ArrayBsonCppTypeBase(ast_type)
if ast_type.bson_serialization_type[0] == 'bindata':
if ast_type.bson_serialization_type[0] == "bindata":
return _BinDataBsonCppTypeBase(ast_type)
if ast_type.bson_serialization_type[0] == 'int':
if ast_type.bson_serialization_type[0] == "int":
return _CommonBsonCppTypeBase(ast_type, "_numberInt")
# Unsupported type

View File

@ -71,32 +71,37 @@ class EnumTypeInfoBase(object, metaclass=ABCMeta):
def _get_enum_deserializer_name(self):
# type: () -> str
"""Return the name of deserializer function without prefix."""
return common.template_args("${enum_name}_parse",
enum_name=common.title_case(self._enum.name))
return common.template_args(
"${enum_name}_parse", enum_name=common.title_case(self._enum.name)
)
def get_enum_deserializer_name(self):
# type: () -> str
"""Return the name of deserializer function with non-method prefix."""
return "::" + common.qualify_cpp_name(self._enum.cpp_namespace,
self._get_enum_deserializer_name())
return "::" + common.qualify_cpp_name(
self._enum.cpp_namespace, self._get_enum_deserializer_name()
)
def _get_enum_serializer_name(self):
# type: () -> str
"""Return the name of serializer function without prefix."""
return common.template_args("${enum_name}_serializer",
enum_name=common.title_case(self._enum.name))
return common.template_args(
"${enum_name}_serializer", enum_name=common.title_case(self._enum.name)
)
def get_enum_serializer_name(self):
# type: () -> str
"""Return the name of serializer function with non-method prefix."""
return "::" + common.qualify_cpp_name(self._enum.cpp_namespace,
self._get_enum_serializer_name())
return "::" + common.qualify_cpp_name(
self._enum.cpp_namespace, self._get_enum_serializer_name()
)
def _get_enum_extra_data_name(self):
# type: () -> str
"""Return the name of the get_extra_data function without prefix."""
return common.template_args("${enum_name}_get_extra_data",
enum_name=common.title_case(self._enum.name))
return common.template_args(
"${enum_name}_get_extra_data", enum_name=common.title_case(self._enum.name)
)
@abstractmethod
def get_cpp_value_assignment(self, enum_value):
@ -139,9 +144,11 @@ class EnumTypeInfoBase(object, metaclass=ABCMeta):
if len(self._get_populated_extra_values()) == 0:
return None
return common.template_args("BSONObj ${function_name}(${enum_name} value)",
enum_name=self.get_cpp_type_name(),
function_name=self._get_enum_extra_data_name())
return common.template_args(
"BSONObj ${function_name}(${enum_name} value)",
enum_name=self.get_cpp_type_name(),
function_name=self._get_enum_extra_data_name(),
)
def gen_extra_data_definition(self, indented_writer):
# type: (writer.IndentedTextWriter) -> None
@ -154,26 +161,30 @@ class EnumTypeInfoBase(object, metaclass=ABCMeta):
# Generate an anonymous namespace full of BSON constants.
#
with writer.NamespaceScopeBlock(indented_writer, ['']):
with writer.NamespaceScopeBlock(indented_writer, [""]):
for enum_value in extra_values:
indented_writer.write_line(
common.template_args('// %s' % json.dumps(enum_value.extra_data)))
common.template_args("// %s" % json.dumps(enum_value.extra_data))
)
bson_value = ''.join(
[('\\x%02x' % (b)) for b in bson.BSON.encode(enum_value.extra_data)])
bson_value = "".join(
[("\\x%02x" % (b)) for b in bson.BSON.encode(enum_value.extra_data)]
)
indented_writer.write_line(
common.template_args(
'const BSONObj ${const_name}("${bson_value}");',
const_name=_get_constant_enum_extra_data_name(self._enum, enum_value),
bson_value=bson_value))
bson_value=bson_value,
)
)
indented_writer.write_empty_line()
# Generate implementation of get_extra_data function.
#
template_params = {
'enum_name': self.get_cpp_type_name(),
'function_name': self.get_extra_data_declaration(),
"enum_name": self.get_cpp_type_name(),
"function_name": self.get_extra_data_declaration(),
}
with writer.TemplateContext(indented_writer, template_params):
@ -181,16 +192,19 @@ class EnumTypeInfoBase(object, metaclass=ABCMeta):
with writer.IndentedScopedBlock(indented_writer, "switch (value) {", "}"):
for enum_value in extra_values:
indented_writer.write_template(
'case ${enum_name}::%s: return %s;' %
(enum_value.name,
_get_constant_enum_extra_data_name(self._enum, enum_value)))
"case ${enum_name}::%s: return %s;"
% (
enum_value.name,
_get_constant_enum_extra_data_name(self._enum, enum_value),
)
)
if len(extra_values) != len(self._enum.values):
# One or more enums does not have associated extra data.
indented_writer.write_line('default: return BSONObj();')
indented_writer.write_line("default: return BSONObj();")
if len(extra_values) == len(self._enum.values):
# All enum cases handled, the compiler should know this.
indented_writer.write_line('MONGO_UNREACHABLE;')
indented_writer.write_line("MONGO_UNREACHABLE;")
class _EnumTypeInt(EnumTypeInfoBase, metaclass=ABCMeta):
@ -212,17 +226,19 @@ class _EnumTypeInt(EnumTypeInfoBase, metaclass=ABCMeta):
# type: () -> str
return common.template_args(
"${enum_name} ${function_name}(const IDLParserContext& ctxt, std::int32_t value)",
enum_name=self.get_cpp_type_name(), function_name=self._get_enum_deserializer_name())
enum_name=self.get_cpp_type_name(),
function_name=self._get_enum_deserializer_name(),
)
def gen_deserializer_definition(self, indented_writer):
# type: (writer.IndentedTextWriter) -> None
enum_values = sorted(cast(ast.Enum, self._enum).values, key=lambda ev: int(ev.value))
template_params = {
'enum_name': self.get_cpp_type_name(),
'function_name': self.get_deserializer_declaration(),
'min_value': enum_values[0].name,
'max_value': enum_values[-1].name,
"enum_name": self.get_cpp_type_name(),
"function_name": self.get_deserializer_declaration(),
"min_value": enum_values[0].name,
"max_value": enum_values[-1].name,
}
with writer.TemplateContext(indented_writer, template_params):
@ -237,33 +253,39 @@ class _EnumTypeInt(EnumTypeInfoBase, metaclass=ABCMeta):
} else {
return static_cast<${enum_name}>(value);
}
"""))
""")
)
def get_serializer_declaration(self):
# type: () -> str
"""Get the serializer function declaration minus trailing semicolon."""
return common.template_args("std::int32_t ${function_name}(${enum_name} value)",
enum_name=self.get_cpp_type_name(),
function_name=self._get_enum_serializer_name())
return common.template_args(
"std::int32_t ${function_name}(${enum_name} value)",
enum_name=self.get_cpp_type_name(),
function_name=self._get_enum_serializer_name(),
)
def gen_serializer_definition(self, indented_writer):
# type: (writer.IndentedTextWriter) -> None
"""Generate the serializer function definition."""
template_params = {
'enum_name': self.get_cpp_type_name(),
'function_name': self.get_serializer_declaration(),
"enum_name": self.get_cpp_type_name(),
"function_name": self.get_serializer_declaration(),
}
with writer.TemplateContext(indented_writer, template_params):
with writer.IndentedScopedBlock(indented_writer, "${function_name} {", "}"):
indented_writer.write_template('return static_cast<std::int32_t>(value);')
indented_writer.write_template("return static_cast<std::int32_t>(value);")
def _get_constant_enum_extra_data_name(idl_enum, enum_value):
# type: (Union[syntax.Enum,ast.Enum], Union[syntax.EnumValue,ast.EnumValue]) -> str
"""Return the C++ name for a string constant of enum extra data value."""
return common.template_args('k${enum_name}_${name}_extra_data',
enum_name=common.title_case(idl_enum.name), name=enum_value.name)
return common.template_args(
"k${enum_name}_${name}_extra_data",
enum_name=common.title_case(idl_enum.name),
name=enum_value.name,
)
class _EnumTypeString(EnumTypeInfoBase, metaclass=ABCMeta):
@ -271,8 +293,9 @@ class _EnumTypeString(EnumTypeInfoBase, metaclass=ABCMeta):
def get_cpp_type_name(self):
# type: () -> str
return common.template_args("${enum_name}Enum",
enum_name=common.title_case(self._enum.name))
return common.template_args(
"${enum_name}Enum", enum_name=common.title_case(self._enum.name)
)
def get_bson_types(self):
# type: () -> List[str]
@ -280,65 +303,76 @@ class _EnumTypeString(EnumTypeInfoBase, metaclass=ABCMeta):
def get_cpp_value_assignment(self, enum_value):
# type: (ast.EnumValue) -> str
return ''
return ""
def get_deserializer_declaration(self):
# type: () -> str
cpp_type = self.get_cpp_type_name()
func = self._get_enum_deserializer_name()
return f'{cpp_type} {func}(const IDLParserContext& ctxt, StringData value)'
return f"{cpp_type} {func}(const IDLParserContext& ctxt, StringData value)"
def gen_deserializer_definition(self, indented_writer):
# type: (writer.IndentedTextWriter) -> None
cpp_type = self.get_cpp_type_name()
func = self._get_enum_deserializer_name()
with writer.NamespaceScopeBlock(indented_writer, ['']):
with writer.IndentedScopedBlock(indented_writer,
f'constexpr std::array {cpp_type}_values{{', '};'):
with writer.NamespaceScopeBlock(indented_writer, [""]):
with writer.IndentedScopedBlock(
indented_writer, f"constexpr std::array {cpp_type}_values{{", "};"
):
for e in self._enum.values:
indented_writer.write_line(f'{cpp_type}::{e.name},')
with writer.IndentedScopedBlock(indented_writer,
f'constexpr std::array {cpp_type}_names{{', '};'):
indented_writer.write_line(f"{cpp_type}::{e.name},")
with writer.IndentedScopedBlock(
indented_writer, f"constexpr std::array {cpp_type}_names{{", "};"
):
for e in self._enum.values:
indented_writer.write_line(f'"{e.value}"_sd,')
indented_writer.write_empty_line()
with writer.IndentedScopedBlock(
indented_writer,
f"{cpp_type} {func}(const IDLParserContext& ctxt, StringData value) {{",
"}",
):
indented_writer.write_line(
f"static constexpr auto onMatch = [](int i) {{ return {cpp_type}_values[i]; }};"
)
indented_writer.write_line(
f"auto onFail = [&] {{ ctxt.throwBadEnumValue(value); return {cpp_type}{{}}; }};"
)
writer.gen_string_table_find_function_block(
indented_writer,
f'{cpp_type} {func}(const IDLParserContext& ctxt, StringData value) {{', '}'):
indented_writer.write_line(
f'static constexpr auto onMatch = [](int i) {{ return {cpp_type}_values[i]; }};')
indented_writer.write_line(
f'auto onFail = [&] {{ ctxt.throwBadEnumValue(value); return {cpp_type}{{}}; }};')
writer.gen_string_table_find_function_block(indented_writer, 'value', 'onMatch({})',
'onFail()',
[e.value for e in self._enum.values])
"value",
"onMatch({})",
"onFail()",
[e.value for e in self._enum.values],
)
def get_serializer_declaration(self):
# type: () -> str
"""Get the serializer function declaration minus trailing semicolon."""
cpp_type = self.get_cpp_type_name()
func = self._get_enum_serializer_name()
return f'StringData {func}({cpp_type} value)'
return f"StringData {func}({cpp_type} value)"
def gen_serializer_definition(self, indented_writer):
# type: (writer.IndentedTextWriter) -> None
"""Generate the serializer function definition."""
func = self._get_enum_serializer_name()
cpp_type = self.get_cpp_type_name()
with writer.IndentedScopedBlock(indented_writer, f'StringData {func}({cpp_type} value) {{',
'}'):
indented_writer.write_line('auto idx = static_cast<size_t>(value);')
indented_writer.write_line(f'invariant(idx < {cpp_type}_names.size());')
indented_writer.write_line(f'return {cpp_type}_names[idx];')
with writer.IndentedScopedBlock(
indented_writer, f"StringData {func}({cpp_type} value) {{", "}"
):
indented_writer.write_line("auto idx = static_cast<size_t>(value);")
indented_writer.write_line(f"invariant(idx < {cpp_type}_names.size());")
indented_writer.write_line(f"return {cpp_type}_names[idx];")
def get_type_info(idl_enum):
# type: (Union[syntax.Enum,ast.Enum]) -> Optional[EnumTypeInfoBase]
"""Get the type information for a given enumeration type, return None if not supported."""
if idl_enum.type == 'int':
if idl_enum.type == "int":
return _EnumTypeInt(idl_enum)
elif idl_enum.type == 'string':
elif idl_enum.type == "string":
return _EnumTypeString(idl_enum)
return None

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -43,7 +43,7 @@ class FieldListInfo:
# type: () -> MethodInfo
"""Get the hasField method for a generic argument or generic reply field list."""
class_name = common.title_case(self.struct.cpp_name)
return MethodInfo(class_name, 'hasField', ['StringData fieldName'], 'bool', static=True)
return MethodInfo(class_name, "hasField", ["StringData fieldName"], "bool", static=True)
def get_should_forward_name(self):
"""Get the name of the shard-forwarding rule for a generic argument or reply field."""
@ -62,8 +62,13 @@ class FieldListInfo:
# type: () -> MethodInfo
"""Get the method for checking the shard-forwarding rule of an argument or reply field."""
class_name = common.title_case(self.struct.cpp_name)
return MethodInfo(class_name, self.get_should_forward_name(), ['StringData fieldName'],
'bool', static=True)
return MethodInfo(
class_name,
self.get_should_forward_name(),
["StringData fieldName"],
"bool",
static=True,
)
def get_field_list_info(struct):

File diff suppressed because it is too large Load Diff

View File

@ -40,8 +40,14 @@ from . import writer
def _is_required_constructor_arg(field):
# type: (ast.Field) -> bool
"""Get whether we require this field to have a value set for constructor purposes."""
return not field.ignore and not field.optional and not field.default and not field.chained \
and not field.chained_struct_field and not field.serialize_op_msg_request_only
return (
not field.ignore
and not field.optional
and not field.default
and not field.chained
and not field.chained_struct_field
and not field.serialize_op_msg_request_only
)
def _get_arg_for_field(field):
@ -66,7 +72,7 @@ def _get_required_parameters(struct):
def _get_serialization_ctx_arg():
return 'boost::optional<SerializationContext> serializationContext = boost::none'
return "boost::optional<SerializationContext> serializationContext = boost::none"
class ArgumentInfo(object):
@ -76,12 +82,12 @@ class ArgumentInfo(object):
# type: (str) -> None
"""Create a instance of the ArgumentInfo class by parsing the argument string."""
self.defaults = None
equal_tokens = arg.split('=')
equal_tokens = arg.split("=")
if len(equal_tokens) > 1:
self.defaults = equal_tokens[-1].strip()
space_tokens = equal_tokens[0].strip().split(' ')
self.type = ' '.join(space_tokens[0:-1])
space_tokens = equal_tokens[0].strip().split(" ")
self.type = " ".join(space_tokens[0:-1])
self.name = space_tokens[-1]
def get_string(self, get_defaults):
@ -97,8 +103,17 @@ class MethodInfo(object):
# pylint: disable=too-many-instance-attributes
def __init__(self, class_name, method_name, args, return_type=None, static=False, const=False,
explicit=False, desc_for_comment=None):
def __init__(
self,
class_name,
method_name,
args,
return_type=None,
static=False,
const=False,
explicit=False,
desc_for_comment=None,
):
# type: (str, str, List[str], str, bool, bool, bool, Optional[str]) -> None
# pylint: disable=too-many-arguments
"""Create a MethodInfo instance."""
@ -114,59 +129,68 @@ class MethodInfo(object):
def get_declaration(self):
# type: () -> str
"""Get a declaration for a method."""
pre_modifiers = ''
post_modifiers = ''
return_type_str = ''
pre_modifiers = ""
post_modifiers = ""
return_type_str = ""
if self.static:
pre_modifiers = 'static '
pre_modifiers = "static "
if self.const:
post_modifiers = ' const'
post_modifiers = " const"
if self.explicit:
pre_modifiers += 'explicit '
pre_modifiers += "explicit "
if self.return_type:
return_type_str = self.return_type + ' '
return_type_str = self.return_type + " "
return common.template_args(
"${pre_modifiers}${return_type}${method_name}(${args})${post_modifiers};",
pre_modifiers=pre_modifiers, return_type=return_type_str, method_name=self.method_name,
args=', '.join(
[arg.get_string(True) for arg in self.args]), post_modifiers=post_modifiers)
pre_modifiers=pre_modifiers,
return_type=return_type_str,
method_name=self.method_name,
args=", ".join([arg.get_string(True) for arg in self.args]),
post_modifiers=post_modifiers,
)
def get_definition(self):
# type: () -> str
"""Get a definition for a method."""
pre_modifiers = ''
post_modifiers = ''
return_type_str = ''
pre_modifiers = ""
post_modifiers = ""
return_type_str = ""
if self.const:
post_modifiers = ' const'
post_modifiers = " const"
if self.return_type:
return_type_str = self.return_type + ' '
return_type_str = self.return_type + " "
return common.template_args(
"${pre_modifiers}${return_type}${class_name}::${method_name}(${args})${post_modifiers}",
pre_modifiers=pre_modifiers, return_type=return_type_str, class_name=self.class_name,
method_name=self.method_name, args=', '.join(
[arg.get_string(False) for arg in self.args]), post_modifiers=post_modifiers)
pre_modifiers=pre_modifiers,
return_type=return_type_str,
class_name=self.class_name,
method_name=self.method_name,
args=", ".join([arg.get_string(False) for arg in self.args]),
post_modifiers=post_modifiers,
)
def get_call(self, obj):
# type: (Optional[str]) -> str
"""Generate a simple call to the method using the defined args list."""
args = ', '.join([arg.name for arg in self.args])
args = ", ".join([arg.name for arg in self.args])
if obj:
return common.template_args("${obj}.${method_name}(${args});", obj=obj,
method_name=self.method_name, args=args)
return common.template_args(
"${obj}.${method_name}(${args});", obj=obj, method_name=self.method_name, args=args
)
return common.template_args("${method_name}(${args});", method_name=self.method_name,
args=args)
return common.template_args(
"${method_name}(${args});", method_name=self.method_name, args=args
)
def get_desc_for_comment(self):
# type: () -> Optional[str]
@ -292,9 +316,14 @@ class _StructTypeInfo(StructTypeInfoBase):
comment = textwrap.dedent(f"""\
Factory function that parses a {class_name} from a BSONObj. A {class_name} parsed
this way participates in ownership of the data underlying the BSONObj.""")
return MethodInfo(class_name, 'parseSharingOwnership',
['const IDLParserContext& ctxt', 'const BSONObj& bsonObject'], class_name,
static=True, desc_for_comment=comment)
return MethodInfo(
class_name,
"parseSharingOwnership",
["const IDLParserContext& ctxt", "const BSONObj& bsonObject"],
class_name,
static=True,
desc_for_comment=comment,
)
def get_owned_deserializer_static_method(self):
# type: () -> MethodInfo
@ -302,9 +331,14 @@ class _StructTypeInfo(StructTypeInfoBase):
comment = textwrap.dedent(f"""\
Factory function that parses a {class_name} from a BSONObj. A {class_name} parsed
this way takes ownership of the data underlying the BSONObj.""")
return MethodInfo(class_name, 'parseOwned',
['const IDLParserContext& ctxt', 'BSONObj&& bsonObject'], class_name,
static=True, desc_for_comment=comment)
return MethodInfo(
class_name,
"parseOwned",
["const IDLParserContext& ctxt", "BSONObj&& bsonObject"],
class_name,
static=True,
desc_for_comment=comment,
)
def get_deserializer_static_method(self):
# type: () -> MethodInfo
@ -315,23 +349,32 @@ class _StructTypeInfo(StructTypeInfoBase):
ensure the validity any members of this struct that point-into the BSONObj (i.e.
unowned
objects).""")
return MethodInfo(class_name, 'parse',
['const IDLParserContext& ctxt', 'const BSONObj& bsonObject'], class_name,
static=True, desc_for_comment=comment)
return MethodInfo(
class_name,
"parse",
["const IDLParserContext& ctxt", "const BSONObj& bsonObject"],
class_name,
static=True,
desc_for_comment=comment,
)
def get_deserializer_method(self):
# type: () -> MethodInfo
return MethodInfo(
common.title_case(self._struct.cpp_name), 'parseProtected',
['const IDLParserContext& ctxt', 'const BSONObj& bsonObject'], 'void')
common.title_case(self._struct.cpp_name),
"parseProtected",
["const IDLParserContext& ctxt", "const BSONObj& bsonObject"],
"void",
)
def get_serializer_method(self):
# type: () -> MethodInfo
args = ['BSONObjBuilder* builder']
args = ["BSONObjBuilder* builder"]
if self._struct.query_shape_component:
args.append("const SerializationOptions& options = {}")
return MethodInfo(
common.title_case(self._struct.cpp_name), 'serialize', args, 'void', const=True)
common.title_case(self._struct.cpp_name), "serialize", args, "void", const=True
)
def get_to_bson_method(self):
# type: () -> MethodInfo
@ -339,7 +382,8 @@ class _StructTypeInfo(StructTypeInfoBase):
if self._struct.query_shape_component:
args.append("const SerializationOptions& options = {}")
return MethodInfo(
common.title_case(self._struct.cpp_name), 'toBSON', args, 'BSONObj', const=True)
common.title_case(self._struct.cpp_name), "toBSON", args, "BSONObj", const=True
)
def get_op_msg_request_serializer_method(self):
# type: () -> Optional[MethodInfo]
@ -383,21 +427,32 @@ class _CommandBaseTypeInfo(_StructTypeInfo):
def get_op_msg_request_serializer_method(self):
# type: () -> Optional[MethodInfo]
return MethodInfo(
common.title_case(self._struct.cpp_name), 'serialize',
['const BSONObj& commandPassthroughFields = {}'], 'OpMsgRequest', const=True)
common.title_case(self._struct.cpp_name),
"serialize",
["const BSONObj& commandPassthroughFields = {}"],
"OpMsgRequest",
const=True,
)
def get_op_msg_request_deserializer_static_method(self):
# type: () -> Optional[MethodInfo]
class_name = common.title_case(self._struct.cpp_name)
return MethodInfo(class_name, 'parse',
['const IDLParserContext& ctxt', 'const OpMsgRequest& request'],
class_name, static=True)
return MethodInfo(
class_name,
"parse",
["const IDLParserContext& ctxt", "const OpMsgRequest& request"],
class_name,
static=True,
)
def get_op_msg_request_deserializer_method(self):
# type: () -> Optional[MethodInfo]
return MethodInfo(
common.title_case(self._struct.cpp_name), 'parseProtected',
['const IDLParserContext& ctxt', 'const OpMsgRequest& request'], 'void')
common.title_case(self._struct.cpp_name),
"parseProtected",
["const IDLParserContext& ctxt", "const OpMsgRequest& request"],
"void",
)
class _IgnoredCommandTypeInfo(_CommandBaseTypeInfo):
@ -413,16 +468,23 @@ class _IgnoredCommandTypeInfo(_CommandBaseTypeInfo):
def get_serializer_method(self):
# type: () -> MethodInfo
return MethodInfo(
common.title_case(self._struct.cpp_name), 'serialize',
['BSONObjBuilder* builder', 'const BSONObj& commandPassthroughFields = {}'], 'void',
const=True)
common.title_case(self._struct.cpp_name),
"serialize",
["BSONObjBuilder* builder", "const BSONObj& commandPassthroughFields = {}"],
"void",
const=True,
)
def get_to_bson_method(self):
# type: () -> MethodInfo
# Commands that require namespaces require it as a parameter to serialize()
return MethodInfo(
common.title_case(self._struct.cpp_name), 'toBSON',
['const BSONObj& commandPassthroughFields = {}'], 'BSONObj', const=True)
common.title_case(self._struct.cpp_name),
"toBSON",
["const BSONObj& commandPassthroughFields = {}"],
"BSONObj",
const=True,
)
def gen_serializer(self, indented_writer):
# type: (writer.IndentedTextWriter) -> None
@ -440,8 +502,8 @@ def _get_command_type_parameter(command, gen_header=False):
# Use the storage type for the constructor argument since the generated code will use std::move.
member_type = cpp_type_info.get_storage_type()
result = f"{member_type} {common.camel_case(command.command_field.cpp_name)}"
if not gen_header or '&' in result:
result = 'const ' + result
if not gen_header or "&" in result:
result = "const " + result
return result
@ -468,27 +530,38 @@ class _CommandFromType(_CommandBaseTypeInfo):
class_name = common.title_case(self._struct.cpp_name)
arg = _get_command_type_parameter(self._command, gen_header)
return MethodInfo(class_name, class_name, [arg] + _get_required_parameters(self._struct),
explicit=True)
return MethodInfo(
class_name, class_name, [arg] + _get_required_parameters(self._struct), explicit=True
)
def get_serializer_method(self):
# type: () -> MethodInfo
return MethodInfo(
common.title_case(self._struct.cpp_name), 'serialize',
['BSONObjBuilder* builder', 'const BSONObj& commandPassthroughFields = {}'], 'void',
const=True)
common.title_case(self._struct.cpp_name),
"serialize",
["BSONObjBuilder* builder", "const BSONObj& commandPassthroughFields = {}"],
"void",
const=True,
)
def get_to_bson_method(self):
# type: () -> MethodInfo
return MethodInfo(
common.title_case(self._struct.cpp_name), 'toBSON',
['const BSONObj& commandPassthroughFields = {}'], 'BSONObj', const=True)
common.title_case(self._struct.cpp_name),
"toBSON",
["const BSONObj& commandPassthroughFields = {}"],
"BSONObj",
const=True,
)
def get_deserializer_method(self):
# type: () -> MethodInfo
return MethodInfo(
common.title_case(self._struct.cpp_name), 'parseProtected',
['const IDLParserContext& ctxt', 'const BSONObj& bsonObject'], 'void')
common.title_case(self._struct.cpp_name),
"parseProtected",
["const IDLParserContext& ctxt", "const BSONObj& bsonObject"],
"void",
)
def gen_getter_method(self, indented_writer):
# type: (writer.IndentedTextWriter) -> None
@ -520,78 +593,96 @@ class _CommandWithNamespaceTypeInfo(_CommandBaseTypeInfo):
@staticmethod
def _get_nss_param(gen_header):
nss_param = 'NamespaceString nss'
nss_param = "NamespaceString nss"
if not gen_header:
nss_param = 'const ' + nss_param
nss_param = "const " + nss_param
return nss_param
def get_constructor_method(self, gen_header=False):
# type: (bool) -> MethodInfo
class_name = common.title_case(self._struct.cpp_name)
sc_arg = _get_serialization_ctx_arg()
return MethodInfo(class_name, class_name, [self._get_nss_param(gen_header), sc_arg],
explicit=True)
return MethodInfo(
class_name, class_name, [self._get_nss_param(gen_header), sc_arg], explicit=True
)
def get_required_constructor_method(self, gen_header=False):
# type: (bool) -> MethodInfo
class_name = common.title_case(self._struct.cpp_name)
return MethodInfo(
class_name, class_name,
[self._get_nss_param(gen_header)] + _get_required_parameters(self._struct))
class_name,
class_name,
[self._get_nss_param(gen_header)] + _get_required_parameters(self._struct),
)
def get_serializer_method(self):
# type: () -> MethodInfo
return MethodInfo(
common.title_case(self._struct.cpp_name), 'serialize',
['BSONObjBuilder* builder', 'const BSONObj& commandPassthroughFields = {}'], 'void',
const=True)
common.title_case(self._struct.cpp_name),
"serialize",
["BSONObjBuilder* builder", "const BSONObj& commandPassthroughFields = {}"],
"void",
const=True,
)
def get_to_bson_method(self):
# type: () -> MethodInfo
return MethodInfo(
common.title_case(self._struct.cpp_name), 'toBSON',
['const BSONObj& commandPassthroughFields = {}'], 'BSONObj', const=True)
common.title_case(self._struct.cpp_name),
"toBSON",
["const BSONObj& commandPassthroughFields = {}"],
"BSONObj",
const=True,
)
def get_deserializer_method(self):
# type: () -> MethodInfo
return MethodInfo(
common.title_case(self._struct.cpp_name), 'parseProtected',
['const IDLParserContext& ctxt', 'const BSONObj& bsonObject'], 'void')
common.title_case(self._struct.cpp_name),
"parseProtected",
["const IDLParserContext& ctxt", "const BSONObj& bsonObject"],
"void",
)
def gen_getter_method(self, indented_writer):
# type: (writer.IndentedTextWriter) -> None
indented_writer.write_line('const NamespaceString& getNamespace() const { return _nss; }')
indented_writer.write_line("const NamespaceString& getNamespace() const { return _nss; }")
if self._struct.non_const_getter:
indented_writer.write_line('NamespaceString& getNamespace() { return _nss; }')
indented_writer.write_line("NamespaceString& getNamespace() { return _nss; }")
def gen_member(self, indented_writer):
# type: (writer.IndentedTextWriter) -> None
indented_writer.write_line('NamespaceString _nss;')
indented_writer.write_line("NamespaceString _nss;")
def gen_serializer(self, indented_writer):
# type: (writer.IndentedTextWriter) -> None
if self._struct.allow_global_collection_name:
indented_writer.write_line(
'_nss.serializeCollectionName(builder, "%s"_sd);' % (self._command.name))
'_nss.serializeCollectionName(builder, "%s"_sd);' % (self._command.name)
)
else:
indented_writer.write_line('invariant(!_nss.isEmpty());')
indented_writer.write_line("invariant(!_nss.isEmpty());")
indented_writer.write_line(
'builder->append("%s"_sd, _nss.coll());' % (self._command.name))
'builder->append("%s"_sd, _nss.coll());' % (self._command.name)
)
indented_writer.write_empty_line()
def gen_namespace_check(self, indented_writer, db_name, element):
# type: (writer.IndentedTextWriter, str, str) -> None
# TODO: should the name of the first element be validated??
indented_writer.write_line('invariant(_nss.isEmpty());')
allow_global = 'true' if self._struct.allow_global_collection_name else 'false'
indented_writer.write_line("invariant(_nss.isEmpty());")
allow_global = "true" if self._struct.allow_global_collection_name else "false"
indented_writer.write_line(
'auto collectionName = ctxt.checkAndAssertCollectionName(%s, %s);' % (element,
allow_global))
"auto collectionName = ctxt.checkAndAssertCollectionName(%s, %s);"
% (element, allow_global)
)
indented_writer.write_line(
'_nss = NamespaceStringUtil::deserialize(%s, collectionName);' % (db_name))
"_nss = NamespaceStringUtil::deserialize(%s, collectionName);" % (db_name)
)
indented_writer.write_line(
'uassert(ErrorCodes::InvalidNamespace, str::stream() << "Invalid namespace specified: "'
' << _nss.toStringForErrorMsg(), _nss.isValid());')
" << _nss.toStringForErrorMsg(), _nss.isValid());"
)
class _CommandWithUUIDNamespaceTypeInfo(_CommandBaseTypeInfo):
@ -606,55 +697,70 @@ class _CommandWithUUIDNamespaceTypeInfo(_CommandBaseTypeInfo):
@staticmethod
def _get_nss_param(gen_header):
nss_param = 'NamespaceStringOrUUID nssOrUUID'
nss_param = "NamespaceStringOrUUID nssOrUUID"
if not gen_header:
nss_param = 'const ' + nss_param
nss_param = "const " + nss_param
return nss_param
def get_constructor_method(self, gen_header=False):
# type: (bool) -> MethodInfo
class_name = common.title_case(self._struct.cpp_name)
sc_arg = _get_serialization_ctx_arg()
return MethodInfo(class_name, class_name, [self._get_nss_param(gen_header), sc_arg],
explicit=True)
return MethodInfo(
class_name, class_name, [self._get_nss_param(gen_header), sc_arg], explicit=True
)
def get_required_constructor_method(self, gen_header=False):
# type: (bool) -> MethodInfo
class_name = common.title_case(self._struct.cpp_name)
return MethodInfo(
class_name, class_name,
[self._get_nss_param(gen_header)] + _get_required_parameters(self._struct))
class_name,
class_name,
[self._get_nss_param(gen_header)] + _get_required_parameters(self._struct),
)
def get_serializer_method(self):
# type: () -> MethodInfo
return MethodInfo(
common.title_case(self._struct.cpp_name), 'serialize',
['BSONObjBuilder* builder', 'const BSONObj& commandPassthroughFields = {}'], 'void',
const=True)
common.title_case(self._struct.cpp_name),
"serialize",
["BSONObjBuilder* builder", "const BSONObj& commandPassthroughFields = {}"],
"void",
const=True,
)
def get_to_bson_method(self):
# type: () -> MethodInfo
return MethodInfo(
common.title_case(self._struct.cpp_name), 'toBSON',
['const BSONObj& commandPassthroughFields = {}'], 'BSONObj', const=True)
common.title_case(self._struct.cpp_name),
"toBSON",
["const BSONObj& commandPassthroughFields = {}"],
"BSONObj",
const=True,
)
def get_deserializer_method(self):
# type: () -> MethodInfo
return MethodInfo(
common.title_case(self._struct.cpp_name), 'parseProtected',
['const IDLParserContext& ctxt', 'const BSONObj& bsonObject'], 'void')
common.title_case(self._struct.cpp_name),
"parseProtected",
["const IDLParserContext& ctxt", "const BSONObj& bsonObject"],
"void",
)
def gen_getter_method(self, indented_writer):
# type: (writer.IndentedTextWriter) -> None
indented_writer.write_line(
'const NamespaceStringOrUUID& getNamespaceOrUUID() const { return _nssOrUUID; }')
"const NamespaceStringOrUUID& getNamespaceOrUUID() const { return _nssOrUUID; }"
)
if self._struct.non_const_getter:
indented_writer.write_line(
'NamespaceStringOrUUID& getNamespaceOrUUID() { return _nssOrUUID; }')
"NamespaceStringOrUUID& getNamespaceOrUUID() { return _nssOrUUID; }"
)
def gen_member(self, indented_writer):
# type: (writer.IndentedTextWriter) -> None
indented_writer.write_line('NamespaceStringOrUUID _nssOrUUID;')
indented_writer.write_line("NamespaceStringOrUUID _nssOrUUID;")
def gen_serializer(self, indented_writer):
# type: (writer.IndentedTextWriter) -> None
@ -664,14 +770,17 @@ class _CommandWithUUIDNamespaceTypeInfo(_CommandBaseTypeInfo):
def gen_namespace_check(self, indented_writer, db_name, element):
# type: (writer.IndentedTextWriter, str, str) -> None
indented_writer.write_line(
'auto collOrUUID = ctxt.checkAndAssertCollectionNameOrUUID(%s);' % (element))
"auto collOrUUID = ctxt.checkAndAssertCollectionNameOrUUID(%s);" % (element)
)
indented_writer.write_line(
'_nssOrUUID = std::holds_alternative<StringData>(collOrUUID) ? NamespaceStringUtil::deserialize(%s, get<StringData>(collOrUUID)) : NamespaceStringOrUUID(%s, get<UUID>(collOrUUID));'
% (db_name, db_name))
"_nssOrUUID = std::holds_alternative<StringData>(collOrUUID) ? NamespaceStringUtil::deserialize(%s, get<StringData>(collOrUUID)) : NamespaceStringOrUUID(%s, get<UUID>(collOrUUID));"
% (db_name, db_name)
)
indented_writer.write_line(
'uassert(ErrorCodes::InvalidNamespace, str::stream() << "Invalid namespace specified: "'
' << _nssOrUUID.toStringForErrorMsg()'
', !_nssOrUUID.isNamespaceString() || _nssOrUUID.nss().isValid());')
" << _nssOrUUID.toStringForErrorMsg()"
", !_nssOrUUID.isNamespaceString() || _nssOrUUID.nss().isValid());"
)
def get_struct_info(struct):

View File

@ -46,8 +46,9 @@ class IDLParsedSpec(object):
def __init__(self, spec, error_collection):
# type: (IDLSpec, errors.ParserErrorCollection) -> None
"""Must specify either an IDL document or errors, not both."""
assert (spec is None and error_collection is not None) or (spec is not None
and error_collection is None)
assert (spec is None and error_collection is not None) or (
spec is not None and error_collection is None
)
self.spec = spec
self.errors = error_collection
@ -76,11 +77,11 @@ def parse_array_variant_types(name):
if not name.startswith("array<variant<") and not name.endswith(">>"):
return None
name = name[len("array<variant<"):]
name = name[len("array<variant<") :]
name = name[:-2]
variant_types = []
for variant_type in name.split(','):
for variant_type in name.split(","):
variant_type = variant_type.strip()
# Ban array<variant<..., array<...>, ...>> types.
if variant_type.startswith("array<") and variant_type.endswith(">"):
@ -96,7 +97,7 @@ def parse_array_type(name):
if not name.startswith("array<") and not name.endswith(">"):
return None
name = name[len("array<"):]
name = name[len("array<") :]
name = name[:-1]
# V1 restriction, ban nested array types to reduce scope.
@ -139,19 +140,22 @@ class SymbolTable(object):
def _is_duplicate(self, ctxt, location, name, duplicate_class_name):
# type: (errors.ParserContext, common.SourceLocation, str, str) -> bool
"""Return true if the given item already exist in the symbol table."""
for (item, entity_type) in _item_and_type({
for item, entity_type in _item_and_type(
{
"command": self.commands,
"enum": self.enums,
"struct": self.structs,
"type": self.types,
}):
}
):
if item.name == name:
ctxt.add_duplicate_symbol_error(location, name, duplicate_class_name, entity_type)
return True
if entity_type == "command":
if name in [item.command_name, item.command_alias if item.command_alias else '']:
ctxt.add_duplicate_symbol_error(location, name, duplicate_class_name,
entity_type)
if name in [item.command_name, item.command_alias if item.command_alias else ""]:
ctxt.add_duplicate_symbol_error(
location, name, duplicate_class_name, entity_type
)
return True
return False
@ -177,8 +181,9 @@ class SymbolTable(object):
def add_command(self, ctxt, command):
# type: (errors.ParserContext, Command) -> None
"""Add an IDL command to the symbol table and check for duplicates."""
if (not self._is_duplicate(ctxt, command, command.name, "command")
and not self._is_duplicate(ctxt, command, command.command_alias, "command")):
if not self._is_duplicate(
ctxt, command, command.name, "command"
) and not self._is_duplicate(ctxt, command, command.command_alias, "command"):
self.commands.append(command)
def add_generic_argument_list(self, field_list):
@ -270,7 +275,7 @@ class SymbolTable(object):
# cause parsing ambiguity.
first_element = alternative_type.fields[0].name
if first_element in [
elem.fields[0].name for elem in variant.variant_struct_types
elem.fields[0].name for elem in variant.variant_struct_types
]:
ctxt.add_variant_structs_error(location, field_name)
continue
@ -293,11 +298,13 @@ class SymbolTable(object):
return variant
if isinstance(field_type, FieldTypeArray):
element_type = self.resolve_field_type(ctxt, location, field_name,
field_type.element_type)
element_type = self.resolve_field_type(
ctxt, location, field_name, field_type.element_type
)
if not element_type:
ctxt.add_unknown_type_error(location, field_name,
field_type.element_type.debug_string())
ctxt.add_unknown_type_error(
location, field_name, field_type.element_type.debug_string()
)
return None
if isinstance(element_type, Enum):
@ -308,7 +315,7 @@ class SymbolTable(object):
assert isinstance(field_type, FieldTypeSingle)
type_name = field_type.type_name
if type_name.startswith('array<'):
if type_name.startswith("array<"):
# The caller should've already stripped "array<...>" from type_name, this may be an
# illegal nested array like "array<array<...>>".
ctxt.add_bad_array_type_name_error(location, field_name, type_name)
@ -405,13 +412,14 @@ class ArrayType(Type):
def __init__(self, element_type):
# type: (Union[Struct, Type]) -> None
"""Construct an ArrayType."""
super(ArrayType, self).__init__(element_type.file_name, element_type.line,
element_type.column)
self.name = f'array<{element_type.name}>'
super(ArrayType, self).__init__(
element_type.file_name, element_type.line, element_type.column
)
self.name = f"array<{element_type.name}>"
self.element_type = element_type
if isinstance(element_type, Type):
if element_type.cpp_type:
self.cpp_type = f'std::vector<{element_type.cpp_type}>'
self.cpp_type = f"std::vector<{element_type.cpp_type}>"
else:
assert isinstance(element_type, VariantType)
# cpp_type can't be set here for array of variants as element_type.cpp_type is not set yet.
@ -424,7 +432,7 @@ class VariantType(Type):
# type: (str, int, int) -> None
"""Construct a VariantType."""
super(VariantType, self).__init__(file_name, line, column)
self.name = 'variant'
self.name = "variant"
self.variant_types = [] # type: List[Type]
self.variant_struct_types = [] # type: List[Struct]
@ -451,9 +459,14 @@ class Validator(common.SourceLocation):
super(Validator, self).__init__(file_name, line, column)
def __eq__(self, other):
return (isinstance(other, Validator) and self.gt == other.gt and self.lt == other.lt
and self.gte == other.gte and self.lte == other.lte
and self.callback == other.callback)
return (
isinstance(other, Validator)
and self.gt == other.gt
and self.lt == other.lt
and self.gte == other.gte
and self.lte == other.lte
and self.callback == other.callback
)
def __ne__(self, other):
return not self == other
@ -602,7 +615,11 @@ class Privilege(common.SourceLocation):
"""
location = super(Privilege, self).__str__()
msg = "location: %s, resource_pattern: %s, action_type: %s, agg_stage: %s" % (
location, self.resource_pattern, self.action_type, self.agg_stage)
location,
self.resource_pattern,
self.action_type,
self.agg_stage,
)
return msg # type: ignore
@ -709,8 +726,9 @@ class EnumValue(common.SourceLocation):
super(EnumValue, self).__init__(file_name, line, column)
def __eq__(self, other):
return (isinstance(other, EnumValue) and self.name == other.name
and self.value == other.value)
return (
isinstance(other, EnumValue) and self.name == other.name and self.value == other.value
)
def __ne__(self, other):
return not self == other
@ -790,12 +808,13 @@ class FieldTypeArray(FieldType):
"""Construct a FieldTypeArray."""
self.element_type = element_type # type: Union[FieldTypeSingle, FieldTypeVariant]
super(FieldTypeArray, self).__init__(element_type.file_name, element_type.line,
element_type.column)
super(FieldTypeArray, self).__init__(
element_type.file_name, element_type.line, element_type.column
)
def debug_string(self):
"""Display this field type in error messages."""
return f'array<{self.element_type.type_name}>'
return f"array<{self.element_type.type_name}>"
class FieldTypeVariant(FieldType):
@ -810,7 +829,7 @@ class FieldTypeVariant(FieldType):
def debug_string(self):
"""Display this field type in error messages."""
return 'variant<%s>' % (', '.join(v.debug_string() for v in self.variant))
return "variant<%s>" % (", ".join(v.debug_string() for v in self.variant))
class Expression(common.SourceLocation):
@ -827,8 +846,12 @@ class Expression(common.SourceLocation):
super(Expression, self).__init__(file_name, line, column)
def __eq__(self, other):
return (isinstance(other, Expression) and self.literal == other.literal
and self.expr == other.expr and self.is_constexpr == other.is_constexpr)
return (
isinstance(other, Expression)
and self.literal == other.literal
and self.expr == other.expr
and self.is_constexpr == other.is_constexpr
)
def __ne__(self, other):
return not self == other

View File

@ -39,9 +39,9 @@ _INDENT_SPACE_COUNT = 4
def _fill_spaces(count):
# type: (int) -> str
"""Fill a string full of spaces."""
fill = ''
fill = ""
for _ in range(count * _INDENT_SPACE_COUNT):
fill += ' '
fill += " "
return fill
@ -51,7 +51,7 @@ def _indent_text(count, unindented_text):
"""Indent each line of a multi-line string."""
lines = unindented_text.splitlines()
fill = _fill_spaces(count)
return '\n'.join(fill + line for line in lines)
return "\n".join(fill + line for line in lines)
def is_function(name):
@ -68,10 +68,10 @@ def is_function(name):
def get_method_name(name):
# type: (str) -> str
"""Get a method name from a fully qualified method name."""
pos = name.rfind('::')
pos = name.rfind("::")
if pos == -1:
return name
return name[pos + 2:]
return name[pos + 2 :]
def get_method_name_from_qualified_method_name(name):
@ -82,12 +82,12 @@ def get_method_name_from_qualified_method_name(name):
if name.startswith("::"):
name = name[2:]
prefix = 'mongo::'
prefix = "mongo::"
pos = name.find(prefix)
if pos == -1:
return name
return name[len(prefix):]
return name[len(prefix) :]
class IndentedTextWriter(object):
@ -243,7 +243,7 @@ class NamespaceScopeBlock(WriterBlock):
# type: () -> None
"""Write the beginning of the block and do not indent."""
for namespace in self._namespaces:
self._writer.write_unindented_line('namespace %s {' % (namespace))
self._writer.write_unindented_line("namespace %s {" % (namespace))
def __exit__(self, *args):
# type: (*str) -> None
@ -251,7 +251,7 @@ class NamespaceScopeBlock(WriterBlock):
self._namespaces.reverse()
for namespace in self._namespaces:
self._writer.write_unindented_line('} // namespace %s' % (namespace))
self._writer.write_unindented_line("} // namespace %s" % (namespace))
class UnindentedBlock(WriterBlock):
@ -304,7 +304,7 @@ def _get_common_prefix(words):
"""
empty_words = [lw for lw in words if len(lw) == 0]
if empty_words:
return ''
return ""
first_letters = {w[0] for w in words}
@ -317,7 +317,7 @@ def _get_common_prefix(words):
return words[0][0] + _get_common_prefix(suffix_words)
else:
return ''
return ""
def gen_trie(words, writer, callback):
@ -329,7 +329,7 @@ def gen_trie(words, writer, callback):
i.e. for ["abc", "def"], then callback() will be called twice, once for each string.
"""
_gen_trie('', words, writer, callback)
_gen_trie("", words, writer, callback)
def _gen_trie(prefix, words, writer, callback):
@ -352,12 +352,14 @@ def _gen_trie(prefix, words, writer, callback):
suffix = words[0]
suffix_len = len(suffix)
predicate = f'fieldName.size() == {len(word_to_check)} && ' \
predicate = (
f"fieldName.size() == {len(word_to_check)} && "
+ f'std::char_traits<char>::compare(fieldName.rawData() + {prefix_len}, "{suffix}", {suffix_len}) == 0'
)
# If there is no trailing text, we just need to check length to validate we matched
if suffix_len == 0:
predicate = f'fieldName.size() == {len(word_to_check)}'
predicate = f"fieldName.size() == {len(word_to_check)}"
# Optimization:
# Checking strings of length 1 or even length is efficient. Strings of 3 byte length are
@ -365,10 +367,12 @@ def _gen_trie(prefix, words, writer, callback):
# length strings require just 1. Since we know the field name is zero terminated, we can
# just use memcmp and compare with the trailing null byte.
elif suffix_len % 4 == 3:
predicate = f'fieldName.size() == {len(word_to_check)} && ' \
predicate = (
f"fieldName.size() == {len(word_to_check)} && "
+ f' memcmp(fieldName.rawData() + {prefix_len}, "{suffix}\\0", {suffix_len + 1}) == 0'
)
with IndentedScopedBlock(writer, f'if ({predicate}) {{', '}'):
with IndentedScopedBlock(writer, f"if ({predicate}) {{", "}"):
callback(word_to_check)
return
@ -379,7 +383,7 @@ def _gen_trie(prefix, words, writer, callback):
empty_words = [lw for lw in words if len(lw) == 0]
if empty_words:
word_to_check = prefix
with IndentedScopedBlock(writer, f'if (fieldName.size() == {len(word_to_check)}) {{', "}"):
with IndentedScopedBlock(writer, f"if (fieldName.size() == {len(word_to_check)}) {{", "}"):
callback(word_to_check)
# Filter out empty words
@ -395,10 +399,11 @@ def _gen_trie(prefix, words, writer, callback):
suffix_words = [flw[gcp_len:] for flw in words]
with IndentedScopedBlock(
writer,
f'if (fieldName.size() >= {gcp_len} && '\
+ f'std::char_traits<char>::compare(fieldName.rawData() + {prefix_len}, "{gcp}", {gcp_len}) == 0) {{',
"}"):
writer,
f"if (fieldName.size() >= {gcp_len} && "
+ f'std::char_traits<char>::compare(fieldName.rawData() + {prefix_len}, "{gcp}", {gcp_len}) == 0) {{',
"}",
):
_gen_trie(prefix + gcp, suffix_words, writer, callback)
return
@ -410,16 +415,16 @@ def _gen_trie(prefix, words, writer, callback):
first_letters = {w[0] for w in sorted_words}
min_len = len(prefix) + min([len(w) for w in sorted_words])
with IndentedScopedBlock(writer, f'if (fieldName.size() >= {min_len}) {{', "}"):
with IndentedScopedBlock(writer, f"if (fieldName.size() >= {min_len}) {{", "}"):
first_if = True
for first_letter in first_letters:
fl_words = [flw[1:] for flw in words if flw[0] == first_letter]
ei = "else " if not first_if else ''
ei = "else " if not first_if else ""
with IndentedScopedBlock(
writer, f"{ei}if (fieldName[{len(prefix)}] == '{first_letter}') {{", "}"):
writer, f"{ei}if (fieldName[{len(prefix)}] == '{first_letter}') {{", "}"
):
_gen_trie(prefix + first_letter, fl_words, writer, callback)
first_if = False
@ -429,6 +434,6 @@ def gen_string_table_find_function_block(out, in_str, on_match, on_fail, words):
# type: (IndentedTextWriter, str, str, str, list[str]) -> None
"""Wrap a gen_trie generated block as a function."""
index = {word: i for i, word in enumerate(words)}
out.write_line(f'StringData fieldName{{{in_str}}};')
gen_trie(words, out, lambda w: out.write_line(f'return {on_match.format(index[w])};'))
out.write_line(f'return {on_fail};')
out.write_line(f"StringData fieldName{{{in_str}}};")
gen_trie(words, out, lambda w: out.write_line(f"return {on_match.format(index[w])};"))
out.write_line(f"return {on_fail};")

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -39,29 +39,43 @@ import idl.compiler
def main():
# type: () -> None
"""Execute Main Entry point."""
parser = argparse.ArgumentParser(description='MongoDB IDL Compiler.')
parser = argparse.ArgumentParser(description="MongoDB IDL Compiler.")
parser.add_argument('file', type=str, help="IDL input file")
parser.add_argument("file", type=str, help="IDL input file")
parser.add_argument('-o', '--output', type=str, help="IDL output source file")
parser.add_argument("-o", "--output", type=str, help="IDL output source file")
parser.add_argument('--header', type=str, help="IDL output header file")
parser.add_argument("--header", type=str, help="IDL output header file")
parser.add_argument('-i', '--include', type=str, action="append",
help="Directory to search for IDL import files")
parser.add_argument(
"-i",
"--include",
type=str,
action="append",
help="Directory to search for IDL import files",
)
parser.add_argument('-v', '--verbose', action='count', help="Enable verbose tracing")
parser.add_argument("-v", "--verbose", action="count", help="Enable verbose tracing")
parser.add_argument('--base_dir', type=str, help="IDL output relative base directory")
parser.add_argument("--base_dir", type=str, help="IDL output relative base directory")
parser.add_argument('--write-dependencies', action='store_true',
help='only print out a list of dependent imports')
parser.add_argument(
"--write-dependencies",
action="store_true",
help="only print out a list of dependent imports",
)
parser.add_argument('--write-dependencies-inline', action='store_true',
help='print out a list of dependent imports during file generation')
parser.add_argument(
"--write-dependencies-inline",
action="store_true",
help="print out a list of dependent imports during file generation",
)
parser.add_argument('--target_arch', type=str,
help="IDL target archiecture (amd64, s390x). defaults to current machine")
parser.add_argument(
"--target_arch",
type=str,
help="IDL target archiecture (amd64, s390x). defaults to current machine",
)
args = parser.parse_args()
@ -81,8 +95,9 @@ def main():
compiler_args.write_dependencies = args.write_dependencies
compiler_args.write_dependencies_inline = args.write_dependencies_inline
if (args.output is not None and args.header is None) or \
(args.output is None and args.header is not None):
if (args.output is not None and args.header is None) or (
args.output is None and args.header is not None
):
print("ERROR: Either both --header and --output must be specified or neither.")
sys.exit(1)
@ -93,5 +108,5 @@ def main():
sys.exit(1)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@ -38,8 +38,9 @@ def list_idls(directory: str) -> Set[str]:
"""Find all IDL files in the current directory."""
return {
os.path.join(dirpath, filename)
for dirpath, dirnames, filenames in os.walk(directory) for filename in filenames
if filename.endswith('.idl')
for dirpath, dirnames, filenames in os.walk(directory)
for filename in filenames
if filename.endswith(".idl")
}

View File

@ -46,11 +46,11 @@ def run_tests():
# my-py type information.
all_tests = unittest.defaultTestLoader.discover(start_dir="tests") # type: ignore
runner = XMLTestRunner(verbosity=2, failfast=False, output='results')
runner = XMLTestRunner(verbosity=2, failfast=False, output="results")
result = runner.run(all_tests)
sys.exit(not result.wasSuccessful())
if __name__ == '__main__':
if __name__ == "__main__":
run_tests()

View File

@ -30,7 +30,7 @@
import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
import idl.ast # pylint: disable=wrong-import-position
import idl.binder # pylint: disable=wrong-import-position

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -43,6 +43,7 @@ from textwrap import dedent
# import package so that it works regardless of whether we run as a module or file
if __package__ is None:
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from context import idl
import testcase
@ -61,16 +62,17 @@ class TestGenerator(testcase.IDLTestcase):
def _src_dir(self):
"""Get the directory of the src folder."""
base_dir = os.path.dirname(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
)
return os.path.join(
base_dir,
'src',
"src",
)
@property
def _idl_dir(self):
"""Get the directory of the idl folder."""
return os.path.join(self._src_dir, 'mongo', 'idl')
return os.path.join(self._src_dir, "mongo", "idl")
def tearDown(self) -> None:
"""Cleanup resources created by tests."""
@ -87,14 +89,15 @@ class TestGenerator(testcase.IDLTestcase):
args.output_suffix = self.output_suffix
args.import_directories = [self._src_dir]
unittest_idl_file = os.path.join(self._idl_dir, f'{self.idl_files_to_test[0]}.idl')
unittest_idl_file = os.path.join(self._idl_dir, f"{self.idl_files_to_test[0]}.idl")
if not os.path.exists(unittest_idl_file):
unittest.skip(
"Skipping IDL Generator testing since %s could not be found." % (unittest_idl_file))
"Skipping IDL Generator testing since %s could not be found." % (unittest_idl_file)
)
return
for idl_file in self.idl_files_to_test:
args.input_file = os.path.join(self._idl_dir, f'{idl_file}.idl')
args.input_file = os.path.join(self._idl_dir, f"{idl_file}.idl")
self.assertTrue(idl.compiler.compile_idl(args))
def test_enum_non_const(self):
@ -130,13 +133,15 @@ class TestGenerator(testcase.IDLTestcase):
# Make sure the getter is marked as const.
# Make sure the return type is not marked as const by validating the getter marked as const
# is the only occurrence of the word "const".
header_lines = header.split('\n')
header_lines = header.split("\n")
found = False
for header_line in header_lines:
if header_line.find("getValue") > 0 \
and header_line.find("const {") > 0 \
and header_line.find("const {") == header_line.find("const"):
if (
header_line.find("getValue") > 0
and header_line.find("const {") > 0
and header_line.find("const {") == header_line.find("const")
):
found = True
self.assertTrue(found, "Bad Header: " + header)
@ -267,7 +272,8 @@ class TestGenerator(testcase.IDLTestcase):
self.assertIn(expected, source)
def test_array_of_object_type_with_custom_serializer_and_query_shape_specification_custom(
self) -> None:
self,
) -> None:
"""Serialization with custom query_shape used, array use case."""
_, source = self.assert_generate("""
types:
@ -373,7 +379,9 @@ class TestGenerator(testcase.IDLTestcase):
def test_view_struct_generates_anchor(self) -> None:
"""Test anchor generation on view struct."""
header, _ = self.assert_generate(self.view_test_common_types + dedent("""
header, _ = self.assert_generate(
self.view_test_common_types
+ dedent("""
structs:
ViewStruct:
description: ViewStruct
@ -382,14 +390,17 @@ class TestGenerator(testcase.IDLTestcase):
value2: random_type_not_view
value3: random_type_not_view
value4: random_type_not_view
"""))
""")
)
expected = dedent("BSONObj _anchorObj;")
self.assertIn(expected, header)
def test_non_view_struct_does_not_generate_anchor(self) -> None:
"""Test anchor is not generated on non view struct."""
header, _ = self.assert_generate(self.view_test_common_types + dedent("""
header, _ = self.assert_generate(
self.view_test_common_types
+ dedent("""
structs:
NonViewStruct:
description: NonViewStruct
@ -398,14 +409,17 @@ class TestGenerator(testcase.IDLTestcase):
value2: random_type_not_view
value3: random_type_not_view
value4: random_type_not_view
"""))
""")
)
expected = dedent("BSONObj _anchorObj;")
self.assertNotIn(expected, header)
def test_compound_view_struct_generates_anchor(self) -> None:
"""Test anchor generation on view struct with compound type."""
header, _ = self.assert_generate(self.view_test_common_types + dedent("""
header, _ = self.assert_generate(
self.view_test_common_types
+ dedent("""
structs:
ViewStruct:
description: ViewStruct
@ -414,14 +428,17 @@ class TestGenerator(testcase.IDLTestcase):
value2: random_type_not_view
value3: random_type_not_view
value4: random_type_not_view
"""))
""")
)
expected = dedent("BSONObj _anchorObj;")
self.assertIn(expected, header)
def test_compound_non_view_struct_does_not_generate_anchor(self) -> None:
"""Test anchor is not generated on non view struct with compound type."""
header, _ = self.assert_generate(self.view_test_common_types + dedent("""
header, _ = self.assert_generate(
self.view_test_common_types
+ dedent("""
structs:
NonViewStruct:
description: NonViewStruct
@ -430,14 +447,17 @@ class TestGenerator(testcase.IDLTestcase):
value2: random_type_not_view
value3: random_type_not_view
value4: random_type_not_view
"""))
""")
)
expected = dedent("BSONObj _anchorObj;")
self.assertNotIn(expected, header)
def test_command_view_type_generates_anchor(self) -> None:
"""Test anchor generation on command with view parameter."""
header, _ = self.assert_generate(self.view_test_common_types + dedent("""
header, _ = self.assert_generate(
self.view_test_common_types
+ dedent("""
commands:
CommandTypeArrayObjectCommand:
description: CommandTypeArrayObjectCommand
@ -445,14 +465,17 @@ class TestGenerator(testcase.IDLTestcase):
namespace: type
api_version: ""
type: array<object_is_view>
"""))
""")
)
expected = dedent("BSONObj _anchorObj;")
self.assertIn(expected, header)
def test_command_non_view_type_does_not_generate_anchor(self) -> None:
"""Test anchor is not generated on command with nont view parameter."""
header, _ = self.assert_generate(self.view_test_common_types + dedent("""
header, _ = self.assert_generate(
self.view_test_common_types
+ dedent("""
commands:
CommandTypeArrayObjectCommand:
description: CommandTypeArrayObjectCommand
@ -460,14 +483,17 @@ class TestGenerator(testcase.IDLTestcase):
namespace: type
api_version: ""
type: array<object_is_not_view>
"""))
""")
)
expected = dedent("BSONObj _anchorObj;")
self.assertNotIn(expected, header)
def test_chained_view_struct_generates_anchor(self) -> None:
"""Test anchor generation on struct chained with view struct."""
header, _ = self.assert_generate(self.view_test_common_types + dedent("""
header, _ = self.assert_generate(
self.view_test_common_types
+ dedent("""
structs:
ViewStruct:
description: ViewStruct
@ -480,14 +506,17 @@ class TestGenerator(testcase.IDLTestcase):
description: ViewStructChainedStruct
chained_structs:
ViewStruct: ViewStruct
"""))
""")
)
expected = dedent("BSONObj _anchorObj;")
self.assertIn(expected, header)
def test_chained_non_view_struct_does_not_generate_anchor(self) -> None:
"""Test anchor not generated on struct chained with non view struct."""
header, _ = self.assert_generate(self.view_test_common_types + dedent("""
header, _ = self.assert_generate(
self.view_test_common_types
+ dedent("""
structs:
NonViewStruct:
description: NonViewStruct
@ -500,12 +529,12 @@ class TestGenerator(testcase.IDLTestcase):
description: NonViewStructChainedStruct
chained_structs:
NonViewStruct: NonViewStruct
"""))
""")
)
expected = dedent("BSONObj _anchorObj;")
self.assertNotIn(expected, header)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -38,6 +38,7 @@ from typing import Any, Dict
if __package__ is None:
import sys
from os import path
sys.path.append(path.dirname(path.abspath(__file__)))
from context import idl
import testcase
@ -88,27 +89,32 @@ class TestImport(testcase.IDLTestcase):
imports:
- "b.idl"
"""), idl.errors.ERROR_ID_DUPLICATE_NODE)
"""),
idl.errors.ERROR_ID_DUPLICATE_NODE,
)
self.assert_parse_fail(
textwrap.dedent("""
imports: "basetypes.idl"
"""), idl.errors.ERROR_ID_IS_NODE_TYPE)
"""),
idl.errors.ERROR_ID_IS_NODE_TYPE,
)
self.assert_parse_fail(
textwrap.dedent("""
imports:
a: "a.idl"
b: "b.idl"
"""), idl.errors.ERROR_ID_IS_NODE_TYPE)
"""),
idl.errors.ERROR_ID_IS_NODE_TYPE,
)
def test_import_positive(self):
# type: () -> None
"""Postive import tests."""
import_dict = {
"basetypes.idl":
textwrap.dedent("""
"basetypes.idl": textwrap.dedent("""
global:
cpp_namespace: 'mongo'
@ -135,8 +141,7 @@ class TestImport(testcase.IDLTestcase):
fields:
foo: string
"""),
"recurse1.idl":
textwrap.dedent("""
"recurse1.idl": textwrap.dedent("""
imports:
- "basetypes.idl"
@ -149,8 +154,7 @@ class TestImport(testcase.IDLTestcase):
is_view: false
"""),
"recurse2.idl":
textwrap.dedent("""
"recurse2.idl": textwrap.dedent("""
imports:
- "recurse1.idl"
@ -163,8 +167,7 @@ class TestImport(testcase.IDLTestcase):
is_view: false
"""),
"recurse1b.idl":
textwrap.dedent("""
"recurse1b.idl": textwrap.dedent("""
imports:
- "basetypes.idl"
@ -176,8 +179,7 @@ class TestImport(testcase.IDLTestcase):
deserializer: BSONElement::fake
is_view: false
"""),
"cycle1a.idl":
textwrap.dedent("""
"cycle1a.idl": textwrap.dedent("""
global:
cpp_namespace: 'mongo'
@ -208,8 +210,7 @@ class TestImport(testcase.IDLTestcase):
foo: string
foo1: bool
"""),
"cycle1b.idl":
textwrap.dedent("""
"cycle1b.idl": textwrap.dedent("""
global:
cpp_namespace: 'mongo'
@ -232,8 +233,7 @@ class TestImport(testcase.IDLTestcase):
foo: string
foo1: bool
"""),
"cycle2.idl":
textwrap.dedent("""
"cycle2.idl": textwrap.dedent("""
global:
cpp_namespace: 'mongo'
@ -275,7 +275,9 @@ class TestImport(testcase.IDLTestcase):
strict: false
fields:
foo: string
"""), resolver=resolver)
"""),
resolver=resolver,
)
# Test nested import
self.assert_bind(
@ -294,7 +296,9 @@ class TestImport(testcase.IDLTestcase):
foo: string
foo1: int
foo2: double
"""), resolver=resolver)
"""),
resolver=resolver,
)
# Test diamond import
self.assert_bind(
@ -315,7 +319,9 @@ class TestImport(testcase.IDLTestcase):
foo1: int
foo2: double
foo3: bool
"""), resolver=resolver)
"""),
resolver=resolver,
)
# Test cycle import
self.assert_bind(
@ -333,7 +339,9 @@ class TestImport(testcase.IDLTestcase):
fields:
foo: string
foo1: bool
"""), resolver=resolver)
"""),
resolver=resolver,
)
# Test self cycle import
self.assert_bind(
@ -350,15 +358,16 @@ class TestImport(testcase.IDLTestcase):
strict: false
fields:
foo: string
"""), resolver=resolver)
"""),
resolver=resolver,
)
def test_import_negative(self):
# type: () -> None
"""Negative import tests."""
import_dict = {
"basetypes.idl":
textwrap.dedent("""
"basetypes.idl": textwrap.dedent("""
global:
cpp_namespace: 'mongo'
@ -388,8 +397,7 @@ class TestImport(testcase.IDLTestcase):
b1: 1
"""),
"bug.idl":
textwrap.dedent("""
"bug.idl": textwrap.dedent("""
global:
cpp_namespace: 'mongo'
@ -409,7 +417,10 @@ class TestImport(testcase.IDLTestcase):
textwrap.dedent("""
imports:
- "notfound.idl"
"""), idl.errors.ERROR_ID_BAD_IMPORT, resolver=resolver)
"""),
idl.errors.ERROR_ID_BAD_IMPORT,
resolver=resolver,
)
# Duplicate types
self.assert_parse_fail(
@ -423,7 +434,10 @@ class TestImport(testcase.IDLTestcase):
cpp_type: foo
bson_serialization_type: string
is_view: false
"""), idl.errors.ERROR_ID_DUPLICATE_SYMBOL, resolver=resolver)
"""),
idl.errors.ERROR_ID_DUPLICATE_SYMBOL,
resolver=resolver,
)
# Duplicate structs
self.assert_parse_fail(
@ -436,7 +450,10 @@ class TestImport(testcase.IDLTestcase):
description: foo
fields:
foo1: string
"""), idl.errors.ERROR_ID_DUPLICATE_SYMBOL, resolver=resolver)
"""),
idl.errors.ERROR_ID_DUPLICATE_SYMBOL,
resolver=resolver,
)
# Duplicate struct and type
self.assert_parse_fail(
@ -449,7 +466,10 @@ class TestImport(testcase.IDLTestcase):
description: foo
fields:
foo1: string
"""), idl.errors.ERROR_ID_DUPLICATE_SYMBOL, resolver=resolver)
"""),
idl.errors.ERROR_ID_DUPLICATE_SYMBOL,
resolver=resolver,
)
# Duplicate type and struct
self.assert_parse_fail(
@ -463,7 +483,10 @@ class TestImport(testcase.IDLTestcase):
cpp_type: foo
bson_serialization_type: string
is_view: false
"""), idl.errors.ERROR_ID_DUPLICATE_SYMBOL, resolver=resolver)
"""),
idl.errors.ERROR_ID_DUPLICATE_SYMBOL,
resolver=resolver,
)
# Duplicate enums
self.assert_parse_fail(
@ -478,7 +501,10 @@ class TestImport(testcase.IDLTestcase):
values:
a0: 0
b1: 1
"""), idl.errors.ERROR_ID_DUPLICATE_SYMBOL, resolver=resolver)
"""),
idl.errors.ERROR_ID_DUPLICATE_SYMBOL,
resolver=resolver,
)
# Import a file with errors
self.assert_parse_fail(
@ -493,9 +519,11 @@ class TestImport(testcase.IDLTestcase):
cpp_type: foo
bson_serialization_type: string
is_view: false
"""), idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD, resolver=resolver)
"""),
idl.errors.ERROR_ID_MISSING_REQUIRED_FIELD,
resolver=resolver,
)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

File diff suppressed because it is too large Load Diff

View File

@ -30,9 +30,10 @@
import unittest
from typing import Any, Tuple
if __name__ == 'testcase':
if __name__ == "testcase":
import sys
from os import path
sys.path.append(path.dirname(path.dirname(path.abspath(__file__))))
from context import idl
else:
@ -79,8 +80,9 @@ class IDLTestcase(unittest.TestCase):
"""Assert a document parsed correctly by the IDL compiler and returned no errors."""
self.assertIsNone(
parsed_doc.errors,
"Expected no parser errors\nFor document:\n%s\nReceived errors:\n\n%s" %
(doc_str, errors_to_str(parsed_doc.errors)))
"Expected no parser errors\nFor document:\n%s\nReceived errors:\n\n%s"
% (doc_str, errors_to_str(parsed_doc.errors)),
)
self.assertIsNotNone(parsed_doc.spec, "Expected a parsed doc")
def assert_parse(self, doc_str, resolver=NothingImportResolver()):
@ -89,8 +91,9 @@ class IDLTestcase(unittest.TestCase):
parsed_doc = self._parse(doc_str, resolver)
self._assert_parse(doc_str, parsed_doc)
def assert_parse_fail(self, doc_str, error_id, multiple=False,
resolver=NothingImportResolver()):
def assert_parse_fail(
self, doc_str, error_id, multiple=False, resolver=NothingImportResolver()
):
# type: (str, str, bool, idl.parser.ImportResolverBase) -> None
"""
Assert a document parsed correctly by the YAML parser, but not the by the IDL compiler.
@ -107,12 +110,14 @@ class IDLTestcase(unittest.TestCase):
self.assertTrue(
multiple or parsed_doc.errors.count() == 1,
"For document:\n%s\nExpected only error message '%s' but received multiple errors:\n\n%s"
% (doc_str, error_id, errors_to_str(parsed_doc.errors)))
% (doc_str, error_id, errors_to_str(parsed_doc.errors)),
)
self.assertTrue(
parsed_doc.errors.contains(error_id),
"For document:\n%s\nExpected error message '%s' but received only errors:\n %s" %
(doc_str, error_id, errors_to_str(parsed_doc.errors)))
"For document:\n%s\nExpected error message '%s' but received only errors:\n %s"
% (doc_str, error_id, errors_to_str(parsed_doc.errors)),
)
def assert_bind(self, doc_str, resolver=NothingImportResolver()):
# type: (str, idl.parser.ImportResolverBase) -> idl.ast.IDLBoundSpec
@ -123,8 +128,10 @@ class IDLTestcase(unittest.TestCase):
bound_doc = idl.binder.bind(parsed_doc.spec)
self.assertIsNone(
bound_doc.errors, "Expected no binder errors\nFor document:\n%s\nReceived errors:\n\n%s"
% (doc_str, errors_to_str(bound_doc.errors)))
bound_doc.errors,
"Expected no binder errors\nFor document:\n%s\nReceived errors:\n\n%s"
% (doc_str, errors_to_str(bound_doc.errors)),
)
self.assertIsNotNone(bound_doc.spec, "Expected a bound doc")
return bound_doc.spec
@ -149,12 +156,14 @@ class IDLTestcase(unittest.TestCase):
self.assertTrue(
(multiple and bound_doc.errors.count() >= 1) or bound_doc.errors.count() == 1,
"For document:\n%s\nExpected only error message '%s' but received multiple errors:\n\n%s"
% (doc_str, error_id, errors_to_str(bound_doc.errors)))
% (doc_str, error_id, errors_to_str(bound_doc.errors)),
)
self.assertTrue(
bound_doc.errors.contains(error_id),
"For document:\n%s\nExpected error message '%s' but received only errors:\n %s" %
(doc_str, error_id, errors_to_str(bound_doc.errors)))
"For document:\n%s\nExpected error message '%s' but received only errors:\n %s"
% (doc_str, error_id, errors_to_str(bound_doc.errors)),
)
def assert_generate(self, doc_str, resolver=NothingImportResolver()):
# type: (str, idl.parser.ImportResolverBase) -> Tuple[str,str]

View File

@ -212,7 +212,6 @@ ignore = [
[tool.ruff.format]
exclude = [
# allow-list: incrementally remove as each is formatted and locked down
"buildscripts/idl/*",
"jstests/*",
"site_scons/*",
"src/*",