diff --git a/dace/subsets.py b/dace/subsets.py index 66ba05b14f..9d5e422c0d 100644 --- a/dace/subsets.py +++ b/dace/subsets.py @@ -1,6 +1,6 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import dace.serialize -from dace import data, symbolic, dtypes +from dace import symbolic import re import sympy as sp from functools import reduce @@ -1243,6 +1243,9 @@ def intersection(self, other: Subset) -> 'SubsetUnion': except TypeError: return None + def intersects(self, other: Subset): + return self.intersection(other) is not None + @property def free_symbols(self) -> Set[str]: result = set() @@ -1349,17 +1352,13 @@ def bounding_box_union(subset_a: Subset, subset_b: Subset) -> Range: return Range(result) - - def union(subset_a: Subset, subset_b: Subset) -> Subset: """ Compute the union of two Subset objects. - If the subsets are not of the same type, degenerates to bounding-box - union. + If the subsets are not of the same type, degenerates to bounding-box union. :param subset_a: The first subset. :param subset_b: The second subset. - :return: A Subset object whose size is at least the union of the two - inputs. If union failed, returns None. + :return: A Subset object whose size is at least the union of the two inputs. If union failed, returns None. """ try: @@ -1369,8 +1368,7 @@ def union(subset_a: Subset, subset_b: Subset) -> Subset: return subset_b elif subset_a is None and subset_b is None: raise TypeError('Both subsets cannot be None') - elif isinstance(subset_a, SubsetUnion) or isinstance( - subset_b, SubsetUnion): + elif isinstance(subset_a, SubsetUnion) or isinstance(subset_b, SubsetUnion): return list_union(subset_a, subset_b) elif type(subset_a) != type(subset_b): return bounding_box_union(subset_a, subset_b) @@ -1437,6 +1435,10 @@ def intersects(subset_a: Subset, subset_b: Subset) -> Union[bool, None]: subset_b = Range.from_indices(subset_b) if type(subset_a) is type(subset_b): return subset_a.intersects(subset_b) + elif isinstance(subset_a, SubsetUnion): + return subset_a.intersects(subset_b) + elif isinstance(subset_b, SubsetUnion): + return subset_b.intersects(subset_a) return None except TypeError: # cannot determine truth value of Relational return None @@ -1451,6 +1453,10 @@ def intersection(subset_a: Subset, subset_b: Subset) -> Optional[Subset]: subset_b = Range.from_indices(subset_b) if type(subset_a) is type(subset_b): return subset_a.intersection(subset_b) + elif isinstance(subset_a, SubsetUnion): + return subset_a.intersection(subset_b) + elif isinstance(subset_b, SubsetUnion): + return subset_b.intersection(subset_a) return None except TypeError: return None diff --git a/dace/transformation/passes/analysis/loop_analysis.py b/dace/transformation/passes/analysis/loop_analysis.py index 69a77422e8..fa6a516bdd 100644 --- a/dace/transformation/passes/analysis/loop_analysis.py +++ b/dace/transformation/passes/analysis/loop_analysis.py @@ -3,13 +3,15 @@ Various analyses concerning LopoRegions, and utility functions to get information about LoopRegions for other passes. """ -from typing import Dict, Optional -from dace.frontend.python import astutils +from typing import Dict, Optional, Set, Union import sympy from dace import symbolic +from dace.frontend.python import astutils +from dace.memlet import Memlet from dace.sdfg.state import LoopRegion +from dace.subsets import Range, SubsetUnion, intersects def get_loop_end(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: @@ -97,3 +99,59 @@ def get_loop_stride(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: if update_assignment: return update_assignment - symbolic.pystr_to_symbolic(loop.loop_variable) return None + + +def _loop_read_intersects_loop_write(loop: LoopRegion, write_subset: Union[SubsetUnion, Range], + read_subset: Union[SubsetUnion, Range], update: sympy.Basic) -> bool: + """ + Check if a write subset intersects a read subset after being offset by the loop stride. The offset is performed + based on the symbolic loop update assignment expression. + """ + offset = update - symbolic.symbol(loop.loop_variable) + offset_list = [] + for i in range(write_subset.dims()): + if loop.loop_variable in write_subset.get_free_symbols_by_indices([i]): + offset_list.append(offset) + else: + offset_list.append(0) + offset_write = write_subset.offset_new(offset_list, True) + return intersects(offset_write, read_subset) + +def get_loop_carry_dependencies(loop: LoopRegion) -> Optional[Dict[Memlet, Memlet]]: + """ + Compute loop carry dependencies. + :return: A dictionary mapping loop reads to writes in the same loop, from which they may carry a RAW dependency. + None if the loop cannot be analyzed. + """ + update_assignment = None + raw_deps: Dict[Memlet, Memlet] = dict() + for data in loop._possible_reads: + if not data in loop._possible_writes: + continue + + input = loop._possible_reads[data] + read_subset = input.src_subset or input.subset + if loop.loop_variable and loop.loop_variable in input.free_symbols: + # If the iteration variable is involved in an access, we need to first offset it by the loop + # stride and then check for an overlap/intersection. If one is found after offsetting, there + # is a RAW loop carry dependency. + output = loop._possible_writes[data] + # Get and cache the update assignment for the loop. + if update_assignment is None: + update_assignment = get_update_assignment(loop) + if update_assignment is None: + return None + + if isinstance(output.subset, SubsetUnion): + if any([_loop_read_intersects_loop_write(loop, s, read_subset, update_assignment) + for s in output.subset.subset_list]): + raw_deps[input] = output + elif _loop_read_intersects_loop_write(loop, output.subset, read_subset, update_assignment): + raw_deps[input] = output + else: + # Check for basic overlaps/intersections in RAW loop carry dependencies, when there is no + # iteration variable involved. + output = loop._possible_writes[data] + if intersects(output.subset, read_subset): + raw_deps[input] = output + return raw_deps diff --git a/dace/transformation/passes/analysis/propagation.py b/dace/transformation/passes/analysis/propagation.py index d3d5545c34..0401d0cf95 100644 --- a/dace/transformation/passes/analysis/propagation.py +++ b/dace/transformation/passes/analysis/propagation.py @@ -519,6 +519,7 @@ def _propagate_loop(self, loop: LoopRegion) -> None: self._propagate_cfg(loop) # TODO: Remove loop-carried dependencies from the writes (i.e., only the first read would be a true read) + deps = loop_analysis.get_loop_carry_dependencies(loop) # Propagate memlets from inside the loop through the loop ranges. # Collect loop information and form the loop variable range first. @@ -550,8 +551,12 @@ def _propagate_loop(self, loop: LoopRegion) -> None: for dat in memlet_repo.keys(): memlet = memlet_repo[dat] arr = loop.sdfg.data(dat) - new_memlet = propagate_subset([memlet], arr, [itvar], loop_range, defined_symbols, use_dst) - memlet_repo[dat] = new_memlet + if memlet in deps: + dep_write = deps[memlet] + print(memlet) + else: + new_memlet = propagate_subset([memlet], arr, [itvar], loop_range, defined_symbols, use_dst) + memlet_repo[dat] = new_memlet def _propagate_cfg(self, cfg: ControlFlowRegion) -> None: cfg._possible_reads = {}