Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OTFMapFusion: Minor bug fixes #1448

Merged
merged 2 commits into from
Nov 26, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 15 additions & 19 deletions dace/transformation/dataflow/otf_map_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,27 +132,28 @@ def apply(self, graph: SDFGState, sdfg: SDFG):
intermediate_access_node = self.array
first_map_exit = self.first_map_exit
first_map_entry = graph.entry_node(first_map_exit)
second_map_entry = self.second_map_entry

# Prepare: Make first and second map parameters disjoint
# This avoids mutual matching: i -> j, j -> i
subgraph = graph.scope_subgraph(first_map_entry, include_entry=True, include_exit=True)
for param in first_map_entry.map.params:
i = 0
new_param = f"_i{i}"
while new_param in self.second_map_entry.map.params or new_param in first_map_entry.map.params:
while new_param in second_map_entry.map.params or new_param in first_map_entry.map.params:
i = i + 1
new_param = f"_i{i}"

advanced_replace(subgraph, param, new_param)

# Prepare: Preemptively rename params defined by second map in scope of first
# This avoids that local variables (e.g., in nested SDFG) have collisions with new map scope
for param in self.second_map_entry.map.params:
for param in second_map_entry.map.params:
new_param = param + "_local"
advanced_replace(subgraph, param, new_param)

# Add local buffers for array-like OTFs
for edge in graph.out_edges(self.second_map_entry):
for edge in graph.out_edges(second_map_entry):
if edge.data is None or edge.data.data != intermediate_access_node.data:
continue

Expand Down Expand Up @@ -208,18 +209,18 @@ def apply(self, graph: SDFGState, sdfg: SDFG):
save=False)

# Phase 1: Add new access nodes to second map
for edge in graph.edges_between(intermediate_access_node, self.second_map_entry):
for edge in graph.edges_between(intermediate_access_node, second_map_entry):
graph.remove_edge_and_connectors(edge)

connector_mapping = {}
for edge in graph.in_edges(first_map_entry):
new_in_connector = self.second_map_entry.next_connector(edge.dst_conn[3:])
new_in_connector = second_map_entry.next_connector(edge.dst_conn[3:])
new_in_connector = "IN_" + new_in_connector
if not self.second_map_entry.add_in_connector(new_in_connector):
if not second_map_entry.add_in_connector(new_in_connector):
raise ValueError("Failed to add new in connector")

memlet = copy.deepcopy(edge.data)
graph.add_edge(edge.src, edge.src_conn, self.second_map_entry, new_in_connector, memlet)
graph.add_edge(edge.src, edge.src_conn, second_map_entry, new_in_connector, memlet)

connector_mapping[edge.dst_conn] = new_in_connector

Expand All @@ -231,7 +232,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG):

# Group by same access scheme
consume_memlets = {}
for edge in graph.out_edges(self.second_map_entry):
for edge in graph.out_edges(second_map_entry):
memlet = edge.data
if memlet.data not in produce_memlets:
continue
Expand All @@ -246,7 +247,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG):
consume_memlets[memlet.data][accesses].append(edge)

# And remove from second map
self.second_map_entry.remove_out_connector(edge.src_conn)
second_map_entry.remove_out_connector(edge.src_conn)
graph.remove_edge(edge)

# Phase 3: OTF - copy content of first map for each memlet of second according to matches
Expand All @@ -256,7 +257,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG):
for second_accesses in consume_memlets[array]:
# Step 1: Infer index access of second map to new inputs with respect to original first map
mapping = OTFMapFusion.solve(first_map_entry.map.params, first_accesses,
self.second_map_entry.map.params, second_accesses)
second_map_entry.map.params, second_accesses)

# Step 2: Add Temporary buffer
tmp_name = sdfg.temp_data_name()
Expand Down Expand Up @@ -296,16 +297,16 @@ def apply(self, graph: SDFGState, sdfg: SDFG):
else:
out_connector = edge.src_conn

if out_connector not in self.second_map_entry.out_connectors:
self.second_map_entry.add_out_connector(out_connector)
if out_connector not in second_map_entry.out_connectors:
second_map_entry.add_out_connector(out_connector)
else:
out_connector = None

graph.add_edge(self.second_map_entry, out_connector, node, edge.dst_conn, memlet)
graph.add_edge(second_map_entry, out_connector, node, edge.dst_conn, memlet)
graph.remove_edge(edge)

# Step 4: Rename all symbols of first map in copied content my matched symbol of second map
otf_nodes.append(self.second_map_entry)
otf_nodes.append(second_map_entry)
otf_subgraph = StateSubgraphView(graph, otf_nodes)
for param in mapping:
if isinstance(param, tuple):
Expand All @@ -316,14 +317,9 @@ def apply(self, graph: SDFGState, sdfg: SDFG):

# Check if first_map is still consumed by some node
if graph.out_degree(intermediate_access_node) == 0:
del sdfg.arrays[intermediate_access_node.data]
graph.remove_node(intermediate_access_node)

subgraph = graph.scope_subgraph(first_map_entry, include_entry=True, include_exit=True)
for dnode in subgraph.data_nodes():
if dnode.data in sdfg.arrays:
del sdfg.arrays[dnode.data]

obsolete_nodes = graph.all_nodes_between(first_map_entry,
first_map_exit) | {first_map_entry, first_map_exit}
graph.remove_nodes_from(obsolete_nodes)
Expand Down
Loading