From 7cf9999b5223d04ac9adcc9681c8f69746291646 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 25 Jul 2024 00:38:57 -0700 Subject: [PATCH] fp16 postprocess --- .../python/tools/transformers/float16.py | 152 ++++++++++++++++++ 1 file changed, 152 insertions(+) diff --git a/onnxruntime/python/tools/transformers/float16.py b/onnxruntime/python/tools/transformers/float16.py index 2398bb9d6031b..b00012264d40f 100644 --- a/onnxruntime/python/tools/transformers/float16.py +++ b/onnxruntime/python/tools/transformers/float16.py @@ -164,6 +164,154 @@ def add_node(self, node: NodeProto, is_node_blocked): self.fp16_nodes.append(node) +def sort_graph_node(graph_proto): + # find the "first" node in Nodes that its input is not any node's output + def find_first_node(output2node_dict): + for node in org_nodes: + is_not_first_node = any(item in output2node_dict for item in node.input) + if not is_not_first_node: + return node + return None + + # remove the node from output2node_dict using output as key + def remove_first_node_from_dict2(first_node): + for output in first_node.output: + if output in output2node_dict: + del output2node_dict[output] + + org_nodes = graph_proto.node + # create a dict to store output as key and node as value + output2node_dict = {} + for node in org_nodes: + for output in node.output: + output2node_dict[output] = node + + # save the final node after sorted + sorted_node = [] + # traverse the Nodes to find the first node + while len(output2node_dict) > 0: + first_node = find_first_node(output2node_dict) + sorted_node.append(first_node) + remove_first_node_from_dict2(first_node) + # del node from original nodes list to avoid duplicate traverse + org_nodes.remove(first_node) + + for new_node in sorted_node: + graph_proto.node.extend([new_node]) + + +# The input graph should be mode.graph +# Recursively sort the topology for each sub-graph +def sort_topology(graph_proto): + assert isinstance(graph_proto, GraphProto) + sort_graph_node(graph_proto) # sort global graph + for node in graph_proto.node: + for attr in node.attribute: + if isinstance(attr.g, GraphProto) and len(attr.g.node) > 0: + sort_topology(attr.g) # sort sub-graph + for g in attr.graphs: + if isinstance(g, GraphProto): + sort_topology(g) # sort sub-graph + + +def remove_unnecessary_cast_node(graph_proto): + # 1. find all cast nodes in the graph + cast_node_list = [] + input_name_to_cast_node_dict = {} + output_name_to_cast_node_dict = {} + # using name as key to point to a node. because node cannot be key + name_to_node_dict = {} + for node in graph_proto.node: + if node.op_type == "Cast": + if node.name not in ["graph_input_cast0", "graph_output_cast0"]: + cast_node_list.append(node) + + name_to_node_dict[node.name] = node + for input_name in node.input: + input_name_to_cast_node_dict[input_name] = node + for output_name in node.output: + output_name_to_cast_node_dict[output_name] = node + + # 2. find upstream and downstream node of the cast node + cast_node_upstream_dict = {} # mapping cast node(name) to its upstream node + cast_node_downstream_dict = {} # mapping cast node(name) to its downstream node + for current_node in graph_proto.node: + # find the downstream node(s) + for input_name in current_node.input: + if input_name in output_name_to_cast_node_dict: + # found the downstream node of the cast node, might be multiple + cast_node = output_name_to_cast_node_dict[input_name] + if cast_node.name not in cast_node_downstream_dict: + cast_node_downstream_dict[cast_node.name] = current_node + else: # already exists one downstream node, make it a list + existing_downstream_nodes = cast_node_downstream_dict[cast_node.name] + if isinstance(existing_downstream_nodes, list): + existing_downstream_nodes.append(current_node) + else: # make a list + existing_downstream_nodes = [existing_downstream_nodes, current_node] + cast_node_downstream_dict[cast_node.name] = existing_downstream_nodes + # find the upstream node + for output_name in current_node.output: + if output_name in input_name_to_cast_node_dict: + # found the upstream node of the cast node, should be unique + cast_node = input_name_to_cast_node_dict[output_name] + cast_node_upstream_dict[cast_node.name] = current_node + + # 3. remove the cast node which upstream is 'Constant' + for cast_node_name, upstream_node in cast_node_upstream_dict.items(): + cast_node = name_to_node_dict[cast_node_name] + if upstream_node.op_type == "Constant": + cast_node_list.remove(cast_node) + + # 4. find the cast(to16) node which downstream is Cast(to32) + remove_candidate = [] + for cast_node_name, downstream_node in cast_node_downstream_dict.items(): + cast_node = name_to_node_dict[cast_node_name] + if isinstance(downstream_node, list): + for dn in downstream_node: + if ( + dn.op_type == "Cast" + and dn.attribute[0].i == 32 + and cast_node.attribute[0].i == 16 + and dn in cast_node_list + and cast_node in cast_node_list + ): + remove_candidate.append((cast_node, dn)) + else: + if ( + downstream_node.op_type == "Cast" + and cast_node.attribute[0].i == 10 + and downstream_node.attribute[0].i == 1 + and downstream_node in cast_node_list + and cast_node in cast_node_list + ): + remove_candidate.append((cast_node, downstream_node)) + + # 5. change the connection of "upstream->cast16->cast32->downstream" to "upstream->downstream" + for cast_node_pair in remove_candidate: + first_cast_node = cast_node_pair[0] + second_cast_node = cast_node_pair[1] + upstream_node = cast_node_upstream_dict[first_cast_node.name] + downstream_node = cast_node_downstream_dict[second_cast_node.name] + # find the upstream node's output to first_cast_node + out = None + for output_name in upstream_node.output: + if output_name == first_cast_node.input[0]: + out = output_name + break + # find the downstream node's input as second_cast_node's output + for i, input_name in enumerate(downstream_node.input): + for output_name in second_cast_node.output: + if input_name == output_name: + # change the input as the upstream node's output + downstream_node.input[i] = out + + # 6. remove the cast node pair + for cast_node_pair in remove_candidate: + graph_proto.node.remove(cast_node_pair[0]) + graph_proto.node.remove(cast_node_pair[1]) + + def convert_float_to_float16( model, min_positive_val=5.96e-08, @@ -477,6 +625,10 @@ def convert_float_to_float16( # change current node's input name node.output[i] = input_name break + + sort_topology(model.graph) + remove_unnecessary_cast_node(model.graph) + return model