Skip to content

Commit

Permalink
Replace WhichOneof("x") with which_x.
Browse files Browse the repository at this point in the history
This change refactors `OneOfField` so that all fields in a given `oneof`
construct share the same backing attributes on their container class --
`which_{oneof name}`, which holds the (string) name of the
currently-active member of the oneof named `{oneof name}` (or `None` if
no member is active), and `_value_{oneof name}`, which holds the value
of the currently-active member (or `None`).

This avoids looping through field specs in order to do an update or to
figure out which member of a `oneof` is currently active.

Since the `WhichOneof()` method is now a trivial read of a
similarly-named attribute, it can be inlined for a small decrease in
overall code size and without sacrificing readability.

As a result of these changes, the compiler now runs 4.5% faster on my
large test `.emb`.
  • Loading branch information
reventlov committed Oct 5, 2024
1 parent 886cfbb commit f90c011
Show file tree
Hide file tree
Showing 17 changed files with 140 additions and 152 deletions.
40 changes: 20 additions & 20 deletions compiler/back_end/cpp/header_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,19 +593,19 @@ def _builtin_function_name(function):

def _cpp_basic_type_for_expression_type(expression_type, ir):
"""Returns the C++ basic type (int32_t, bool, etc.) for an ExpressionType."""
if expression_type.WhichOneof("type") == "integer":
if expression_type.which_type == "integer":
return _cpp_integer_type_for_range(
int(expression_type.integer.minimum_value),
int(expression_type.integer.maximum_value),
)
elif expression_type.WhichOneof("type") == "boolean":
elif expression_type.which_type == "boolean":
return "bool"
elif expression_type.WhichOneof("type") == "enumeration":
elif expression_type.which_type == "enumeration":
return _get_fully_qualified_name(
expression_type.enumeration.name.canonical_name, ir
)
else:
assert False, "Unknown expression type " + expression_type.WhichOneof("type")
assert False, "Unknown expression type " + expression_type.which_type


def _cpp_basic_type_for_expression(expression, ir):
Expand Down Expand Up @@ -668,12 +668,12 @@ def _render_builtin_operation(expression, ir, field_reader, subexpressions):
enum_types = set()
have_boolean_types = False
for subexpression in [expression] + list(args):
if subexpression.type.WhichOneof("type") == "integer":
if subexpression.type.which_type == "integer":
minimum_integers.append(int(subexpression.type.integer.minimum_value))
maximum_integers.append(int(subexpression.type.integer.maximum_value))
elif subexpression.type.WhichOneof("type") == "enumeration":
elif subexpression.type.which_type == "enumeration":
enum_types.add(_cpp_basic_type_for_expression(subexpression, ir))
elif subexpression.type.WhichOneof("type") == "boolean":
elif subexpression.type.which_type == "boolean":
have_boolean_types = True
# At present, all Emboss functions other than `$has` take and return one of
# the following:
Expand Down Expand Up @@ -821,39 +821,39 @@ def _render_expression(expression, ir, field_reader=None, subexpressions=None):
# will fit into C++ types, or that operator arguments and return types can fit
# in the same type: expressions like `-0x8000_0000_0000_0000` and
# `0x1_0000_0000_0000_0000 - 1` can appear.
if expression.type.WhichOneof("type") == "integer":
if expression.type.which_type == "integer":
if expression.type.integer.modulus == "infinity":
return _ExpressionResult(
_render_integer_for_expression(
int(expression.type.integer.modular_value)
),
True,
)
elif expression.type.WhichOneof("type") == "boolean":
elif expression.type.which_type == "boolean":
if expression.type.boolean.HasField("value"):
if expression.type.boolean.value:
return _ExpressionResult(_maybe_type("bool") + "(true)", True)
else:
return _ExpressionResult(_maybe_type("bool") + "(false)", True)
elif expression.type.WhichOneof("type") == "enumeration":
elif expression.type.which_type == "enumeration":
if expression.type.enumeration.HasField("value"):
return _ExpressionResult(
_render_enum_value(expression.type.enumeration, ir), True
)
else:
# There shouldn't be any "opaque" type expressions here.
assert False, "Unhandled expression type {}".format(
expression.type.WhichOneof("type")
expression.type.which_type
)

