From 073b61373a57990a08d78d1843b3bb5ae0af5d0d Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Sat, 12 Oct 2024 21:13:19 +0200 Subject: [PATCH] Control Flow Raising (#1657) This PR mainly provides control flow raising passes for the new intrinsic control flow constructs (Branches and loops) in SDFGs. In addition to raising, the state and control flow reachability passes have been adjusted to faithfully work with the intrinsic control flow constructs. Along with the raising and reachability passes, a few important bugfixes and a general cleanup is included in the PR, but no other functionality is changed. --- dace/codegen/control_flow.py | 37 +- dace/codegen/targets/framecode.py | 24 +- dace/frontend/python/newast.py | 7 +- dace/frontend/python/parser.py | 2 + dace/sdfg/analysis/schedule_tree/treenodes.py | 15 +- .../analysis/writeset_underapproximation.py | 397 +++++++++--------- dace/sdfg/propagation.py | 54 +-- dace/sdfg/state.py | 38 +- dace/transformation/helpers.py | 4 +- .../interstate/loop_detection.py | 300 ++++++++++--- .../transformation/interstate/loop_lifting.py | 99 +++++ dace/transformation/pass_pipeline.py | 3 +- .../passes/analysis/__init__.py | 1 + .../passes/{ => analysis}/analysis.py | 141 +++++-- .../passes/analysis/loop_analysis.py | 116 +++++ .../simplification/control_flow_raising.py | 96 +++++ dace/transformation/subgraph/expansion.py | 9 +- dace/transformation/subgraph/helpers.py | 17 +- .../control_flow_raising_test.py | 98 +++++ .../writeset_underapproximation_test.py | 102 +++-- tests/sdfg/conditional_region_test.py | 50 +-- tests/sdfg/loop_region_test.py | 51 +++ .../interstate/loop_lifting_test.py | 217 ++++++++++ tests/transformations/loop_detection_test.py | 51 ++- 24 files changed, 1468 insertions(+), 461 deletions(-) create mode 100644 dace/transformation/interstate/loop_lifting.py create mode 100644 dace/transformation/passes/analysis/__init__.py rename dace/transformation/passes/{ => analysis}/analysis.py (81%) create mode 100644 dace/transformation/passes/analysis/loop_analysis.py create mode 100644 dace/transformation/passes/simplification/control_flow_raising.py create mode 100644 tests/passes/simplification/control_flow_raising_test.py create mode 100644 tests/transformations/interstate/loop_lifting_test.py diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index 7701a19ec2..f5559984e7 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -275,9 +275,13 @@ def as_cpp(self, codegen, symbols) -> str: expr += elem.as_cpp(codegen, symbols) # In a general block, emit transitions and assignments after each individual block or region. if isinstance(elem, BasicCFBlock) or (isinstance(elem, RegionBlock) and elem.region): - cfg = elem.state.parent_graph if isinstance(elem, BasicCFBlock) else elem.region.parent_graph + if isinstance(elem, BasicCFBlock): + g_elem = elem.state + else: + g_elem = elem.region + cfg = g_elem.parent_graph sdfg = cfg if isinstance(cfg, SDFG) else cfg.sdfg - out_edges = cfg.out_edges(elem.state) if isinstance(elem, BasicCFBlock) else cfg.out_edges(elem.region) + out_edges = cfg.out_edges(g_elem) for j, e in enumerate(out_edges): if e not in self.gotos_to_ignore: # Skip gotos to immediate successors @@ -532,26 +536,27 @@ def as_cpp(self, codegen, symbols) -> str: expr = '' if self.loop.update_statement and self.loop.init_statement and self.loop.loop_variable: - # Initialize to either "int i = 0" or "i = 0" depending on whether the type has been defined. - defined_vars = codegen.dispatcher.defined_vars - if not defined_vars.has(self.loop.loop_variable): - try: - init = f'{symbols[self.loop.loop_variable]} ' - except KeyError: - init = 'auto ' - symbols[self.loop.loop_variable] = None - init += unparse_interstate_edge(self.loop.init_statement.code[0], sdfg, codegen=codegen, symbols=symbols) + init = unparse_interstate_edge(self.loop.init_statement.code[0], sdfg, codegen=codegen, symbols=symbols) init = init.strip(';') update = unparse_interstate_edge(self.loop.update_statement.code[0], sdfg, codegen=codegen, symbols=symbols) update = update.strip(';') if self.loop.inverted: - expr += f'{init};\n' - expr += 'do {\n' - expr += _clean_loop_body(self.body.as_cpp(codegen, symbols)) - expr += f'{update};\n' - expr += f'\n}} while({cond});\n' + if self.loop.update_before_condition: + expr += f'{init};\n' + expr += 'do {\n' + expr += _clean_loop_body(self.body.as_cpp(codegen, symbols)) + expr += f'{update};\n' + expr += f'}} while({cond});\n' + else: + expr += f'{init};\n' + expr += 'while (1) {\n' + expr += _clean_loop_body(self.body.as_cpp(codegen, symbols)) + expr += f'if (!({cond}))\n' + expr += 'break;\n' + expr += f'{update};\n' + expr += '}\n' else: expr += f'for ({init}; {cond}; {update}) {{\n' expr += _clean_loop_body(self.body.as_cpp(codegen, symbols)) diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 488c1c7fbd..d71ea40fee 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -15,12 +15,14 @@ from dace.codegen.prettycode import CodeIOStream from dace.codegen.common import codeblock_to_cpp, sym2cpp from dace.codegen.targets.target import TargetCodeGenerator +from dace.codegen.tools.type_inference import infer_expr_type +from dace.frontend.python import astutils from dace.sdfg import SDFG, SDFGState, nodes from dace.sdfg import scope as sdscope from dace.sdfg import utils from dace.sdfg.analysis import cfg as cfg_analysis -from dace.sdfg.state import ControlFlowRegion -from dace.transformation.passes.analysis import StateReachability +from dace.sdfg.state import ControlFlowRegion, LoopRegion +from dace.transformation.passes.analysis import StateReachability, loop_analysis def _get_or_eval_sdfg_first_arg(func, sdfg): @@ -916,6 +918,24 @@ def generate_code(self, interstate_symbols.update(symbols) global_symbols.update(symbols) + if isinstance(cfr, LoopRegion) and cfr.loop_variable is not None and cfr.init_statement is not None: + init_assignment = cfr.init_statement.code[0] + update_assignment = cfr.update_statement.code[0] + if isinstance(init_assignment, astutils.ast.Assign): + init_assignment = init_assignment.value + if isinstance(update_assignment, astutils.ast.Assign): + update_assignment = update_assignment.value + if not cfr.loop_variable in interstate_symbols: + l_end = loop_analysis.get_loop_end(cfr) + l_start = loop_analysis.get_init_assignment(cfr) + l_step = loop_analysis.get_loop_stride(cfr) + sym_type = dtypes.result_type_of(infer_expr_type(l_start, global_symbols), + infer_expr_type(l_step, global_symbols), + infer_expr_type(l_end, global_symbols)) + interstate_symbols[cfr.loop_variable] = sym_type + if not cfr.loop_variable in global_symbols: + global_symbols[cfr.loop_variable] = interstate_symbols[cfr.loop_variable] + for isvarName, isvarType in interstate_symbols.items(): if isvarType is None: raise TypeError(f'Type inference failed for symbol {isvarName}') diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 0d40e13282..cacf15d785 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -2565,8 +2565,7 @@ def visit_If(self, node: ast.If): self._on_block_added(cond_block) if_body = ControlFlowRegion(cond_block.label + '_body', sdfg=self.sdfg) - cond_block.branches.append((CodeBlock(cond), if_body)) - if_body.parent_graph = self.cfg_target + cond_block.add_branch(CodeBlock(cond), if_body) # Visit recursively self._recursive_visit(node.body, 'if', node.lineno, if_body, False) @@ -2575,9 +2574,7 @@ def visit_If(self, node: ast.If): if len(node.orelse) > 0: else_body = ControlFlowRegion(f'{cond_block.label}_else_{node.orelse[0].lineno}', sdfg=self.sdfg) - #cond_block.branches.append((CodeBlock(cond_else), else_body)) - cond_block.branches.append((None, else_body)) - else_body.parent_graph = self.cfg_target + cond_block.add_branch(None, else_body) # Visit recursively self._recursive_visit(node.orelse, 'else', node.lineno, else_body, False) diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index b0ef56907f..d99be1265d 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -499,6 +499,8 @@ def _parse(self, args, kwargs, simplify=None, save=False, validate=False) -> SDF sdutils.inline_control_flow_regions(nsdfg) sdfg.using_experimental_blocks = self.use_experimental_cfg_blocks + sdfg.reset_cfg_list() + # Apply simplification pass automatically if not cached and (simplify == True or (simplify is None and Config.get_bool('optimizer', 'automatic_simplification'))): diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 619b71b770..3b447fa15a 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -162,10 +162,17 @@ def as_string(self, indent: int = 0): loop = self.header.loop if loop.update_statement and loop.init_statement and loop.loop_variable: if loop.inverted: - pre_header = indent * INDENTATION + f'{loop.init_statement.as_string}\n' - header = indent * INDENTATION + 'do:\n' - pre_footer = (indent + 1) * INDENTATION + f'{loop.update_statement.as_string}\n' - footer = indent * INDENTATION + f'while {loop.loop_condition.as_string}' + if loop.update_before_condition: + pre_header = indent * INDENTATION + f'{loop.init_statement.as_string}\n' + header = indent * INDENTATION + 'do:\n' + pre_footer = (indent + 1) * INDENTATION + f'{loop.update_statement.as_string}\n' + footer = indent * INDENTATION + f'while {loop.loop_condition.as_string}' + else: + pre_header = indent * INDENTATION + f'{loop.init_statement.as_string}\n' + header = indent * INDENTATION + 'while True:\n' + pre_footer = (indent + 1) * INDENTATION + f'if (not {loop.loop_condition.as_string}):\n' + pre_footer += (indent + 2) * INDENTATION + 'break\n' + footer = (indent + 1) * INDENTATION + f'{loop.update_statement.as_string}\n' return pre_header + header + super().as_string(indent) + '\n' + pre_footer + footer else: result = (indent * INDENTATION + diff --git a/dace/sdfg/analysis/writeset_underapproximation.py b/dace/sdfg/analysis/writeset_underapproximation.py index bfd5f4cb00..a0f84e93a6 100644 --- a/dace/sdfg/analysis/writeset_underapproximation.py +++ b/dace/sdfg/analysis/writeset_underapproximation.py @@ -1,42 +1,36 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ -Pass derived from ``propagation.py`` that under-approximates write-sets of for-loops and Maps in -an SDFG. +Pass derived from ``propagation.py`` that under-approximates write-sets of for-loops and Maps in an SDFG. """ -from collections import defaultdict import copy +from dataclasses import dataclass, field import itertools +import sys import warnings -from typing import Any, Dict, List, Set, Tuple, Type, Union +from collections import defaultdict +from typing import Dict, List, Set, Tuple, Union + +if sys.version_info >= (3, 8): + from typing import TypedDict +else: + from typing_extensions import TypedDict + import sympy import dace +from dace import SDFG, Memlet, data, dtypes, registry, subsets, symbolic +from dace.sdfg import SDFGState +from dace.sdfg import graph +from dace.sdfg import graph as gr +from dace.sdfg import nodes, scope +from dace.sdfg.analysis import cfg as cfg_analysis +from dace.sdfg.nodes import AccessNode, NestedSDFG +from dace.sdfg.state import LoopRegion from dace.symbolic import issymbolic, pystr_to_symbolic, simplify -from dace.transformation.pass_pipeline import Modifies, Pass -from dace import registry, subsets, symbolic, dtypes, data, SDFG, Memlet -from dace.sdfg.nodes import NestedSDFG, AccessNode -from dace.sdfg import nodes, SDFGState, graph as gr -from dace.sdfg.analysis import cfg from dace.transformation import pass_pipeline as ppl -from dace.sdfg import graph -from dace.sdfg import scope - -# dictionary mapping each edge to a copy of the memlet of that edge with its write set -# underapproximated -approximation_dict: Dict[graph.Edge, Memlet] = {} -# dictionary that maps loop headers to "border memlets" that are written to in the -# corresponding loop -loop_write_dict: Dict[SDFGState, Dict[str, Memlet]] = {} -# dictionary containing information about the for loops in the SDFG -loop_dict: Dict[SDFGState, Tuple[SDFGState, SDFGState, - List[SDFGState], str, subsets.Range]] = {} -# dictionary mapping each nested SDFG to the iteration variables surrounding it -iteration_variables: Dict[SDFG, Set[str]] = {} -# dictionary mapping each state to the iteration variables surrounding it -# (including the ones from surrounding SDFGs) -ranges_per_state: Dict[SDFGState, - Dict[str, subsets.Range]] = defaultdict(lambda: {}) +from dace.transformation import transformation +from dace.transformation.pass_pipeline import Modifies @registry.make_registry @@ -81,7 +75,7 @@ def can_be_applied(self, expressions, variable_context, node_range, orig_edges): # Return False if iteration variable appears in multiple dimensions # or if two iteration variables appear in the same dimension - if not self._iteration_variables_appear_multiple_times(data_dims, expressions, other_params, params): + if not self._iteration_variables_appear_only_once(data_dims, expressions, other_params, params): return False node_range = self._make_range(node_range) @@ -89,27 +83,25 @@ def can_be_applied(self, expressions, variable_context, node_range, orig_edges): for dim in range(data_dims): dexprs = [] for expr in expressions: - if isinstance(expr[dim], symbolic.SymExpr): - dexprs.append(expr[dim].expr) - elif isinstance(expr[dim], tuple): - dexprs.append( - (expr[dim][0].expr if isinstance(expr[dim][0], symbolic.SymExpr) else - expr[dim][0], expr[dim][1].expr if isinstance( - expr[dim][1], symbolic.SymExpr) else expr[dim][1], expr[dim][2].expr - if isinstance(expr[dim][2], symbolic.SymExpr) else expr[dim][2])) + expr_dim = expr[dim] + if isinstance(expr_dim, symbolic.SymExpr): + dexprs.append(expr_dim.expr) + elif isinstance(expr_dim, tuple): + dexprs.append((expr_dim[0].expr if isinstance(expr_dim[0], symbolic.SymExpr) else expr_dim[0], + expr_dim[1].expr if isinstance(expr_dim[1], symbolic.SymExpr) else expr_dim[1], + expr_dim[2].expr if isinstance(expr_dim[2], symbolic.SymExpr) else expr_dim[2])) else: - dexprs.append(expr[dim]) + dexprs.append(expr_dim) for pattern_class in SeparableUnderapproximationMemletPattern.extensions().keys(): smpattern = pattern_class() - if smpattern.can_be_applied(dexprs, variable_context, node_range, orig_edges, dim, - data_dims): + if smpattern.can_be_applied(dexprs, variable_context, node_range, orig_edges, dim, data_dims): self.patterns_per_dim[dim] = smpattern break return None not in self.patterns_per_dim - def _iteration_variables_appear_multiple_times(self, data_dims, expressions, other_params, params): + def _iteration_variables_appear_only_once(self, data_dims, expressions, other_params, params): for expr in expressions: for param in params: occured_before = False @@ -146,8 +138,7 @@ def _iteration_variables_appear_multiple_times(self, data_dims, expressions, oth def _make_range(self, node_range): return subsets.Range([(rb.expr if isinstance(rb, symbolic.SymExpr) else rb, - re.expr if isinstance( - re, symbolic.SymExpr) else re, + re.expr if isinstance(re, symbolic.SymExpr) else re, rs.expr if isinstance(rs, symbolic.SymExpr) else rs) for rb, re, rs in node_range]) @@ -160,19 +151,16 @@ def propagate(self, array, expressions, node_range): dexprs = [] for expr in expressions: - if isinstance(expr[i], symbolic.SymExpr): - dexprs.append(expr[i].expr) - elif isinstance(expr[i], tuple): - dexprs.append(( - expr[i][0].expr if isinstance( - expr[i][0], symbolic.SymExpr) else expr[i][0], - expr[i][1].expr if isinstance( - expr[i][1], symbolic.SymExpr) else expr[i][1], - expr[i][2].expr if isinstance( - expr[i][2], symbolic.SymExpr) else expr[i][2], - expr.tile_sizes[i])) + expr_i = expr[i] + if isinstance(expr_i, symbolic.SymExpr): + dexprs.append(expr_i.expr) + elif isinstance(expr_i, tuple): + dexprs.append((expr_i[0].expr if isinstance(expr_i[0], symbolic.SymExpr) else expr_i[0], + expr_i[1].expr if isinstance(expr_i[1], symbolic.SymExpr) else expr_i[1], + expr_i[2].expr if isinstance(expr_i[2], symbolic.SymExpr) else expr_i[2], + expr.tile_sizes[i])) else: - dexprs.append(expr[i]) + dexprs.append(expr_i) result[i] = smpattern.propagate(array, dexprs, node_range) @@ -417,7 +405,7 @@ def _find_unconditionally_executed_states(sdfg: SDFG) -> Set[SDFGState]: sdfg.add_edge(sink_node, dummy_sink, dace.sdfg.InterstateEdge()) # get all the nodes that are executed unconditionally in the state-machine a.k.a nodes # that dominate the sink states - dominators = cfg.all_dominators(sdfg) + dominators = cfg_analysis.all_dominators(sdfg) states = dominators[dummy_sink] # remove dummy state sdfg.remove_node(dummy_sink) @@ -689,21 +677,44 @@ def _merge_subsets(subset_a: subsets.Subset, subset_b: subsets.Subset) -> subset return subset_b +@dataclass +class UnderapproximateWritesDict: + approximation: Dict[graph.Edge, Memlet] = field(default_factory=dict) + loop_approximation: Dict[SDFGState, Dict[str, Memlet]] = field(default_factory=dict) + loops: Dict[SDFGState, + Tuple[SDFGState, SDFGState, List[SDFGState], str, subsets.Range]] = field(default_factory=dict) + + +@transformation.experimental_cfg_block_compatible class UnderapproximateWrites(ppl.Pass): + # Dictionary mapping each edge to a copy of the memlet of that edge with its write set underapproximated. + approximation_dict: Dict[graph.Edge, Memlet] + # Dictionary that maps loop headers to "border memlets" that are written to in the corresponding loop. + loop_write_dict: Dict[SDFGState, Dict[str, Memlet]] + # Dictionary containing information about the for loops in the SDFG. + loop_dict: Dict[SDFGState, Tuple[SDFGState, SDFGState, List[SDFGState], str, subsets.Range]] + # Dictionary mapping each nested SDFG to the iteration variables surrounding it. + iteration_variables: Dict[SDFG, Set[str]] + # Mapping of state to the iteration variables surrounding them, including the ones from surrounding SDFGs. + ranges_per_state: Dict[SDFGState, Dict[str, subsets.Range]] + + def __init__(self): + super().__init__() + self.approximation_dict = {} + self.loop_write_dict = {} + self.loop_dict = {} + self.iteration_variables = {} + self.ranges_per_state = defaultdict(lambda: {}) + def modifies(self) -> Modifies: - return ppl.Modifies.Everything + return ppl.Modifies.States def should_reapply(self, modified: ppl.Modifies) -> bool: - # If anything was modified, reapply - return modified & ppl.Modifies.States | ppl.Modifies.Edges | ppl.Modifies.Symbols | ppl.Modifies.Nodes - - def apply_pass( - self, sdfg: dace.SDFG, pipeline_results: Dict[str, Any] - ) -> Dict[str, Union[ - Dict[graph.Edge, Memlet], - Dict[SDFGState, Dict[str, Memlet]], - Dict[SDFGState, Tuple[SDFGState, SDFGState, List[SDFGState], str, subsets.Range]]]]: + # If anything was modified, reapply. + return modified & ppl.Modifies.Everything + + def apply_pass(self, top_sdfg: dace.SDFG, _) -> Dict[int, UnderapproximateWritesDict]: """ Applies the pass to the given SDFG. @@ -725,55 +736,71 @@ def apply_pass( :notes: The only modification this pass performs on the SDFG is splitting interstate edges. """ - # clear the global dictionaries - approximation_dict.clear() - loop_write_dict.clear() - loop_dict.clear() - iteration_variables.clear() - ranges_per_state.clear() - - # fill the approximation dictionary with the original edges as keys and the edges with the - # approximated memlets as values - for (edge, parent) in sdfg.all_edges_recursive(): - if isinstance(parent, SDFGState): - approximation_dict[edge] = copy.deepcopy(edge.data) - if not isinstance(approximation_dict[edge].subset, - subsets.SubsetUnion) and approximation_dict[edge].subset: - approximation_dict[edge].subset = subsets.SubsetUnion( - [approximation_dict[edge].subset]) - if not isinstance(approximation_dict[edge].dst_subset, - subsets.SubsetUnion) and approximation_dict[edge].dst_subset: - approximation_dict[edge].dst_subset = subsets.SubsetUnion( - [approximation_dict[edge].dst_subset]) - if not isinstance(approximation_dict[edge].src_subset, - subsets.SubsetUnion) and approximation_dict[edge].src_subset: - approximation_dict[edge].src_subset = subsets.SubsetUnion( - [approximation_dict[edge].src_subset]) - - self._underapproximate_writes_sdfg(sdfg) - - # Replace None with empty SubsetUnion in each Memlet - for entry in approximation_dict.values(): - if entry.subset is None: - entry.subset = subsets.SubsetUnion([]) - return { - "approximation": approximation_dict, - "loop_approximation": loop_write_dict, - "loops": loop_dict - } + result = defaultdict(lambda: UnderapproximateWritesDict()) + + for sdfg in top_sdfg.all_sdfgs_recursive(): + # Clear the global dictionaries. + self.approximation_dict = {} + self.loop_write_dict = {} + self.loop_dict = {} + self.iteration_variables = {} + self.ranges_per_state = defaultdict(lambda: {}) + + # fill the approximation dictionary with the original edges as keys and the edges with the + # approximated memlets as values + for (edge, parent) in sdfg.all_edges_recursive(): + if isinstance(parent, SDFGState): + self.approximation_dict[edge] = copy.deepcopy(edge.data) + if not isinstance(self.approximation_dict[edge].subset, + subsets.SubsetUnion) and self.approximation_dict[edge].subset: + self.approximation_dict[edge].subset = subsets.SubsetUnion([ + self.approximation_dict[edge].subset + ]) + if not isinstance(self.approximation_dict[edge].dst_subset, + subsets.SubsetUnion) and self.approximation_dict[edge].dst_subset: + self.approximation_dict[edge].dst_subset = subsets.SubsetUnion([ + self.approximation_dict[edge].dst_subset + ]) + if not isinstance(self.approximation_dict[edge].src_subset, + subsets.SubsetUnion) and self.approximation_dict[edge].src_subset: + self.approximation_dict[edge].src_subset = subsets.SubsetUnion([ + self.approximation_dict[edge].src_subset + ]) + + self._underapproximate_writes_sdfg(sdfg) + + # Replace None with empty SubsetUnion in each Memlet + for entry in self.approximation_dict.values(): + if entry.subset is None: + entry.subset = subsets.SubsetUnion([]) + + result[sdfg.cfg_id].approximation = self.approximation_dict + result[sdfg.cfg_id].loop_approximation = self.loop_write_dict + result[sdfg.cfg_id].loops = self.loop_dict + + return result def _underapproximate_writes_sdfg(self, sdfg: SDFG): """ Underapproximates write-sets of loops, maps and nested SDFGs in the given SDFG. """ from dace.transformation.helpers import split_interstate_edges + from dace.transformation.passes.analysis import loop_analysis split_interstate_edges(sdfg) loops = self._find_for_loops(sdfg) - loop_dict.update(loops) + self.loop_dict.update(loops) + + for region in sdfg.all_control_flow_regions(): + if isinstance(region, LoopRegion): + start = loop_analysis.get_init_assignment(region) + stop = loop_analysis.get_loop_end(region) + stride = loop_analysis.get_loop_stride(region) + for state in region.all_states(): + self.ranges_per_state[state][region.loop_variable] = subsets.Range([(start, stop, stride)]) - for state in sdfg.nodes(): - self._underapproximate_writes_state(sdfg, state) + for state in region.all_states(): + self._underapproximate_writes_state(sdfg, state) self._underapproximate_writes_loops(loops, sdfg) @@ -792,8 +819,8 @@ def _find_for_loops(self, """ # We import here to avoid cyclic imports. - from dace.transformation.interstate.loop_detection import find_for_loop from dace.sdfg import utils as sdutils + from dace.transformation.interstate.loop_detection import find_for_loop # dictionary mapping loop headers to beginstate, loopstates, looprange identified_loops = {} @@ -885,13 +912,12 @@ def _find_for_loops(self, sources=[begin], condition=lambda _, child: child != guard) - if itvar not in ranges_per_state[begin]: + if itvar not in self.ranges_per_state[begin]: for loop_state in loop_states: - ranges_per_state[loop_state][itervar] = subsets.Range([ - rng]) + self.ranges_per_state[loop_state][itervar] = subsets.Range([rng]) loop_state_list.append(loop_state) - ranges_per_state[guard][itervar] = subsets.Range([rng]) + self.ranges_per_state[guard][itervar] = subsets.Range([rng]) identified_loops[guard] = (begin, last_loop_state, loop_state_list, itvar, subsets.Range([rng])) @@ -934,8 +960,11 @@ def _underapproximate_writes_state(self, sdfg: SDFG, state: SDFGState): # approximation_dict # First, propagate nested SDFGs in a bottom-up fashion + dnodes: Set[nodes.AccessNode] = set() for node in state.nodes(): - if isinstance(node, nodes.NestedSDFG): + if isinstance(node, AccessNode): + dnodes.add(node) + elif isinstance(node, nodes.NestedSDFG): self._find_live_iteration_variables(node, sdfg, state) # Propagate memlets inside the nested SDFG. @@ -947,6 +976,15 @@ def _underapproximate_writes_state(self, sdfg: SDFG, state: SDFGState): # Process scopes from the leaves upwards self._underapproximate_writes_scope(sdfg, state, state.scope_leaves()) + # Make sure any scalar writes are also added if they have not been processed yet. + for dn in dnodes: + desc = sdfg.data(dn.data) + if isinstance(desc, data.Scalar) or (isinstance(desc, data.Array) and desc.total_size == 1): + for iedge in state.in_edges(dn): + if not iedge in self.approximation_dict: + self.approximation_dict[iedge] = copy.deepcopy(iedge.data) + self.approximation_dict[iedge]._edge = iedge + def _find_live_iteration_variables(self, nsdfg: nodes.NestedSDFG, sdfg: SDFG, @@ -963,15 +1001,14 @@ def symbol_map(mapping, symbol): return None map_iteration_variables = _collect_iteration_variables(state, nsdfg) - sdfg_iteration_variables = iteration_variables[ - sdfg] if sdfg in iteration_variables else set() - state_iteration_variables = ranges_per_state[state].keys() + sdfg_iteration_variables = self.iteration_variables[sdfg] if sdfg in self.iteration_variables else set() + state_iteration_variables = self.ranges_per_state[state].keys() iteration_variables_local = (map_iteration_variables | sdfg_iteration_variables | state_iteration_variables) mapped_iteration_variables = set( map(lambda x: symbol_map(nsdfg.symbol_mapping, x), iteration_variables_local)) if mapped_iteration_variables: - iteration_variables[nsdfg.sdfg] = mapped_iteration_variables + self.iteration_variables[nsdfg.sdfg] = mapped_iteration_variables def _underapproximate_writes_nested_sdfg( self, @@ -1025,12 +1062,11 @@ def _init_border_memlet(template_memlet: Memlet, # Collect all memlets belonging to this access node memlets = [] for edge in edges: - inside_memlet = approximation_dict[edge] + inside_memlet = self.approximation_dict[edge] memlets.append(inside_memlet) # initialize border memlet if it does not exist already if border_memlet is None: - border_memlet = _init_border_memlet( - inside_memlet, node.label) + border_memlet = _init_border_memlet(inside_memlet, node.label) # Given all of this access nodes' memlets union all the subsets to one SubsetUnion if len(memlets) > 0: @@ -1042,18 +1078,16 @@ def _init_border_memlet(template_memlet: Memlet, border_memlet.subset, subset) # collect the memlets for each loop in the NSDFG - if state in loop_write_dict: - for node_label, loop_memlet in loop_write_dict[state].items(): + if state in self.loop_write_dict: + for node_label, loop_memlet in self.loop_write_dict[state].items(): if node_label not in border_memlets: continue border_memlet = border_memlets[node_label] # initialize border memlet if it does not exist already if border_memlet is None: - border_memlet = _init_border_memlet( - loop_memlet, node_label) + border_memlet = _init_border_memlet(loop_memlet, node_label) # compute the union of the ranges to merge the subsets. - border_memlet.subset = _merge_subsets( - border_memlet.subset, loop_memlet.subset) + border_memlet.subset = _merge_subsets(border_memlet.subset, loop_memlet.subset) # Make sure any potential NSDFG symbol mapping is correctly reversed # when propagating out. @@ -1068,17 +1102,16 @@ def _init_border_memlet(template_memlet: Memlet, # Propagate the inside 'border' memlets outside the SDFG by # offsetting, and unsqueezing if necessary. for edge in parent_state.out_edges(nsdfg_node): - out_memlet = approximation_dict[edge] + out_memlet = self.approximation_dict[edge] if edge.src_conn in border_memlets: internal_memlet = border_memlets[edge.src_conn] if internal_memlet is None: out_memlet.subset = None out_memlet.dst_subset = None - approximation_dict[edge] = out_memlet + self.approximation_dict[edge] = out_memlet continue - out_memlet = _unsqueeze_memlet_subsetunion(internal_memlet, out_memlet, parent_sdfg, - nsdfg_node) - approximation_dict[edge] = out_memlet + out_memlet = _unsqueeze_memlet_subsetunion(internal_memlet, out_memlet, parent_sdfg, nsdfg_node) + self.approximation_dict[edge] = out_memlet def _underapproximate_writes_loop(self, sdfg: SDFG, @@ -1099,9 +1132,7 @@ def _underapproximate_writes_loop(self, propagate_memlet_loop will be called recursively on the outermost loopheaders """ - def _init_border_memlet(template_memlet: Memlet, - node_label: str - ): + def _init_border_memlet(template_memlet: Memlet, node_label: str): ''' Creates a Memlet with the same data as the template_memlet, stores it in the border_memlets dictionary and returns it. @@ -1111,8 +1142,7 @@ def _init_border_memlet(template_memlet: Memlet, border_memlets[node_label] = border_memlet return border_memlet - def filter_subsets(itvar: str, itrange: subsets.Range, - memlet: Memlet) -> List[subsets.Subset]: + def filter_subsets(itvar: str, itrange: subsets.Range, memlet: Memlet) -> List[subsets.Subset]: # helper method that filters out subsets that do not depend on the iteration variable # if the iteration range is symbolic @@ -1134,7 +1164,7 @@ def filter_subsets(itvar: str, itrange: subsets.Range, if rng.num_elements() == 0: return # make sure there is no break out of the loop - dominators = cfg.all_dominators(sdfg) + dominators = cfg_analysis.all_dominators(sdfg) if any(begin not in dominators[s] and not begin is s for s in loop_states): return border_memlets = defaultdict(None) @@ -1159,7 +1189,7 @@ def filter_subsets(itvar: str, itrange: subsets.Range, # collect all the subsets of the incoming memlets for the current access node for edge in edges: - inside_memlet = copy.copy(approximation_dict[edge]) + inside_memlet = copy.copy(self.approximation_dict[edge]) # filter out subsets that could become empty depending on assignments # of symbols filtered_subsets = filter_subsets( @@ -1177,35 +1207,27 @@ def filter_subsets(itvar: str, itrange: subsets.Range, self._underapproximate_writes_loop_subset(sdfg, memlets, border_memlet, sdfg.arrays[node.label], itvar, rng) - if state not in loop_write_dict: + if state not in self.loop_write_dict: continue # propagate the border memlets of nested loop - for node_label, other_border_memlet in loop_write_dict[state].items(): + for node_label, other_border_memlet in self.loop_write_dict[state].items(): # filter out subsets that could become empty depending on symbol assignments - filtered_subsets = filter_subsets( - itvar, rng, other_border_memlet) + filtered_subsets = filter_subsets(itvar, rng, other_border_memlet) if not filtered_subsets: continue - other_border_memlet.subset = subsets.SubsetUnion( - filtered_subsets) + other_border_memlet.subset = subsets.SubsetUnion(filtered_subsets) border_memlet = border_memlets.get(node_label) if border_memlet is None: - border_memlet = _init_border_memlet( - other_border_memlet, node_label) + border_memlet = _init_border_memlet(other_border_memlet, node_label) self._underapproximate_writes_loop_subset(sdfg, [other_border_memlet], border_memlet, sdfg.arrays[node_label], itvar, rng) - loop_write_dict[loop_header] = border_memlets + self.loop_write_dict[loop_header] = border_memlets - def _underapproximate_writes_loop_subset(self, - sdfg: dace.SDFG, - memlets: List[Memlet], - dst_memlet: Memlet, - arr: dace.data.Array, - itvar: str, - rng: subsets.Subset, + def _underapproximate_writes_loop_subset(self, sdfg: dace.SDFG, memlets: List[Memlet], dst_memlet: Memlet, + arr: dace.data.Array, itvar: str, rng: subsets.Subset, loop_nest_itvars: Union[Set[str], None] = None): """ Helper function that takes a list of (border) memlets, propagates them out of a @@ -1223,16 +1245,11 @@ def _underapproximate_writes_loop_subset(self, if len(memlets) > 0: params = [itvar] # get all the other iteration variables surrounding this memlet - surrounding_itvars = iteration_variables[sdfg] if sdfg in iteration_variables else set( - ) + surrounding_itvars = self.iteration_variables[sdfg] if sdfg in self.iteration_variables else set() if loop_nest_itvars: surrounding_itvars |= loop_nest_itvars - subset = self._underapproximate_subsets(memlets, - arr, - params, - rng, - use_dst=True, + subset = self._underapproximate_subsets(memlets, arr, params, rng, use_dst=True, surrounding_itvars=surrounding_itvars).subset if subset is None or len(subset.subset_list) == 0: @@ -1240,9 +1257,7 @@ def _underapproximate_writes_loop_subset(self, # compute the union of the ranges to merge the subsets. dst_memlet.subset = _merge_subsets(dst_memlet.subset, subset) - def _underapproximate_writes_scope(self, - sdfg: SDFG, - state: SDFGState, + def _underapproximate_writes_scope(self, sdfg: SDFG, state: SDFGState, scopes: Union[scope.ScopeTree, List[scope.ScopeTree]]): """ Propagate memlets from the given scopes outwards. @@ -1253,8 +1268,7 @@ def _underapproximate_writes_scope(self, """ # for each map scope find the iteration variables of surrounding maps - surrounding_map_vars: Dict[scope.ScopeTree, - Set[str]] = _collect_itvars_scope(scopes) + surrounding_map_vars: Dict[scope.ScopeTree, Set[str]] = _collect_itvars_scope(scopes) if isinstance(scopes, scope.ScopeTree): scopes_to_process = [scopes] else: @@ -1272,8 +1286,7 @@ def _underapproximate_writes_scope(self, sdfg, state, surrounding_map_vars) - self._underapproximate_writes_node( - state, scope_node.exit, surrounding_iteration_variables) + self._underapproximate_writes_node(state, scope_node.exit, surrounding_iteration_variables) # Add parent to next frontier next_scopes.add(scope_node.parent) scopes_to_process = next_scopes @@ -1286,9 +1299,8 @@ def _collect_iteration_variables_scope_node(self, surrounding_map_vars: Dict[scope.ScopeTree, Set[str]]) -> Set[str]: map_iteration_variables = surrounding_map_vars[ scope_node] if scope_node in surrounding_map_vars else set() - sdfg_iteration_variables = iteration_variables[ - sdfg] if sdfg in iteration_variables else set() - loop_iteration_variables = ranges_per_state[state].keys() + sdfg_iteration_variables = self.iteration_variables[sdfg] if sdfg in self.iteration_variables else set() + loop_iteration_variables = self.ranges_per_state[state].keys() surrounding_iteration_variables = (map_iteration_variables | sdfg_iteration_variables | loop_iteration_variables) @@ -1308,12 +1320,8 @@ def _underapproximate_writes_node(self, :param surrounding_itvars: Iteration variables that surround the map scope """ if isinstance(node, nodes.EntryNode): - internal_edges = [ - e for e in dfg_state.out_edges(node) if e.src_conn and e.src_conn.startswith('OUT_') - ] - external_edges = [ - e for e in dfg_state.in_edges(node) if e.dst_conn and e.dst_conn.startswith('IN_') - ] + internal_edges = [e for e in dfg_state.out_edges(node) if e.src_conn and e.src_conn.startswith('OUT_')] + external_edges = [e for e in dfg_state.in_edges(node) if e.dst_conn and e.dst_conn.startswith('IN_')] def geticonn(e): return e.src_conn[4:] @@ -1323,12 +1331,8 @@ def geteconn(e): use_dst = False else: - internal_edges = [ - e for e in dfg_state.in_edges(node) if e.dst_conn and e.dst_conn.startswith('IN_') - ] - external_edges = [ - e for e in dfg_state.out_edges(node) if e.src_conn and e.src_conn.startswith('OUT_') - ] + internal_edges = [e for e in dfg_state.in_edges(node) if e.dst_conn and e.dst_conn.startswith('IN_')] + external_edges = [e for e in dfg_state.out_edges(node) if e.src_conn and e.src_conn.startswith('OUT_')] def geticonn(e): return e.dst_conn[3:] @@ -1339,21 +1343,17 @@ def geteconn(e): use_dst = True for edge in external_edges: - if approximation_dict[edge].is_empty(): + if self.approximation_dict[edge].is_empty(): new_memlet = Memlet() else: internal_edge = next( e for e in internal_edges if geticonn(e) == geteconn(edge)) - aligned_memlet = self._align_memlet( - dfg_state, internal_edge, dst=use_dst) - new_memlet = self._underapproximate_memlets(dfg_state, - aligned_memlet, - node, - True, - connector=geteconn( - edge), + aligned_memlet = self._align_memlet(dfg_state, internal_edge, dst=use_dst) + new_memlet = self._underapproximate_memlets(dfg_state, aligned_memlet, node, True, + connector=geteconn(edge), surrounding_itvars=surrounding_itvars) - approximation_dict[edge] = new_memlet + new_memlet._edge = edge + self.approximation_dict[edge] = new_memlet def _align_memlet(self, state: SDFGState, @@ -1373,16 +1373,16 @@ def _align_memlet(self, is_src = edge.data._is_data_src # Memlet is already aligned if is_src is None or (is_src and not dst) or (not is_src and dst): - res = approximation_dict[edge] + res = self.approximation_dict[edge] return res # Data<->Code memlets always have one data container mpath = state.memlet_path(edge) if not isinstance(mpath[0].src, AccessNode) or not isinstance(mpath[-1].dst, AccessNode): - return approximation_dict[edge] + return self.approximation_dict[edge] # Otherwise, find other data container - result = copy.deepcopy(approximation_dict[edge]) + result = copy.deepcopy(self.approximation_dict[edge]) if dst: node = mpath[-1].dst else: @@ -1390,8 +1390,8 @@ def _align_memlet(self, # Fix memlet fields result.data = node.data - result.subset = approximation_dict[edge].other_subset - result.other_subset = approximation_dict[edge].subset + result.subset = self.approximation_dict[edge].other_subset + result.other_subset = self.approximation_dict[edge].subset result._is_data_src = not is_src return result @@ -1448,9 +1448,9 @@ def _underapproximate_memlets(self, # and union their subsets if union_inner_edges: aggdata = [ - approximation_dict[e] + self.approximation_dict[e] for e in neighboring_edges - if approximation_dict[e].data == memlet.data and approximation_dict[e] != memlet + if self.approximation_dict[e].data == memlet.data and self.approximation_dict[e] != memlet ] else: aggdata = [] @@ -1459,8 +1459,7 @@ def _underapproximate_memlets(self, if arr is None: if memlet.data not in sdfg.arrays: - raise KeyError('Data descriptor (Array, Stream) "%s" not defined in SDFG.' % - memlet.data) + raise KeyError('Data descriptor (Array, Stream) "%s" not defined in SDFG.' % memlet.data) # FIXME: A memlet alone (without an edge) cannot figure out whether it is data<->data or data<->code # so this test cannot be used diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index 1c038dd2e4..f62bb6eb58 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -4,21 +4,22 @@ from internal memory accesses and scope ranges). """ -from collections import deque import copy -from dace.symbolic import issymbolic, pystr_to_symbolic, simplify -import itertools import functools +import itertools +import warnings +from collections import deque +from typing import List, Set + import sympy -from sympy import ceiling, Symbol +from sympy import Symbol, ceiling from sympy.concrete.summations import Sum -import warnings -import networkx as nx -from dace import registry, subsets, symbolic, dtypes, data +from dace import data, dtypes, registry, subsets, symbolic from dace.memlet import Memlet -from dace.sdfg import nodes, graph as gr -from typing import List, Set +from dace.sdfg import graph as gr +from dace.sdfg import nodes +from dace.symbolic import issymbolic, pystr_to_symbolic, simplify @registry.make_registry @@ -61,17 +62,17 @@ def can_be_applied(self, expressions, variable_context, node_range, orig_edges): for rb, re, rs in node_range]) for dim in range(data_dims): - dexprs = [] for expr in expressions: - if isinstance(expr[dim], symbolic.SymExpr): - dexprs.append(expr[dim].approx) - elif isinstance(expr[dim], tuple): - dexprs.append((expr[dim][0].approx if isinstance(expr[dim][0], symbolic.SymExpr) else expr[dim][0], - expr[dim][1].approx if isinstance(expr[dim][1], symbolic.SymExpr) else expr[dim][1], - expr[dim][2].approx if isinstance(expr[dim][2], symbolic.SymExpr) else expr[dim][2])) + expr_dim = expr[dim] + if isinstance(expr_dim, symbolic.SymExpr): + dexprs.append(expr_dim.approx) + elif isinstance(expr_dim, tuple): + dexprs.append((expr_dim[0].approx if isinstance(expr_dim[0], symbolic.SymExpr) else expr_dim[0], + expr_dim[1].approx if isinstance(expr_dim[1], symbolic.SymExpr) else expr_dim[1], + expr_dim[2].approx if isinstance(expr_dim[2], symbolic.SymExpr) else expr_dim[2])) else: - dexprs.append(expr[dim]) + dexprs.append(expr_dim) for pattern_class in SeparableMemletPattern.extensions().keys(): smpattern = pattern_class() @@ -93,15 +94,16 @@ def propagate(self, array, expressions, node_range): dexprs = [] for expr in expressions: - if isinstance(expr[i], symbolic.SymExpr): - dexprs.append(expr[i].approx) - elif isinstance(expr[i], tuple): - dexprs.append((expr[i][0].approx if isinstance(expr[i][0], symbolic.SymExpr) else expr[i][0], - expr[i][1].approx if isinstance(expr[i][1], symbolic.SymExpr) else expr[i][1], - expr[i][2].approx if isinstance(expr[i][2], symbolic.SymExpr) else expr[i][2], + expr_i = expr[i] + if isinstance(expr_i, symbolic.SymExpr): + dexprs.append(expr_i.approx) + elif isinstance(expr_i, tuple): + dexprs.append((expr_i[0].approx if isinstance(expr_i[0], symbolic.SymExpr) else expr_i[0], + expr_i[1].approx if isinstance(expr_i[1], symbolic.SymExpr) else expr_i[1], + expr_i[2].approx if isinstance(expr_i[2], symbolic.SymExpr) else expr_i[2], expr.tile_sizes[i])) else: - dexprs.append(expr[i]) + dexprs.append(expr_i) result[i] = smpattern.propagate(array, dexprs, overapprox_range) @@ -569,8 +571,8 @@ def _annotate_loop_ranges(sdfg, unannotated_cycle_states): """ # We import here to avoid cyclic imports. - from dace.transformation.interstate.loop_detection import find_for_loop from dace.sdfg import utils as sdutils + from dace.transformation.interstate.loop_detection import find_for_loop condition_edges = {} @@ -739,8 +741,8 @@ def propagate_states(sdfg, concretize_dynamic_unbounded=False) -> None: # We import here to avoid cyclic imports. from dace.sdfg import InterstateEdge - from dace.transformation.helpers import split_interstate_edges from dace.sdfg.analysis import cfg + from dace.transformation.helpers import split_interstate_edges # Reset the state edge annotations (which may have changed due to transformations) reset_state_annotations(sdfg) diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 8d443e6beb..2ae6109b31 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2987,35 +2987,52 @@ class LoopRegion(ControlFlowRegion): inverted = Property(dtype=bool, default=False, desc='If True, the loop condition is checked after the first iteration.') + update_before_condition = Property(dtype=bool, + default=True, + desc='If False, the loop condition is checked before the update statement is' + + ' executed. This only applies to inverted loops, turning them from a typical ' + + 'do-while style into a while(true) with a break before the update (at the end ' + + 'of an iteration) if the condition no longer holds.') loop_variable = Property(dtype=str, default='', desc='The loop variable, if given') def __init__(self, label: str, - condition_expr: Optional[str] = None, + condition_expr: Optional[Union[str, CodeBlock]] = None, loop_var: Optional[str] = None, - initialize_expr: Optional[str] = None, - update_expr: Optional[str] = None, + initialize_expr: Optional[Union[str, CodeBlock]] = None, + update_expr: Optional[Union[str, CodeBlock]] = None, inverted: bool = False, - sdfg: Optional['SDFG'] = None): + sdfg: Optional['SDFG'] = None, + update_before_condition = True): super(LoopRegion, self).__init__(label, sdfg) if initialize_expr is not None: - self.init_statement = CodeBlock(initialize_expr) + if isinstance(initialize_expr, CodeBlock): + self.init_statement = initialize_expr + else: + self.init_statement = CodeBlock(initialize_expr) else: self.init_statement = None if condition_expr: - self.loop_condition = CodeBlock(condition_expr) + if isinstance(condition_expr, CodeBlock): + self.loop_condition = condition_expr + else: + self.loop_condition = CodeBlock(condition_expr) else: self.loop_condition = CodeBlock('True') if update_expr is not None: - self.update_statement = CodeBlock(update_expr) + if isinstance(update_expr, CodeBlock): + self.update_statement = update_expr + else: + self.update_statement = CodeBlock(update_expr) else: self.update_statement = None self.loop_variable = loop_var or '' self.inverted = inverted + self.update_before_condition = update_before_condition def inline(self) -> Tuple[bool, Any]: """ @@ -3234,7 +3251,12 @@ def __repr__(self) -> str: @property def branches(self) -> List[Tuple[Optional[CodeBlock], ControlFlowRegion]]: return self._branches - + + def add_branch(self, condition: Optional[CodeBlock], branch: ControlFlowRegion): + self._branches.append([condition, branch]) + branch.parent_graph = self.parent_graph + branch.sdfg = self.sdfg + def nodes(self) -> List['ControlFlowBlock']: return [node for _, node in self._branches if node is not None] diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index 74a3d2ee12..6ca4602079 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -379,7 +379,7 @@ def nest_state_subgraph(sdfg: SDFG, SDFG. :raise ValueError: The subgraph is contained in more than one scope. """ - if state.parent != sdfg: + if state.sdfg != sdfg: raise KeyError('State does not belong to given SDFG') if subgraph is not state and subgraph.graph is not state: raise KeyError('Subgraph does not belong to given state') @@ -433,7 +433,7 @@ def nest_state_subgraph(sdfg: SDFG, # top-level graph) data_in_subgraph = set(n.data for n in subgraph.nodes() if isinstance(n, nodes.AccessNode)) # Find other occurrences in SDFG - other_nodes = set(n.data for s in sdfg.nodes() for n in s.nodes() + other_nodes = set(n.data for s in sdfg.states() for n in s.nodes() if isinstance(n, nodes.AccessNode) and n not in subgraph.nodes()) subgraph_transients = set() for data in data_in_subgraph: diff --git a/dace/transformation/interstate/loop_detection.py b/dace/transformation/interstate/loop_detection.py index 93c2f6ea1c..8081447132 100644 --- a/dace/transformation/interstate/loop_detection.py +++ b/dace/transformation/interstate/loop_detection.py @@ -1,9 +1,9 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Loop detection transformation """ import sympy as sp import networkx as nx -from typing import AnyStr, Optional, Tuple, List, Set +from typing import AnyStr, Iterable, Optional, Tuple, List, Set from dace import sdfg as sd, symbolic from dace.sdfg import graph as gr, utils as sdutil, InterstateEdge @@ -29,6 +29,9 @@ class DetectLoop(transformation.PatternTransformation): # Available for rotated and self loops entry_state = transformation.PatternNode(sd.SDFGState) + # Available for explicit-latch rotated loops + loop_break = transformation.PatternNode(sd.SDFGState) + @classmethod def expressions(cls): # Case 1: Loop with one state @@ -69,7 +72,46 @@ def expressions(cls): ssdfg.add_edge(cls.loop_begin, cls.loop_begin, sd.InterstateEdge()) ssdfg.add_edge(cls.loop_begin, cls.exit_state, sd.InterstateEdge()) - return [sdfg, msdfg, rsdfg, rmsdfg, ssdfg] + # Case 6: Rotated multi-state loop with explicit exiting and latch states + mlrmsdfg = gr.OrderedDiGraph() + mlrmsdfg.add_nodes_from([cls.entry_state, cls.loop_break, cls.loop_latch, cls.loop_begin, cls.exit_state]) + mlrmsdfg.add_edge(cls.entry_state, cls.loop_begin, sd.InterstateEdge()) + mlrmsdfg.add_edge(cls.loop_latch, cls.loop_begin, sd.InterstateEdge()) + mlrmsdfg.add_edge(cls.loop_break, cls.exit_state, sd.InterstateEdge()) + mlrmsdfg.add_edge(cls.loop_break, cls.loop_latch, sd.InterstateEdge()) + + # Case 7: Rotated single-state loop with explicit exiting and latch states + mlrsdfg = gr.OrderedDiGraph() + mlrsdfg.add_nodes_from([cls.entry_state, cls.loop_latch, cls.loop_begin, cls.exit_state]) + mlrsdfg.add_edge(cls.entry_state, cls.loop_begin, sd.InterstateEdge()) + mlrsdfg.add_edge(cls.loop_latch, cls.loop_begin, sd.InterstateEdge()) + mlrsdfg.add_edge(cls.loop_begin, cls.exit_state, sd.InterstateEdge()) + mlrsdfg.add_edge(cls.loop_begin, cls.loop_latch, sd.InterstateEdge()) + + # Case 8: Guarded rotated multi-state loop with explicit exiting and latch states (modification of case 6) + gmlrmsdfg = gr.OrderedDiGraph() + gmlrmsdfg.add_nodes_from([cls.entry_state, cls.loop_break, cls.loop_latch, cls.loop_begin, cls.exit_state]) + gmlrmsdfg.add_edge(cls.entry_state, cls.loop_begin, sd.InterstateEdge()) + gmlrmsdfg.add_edge(cls.loop_latch, cls.loop_begin, sd.InterstateEdge()) + gmlrmsdfg.add_edge(cls.loop_begin, cls.loop_break, sd.InterstateEdge()) + gmlrmsdfg.add_edge(cls.loop_break, cls.exit_state, sd.InterstateEdge()) + gmlrmsdfg.add_edge(cls.loop_break, cls.loop_latch, sd.InterstateEdge()) + + return [sdfg, msdfg, rsdfg, rmsdfg, ssdfg, mlrmsdfg, mlrsdfg, gmlrmsdfg] + + @property + def inverted(self) -> bool: + """ + Whether the loop matched a pattern of an inverted (do-while style) loop. + """ + return self.expr_index in (2, 3, 5, 6, 7) + + @property + def first_loop_block(self) -> ControlFlowBlock: + """ + The first control flow block executed in each loop iteration. + """ + return self.loop_guard if self.expr_index <= 1 else self.loop_begin def can_be_applied(self, graph: ControlFlowRegion, @@ -77,19 +119,26 @@ def can_be_applied(self, sdfg: sd.SDFG, permissive: bool = False) -> bool: if expr_index == 0: - return self.detect_loop(graph, False) is not None + return self.detect_loop(graph, multistate_loop=False, accept_missing_itvar=permissive) is not None elif expr_index == 1: - return self.detect_loop(graph, True) is not None + return self.detect_loop(graph, multistate_loop=True, accept_missing_itvar=permissive) is not None elif expr_index == 2: - return self.detect_rotated_loop(graph, False) is not None + return self.detect_rotated_loop(graph, multistate_loop=False, accept_missing_itvar=permissive) is not None elif expr_index == 3: - return self.detect_rotated_loop(graph, True) is not None + return self.detect_rotated_loop(graph, multistate_loop=True, accept_missing_itvar=permissive) is not None elif expr_index == 4: - return self.detect_self_loop(graph) is not None + return self.detect_self_loop(graph, accept_missing_itvar=permissive) is not None + elif expr_index in (5, 7): + return self.detect_rotated_loop(graph, multistate_loop=True, accept_missing_itvar=permissive, + separate_latch=True) is not None + elif expr_index == 6: + return self.detect_rotated_loop(graph, multistate_loop=False, accept_missing_itvar=permissive, + separate_latch=True) is not None raise ValueError(f'Invalid expression index {expr_index}') - def detect_loop(self, graph: ControlFlowRegion, multistate_loop: bool) -> Optional[str]: + def detect_loop(self, graph: ControlFlowRegion, multistate_loop: bool, + accept_missing_itvar: bool = False) -> Optional[str]: """ Detects a loop of the form: @@ -159,13 +208,19 @@ def detect_loop(self, graph: ControlFlowRegion, multistate_loop: bool) -> Option # The backedge must reassign the iteration variable itvar &= backedge.data.assignments.keys() if len(itvar) != 1: - # Either no consistent iteration variable found, or too many - # consistent iteration variables found - return None + if not accept_missing_itvar: + # Either no consistent iteration variable found, or too many consistent iteration variables found + return None + else: + if len(itvar) == 0: + return '' + else: + return None return next(iter(itvar)) - def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool) -> Optional[str]: + def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool, + accept_missing_itvar: bool = False, separate_latch: bool = False) -> Optional[str]: """ Detects a loop of the form: @@ -181,6 +236,9 @@ def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool) - :return: The loop variable or ``None`` if not detected. """ latch = self.loop_latch + ltest = self.loop_latch + if separate_latch: + ltest = self.loop_break if multistate_loop else self.loop_begin begin = self.loop_begin # A for-loop start has at least two incoming edges (init and increment) @@ -188,18 +246,14 @@ def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool) - if len(begin_inedges) < 2: return None # A for-loop latch only has two outgoing edges (loop condition and exit-loop) - latch_outedges = graph.out_edges(latch) + latch_outedges = graph.out_edges(ltest) if len(latch_outedges) != 2: return None - # All incoming edges to the start of the loop must set the same variable - itvar = None - for iedge in begin_inedges: - if itvar is None: - itvar = set(iedge.data.assignments.keys()) - else: - itvar &= iedge.data.assignments.keys() - if itvar is None: + # A for-loop latch can further only have one incoming edge (the increment edge). A while-loop, i.e., a loop + # with no explicit iteration variable, may have more than that. + latch_inedges = graph.in_edges(latch) + if not accept_missing_itvar and len(latch_inedges) != 1: return None # Outgoing edges must be a negation of each other @@ -208,8 +262,13 @@ def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool) - # All nodes inside loop must be dominated by loop start dominators = nx.dominance.immediate_dominators(graph.nx, graph.start_block) - loop_nodes = list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != latch)) - loop_nodes += [latch] + if begin is ltest: + loop_nodes = [begin] + else: + loop_nodes = list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != ltest)) + loop_nodes.append(latch) + if ltest is not latch and ltest is not begin: + loop_nodes.append(ltest) backedge = None for node in loop_nodes: for e in graph.out_edges(node): @@ -231,16 +290,9 @@ def detect_rotated_loop(self, graph: ControlFlowRegion, multistate_loop: bool) - if backedge is None: return None - # The backedge must reassign the iteration variable - itvar &= backedge.data.assignments.keys() - if len(itvar) != 1: - # Either no consistent iteration variable found, or too many - # consistent iteration variables found - return None + return rotated_loop_find_itvar(begin_inedges, latch_inedges, backedge, ltest, accept_missing_itvar)[0] - return next(iter(itvar)) - - def detect_self_loop(self, graph: ControlFlowRegion) -> Optional[str]: + def detect_self_loop(self, graph: ControlFlowRegion, accept_missing_itvar: bool = False) -> Optional[str]: """ Detects a loop of the form: @@ -288,9 +340,14 @@ def detect_self_loop(self, graph: ControlFlowRegion) -> Optional[str]: # The backedge must reassign the iteration variable itvar &= backedge.data.assignments.keys() if len(itvar) != 1: - # Either no consistent iteration variable found, or too many - # consistent iteration variables found - return None + if not accept_missing_itvar: + # Either no consistent iteration variable found, or too many consistent iteration variables found + return None + else: + if len(itvar) == 0: + return '' + else: + return None return next(iter(itvar)) @@ -310,9 +367,10 @@ def loop_information( if self.expr_index <= 1: guard = self.loop_guard return find_for_loop(guard.parent_graph, guard, entry, itervar) - elif self.expr_index in (2, 3): + elif self.expr_index in (2, 3, 5, 6, 7): latch = self.loop_latch - return find_rotated_for_loop(latch.parent_graph, latch, entry, itervar) + return find_rotated_for_loop(latch.parent_graph, latch, entry, itervar, + separate_latch=(self.expr_index in (5, 6, 7))) elif self.expr_index == 4: return find_rotated_for_loop(entry.parent_graph, entry, entry, itervar) @@ -334,6 +392,14 @@ def loop_body(self) -> List[ControlFlowBlock]: return loop_nodes elif self.expr_index == 4: return [begin] + elif self.expr_index in (5, 7): + ltest = self.loop_break + latch = self.loop_latch + loop_nodes = list(sdutil.dfs_conditional(graph, sources=[begin], condition=lambda _, child: child != ltest)) + loop_nodes += [ltest, latch] + return loop_nodes + elif self.expr_index == 6: + return [begin, self.loop_latch] return [] @@ -343,8 +409,10 @@ def loop_meta_states(self) -> List[ControlFlowBlock]: """ if self.expr_index in (0, 1): return [self.loop_guard] - if self.expr_index in (2, 3): + if self.expr_index in (2, 3, 6): return [self.loop_latch] + if self.expr_index in (5, 7): + return [self.loop_break, self.loop_latch] return [] def loop_init_edge(self) -> gr.Edge[InterstateEdge]: @@ -357,7 +425,7 @@ def loop_init_edge(self) -> gr.Edge[InterstateEdge]: guard = self.loop_guard body = self.loop_body() return next(e for e in graph.in_edges(guard) if e.src not in body) - elif self.expr_index in (2, 3): + elif self.expr_index in (2, 3, 5, 6, 7): latch = self.loop_latch return next(e for e in graph.in_edges(begin) if e.src is not latch) elif self.expr_index == 4: @@ -377,9 +445,12 @@ def loop_exit_edge(self) -> gr.Edge[InterstateEdge]: elif self.expr_index in (2, 3): latch = self.loop_latch return graph.edges_between(latch, exitstate)[0] - elif self.expr_index == 4: + elif self.expr_index in (4, 6): begin = self.loop_begin return graph.edges_between(begin, exitstate)[0] + elif self.expr_index in (5, 7): + ltest = self.loop_break + return graph.edges_between(ltest, exitstate)[0] raise ValueError(f'Invalid expression index {self.expr_index}') @@ -398,6 +469,10 @@ def loop_condition_edge(self) -> gr.Edge[InterstateEdge]: elif self.expr_index == 4: begin = self.loop_begin return graph.edges_between(begin, begin)[0] + elif self.expr_index in (5, 6, 7): + latch = self.loop_latch + ltest = self.loop_break if self.expr_index in (5, 7) else self.loop_begin + return graph.edges_between(ltest, latch)[0] raise ValueError(f'Invalid expression index {self.expr_index}') @@ -411,15 +486,93 @@ def loop_increment_edge(self) -> gr.Edge[InterstateEdge]: guard = self.loop_guard body = self.loop_body() return next(e for e in graph.in_edges(guard) if e.src in body) - elif self.expr_index in (2, 3): - body = self.loop_body() - return next(e for e in graph.in_edges(begin) if e.src in body) + elif self.expr_index in (2, 3, 5, 6, 7): + _, step_edge = rotated_loop_find_itvar(graph.in_edges(begin), graph.in_edges(self.loop_latch), + graph.edges_between(self.loop_latch, begin)[0], self.loop_latch) + return step_edge elif self.expr_index == 4: return graph.edges_between(begin, begin)[0] raise ValueError(f'Invalid expression index {self.expr_index}') +def rotated_loop_find_itvar(begin_inedges: List[gr.Edge[InterstateEdge]], + latch_inedges: List[gr.Edge[InterstateEdge]], + backedge: gr.Edge[InterstateEdge], latch: ControlFlowBlock, + accept_missing_itvar: bool = False) -> Tuple[Optional[str], + Optional[gr.Edge[InterstateEdge]]]: + # The iteration variable must be assigned (initialized) on all edges leading into the beginning block, which + # are not the backedge. Gather all variabes for which that holds - they are all candidates for the iteration + # variable (Phase 1). Said iteration variable must then be incremented: + # EITHER: On the backedge, in which case the increment is only executed if the loop does not exit. This + # corresponds to a while(true) loop that checks the condition at the end of the loop body and breaks + # if it does not hold before incrementing. (Scenario 1) + # OR: On the edge(s) leading into the latch, in which case the increment is executed BEFORE the condition is + # checked - which corresponds to a do-while loop. (Scenario 2) + # For either case, the iteration variable may only be incremented on one of these places. Filter the candidates + # down to each variable for which this condition holds (Phase 2). If there is exactly one candidate remaining, + # that is the iteration variable. Otherwise it cannot be determined. + + # Phase 1: Gather iteration variable candidates. + itvar_candidates = None + for e in begin_inedges: + if e is backedge: + continue + if itvar_candidates is None: + itvar_candidates = set(e.data.assignments.keys()) + else: + itvar_candidates &= set(e.data.assignments.keys()) + + # Phase 2: Filter down the candidates according to incrementation edges. + step_edge = None + filtered_candidates = set() + backedge_incremented = set(backedge.data.assignments.keys()) + latch_incremented = None + if backedge.src is not backedge.dst: + # If this is a self loop, there are no edges going into the latch to be considered. The only incoming edges are + # from outside the loop. + for e in latch_inedges: + if e is backedge: + continue + if latch_incremented is None: + latch_incremented = set(e.data.assignments.keys()) + else: + latch_incremented &= set(e.data.assignments.keys()) + if latch_incremented is None: + latch_incremented = set() + for cand in itvar_candidates: + if cand in backedge_incremented: + # Scenario 1. + + # Note, only allow this scenario if the backedge leads directly from the latch to the entry, i.e., there is + # no intermediate block on the backedge path. + if backedge.src is not latch: + continue + + if cand not in latch_incremented: + filtered_candidates.add(cand) + elif cand in latch_incremented: + # Scenario 2. + if cand not in backedge_incremented: + filtered_candidates.add(cand) + if len(filtered_candidates) != 1: + if not accept_missing_itvar: + # Either no consistent iteration variable found, or too many consistent iteration variables found + return None, None + else: + if len(filtered_candidates) == 0: + return '', None + else: + return None, None + else: + itvar = next(iter(filtered_candidates)) + if itvar in backedge_incremented: + step_edge = backedge + elif len(latch_inedges) == 1: + step_edge = latch_inedges[0] + return itvar, step_edge + + def find_for_loop( graph: ControlFlowRegion, guard: sd.SDFGState, @@ -520,6 +673,10 @@ def find_for_loop( match = condition.match(itersym >= a) if match: end = match[a] + if end is None: + match = condition.match(sp.Ne(itersym + stride, a)) + if match: + end = match[a] - stride if end is None: # No match found return None @@ -531,14 +688,14 @@ def find_rotated_for_loop( graph: ControlFlowRegion, latch: sd.SDFGState, entry: sd.SDFGState, - itervar: Optional[str] = None + itervar: Optional[str] = None, + separate_latch: bool = False, ) -> Optional[Tuple[AnyStr, Tuple[symbolic.SymbolicType, symbolic.SymbolicType, symbolic.SymbolicType], Tuple[ List[sd.SDFGState], sd.SDFGState]]]: """ Finds rotated loop range from state machine. - :param latch: State from which the outgoing edges detect whether to exit - the loop or not. + :param latch: State from which the outgoing edges detect whether to reenter the loop or not. :param entry: First state in the loop body. :param itervar: An optional field that overrides the analyzed iteration variable. :return: (iteration variable, (start, end, stride), @@ -547,20 +704,19 @@ def find_rotated_for_loop( """ # Extract state transition edge information entry_inedges = graph.in_edges(entry) - condition_edge = graph.edges_between(latch, entry)[0] - - # All incoming edges to the loop entry must set the same variable + if separate_latch: + condition_edge = graph.in_edges(latch)[0] + backedge = graph.edges_between(latch, entry)[0] + else: + condition_edge = graph.edges_between(latch, entry)[0] + backedge = condition_edge + latch_inedges = graph.in_edges(latch) + + self_loop = latch is entry + step_edge = None if itervar is None: - itervars = None - for iedge in entry_inedges: - if itervars is None: - itervars = set(iedge.data.assignments.keys()) - else: - itervars &= iedge.data.assignments.keys() - if itervars and len(itervars) == 1: - itervar = next(iter(itervars)) - else: - # Ambiguous or no iteration variable + itervar, step_edge = rotated_loop_find_itvar(entry_inedges, latch_inedges, backedge, latch) + if itervar is None: return None condition = condition_edge.data.condition_sympy() @@ -570,18 +726,12 @@ def find_rotated_for_loop( # have one assignment. init_edges = [] init_assignment = None - step_edge = None itersym = symbolic.symbol(itervar) for iedge in entry_inedges: + if iedge is condition_edge: + continue assignment = iedge.data.assignments[itervar] - if itersym in symbolic.pystr_to_symbolic(assignment).free_symbols: - if step_edge is None: - step_edge = iedge - else: - # More than one edge with the iteration variable as a free - # symbol, which is not legal. Invalid for loop. - return None - else: + if itersym not in symbolic.pystr_to_symbolic(assignment).free_symbols: if init_assignment is None: init_assignment = assignment init_edges.append(iedge) @@ -591,10 +741,16 @@ def find_rotated_for_loop( return None else: init_edges.append(iedge) - if step_edge is None or len(init_edges) == 0 or init_assignment is None: + if len(init_edges) == 0 or init_assignment is None: # Less than two assignment variations, can't be a valid for loop. return None + if self_loop: + step_edge = condition_edge + else: + if step_edge is None: + return None + # Get the init expression and the stride. start = symbolic.pystr_to_symbolic(init_assignment) stride = (symbolic.pystr_to_symbolic(step_edge.data.assignments[itervar]) - itersym) @@ -626,6 +782,10 @@ def find_rotated_for_loop( match = condition.match(itersym >= a) if match: end = match[a] + if end is None: + match = condition.match(sp.Ne(itersym + stride, a)) + if match: + end = match[a] - stride if end is None: # No match found return None diff --git a/dace/transformation/interstate/loop_lifting.py b/dace/transformation/interstate/loop_lifting.py new file mode 100644 index 0000000000..072c2519ed --- /dev/null +++ b/dace/transformation/interstate/loop_lifting.py @@ -0,0 +1,99 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +from dace import properties +from dace.sdfg.sdfg import SDFG, InterstateEdge +from dace.sdfg.state import ControlFlowRegion, LoopRegion +from dace.transformation import transformation +from dace.transformation.interstate.loop_detection import DetectLoop + + +@properties.make_properties +@transformation.experimental_cfg_block_compatible +class LoopLifting(DetectLoop, transformation.MultiStateTransformation): + + def can_be_applied(self, graph: transformation.ControlFlowRegion, expr_index: int, sdfg: transformation.SDFG, + permissive: bool = False) -> bool: + # Check loop detection with permissive = True, which allows loops where no iteration variable could be detected. + # We want this to detect while loops. + if not super().can_be_applied(graph, expr_index, sdfg, permissive=True): + return False + + # Check that there's a condition edge, that's the only requirement to lift it into loop. + cond_edge = self.loop_condition_edge() + if not cond_edge or cond_edge.data.condition is None: + return False + return True + + def apply(self, graph: ControlFlowRegion, sdfg: SDFG): + first_state = self.first_loop_block + after = self.exit_state + + loop_info = self.loop_information() + + body = self.loop_body() + meta = self.loop_meta_states() + full_body = set(body) + full_body.update(meta) + cond_edge = self.loop_condition_edge() + incr_edge = self.loop_increment_edge() + inverted = self.inverted + init_edge = self.loop_init_edge() + exit_edge = self.loop_exit_edge() + + label = 'loop_' + first_state.label + if loop_info is None: + itvar = None + init_expr = None + incr_expr = None + else: + incr_expr = f'{loop_info[0]} = {incr_edge.data.assignments[loop_info[0]]}' + init_expr = f'{loop_info[0]} = {init_edge.data.assignments[loop_info[0]]}' + itvar = loop_info[0] + + left_over_assignments = {} + for k in init_edge.data.assignments.keys(): + if k != itvar: + left_over_assignments[k] = init_edge.data.assignments[k] + left_over_incr_assignments = {} + if incr_edge is not None: + for k in incr_edge.data.assignments.keys(): + if k != itvar: + left_over_incr_assignments[k] = incr_edge.data.assignments[k] + + if inverted and incr_edge is cond_edge: + update_before_condition = False + else: + update_before_condition = True + + loop = LoopRegion(label, condition_expr=cond_edge.data.condition, loop_var=itvar, initialize_expr=init_expr, + update_expr=incr_expr, inverted=inverted, sdfg=sdfg, + update_before_condition=update_before_condition) + + graph.add_node(loop) + graph.add_edge(init_edge.src, loop, + InterstateEdge(condition=init_edge.data.condition, assignments=left_over_assignments)) + graph.add_edge(loop, after, InterstateEdge(assignments=exit_edge.data.assignments)) + + loop.add_node(first_state, is_start_block=True) + added = set() + for e in graph.all_edges(*full_body): + if e.src in full_body and e.dst in full_body: + if not e in added: + added.add(e) + if e is incr_edge: + if left_over_incr_assignments != {}: + dst = loop.add_state(label + '_tail') if not inverted else e.dst + loop.add_edge(e.src, dst, InterstateEdge(assignments=left_over_incr_assignments)) + elif e is cond_edge: + if not inverted: + e.data.condition = properties.CodeBlock('1') + loop.add_edge(e.src, e.dst, e.data) + else: + loop.add_edge(e.src, e.dst, e.data) + + # Remove old loop. + for n in full_body: + graph.remove_node(n) + + sdfg.root_sdfg.using_experimental_blocks = True + sdfg.reset_cfg_list() diff --git a/dace/transformation/pass_pipeline.py b/dace/transformation/pass_pipeline.py index 494f9c39ae..9a8154df90 100644 --- a/dace/transformation/pass_pipeline.py +++ b/dace/transformation/pass_pipeline.py @@ -29,7 +29,8 @@ class Modifies(Flag): Memlets = auto() #: Memlets' existence, contents, or properties were modified Nodes = AccessNodes | Scopes | Tasklets | NestedSDFGs #: Modification of any dataflow node (contained in an SDFG state) was made Edges = InterstateEdges | Memlets #: Any edge (memlet or inter-state) was modified - Everything = Descriptors | Symbols | States | InterstateEdges | Nodes | Memlets #: Modification to arbitrary parts of SDFGs (nodes, edges, or properties) + CFG = States | InterstateEdges #: A CFG (any level) was modified (connectivity or number of control flow blocks, but not their contents) + Everything = Descriptors | Symbols | CFG | Nodes | Memlets #: Modification to arbitrary parts of SDFGs (nodes, edges, or properties) @properties.make_properties diff --git a/dace/transformation/passes/analysis/__init__.py b/dace/transformation/passes/analysis/__init__.py new file mode 100644 index 0000000000..5bc1f6e3f3 --- /dev/null +++ b/dace/transformation/passes/analysis/__init__.py @@ -0,0 +1 @@ +from .analysis import * diff --git a/dace/transformation/passes/analysis.py b/dace/transformation/passes/analysis/analysis.py similarity index 81% rename from dace/transformation/passes/analysis.py rename to dace/transformation/passes/analysis/analysis.py index c8bb0b7a9c..095319f807 100644 --- a/dace/transformation/passes/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -1,7 +1,8 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. from collections import defaultdict -from dace.transformation import pass_pipeline as ppl +from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion, LoopRegion +from dace.transformation import pass_pipeline as ppl, transformation from dace import SDFG, SDFGState, properties, InterstateEdge, Memlet, data as dt, symbolic from dace.sdfg.graph import Edge from dace.sdfg import nodes as nd @@ -16,6 +17,7 @@ @properties.make_properties +@transformation.experimental_cfg_block_compatible class StateReachability(ppl.Pass): """ Evaluates state reachability (which other states can be executed after each state). @@ -28,25 +30,106 @@ def modifies(self) -> ppl.Modifies: def should_reapply(self, modified: ppl.Modifies) -> bool: # If anything was modified, reapply - return modified & ppl.Modifies.States + return modified & ppl.Modifies.CFG + + def depends_on(self): + return {ControlFlowBlockReachability} - def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Set[SDFGState]]]: + def apply_pass(self, top_sdfg: SDFG, pipeline_res: Dict) -> Dict[int, Dict[SDFGState, Set[SDFGState]]]: """ :return: A dictionary mapping each state to its other reachable states. """ + # Ensure control flow block reachability is run if not run within a pipeline. + if pipeline_res is None or not ControlFlowBlockReachability.__name__ in pipeline_res: + cf_block_reach_dict = ControlFlowBlockReachability().apply_pass(top_sdfg, {}) + else: + cf_block_reach_dict = pipeline_res[ControlFlowBlockReachability.__name__] reachable: Dict[int, Dict[SDFGState, Set[SDFGState]]] = {} for sdfg in top_sdfg.all_sdfgs_recursive(): - result: Dict[SDFGState, Set[SDFGState]] = {} + result: Dict[SDFGState, Set[SDFGState]] = defaultdict(set) + for state in sdfg.states(): + for reached in cf_block_reach_dict[state.parent_graph.cfg_id][state]: + if isinstance(reached, SDFGState): + result[state].add(reached) + reachable[sdfg.cfg_id] = result + return reachable + + +@properties.make_properties +@transformation.experimental_cfg_block_compatible +class ControlFlowBlockReachability(ppl.Pass): + """ + Evaluates control flow block reachability (which control flow block can be executed after each control flow block) + """ + + CATEGORY: str = 'Analysis' + + contain_to_single_level = properties.Property(dtype=bool, default=False) + + def __init__(self, contain_to_single_level=False) -> None: + super().__init__() + + self.contain_to_single_level = contain_to_single_level + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.Nothing + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return modified & ppl.Modifies.CFG + + def _region_closure(self, region: ControlFlowRegion, + block_reach: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]]) -> Set[SDFGState]: + closure: Set[SDFGState] = set() + if isinstance(region, LoopRegion): + # Any point inside the loop may reach any other point inside the loop again. + # TODO(later): This is an overapproximation. A branch terminating in a break is excluded from this. + closure.update(region.all_control_flow_blocks()) + + # Add all states that this region can reach in its parent graph to the closure. + for reached_block in block_reach[region.parent_graph.cfg_id][region]: + if isinstance(reached_block, ControlFlowRegion): + closure.update(reached_block.all_control_flow_blocks()) + closure.add(reached_block) + + # Walk up the parent tree. + pivot = region.parent_graph + while pivot and not isinstance(pivot, SDFG): + closure.update(self._region_closure(pivot, block_reach)) + pivot = pivot.parent_graph + return closure + + def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]]: + """ + :return: For each control flow region, a dictionary mapping each control flow block to its other reachable + control flow blocks in the same region. + """ + single_level_reachable: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]] = defaultdict( + lambda: defaultdict(set) + ) + for cfg in top_sdfg.all_control_flow_regions(recursive=True): # In networkx this is currently implemented naively for directed graphs. # The implementation below is faster # tc: nx.DiGraph = nx.transitive_closure(sdfg.nx) + for n, v in reachable_nodes(cfg.nx): + single_level_reachable[cfg.cfg_id][n] = set(v) + if isinstance(cfg, LoopRegion): + single_level_reachable[cfg.cfg_id][n].update(cfg.nodes()) - for n, v in reachable_nodes(sdfg.nx): - result[n] = set(v) - - reachable[sdfg.cfg_id] = result + if self.contain_to_single_level: + return single_level_reachable + reachable: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]] = {} + for sdfg in top_sdfg.all_sdfgs_recursive(): + for cfg in sdfg.all_control_flow_regions(): + result: Dict[ControlFlowBlock, Set[ControlFlowBlock]] = defaultdict(set) + for block in cfg.nodes(): + for reached in single_level_reachable[block.parent_graph.cfg_id][block]: + if isinstance(reached, ControlFlowRegion): + result[block].update(reached.all_control_flow_blocks()) + result[block].add(reached) + if block.parent_graph is not sdfg: + result[block].update(self._region_closure(block.parent_graph, single_level_reachable)) + reachable[cfg.cfg_id] = result return reachable @@ -99,6 +182,7 @@ def reachable_nodes(G): @properties.make_properties +@transformation.experimental_cfg_block_compatible class SymbolAccessSets(ppl.Pass): """ Evaluates symbol access sets (which symbols are read/written in each state or interstate edge). @@ -116,25 +200,27 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]]: """ - :return: A dictionary mapping each state to a tuple of its (read, written) data descriptors. + :return: A dictionary mapping each state and interstate edge to a tuple of its (read, written) symbols. """ - top_result: Dict[int, Dict[SDFGState, Tuple[Set[str], Set[str]]]] = {} + top_result: Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]] = {} for sdfg in top_sdfg.all_sdfgs_recursive(): - adesc = set(sdfg.arrays.keys()) - result: Dict[SDFGState, Tuple[Set[str], Set[str]]] = {} - for state in sdfg.nodes(): - readset = state.free_symbols - # No symbols may be written to inside states. - result[state] = (readset, set()) - for oedge in sdfg.out_edges(state): - edge_readset = oedge.data.read_symbols() - adesc - edge_writeset = set(oedge.data.assignments.keys()) - result[oedge] = (edge_readset, edge_writeset) - top_result[sdfg.cfg_id] = result + for cfg in sdfg.all_control_flow_regions(): + adesc = set(sdfg.arrays.keys()) + result: Dict[SDFGState, Tuple[Set[str], Set[str]]] = {} + for block in cfg.nodes(): + if isinstance(block, SDFGState): + # No symbols may be written to inside states. + result[block] = (block.free_symbols, set()) + for oedge in cfg.out_edges(block): + edge_readset = oedge.data.read_symbols() - adesc + edge_writeset = set(oedge.data.assignments.keys()) + result[oedge] = (edge_readset, edge_writeset) + top_result[cfg.cfg_id] = result return top_result @properties.make_properties +@transformation.experimental_cfg_block_compatible class AccessSets(ppl.Pass): """ Evaluates memory access sets (which arrays/data descriptors are read/written in each state). @@ -179,6 +265,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Tuple[Set[s @properties.make_properties +@transformation.experimental_cfg_block_compatible class FindAccessStates(ppl.Pass): """ For each data descriptor, creates a set of states in which access nodes of that data are used. @@ -201,13 +288,13 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[SDFGState]]]: for sdfg in top_sdfg.all_sdfgs_recursive(): result: Dict[str, Set[SDFGState]] = defaultdict(set) - for state in sdfg.nodes(): + for state in sdfg.states(): for anode in state.data_nodes(): result[anode.data].add(state) # Edges that read from arrays add to both ends' access sets anames = sdfg.arrays.keys() - for e in sdfg.edges(): + for e in sdfg.all_interstate_edges(): fsyms = e.data.free_symbols & anames for access in fsyms: result[access].update({e.src, e.dst}) @@ -217,6 +304,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[SDFGState]]]: @properties.make_properties +@transformation.experimental_cfg_block_compatible class FindAccessNodes(ppl.Pass): """ For each data descriptor, creates a dictionary mapping states to all read and write access nodes with the given @@ -242,7 +330,7 @@ def apply_pass(self, top_sdfg: SDFG, for sdfg in top_sdfg.all_sdfgs_recursive(): result: Dict[str, Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]]] = defaultdict( lambda: defaultdict(lambda: [set(), set()])) - for state in sdfg.nodes(): + for state in sdfg.states(): for anode in state.data_nodes(): if state.in_degree(anode) > 0: result[anode.data][state][1].add(anode) @@ -508,6 +596,7 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i @properties.make_properties +@transformation.experimental_cfg_block_compatible class AccessRanges(ppl.Pass): """ For each data descriptor, finds all memlets used to access it (read/write ranges). @@ -544,6 +633,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[Memlet]]]: @properties.make_properties +@transformation.experimental_cfg_block_compatible class FindReferenceSources(ppl.Pass): """ For each Reference data descriptor, finds all memlets used to set it. If a Tasklet was used @@ -586,6 +676,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[Union[Memlet, @properties.make_properties +@transformation.experimental_cfg_block_compatible class DeriveSDFGConstraints(ppl.Pass): CATEGORY: str = 'Analysis' diff --git a/dace/transformation/passes/analysis/loop_analysis.py b/dace/transformation/passes/analysis/loop_analysis.py new file mode 100644 index 0000000000..3d15f73c73 --- /dev/null +++ b/dace/transformation/passes/analysis/loop_analysis.py @@ -0,0 +1,116 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +""" +Various analyses concerning LopoRegions, and utility functions to get information about LoopRegions for other passes. +""" + +import ast +from typing import Any, Dict, Optional +from dace.frontend.python import astutils + +import sympy + +from dace import symbolic +from dace.sdfg.state import LoopRegion + + +class FindAssignment(ast.NodeVisitor): + + assignments: Dict[str, str] + multiple: bool + + def __init__(self): + self.assignments = {} + self.multiple = False + + def visit_Assign(self, node: ast.Assign) -> Any: + for tgt in node.targets: + if isinstance(tgt, ast.Name): + if tgt.id in self.assignments: + self.multiple = True + self.assignments[tgt.id] = astutils.unparse(node.value) + return self.generic_visit(node) + + +def get_loop_end(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: + """ + Parse a loop region to identify the end value of the iteration variable under normal loop termination (no break). + """ + end: Optional[symbolic.SymbolicType] = None + a = sympy.Wild('a') + condition = symbolic.pystr_to_symbolic(loop.loop_condition.as_string) + itersym = symbolic.pystr_to_symbolic(loop.loop_variable) + match = condition.match(itersym < a) + if match: + end = match[a] - 1 + if end is None: + match = condition.match(itersym <= a) + if match: + end = match[a] + if end is None: + match = condition.match(itersym > a) + if match: + end = match[a] + 1 + if end is None: + match = condition.match(itersym >= a) + if match: + end = match[a] + return end + + +def get_init_assignment(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: + """ + Parse a loop region's init statement to identify the exact init assignment expression. + """ + init_stmt = loop.init_statement + if init_stmt is None: + return None + + init_codes_list = init_stmt.code if isinstance(init_stmt.code, list) else [init_stmt.code] + assignments: Dict[str, str] = {} + for code in init_codes_list: + visitor = FindAssignment() + visitor.visit(code) + if visitor.multiple: + return None + for assign in visitor.assignments: + if assign in assignments: + return None + assignments[assign] = visitor.assignments[assign] + + if loop.loop_variable in assignments: + return symbolic.pystr_to_symbolic(assignments[loop.loop_variable]) + + return None + + +def get_update_assignment(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: + """ + Parse a loop region's update statement to identify the exact update assignment expression. + """ + update_stmt = loop.update_statement + if update_stmt is None: + return None + + update_codes_list = update_stmt.code if isinstance(update_stmt.code, list) else [update_stmt.code] + assignments: Dict[str, str] = {} + for code in update_codes_list: + visitor = FindAssignment() + visitor.visit(code) + if visitor.multiple: + return None + for assign in visitor.assignments: + if assign in assignments: + return None + assignments[assign] = visitor.assignments[assign] + + if loop.loop_variable in assignments: + return symbolic.pystr_to_symbolic(assignments[loop.loop_variable]) + + return None + + +def get_loop_stride(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: + update_assignment = get_update_assignment(loop) + if update_assignment: + return update_assignment - symbolic.pystr_to_symbolic(loop.loop_variable) + return None diff --git a/dace/transformation/passes/simplification/control_flow_raising.py b/dace/transformation/passes/simplification/control_flow_raising.py new file mode 100644 index 0000000000..abe305f12c --- /dev/null +++ b/dace/transformation/passes/simplification/control_flow_raising.py @@ -0,0 +1,96 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +from typing import Optional, Tuple +import networkx as nx +from dace import properties +from dace.sdfg.analysis import cfg as cfg_analysis +from dace.sdfg.sdfg import SDFG, InterstateEdge +from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, ControlFlowRegion +from dace.sdfg.utils import dfs_conditional +from dace.transformation import pass_pipeline as ppl, transformation +from dace.transformation.interstate.loop_lifting import LoopLifting + + +@properties.make_properties +@transformation.experimental_cfg_block_compatible +class ControlFlowRaising(ppl.Pass): + """ + Raises all detectable control flow that can be expressed with native SDFG structures, such as loops and branching. + """ + + CATEGORY: str = 'Simplification' + + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.CFG + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return modified & ppl.Modifies.CFG + + def _lift_conditionals(self, sdfg: SDFG) -> int: + cfgs = list(sdfg.all_control_flow_regions()) + n_cond_regions_pre = len([x for x in sdfg.all_control_flow_blocks() if isinstance(x, ConditionalBlock)]) + + for region in cfgs: + sinks = region.sink_nodes() + dummy_exit = region.add_state('__DACE_DUMMY') + for s in sinks: + region.add_edge(s, dummy_exit, InterstateEdge()) + idom = nx.immediate_dominators(region.nx, region.start_block) + alldoms = cfg_analysis.all_dominators(region, idom) + branch_merges = cfg_analysis.branch_merges(region, idom, alldoms) + + for block in region.nodes(): + graph = block.parent_graph + oedges = graph.out_edges(block) + if len(oedges) > 1 and block in branch_merges: + merge_block = branch_merges[block] + + # Construct the branching block. + conditional = ConditionalBlock('conditional_' + block.label, sdfg, graph) + graph.add_node(conditional) + # Connect it. + graph.add_edge(block, conditional, InterstateEdge()) + + # Populate branches. + for i, oe in enumerate(oedges): + branch_name = 'branch_' + str(i) + '_' + block.label + branch = ControlFlowRegion(branch_name, sdfg) + conditional.add_branch(oe.data.condition, branch) + if oe.dst is merge_block: + # Empty branch. + continue + + branch_nodes = set(dfs_conditional(graph, [oe.dst], lambda _, x: x is not merge_block)) + branch_start = branch.add_state(branch_name + '_start', is_start_block=True) + branch.add_nodes_from(branch_nodes) + branch_end = branch.add_state(branch_name + '_end') + branch.add_edge(branch_start, oe.dst, InterstateEdge(assignments=oe.data.assignments)) + added = set() + for e in graph.all_edges(*branch_nodes): + if not (e in added): + added.add(e) + if e is oe: + continue + elif e.dst is merge_block: + branch.add_edge(e.src, branch_end, e.data) + else: + branch.add_edge(e.src, e.dst, e.data) + graph.remove_nodes_from(branch_nodes) + + # Connect to the end of the branch / what happens after. + if merge_block is not dummy_exit: + graph.add_edge(conditional, merge_block, InterstateEdge()) + region.remove_node(dummy_exit) + + n_cond_regions_post = len([x for x in sdfg.all_control_flow_blocks() if isinstance(x, ConditionalBlock)]) + return n_cond_regions_post - n_cond_regions_pre + + def apply_pass(self, top_sdfg: SDFG, _) -> Optional[Tuple[int, int]]: + lifted_loops = 0 + lifted_branches = 0 + for sdfg in top_sdfg.all_sdfgs_recursive(): + lifted_loops += sdfg.apply_transformations_repeated([LoopLifting], validate_all=False, validate=False) + lifted_branches += self._lift_conditionals(sdfg) + if lifted_branches == 0 and lifted_loops == 0: + return None + return lifted_loops, lifted_branches diff --git a/dace/transformation/subgraph/expansion.py b/dace/transformation/subgraph/expansion.py index db1e9b59ab..aa182e8c80 100644 --- a/dace/transformation/subgraph/expansion.py +++ b/dace/transformation/subgraph/expansion.py @@ -1,26 +1,21 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ This module contains classes that implement the expansion transformation. """ -from dace import dtypes, registry, symbolic, subsets +from dace import dtypes, symbolic, subsets from dace.sdfg import nodes -from dace.memlet import Memlet from dace.sdfg import replace, SDFG, dynamic_map_inputs from dace.sdfg.graph import SubgraphView from dace.transformation import transformation from dace.properties import make_properties, Property -from dace.sdfg.propagation import propagate_memlets_sdfg from dace.transformation.subgraph import helpers from collections import defaultdict from copy import deepcopy as dcpy -from typing import List, Union import itertools -import dace.libraries.standard as stdlib import warnings -import sys def offset_map(state, map_entry): diff --git a/dace/transformation/subgraph/helpers.py b/dace/transformation/subgraph/helpers.py index b2af49c879..0ea1903522 100644 --- a/dace/transformation/subgraph/helpers.py +++ b/dace/transformation/subgraph/helpers.py @@ -1,20 +1,11 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Subgraph Transformation Helper API """ -from dace import dtypes, registry, symbolic, subsets -from dace.sdfg import nodes, utils -from dace.memlet import Memlet -from dace.sdfg import replace, SDFG, SDFGState -from dace.properties import make_properties, Property -from dace.sdfg.propagation import propagate_memlets_sdfg +from dace import subsets +from dace.sdfg import nodes from dace.sdfg.graph import SubgraphView -from collections import defaultdict import copy -from typing import List, Union, Dict, Tuple, Set - -import dace.libraries.standard as stdlib - -import itertools +from typing import List, Dict, Set # **************** # Helper functions diff --git a/tests/passes/simplification/control_flow_raising_test.py b/tests/passes/simplification/control_flow_raising_test.py new file mode 100644 index 0000000000..53e01df12f --- /dev/null +++ b/tests/passes/simplification/control_flow_raising_test.py @@ -0,0 +1,98 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +import dace +import numpy as np +from dace.sdfg.state import ConditionalBlock +from dace.transformation.pass_pipeline import FixedPointPipeline, Pipeline +from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising + + +def test_dataflow_if_check(): + + @dace.program + def dataflow_if_check(A: dace.int32[10], i: dace.int64): + if A[i] < 10: + return 0 + elif A[i] == 10: + return 10 + return 100 + + sdfg = dataflow_if_check.to_sdfg() + + assert not any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + + ppl = FixedPointPipeline([ControlFlowRaising()]) + ppl.__experimental_cfg_block_compatible__ = True + ppl.apply_pass(sdfg, {}) + + assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + + A = np.zeros((10,), np.int32) + A[4] = 10 + A[5] = 100 + assert sdfg(A, 0)[0] == 0 + assert sdfg(A, 4)[0] == 10 + assert sdfg(A, 5)[0] == 100 + assert sdfg(A, 6)[0] == 0 + + +def test_nested_if_chain(): + + @dace.program + def nested_if_chain(i: dace.int64): + if i < 2: + return 0 + else: + if i < 4: + return 1 + else: + if i < 6: + return 2 + else: + if i < 8: + return 3 + else: + return 4 + + sdfg = nested_if_chain.to_sdfg() + + assert not any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + + assert nested_if_chain(0)[0] == 0 + assert nested_if_chain(2)[0] == 1 + assert nested_if_chain(4)[0] == 2 + assert nested_if_chain(7)[0] == 3 + assert nested_if_chain(15)[0] == 4 + + +def test_elif_chain(): + + @dace.program + def elif_chain(i: dace.int64): + if i < 2: + return 0 + elif i < 4: + return 1 + elif i < 6: + return 2 + elif i < 8: + return 3 + else: + return 4 + + elif_chain.use_experimental_cfg_blocks = True + sdfg = elif_chain.to_sdfg() + + assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + + assert elif_chain(0)[0] == 0 + assert elif_chain(2)[0] == 1 + assert elif_chain(4)[0] == 2 + assert elif_chain(7)[0] == 3 + assert elif_chain(15)[0] == 4 + + +if __name__ == '__main__': + test_dataflow_if_check() + test_nested_if_chain() + test_elif_chain() diff --git a/tests/passes/writeset_underapproximation_test.py b/tests/passes/writeset_underapproximation_test.py index 7d5272d80a..96df87b5e7 100644 --- a/tests/passes/writeset_underapproximation_test.py +++ b/tests/passes/writeset_underapproximation_test.py @@ -1,7 +1,8 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +from typing import Dict import dace -from dace.sdfg.analysis.writeset_underapproximation import UnderapproximateWrites +from dace.sdfg.analysis.writeset_underapproximation import UnderapproximateWrites, UnderapproximateWritesDict from dace.subsets import Range from dace.transformation.pass_pipeline import Pipeline @@ -9,8 +10,6 @@ M = dace.symbol("M") K = dace.symbol("K") -pipeline = Pipeline([UnderapproximateWrites()]) - def test_2D_map_overwrites_2D_array(): """ @@ -33,9 +32,10 @@ def test_2D_map_overwrites_2D_array(): output_nodes={'B': a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results['approximation'] + result = results[sdfg.cfg_id].approximation edge = map_state.in_edges(a1)[0] result_subset_list = result[edge].subset.subset_list result_subset = result_subset_list[0] @@ -65,9 +65,10 @@ def test_2D_map_added_indices(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id].approximation edge = map_state.in_edges(a1)[0] assert (len(result[edge].subset.subset_list) == 0) @@ -94,9 +95,10 @@ def test_2D_map_multiplied_indices(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id].approximation edge = map_state.in_edges(a1)[0] assert (len(result[edge].subset.subset_list) == 0) @@ -121,9 +123,10 @@ def test_1D_map_one_index_multiple_dims(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id].approximation edge = map_state.in_edges(a1)[0] assert (len(result[edge].subset.subset_list) == 0) @@ -146,9 +149,10 @@ def test_1D_map_one_index_squared(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id].approximation edge = map_state.in_edges(a1)[0] assert (len(result[edge].subset.subset_list) == 0) @@ -185,9 +189,10 @@ def test_map_tree_full_write(): inner_edge_1 = map_state.add_edge(inner_map_exit_1, "OUT_B", map_exit, "IN_B", dace.Memlet(data="B")) outer_edge = map_state.add_edge(map_exit, "OUT_B", a1, None, dace.Memlet(data="B")) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id].approximation expected_subset_outer_edge = Range.from_string("0:M, 0:N") expected_subset_inner_edge = Range.from_string("0:M, _i") result_inner_edge_0 = result[inner_edge_0].subset.subset_list[0] @@ -230,9 +235,10 @@ def test_map_tree_no_write_multiple_indices(): inner_edge_1 = map_state.add_edge(inner_map_exit_1, "OUT_B", map_exit, "IN_B", dace.Memlet(data="B")) outer_edge = map_state.add_edge(map_exit, "OUT_B", a1, None, dace.Memlet(data="B")) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id].approximation result_inner_edge_0 = result[inner_edge_0].subset.subset_list result_inner_edge_1 = result[inner_edge_1].subset.subset_list result_outer_edge = result[outer_edge].subset.subset_list @@ -273,9 +279,10 @@ def test_map_tree_multiple_indices_per_dimension(): inner_edge_1 = map_state.add_edge(inner_map_exit_1, "OUT_B", map_exit, "IN_B", dace.Memlet(data="B")) outer_edge = map_state.add_edge(map_exit, "OUT_B", a1, None, dace.Memlet(data="B")) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["approximation"] + result = results[sdfg.cfg_id].approximation expected_subset_outer_edge = Range.from_string("0:M, 0:N") expected_subset_inner_edge_1 = Range.from_string("0:M, _i") result_inner_edge_1 = result[inner_edge_1].subset.subset_list[0] @@ -300,11 +307,12 @@ def loop(A: dace.float64[N, M]): sdfg = loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] nsdfg = sdfg.cfg_list[1].parent_nsdfg_node map_state = sdfg.states()[0] - result = results["approximation"] + result = results[sdfg.cfg_id].approximation edge = map_state.out_edges(nsdfg)[0] assert (len(result[edge].subset.subset_list) == 0) @@ -323,11 +331,12 @@ def loop(A: dace.float64[N, M]): sdfg = loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] map_state = sdfg.states()[0] edge = map_state.in_edges(map_state.data_nodes()[0])[0] - result = results["approximation"] + result = results[sdfg.cfg_id].approximation expected_subset = Range.from_string("0:N, 0:M") assert (str(result[edge].subset.subset_list[0]) == str(expected_subset)) @@ -357,9 +366,10 @@ def test_map_in_loop(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["loop_approximation"] + result = results[sdfg.cfg_id].loop_approximation expected_subset = Range.from_string("0:N, 0:M") assert (str(result[guard]["B"].subset.subset_list[0]) == str(expected_subset)) @@ -390,9 +400,10 @@ def test_map_in_loop_multiplied_indices_first_dimension(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["loop_approximation"] + result = results[sdfg.cfg_id].loop_approximation assert (guard not in result.keys() or len(result[guard]) == 0) @@ -421,9 +432,10 @@ def test_map_in_loop_multiplied_indices_second_dimension(): output_nodes={"B": a1}, external_edges=True) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["loop_approximation"] + result = results[sdfg.cfg_id].loop_approximation assert (guard not in result.keys() or len(result[guard]) == 0) @@ -444,8 +456,9 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - write_approx = result["approximation"] + write_approx = result[sdfg.cfg_id].approximation # find write set accessnode = None write_set = None @@ -478,9 +491,10 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - write_approx = result["approximation"] + write_approx = result[sdfg.cfg_id].approximation # find write set accessnode = None write_set = None @@ -510,15 +524,16 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - write_approx = result["approximation"] + write_approx = result[sdfg.cfg_id].approximation # find write set accessnode = None write_set = None - for node, _ in sdfg.all_nodes_recursive(): + for node, parent in sdfg.all_nodes_recursive(): if isinstance(node, dace.nodes.AccessNode): - if node.data == "A": + if node.data == "A" and parent.out_degree(node) == 0: accessnode = node for edge, memlet in write_approx.items(): if edge.dst is accessnode: @@ -531,6 +546,7 @@ def test_nested_sdfg_in_map_branches(): Nested SDFG that overwrites second dimension of array conditionally. --> should approximate write-set of map as empty """ + # No, should be approximated precisely - at least certainly with CF regions..? @dace.program def nested_loop(A: dace.float64[M, N]): @@ -542,15 +558,16 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] + pipeline = Pipeline([UnderapproximateWrites()]) + result: Dict[int, UnderapproximateWritesDict] = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - write_approx = result["approximation"] + write_approx = result[sdfg.cfg_id].approximation # find write set accessnode = None write_set = None - for node, _ in sdfg.all_nodes_recursive(): + for node, parent in sdfg.all_nodes_recursive(): if isinstance(node, dace.nodes.AccessNode): - if node.data == "A": + if node.data == "A" and parent.out_degree(node) == 0: accessnode = node for edge, memlet in write_approx.items(): if edge.dst is accessnode: @@ -574,9 +591,10 @@ def test_simple_loop_overwrite(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[i]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result: UnderapproximateWritesDict = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id] - assert (str(result[guard]["A"].subset) == str(Range.from_array(sdfg.arrays["A"]))) + assert (str(result.loop_approximation[guard]["A"].subset) == str(Range.from_array(sdfg.arrays["A"]))) def test_loop_2D_overwrite(): @@ -598,7 +616,8 @@ def test_loop_2D_overwrite(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[j,i]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id].loop_approximation assert (str(result[guard1]["A"].subset) == str(Range.from_array(sdfg.arrays["A"]))) assert (str(result[guard2]["A"].subset) == "j, 0:N") @@ -629,7 +648,8 @@ def test_loop_2D_propagation_gap_symbolic(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[j,i]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id].loop_approximation assert ("A" not in result[guard1].keys()) assert ("A" not in result[guard2].keys()) @@ -657,7 +677,8 @@ def test_2_loops_overwrite(): loop_tasklet_2 = loop_body_2.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body_2.add_edge(loop_tasklet_2, "a", a1, None, dace.Memlet("A[i]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id].loop_approximation assert (str(result[guard_1]["A"].subset) == str(Range.from_array(sdfg.arrays["A"]))) assert (str(result[guard_2]["A"].subset) == str(Range.from_array(sdfg.arrays["A"]))) @@ -687,7 +708,8 @@ def test_loop_2D_overwrite_propagation_gap_non_empty(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[j,i]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id].loop_approximation assert (str(result[guard1]["A"].subset) == str(Range.from_array(sdfg.arrays["A"]))) assert (str(result[guard2]["A"].subset) == "j, 0:N") @@ -717,7 +739,8 @@ def test_loop_nest_multiplied_indices(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[i,i*j]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id].loop_approximation assert (guard1 not in result.keys() or "A" not in result[guard1].keys()) assert (guard2 not in result.keys() or "A" not in result[guard2].keys()) @@ -748,7 +771,8 @@ def test_loop_nest_empty_nested_loop(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[j,i]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id].loop_approximation assert (guard1 not in result.keys() or "A" not in result[guard1].keys()) assert (guard2 not in result.keys() or "A" not in result[guard2].keys()) @@ -779,7 +803,8 @@ def test_loop_nest_inner_loop_conditional(): loop_tasklet = loop_body.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[k]")) - result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__]["loop_approximation"] + pipeline = Pipeline([UnderapproximateWrites()]) + result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__][sdfg.cfg_id].loop_approximation assert (guard1 not in result.keys() or "A" not in result[guard1].keys()) assert (guard2 in result.keys() and "A" in result[guard2].keys() and str(result[guard2]['A'].subset) == "0:N") @@ -799,9 +824,10 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - write_approx = result["approximation"] + write_approx = result[sdfg.cfg_id].approximation write_set = None accessnode = None for node, _ in sdfg.all_nodes_recursive(): @@ -828,10 +854,11 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] # find write set - write_approx = result["approximation"] + write_approx = result[sdfg.cfg_id].approximation accessnode = None write_set = None for node, _ in sdfg.all_nodes_recursive(): @@ -864,9 +891,10 @@ def test_loop_break(): loop_tasklet = loop_body_1.add_tasklet("overwrite", {}, {"a"}, "a = 0") loop_body_1.add_edge(loop_tasklet, "a", a0, None, dace.Memlet("A[i]")) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] - result = results["loop_approximation"] + result = results[sdfg.cfg_id].loop_approximation assert (guard3 not in result.keys() or "A" not in result[guard3].keys()) diff --git a/tests/sdfg/conditional_region_test.py b/tests/sdfg/conditional_region_test.py index 4e4eda3f44..0be40f43d3 100644 --- a/tests/sdfg/conditional_region_test.py +++ b/tests/sdfg/conditional_region_test.py @@ -10,20 +10,20 @@ def test_cond_region_if(): sdfg = dace.SDFG('regular_if') - sdfg.add_array("A", (1,), dace.float32) - sdfg.add_symbol("i", dace.int32) + sdfg.add_array('A', (1,), dace.float32) + sdfg.add_symbol('i', dace.int32) state0 = sdfg.add_state('state0', is_start_block=True) - if1 = ConditionalBlock("if1") + if1 = ConditionalBlock('if1') sdfg.add_node(if1) sdfg.add_edge(state0, if1, InterstateEdge()) - if_body = ControlFlowRegion("if_body", sdfg=sdfg) - if1.branches.append((CodeBlock("i == 1"), if_body)) + if_body = ControlFlowRegion('if_body', sdfg=sdfg) + if1.add_branch(CodeBlock('i == 1'), if_body) - state1 = if_body.add_state("state1", is_start_block=True) + state1 = if_body.add_state('state1', is_start_block=True) acc_a = state1.add_access('A') - t1 = state1.add_tasklet("t1", None, {"a"}, "a = 100") + t1 = state1.add_tasklet('t1', None, {'a'}, 'a = 100') state1.add_edge(t1, 'a', acc_a, None, dace.Memlet('A[0]')) assert sdfg.is_valid() @@ -36,14 +36,14 @@ def test_cond_region_if(): assert A[0] == 1 def test_serialization(): - sdfg = SDFG("test_serialization") - cond_region = ConditionalBlock("cond_region") + sdfg = SDFG('test_serialization') + cond_region = ConditionalBlock('cond_region') sdfg.add_node(cond_region, is_start_block=True) - sdfg.add_symbol("i", dace.int32) + sdfg.add_symbol('i', dace.int32) for j in range(10): - cfg = ControlFlowRegion(f"cfg_{j}", sdfg) - cond_region.branches.append((CodeBlock(f"i == {j}"), cfg)) + cfg = ControlFlowRegion(f'cfg_{j}', sdfg) + cond_region.add_branch(CodeBlock(f'i == {j}'), cfg) assert sdfg.is_valid() @@ -52,32 +52,32 @@ def test_serialization(): new_cond_region: ConditionalBlock = new_sdfg.nodes()[0] for j in range(10): condition, cfg = new_cond_region.branches[j] - assert condition == CodeBlock(f"i == {j}") - assert cfg.label == f"cfg_{j}" + assert condition == CodeBlock(f'i == {j}') + assert cfg.label == f'cfg_{j}' def test_if_else(): sdfg = dace.SDFG('regular_if_else') - sdfg.add_array("A", (1,), dace.float32) - sdfg.add_symbol("i", dace.int32) + sdfg.add_array('A', (1,), dace.float32) + sdfg.add_symbol('i', dace.int32) state0 = sdfg.add_state('state0', is_start_block=True) - if1 = ConditionalBlock("if1") + if1 = ConditionalBlock('if1') sdfg.add_node(if1) sdfg.add_edge(state0, if1, InterstateEdge()) - if_body = ControlFlowRegion("if_body", sdfg=sdfg) - state1 = if_body.add_state("state1", is_start_block=True) + if_body = ControlFlowRegion('if_body', sdfg=sdfg) + state1 = if_body.add_state('state1', is_start_block=True) acc_a = state1.add_access('A') - t1 = state1.add_tasklet("t1", None, {"a"}, "a = 100") + t1 = state1.add_tasklet('t1', None, {'a'}, 'a = 100') state1.add_edge(t1, 'a', acc_a, None, dace.Memlet('A[0]')) - if1.branches.append((CodeBlock("i == 1"), if_body)) + if1.add_branch(CodeBlock('i == 1'), if_body) - else_body = ControlFlowRegion("else_body", sdfg=sdfg) - state2 = else_body.add_state("state1", is_start_block=True) + else_body = ControlFlowRegion('else_body', sdfg=sdfg) + state2 = else_body.add_state('state1', is_start_block=True) acc_a2 = state2.add_access('A') - t2 = state2.add_tasklet("t2", None, {"a"}, "a = 200") + t2 = state2.add_tasklet('t2', None, {'a'}, 'a = 200') state2.add_edge(t2, 'a', acc_a2, None, dace.Memlet('A[0]')) - if1.branches.append((CodeBlock("i == 0"), else_body)) + if1.add_branch(CodeBlock('i == 0'), else_body) assert sdfg.is_valid() A = np.ones((1,), dtype=np.float32) diff --git a/tests/sdfg/loop_region_test.py b/tests/sdfg/loop_region_test.py index 6aca54f40c..dedafb67ba 100644 --- a/tests/sdfg/loop_region_test.py +++ b/tests/sdfg/loop_region_test.py @@ -86,6 +86,27 @@ def _make_do_for_loop() -> SDFG: return sdfg +def _make_do_for_inverted_cond_loop() -> SDFG: + sdfg = dace.SDFG('do_for_inverted_cond') + sdfg.using_experimental_blocks = True + sdfg.add_symbol('i', dace.int32) + sdfg.add_array('A', [10], dace.float32) + state0 = sdfg.add_state('state0', is_start_block=True) + loop1 = LoopRegion(label='loop1', condition_expr='i < 8', loop_var='i', initialize_expr='i = 0', + update_expr='i = i + 1', inverted=True, update_before_condition=False) + sdfg.add_node(loop1) + state1 = loop1.add_state('state1', is_start_block=True) + acc_a = state1.add_access('A') + t1 = state1.add_tasklet('t1', None, {'a'}, 'a = i') + state1.add_edge(t1, 'a', acc_a, None, dace.Memlet('A[i]')) + state2 = loop1.add_state('state2') + loop1.add_edge(state1, state2, dace.InterstateEdge()) + state3 = sdfg.add_state('state3') + sdfg.add_edge(state0, loop1, dace.InterstateEdge()) + sdfg.add_edge(loop1, state3, dace.InterstateEdge()) + return sdfg + + def _make_triple_nested_for_loop() -> SDFG: sdfg = dace.SDFG('gemm') sdfg.using_experimental_blocks = True @@ -177,6 +198,19 @@ def test_loop_do_for(): assert np.allclose(a_validation, a_test) +def test_loop_do_for_inverted_condition(): + sdfg = _make_do_for_inverted_cond_loop() + + assert sdfg.is_valid() + + a_validation = np.zeros([10], dtype=np.float32) + a_test = np.zeros([10], dtype=np.float32) + sdfg(A=a_test) + for i in range(9): + a_validation[i] = i + assert np.allclose(a_validation, a_test) + + def test_loop_triple_nested_for(): sdfg = _make_triple_nested_for_loop() @@ -249,6 +283,21 @@ def test_loop_to_stree_do_for(): f'{tn.INDENTATION}while (i < 10)') +def test_loop_to_stree_do_for_inverted_cond(): + sdfg = _make_do_for_inverted_cond_loop() + + assert sdfg.is_valid() + + stree = s2t.as_schedule_tree(sdfg) + + assert stree.as_string() == (f'{tn.INDENTATION}i = 0\n' + + f'{tn.INDENTATION}while True:\n' + + f'{2 * tn.INDENTATION}A[i] = tasklet()\n' + + f'{2 * tn.INDENTATION}if (not (i < 8)):\n' + + f'{3 * tn.INDENTATION}break\n' + + f'{2 * tn.INDENTATION}i = (i + 1)\n') + + def test_loop_to_stree_triple_nested_for(): sdfg = _make_triple_nested_for_loop() @@ -267,9 +316,11 @@ def test_loop_to_stree_triple_nested_for(): test_loop_regular_while() test_loop_do_while() test_loop_do_for() + test_loop_do_for_inverted_condition() test_loop_triple_nested_for() test_loop_to_stree_regular_for() test_loop_to_stree_regular_while() test_loop_to_stree_do_while() test_loop_to_stree_do_for() + test_loop_to_stree_do_for_inverted_cond() test_loop_to_stree_triple_nested_for() diff --git a/tests/transformations/interstate/loop_lifting_test.py b/tests/transformations/interstate/loop_lifting_test.py new file mode 100644 index 0000000000..20f244621c --- /dev/null +++ b/tests/transformations/interstate/loop_lifting_test.py @@ -0,0 +1,217 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +""" Tests loop raising trainsformations. """ + +import numpy as np +import pytest +import dace +from dace.memlet import Memlet +from dace.sdfg.sdfg import SDFG, InterstateEdge +from dace.sdfg.state import LoopRegion +from dace.transformation.interstate.loop_lifting import LoopLifting + + +def test_lift_regular_for_loop(): + sdfg = SDFG('regular_for') + N = dace.symbol('N') + sdfg.add_symbol('i', dace.int32) + sdfg.add_symbol('j', dace.int32) + sdfg.add_symbol('k', dace.int32) + sdfg.add_array('A', (N,), dace.int32) + start_state = sdfg.add_state('start', is_start_block=True) + init_state = sdfg.add_state('init') + guard_state = sdfg.add_state('guard') + main_state = sdfg.add_state('loop_state') + loop_exit = sdfg.add_state('exit') + final_state = sdfg.add_state('final') + sdfg.add_edge(start_state, init_state, InterstateEdge(assignments={'j': 0})) + sdfg.add_edge(init_state, guard_state, InterstateEdge(assignments={'i': 0, 'k': 0})) + sdfg.add_edge(guard_state, main_state, InterstateEdge(condition='i < N')) + sdfg.add_edge(main_state, guard_state, InterstateEdge(assignments={'i': 'i + 2', 'j': 'j + 1'})) + sdfg.add_edge(guard_state, loop_exit, InterstateEdge(condition='i >= N', assignments={'k': 2})) + sdfg.add_edge(loop_exit, final_state, InterstateEdge()) + a_access = main_state.add_access('A') + w_tasklet = main_state.add_tasklet('t1', {}, {'out'}, 'out = 1') + main_state.add_edge(w_tasklet, 'out', a_access, None, Memlet('A[i]')) + a_access_2 = loop_exit.add_access('A') + w_tasklet_2 = loop_exit.add_tasklet('t1', {}, {'out'}, 'out = k') + loop_exit.add_edge(w_tasklet_2, 'out', a_access_2, None, Memlet('A[1]')) + a_access_3 = final_state.add_access('A') + w_tasklet_3 = final_state.add_tasklet('t1', {}, {'out'}, 'out = j') + final_state.add_edge(w_tasklet_3, 'out', a_access_3, None, Memlet('A[3]')) + + N = 30 + A = np.zeros((N,)).astype(np.int32) + A_valid = np.zeros((N,)).astype(np.int32) + sdfg(A=A_valid, N=N) + sdfg.apply_transformations_repeated([LoopLifting]) + + assert sdfg.using_experimental_blocks == True + assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) + + sdfg(A=A, N=N) + + assert np.allclose(A_valid, A) + + +@pytest.mark.parametrize('increment_before_condition', (True, False)) +def test_lift_loop_llvm_canonical(increment_before_condition): + addendum = '_incr_before_cond' if increment_before_condition else '' + sdfg = dace.SDFG('llvm_canonical' + addendum) + N = dace.symbol('N') + sdfg.add_symbol('i', dace.int32) + sdfg.add_symbol('j', dace.int32) + sdfg.add_symbol('k', dace.int32) + sdfg.add_array('A', (N,), dace.int32) + + entry = sdfg.add_state('entry', is_start_block=True) + guard = sdfg.add_state('guard') + preheader = sdfg.add_state('preheader') + body = sdfg.add_state('body') + latch = sdfg.add_state('latch') + loopexit = sdfg.add_state('loopexit') + exitstate = sdfg.add_state('exitstate') + + sdfg.add_edge(entry, guard, InterstateEdge(assignments={'j': 0})) + sdfg.add_edge(guard, exitstate, InterstateEdge(condition='N <= 0')) + sdfg.add_edge(guard, preheader, InterstateEdge(condition='N > 0')) + sdfg.add_edge(preheader, body, InterstateEdge(assignments={'i': 0, 'k': 0})) + if increment_before_condition: + sdfg.add_edge(body, latch, InterstateEdge(assignments={'i': 'i + 2', 'j': 'j + 1'})) + sdfg.add_edge(latch, body, InterstateEdge(condition='i < N')) + sdfg.add_edge(latch, loopexit, InterstateEdge(condition='i >= N', assignments={'k': 2})) + else: + sdfg.add_edge(body, latch, InterstateEdge(assignments={'j': 'j + 1'})) + sdfg.add_edge(latch, body, InterstateEdge(condition='i < N - 2', assignments={'i': 'i + 2'})) + sdfg.add_edge(latch, loopexit, InterstateEdge(condition='i >= N - 2', assignments={'k': 2})) + sdfg.add_edge(loopexit, exitstate, InterstateEdge()) + + a_access = body.add_access('A') + w_tasklet = body.add_tasklet('t1', {}, {'out'}, 'out = 1') + body.add_edge(w_tasklet, 'out', a_access, None, Memlet('A[i]')) + a_access_2 = loopexit.add_access('A') + w_tasklet_2 = loopexit.add_tasklet('t1', {}, {'out'}, 'out = k') + loopexit.add_edge(w_tasklet_2, 'out', a_access_2, None, Memlet('A[1]')) + a_access_3 = exitstate.add_access('A') + w_tasklet_3 = exitstate.add_tasklet('t1', {}, {'out'}, 'out = j') + exitstate.add_edge(w_tasklet_3, 'out', a_access_3, None, Memlet('A[3]')) + + N = 30 + A = np.zeros((N,)).astype(np.int32) + A_valid = np.zeros((N,)).astype(np.int32) + sdfg(A=A_valid, N=N) + sdfg.apply_transformations_repeated([LoopLifting]) + + assert sdfg.using_experimental_blocks == True + assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) + + sdfg(A=A, N=N) + + assert np.allclose(A_valid, A) + + +def test_lift_loop_llvm_canonical_while(): + sdfg = dace.SDFG('llvm_canonical_while') + N = dace.symbol('N') + sdfg.add_symbol('j', dace.int32) + sdfg.add_symbol('k', dace.int32) + sdfg.add_array('A', (N,), dace.int32) + sdfg.add_scalar('i', dace.int32, transient=True) + + entry = sdfg.add_state('entry', is_start_block=True) + guard = sdfg.add_state('guard') + preheader = sdfg.add_state('preheader') + body = sdfg.add_state('body') + latch = sdfg.add_state('latch') + loopexit = sdfg.add_state('loopexit') + exitstate = sdfg.add_state('exitstate') + + sdfg.add_edge(entry, guard, InterstateEdge(assignments={'j': 0})) + sdfg.add_edge(guard, exitstate, InterstateEdge(condition='N <= 0')) + sdfg.add_edge(guard, preheader, InterstateEdge(condition='N > 0')) + sdfg.add_edge(preheader, body, InterstateEdge(assignments={'k': 0})) + sdfg.add_edge(body, latch, InterstateEdge(assignments={'j': 'j + 1'})) + sdfg.add_edge(latch, body, InterstateEdge(condition='i < N - 2')) + sdfg.add_edge(latch, loopexit, InterstateEdge(condition='i >= N - 2', assignments={'k': 2})) + sdfg.add_edge(loopexit, exitstate, InterstateEdge()) + + i_init_write = entry.add_access('i') + iw_init_tasklet = entry.add_tasklet('ti', {}, {'out'}, 'out = 0') + entry.add_edge(iw_init_tasklet, 'out', i_init_write, None, Memlet('i[0]')) + a_access = body.add_access('A') + w_tasklet = body.add_tasklet('t1', {}, {'out'}, 'out = 1') + body.add_edge(w_tasklet, 'out', a_access, None, Memlet('A[i]')) + i_read = body.add_access('i') + i_write = body.add_access('i') + iw_tasklet = body.add_tasklet('t2', {'in1'}, {'out'}, 'out = in1 + 2') + body.add_edge(i_read, None, iw_tasklet, 'in1', Memlet('i[0]')) + body.add_edge(iw_tasklet, 'out', i_write, None, Memlet('i[0]')) + a_access_2 = loopexit.add_access('A') + w_tasklet_2 = loopexit.add_tasklet('t1', {}, {'out'}, 'out = k') + loopexit.add_edge(w_tasklet_2, 'out', a_access_2, None, Memlet('A[1]')) + a_access_3 = exitstate.add_access('A') + w_tasklet_3 = exitstate.add_tasklet('t1', {}, {'out'}, 'out = j') + exitstate.add_edge(w_tasklet_3, 'out', a_access_3, None, Memlet('A[3]')) + + N = 30 + A = np.zeros((N,)).astype(np.int32) + A_valid = np.zeros((N,)).astype(np.int32) + sdfg(A=A_valid, N=N) + sdfg.apply_transformations_repeated([LoopLifting]) + + assert sdfg.using_experimental_blocks == True + assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) + + sdfg(A=A, N=N) + + assert np.allclose(A_valid, A) + + +def test_do_while(): + sdfg = SDFG('regular_for') + N = dace.symbol('N') + sdfg.add_symbol('i', dace.int32) + sdfg.add_symbol('j', dace.int32) + sdfg.add_symbol('k', dace.int32) + sdfg.add_array('A', (N,), dace.int32) + start_state = sdfg.add_state('start', is_start_block=True) + init_state = sdfg.add_state('init') + guard_state = sdfg.add_state('guard') + main_state = sdfg.add_state('loop_state') + loop_exit = sdfg.add_state('exit') + final_state = sdfg.add_state('final') + sdfg.add_edge(start_state, init_state, InterstateEdge(assignments={'j': 0})) + sdfg.add_edge(init_state, main_state, InterstateEdge(assignments={'i': 0, 'k': 0})) + sdfg.add_edge(main_state, guard_state, InterstateEdge(assignments={'i': 'i + 2', 'j': 'j + 1'})) + sdfg.add_edge(guard_state, main_state, InterstateEdge(condition='i < N')) + sdfg.add_edge(guard_state, loop_exit, InterstateEdge(condition='i >= N', assignments={'k': 2})) + sdfg.add_edge(loop_exit, final_state, InterstateEdge()) + a_access = main_state.add_access('A') + w_tasklet = main_state.add_tasklet('t1', {}, {'out'}, 'out = 1') + main_state.add_edge(w_tasklet, 'out', a_access, None, Memlet('A[i]')) + a_access_2 = loop_exit.add_access('A') + w_tasklet_2 = loop_exit.add_tasklet('t1', {}, {'out'}, 'out = k') + loop_exit.add_edge(w_tasklet_2, 'out', a_access_2, None, Memlet('A[1]')) + a_access_3 = final_state.add_access('A') + w_tasklet_3 = final_state.add_tasklet('t1', {}, {'out'}, 'out = j') + final_state.add_edge(w_tasklet_3, 'out', a_access_3, None, Memlet('A[3]')) + + N = 30 + A = np.zeros((N,)).astype(np.int32) + A_valid = np.zeros((N,)).astype(np.int32) + sdfg(A=A_valid, N=N) + sdfg.apply_transformations_repeated([LoopLifting]) + + assert sdfg.using_experimental_blocks == True + assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) + + sdfg(A=A, N=N) + + assert np.allclose(A_valid, A) + + +if __name__ == '__main__': + test_lift_regular_for_loop() + test_lift_loop_llvm_canonical(True) + test_lift_loop_llvm_canonical(False) + test_lift_loop_llvm_canonical_while() + test_do_while() diff --git a/tests/transformations/loop_detection_test.py b/tests/transformations/loop_detection_test.py index 5469f45762..323a27787a 100644 --- a/tests/transformations/loop_detection_test.py +++ b/tests/transformations/loop_detection_test.py @@ -27,7 +27,8 @@ def tester(a: dace.float64[20]): assert rng == (1, 19, 1) -def test_loop_rotated(): +@pytest.mark.parametrize('increment_before_condition', (True, False)) +def test_loop_rotated(increment_before_condition): sdfg = dace.SDFG('tester') sdfg.add_symbol('N', dace.int32) @@ -37,8 +38,12 @@ def test_loop_rotated(): exitstate = sdfg.add_state('exitstate') sdfg.add_edge(entry, body, dace.InterstateEdge(assignments=dict(i=0))) - sdfg.add_edge(body, latch, dace.InterstateEdge()) - sdfg.add_edge(latch, body, dace.InterstateEdge('i < N', assignments=dict(i='i + 2'))) + if increment_before_condition: + sdfg.add_edge(body, latch, dace.InterstateEdge(assignments=dict(i='i + 2'))) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N')) + else: + sdfg.add_edge(body, latch, dace.InterstateEdge()) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N', assignments=dict(i='i + 2'))) sdfg.add_edge(latch, exitstate, dace.InterstateEdge('i >= N')) xform = CountLoops() @@ -48,8 +53,9 @@ def test_loop_rotated(): assert rng == (0, dace.symbol('N') - 1, 2) -@pytest.mark.skip('Extra incrementation states should not be supported by loop detection') def test_loop_rotated_extra_increment(): + # Extra incrementation states (i.e., something more than a single edge between the latch and the body) should not + # be allowed and consequently not be detected as loops. sdfg = dace.SDFG('tester') sdfg.add_symbol('N', dace.int32) @@ -60,15 +66,13 @@ def test_loop_rotated_extra_increment(): exitstate = sdfg.add_state('exitstate') sdfg.add_edge(entry, body, dace.InterstateEdge(assignments=dict(i=0))) + sdfg.add_edge(body, latch, dace.InterstateEdge()) sdfg.add_edge(latch, increment, dace.InterstateEdge('i < N')) sdfg.add_edge(increment, body, dace.InterstateEdge(assignments=dict(i='i + 1'))) sdfg.add_edge(latch, exitstate, dace.InterstateEdge('i >= N')) xform = CountLoops() - assert sdfg.apply_transformations(xform) == 1 - itvar, rng, _ = xform.loop_information() - assert itvar == 'i' - assert rng == (0, dace.symbol('N') - 1, 1) + assert sdfg.apply_transformations(xform) == 0 def test_self_loop(): @@ -91,7 +95,8 @@ def test_self_loop(): assert rng == (2, dace.symbol('N') - 1, 3) -def test_loop_llvm_canonical(): +@pytest.mark.parametrize('increment_before_condition', (True, False)) +def test_loop_llvm_canonical(increment_before_condition): sdfg = dace.SDFG('tester') sdfg.add_symbol('N', dace.int32) @@ -106,8 +111,12 @@ def test_loop_llvm_canonical(): sdfg.add_edge(guard, exitstate, dace.InterstateEdge('N <= 0')) sdfg.add_edge(guard, preheader, dace.InterstateEdge('N > 0')) sdfg.add_edge(preheader, body, dace.InterstateEdge(assignments=dict(i=0))) - sdfg.add_edge(body, latch, dace.InterstateEdge()) - sdfg.add_edge(latch, body, dace.InterstateEdge('i < N', assignments=dict(i='i + 1'))) + if increment_before_condition: + sdfg.add_edge(body, latch, dace.InterstateEdge(assignments=dict(i='i + 1'))) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N')) + else: + sdfg.add_edge(body, latch, dace.InterstateEdge()) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N', assignments=dict(i='i + 1'))) sdfg.add_edge(latch, loopexit, dace.InterstateEdge('i >= N')) sdfg.add_edge(loopexit, exitstate, dace.InterstateEdge()) @@ -118,9 +127,10 @@ def test_loop_llvm_canonical(): assert rng == (0, dace.symbol('N') - 1, 1) -@pytest.mark.skip('Extra incrementation states should not be supported by loop detection') @pytest.mark.parametrize('with_bounds_check', (False, True)) def test_loop_llvm_canonical_with_extras(with_bounds_check): + # Extra incrementation states (i.e., something more than a single edge between the latch and the body) should not + # be allowed and consequently not be detected as loops. sdfg = dace.SDFG('tester') sdfg.add_symbol('N', dace.int32) @@ -148,17 +158,16 @@ def test_loop_llvm_canonical_with_extras(with_bounds_check): sdfg.add_edge(loopexit, exitstate, dace.InterstateEdge()) xform = CountLoops() - assert sdfg.apply_transformations(xform) == 1 - itvar, rng, _ = xform.loop_information() - assert itvar == 'i' - assert rng == (0, dace.symbol('N') - 1, 1) + assert sdfg.apply_transformations(xform) == 0 if __name__ == '__main__': test_pyloop() - test_loop_rotated() - # test_loop_rotated_extra_increment() + test_loop_rotated(True) + test_loop_rotated(False) + test_loop_rotated_extra_increment() test_self_loop() - test_loop_llvm_canonical() - # test_loop_llvm_canonical_with_extras(False) - # test_loop_llvm_canonical_with_extras(True) + test_loop_llvm_canonical(True) + test_loop_llvm_canonical(False) + test_loop_llvm_canonical_with_extras(False) + test_loop_llvm_canonical_with_extras(True)