Skip to content

Commit

Permalink
Control Flow Raising (#1657)
Browse files Browse the repository at this point in the history
This PR mainly provides control flow raising passes for the new
intrinsic control flow constructs (Branches and loops) in SDFGs. In
addition to raising, the state and control flow reachability passes have
been adjusted to faithfully work with the intrinsic control flow
constructs.

Along with the raising and reachability passes, a few important bugfixes
and a general cleanup is included in the PR, but no other functionality
is changed.
  • Loading branch information
phschaad authored Oct 12, 2024
1 parent 64c54ab commit 073b613
Show file tree
Hide file tree
Showing 24 changed files with 1,468 additions and 461 deletions.
37 changes: 21 additions & 16 deletions dace/codegen/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,13 @@ def as_cpp(self, codegen, symbols) -> str:
expr += elem.as_cpp(codegen, symbols)
# In a general block, emit transitions and assignments after each individual block or region.
if isinstance(elem, BasicCFBlock) or (isinstance(elem, RegionBlock) and elem.region):
cfg = elem.state.parent_graph if isinstance(elem, BasicCFBlock) else elem.region.parent_graph
if isinstance(elem, BasicCFBlock):
g_elem = elem.state
else:
g_elem = elem.region
cfg = g_elem.parent_graph
sdfg = cfg if isinstance(cfg, SDFG) else cfg.sdfg
out_edges = cfg.out_edges(elem.state) if isinstance(elem, BasicCFBlock) else cfg.out_edges(elem.region)
out_edges = cfg.out_edges(g_elem)
for j, e in enumerate(out_edges):
if e not in self.gotos_to_ignore:
# Skip gotos to immediate successors
Expand Down Expand Up @@ -532,26 +536,27 @@ def as_cpp(self, codegen, symbols) -> str:
expr = ''

if self.loop.update_statement and self.loop.init_statement and self.loop.loop_variable:
# Initialize to either "int i = 0" or "i = 0" depending on whether the type has been defined.
defined_vars = codegen.dispatcher.defined_vars
if not defined_vars.has(self.loop.loop_variable):
try:
init = f'{symbols[self.loop.loop_variable]} '
except KeyError:
init = 'auto '
symbols[self.loop.loop_variable] = None
init += unparse_interstate_edge(self.loop.init_statement.code[0], sdfg, codegen=codegen, symbols=symbols)
init = unparse_interstate_edge(self.loop.init_statement.code[0], sdfg, codegen=codegen, symbols=symbols)
init = init.strip(';')

update = unparse_interstate_edge(self.loop.update_statement.code[0], sdfg, codegen=codegen, symbols=symbols)
update = update.strip(';')

if self.loop.inverted:
expr += f'{init};\n'
expr += 'do {\n'
expr += _clean_loop_body(self.body.as_cpp(codegen, symbols))
expr += f'{update};\n'
expr += f'\n}} while({cond});\n'
if self.loop.update_before_condition:
expr += f'{init};\n'
expr += 'do {\n'
expr += _clean_loop_body(self.body.as_cpp(codegen, symbols))
expr += f'{update};\n'
expr += f'}} while({cond});\n'
else:
expr += f'{init};\n'
expr += 'while (1) {\n'
expr += _clean_loop_body(self.body.as_cpp(codegen, symbols))
expr += f'if (!({cond}))\n'
expr += 'break;\n'
expr += f'{update};\n'
expr += '}\n'
else:
expr += f'for ({init}; {cond}; {update}) {{\n'
expr += _clean_loop_body(self.body.as_cpp(codegen, symbols))
Expand Down
24 changes: 22 additions & 2 deletions dace/codegen/targets/framecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,14 @@
from dace.codegen.prettycode import CodeIOStream
from dace.codegen.common import codeblock_to_cpp, sym2cpp
from dace.codegen.targets.target import TargetCodeGenerator
from dace.codegen.tools.type_inference import infer_expr_type
from dace.frontend.python import astutils
from dace.sdfg import SDFG, SDFGState, nodes
from dace.sdfg import scope as sdscope
from dace.sdfg import utils
from dace.sdfg.analysis import cfg as cfg_analysis
from dace.sdfg.state import ControlFlowRegion
from dace.transformation.passes.analysis import StateReachability
from dace.sdfg.state import ControlFlowRegion, LoopRegion
from dace.transformation.passes.analysis import StateReachability, loop_analysis


def _get_or_eval_sdfg_first_arg(func, sdfg):
Expand Down Expand Up @@ -916,6 +918,24 @@ def generate_code(self,
interstate_symbols.update(symbols)
global_symbols.update(symbols)

if isinstance(cfr, LoopRegion) and cfr.loop_variable is not None and cfr.init_statement is not None:
init_assignment = cfr.init_statement.code[0]
update_assignment = cfr.update_statement.code[0]
if isinstance(init_assignment, astutils.ast.Assign):
init_assignment = init_assignment.value
if isinstance(update_assignment, astutils.ast.Assign):
update_assignment = update_assignment.value
if not cfr.loop_variable in interstate_symbols:
l_end = loop_analysis.get_loop_end(cfr)
l_start = loop_analysis.get_init_assignment(cfr)
l_step = loop_analysis.get_loop_stride(cfr)
sym_type = dtypes.result_type_of(infer_expr_type(l_start, global_symbols),
infer_expr_type(l_step, global_symbols),
infer_expr_type(l_end, global_symbols))
interstate_symbols[cfr.loop_variable] = sym_type
if not cfr.loop_variable in global_symbols:
global_symbols[cfr.loop_variable] = interstate_symbols[cfr.loop_variable]

for isvarName, isvarType in interstate_symbols.items():
if isvarType is None:
raise TypeError(f'Type inference failed for symbol {isvarName}')
Expand Down
7 changes: 2 additions & 5 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -2565,8 +2565,7 @@ def visit_If(self, node: ast.If):
self._on_block_added(cond_block)

if_body = ControlFlowRegion(cond_block.label + '_body', sdfg=self.sdfg)
cond_block.branches.append((CodeBlock(cond), if_body))
if_body.parent_graph = self.cfg_target
cond_block.add_branch(CodeBlock(cond), if_body)

# Visit recursively
self._recursive_visit(node.body, 'if', node.lineno, if_body, False)
Expand All @@ -2575,9 +2574,7 @@ def visit_If(self, node: ast.If):
if len(node.orelse) > 0:
else_body = ControlFlowRegion(f'{cond_block.label}_else_{node.orelse[0].lineno}',
sdfg=self.sdfg)
#cond_block.branches.append((CodeBlock(cond_else), else_body))
cond_block.branches.append((None, else_body))
else_body.parent_graph = self.cfg_target
cond_block.add_branch(None, else_body)
# Visit recursively
self._recursive_visit(node.orelse, 'else', node.lineno, else_body, False)

Expand Down
2 changes: 2 additions & 0 deletions dace/frontend/python/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,8 @@ def _parse(self, args, kwargs, simplify=None, save=False, validate=False) -> SDF
sdutils.inline_control_flow_regions(nsdfg)
sdfg.using_experimental_blocks = self.use_experimental_cfg_blocks

sdfg.reset_cfg_list()

# Apply simplification pass automatically
if not cached and (simplify == True or
(simplify is None and Config.get_bool('optimizer', 'automatic_simplification'))):
Expand Down
15 changes: 11 additions & 4 deletions dace/sdfg/analysis/schedule_tree/treenodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,17 @@ def as_string(self, indent: int = 0):
loop = self.header.loop
if loop.update_statement and loop.init_statement and loop.loop_variable:
if loop.inverted:
pre_header = indent * INDENTATION + f'{loop.init_statement.as_string}\n'
header = indent * INDENTATION + 'do:\n'
pre_footer = (indent + 1) * INDENTATION + f'{loop.update_statement.as_string}\n'
footer = indent * INDENTATION + f'while {loop.loop_condition.as_string}'
if loop.update_before_condition:
pre_header = indent * INDENTATION + f'{loop.init_statement.as_string}\n'
header = indent * INDENTATION + 'do:\n'
pre_footer = (indent + 1) * INDENTATION + f'{loop.update_statement.as_string}\n'
footer = indent * INDENTATION + f'while {loop.loop_condition.as_string}'
else:
pre_header = indent * INDENTATION + f'{loop.init_statement.as_string}\n'
header = indent * INDENTATION + 'while True:\n'
pre_footer = (indent + 1) * INDENTATION + f'if (not {loop.loop_condition.as_string}):\n'
pre_footer += (indent + 2) * INDENTATION + 'break\n'
footer = (indent + 1) * INDENTATION + f'{loop.update_statement.as_string}\n'
return pre_header + header + super().as_string(indent) + '\n' + pre_footer + footer
else:
result = (indent * INDENTATION +
Expand Down
Loading

0 comments on commit 073b613

Please sign in to comment.