diff --git a/compiler/back_end/cpp/header_generator.py b/compiler/back_end/cpp/header_generator.py index 6a66461..7c6e56d 100644 --- a/compiler/back_end/cpp/header_generator.py +++ b/compiler/back_end/cpp/header_generator.py @@ -593,19 +593,19 @@ def _builtin_function_name(function): def _cpp_basic_type_for_expression_type(expression_type, ir): """Returns the C++ basic type (int32_t, bool, etc.) for an ExpressionType.""" - if expression_type.WhichOneof("type") == "integer": + if expression_type.which_type == "integer": return _cpp_integer_type_for_range( int(expression_type.integer.minimum_value), int(expression_type.integer.maximum_value), ) - elif expression_type.WhichOneof("type") == "boolean": + elif expression_type.which_type == "boolean": return "bool" - elif expression_type.WhichOneof("type") == "enumeration": + elif expression_type.which_type == "enumeration": return _get_fully_qualified_name( expression_type.enumeration.name.canonical_name, ir ) else: - assert False, "Unknown expression type " + expression_type.WhichOneof("type") + assert False, "Unknown expression type " + expression_type.which_type def _cpp_basic_type_for_expression(expression, ir): @@ -668,12 +668,12 @@ def _render_builtin_operation(expression, ir, field_reader, subexpressions): enum_types = set() have_boolean_types = False for subexpression in [expression] + list(args): - if subexpression.type.WhichOneof("type") == "integer": + if subexpression.type.which_type == "integer": minimum_integers.append(int(subexpression.type.integer.minimum_value)) maximum_integers.append(int(subexpression.type.integer.maximum_value)) - elif subexpression.type.WhichOneof("type") == "enumeration": + elif subexpression.type.which_type == "enumeration": enum_types.add(_cpp_basic_type_for_expression(subexpression, ir)) - elif subexpression.type.WhichOneof("type") == "boolean": + elif subexpression.type.which_type == "boolean": have_boolean_types = True # At present, all Emboss functions other than `$has` take and return one of # the following: @@ -821,7 +821,7 @@ def _render_expression(expression, ir, field_reader=None, subexpressions=None): # will fit into C++ types, or that operator arguments and return types can fit # in the same type: expressions like `-0x8000_0000_0000_0000` and # `0x1_0000_0000_0000_0000 - 1` can appear. - if expression.type.WhichOneof("type") == "integer": + if expression.type.which_type == "integer": if expression.type.integer.modulus == "infinity": return _ExpressionResult( _render_integer_for_expression( @@ -829,13 +829,13 @@ def _render_expression(expression, ir, field_reader=None, subexpressions=None): ), True, ) - elif expression.type.WhichOneof("type") == "boolean": + elif expression.type.which_type == "boolean": if expression.type.boolean.HasField("value"): if expression.type.boolean.value: return _ExpressionResult(_maybe_type("bool") + "(true)", True) else: return _ExpressionResult(_maybe_type("bool") + "(false)", True) - elif expression.type.WhichOneof("type") == "enumeration": + elif expression.type.which_type == "enumeration": if expression.type.enumeration.HasField("value"): return _ExpressionResult( _render_enum_value(expression.type.enumeration, ir), True @@ -843,17 +843,17 @@ def _render_expression(expression, ir, field_reader=None, subexpressions=None): else: # There shouldn't be any "opaque" type expressions here. assert False, "Unhandled expression type {}".format( - expression.type.WhichOneof("type") + expression.type.which_type ) result = None # Otherwise, render the operation. - if expression.WhichOneof("expression") == "function": + if expression.which_expression == "function": result = _render_builtin_operation(expression, ir, field_reader, subexpressions) - elif expression.WhichOneof("expression") == "field_reference": + elif expression.which_expression == "field_reference": result = field_reader.render_field(expression, ir, subexpressions) elif ( - expression.WhichOneof("expression") == "builtin_reference" + expression.which_expression == "builtin_reference" and expression.builtin_reference.canonical_name.object_path[-1] == "$logical_value" ): @@ -983,7 +983,7 @@ def _generate_structure_virtual_field_methods(enclosing_type_name, field_ir, ir) definitions should be placed after the class definition. These are separated to satisfy C++'s declaration-before-use requirements. """ - if field_ir.write_method.WhichOneof("method") == "alias": + if field_ir.write_method.which_method == "alias": return _generate_field_indirection(field_ir, enclosing_type_name, ir) read_subexpressions = _SubexpressionStore("emboss_reserved_local_subexpr_") @@ -1012,7 +1012,7 @@ def _generate_structure_virtual_field_methods(enclosing_type_name, field_ir, ir) _TEMPLATES.structure_single_virtual_field_method_definitions ) - if field_ir.write_method.WhichOneof("method") == "transform": + if field_ir.write_method.which_method == "transform": destination = _render_variable( ir_util.hashable_form_of_field_reference( field_ir.write_method.transform.destination @@ -1043,15 +1043,15 @@ def _generate_structure_virtual_field_methods(enclosing_type_name, field_ir, ir) assert logical_type, "Could not find appropriate C++ type for {}".format( field_ir.read_transform ) - if field_ir.read_transform.type.WhichOneof("type") == "integer": + if field_ir.read_transform.type.which_type == "integer": write_to_text_stream_function = "WriteIntegerViewToTextStream" - elif field_ir.read_transform.type.WhichOneof("type") == "boolean": + elif field_ir.read_transform.type.which_type == "boolean": write_to_text_stream_function = "WriteBooleanViewToTextStream" - elif field_ir.read_transform.type.WhichOneof("type") == "enumeration": + elif field_ir.read_transform.type.which_type == "enumeration": write_to_text_stream_function = "WriteEnumViewToTextStream" else: assert False, "Unexpected read-only virtual field type {}".format( - field_ir.read_transform.type.WhichOneof("type") + field_ir.read_transform.type.which_type ) value_is_ok = _generate_validator_expression_for(field_ir, ir) diff --git a/compiler/front_end/attribute_checker.py b/compiler/front_end/attribute_checker.py index 9c867d2..9e6ed50 100644 --- a/compiler/front_end/attribute_checker.py +++ b/compiler/front_end/attribute_checker.py @@ -441,7 +441,7 @@ def _verify_requires_attribute_on_field(field, source_file_name, ir, errors): field_expression_type = type_check.unbounded_expression_type_for_physical_type( field_type ) - if field_expression_type.WhichOneof("type") not in ( + if field_expression_type.which_type not in ( "integer", "enumeration", "boolean", diff --git a/compiler/front_end/constraints.py b/compiler/front_end/constraints.py index 852c9ad..3753ffb 100644 --- a/compiler/front_end/constraints.py +++ b/compiler/front_end/constraints.py @@ -51,7 +51,7 @@ def _render_atomic_type_name(type_ir, ir, suffix=None): def _check_that_inner_array_dimensions_are_constant(type_ir, source_file_name, errors): """Checks that inner array dimensions are constant.""" - if type_ir.WhichOneof("size") == "automatic": + if type_ir.which_size == "automatic": errors.append( [ error.error( @@ -61,7 +61,7 @@ def _check_that_inner_array_dimensions_are_constant(type_ir, source_file_name, e ) ] ) - elif type_ir.WhichOneof("size") == "element_count": + elif type_ir.which_size == "element_count": if not ir_util.is_constant(type_ir.element_count): errors.append( [ @@ -140,7 +140,7 @@ def _check_that_array_base_types_in_structs_are_multiples_of_bytes( def _check_constancy_of_constant_references(expression, source_file_name, errors, ir): """Checks that constant_references are constant.""" - if expression.WhichOneof("expression") != "constant_reference": + if expression.which_expression != "constant_reference": return # This is a bit of a hack: really, we want to know that the referred-to object # has no dependencies on any instance variables of its parent structure; i.e., @@ -326,7 +326,7 @@ def _check_type_requirements_for_parameter_type( physical_type = runtime_parameter.physical_type_alias logical_type = runtime_parameter.type size = ir_util.constant_value(physical_type.size_in_bits) - if logical_type.WhichOneof("type") == "integer": + if logical_type.which_type == "integer": integer_errors = _integer_bounds_errors( logical_type.integer, "parameter", @@ -345,7 +345,7 @@ def _check_type_requirements_for_parameter_type( source_file_name, ) ) - elif logical_type.WhichOneof("type") == "enumeration": + elif logical_type.which_type == "enumeration": if physical_type.HasField("size_in_bits"): # This seems a little weird: for `UInt`, `Int`, etc., the explicit size is # required, but for enums it is banned. This is because enums have a @@ -567,9 +567,7 @@ def _bounds_can_fit_any_64_bit_integer_type(minimum, maximum): def _integer_bounds_errors_for_expression(expression, source_file_name): """Checks that `expression` is in range for int64_t or uint64_t.""" # Only check non-constant subexpressions. - if expression.WhichOneof( - "expression" - ) == "function" and not ir_util.is_constant_type(expression.type): + if expression.which_expression == "function" and not ir_util.is_constant_type(expression.type): errors = [] for arg in expression.function.args: errors += _integer_bounds_errors_for_expression(arg, source_file_name) @@ -577,7 +575,7 @@ def _integer_bounds_errors_for_expression(expression, source_file_name): # Don't cascade bounds errors: report them at the lowest level they # appear. return errors - if expression.type.WhichOneof("type") == "integer": + if expression.type.which_type == "integer": errors = _integer_bounds_errors( expression.type.integer, "expression", @@ -586,13 +584,11 @@ def _integer_bounds_errors_for_expression(expression, source_file_name): ) if errors: return errors - if expression.WhichOneof( - "expression" - ) == "function" and not ir_util.is_constant_type(expression.type): + if expression.which_expression == "function" and not ir_util.is_constant_type(expression.type): int64_only_clauses = [] uint64_only_clauses = [] for clause in [expression] + list(expression.function.args): - if clause.type.WhichOneof("type") == "integer": + if clause.type.which_type == "integer": arg_minimum = int(clause.type.integer.minimum_value) arg_maximum = int(clause.type.integer.maximum_value) if not _bounds_can_fit_64_bit_signed(arg_minimum, arg_maximum): diff --git a/compiler/front_end/expression_bounds.py b/compiler/front_end/expression_bounds.py index cca36ee..51efca5 100644 --- a/compiler/front_end/expression_bounds.py +++ b/compiler/front_end/expression_bounds.py @@ -36,7 +36,7 @@ def compute_constraints_of_expression(expression, ir): """Adds appropriate bounding constraints to the given expression.""" if ir_util.is_constant_type(expression.type): return - expression_variety = expression.WhichOneof("expression") + expression_variety = expression.which_expression if expression_variety == "constant": _compute_constant_value_of_constant(expression) elif expression_variety == "constant_reference": @@ -51,7 +51,7 @@ def compute_constraints_of_expression(expression, ir): _compute_constant_value_of_boolean_constant(expression) else: assert False, "Unknown expression variety {!r}".format(expression_variety) - if expression.type.WhichOneof("type") == "integer": + if expression.type.which_type == "integer": _assert_integer_constraints(expression) @@ -138,7 +138,7 @@ def _compute_constraints_of_field_reference(expression, ir): ir_data_utils.builder(expression).type.CopyFrom(field.read_transform.type) return # Non-virtual non-integer fields do not (yet) have constraints. - if expression.type.WhichOneof("type") == "integer": + if expression.type.which_type == "integer": # TODO(bolms): These lines will need to change when support is added for # fixed-point types. expression.type.integer.modulus = "1" @@ -205,7 +205,7 @@ def _set_integer_constraints_from_physical_type(expression, physical_type, type_ def _compute_constraints_of_parameter(parameter): - if parameter.type.WhichOneof("type") == "integer": + if parameter.type.which_type == "integer": type_size = ir_util.constant_value(parameter.physical_type_alias.size_in_bits) _set_integer_constraints_from_physical_type( parameter, parameter.physical_type_alias, type_size @@ -238,14 +238,14 @@ def _compute_constraints_of_builtin_value(expression): # [requires] attribute, are elevated to write-through fields, so that the # [requires] clause can be checked in Write, CouldWriteValue, TryToWrite, # Read, and Ok. - if expression.type.WhichOneof("type") == "integer": + if expression.type.which_type == "integer": assert expression.type.integer.modulus assert expression.type.integer.modular_value assert expression.type.integer.minimum_value assert expression.type.integer.maximum_value - elif expression.type.WhichOneof("type") == "enumeration": + elif expression.type.which_type == "enumeration": assert expression.type.enumeration.name - elif expression.type.WhichOneof("type") == "boolean": + elif expression.type.which_type == "boolean": pass else: assert False, "Unexpected type for $logical_value" @@ -559,9 +559,9 @@ def _compute_constraints_of_bound_function(expression): def _compute_constraints_of_maximum_function(expression): """Computes the constraints of the $max function.""" - assert expression.type.WhichOneof("type") == "integer" + assert expression.type.which_type == "integer" args = expression.function.args - assert args[0].type.WhichOneof("type") == "integer" + assert args[0].type.which_type == "integer" # The minimum value of the result occurs when every argument takes its minimum # value, which means that the minimum result is the maximum-of-minimums. expression.type.integer.minimum_value = str( @@ -683,7 +683,7 @@ def _compute_constraints_of_choice_operator(expression): # constraints.check_constraints() will complain if minimum and maximum are not # set correctly. I'm (bolms@) not sure if the modulus/modular_value pulls its # weight, but for completeness I've left it in. - if if_true.type.WhichOneof("type") == "integer": + if if_true.type.which_type == "integer": # The minimum value of the choice is the minimum value of either side, and # the maximum is the maximum value of either side. expression.type.integer.minimum_value = str( @@ -709,10 +709,10 @@ def _compute_constraints_of_choice_operator(expression): expression.type.integer.modulus = str(new_modulus) expression.type.integer.modular_value = str(new_modular_value) else: - assert if_true.type.WhichOneof("type") in ( + assert if_true.type.which_type in ( "boolean", "enumeration", - ), "Unknown type {} for expression".format(if_true.type.WhichOneof("type")) + ), "Unknown type {} for expression".format(if_true.type.which_type) def _greatest_common_divisor(a, b): diff --git a/compiler/front_end/expression_bounds_test.py b/compiler/front_end/expression_bounds_test.py index 7af6836..e5bc25d 100644 --- a/compiler/front_end/expression_bounds_test.py +++ b/compiler/front_end/expression_bounds_test.py @@ -1007,7 +1007,7 @@ def test_choice_non_integer_arguments(self): ) self.assertEqual([], expression_bounds.compute_constants(ir)) expr = ir.module[0].type[0].structure.field[1].existence_condition - self.assertEqual("boolean", expr.type.WhichOneof("type")) + self.assertEqual("boolean", expr.type.which_type) self.assertFalse(expr.type.boolean.HasField("value")) def test_uint_value_range_for_explicit_size(self): diff --git a/compiler/front_end/module_ir.py b/compiler/front_end/module_ir.py index 4a459c2..f54b953 100644 --- a/compiler/front_end/module_ir.py +++ b/compiler/front_end/module_ir.py @@ -811,7 +811,7 @@ def _bottom_expression_from_reference(reference): @_handles("field-reference -> snake-reference field-reference-tail*") def _indirect_field_reference(field_reference, field_references): - if field_references.source_location.HasField("end"): + if field_references.source_location.end is not None: end_location = field_references.source_location.end else: end_location = field_reference.source_location.end @@ -1097,7 +1097,7 @@ def _field( if abbreviation.list: field.abbreviation.CopyFrom(abbreviation.list[0]) field.source_location.start.CopyFrom(location.source_location.start) - if field_body.source_location.HasField("end"): + if field_body.source_location.end is not None: field.source_location.end.CopyFrom(field_body.source_location.end) else: field.source_location.end.CopyFrom(newline.source_location.end) @@ -1122,7 +1122,7 @@ def _virtual_field(let, name, equals, value, comment, newline, field_body): field.attribute.extend(field_body.list[0].attribute) field.documentation.extend(field_body.list[0].documentation) field.source_location.start.CopyFrom(let.source_location.start) - if field_body.source_location.HasField("end"): + if field_body.source_location.end is not None: field.source_location.end.CopyFrom(field_body.source_location.end) else: field.source_location.end.CopyFrom(newline.source_location.end) @@ -1202,12 +1202,12 @@ def _inline_type_field(location, name, abbreviation, body): ir_data_utils.builder(body.source_location).start.CopyFrom( location.source_location.start ) - if body.HasField("enumeration"): + if body.enumeration is not None: ir_data_utils.builder(body.enumeration).source_location.CopyFrom( body.source_location ) else: - assert body.HasField("structure") + assert body.structure is not None ir_data_utils.builder(body.structure).source_location.CopyFrom( body.source_location ) diff --git a/compiler/front_end/symbol_resolver.py b/compiler/front_end/symbol_resolver.py index 498b1a9..53f2baa 100644 --- a/compiler/front_end/symbol_resolver.py +++ b/compiler/front_end/symbol_resolver.py @@ -460,7 +460,7 @@ def _resolve_field_reference(field_reference, source_file_name, errors, ir): for ref in field_reference.path[1:]: while ir_util.field_is_virtual(previous_field): if ( - previous_field.read_transform.WhichOneof("expression") + previous_field.read_transform.which_expression == "field_reference" ): # Pass a separate error list into the recursive _resolve_field_reference @@ -494,7 +494,7 @@ def _resolve_field_reference(field_reference, source_file_name, errors, ir): ) ) return - if previous_field.type.WhichOneof("type") == "array_type": + if previous_field.type.which_type == "array_type": errors.append( array_subfield_error( source_file_name, @@ -503,7 +503,7 @@ def _resolve_field_reference(field_reference, source_file_name, errors, ir): ) ) return - assert previous_field.type.WhichOneof("type") == "atomic_type" + assert previous_field.type.which_type == "atomic_type" member_name = ir_data_utils.copy( previous_field.type.atomic_type.reference.canonical_name ) diff --git a/compiler/front_end/symbol_resolver_test.py b/compiler/front_end/symbol_resolver_test.py index 693d157..ddf7783 100644 --- a/compiler/front_end/symbol_resolver_test.py +++ b/compiler/front_end/symbol_resolver_test.py @@ -179,7 +179,7 @@ def test_symbol_resolution_in_expression_in_void_array_length(self): struct_ir = ir.module[0].type[4].structure array_type = struct_ir.field[0].type.array_type # The symbol resolver should ignore void fields. - self.assertEqual("automatic", array_type.WhichOneof("size")) + self.assertEqual("automatic", array_type.which_size) def test_name_definitions_have_correct_canonical_names(self): ir = self._construct_ir(_HAPPY_EMB) diff --git a/compiler/front_end/type_check.py b/compiler/front_end/type_check.py index c3f7870..0bfe916 100644 --- a/compiler/front_end/type_check.py +++ b/compiler/front_end/type_check.py @@ -24,10 +24,10 @@ def _type_check_expression(expression, source_file_name, ir, errors): """Checks and annotates the type of an expression and all subexpressions.""" - if ir_data_utils.reader(expression).type.WhichOneof("type"): + if ir_data_utils.reader(expression).type.which_type: # This expression has already been type checked. return - expression_variety = expression.WhichOneof("expression") + expression_variety = expression.which_expression if expression_variety == "constant": _type_check_integer_constant(expression) elif expression_variety == "constant_reference": @@ -55,7 +55,7 @@ def _annotate_as_boolean(expression): def _type_check( expression, source_file_name, errors, type_oneof, type_name, expression_name ): - if ir_data_utils.reader(expression).type.WhichOneof("type") != type_oneof: + if ir_data_utils.reader(expression).type.which_type != type_oneof: errors.append( [ error.error( @@ -80,7 +80,7 @@ def _type_check_boolean(expression, source_file_name, errors, expression_name): def _kind_check_field_reference(expression, source_file_name, errors, expression_name): - if expression.WhichOneof("expression") != "field_reference": + if expression.which_expression != "field_reference": errors.append( [ error.error( @@ -286,7 +286,7 @@ def _type_check_local_reference(expression, ir, errors): ) ir_data_utils.builder(expression).type.CopyFrom(field.read_transform.type) return - if not field.type.HasField("atomic_type"): + if field.type.atomic_type is None: ir_data_utils.builder(expression).type.opaque.CopyFrom(ir_data.OpaqueType()) else: _set_expression_type_from_physical_type_reference( @@ -313,7 +313,7 @@ def unbounded_expression_type_for_physical_type(type_definition): elif tuple(type_definition.name.canonical_name.object_path) == ("Flag",): # This is a hack: the Flag type should say that it is a boolean. return ir_data.ExpressionType(boolean=ir_data.BooleanType()) - elif type_definition.HasField("enumeration"): + elif type_definition.enumeration is not None: return ir_data.ExpressionType( enumeration=ir_data.EnumType( name=ir_data.Reference( @@ -335,7 +335,7 @@ def _set_expression_type_from_physical_type_reference(expression, type_reference def _annotate_parameter_type(parameter, ir, source_file_name, errors): - if parameter.physical_type_alias.WhichOneof("type") != "atomic_type": + if parameter.physical_type_alias.which_type != "atomic_type": errors.append( [ error.error( @@ -353,13 +353,13 @@ def _annotate_parameter_type(parameter, ir, source_file_name, errors): def _types_are_compatible(a, b): """Returns true if a and b have compatible types.""" - if a.type.WhichOneof("type") != b.type.WhichOneof("type"): + if a.type.which_type != b.type.which_type: return False - elif a.type.WhichOneof("type") == "enumeration": + elif a.type.which_type == "enumeration": return ir_util.hashable_form_of_reference( a.type.enumeration.name ) == ir_util.hashable_form_of_reference(b.type.enumeration.name) - elif a.type.WhichOneof("type") in ("integer", "boolean"): + elif a.type.which_type in ("integer", "boolean"): # All integers are compatible with integers; booleans are compatible with # booleans return True @@ -383,7 +383,7 @@ def _type_check_comparison_operator(expression, source_file_name, errors): left = expression.function.args[0] right = expression.function.args[1] for argument, name in ((left, "Left"), (right, "Right")): - if argument.type.WhichOneof("type") not in acceptable_types: + if argument.type.which_type not in acceptable_types: errors.append( [ error.error( @@ -415,7 +415,7 @@ def _type_check_comparison_operator(expression, source_file_name, errors): def _type_check_choice_operator(expression, source_file_name, errors): """Checks the type of the choice operator cond ? if_true : if_false.""" condition = expression.function.args[0] - if condition.type.WhichOneof("type") != "boolean": + if condition.type.which_type != "boolean": errors.append( [ error.error( @@ -426,7 +426,7 @@ def _type_check_choice_operator(expression, source_file_name, errors): ] ) if_true = expression.function.args[1] - if if_true.type.WhichOneof("type") not in ("integer", "boolean", "enumeration"): + if if_true.type.which_type not in ("integer", "boolean", "enumeration"): errors.append( [ error.error( @@ -450,11 +450,11 @@ def _type_check_choice_operator(expression, source_file_name, errors): ) ] ) - if if_true.type.WhichOneof("type") == "integer": + if if_true.type.which_type == "integer": _annotate_as_integer(expression) - elif if_true.type.WhichOneof("type") == "boolean": + elif if_true.type.which_type == "boolean": _annotate_as_boolean(expression) - elif if_true.type.WhichOneof("type") == "enumeration": + elif if_true.type.which_type == "enumeration": ir_data_utils.builder(expression).type.enumeration.name.CopyFrom( if_true.type.enumeration.name ) @@ -492,9 +492,9 @@ def _type_check_field_existence_condition(field, source_file_name, errors): def _type_name_for_error_messages(expression_type): - if expression_type.WhichOneof("type") == "integer": + if expression_type.which_type == "integer": return "integer" - elif expression_type.WhichOneof("type") == "enumeration": + elif expression_type.which_type == "enumeration": # TODO(bolms): Should this be the fully-qualified name? return expression_type.enumeration.name.canonical_name.object_path[-1] assert False, "Shouldn't be here." @@ -526,7 +526,7 @@ def _type_check_passed_parameters(atomic_type, ir, source_file_name, errors): ) return for i in range(len(referenced_type.runtime_parameter)): - if referenced_type.runtime_parameter[i].type.WhichOneof("type") not in ( + if referenced_type.runtime_parameter[i].type.which_type not in ( "integer", "boolean", "enumeration", @@ -535,9 +535,7 @@ def _type_check_passed_parameters(atomic_type, ir, source_file_name, errors): # definition site; no need for another, probably-confusing error at any # usage sites. continue - if atomic_type.runtime_parameter[i].type.WhichOneof( - "type" - ) != referenced_type.runtime_parameter[i].type.WhichOneof("type"): + if atomic_type.runtime_parameter[i].type.which_type != referenced_type.runtime_parameter[i].type.which_type: errors.append( [ error.error( @@ -565,7 +563,7 @@ def _type_check_passed_parameters(atomic_type, ir, source_file_name, errors): def _type_check_parameter(runtime_parameter, source_file_name, errors): """Checks the type of a parameter to a physical type.""" - if runtime_parameter.type.WhichOneof("type") not in ("integer", "enumeration"): + if runtime_parameter.type.which_type not in ("integer", "enumeration"): errors.append( [ error.error( diff --git a/compiler/front_end/type_check_test.py b/compiler/front_end/type_check_test.py index 995f20d..fce5a14 100644 --- a/compiler/front_end/type_check_test.py +++ b/compiler/front_end/type_check_test.py @@ -39,7 +39,7 @@ def test_adds_integer_constant_type(self): ) self.assertEqual([], type_check.annotate_types(ir)) expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual(expression.type.WhichOneof("type"), "integer") + self.assertEqual(expression.type.which_type, "integer") def test_adds_boolean_constant_type(self): ir = self._make_ir( @@ -51,7 +51,7 @@ def test_adds_boolean_constant_type(self): ir_data_utils.IrDataSerializer(ir).to_json(indent=2), ) expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual(expression.type.WhichOneof("type"), "boolean") + self.assertEqual(expression.type.which_type, "boolean") def test_adds_enum_constant_type(self): ir = self._make_ir( @@ -62,7 +62,7 @@ def test_adds_enum_constant_type(self): ) self.assertEqual([], error.filter_errors(type_check.annotate_types(ir))) expression = ir.module[0].type[0].structure.field[0].location.size - self.assertEqual(expression.type.WhichOneof("type"), "enumeration") + self.assertEqual(expression.type.which_type, "enumeration") enum_type_name = expression.type.enumeration.name.canonical_name self.assertEqual(enum_type_name.module_file, "m.emb") self.assertEqual(enum_type_name.object_path[0], "Enum") @@ -77,7 +77,7 @@ def test_adds_enum_field_type(self): ) self.assertEqual([], error.filter_errors(type_check.annotate_types(ir))) expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual(expression.type.WhichOneof("type"), "enumeration") + self.assertEqual(expression.type.which_type, "enumeration") enum_type_name = expression.type.enumeration.name.canonical_name self.assertEqual(enum_type_name.module_file, "m.emb") self.assertEqual(enum_type_name.object_path[0], "Enum") @@ -88,9 +88,9 @@ def test_adds_integer_operation_types(self): ) self.assertEqual([], type_check.annotate_types(ir)) expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual(expression.type.WhichOneof("type"), "integer") - self.assertEqual(expression.function.args[0].type.WhichOneof("type"), "integer") - self.assertEqual(expression.function.args[1].type.WhichOneof("type"), "integer") + self.assertEqual(expression.type.which_type, "integer") + self.assertEqual(expression.function.args[0].type.which_type, "integer") + self.assertEqual(expression.function.args[1].type.which_type, "integer") def test_adds_enum_operation_type(self): ir = self._make_ir( @@ -102,12 +102,12 @@ def test_adds_enum_operation_type(self): ) self.assertEqual([], error.filter_errors(type_check.annotate_types(ir))) expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual(expression.type.WhichOneof("type"), "boolean") + self.assertEqual(expression.type.which_type, "boolean") self.assertEqual( - expression.function.args[0].type.WhichOneof("type"), "enumeration" + expression.function.args[0].type.which_type, "enumeration" ) self.assertEqual( - expression.function.args[1].type.WhichOneof("type"), "enumeration" + expression.function.args[1].type.which_type, "enumeration" ) def test_adds_enum_comparison_operation_type(self): @@ -120,12 +120,12 @@ def test_adds_enum_comparison_operation_type(self): ) self.assertEqual([], error.filter_errors(type_check.annotate_types(ir))) expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual(expression.type.WhichOneof("type"), "boolean") + self.assertEqual(expression.type.which_type, "boolean") self.assertEqual( - expression.function.args[0].type.WhichOneof("type"), "enumeration" + expression.function.args[0].type.which_type, "enumeration" ) self.assertEqual( - expression.function.args[1].type.WhichOneof("type"), "enumeration" + expression.function.args[1].type.which_type, "enumeration" ) def test_adds_integer_field_type(self): @@ -134,7 +134,7 @@ def test_adds_integer_field_type(self): ) self.assertEqual([], type_check.annotate_types(ir)) expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual(expression.type.WhichOneof("type"), "integer") + self.assertEqual(expression.type.which_type, "integer") def test_adds_opaque_field_type(self): ir = self._make_ir( @@ -146,7 +146,7 @@ def test_adds_opaque_field_type(self): ) self.assertEqual([], error.filter_errors(type_check.annotate_types(ir))) expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual(expression.type.WhichOneof("type"), "opaque") + self.assertEqual(expression.type.which_type, "opaque") def test_adds_opaque_field_type_for_array(self): ir = self._make_ir( @@ -154,7 +154,7 @@ def test_adds_opaque_field_type_for_array(self): ) self.assertEqual([], error.filter_errors(type_check.annotate_types(ir))) expression = ir.module[0].type[0].structure.field[1].location.size - self.assertEqual(expression.type.WhichOneof("type"), "opaque") + self.assertEqual(expression.type.which_type, "opaque") def test_error_on_bad_plus_operand_types(self): ir = self._make_ir( @@ -395,7 +395,7 @@ def test_choice_of_bools(self): ) expression = ir.module[0].type[0].structure.field[1].location.size self.assertEqual([], error.filter_errors(type_check.annotate_types(ir))) - self.assertEqual("boolean", expression.type.WhichOneof("type")) + self.assertEqual("boolean", expression.type.which_type) def test_choice_of_integers(self): ir = self._make_ir( @@ -405,7 +405,7 @@ def test_choice_of_integers(self): ) expression = ir.module[0].type[0].structure.field[1].location.size self.assertEqual([], type_check.annotate_types(ir)) - self.assertEqual("integer", expression.type.WhichOneof("type")) + self.assertEqual("integer", expression.type.which_type) def test_choice_of_enums(self): ir = self._make_ir( @@ -417,7 +417,7 @@ def test_choice_of_enums(self): ) expression = ir.module[0].type[0].structure.field[1].location.size self.assertEqual([], error.filter_errors(type_check.annotate_types(ir))) - self.assertEqual("enumeration", expression.type.WhichOneof("type")) + self.assertEqual("enumeration", expression.type.which_type) self.assertFalse(expression.type.enumeration.HasField("value")) self.assertEqual( "m.emb", expression.type.enumeration.name.canonical_name.module_file @@ -579,7 +579,7 @@ def test_max_return_type(self): ir = self._make_ir("struct Foo:\n" " $max(1, 2, 3) [+1] UInt:8[] x\n") expression = ir.module[0].type[0].structure.field[0].location.start self.assertEqual([], type_check.annotate_types(ir)) - self.assertEqual("integer", expression.type.WhichOneof("type")) + self.assertEqual("integer", expression.type.which_type) def test_error_on_bad_max_argument(self): ir = self._make_ir( @@ -622,7 +622,7 @@ def test_upper_bound_return_type(self): ir = self._make_ir("struct Foo:\n" " $upper_bound(3) [+1] UInt:8[] x\n") expression = ir.module[0].type[0].structure.field[0].location.start self.assertEqual([], type_check.annotate_types(ir)) - self.assertEqual("integer", expression.type.WhichOneof("type")) + self.assertEqual("integer", expression.type.which_type) def test_upper_bound_too_few_arguments(self): ir = self._make_ir("struct Foo:\n" " $upper_bound() [+1] UInt:8[] x\n") @@ -681,7 +681,7 @@ def test_lower_bound_return_type(self): ir = self._make_ir("struct Foo:\n" " $lower_bound(3) [+1] UInt:8[] x\n") expression = ir.module[0].type[0].structure.field[0].location.start self.assertEqual([], type_check.annotate_types(ir)) - self.assertEqual("integer", expression.type.WhichOneof("type")) + self.assertEqual("integer", expression.type.which_type) def test_lower_bound_too_few_arguments(self): ir = self._make_ir("struct Foo:\n" " $lower_bound() [+1] UInt:8[] x\n") diff --git a/compiler/front_end/write_inference.py b/compiler/front_end/write_inference.py index 8353306..bd4e2ba 100644 --- a/compiler/front_end/write_inference.py +++ b/compiler/front_end/write_inference.py @@ -51,9 +51,9 @@ def _find_field_reference_path(expression): def _recursively_find_field_reference_path(expression): """Recursive implementation of _find_field_reference_path.""" - if expression.WhichOneof("expression") == "field_reference": + if expression.which_expression == "field_reference": return 1, [] - elif expression.WhichOneof("expression") == "function": + elif expression.which_expression == "function": field_count = 0 path = [] for index in range(len(expression.function.args)): @@ -228,7 +228,7 @@ def _add_write_method(field, ir): # requirement. requires_attr = ir_util.get_attribute(field.attribute, attributes.REQUIRES) if ( - field_checker.read_transform.WhichOneof("expression") != "field_reference" + field_checker.read_transform.which_expression != "field_reference" or requires_attr is not None ): inverse = _invert_expression(field.read_transform, ir) diff --git a/compiler/front_end/write_inference_test.py b/compiler/front_end/write_inference_test.py index c6afa2f..f4ea8bf 100644 --- a/compiler/front_end/write_inference_test.py +++ b/compiler/front_end/write_inference_test.py @@ -185,7 +185,7 @@ def test_does_not_add_transform_write_method_for_parameter_target(self): ir = self._make_ir("struct Foo(x: UInt:8):\n" " let y = 50 + x\n") self.assertEqual([], write_inference.set_write_methods(ir)) field = ir.module[0].type[0].structure.field[0] - self.assertEqual("read_only", field.write_method.WhichOneof("method")) + self.assertEqual("read_only", field.write_method.which_method) def test_adds_transform_write_method_with_complex_auxiliary_subexpression(self): ir = self._make_ir( diff --git a/compiler/util/attribute_util.py b/compiler/util/attribute_util.py index 2ebbffa..d20f191 100644 --- a/compiler/util/attribute_util.py +++ b/compiler/util/attribute_util.py @@ -57,7 +57,7 @@ def _is_constant_boolean(attr, module_source_file): def _is_boolean(attr, module_source_file): """Checks if the given attr is a boolean.""" - if attr.value.expression.type.WhichOneof("type") != "boolean": + if attr.value.expression.type.which_type != "boolean": return [ [ error.error( @@ -76,7 +76,7 @@ def _is_constant_integer(attr, module_source_file): """Checks if the given attr is an integer constant expression.""" if ( not attr.value.HasField("expression") - or attr.value.expression.type.WhichOneof("type") != "integer" + or attr.value.expression.type.which_type != "integer" ): return [ [ diff --git a/compiler/util/ir_data.py b/compiler/util/ir_data.py index fda124e..cd12b96 100644 --- a/compiler/util/ir_data.py +++ b/compiler/util/ir_data.py @@ -106,20 +106,6 @@ def HasField(self, name): # pylint:disable=invalid-name """Indicates if this class has the given field defined and it is set.""" return getattr(self, name, None) is not None - # Non-PEP8 name to mimic the Google Protobuf interface. - def WhichOneof(self, oneof_name): # pylint:disable=invalid-name - """Indicates which field has been set for the oneof value. - - Args: - oneof_name: the name of the oneof construct to test. - - Returns: the field name, or None if no field has been set. - """ - for field_name, oneof in self.field_specs.oneof_mappings: - if oneof == oneof_name and self.HasField(field_name): - return field_name - return None - ################################################################################ # From here to the end of the file are actual structure definitions. diff --git a/compiler/util/ir_data_fields.py b/compiler/util/ir_data_fields.py index 76df36c..53be16f 100644 --- a/compiler/util/ir_data_fields.py +++ b/compiler/util/ir_data_fields.py @@ -399,18 +399,23 @@ def __init__(self, oneof: str) -> None: super().__init__() self.oneof = oneof self.owner_type = None - self.proxy_name: str = "" + self.proxy_name: str = f"_value_{oneof}" + self.proxy_choice_name: str = f"which_{oneof}" self.name: str = "" def __set_name__(self, owner, name): self.name = name - self.proxy_name = f"_{name}" self.owner_type = owner - # Add our empty proxy field to the class. + # Add the empty proxy fields to the class. This may re-initialize + # these if another field in this oneof got there first. setattr(owner, self.proxy_name, None) + setattr(owner, self.proxy_choice_name, None) def __get__(self, obj, objtype=None): - return getattr(obj, self.proxy_name) + if getattr(obj, self.proxy_choice_name, None) == self.name: + return getattr(obj, self.proxy_name) + else: + return None def __set__(self, obj, value): if value is self: @@ -418,15 +423,13 @@ def __set__(self, obj, value): # default to None. value = None - if value is not None: - # Clear the others - for name, oneof in IrDataclassSpecs.get_specs( - self.owner_type - ).oneof_mappings: - if oneof == self.oneof and name != self.name: - setattr(obj, name, None) - - setattr(obj, self.proxy_name, value) + if value is None: + if getattr(obj, self.proxy_choice_name) == self.name: + setattr(obj, self.proxy_name, None) + setattr(obj, self.proxy_choice_name, None) + else: + setattr(obj, self.proxy_name, value) + setattr(obj, self.proxy_choice_name, self.name) def oneof_field(name: str): diff --git a/compiler/util/ir_data_utils.py b/compiler/util/ir_data_utils.py index f880231..a79f060 100644 --- a/compiler/util/ir_data_utils.py +++ b/compiler/util/ir_data_utils.py @@ -277,7 +277,7 @@ def __getattribute__(self, name: str) -> Any: if ir is None: return object.__getattribute__(self, name) - if name in ("HasField", "WhichOneof"): + if name in ("HasField",): return getattr(ir, name) field_spec = field_specs(ir).get(name) @@ -362,8 +362,13 @@ def __getattribute__( if isinstance(ir_or_spec, ir_data_fields.FieldSpec): if name == "HasField": return lambda x: False - if name == "WhichOneof": - return lambda x: None + # This *should* be limited to only the `which_` attributes that + # correspond to real oneofs, but that would add complexity and + # runtime, and the odds are low that laxness here causes a bug + # -- the same code needs to run against real IR objects that + # will raise if a nonexistent `which_` field is accessed. + if name.startswith("which_"): + return None return object.__getattribute__(ir_or_spec, name) if isinstance(ir_or_spec, ir_data_fields.FieldSpec): diff --git a/compiler/util/ir_util.py b/compiler/util/ir_util.py index 2d8763b..c1de244 100644 --- a/compiler/util/ir_util.py +++ b/compiler/util/ir_util.py @@ -72,7 +72,7 @@ def get_integer_attribute(attribute_list, name, default_value=None): attribute_value = get_attribute(attribute_list, name) if ( not attribute_value - or attribute_value.expression.type.WhichOneof("type") != "integer" + or attribute_value.expression.type.which_type != "integer" or not is_constant(attribute_value.expression) ): return default_value @@ -98,42 +98,42 @@ def constant_value(expression, bindings=None): if expression is None: return None expression = ir_data_utils.reader(expression) - if expression.WhichOneof("expression") == "constant": + if expression.which_expression == "constant": return int(expression.constant.value or 0) - elif expression.WhichOneof("expression") == "constant_reference": + elif expression.which_expression == "constant_reference": # We can't look up the constant reference without the IR, but by the time # constant_value is called, the actual values should have been propagated to # the type information. - if expression.type.WhichOneof("type") == "integer": + if expression.type.which_type == "integer": assert expression.type.integer.modulus == "infinity" return int(expression.type.integer.modular_value) - elif expression.type.WhichOneof("type") == "boolean": + elif expression.type.which_type == "boolean": assert expression.type.boolean.HasField("value") return expression.type.boolean.value - elif expression.type.WhichOneof("type") == "enumeration": + elif expression.type.which_type == "enumeration": assert expression.type.enumeration.HasField("value") return int(expression.type.enumeration.value) else: assert False, "Unexpected expression type {}".format( - expression.type.WhichOneof("type") + expression.type.which_type ) - elif expression.WhichOneof("expression") == "function": + elif expression.which_expression == "function": return _constant_value_of_function(expression.function, bindings) - elif expression.WhichOneof("expression") == "field_reference": + elif expression.which_expression == "field_reference": return None - elif expression.WhichOneof("expression") == "boolean_constant": + elif expression.which_expression == "boolean_constant": return expression.boolean_constant.value - elif expression.WhichOneof("expression") == "builtin_reference": + elif expression.which_expression == "builtin_reference": name = expression.builtin_reference.canonical_name.object_path[0] if bindings and name in bindings: return bindings[name] else: return None - elif expression.WhichOneof("expression") is None: + elif expression.which_expression is None: return None else: assert False, "Unexpected expression kind {}".format( - expression.WhichOneof("expression") + expression.which_expression ) @@ -366,11 +366,11 @@ def fixed_size_of_type_in_bits(type_ir, ir): """ array_multiplier = 1 while type_ir.HasField("array_type"): - if type_ir.array_type.WhichOneof("size") == "automatic": + if type_ir.array_type.which_size == "automatic": return None else: assert ( - type_ir.array_type.WhichOneof("size") == "element_count" + type_ir.array_type.which_size == "element_count" ), 'Expected array size to be "automatic" or "element_count".' element_count = type_ir.array_type.element_count if not is_constant(element_count):