Skip to content

Commit

Permalink
Improve performance of prune_graph in onnx_model.py (#17502)
Browse files Browse the repository at this point in the history
During optimization of SDXL UNet, the prune_graph takes up to 5 minutes.
The cause is to find a node in all nodes is time-consuming. This
optimization will reduce the latency of prune_graph to 2 seconds.

New algorithm will use a hash table (key is first node output, value is
node) to speed up.
  • Loading branch information
tianleiwu authored Sep 12, 2023
1 parent cf672c5 commit 49511b5
Showing 1 changed file with 52 additions and 26 deletions.
78 changes: 52 additions & 26 deletions onnxruntime/python/tools/transformers/onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,51 +816,77 @@ def prune_graph(self, outputs=None, allow_remove_graph_inputs=True):
"""

if len(self.graphs()) > 1:
# TODO(tianleiwu): handle subgraph
logger.debug("Skip prune_graph since graph has subgraph")
return

if outputs is None:
outputs = [output.name for output in self.model.graph.output]
keep_outputs = [output.name for output in self.model.graph.output] if outputs is None else outputs

output_name_to_node = self.output_name_to_node()
all_nodes = []
for output in outputs:
if output in output_name_to_node:
last_node = output_name_to_node[output]
if last_node in all_nodes:
continue
nodes = self.get_parent_subgraph_nodes(last_node, [])
all_nodes.append(last_node)
all_nodes.extend(nodes)

nodes_to_remove = [node for node in self.model.graph.node if node not in all_nodes]
def get_first_output(node):
if node.output[0]:
return node.output[0]
return next(iter([o for o in node.output if o]), None)

self.remove_nodes(nodes_to_remove)
# Keep track of nodes to keep. The key is first output of node, and the value is the node.
output_to_node = {}

# remove outputs not in list
output_to_remove = []
for output in self.model.graph.output:
if output.name not in outputs:
output_to_remove.append(output)
for output in output_to_remove:
self.model.graph.output.remove(output)
# Start from graph outputs, and find parent nodes recurisvely, and add nodes to the output_to_node dictionary.
dq = deque()
for output in keep_outputs:
if output in output_name_to_node:
dq.append(output_name_to_node[output])
while len(dq) > 0:
node = dq.pop()
first_output = get_first_output(node)
if first_output and (first_output not in output_to_node):
output_to_node[first_output] = node
for name in node.input:
if len(name) > 0 and (name in output_name_to_node) and (name not in output_to_node):
dq.appendleft(output_name_to_node[name])

# Keep only those nodes in the output_to_node dictionary.
nodes_to_keep = []
num_nodes_removed = 0
for node in self.model.graph.node:
first_output = get_first_output(node)
kept_node = output_to_node[first_output] if first_output in output_to_node else None

# remove inputs not used by any node.
# Need double check the node since fused node might reuse output name of some nodes to be removed.
# It is slow to compare whole node, so we compare op_type first to avoid comparing node in most cases.
if kept_node and kept_node.op_type == node.op_type and kept_node == node:
nodes_to_keep.append(node)
else:
num_nodes_removed += 1
self.model.graph.ClearField("node")
self.model.graph.node.extend(nodes_to_keep)

# Remove graph outputs not in list
output_to_remove = []
if outputs is not None:
for output in self.model.graph.output:
if output.name not in outputs:
output_to_remove.append(output)
for output in output_to_remove:
self.model.graph.output.remove(output)

# Remove graph inputs not used by any node.
input_to_remove = []
if allow_remove_graph_inputs:
input_name_to_nodes = self.input_name_to_nodes()
input_to_remove = [input for input in self.model.graph.input if input.name not in input_name_to_nodes]
for input in input_to_remove:
self.model.graph.input.remove(input)
for name in input_to_remove:
self.model.graph.input.remove(name)

if input_to_remove or output_to_remove or nodes_to_remove:
if input_to_remove or output_to_remove or num_nodes_removed > 0:
removed = []
if input_to_remove:
removed.append(f"{len(input_to_remove)} inputs")
if output_to_remove:
removed.append(f"{len(output_to_remove)} outputs")
if nodes_to_remove:
removed.append(f"{len(nodes_to_remove)} nodes")
if num_nodes_removed > 0:
removed.append(f"{num_nodes_removed} nodes")
logger.info("Removed %s", ", ".join(removed))

self.update_graph()
Expand Down

0 comments on commit 49511b5

Please sign in to comment.