diff --git a/dace/transformation/dataflow/map_expansion.py b/dace/transformation/dataflow/map_expansion.py index 8bc14213b0..fd092a252d 100644 --- a/dace/transformation/dataflow/map_expansion.py +++ b/dace/transformation/dataflow/map_expansion.py @@ -136,10 +136,15 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG): graph.add_edge(entries[-1], edge.src_conn, edge.dst, edge.dst_conn, memlet=copy.deepcopy(edge.data)) graph.remove_edge(edge) - if graph.in_degree(map_entry) == 0: + if graph.in_degree(map_entry) == 0 or all( + e.dst_conn is None or not e.dst_conn.startswith("IN_") + for e in graph.in_edges(map_entry) + ): graph.add_memlet_path(map_entry, *entries, memlet=dace.Memlet()) else: for edge in graph.in_edges(map_entry): + if edge.dst_conn is None: + continue if not edge.dst_conn.startswith("IN_"): continue diff --git a/tests/transformations/map_expansion_test.py b/tests/transformations/map_expansion_test.py index 6e4b965ba2..504b2ca37a 100644 --- a/tests/transformations/map_expansion_test.py +++ b/tests/transformations/map_expansion_test.py @@ -163,8 +163,41 @@ def mymap(i: _[0:20], j: _[0:30], k: _[0:5]): assert len(map_entries) == 2 +def test_expand_with_dependency_edges(): + + @dace.program + def expansion(A: dace.float32[2], B: dace.float32[2, 2, 2]): + for i in dace.map[0:2]: + A[i] = i + + for j, k in dace.map[0:2, 0:2]: + B[i, j, k] = i * j + k + + sdfg = expansion.to_sdfg() + sdfg.simplify() + sdfg.validate() + + # If dependency edges are handled correctly, this should not raise an exception + try: + num_app = sdfg.apply_transformations_repeated(MapExpansion) + except Exception as e: + assert False, f"MapExpansion failed: {str(e)}" + assert num_app == 1 + sdfg.validate() + + A = np.random.rand(2).astype(np.float32) + B = np.random.rand(2, 2, 2).astype(np.float32) + sdfg(A=A, B=B) + + A_expected = np.array([0, 1], dtype=np.float32) + B_expected = np.array([[[0, 1], [0, 1]], [[0, 1], [1, 2]]], dtype=np.float32) + assert np.all(A == A_expected) + assert np.all(B == B_expected) + + if __name__ == '__main__': test_expand_with_inputs() test_expand_without_inputs() test_expand_without_dynamic_inputs() test_expand_with_limits() + test_expand_with_dependency_edges()