From 899f0f322220c45408845618d03048a184ce761f Mon Sep 17 00:00:00 2001 From: Eric Rahm Date: Mon, 15 Apr 2024 13:03:37 -0700 Subject: [PATCH] Add `ir_data_utils.copy` and `update` This adds `ir_data_utils` helpers to handle copying and updating `ir_data` instances. Some instances of `CopyFrom` that aren't associated with a `builder` are updated to use `copy` or `update`. Part of #118. --- compiler/back_end/cpp/header_generator.py | 3 +-- compiler/front_end/glue.py | 6 ++---- compiler/front_end/glue_test.py | 3 +-- compiler/front_end/module_ir.py | 20 ++++++++++++++------ compiler/front_end/symbol_resolver.py | 3 +-- compiler/front_end/synthetics.py | 17 +++++++---------- compiler/util/attribute_util.py | 3 +-- compiler/util/ir_data_utils.py | 15 +++++++++++++++ 8 files changed, 42 insertions(+), 28 deletions(-) diff --git a/compiler/back_end/cpp/header_generator.py b/compiler/back_end/cpp/header_generator.py index 55c0a6e..11fcd17 100644 --- a/compiler/back_end/cpp/header_generator.py +++ b/compiler/back_end/cpp/header_generator.py @@ -1465,8 +1465,7 @@ def _offset_source_location_column(source_location, offset): Offset should be a tuple of (start, end), which are the offsets relative to source_location.start.column to set the new start.column and end.column.""" - new_location = ir_data.Location() - new_location.CopyFrom(source_location) + new_location = ir_data_utils.copy(source_location) new_location.start.column = source_location.start.column + offset[0] new_location.end.column = source_location.start.column + offset[1] diff --git a/compiler/front_end/glue.py b/compiler/front_end/glue.py index 7724da9..a1e1a5b 100644 --- a/compiler/front_end/glue.py +++ b/compiler/front_end/glue.py @@ -143,8 +143,7 @@ def parse_module_text(source_code, file_name): # need to re-parse the prelude for every test .emb. if (source_code, file_name) in _cached_modules: debug_info = _cached_modules[source_code, file_name] - ir = ir_data.Module() - ir.CopyFrom(debug_info.ir) + ir = ir_data_utils.copy(debug_info.ir) else: debug_info = ModuleDebugInfo(file_name) debug_info.source_code = source_code @@ -163,8 +162,7 @@ def parse_module_text(source_code, file_name): ir = module_ir.build_ir(parse_result.parse_tree, used_productions) ir.source_text = source_code debug_info.used_productions = used_productions - debug_info.ir = ir_data.Module() - debug_info.ir.CopyFrom(ir) + debug_info.ir = ir_data_utils.copy(ir) _cached_modules[source_code, file_name] = debug_info ir.source_file_name = file_name return _IrDebugInfo(ir, debug_info, []) diff --git a/compiler/front_end/glue_test.py b/compiler/front_end/glue_test.py index 10613d7..2f2ddc5 100644 --- a/compiler/front_end/glue_test.py +++ b/compiler/front_end/glue_test.py @@ -141,8 +141,7 @@ def test_circular_dependency_error(self): self.assertFalse(ir) def test_ir_from_parse_module(self): - log_file_path_ir = ir_data.Module() - log_file_path_ir.CopyFrom(_SPAN_SE_LOG_FILE_IR) + log_file_path_ir = ir_data_utils.copy(_SPAN_SE_LOG_FILE_IR) log_file_path_ir.source_file_name = _SPAN_SE_LOG_FILE_PATH self.assertEqual(log_file_path_ir, glue.parse_module( _SPAN_SE_LOG_FILE_PATH, _SPAN_SE_LOG_FILE_READER).ir) diff --git a/compiler/front_end/module_ir.py b/compiler/front_end/module_ir.py index 163bf6e..c9ba765 100644 --- a/compiler/front_end/module_ir.py +++ b/compiler/front_end/module_ir.py @@ -150,7 +150,10 @@ def _really_build_ir(parse_tree, used_productions): used_productions.add(parse_tree.production) result = _handlers[parse_tree.production](*parsed_children) if parse_tree.source_location is not None: - result.source_location.CopyFrom(parse_tree.source_location) + if result.source_location: + ir_data_utils.update(result.source_location, parse_tree.source_location) + else: + result.source_location = ir_data_utils.copy(parse_tree.source_location) return result else: # For leaf nodes, the temporary "IR" is just the token. Higher-level rules @@ -798,7 +801,10 @@ def _structure(struct, name, parameters, colon, comment, newline, struct_body): struct.source_location.start) ir_data_utils.builder(struct_body.structure).source_location.end.CopyFrom( struct_body.source_location.end) - struct_body.name.CopyFrom(name) + if struct_body.name: + ir_data_utils.update(struct_body.name, name) + else: + struct_body.name = ir_data_utils.copy(name) if parameters.list: struct_body.runtime_parameter.extend(parameters.list[0].list) return struct_body @@ -1059,8 +1065,7 @@ def _inline_type_field(location, name, abbreviation, body): # the user wants to use type attributes, they should create a separate type # definition and reference it. del body.attribute[:] - type_name = ir_data.NameDefinition() - type_name.CopyFrom(name) + type_name = ir_data_utils.copy(name) ir_data_utils.builder(type_name).name.text = name_conversion.snake_to_camel(type_name.name.text) field.type.atomic_type.reference.source_name.extend([type_name.name]) field.type.atomic_type.reference.source_location.CopyFrom( @@ -1166,7 +1171,10 @@ def _enum_value_body(indent, docs, attributes, dedent): def _external(external, name, colon, comment, newline, external_body): del colon, comment, newline # Unused. ir_data_utils.builder(external_body.source_location).start.CopyFrom(external.source_location.start) - external_body.name.CopyFrom(name) + if external_body.name: + ir_data_utils.update(external_body.name, name) + else: + external_body.name = ir_data_utils.copy(name) return external_body @@ -1218,7 +1226,7 @@ def _type(reference, parameters, size, array_spec): atomic_type_source_location_end) t = ir_data.Type( atomic_type=ir_data.AtomicType( - reference=reference, + reference=ir_data_utils.copy(reference), source_location=atomic_type_location, runtime_parameter=parameters.list[0].list if parameters.list else []), size_in_bits=size.list[0] if size.list else None, diff --git a/compiler/front_end/symbol_resolver.py b/compiler/front_end/symbol_resolver.py index 09675b8..5990f13 100644 --- a/compiler/front_end/symbol_resolver.py +++ b/compiler/front_end/symbol_resolver.py @@ -417,8 +417,7 @@ def _resolve_field_reference(field_reference, source_file_name, errors, ir): previous_reference.source_name[0].text)) return assert previous_field.type.WhichOneof("type") == "atomic_type" - member_name = ir_data.CanonicalName() - member_name.CopyFrom( + member_name = ir_data_utils.copy( previous_field.type.atomic_type.reference.canonical_name) ir_data_utils.builder(member_name).object_path.extend([ref.source_name[0].text]) previous_field = ir_util.find_object_or_none(member_name, ir) diff --git a/compiler/front_end/synthetics.py b/compiler/front_end/synthetics.py index d1bef9a..42b3cff 100644 --- a/compiler/front_end/synthetics.py +++ b/compiler/front_end/synthetics.py @@ -111,15 +111,14 @@ def _add_anonymous_aliases(structure, type_definition): ir_data.Reference(source_name=[subfield.name.name]), ] ) - new_existence_condition = ir_data.Expression() - new_existence_condition.CopyFrom(_ANONYMOUS_BITS_ALIAS_EXISTENCE_SKELETON) + new_existence_condition = ir_data_utils.copy(_ANONYMOUS_BITS_ALIAS_EXISTENCE_SKELETON) existence_clauses = ir_data_utils.builder(new_existence_condition).function.args existence_clauses[0].function.args[0].field_reference.CopyFrom( anonymous_field_reference) existence_clauses[1].function.args[0].field_reference.CopyFrom( alias_field_reference) new_read_transform = ir_data.Expression( - field_reference=alias_field_reference) + field_reference=ir_data_utils.copy(alias_field_reference)) # This treats *most* of the alias field as synthetic, but not its name(s): # leaving the name(s) as "real" means that symbol collisions with the # surrounding structure will be properly reported to the user. @@ -128,7 +127,7 @@ def _add_anonymous_aliases(structure, type_definition): new_alias = ir_data.Field( read_transform=new_read_transform, existence_condition=new_existence_condition, - name=subfield.name) + name=ir_data_utils.copy(subfield.name)) if subfield.HasField("abbreviation"): ir_data_utils.builder(new_alias).abbreviation.CopyFrom(subfield.abbreviation) _mark_as_synthetic(new_alias.existence_condition) @@ -195,16 +194,14 @@ def _add_size_virtuals(structure, type_definition): # to the size of the structure. if ir_util.field_is_virtual(field): continue - size_clause = ir_data.Expression() - size_clause.CopyFrom(_SIZE_CLAUSE_SKELETON) - size_clause = ir_data_utils.builder(size_clause) + size_clause_ir = ir_data_utils.copy(_SIZE_CLAUSE_SKELETON) + size_clause = ir_data_utils.builder(size_clause_ir) # Copy the appropriate clauses into `existence_condition ? start + size : 0` size_clause.function.args[0].CopyFrom(field.existence_condition) size_clause.function.args[1].function.args[0].CopyFrom(field.location.start) size_clause.function.args[1].function.args[1].CopyFrom(field.location.size) - size_clauses.append(size_clause) - size_expression = ir_data.Expression() - size_expression.CopyFrom(_SIZE_SKELETON) + size_clauses.append(size_clause_ir) + size_expression = ir_data_utils.copy(_SIZE_SKELETON) size_expression.function.args.extend(size_clauses) _mark_as_synthetic(size_expression) size_field = ir_data.Field( diff --git a/compiler/util/attribute_util.py b/compiler/util/attribute_util.py index 063f38e..6e04280 100644 --- a/compiler/util/attribute_util.py +++ b/compiler/util/attribute_util.py @@ -305,8 +305,7 @@ def gather_default_attributes(obj, defaults): defaults = defaults.copy() for attr in obj.attribute: if attr.is_default: - defaulted_attr = ir_data.Attribute() - defaulted_attr.CopyFrom(attr) + defaulted_attr = ir_data_utils.copy(attr) defaulted_attr.is_default = False defaults[attr.name.text] = defaulted_attr return {"defaults": defaults} diff --git a/compiler/util/ir_data_utils.py b/compiler/util/ir_data_utils.py index 479d974..ac02bb0 100644 --- a/compiler/util/ir_data_utils.py +++ b/compiler/util/ir_data_utils.py @@ -45,3 +45,18 @@ def reader(obj: ir_data.Message) -> ir_data.Message: This is a no-op and just used for annotation for now. """ return obj + + +def copy(ir: ir_data.Message | None) -> ir_data.Message | None: + """Creates a copy of the given IR data class""" + if not ir: + return None + ir_class = type(ir) + ir_copy = ir_class() + update(ir_copy, ir) + return ir_copy + + +def update(ir: ir_data.Message, template: ir_data.Message): + """Updates `ir`s fields with all set fields in the template.""" + ir.CopyFrom(template)