Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Nov 27, 2023
1 parent f270852 commit fba7e87
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 26 deletions.
5 changes: 4 additions & 1 deletion dace/codegen/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import dace
from dace import dtypes
from dace import data
from dace.sdfg import SDFG
from dace.sdfg import SDFG, utils as sdutils
from dace.codegen.targets import framecode
from dace.codegen.codeobject import CodeObject
from dace.config import Config
Expand Down Expand Up @@ -178,6 +178,9 @@ def generate_code(sdfg, validate=True) -> List[CodeObject]:
shutil.move(f"{tmp_dir}/test2.sdfg", "test2.sdfg")
raise RuntimeError('SDFG serialization failed - files do not match')

# Convert any loop constructs with hierarchical loop regions into simple 1-level state machine loops.
# TODO (later): Adapt codegen to deal with hierarchical CFGs instead.
sdutils.inline_loop_blocks(sdfg)

# Before generating the code, run type inference on the SDFG connectors
infer_types.infer_connector_types(sdfg)
Expand Down
54 changes: 35 additions & 19 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

if TYPE_CHECKING:
import dace.sdfg.scope
from dace.sdfg import SDFG


NodeT = Union[nd.Node, 'ControlFlowBlock']
Expand Down Expand Up @@ -100,7 +101,7 @@ def out_degree(self, node: NodeT) -> int:
...

@property
def sdfg(self) -> 'dace.sdfg.SDFG':
def sdfg(self) -> 'SDFG':
...

###################################################################
Expand Down Expand Up @@ -777,7 +778,7 @@ def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]:
def unordered_arglist(self,
defined_syms=None,
shared_transients=None) -> Tuple[Dict[str, dt.Data], Dict[str, dt.Data]]:
sdfg: 'dace.sdfg.SDFG' = self.sdfg
sdfg: 'SDFG' = self.sdfg
shared_transients = shared_transients or sdfg.shared_transients()
sdict = self.scope_dict()

Expand Down Expand Up @@ -1077,7 +1078,7 @@ class ControlFlowBlock(BlockGraphView, abc.ABC):

def __init__(self,
label: str='',
sdfg: Optional['dace.SDFG'] = None,
sdfg: Optional['SDFG'] = None,
parent: Optional['ControlFlowRegion'] = None):
super(ControlFlowBlock, self).__init__()
self._label = label
Expand Down Expand Up @@ -1121,11 +1122,11 @@ def name(self) -> str:
return self._label

@property
def sdfg(self) -> 'dace.SDFG':
def sdfg(self) -> 'SDFG':
return self._sdfg

@sdfg.setter
def sdfg(self, sdfg: 'dace.SDFG'):
def sdfg(self, sdfg: 'SDFG'):
self._sdfg = sdfg

