diff --git a/dace/codegen/codegen.py b/dace/codegen/codegen.py index c502a47376..b7eed49f17 100644 --- a/dace/codegen/codegen.py +++ b/dace/codegen/codegen.py @@ -6,7 +6,7 @@ import dace from dace import dtypes from dace import data -from dace.sdfg import SDFG +from dace.sdfg import SDFG, utils as sdutils from dace.codegen.targets import framecode from dace.codegen.codeobject import CodeObject from dace.config import Config @@ -178,6 +178,9 @@ def generate_code(sdfg, validate=True) -> List[CodeObject]: shutil.move(f"{tmp_dir}/test2.sdfg", "test2.sdfg") raise RuntimeError('SDFG serialization failed - files do not match') + # Convert any loop constructs with hierarchical loop regions into simple 1-level state machine loops. + # TODO (later): Adapt codegen to deal with hierarchical CFGs instead. + sdutils.inline_loop_blocks(sdfg) # Before generating the code, run type inference on the SDFG connectors infer_types.infer_connector_types(sdfg) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index a3c6eed168..ccc30df6ca 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -27,6 +27,7 @@ if TYPE_CHECKING: import dace.sdfg.scope + from dace.sdfg import SDFG NodeT = Union[nd.Node, 'ControlFlowBlock'] @@ -100,7 +101,7 @@ def out_degree(self, node: NodeT) -> int: ... @property - def sdfg(self) -> 'dace.sdfg.SDFG': + def sdfg(self) -> 'SDFG': ... ################################################################### @@ -777,7 +778,7 @@ def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: def unordered_arglist(self, defined_syms=None, shared_transients=None) -> Tuple[Dict[str, dt.Data], Dict[str, dt.Data]]: - sdfg: 'dace.sdfg.SDFG' = self.sdfg + sdfg: 'SDFG' = self.sdfg shared_transients = shared_transients or sdfg.shared_transients() sdict = self.scope_dict() @@ -1077,7 +1078,7 @@ class ControlFlowBlock(BlockGraphView, abc.ABC): def __init__(self, label: str='', - sdfg: Optional['dace.SDFG'] = None, + sdfg: Optional['SDFG'] = None, parent: Optional['ControlFlowRegion'] = None): super(ControlFlowBlock, self).__init__() self._label = label @@ -1121,11 +1122,11 @@ def name(self) -> str: return self._label @property - def sdfg(self) -> 'dace.SDFG': + def sdfg(self) -> 'SDFG': return self._sdfg @sdfg.setter - def sdfg(self, sdfg: 'dace.SDFG'): + def sdfg(self, sdfg: 'SDFG'): self._sdfg = sdfg @property @@ -1512,7 +1513,7 @@ def add_tasklet( def add_nested_sdfg( self, - sdfg: 'dace.sdfg.SDFG', + sdfg: 'SDFG', parent, inputs: Union[Set[str], Dict[str, dtypes.typeclass]], outputs: Union[Set[str], Dict[str, dtypes.typeclass]], @@ -2344,7 +2345,7 @@ def __init__(self, graph, subgraph_nodes): super().__init__(graph, subgraph_nodes) @property - def sdfg(self) -> 'dace.sdfg.SDFG': + def sdfg(self) -> 'SDFG': state: SDFGState = self.graph return state.sdfg @@ -2353,7 +2354,7 @@ def sdfg(self) -> 'dace.sdfg.SDFG': class ControlFlowRegion(OrderedDiGraph[ControlFlowBlock, 'dace.sdfg.InterstateEdge'], ControlGraphView, ControlFlowBlock): - def __init__(self, label: str='', sdfg: Optional['dace.SDFG'] = None): + def __init__(self, label: str='', sdfg: Optional['SDFG'] = None): OrderedDiGraph.__init__(self) ControlGraphView.__init__(self) ControlFlowBlock.__init__(self, label, sdfg) @@ -2523,7 +2524,7 @@ def all_control_flow_regions(self, recursive=False) -> Iterator['ControlFlowRegi elif isinstance(block, ControlFlowRegion): yield from block.all_control_flow_regions(recursive=recursive) - def all_sdfgs_recursive(self) -> Iterator['dace.SDFG']: + def all_sdfgs_recursive(self) -> Iterator['SDFG']: """ Iterate over this and all nested SDFGs. """ for cfg in self.all_control_flow_regions(recursive=True): if isinstance(cfg, dace.SDFG): @@ -2590,6 +2591,25 @@ def start_block(self, block_id): @make_properties class LoopRegion(ControlFlowRegion): + """ + A control flow region that represents a loop. + + Like in traditional programming languages, a loop has a condition that is checked before each iteration. + It may have zero or more initialization statements that are executed before the first loop iteration, and zero or + more update statements that are executed after each iteration. For example, a loop with only a condition and neither + an initialization nor an update statement is equivalent to a while loop, while a loop with initialization and update + statements represents a for loop. Loops may additionally be inverted, meaning that the condition is checked after + the first iteration instead of before. + + A loop region, like any other control flow region, has a single distinct entry / start block, and one or more + exit blocks. Exit blocks are blocks that have no outgoing edges or only conditional outgoing edges. Whenever an + exit block finshes executing, one iteration of the loop is completed. + + Loops may have an arbitrary number of break states. Whenever a break state finishes executing, the loop is exited + immediately. A loop may additionally have an arbitrary number of continue states. Whenever a continue state finishes + executing, the next iteration of the loop is started immediately (with execution of the update statement(s), if + present). + """ update_statement = CodeProperty(optional=True, allow_none=True, default=None, desc='The loop update statement. May be None if the update happens elsewhere.') @@ -2673,8 +2693,7 @@ def replace_dict(self, repl: Dict[str, str], def to_json(self, parent=None): return super().to_json(parent) - def add_node(self, node, is_start_block=False, is_continue=False, is_break=False, *, is_start_state: bool = None): - super().add_node(node, is_start_block, is_start_state=is_start_state) + def _add_node_internal(self, node, is_continue=False, is_break=False): if is_continue: if is_break: raise ValueError('Cannot set both is_continue and is_break') @@ -2684,15 +2703,12 @@ def add_node(self, node, is_start_block=False, is_continue=False, is_break=False raise ValueError('Cannot set both is_continue and is_break') self.break_states.add(self.node_id(node)) + def add_node(self, node, is_start_block=False, is_continue=False, is_break=False, *, is_start_state: bool = None): + super().add_node(node, is_start_block, is_start_state=is_start_state) + self._add_node_internal(node, is_continue, is_break) + def add_state(self, label=None, is_start_block=False, is_continue=False, is_break=False, *, is_start_state: bool = None) -> SDFGState: state = super().add_state(label, is_start_block, is_start_state=is_start_state) - if is_continue: - if is_break: - raise ValueError('Cannot set both is_continue and is_break') - self.continue_states.add(self.node_id(state)) - if is_break: - if is_continue: - raise ValueError('Cannot set both is_continue and is_break') - self.break_states.add(self.node_id(state)) + self._add_node_internal(state, is_continue, is_break) return state diff --git a/dace/transformation/interstate/control_flow_inline.py b/dace/transformation/interstate/control_flow_inline.py index d0dd5c8f3c..b86317b8ed 100644 --- a/dace/transformation/interstate/control_flow_inline.py +++ b/dace/transformation/interstate/control_flow_inline.py @@ -27,6 +27,17 @@ def expressions(cls): return [sdutil.node_path_graph(cls.loop)] def can_be_applied(self, graph: ControlFlowRegion, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: + # Check that the loop initialization and update statements each only contain assignments, if the loop has any. + if self.loop.init_statement is not None: + if isinstance(self.loop.init_statement.code, list): + for stmt in self.loop.init_statement.code: + if not isinstance(stmt, astutils.ast.Assign): + return False + if self.loop.update_statement is not None: + if isinstance(self.loop.update_statement.code, list): + for stmt in self.loop.update_statement.code: + if not isinstance(stmt, astutils.ast.Assign): + return False return True def apply(self, graph: ControlFlowRegion, sdfg: SDFG) -> Optional[int]: @@ -71,9 +82,10 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG) -> Optional[int]: # Add an initialization edge that initializes the loop variable if applicable. init_edge = InterstateEdge() if self.loop.init_statement is not None: - init_edge.assignments = { - self.loop.loop_variable: self.loop.init_statement.as_string.rpartition('=')[2].strip() - } + init_edge.assignments = {} + for stmt in self.loop.init_statement.code: + assign: astutils.ast.Assign = stmt + init_edge.assignments[assign.targets[0].id] = astutils.unparse(assign.value) if self.loop.inverted: parent.add_edge(init_state, internal_start, init_edge) else: @@ -82,9 +94,10 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG) -> Optional[int]: # Connect the loop tail. update_edge = InterstateEdge() if self.loop.update_statement is not None: - update_edge.assignments = { - self.loop.loop_variable: self.loop.update_statement.as_string.rpartition('=')[2].strip() - } + update_edge.assignments = {} + for stmt in self.loop.update_statement.code: + assign: astutils.ast.Assign = stmt + update_edge.assignments[assign.targets[0].id] = astutils.unparse(assign.value) parent.add_edge(loop_tail_state, guard_state, update_edge) # Add condition checking edges and connect the guard state. diff --git a/tests/transformations/control_flow_inline_test.py b/tests/transformations/control_flow_inline_test.py index bca34ee584..106a955143 100644 --- a/tests/transformations/control_flow_inline_test.py +++ b/tests/transformations/control_flow_inline_test.py @@ -222,6 +222,68 @@ def test_loop_inlining_for_continue_break(): assert len(sdfg.edges_between(state2, tail_state)) == 1 +def test_loop_inlining_multi_assignments(): + sdfg = dace.SDFG('inlining') + sdfg.add_symbol('j', dace.int32) + state0 = sdfg.add_state('state0', is_start_block=True) + loop1 = LoopRegion(label='loop1', condition_expr='i < 10', loop_var='i', initialize_expr='i = 0; j = 10 + 200 - 1', + update_expr='i = i + 1; j = j + i', inverted=False) + sdfg.add_node(loop1) + state1 = loop1.add_state('state1', is_start_block=True) + state2 = loop1.add_state('state2') + loop1.add_edge(state1, state2, dace.InterstateEdge()) + state3 = sdfg.add_state('state3') + sdfg.add_edge(state0, loop1, dace.InterstateEdge()) + sdfg.add_edge(loop1, state3, dace.InterstateEdge()) + + sdutils.inline_loop_blocks(sdfg) + + states = sdfg.nodes() # Get top-level states only, not all (.states()), in case something went wrong + assert len(states) == 8 + assert state0 in states + assert state1 in states + assert state2 in states + assert state3 in states + + guard_state = None + init_state = None + tail_state = None + for state in sdfg.states(): + if state.label == 'loop1_guard': + guard_state = state + elif state.label == 'loop1_init': + init_state = state + elif state.label == 'loop1_tail': + tail_state = state + init_edge = sdfg.edges_between(init_state, guard_state)[0] + assert 'i' in init_edge.data.assignments + assert 'j' in init_edge.data.assignments + update_edge = sdfg.edges_between(tail_state, guard_state)[0] + assert 'i' in update_edge.data.assignments + assert 'j' in update_edge.data.assignments + + +def test_loop_inlining_invalid_update_statement(): + # Inlining should not be applied here. + sdfg = dace.SDFG('inlining') + sdfg.add_symbol('j', dace.int32) + state0 = sdfg.add_state('state0', is_start_block=True) + loop1 = LoopRegion(label='loop1', condition_expr='i < 10', loop_var='i', initialize_expr='i = 0', + update_expr='i = i + 1; j < i', inverted=False) + sdfg.add_node(loop1) + state1 = loop1.add_state('state1', is_start_block=True) + state2 = loop1.add_state('state2') + loop1.add_edge(state1, state2, dace.InterstateEdge()) + state3 = sdfg.add_state('state3') + sdfg.add_edge(state0, loop1, dace.InterstateEdge()) + sdfg.add_edge(loop1, state3, dace.InterstateEdge()) + + sdutils.inline_loop_blocks(sdfg) + + nodes = sdfg.nodes() + assert len(nodes) == 3 + + if __name__ == '__main__': test_loop_inlining_regular_for() test_loop_inlining_regular_while() @@ -229,3 +291,5 @@ def test_loop_inlining_for_continue_break(): test_loop_inlining_do_for() test_inline_triple_nested_for() test_loop_inlining_for_continue_break() + test_loop_inlining_multi_assignments() + test_loop_inlining_invalid_update_statement()