Skip to content

Commit

Permalink
Add ir_data_utils.copy and update
Browse files Browse the repository at this point in the history
This adds `ir_data_utils` helpers to handle copying and updating
`ir_data` instances. Some instances of `CopyFrom` that aren't associated
with a `builder` are updated to use `copy` or `update`.

Part of #118.
  • Loading branch information
EricRahm committed May 31, 2024
1 parent dd186bb commit 899f0f3
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 28 deletions.
3 changes: 1 addition & 2 deletions compiler/back_end/cpp/header_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,8 +1465,7 @@ def _offset_source_location_column(source_location, offset):
Offset should be a tuple of (start, end), which are the offsets relative to
source_location.start.column to set the new start.column and end.column."""

new_location = ir_data.Location()
new_location.CopyFrom(source_location)
new_location = ir_data_utils.copy(source_location)
new_location.start.column = source_location.start.column + offset[0]
new_location.end.column = source_location.start.column + offset[1]

Expand Down
6 changes: 2 additions & 4 deletions compiler/front_end/glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,7 @@ def parse_module_text(source_code, file_name):
# need to re-parse the prelude for every test .emb.
if (source_code, file_name) in _cached_modules:
debug_info = _cached_modules[source_code, file_name]
ir = ir_data.Module()
ir.CopyFrom(debug_info.ir)
ir = ir_data_utils.copy(debug_info.ir)
else:
debug_info = ModuleDebugInfo(file_name)
debug_info.source_code = source_code
Expand All @@ -163,8 +162,7 @@ def parse_module_text(source_code, file_name):
ir = module_ir.build_ir(parse_result.parse_tree, used_productions)
ir.source_text = source_code
debug_info.used_productions = used_productions
debug_info.ir = ir_data.Module()
debug_info.ir.CopyFrom(ir)
debug_info.ir = ir_data_utils.copy(ir)
_cached_modules[source_code, file_name] = debug_info
ir.source_file_name = file_name
return _IrDebugInfo(ir, debug_info, [])
Expand Down
3 changes: 1 addition & 2 deletions compiler/front_end/glue_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,7 @@ def test_circular_dependency_error(self):
self.assertFalse(ir)

def test_ir_from_parse_module(self):
log_file_path_ir = ir_data.Module()
log_file_path_ir.CopyFrom(_SPAN_SE_LOG_FILE_IR)
log_file_path_ir = ir_data_utils.copy(_SPAN_SE_LOG_FILE_IR)
log_file_path_ir.source_file_name = _SPAN_SE_LOG_FILE_PATH
self.assertEqual(log_file_path_ir, glue.parse_module(
_SPAN_SE_LOG_FILE_PATH, _SPAN_SE_LOG_FILE_READER).ir)
Expand Down
20 changes: 14 additions & 6 deletions compiler/front_end/module_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,10 @@ def _really_build_ir(parse_tree, used_productions):
used_productions.add(parse_tree.production)
result = _handlers[parse_tree.production](*parsed_children)
if parse_tree.source_location is not None:
result.source_location.CopyFrom(parse_tree.source_location)
if result.source_location:
ir_data_utils.update(result.source_location, parse_tree.source_location)
else:
result.source_location = ir_data_utils.copy(parse_tree.source_location)
return result
else:
# For leaf nodes, the temporary "IR" is just the token. Higher-level rules
Expand Down Expand Up @@ -798,7 +801,10 @@ def _structure(struct, name, parameters, colon, comment, newline, struct_body):
struct.source_location.start)
ir_data_utils.builder(struct_body.structure).source_location.end.CopyFrom(
struct_body.source_location.end)
struct_body.name.CopyFrom(name)
if struct_body.name:
ir_data_utils.update(struct_body.name, name)
else:
struct_body.name = ir_data_utils.copy(name)
if parameters.list:
struct_body.runtime_parameter.extend(parameters.list[0].list)
return struct_body
Expand Down Expand Up @@ -1059,8 +1065,7 @@ def _inline_type_field(location, name, abbreviation, body):
# the user wants to use type attributes, they should create a separate type
# definition and reference it.
del body.attribute[:]
type_name = ir_data.NameDefinition()
type_name.CopyFrom(name)
type_name = ir_data_utils.copy(name)
ir_data_utils.builder(type_name).name.text = name_conversion.snake_to_camel(type_name.name.text)
field.type.atomic_type.reference.source_name.extend([type_name.name])
field.type.atomic_type.reference.source_location.CopyFrom(
Expand Down Expand Up @@ -1166,7 +1171,10 @@ def _enum_value_body(indent, docs, attributes, dedent):
def _external(external, name, colon, comment, newline, external_body):
del colon, comment, newline # Unused.
ir_data_utils.builder(external_body.source_location).start.CopyFrom(external.source_location.start)
external_body.name.CopyFrom(name)
if external_body.name:
ir_data_utils.update(external_body.name, name)
else:
external_body.name = ir_data_utils.copy(name)
return external_body


Expand Down Expand Up @@ -1218,7 +1226,7 @@ def _type(reference, parameters, size, array_spec):
atomic_type_source_location_end)
t = ir_data.Type(
atomic_type=ir_data.AtomicType(
reference=reference,
reference=ir_data_utils.copy(reference),
source_location=atomic_type_location,
runtime_parameter=parameters.list[0].list if parameters.list else []),
size_in_bits=size.list[0] if size.list else None,
Expand Down
3 changes: 1 addition & 2 deletions compiler/front_end/symbol_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,8 +417,7 @@ def _resolve_field_reference(field_reference, source_file_name, errors, ir):
previous_reference.source_name[0].text))
return
assert previous_field.type.WhichOneof("type") == "atomic_type"
member_name = ir_data.CanonicalName()
member_name.CopyFrom(
member_name = ir_data_utils.copy(
previous_field.type.atomic_type.reference.canonical_name)
ir_data_utils.builder(member_name).object_path.extend([ref.source_name[0].text])
previous_field = ir_util.find_object_or_none(member_name, ir)
Expand Down
17 changes: 7 additions & 10 deletions compiler/front_end/synthetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,14 @@ def _add_anonymous_aliases(structure, type_definition):
ir_data.Reference(source_name=[subfield.name.name]),
]
)
new_existence_condition = ir_data.Expression()
new_existence_condition.CopyFrom(_ANONYMOUS_BITS_ALIAS_EXISTENCE_SKELETON)
new_existence_condition = ir_data_utils.copy(_ANONYMOUS_BITS_ALIAS_EXISTENCE_SKELETON)
existence_clauses = ir_data_utils.builder(new_existence_condition).function.args
existence_clauses[0].function.args[0].field_reference.CopyFrom(
anonymous_field_reference)
existence_clauses[1].function.args[0].field_reference.CopyFrom(
alias_field_reference)
new_read_transform = ir_data.Expression(
field_reference=alias_field_reference)
field_reference=ir_data_utils.copy(alias_field_reference))
# This treats *most* of the alias field as synthetic, but not its name(s):
# leaving the name(s) as "real" means that symbol collisions with the
# surrounding structure will be properly reported to the user.
Expand All @@ -128,7 +127,7 @@ def _add_anonymous_aliases(structure, type_definition):
new_alias = ir_data.Field(
read_transform=new_read_transform,
existence_condition=new_existence_condition,
name=subfield.name)
name=ir_data_utils.copy(subfield.name))
if subfield.HasField("abbreviation"):
ir_data_utils.builder(new_alias).abbreviation.CopyFrom(subfield.abbreviation)
_mark_as_synthetic(new_alias.existence_condition)
Expand Down Expand Up @@ -195,16 +194,14 @@ def _add_size_virtuals(structure, type_definition):
# to the size of the structure.
if ir_util.field_is_virtual(field):
continue
size_clause = ir_data.Expression()
size_clause.CopyFrom(_SIZE_CLAUSE_SKELETON)
size_clause = ir_data_utils.builder(size_clause)
size_clause_ir = ir_data_utils.copy(_SIZE_CLAUSE_SKELETON)
size_clause = ir_data_utils.builder(size_clause_ir)
# Copy the appropriate clauses into `existence_condition ? start + size : 0`
size_clause.function.args[0].CopyFrom(field.existence_condition)
size_clause.function.args[1].function.args[0].CopyFrom(field.location.start)
size_clause.function.args[1].function.args[1].CopyFrom(field.location.size)
size_clauses.append(size_clause)
size_expression = ir_data.Expression()
size_expression.CopyFrom(_SIZE_SKELETON)
size_clauses.append(size_clause_ir)
size_expression = ir_data_utils.copy(_SIZE_SKELETON)
size_expression.function.args.extend(size_clauses)
_mark_as_synthetic(size_expression)
size_field = ir_data.Field(
Expand Down
3 changes: 1 addition & 2 deletions compiler/util/attribute_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,7 @@ def gather_default_attributes(obj, defaults):
defaults = defaults.copy()
for attr in obj.attribute:
if attr.is_default:
defaulted_attr = ir_data.Attribute()
defaulted_attr.CopyFrom(attr)
defaulted_attr = ir_data_utils.copy(attr)
defaulted_attr.is_default = False
defaults[attr.name.text] = defaulted_attr
return {"defaults": defaults}
15 changes: 15 additions & 0 deletions compiler/util/ir_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,18 @@ def reader(obj: ir_data.Message) -> ir_data.Message:
This is a no-op and just used for annotation for now.
"""
return obj


def copy(ir: ir_data.Message | None) -> ir_data.Message | None:
"""Creates a copy of the given IR data class"""
if not ir:
return None
ir_class = type(ir)
ir_copy = ir_class()
update(ir_copy, ir)
return ir_copy


def update(ir: ir_data.Message, template: ir_data.Message):
"""Updates `ir`s fields with all set fields in the template."""
ir.CopyFrom(template)

0 comments on commit 899f0f3

Please sign in to comment.