result = None
# Otherwise, render the operation.
if expression.WhichOneof("expression") == "function":
if expression.which_expression == "function":
result = _render_builtin_operation(expression, ir, field_reader, subexpressions)
elif expression.WhichOneof("expression") == "field_reference":
elif expression.which_expression == "field_reference":
result = field_reader.render_field(expression, ir, subexpressions)
elif (
expression.WhichOneof("expression") == "builtin_reference"
expression.which_expression == "builtin_reference"
and expression.builtin_reference.canonical_name.object_path[-1]
== "$logical_value"
):
Expand Down Expand Up @@ -983,7 +983,7 @@ def _generate_structure_virtual_field_methods(enclosing_type_name, field_ir, ir)
definitions should be placed after the class definition. These are
separated to satisfy C++'s declaration-before-use requirements.
"""
if field_ir.write_method.WhichOneof("method") == "alias":
if field_ir.write_method.which_method == "alias":
return _generate_field_indirection(field_ir, enclosing_type_name, ir)

read_subexpressions = _SubexpressionStore("emboss_reserved_local_subexpr_")
Expand Down Expand Up @@ -1012,7 +1012,7 @@ def _generate_structure_virtual_field_methods(enclosing_type_name, field_ir, ir)
_TEMPLATES.structure_single_virtual_field_method_definitions
)

if field_ir.write_method.WhichOneof("method") == "transform":
if field_ir.write_method.which_method == "transform":
destination = _render_variable(
ir_util.hashable_form_of_field_reference(
field_ir.write_method.transform.destination
Expand Down Expand Up @@ -1043,15 +1043,15 @@ def _generate_structure_virtual_field_methods(enclosing_type_name, field_ir, ir)
assert logical_type, "Could not find appropriate C++ type for {}".format(
field_ir.read_transform
)
if field_ir.read_transform.type.WhichOneof("type") == "integer":
if field_ir.read_transform.type.which_type == "integer":
write_to_text_stream_function = "WriteIntegerViewToTextStream"
elif field_ir.read_transform.type.WhichOneof("type") == "boolean":
elif field_ir.read_transform.type.which_type == "boolean":
write_to_text_stream_function = "WriteBooleanViewToTextStream"
elif field_ir.read_transform.type.WhichOneof("type") == "enumeration":
elif field_ir.read_transform.type.which_type == "enumeration":
write_to_text_stream_function = "WriteEnumViewToTextStream"
else:
assert False, "Unexpected read-only virtual field type {}".format(
field_ir.read_transform.type.WhichOneof("type")
field_ir.read_transform.type.which_type
)

value_is_ok = _generate_validator_expression_for(field_ir, ir)
Expand Down
2 changes: 1 addition & 1 deletion compiler/front_end/attribute_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def _verify_requires_attribute_on_field(field, source_file_name, ir, errors):
field_expression_type = type_check.unbounded_expression_type_for_physical_type(
field_type
)
if field_expression_type.WhichOneof("type") not in (
if field_expression_type.which_type not in (
"integer",
"enumeration",
"boolean",
Expand Down
22 changes: 9 additions & 13 deletions compiler/front_end/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def _render_atomic_type_name(type_ir, ir, suffix=None):

def _check_that_inner_array_dimensions_are_constant(type_ir, source_file_name, errors):
"""Checks that inner array dimensions are constant."""
if type_ir.WhichOneof("size") == "automatic":
if type_ir.which_size == "automatic":
errors.append(
[
error.error(
Expand All @@ -61,7 +61,7 @@ def _check_that_inner_array_dimensions_are_constant(type_ir, source_file_name, e
)
]
)
elif type_ir.WhichOneof("size") == "element_count":
elif type_ir.which_size == "element_count":
if not ir_util.is_constant(type_ir.element_count):
errors.append(
[
Expand Down Expand Up @@ -140,7 +140,7 @@ def _check_that_array_base_types_in_structs_are_multiples_of_bytes(

def _check_constancy_of_constant_references(expression, source_file_name, errors, ir):
"""Checks that constant_references are constant."""
if expression.WhichOneof("expression") != "constant_reference":
if expression.which_expression != "constant_reference":
return
# This is a bit of a hack: really, we want to know that the referred-to object
# has no dependencies on any instance variables of its parent structure; i.e.,
Expand Down Expand Up @@ -326,7 +326,7 @@ def _check_type_requirements_for_parameter_type(
physical_type = runtime_parameter.physical_type_alias
logical_type = runtime_parameter.type
size = ir_util.constant_value(physical_type.size_in_bits)
if logical_type.WhichOneof("type") == "integer":
if logical_type.which_type == "integer":
integer_errors = _integer_bounds_errors(
logical_type.integer,
"parameter",
Expand All @@ -345,7 +345,7 @@ def _check_type_requirements_for_parameter_type(
source_file_name,
)
)
elif logical_type.WhichOneof("type") == "enumeration":
elif logical_type.which_type == "enumeration":
if physical_type.HasField("size_in_bits"):
# This seems a little weird: for `UInt`, `Int`, etc., the explicit size is
# required, but for enums it is banned. This is because enums have a
Expand Down Expand Up @@ -567,17 +567,15 @@ def _bounds_can_fit_any_64_bit_integer_type(minimum, maximum):
def _integer_bounds_errors_for_expression(expression, source_file_name):
"""Checks that `expression` is in range for int64_t or uint64_t."""
# Only check non-constant subexpressions.
if expression.WhichOneof(
"expression"
) == "function" and not ir_util.is_constant_type(expression.type):
if expression.which_expression == "function" and not ir_util.is_constant_type(expression.type):
errors = []
for arg in expression.function.args:
errors += _integer_bounds_errors_for_expression(arg, source_file_name)
if errors:
# Don't cascade bounds errors: report them at the lowest level they
# appear.
return errors
if expression.type.WhichOneof("type") == "integer":
if expression.type.which_type == "integer":
errors = _integer_bounds_errors(
expression.type.integer,
"expression",
Expand All @@ -586,13 +584,11 @@ def _integer_bounds_errors_for_expression(expression, source_file_name):
)
if errors:
return errors
if expression.WhichOneof(
"expression"
) == "function" and not ir_util.is_constant_type(expression.type):
if expression.which_expression == "function" and not ir_util.is_constant_type(expression.type):
int64_only_clauses = []
uint64_only_clauses = []
for clause in [expression] + list(expression.function.args):
if clause.type.WhichOneof("type") == "integer":
if clause.type.which_type == "integer":
arg_minimum = int(clause.type.integer.minimum_value)
arg_maximum = int(clause.type.integer.maximum_value)
if not _bounds_can_fit_64_bit_signed(arg_minimum, arg_maximum):
Expand Down
24 changes: 12 additions & 12 deletions compiler/front_end/expression_bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def compute_constraints_of_expression(expression, ir):
"""Adds appropriate bounding constraints to the given expression."""
if ir_util.is_constant_type(expression.type):
return
expression_variety = expression.WhichOneof("expression")
expression_variety = expression.which_expression
if expression_variety == "constant":
_compute_constant_value_of_constant(expression)
elif expression_variety == "constant_reference":
Expand All @@ -51,7 +51,7 @@ def compute_constraints_of_expression(expression, ir):
_compute_constant_value_of_boolean_constant(expression)
else:
assert False, "Unknown expression variety {!r}".format(expression_variety)
if expression.type.WhichOneof("type") == "integer":
if expression.type.which_type == "integer":
_assert_integer_constraints(expression)


Expand Down Expand Up @@ -138,7 +138,7 @@ def _compute_constraints_of_field_reference(expression, ir):
ir_data_utils.builder(expression).type.CopyFrom(field.read_transform.type)
return
# Non-virtual non-integer fields do not (yet) have constraints.
if expression.type.WhichOneof("type") == "integer":
if expression.type.which_type == "integer":
# TODO(bolms): These lines will need to change when support is added for
# fixed-point types.
expression.type.integer.modulus = "1"
Expand Down Expand Up @@ -205,7 +205,7 @@ def _set_integer_constraints_from_physical_type(expression, physical_type, type_


def _compute_constraints_of_parameter(parameter):
if parameter.type.WhichOneof("type") == "integer":
if parameter.type.which_type == "integer":
type_size = ir_util.constant_value(parameter.physical_type_alias.size_in_bits)
_set_integer_constraints_from_physical_type(
parameter, parameter.physical_type_alias, type_size
Expand Down Expand Up @@ -238,14 +238,14 @@ def _compute_constraints_of_builtin_value(expression):
# [requires] attribute, are elevated to write-through fields, so that the
# [requires] clause can be checked in Write, CouldWriteValue, TryToWrite,
# Read, and Ok.
if expression.type.WhichOneof("type") == "integer":
if expression.type.which_type == "integer":
assert expression.type.integer.modulus
assert expression.type.integer.modular_value
assert expression.type.integer.minimum_value
assert expression.type.integer.maximum_value
elif expression.type.WhichOneof("type") == "enumeration":
elif expression.type.which_type == "enumeration":
assert expression.type.enumeration.name
elif expression.type.WhichOneof("type") == "boolean":
elif expression.type.which_type == "boolean":
pass
else:
assert False, "Unexpected type for $logical_value"
Expand Down Expand Up @@ -559,9 +559,9 @@ def _compute_constraints_of_bound_function(expression):

def _compute_constraints_of_maximum_function(expression):
"""Computes the constraints of the $max function."""
assert expression.type.WhichOneof("type") == "integer"
assert expression.type.which_type == "integer"
args = expression.function.args
assert args[0].type.WhichOneof("type") == "integer"
assert args[0].type.which_type == "integer"
# The minimum value of the result occurs when every argument takes its minimum
# value, which means that the minimum result is the maximum-of-minimums.
expression.type.integer.minimum_value = str(
Expand Down Expand Up @@ -683,7 +683,7 @@ def _compute_constraints_of_choice_operator(expression):
# constraints.check_constraints() will complain if minimum and maximum are not
# set correctly. I'm (bolms@) not sure if the modulus/modular_value pulls its
# weight, but for completeness I've left it in.
if if_true.type.WhichOneof("type") == "integer":
if if_true.type.which_type == "integer":
# The minimum value of the choice is the minimum value of either side, and
# the maximum is the maximum value of either side.
expression.type.integer.minimum_value = str(
Expand All @@ -709,10 +709,10 @@ def _compute_constraints_of_choice_operator(expression):
expression.type.integer.modulus = str(new_modulus)
expression.type.integer.modular_value = str(new_modular_value)
else:
assert if_true.type.WhichOneof("type") in (
assert if_true.type.which_type in (
"boolean",
"enumeration",
), "Unknown type {} for expression".format(if_true.type.WhichOneof("type"))
), "Unknown type {} for expression".format(if_true.type.which_type)


def _greatest_common_divisor(a, b):
Expand Down
2 changes: 1 addition & 1 deletion compiler/front_end/expression_bounds_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,7 @@ def test_choice_non_integer_arguments(self):
)
self.assertEqual([], expression_bounds.compute_constants(ir))
expr = ir.module[0].type[0].structure.field[1].existence_condition
self.assertEqual("boolean", expr.type.WhichOneof("type"))
self.assertEqual("boolean", expr.type.which_type)
self.assertFalse(expr.type.boolean.HasField("value"))

def test_uint_value_range_for_explicit_size(self):
Expand Down
10 changes: 5 additions & 5 deletions compiler/front_end/module_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,7 @@ def _bottom_expression_from_reference(reference):

@_handles("field-reference -> snake-reference field-reference-tail*")
def _indirect_field_reference(field_reference, field_references):
if field_references.source_location.HasField("end"):
if field_references.source_location.end is not None:
end_location = field_references.source_location.end
else:
end_location = field_reference.source_location.end
Expand Down Expand Up @@ -1097,7 +1097,7 @@ def _field(
if abbreviation.list:
field.abbreviation.CopyFrom(abbreviation.list[0])
field.source_location.start.CopyFrom(location.source_location.start)
if field_body.source_location.HasField("end"):
if field_body.source_location.end is not None:
field.source_location.end.CopyFrom(field_body.source_location.end)
else:
field.source_location.end.CopyFrom(newline.source_location.end)
Expand All @@ -1122,7 +1122,7 @@ def _virtual_field(let, name, equals, value, comment, newline, field_body):
field.attribute.extend(field_body.list[0].attribute)
field.documentation.extend(field_body.list[0].documentation)
field.source_location.start.CopyFrom(let.source_location.start)
if field_body.source_location.HasField("end"):
if field_body.source_location.end is not None:
field.source_location.end.CopyFrom(field_body.source_location.end)
else:
field.source_location.end.CopyFrom(newline.source_location.end)
Expand Down Expand Up @@ -1202,12 +1202,12 @@ def _inline_type_field(location, name, abbreviation, body):
ir_data_utils.builder(body.source_location).start.CopyFrom(
location.source_location.start
)
if body.HasField("enumeration"):
if body.enumeration is not None:
ir_data_utils.builder(body.enumeration).source_location.CopyFrom(
body.source_location
)
else:
assert body.HasField("structure")
assert body.structure is not None
ir_data_utils.builder(body.structure).source_location.CopyFrom(
body.source_location
)
Expand Down
6 changes: 3 additions & 3 deletions compiler/front_end/symbol_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def _resolve_field_reference(field_reference, source_file_name, errors, ir):
for ref in field_reference.path[1:]:
while ir_util.field_is_virtual(previous_field):
if (
previous_field.read_transform.WhichOneof("expression")
previous_field.read_transform.which_expression
== "field_reference"
):
# Pass a separate error list into the recursive _resolve_field_reference
Expand Down Expand Up @@ -494,7 +494,7 @@ def _resolve_field_reference(field_reference, source_file_name, errors, ir):
)
)
return
if previous_field.type.WhichOneof("type") == "array_type":
if previous_field.type.which_type == "array_type":
errors.append(
array_subfield_error(
source_file_name,
Expand All @@ -503,7 +503,7 @@ def _resolve_field_reference(field_reference, source_file_name, errors, ir):
)
)
return
assert previous_field.type.WhichOneof("type") == "atomic_type"
assert previous_field.type.which_type == "atomic_type"
member_name = ir_data_utils.copy(
previous_field.type.atomic_type.reference.canonical_name
)
Expand Down
2 changes: 1 addition & 1 deletion compiler/front_end/symbol_resolver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def test_symbol_resolution_in_expression_in_void_array_length(self):
struct_ir = ir.module[0].type[4].structure
array_type = struct_ir.field[0].type.array_type
# The symbol resolver should ignore void fields.
self.assertEqual("automatic", array_type.WhichOneof("size"))
self.assertEqual("automatic", array_type.which_size)

def test_name_definitions_have_correct_canonical_names(self):
ir = self._construct_ir(_HAPPY_EMB)
Expand Down
Loading

0 comments on commit f90c011

Please sign in to comment.