diff --git a/dace/subsets.py b/dace/subsets.py index aa3a3269e7..be8946c945 100644 --- a/dace/subsets.py +++ b/dace/subsets.py @@ -1390,8 +1390,8 @@ def map(self, r: Range) -> Optional[Range]: assert self.src.dims() == r.dims() out = [] src_i, dst_i = 0, 0 - while src_i < self.src.dims(): - assert dst_i < self.dst.dims() + while src_i < self.src.dims() and dst_i < self.dst.dims(): + # If we run out only on one side, handle that case after the loop. # Find the next smallest segments of `src` and `dst` whose volumes matches (and therefore can possibly have # a mapping). @@ -1446,4 +1446,12 @@ def map(self, r: Range) -> Optional[Range]: return None src_i, dst_i = src_j, dst_j + if src_i < self.src.dims(): + src_segment = Range(self.src.ranges[src_i: self.src.dims()]) + assert src_segment.volume_exact() == 1 + if dst_i < self.dst.dims(): + # Take the remaining dst segment which must have a volume of 1 by now. + dst_segment = Range(self.dst.ranges[dst_i: self.dst.dims()]) + assert dst_segment.volume_exact() == 1 + out.extend(dst_segment.ranges) return Range(out) diff --git a/tests/subsets_test.py b/tests/subsets_test.py index 4af4449c72..cc8a304400 100644 --- a/tests/subsets_test.py +++ b/tests/subsets_test.py @@ -124,9 +124,11 @@ def test_mapping_with_reshaping(self): dst = Range([(0, K - 1, 1), (0, N * M - 1, 1)]) # A Mapper sm = SubrangeMapper(src, dst) + sm_inv = SubrangeMapper(dst, src) # Pick the entire range. self.assertEqual(dst, sm.map(src)) + self.assertEqual(src, sm_inv.map(dst)) # NOTE: I couldn't make SymPy understand that `(K//2) % K == (K//2)` always holds for postive integers `K`. # Hence, the numerical approach. @@ -135,12 +137,11 @@ def test_mapping_with_reshaping(self): np.random.randint(1, 10, size=20))] # Pick a point K//2, N//2, M//2. for args in argslist: - want = eval_range( - Range([(K // 2, K // 2, 1), ((N // 2) + (M // 2) * N, (N // 2) + (M // 2) * N, 1)]), - args) - got = eval_range( - sm.map(Range([(K // 2, K // 2, 1), (N // 2, N // 2, 1), (M // 2, M // 2, 1)])), - args) + orig = Range([(K // 2, K // 2, 1), (N // 2, N // 2, 1), (M // 2, M // 2, 1)]) + orig_maps_to = Range([(K // 2, K // 2, 1), ((N // 2) + (M // 2) * N, (N // 2) + (M // 2) * N, 1)]) + want, got = eval_range(orig_maps_to, args), eval_range(sm.map(orig), args) + self.assertEqual(want, got) + want, got = eval_range(orig, args), eval_range(sm_inv.map(orig_maps_to), args) self.assertEqual(want, got) # Pick a quadrant. for args in argslist: @@ -148,12 +149,52 @@ def test_mapping_with_reshaping(self): self.assertIsNone(sm.map(Range([(0, K // 2, 1), (0, N // 2, 1), (0, M // 2, 1)]))) # Pick only points in problematic quadrants, but larger subsets elsewhere. for args in argslist: - want = eval_range( - Range([(0, K // 2, 1), ((N // 2) + (M // 2) * N, (N // 2) + (M // 2) * N, 1)]), - args) - got = eval_range( - sm.map(Range([(0, K // 2, 1), (N // 2, N // 2, 1), (M // 2, M // 2, 1)])), - args) + orig = Range([(0, K // 2, 1), (N // 2, N // 2, 1), (M // 2, M // 2, 1)]) + orig_maps_to = Range([(0, K // 2, 1), ((N // 2) + (M // 2) * N, (N // 2) + (M // 2) * N, 1)]) + want, got = eval_range(orig_maps_to, args), eval_range(sm.map(orig), args) + self.assertEqual(want, got) + want, got = eval_range(orig, args), eval_range(sm_inv.map(orig_maps_to), args) + self.assertEqual(want, got) + + def test_mapping_with_reshaping_unit_dims(self): + K, N, M = dace.symbol('K', positive=True), dace.symbol('N', positive=True), dace.symbol('M', positive=True) + + # A regular cube. + src = Range([(0, K - 1, 1), (0, N - 1, 1), (0, M - 1, 1), (0, 0, 1)]) + # A regular cube with different shape. + dst = Range([(0, K - 1, 1), (0, N * M - 1, 1), (0, 0, 1), (0, 0, 1)]) + # A Mapper + sm = SubrangeMapper(src, dst) + sm_inv = SubrangeMapper(dst, src) + + # Pick the entire range. + self.assertEqual(dst, sm.map(src)) + self.assertEqual(src, sm_inv.map(dst)) + + # NOTE: I couldn't make SymPy understand that `(K//2) % K == (K//2)` always holds for postive integers `K`. + # Hence, the numerical approach. + argslist = [{'K': k, 'N': n, 'M': m} for k, n, m in zip(np.random.randint(1, 10, size=20), + np.random.randint(1, 10, size=20), + np.random.randint(1, 10, size=20))] + # Pick a point K//2, N//2, M//2. + for args in argslist: + orig = Range([(K // 2, K // 2, 1), (N // 2, N // 2, 1), (M // 2, M // 2, 1), (0, 0, 1)]) + orig_maps_to = Range([(K // 2, K // 2, 1), ((N // 2) + (M // 2) * N, (N // 2) + (M // 2) * N, 1), (0, 0, 1), (0, 0, 1)]) + want, got = eval_range(orig_maps_to, args), eval_range(sm.map(orig), args) + self.assertEqual(want, got) + want, got = eval_range(orig, args), eval_range(sm_inv.map(orig_maps_to), args) + self.assertEqual(want, got) + # Pick a quadrant. + for args in argslist: + # But its mapping cannot be expressed as a simple range with offset and stride. + self.assertIsNone(sm.map(Range([(0, K // 2, 1), (0, N // 2, 1), (0, M // 2, 1), (0, 0, 1)]))) + # Pick only points in problematic quadrants, but larger subsets elsewhere. + for args in argslist: + orig = Range([(0, K // 2, 1), (N // 2, N // 2, 1), (M // 2, M // 2, 1), (0, 0, 1)]) + orig_maps_to = Range([(0, K // 2, 1), ((N // 2) + (M // 2) * N, (N // 2) + (M // 2) * N, 1), (0, 0, 1), (0, 0, 1)]) + want, got = eval_range(orig_maps_to, args), eval_range(sm.map(orig), args) + self.assertEqual(want, got) + want, got = eval_range(orig, args), eval_range(sm_inv.map(orig_maps_to), args) self.assertEqual(want, got)