diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 8c836db7b9ef6..60be2d84b2bc8 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -816,51 +816,77 @@ def prune_graph(self, outputs=None, allow_remove_graph_inputs=True): """ if len(self.graphs()) > 1: + # TODO(tianleiwu): handle subgraph logger.debug("Skip prune_graph since graph has subgraph") return - if outputs is None: - outputs = [output.name for output in self.model.graph.output] + keep_outputs = [output.name for output in self.model.graph.output] if outputs is None else outputs output_name_to_node = self.output_name_to_node() - all_nodes = [] - for output in outputs: - if output in output_name_to_node: - last_node = output_name_to_node[output] - if last_node in all_nodes: - continue - nodes = self.get_parent_subgraph_nodes(last_node, []) - all_nodes.append(last_node) - all_nodes.extend(nodes) - nodes_to_remove = [node for node in self.model.graph.node if node not in all_nodes] + def get_first_output(node): + if node.output[0]: + return node.output[0] + return next(iter([o for o in node.output if o]), None) - self.remove_nodes(nodes_to_remove) + # Keep track of nodes to keep. The key is first output of node, and the value is the node. + output_to_node = {} - # remove outputs not in list - output_to_remove = [] - for output in self.model.graph.output: - if output.name not in outputs: - output_to_remove.append(output) - for output in output_to_remove: - self.model.graph.output.remove(output) + # Start from graph outputs, and find parent nodes recurisvely, and add nodes to the output_to_node dictionary. + dq = deque() + for output in keep_outputs: + if output in output_name_to_node: + dq.append(output_name_to_node[output]) + while len(dq) > 0: + node = dq.pop() + first_output = get_first_output(node) + if first_output and (first_output not in output_to_node): + output_to_node[first_output] = node + for name in node.input: + if len(name) > 0 and (name in output_name_to_node) and (name not in output_to_node): + dq.appendleft(output_name_to_node[name]) + + # Keep only those nodes in the output_to_node dictionary. + nodes_to_keep = [] + num_nodes_removed = 0 + for node in self.model.graph.node: + first_output = get_first_output(node) + kept_node = output_to_node[first_output] if first_output in output_to_node else None - # remove inputs not used by any node. + # Need double check the node since fused node might reuse output name of some nodes to be removed. + # It is slow to compare whole node, so we compare op_type first to avoid comparing node in most cases. + if kept_node and kept_node.op_type == node.op_type and kept_node == node: + nodes_to_keep.append(node) + else: + num_nodes_removed += 1 + self.model.graph.ClearField("node") + self.model.graph.node.extend(nodes_to_keep) + + # Remove graph outputs not in list + output_to_remove = [] + if outputs is not None: + for output in self.model.graph.output: + if output.name not in outputs: + output_to_remove.append(output) + for output in output_to_remove: + self.model.graph.output.remove(output) + + # Remove graph inputs not used by any node. input_to_remove = [] if allow_remove_graph_inputs: input_name_to_nodes = self.input_name_to_nodes() input_to_remove = [input for input in self.model.graph.input if input.name not in input_name_to_nodes] - for input in input_to_remove: - self.model.graph.input.remove(input) + for name in input_to_remove: + self.model.graph.input.remove(name) - if input_to_remove or output_to_remove or nodes_to_remove: + if input_to_remove or output_to_remove or num_nodes_removed > 0: removed = [] if input_to_remove: removed.append(f"{len(input_to_remove)} inputs") if output_to_remove: removed.append(f"{len(output_to_remove)} outputs") - if nodes_to_remove: - removed.append(f"{len(nodes_to_remove)} nodes") + if num_nodes_removed > 0: + removed.append(f"{num_nodes_removed} nodes") logger.info("Removed %s", ", ".join(removed)) self.update_graph()