Skip to content

Commit

Permalink
Missed the case of if we run of list only one side. Handle that, add
Browse files Browse the repository at this point in the history
more tests to cover everything.
  • Loading branch information
pratyai committed Oct 24, 2024
1 parent f363f34 commit ee927a3
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 14 deletions.
12 changes: 10 additions & 2 deletions dace/subsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1390,8 +1390,8 @@ def map(self, r: Range) -> Optional[Range]:
assert self.src.dims() == r.dims()
out = []
src_i, dst_i = 0, 0
while src_i < self.src.dims():
assert dst_i < self.dst.dims()
while src_i < self.src.dims() and dst_i < self.dst.dims():
# If we run out only on one side, handle that case after the loop.

# Find the next smallest segments of `src` and `dst` whose volumes matches (and therefore can possibly have
# a mapping).
Expand Down Expand Up @@ -1446,4 +1446,12 @@ def map(self, r: Range) -> Optional[Range]:
return None

src_i, dst_i = src_j, dst_j
if src_i < self.src.dims():
src_segment = Range(self.src.ranges[src_i: self.src.dims()])
assert src_segment.volume_exact() == 1
if dst_i < self.dst.dims():
# Take the remaining dst segment which must have a volume of 1 by now.
dst_segment = Range(self.dst.ranges[dst_i: self.dst.dims()])
assert dst_segment.volume_exact() == 1
out.extend(dst_segment.ranges)
return Range(out)
65 changes: 53 additions & 12 deletions tests/subsets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,11 @@ def test_mapping_with_reshaping(self):
dst = Range([(0, K - 1, 1), (0, N * M - 1, 1)])
# A Mapper
sm = SubrangeMapper(src, dst)
sm_inv = SubrangeMapper(dst, src)

# Pick the entire range.
self.assertEqual(dst, sm.map(src))
self.assertEqual(src, sm_inv.map(dst))

# NOTE: I couldn't make SymPy understand that `(K//2) % K == (K//2)` always holds for postive integers `K`.
# Hence, the numerical approach.
Expand All @@ -135,25 +137,64 @@ def test_mapping_with_reshaping(self):
np.random.randint(1, 10, size=20))]
# Pick a point K//2, N//2, M//2.
for args in argslist:
want = eval_range(
Range([(K // 2, K // 2, 1), ((N // 2) + (M // 2) * N, (N // 2) + (M // 2) * N, 1)]),
args)
got = eval_range(
sm.map(Range([(K // 2, K // 2, 1), (N // 2, N // 2, 1), (M // 2, M // 2, 1)])),
args)
orig = Range([(K // 2, K // 2, 1), (N // 2, N // 2, 1), (M // 2, M // 2, 1)])
orig_maps_to = Range([(K // 2, K // 2, 1), ((N // 2) + (M // 2) * N, (N // 2) + (M // 2) * N, 1)])
want, got = eval_range(orig_maps_to, args), eval_range(sm.map(orig), args)
self.assertEqual(want, got)
want, got = eval_range(orig, args), eval_range(sm_inv.map(orig_maps_to), args)
self.assertEqual(want, got)
# Pick a quadrant.
for args in argslist:
# But its mapping cannot be expressed as a simple range with offset and stride.
self.assertIsNone(sm.map(Range([(0, K // 2, 1), (0, N // 2, 1), (0, M // 2, 1)])))
# Pick only points in problematic quadrants, but larger subsets elsewhere.
for args in argslist:
want = eval_range(
Range([(0, K // 2, 1), ((N // 2) + (M // 2) * N, (N // 2) + (M // 2) * N, 1)]),
args)
got = eval_range(
sm.map(Range([(0, K // 2, 1), (N // 2, N // 2, 1), (M // 2, M // 2, 1)])),
args)
orig = Range([(0, K // 2, 1), (N // 2, N // 2, 1), (M // 2, M // 2, 1)])
orig_maps_to = Range([(0, K // 2, 1), ((N // 2) + (M // 2) * N, (N // 2) + (M // 2) * N, 1)])
want, got = eval_range(orig_maps_to, args), eval_range(sm.map(orig), args)
self.assertEqual(want, got)
want, got = eval_range(orig, args), eval_range(sm_inv.map(orig_maps_to), args)
self.assertEqual(want, got)

def test_mapping_with_reshaping_unit_dims(self):
K, N, M = dace.symbol('K', positive=True), dace.symbol('N', positive=True), dace.symbol('M', positive=True)

# A regular cube.
src = Range([(0, K - 1, 1), (0, N - 1, 1), (0, M - 1, 1), (0, 0, 1)])
# A regular cube with different shape.
dst = Range([(0, K - 1, 1), (0, N * M - 1, 1), (0, 0, 1), (0, 0, 1)])
# A Mapper
sm = SubrangeMapper(src, dst)
sm_inv = SubrangeMapper(dst, src)

# Pick the entire range.
self.assertEqual(dst, sm.map(src))
self.assertEqual(src, sm_inv.map(dst))

# NOTE: I couldn't make SymPy understand that `(K//2) % K == (K//2)` always holds for postive integers `K`.
# Hence, the numerical approach.
argslist = [{'K': k, 'N': n, 'M': m} for k, n, m in zip(np.random.randint(1, 10, size=20),
np.random.randint(1, 10, size=20),
np.random.randint(1, 10, size=20))]
# Pick a point K//2, N//2, M//2.
for args in argslist:
orig = Range([(K // 2, K // 2, 1), (N // 2, N // 2, 1), (M // 2, M // 2, 1), (0, 0, 1)])
orig_maps_to = Range([(K // 2, K // 2, 1), ((N // 2) + (M // 2) * N, (N // 2) + (M // 2) * N, 1), (0, 0, 1), (0, 0, 1)])
want, got = eval_range(orig_maps_to, args), eval_range(sm.map(orig), args)
self.assertEqual(want, got)
want, got = eval_range(orig, args), eval_range(sm_inv.map(orig_maps_to), args)
self.assertEqual(want, got)
# Pick a quadrant.
for args in argslist:
# But its mapping cannot be expressed as a simple range with offset and stride.
self.assertIsNone(sm.map(Range([(0, K // 2, 1), (0, N // 2, 1), (0, M // 2, 1), (0, 0, 1)])))
# Pick only points in problematic quadrants, but larger subsets elsewhere.
for args in argslist:
orig = Range([(0, K // 2, 1), (N // 2, N // 2, 1), (M // 2, M // 2, 1), (0, 0, 1)])
orig_maps_to = Range([(0, K // 2, 1), ((N // 2) + (M // 2) * N, (N // 2) + (M // 2) * N, 1), (0, 0, 1), (0, 0, 1)])
want, got = eval_range(orig_maps_to, args), eval_range(sm.map(orig), args)
self.assertEqual(want, got)
want, got = eval_range(orig, args), eval_range(sm_inv.map(orig_maps_to), args)
self.assertEqual(want, got)


Expand Down

0 comments on commit ee927a3

Please sign in to comment.