Skip to content

Commit

Permalink
OTFMapFusion: Minor bug fixes (#1448)
Browse files Browse the repository at this point in the history
This PR fixes the problem that data containers were removed from an
SDFG, although they might be used in another state. Furthermore, it
fixes the problem that self.second_map_entry may not point to the
correct node after adding/removing nodes
  • Loading branch information
lukastruemper committed Dec 2, 2023
1 parent 8e6331d commit 2a9f0a4
Showing 1 changed file with 22 additions and 20 deletions.
42 changes: 22 additions & 20 deletions dace/transformation/dataflow/otf_map_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,27 +132,28 @@ def apply(self, graph: SDFGState, sdfg: SDFG):
intermediate_access_node = self.array
first_map_exit = self.first_map_exit
first_map_entry = graph.entry_node(first_map_exit)
second_map_entry = self.second_map_entry

# Prepare: Make first and second map parameters disjoint
# This avoids mutual matching: i -> j, j -> i
subgraph = graph.scope_subgraph(first_map_entry, include_entry=True, include_exit=True)
for param in first_map_entry.map.params:
i = 0
new_param = f"_i{i}"
while new_param in self.second_map_entry.map.params or new_param in first_map_entry.map.params:
while new_param in second_map_entry.map.params or new_param in first_map_entry.map.params:
i = i + 1
new_param = f"_i{i}"

advanced_replace(subgraph, param, new_param)

# Prepare: Preemptively rename params defined by second map in scope of first
# This avoids that local variables (e.g., in nested SDFG) have collisions with new map scope
for param in self.second_map_entry.map.params:
for param in second_map_entry.map.params:
new_param = param + "_local"
advanced_replace(subgraph, param, new_param)

# Add local buffers for array-like OTFs
for edge in graph.out_edges(self.second_map_entry):
for edge in graph.out_edges(second_map_entry):
if edge.data is None or edge.data.data != intermediate_access_node.data:
continue

Expand Down Expand Up @@ -208,18 +209,18 @@ def apply(self, graph: SDFGState, sdfg: SDFG):
save=False)

# Phase 1: Add new access nodes to second map
for edge in graph.edges_between(intermediate_access_node, self.second_map_entry):
for edge in graph.edges_between(intermediate_access_node, second_map_entry):
graph.remove_edge_and_connectors(edge)

connector_mapping = {}
for edge in graph.in_edges(first_map_entry):
new_in_connector = self.second_map_entry.next_connector(edge.dst_conn[3:])
new_in_connector = second_map_entry.next_connector(edge.dst_conn[3:])
new_in_connector = "IN_" + new_in_connector
if not self.second_map_entry.add_in_connector(new_in_connector):
if not second_map_entry.add_in_connector(new_in_connector):
raise ValueError("Failed to add new in connector")

memlet = copy.deepcopy(edge.data)
graph.add_edge(edge.src, edge.src_conn, self.second_map_entry, new_in_connector, memlet)
graph.add_edge(edge.src, edge.src_conn, second_map_entry, new_in_connector, memlet)

connector_mapping[edge.dst_conn] = new_in_connector

Expand All @@ -231,7 +232,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG):

# Group by same access scheme
consume_memlets = {}
for edge in graph.out_edges(self.second_map_entry):
for edge in graph.out_edges(second_map_entry):
memlet = edge.data
if memlet.data not in produce_memlets:
continue
Expand All @@ -246,7 +247,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG):
consume_memlets[memlet.data][accesses].append(edge)

# And remove from second map
self.second_map_entry.remove_out_connector(edge.src_conn)
second_map_entry.remove_out_connector(edge.src_conn)
graph.remove_edge(edge)

# Phase 3: OTF - copy content of first map for each memlet of second according to matches
Expand All @@ -256,7 +257,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG):
for second_accesses in consume_memlets[array]:
# Step 1: Infer index access of second map to new inputs with respect to original first map
mapping = OTFMapFusion.solve(first_map_entry.map.params, first_accesses,
self.second_map_entry.map.params, second_accesses)
second_map_entry.map.params, second_accesses)

# Step 2: Add Temporary buffer
tmp_name = sdfg.temp_data_name()
Expand Down Expand Up @@ -289,20 +290,26 @@ def apply(self, graph: SDFGState, sdfg: SDFG):
for edge in graph.edges_between(first_map_entry, node):
memlet = copy.deepcopy(edge.data)

in_connector = edge.src_conn.replace("OUT", "IN")
if in_connector in connector_mapping:
out_connector = connector_mapping[in_connector].replace("IN", "OUT")
if edge.src_conn is not None:
in_connector = edge.src_conn.replace("OUT", "IN")
if in_connector in connector_mapping:
out_connector = connector_mapping[in_connector].replace("IN", "OUT")
else:
out_connector = edge.src_conn

if out_connector not in second_map_entry.out_connectors:
second_map_entry.add_out_connector(out_connector)
else:
out_connector = edge.src_conn

if out_connector not in self.second_map_entry.out_connectors:
self.second_map_entry.add_out_connector(out_connector)

graph.add_edge(self.second_map_entry, out_connector, node, edge.dst_conn, memlet)
graph.add_edge(second_map_entry, out_connector, node, edge.dst_conn, memlet)
graph.remove_edge(edge)

# Step 4: Rename all symbols of first map in copied content my matched symbol of second map
otf_nodes.append(self.second_map_entry)
otf_nodes.append(second_map_entry)
otf_subgraph = StateSubgraphView(graph, otf_nodes)
for param in mapping:
if isinstance(param, tuple):
Expand All @@ -313,14 +320,9 @@ def apply(self, graph: SDFGState, sdfg: SDFG):

# Check if first_map is still consumed by some node
if graph.out_degree(intermediate_access_node) == 0:
del sdfg.arrays[intermediate_access_node.data]
graph.remove_node(intermediate_access_node)

subgraph = graph.scope_subgraph(first_map_entry, include_entry=True, include_exit=True)
for dnode in subgraph.data_nodes():
if dnode.data in sdfg.arrays:
del sdfg.arrays[dnode.data]

obsolete_nodes = graph.all_nodes_between(first_map_entry,
first_map_exit) | {first_map_entry, first_map_exit}
graph.remove_nodes_from(obsolete_nodes)
Expand Down

0 comments on commit 2a9f0a4

Please sign in to comment.