Skip to content

Commit

Permalink
fp16 postprocess
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Jul 25, 2024
1 parent ca47f0f commit 7cf9999
Showing 1 changed file with 152 additions and 0 deletions.
152 changes: 152 additions & 0 deletions onnxruntime/python/tools/transformers/float16.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,154 @@ def add_node(self, node: NodeProto, is_node_blocked):
self.fp16_nodes.append(node)


def sort_graph_node(graph_proto):
# find the "first" node in Nodes that its input is not any node's output
def find_first_node(output2node_dict):
for node in org_nodes:
is_not_first_node = any(item in output2node_dict for item in node.input)
if not is_not_first_node:
return node
return None

# remove the node from output2node_dict using output as key
def remove_first_node_from_dict2(first_node):
for output in first_node.output:
if output in output2node_dict:
del output2node_dict[output]

org_nodes = graph_proto.node
# create a dict to store output as key and node as value
output2node_dict = {}
for node in org_nodes:
for output in node.output:
output2node_dict[output] = node

# save the final node after sorted
sorted_node = []
# traverse the Nodes to find the first node
while len(output2node_dict) > 0:
first_node = find_first_node(output2node_dict)
sorted_node.append(first_node)
remove_first_node_from_dict2(first_node)
# del node from original nodes list to avoid duplicate traverse
org_nodes.remove(first_node)

for new_node in sorted_node:
graph_proto.node.extend([new_node])


# The input graph should be mode.graph
# Recursively sort the topology for each sub-graph
def sort_topology(graph_proto):
assert isinstance(graph_proto, GraphProto)
sort_graph_node(graph_proto) # sort global graph
for node in graph_proto.node:
for attr in node.attribute:
if isinstance(attr.g, GraphProto) and len(attr.g.node) > 0:
sort_topology(attr.g) # sort sub-graph
for g in attr.graphs:
if isinstance(g, GraphProto):
sort_topology(g) # sort sub-graph


def remove_unnecessary_cast_node(graph_proto):
# 1. find all cast nodes in the graph
cast_node_list = []
input_name_to_cast_node_dict = {}
output_name_to_cast_node_dict = {}
# using name as key to point to a node. because node cannot be key
name_to_node_dict = {}
for node in graph_proto.node:
if node.op_type == "Cast":
if node.name not in ["graph_input_cast0", "graph_output_cast0"]:
cast_node_list.append(node)

name_to_node_dict[node.name] = node
for input_name in node.input:
input_name_to_cast_node_dict[input_name] = node
for output_name in node.output:
output_name_to_cast_node_dict[output_name] = node

# 2. find upstream and downstream node of the cast node
cast_node_upstream_dict = {} # mapping cast node(name) to its upstream node
cast_node_downstream_dict = {} # mapping cast node(name) to its downstream node
for current_node in graph_proto.node:
# find the downstream node(s)
for input_name in current_node.input:
if input_name in output_name_to_cast_node_dict:
# found the downstream node of the cast node, might be multiple
cast_node = output_name_to_cast_node_dict[input_name]
if cast_node.name not in cast_node_downstream_dict:
cast_node_downstream_dict[cast_node.name] = current_node
else: # already exists one downstream node, make it a list
existing_downstream_nodes = cast_node_downstream_dict[cast_node.name]
if isinstance(existing_downstream_nodes, list):
existing_downstream_nodes.append(current_node)
else: # make a list
existing_downstream_nodes = [existing_downstream_nodes, current_node]
cast_node_downstream_dict[cast_node.name] = existing_downstream_nodes
# find the upstream node
for output_name in current_node.output:
if output_name in input_name_to_cast_node_dict:
# found the upstream node of the cast node, should be unique
cast_node = input_name_to_cast_node_dict[output_name]
cast_node_upstream_dict[cast_node.name] = current_node

# 3. remove the cast node which upstream is 'Constant'
for cast_node_name, upstream_node in cast_node_upstream_dict.items():
cast_node = name_to_node_dict[cast_node_name]
if upstream_node.op_type == "Constant":
cast_node_list.remove(cast_node)

# 4. find the cast(to16) node which downstream is Cast(to32)
remove_candidate = []
for cast_node_name, downstream_node in cast_node_downstream_dict.items():
cast_node = name_to_node_dict[cast_node_name]
if isinstance(downstream_node, list):
for dn in downstream_node:
if (
dn.op_type == "Cast"
and dn.attribute[0].i == 32
and cast_node.attribute[0].i == 16
and dn in cast_node_list
and cast_node in cast_node_list
):
remove_candidate.append((cast_node, dn))
else:
if (
downstream_node.op_type == "Cast"
and cast_node.attribute[0].i == 10
and downstream_node.attribute[0].i == 1
and downstream_node in cast_node_list
and cast_node in cast_node_list
):
remove_candidate.append((cast_node, downstream_node))

# 5. change the connection of "upstream->cast16->cast32->downstream" to "upstream->downstream"
for cast_node_pair in remove_candidate:
first_cast_node = cast_node_pair[0]
second_cast_node = cast_node_pair[1]
upstream_node = cast_node_upstream_dict[first_cast_node.name]
downstream_node = cast_node_downstream_dict[second_cast_node.name]
# find the upstream node's output to first_cast_node
out = None
for output_name in upstream_node.output:
if output_name == first_cast_node.input[0]:
out = output_name
break
# find the downstream node's input as second_cast_node's output
for i, input_name in enumerate(downstream_node.input):
for output_name in second_cast_node.output:
if input_name == output_name:
# change the input as the upstream node's output
downstream_node.input[i] = out

# 6. remove the cast node pair
for cast_node_pair in remove_candidate:
graph_proto.node.remove(cast_node_pair[0])
graph_proto.node.remove(cast_node_pair[1])


def convert_float_to_float16(
model,
min_positive_val=5.96e-08,
Expand Down Expand Up @@ -477,6 +625,10 @@ def convert_float_to_float16(
# change current node's input name
node.output[i] = input_name
break

sort_topology(model.graph)
remove_unnecessary_cast_node(model.graph)

return model


Expand Down

0 comments on commit 7cf9999

Please sign in to comment.