Skip to content

Commit

Permalink
Handle dtype attribute in float16 conversion script (#17321)
Browse files Browse the repository at this point in the history
Some operators have dtype attribute (search `dtype` in
https://github.com/onnx/onnx/blob/main/docs/Operators.md).
This change make sure dtype attribute is handled correctly in float16
conversion.
  • Loading branch information
tianleiwu authored Aug 30, 2023
1 parent 8224891 commit c961f67
Showing 1 changed file with 50 additions and 29 deletions.
79 changes: 50 additions & 29 deletions onnxruntime/python/tools/transformers/float16.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@

import numpy as np
import onnx
from onnx import helper, numpy_helper
from onnx import onnx_pb as onnx_proto
from onnx import AttributeProto, GraphProto, ModelProto, NodeProto, TensorProto, helper, numpy_helper
from onnx.shape_inference import infer_shapes, infer_shapes_path
from packaging import version

Expand Down Expand Up @@ -87,11 +86,11 @@ def convert_tensor_float_to_float16(tensor, min_positive_val=5.96e-08, max_finit
TensorProto: the converted tensor.
"""

if not isinstance(tensor, onnx_proto.TensorProto):
if not isinstance(tensor, TensorProto):
raise ValueError(f"Expected input type is an ONNX TensorProto but got {type(tensor)}")

if tensor.data_type == onnx_proto.TensorProto.FLOAT:
tensor.data_type = onnx_proto.TensorProto.FLOAT16
if tensor.data_type == TensorProto.FLOAT:
tensor.data_type = TensorProto.FLOAT16
# convert float_data (float type) to float16 and write to int32_data
if tensor.float_data:
float16_data = convert_np_to_float16(np.array(tensor.float_data), min_positive_val, max_finite_val)
Expand Down Expand Up @@ -152,12 +151,12 @@ def make_value_info_from_tensor(tensor):
class InitializerTracker:
"""Class for keeping track of initializer."""

def __init__(self, initializer: onnx_proto.TensorProto):
def __init__(self, initializer: TensorProto):
self.initializer = initializer
self.fp32_nodes = []
self.fp16_nodes = []

def add_node(self, node: onnx_proto.NodeProto, is_node_blocked):
def add_node(self, node: NodeProto, is_node_blocked):
if is_node_blocked:
self.fp32_nodes.append(node)
else:
Expand Down Expand Up @@ -219,7 +218,7 @@ def convert_float_to_float16(
else:
model = onnx.load(model_path)

if not isinstance(model, onnx_proto.ModelProto):
if not isinstance(model, ModelProto):
raise ValueError(f"Expected an ONNX ModelProto but got {type(model)}")

func_infer_shape = None
Expand Down Expand Up @@ -259,8 +258,8 @@ def convert_float_to_float16(
graph_io_to_skip = set()
io_casts = set()

fp32_inputs = [n.name for n in model.graph.input if n.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT]
fp32_outputs = [n.name for n in model.graph.output if n.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT]
fp32_inputs = [n.name for n in model.graph.input if n.type.tensor_type.elem_type == TensorProto.FLOAT]
fp32_outputs = [n.name for n in model.graph.output if n.type.tensor_type.elem_type == TensorProto.FLOAT]
if isinstance(keep_io_types, list):
fp32_inputs = [n for n in fp32_inputs if n in keep_io_types]
fp32_outputs = [n for n in fp32_outputs if n in keep_io_types]
Expand All @@ -278,9 +277,9 @@ def convert_float_to_float16(
new_value_info = model.graph.value_info.add()
new_value_info.CopyFrom(n)
new_value_info.name = output_name
new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16
new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT16
# add Cast node (from tensor(float) to tensor(float16) after graph input
new_node = [helper.make_node("Cast", [n.name], [output_name], to=10, name=node_name)]
new_node = [helper.make_node("Cast", [n.name], [output_name], to=TensorProto.FLOAT16, name=node_name)]
model.graph.node.extend(new_node)
value_info_list.append(new_value_info)
io_casts.add(node_name)
Expand All @@ -296,7 +295,7 @@ def convert_float_to_float16(
new_value_info = model.graph.value_info.add()
new_value_info.CopyFrom(n)
new_value_info.name = input_name
new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16
new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT16
new_node = [helper.make_node("Cast", [input_name], [n.name], to=1, name=node_name)]
model.graph.node.extend(new_node)
value_info_list.append(new_value_info)
Expand All @@ -307,12 +306,12 @@ def convert_float_to_float16(
next_level = []
for q in queue:
# if q is model, push q.graph (GraphProto)
if isinstance(q, onnx_proto.ModelProto):
if isinstance(q, ModelProto):
next_level.append(q.graph)
# if q is model.graph, push q.node.attribute (AttributeProto)
if isinstance(q, onnx_proto.GraphProto):
if isinstance(q, GraphProto):
for n in q.initializer: # TensorProto type
if n.data_type == onnx_proto.TensorProto.FLOAT:
if n.data_type == TensorProto.FLOAT:
assert n.name not in fp32_initializers
fp32_initializers[n.name] = InitializerTracker(n)

Expand Down Expand Up @@ -343,10 +342,32 @@ def convert_float_to_float16(
else:
if n.op_type == "Cast":
for attr in n.attribute:
if attr.name == "to" and attr.i == 1:
attr.i = 10
if attr.name == "to" and attr.i == TensorProto.FLOAT:
attr.i = TensorProto.FLOAT16
break

if n.op_type in [
"EyeLike",
"Multinomial",
"RandomNormal",
"RandomNormalLike",
"RandomUniform",
"RandomUniformLike",
"SequenceEmpty",
"Bernoulli",
]:
has_dtype = False
for attr in n.attribute:
if attr.name == "dtype":
has_dtype = True
if attr.i == TensorProto.FLOAT:
attr.i = TensorProto.FLOAT16

# The dtype attribute is optional and default is FLOAT in the following operators
# so we need add dtype attribute to specify the data type float16
if (n.op_type in ["RandomNormal", "RandomUniform", "SequenceEmpty"]) and not has_dtype:
n.attribute.extend([helper.make_attribute("dtype", TensorProto.FLOAT16)])

# For Resize/GroupNorm, attribute data type cannot be changed
if n.op_type not in ALWAYS_FLOAT_INPUTS or n.op_type in force_fp16_inputs_dict:
for attr in n.attribute:
Expand All @@ -356,27 +377,27 @@ def convert_float_to_float16(

# if q is model.graph.node.attribute, push q.g and q.graphs (GraphProto)
# and process node.attribute.t and node.attribute.tensors (TensorProto)
if isinstance(q, onnx_proto.AttributeProto):
if isinstance(q, AttributeProto):
next_level.append(q.g)
for n in q.graphs:
next_level.append(n) # noqa: PERF402
q.t.CopyFrom(convert_tensor_float_to_float16(q.t, min_positive_val, max_finite_val))
for n in q.tensors:
n = convert_tensor_float_to_float16(n, min_positive_val, max_finite_val) # noqa: PLW2901
# if q is graph, process input, output and value_info (ValueInfoProto)
if isinstance(q, onnx_proto.GraphProto):
if isinstance(q, GraphProto):
# Note that float initializers tracked by fp32_initializers will be processed later.
# for all ValueInfoProto with tensor(float) type in input, output and value_info, convert them to
# tensor(float16) except map and seq(map). And save them in value_info_list for further processing
for n in itertools.chain(q.input, q.output, q.value_info):
if n.type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT:
if n.type.tensor_type.elem_type == TensorProto.FLOAT:
if n.name not in graph_io_to_skip:
n.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16
n.type.tensor_type.elem_type = TensorProto.FLOAT16
value_info_list.append(n)
if n.type.HasField("sequence_type"):
if n.type.sequence_type.elem_type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT:
if n.type.sequence_type.elem_type.tensor_type.elem_type == TensorProto.FLOAT:
if n.name not in graph_io_to_skip:
n.type.sequence_type.elem_type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16
n.type.sequence_type.elem_type.tensor_type.elem_type = TensorProto.FLOAT16
value_info_list.append(n)

queue = next_level
Expand Down Expand Up @@ -405,7 +426,7 @@ def convert_float_to_float16(
new_value_info.CopyFrom(value_info)
output_name = node.name + "_input_cast_" + str(i)
new_value_info.name = output_name
new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT
new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT
# add Cast node (from tensor(float16) to tensor(float) before current node
node_name = node.name + "_input_cast" + str(i)
new_node = [helper.make_node("Cast", [input_name], [output_name], to=1, name=node_name)]
Expand All @@ -428,7 +449,7 @@ def convert_float_to_float16(
new_value_info.CopyFrom(value_info)
output_name = node.name + "_input_cast_" + str(i)
new_value_info.name = output_name
new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT
new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT
# add Cast node (from tensor(float16) to tensor(float) before current node
node_name = node.name + "_input_cast" + str(i)
new_node = [helper.make_node("Cast", [input_name], [output_name], to=1, name=node_name)]
Expand All @@ -447,7 +468,7 @@ def convert_float_to_float16(
new_value_info.CopyFrom(value_info)
input_name = node.name + "_output_cast_" + str(i)
new_value_info.name = input_name
new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT
new_value_info.type.tensor_type.elem_type = TensorProto.FLOAT
# add Cast node (from tensor(float) to tensor(float16) after current node
node_name = node.name + "_output_cast" + str(i)
new_node = [helper.make_node("Cast", [input_name], [output], to=10, name=node_name)]
Expand All @@ -460,9 +481,9 @@ def convert_float_to_float16(

def float_to_float16_max_diff(tensor, min_positive_val=5.96e-08, max_finite_val=65504.0):
"""Measure the maximum absolute difference after converting a float tensor to float16."""
if not isinstance(tensor, onnx_proto.TensorProto):
if not isinstance(tensor, TensorProto):
raise ValueError(f"Expected input type is an ONNX TensorProto but got {type(tensor)}")
if tensor.data_type != onnx_proto.TensorProto.FLOAT:
if tensor.data_type != TensorProto.FLOAT:
raise ValueError("Expected tensor data type is float.")

float32_data = None
Expand Down

0 comments on commit c961f67

Please sign in to comment.