Skip to content

Commit

Permalink
Add subset difference
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Dec 3, 2024
1 parent 4d2b7b3 commit 6ff0a6e
Show file tree
Hide file tree
Showing 5 changed files with 354 additions and 208 deletions.
99 changes: 95 additions & 4 deletions dace/subsets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
from copy import deepcopy
import dace.serialize
from dace import symbolic
import re
Expand Down Expand Up @@ -820,7 +821,7 @@ def replace(self, repl_dict):
rs.subs(repl_dict) if symbolic.issymbolic(rs) else rs)
self.tile_sizes[i] = (ts.subs(repl_dict) if symbolic.issymbolic(ts) else ts)

def intersection(self, other: 'Range') -> 'Range':
def intersection(self, other: 'Range') -> Optional['Range']:
type_error = False
expected_length = len(self.ranges)
if expected_length != len(other.ranges):
Expand All @@ -830,7 +831,7 @@ def intersection(self, other: 'Range') -> 'Range':
for i, (rng, orng) in enumerate(zip(self.ranges, other.ranges)):
if (rng[2] != 1 or orng[2] != 1 or self.tile_sizes[i] != 1 or other.tile_sizes[i] != 1):
# TODO: This function does not consider strides or tiles
return None
raise NotImplementedError('^This function does not yet consider strides or tiles')

# Special case: ranges match
if rng[0] == orng[0] and rng[1] == orng[1]:
Expand All @@ -852,9 +853,9 @@ def intersection(self, other: 'Range') -> 'Range':
return None

if cond3 == True:
rng_start = rng[0]
else:
rng_start = orng[0]
else:
rng_start = rng[0]
if cond4 == True:
rng_end = rng[1]
else:
Expand All @@ -874,6 +875,44 @@ def intersection(self, other: 'Range') -> 'Range':
def intersects(self, other: 'Range'):
return self.intersection(other) is not None

def difference(self, other: 'Range') -> Subset:
isect = self.intersection(other)
if isect is None:
return self
diff_ranges = [[]]
for i, (r1, r2) in enumerate(zip(self.ranges, isect.ranges)):
if r2[0] == r1[0]:
# Intersection over the start of the current range.
if r2[1] == r1[1]:
return Range([])
else:
for dr in diff_ranges:
dr.append((r2[1] + self.ranges[i][2], r1[1], r1[2]))
elif r2[1] == r1[1]:
# Intersection over the end of the current range.
if r2[0] == r1[0]:
return Range([])
else:
for dr in diff_ranges:
dr.append((r1[0], r2[0] - self.ranges[i][2], r1[2]))
else:
# Intersection completely contained inside the current range, split into subset union is necessary.
split_left = (r1[0], r2[0] - self.ranges[i][2], r1[2])
split_right = (r2[1] + self.ranges[i][2], r1[1], r1[2])
for dr in diff_ranges:
dr.append(split_left)
dr_copy = deepcopy(diff_ranges)
for dr in dr_copy:
dr[-1] = split_right
diff_ranges.append(dr)
if len(diff_ranges) == 1:
return Range(diff_ranges[0])
else:
subset_list = []
for dr in diff_ranges:
subset_list.append(Range(dr))
return SubsetUnion(subset_list)


@dace.serialize.serializable
class Indices(Subset):
Expand Down Expand Up @@ -1246,6 +1285,40 @@ def intersection(self, other: Subset) -> 'SubsetUnion':
def intersects(self, other: Subset):
return self.intersection(other) is not None

def difference(self, other: Subset) -> 'SubsetUnion':
try:
if isinstance(other, SubsetUnion):
differences = []
for subs in self.subset_list:
sub_diff = subs
for osubs in other.subset_list:
sub_diff = difference(sub_diff, osubs)
if sub_diff is not None:
if isinstance(sub_diff, SubsetUnion):
for s in sub_diff.subset_list:
differences.append(s)
else:
differences.append(sub_diff)
if differences:
return SubsetUnion(differences)
elif isinstance(other, (Indices, Range)):
differences = []
for subs in self.subset_list:
diff = difference(subs, other)
if diff is not None:
if isinstance(diff, SubsetUnion):
for sub_diff in diff.subset_list:
differences.append(sub_diff)
else:
differences.append(diff)
if differences:
return SubsetUnion(differences)
else:
raise TypeError
except TypeError:
return None
pass

@property
def free_symbols(self) -> Set[str]:
result = set()
Expand Down Expand Up @@ -1460,3 +1533,21 @@ def intersection(subset_a: Subset, subset_b: Subset) -> Optional[Subset]:
return None
except TypeError:
return None

def difference(subset_a: Subset, subset_b: Subset) -> Optional[Subset]:
try:
if subset_a is None or subset_b is None:
return None
if isinstance(subset_a, Indices):
subset_a = Range.from_indices(subset_a)
if isinstance(subset_b, Indices):
subset_b = Range.from_indices(subset_b)
if type(subset_a) is type(subset_b):
return subset_a.difference(subset_b)
elif isinstance(subset_a, SubsetUnion):
return subset_a.difference(subset_b)
elif isinstance(subset_b, SubsetUnion):
return subset_b.difference(subset_a)
return None
except TypeError:
return None
27 changes: 15 additions & 12 deletions dace/transformation/passes/analysis/propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from dace.sdfg.scope import ScopeTree
from dace.sdfg.sdfg import SDFG, memlets_in_ast
from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion, SDFGState
from dace.subsets import Range, SubsetUnion
from dace.subsets import Range, SubsetUnion, intersection
from dace.transformation import pass_pipeline as ppl
from dace.transformation import transformation
from dace.transformation.helpers import unsqueeze_memlet
Expand Down Expand Up @@ -542,21 +542,24 @@ def _propagate_loop(self, loop: LoopRegion) -> None:
symbols_at_loop.update(new_symbols)
pivot = pivot.parent_graph
defined_symbols = [symbolic.pystr_to_symbolic(s) for s in symbols_at_loop.keys()]
repos_to_propagate = [(loop._certain_reads, False),
(loop._certain_writes, True),
(loop._possible_reads, False),
(loop._possible_writes, True)]
# Propagate memlet subsets through the loop variable and its range.
for (memlet_repo, use_dst) in repos_to_propagate:
for memlet_repo in [loop._certain_reads, loop._possible_reads]:
for dat in memlet_repo.keys():
memlet = memlet_repo[dat]
read = memlet_repo[dat]
arr = loop.sdfg.data(dat)
if memlet in deps:
dep_write = deps[memlet]
print(memlet)
if read in deps:
dep_write = deps[read]
inters = intersection(dep_write.subset, read.subset)
...
else:
new_memlet = propagate_subset([memlet], arr, [itvar], loop_range, defined_symbols, use_dst)
memlet_repo[dat] = new_memlet
new_read = propagate_subset([read], arr, [itvar], loop_range, defined_symbols, use_dst=False)
memlet_repo[dat] = new_read
for memlet_repo in [loop._certain_writes, loop._possible_writes]:
for dat in memlet_repo.keys():
write = memlet_repo[dat]
arr = loop.sdfg.data(dat)
new_write = propagate_subset([write], arr, [itvar], loop_range, defined_symbols, use_dst=True)
memlet_repo[dat] = new_write

def _propagate_cfg(self, cfg: ControlFlowRegion) -> None:
cfg._possible_reads = {}
Expand Down
78 changes: 0 additions & 78 deletions tests/subset_intersects_test.py

This file was deleted.

114 changes: 0 additions & 114 deletions tests/subsets_squeeze_test.py

This file was deleted.

Loading

0 comments on commit 6ff0a6e

Please sign in to comment.