Skip to content

Commit

Permalink
Loop dependency analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Dec 3, 2024
1 parent 3a7634f commit 4d2b7b3
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 14 deletions.
26 changes: 16 additions & 10 deletions dace/subsets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
import dace.serialize
from dace import data, symbolic, dtypes
from dace import symbolic
import re
import sympy as sp
from functools import reduce
Expand Down Expand Up @@ -1243,6 +1243,9 @@ def intersection(self, other: Subset) -> 'SubsetUnion':
except TypeError:
return None

def intersects(self, other: Subset):
return self.intersection(other) is not None

@property
def free_symbols(self) -> Set[str]:
result = set()
Expand Down Expand Up @@ -1349,17 +1352,13 @@ def bounding_box_union(subset_a: Subset, subset_b: Subset) -> Range:
return Range(result)




def union(subset_a: Subset, subset_b: Subset) -> Subset:
""" Compute the union of two Subset objects.
If the subsets are not of the same type, degenerates to bounding-box
union.
If the subsets are not of the same type, degenerates to bounding-box union.
:param subset_a: The first subset.
:param subset_b: The second subset.
:return: A Subset object whose size is at least the union of the two
inputs. If union failed, returns None.
:return: A Subset object whose size is at least the union of the two inputs. If union failed, returns None.
"""
try:

Expand All @@ -1369,8 +1368,7 @@ def union(subset_a: Subset, subset_b: Subset) -> Subset:
return subset_b
elif subset_a is None and subset_b is None:
raise TypeError('Both subsets cannot be None')
elif isinstance(subset_a, SubsetUnion) or isinstance(
subset_b, SubsetUnion):
elif isinstance(subset_a, SubsetUnion) or isinstance(subset_b, SubsetUnion):
return list_union(subset_a, subset_b)
elif type(subset_a) != type(subset_b):
return bounding_box_union(subset_a, subset_b)
Expand Down Expand Up @@ -1437,6 +1435,10 @@ def intersects(subset_a: Subset, subset_b: Subset) -> Union[bool, None]:
subset_b = Range.from_indices(subset_b)
if type(subset_a) is type(subset_b):
return subset_a.intersects(subset_b)
elif isinstance(subset_a, SubsetUnion):
return subset_a.intersects(subset_b)
elif isinstance(subset_b, SubsetUnion):
return subset_b.intersects(subset_a)
return None
except TypeError: # cannot determine truth value of Relational
return None
Expand All @@ -1451,6 +1453,10 @@ def intersection(subset_a: Subset, subset_b: Subset) -> Optional[Subset]:
subset_b = Range.from_indices(subset_b)
if type(subset_a) is type(subset_b):
return subset_a.intersection(subset_b)
elif isinstance(subset_a, SubsetUnion):
return subset_a.intersection(subset_b)
elif isinstance(subset_b, SubsetUnion):
return subset_b.intersection(subset_a)
return None
except TypeError:
return None
62 changes: 60 additions & 2 deletions dace/transformation/passes/analysis/loop_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
Various analyses concerning LopoRegions, and utility functions to get information about LoopRegions for other passes.
"""

from typing import Dict, Optional
from dace.frontend.python import astutils
from typing import Dict, Optional, Set, Union

import sympy

from dace import symbolic
from dace.frontend.python import astutils
from dace.memlet import Memlet
from dace.sdfg.state import LoopRegion
from dace.subsets import Range, SubsetUnion, intersects


def get_loop_end(loop: LoopRegion) -> Optional[symbolic.SymbolicType]:
Expand Down Expand Up @@ -97,3 +99,59 @@ def get_loop_stride(loop: LoopRegion) -> Optional[symbolic.SymbolicType]:
if update_assignment:
return update_assignment - symbolic.pystr_to_symbolic(loop.loop_variable)
return None


def _loop_read_intersects_loop_write(loop: LoopRegion, write_subset: Union[SubsetUnion, Range],
read_subset: Union[SubsetUnion, Range], update: sympy.Basic) -> bool:
"""
Check if a write subset intersects a read subset after being offset by the loop stride. The offset is performed
based on the symbolic loop update assignment expression.
"""
offset = update - symbolic.symbol(loop.loop_variable)
offset_list = []
for i in range(write_subset.dims()):
if loop.loop_variable in write_subset.get_free_symbols_by_indices([i]):
offset_list.append(offset)
else:
offset_list.append(0)
offset_write = write_subset.offset_new(offset_list, True)
return intersects(offset_write, read_subset)

def get_loop_carry_dependencies(loop: LoopRegion) -> Optional[Dict[Memlet, Memlet]]:
"""
Compute loop carry dependencies.
:return: A dictionary mapping loop reads to writes in the same loop, from which they may carry a RAW dependency.
None if the loop cannot be analyzed.
"""
update_assignment = None
raw_deps: Dict[Memlet, Memlet] = dict()
for data in loop._possible_reads:
if not data in loop._possible_writes:
continue

input = loop._possible_reads[data]
read_subset = input.src_subset or input.subset
if loop.loop_variable and loop.loop_variable in input.free_symbols:
# If the iteration variable is involved in an access, we need to first offset it by the loop
# stride and then check for an overlap/intersection. If one is found after offsetting, there
# is a RAW loop carry dependency.
output = loop._possible_writes[data]
# Get and cache the update assignment for the loop.
if update_assignment is None:
update_assignment = get_update_assignment(loop)
if update_assignment is None:
return None

if isinstance(output.subset, SubsetUnion):
if any([_loop_read_intersects_loop_write(loop, s, read_subset, update_assignment)
for s in output.subset.subset_list]):
raw_deps[input] = output
elif _loop_read_intersects_loop_write(loop, output.subset, read_subset, update_assignment):
raw_deps[input] = output
else:
# Check for basic overlaps/intersections in RAW loop carry dependencies, when there is no
# iteration variable involved.
output = loop._possible_writes[data]
if intersects(output.subset, read_subset):
raw_deps[input] = output
return raw_deps
9 changes: 7 additions & 2 deletions dace/transformation/passes/analysis/propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,7 @@ def _propagate_loop(self, loop: LoopRegion) -> None:
self._propagate_cfg(loop)

# TODO: Remove loop-carried dependencies from the writes (i.e., only the first read would be a true read)
deps = loop_analysis.get_loop_carry_dependencies(loop)

# Propagate memlets from inside the loop through the loop ranges.
# Collect loop information and form the loop variable range first.
Expand Down Expand Up @@ -550,8 +551,12 @@ def _propagate_loop(self, loop: LoopRegion) -> None:
for dat in memlet_repo.keys():
memlet = memlet_repo[dat]
arr = loop.sdfg.data(dat)
new_memlet = propagate_subset([memlet], arr, [itvar], loop_range, defined_symbols, use_dst)
memlet_repo[dat] = new_memlet
if memlet in deps:
dep_write = deps[memlet]
print(memlet)
else:
new_memlet = propagate_subset([memlet], arr, [itvar], loop_range, defined_symbols, use_dst)
memlet_repo[dat] = new_memlet

def _propagate_cfg(self, cfg: ControlFlowRegion) -> None:
cfg._possible_reads = {}
Expand Down

0 comments on commit 4d2b7b3

Please sign in to comment.