@property
Expand Down Expand Up @@ -1512,7 +1513,7 @@ def add_tasklet(

def add_nested_sdfg(
self,
sdfg: 'dace.sdfg.SDFG',
sdfg: 'SDFG',
parent,
inputs: Union[Set[str], Dict[str, dtypes.typeclass]],
outputs: Union[Set[str], Dict[str, dtypes.typeclass]],
Expand Down Expand Up @@ -2344,7 +2345,7 @@ def __init__(self, graph, subgraph_nodes):
super().__init__(graph, subgraph_nodes)

@property
def sdfg(self) -> 'dace.sdfg.SDFG':
def sdfg(self) -> 'SDFG':
state: SDFGState = self.graph
return state.sdfg

Expand All @@ -2353,7 +2354,7 @@ def sdfg(self) -> 'dace.sdfg.SDFG':
class ControlFlowRegion(OrderedDiGraph[ControlFlowBlock, 'dace.sdfg.InterstateEdge'], ControlGraphView,
ControlFlowBlock):

def __init__(self, label: str='', sdfg: Optional['dace.SDFG'] = None):
def __init__(self, label: str='', sdfg: Optional['SDFG'] = None):
OrderedDiGraph.__init__(self)
ControlGraphView.__init__(self)
ControlFlowBlock.__init__(self, label, sdfg)
Expand Down Expand Up @@ -2523,7 +2524,7 @@ def all_control_flow_regions(self, recursive=False) -> Iterator['ControlFlowRegi
elif isinstance(block, ControlFlowRegion):
yield from block.all_control_flow_regions(recursive=recursive)

def all_sdfgs_recursive(self) -> Iterator['dace.SDFG']:
def all_sdfgs_recursive(self) -> Iterator['SDFG']:
""" Iterate over this and all nested SDFGs. """
for cfg in self.all_control_flow_regions(recursive=True):
if isinstance(cfg, dace.SDFG):
Expand Down Expand Up @@ -2590,6 +2591,25 @@ def start_block(self, block_id):

@make_properties
class LoopRegion(ControlFlowRegion):
"""
A control flow region that represents a loop.
Like in traditional programming languages, a loop has a condition that is checked before each iteration.
It may have zero or more initialization statements that are executed before the first loop iteration, and zero or
more update statements that are executed after each iteration. For example, a loop with only a condition and neither
an initialization nor an update statement is equivalent to a while loop, while a loop with initialization and update
statements represents a for loop. Loops may additionally be inverted, meaning that the condition is checked after
the first iteration instead of before.
A loop region, like any other control flow region, has a single distinct entry / start block, and one or more
exit blocks. Exit blocks are blocks that have no outgoing edges or only conditional outgoing edges. Whenever an
exit block finshes executing, one iteration of the loop is completed.
Loops may have an arbitrary number of break states. Whenever a break state finishes executing, the loop is exited
immediately. A loop may additionally have an arbitrary number of continue states. Whenever a continue state finishes
executing, the next iteration of the loop is started immediately (with execution of the update statement(s), if
present).
"""

update_statement = CodeProperty(optional=True, allow_none=True, default=None,
desc='The loop update statement. May be None if the update happens elsewhere.')
Expand Down Expand Up @@ -2673,8 +2693,7 @@ def replace_dict(self, repl: Dict[str, str],
def to_json(self, parent=None):
return super().to_json(parent)

def add_node(self, node, is_start_block=False, is_continue=False, is_break=False, *, is_start_state: bool = None):
super().add_node(node, is_start_block, is_start_state=is_start_state)
def _add_node_internal(self, node, is_continue=False, is_break=False):
if is_continue:
if is_break:
raise ValueError('Cannot set both is_continue and is_break')
Expand All @@ -2684,15 +2703,12 @@ def add_node(self, node, is_start_block=False, is_continue=False, is_break=False
raise ValueError('Cannot set both is_continue and is_break')
self.break_states.add(self.node_id(node))

def add_node(self, node, is_start_block=False, is_continue=False, is_break=False, *, is_start_state: bool = None):
super().add_node(node, is_start_block, is_start_state=is_start_state)
self._add_node_internal(node, is_continue, is_break)

def add_state(self, label=None, is_start_block=False, is_continue=False, is_break=False, *,
is_start_state: bool = None) -> SDFGState:
state = super().add_state(label, is_start_block, is_start_state=is_start_state)
if is_continue:
if is_break:
raise ValueError('Cannot set both is_continue and is_break')
self.continue_states.add(self.node_id(state))
if is_break:
if is_continue:
raise ValueError('Cannot set both is_continue and is_break')
self.break_states.add(self.node_id(state))
self._add_node_internal(state, is_continue, is_break)
return state
25 changes: 19 additions & 6 deletions dace/transformation/interstate/control_flow_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,17 @@ def expressions(cls):
return [sdutil.node_path_graph(cls.loop)]

def can_be_applied(self, graph: ControlFlowRegion, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool:
# Check that the loop initialization and update statements each only contain assignments, if the loop has any.
if self.loop.init_statement is not None:
if isinstance(self.loop.init_statement.code, list):
for stmt in self.loop.init_statement.code:
if not isinstance(stmt, astutils.ast.Assign):
return False
if self.loop.update_statement is not None:
if isinstance(self.loop.update_statement.code, list):
for stmt in self.loop.update_statement.code:
if not isinstance(stmt, astutils.ast.Assign):
return False
return True

def apply(self, graph: ControlFlowRegion, sdfg: SDFG) -> Optional[int]:
Expand Down Expand Up @@ -71,9 +82,10 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG) -> Optional[int]:
# Add an initialization edge that initializes the loop variable if applicable.
init_edge = InterstateEdge()
if self.loop.init_statement is not None:
init_edge.assignments = {
self.loop.loop_variable: self.loop.init_statement.as_string.rpartition('=')[2].strip()
}
init_edge.assignments = {}
for stmt in self.loop.init_statement.code:
assign: astutils.ast.Assign = stmt
init_edge.assignments[assign.targets[0].id] = astutils.unparse(assign.value)
if self.loop.inverted:
parent.add_edge(init_state, internal_start, init_edge)
else:
Expand All @@ -82,9 +94,10 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG) -> Optional[int]:
# Connect the loop tail.
update_edge = InterstateEdge()
if self.loop.update_statement is not None:
update_edge.assignments = {
self.loop.loop_variable: self.loop.update_statement.as_string.rpartition('=')[2].strip()
}
update_edge.assignments = {}
for stmt in self.loop.update_statement.code:
assign: astutils.ast.Assign = stmt
update_edge.assignments[assign.targets[0].id] = astutils.unparse(assign.value)
parent.add_edge(loop_tail_state, guard_state, update_edge)

# Add condition checking edges and connect the guard state.
Expand Down
64 changes: 64 additions & 0 deletions tests/transformations/control_flow_inline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,10 +222,74 @@ def test_loop_inlining_for_continue_break():
assert len(sdfg.edges_between(state2, tail_state)) == 1


def test_loop_inlining_multi_assignments():
sdfg = dace.SDFG('inlining')
sdfg.add_symbol('j', dace.int32)
state0 = sdfg.add_state('state0', is_start_block=True)
loop1 = LoopRegion(label='loop1', condition_expr='i < 10', loop_var='i', initialize_expr='i = 0; j = 10 + 200 - 1',
update_expr='i = i + 1; j = j + i', inverted=False)
sdfg.add_node(loop1)
state1 = loop1.add_state('state1', is_start_block=True)
state2 = loop1.add_state('state2')
loop1.add_edge(state1, state2, dace.InterstateEdge())
state3 = sdfg.add_state('state3')
sdfg.add_edge(state0, loop1, dace.InterstateEdge())
sdfg.add_edge(loop1, state3, dace.InterstateEdge())

sdutils.inline_loop_blocks(sdfg)

states = sdfg.nodes() # Get top-level states only, not all (.states()), in case something went wrong
assert len(states) == 8
assert state0 in states
assert state1 in states
assert state2 in states
assert state3 in states

guard_state = None
init_state = None
tail_state = None
for state in sdfg.states():
if state.label == 'loop1_guard':
guard_state = state
elif state.label == 'loop1_init':
init_state = state
elif state.label == 'loop1_tail':
tail_state = state
init_edge = sdfg.edges_between(init_state, guard_state)[0]
assert 'i' in init_edge.data.assignments
assert 'j' in init_edge.data.assignments
update_edge = sdfg.edges_between(tail_state, guard_state)[0]
assert 'i' in update_edge.data.assignments
assert 'j' in update_edge.data.assignments


def test_loop_inlining_invalid_update_statement():
# Inlining should not be applied here.
sdfg = dace.SDFG('inlining')
sdfg.add_symbol('j', dace.int32)
state0 = sdfg.add_state('state0', is_start_block=True)
loop1 = LoopRegion(label='loop1', condition_expr='i < 10', loop_var='i', initialize_expr='i = 0',
update_expr='i = i + 1; j < i', inverted=False)
sdfg.add_node(loop1)
state1 = loop1.add_state('state1', is_start_block=True)
state2 = loop1.add_state('state2')
loop1.add_edge(state1, state2, dace.InterstateEdge())
state3 = sdfg.add_state('state3')
sdfg.add_edge(state0, loop1, dace.InterstateEdge())
sdfg.add_edge(loop1, state3, dace.InterstateEdge())

sdutils.inline_loop_blocks(sdfg)

nodes = sdfg.nodes()
assert len(nodes) == 3


if __name__ == '__main__':
test_loop_inlining_regular_for()
test_loop_inlining_regular_while()
test_loop_inlining_do_while()
test_loop_inlining_do_for()
test_inline_triple_nested_for()
test_loop_inlining_for_continue_break()
test_loop_inlining_multi_assignments()
test_loop_inlining_invalid_update_statement()

0 comments on commit fba7e87

Please sign in to comment.