diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 5bcbce1df8c1c..9628e2a74137a 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -78,14 +78,7 @@ def process_mask(self, input: str) -> str: # ReduceSum-13: axes is moved from attribute to input axes_name = "ort_const_1_reduce_sum_axes" if self.model.get_initializer(axes_name) is None: - self.model.add_initializer( - helper.make_tensor( - name=axes_name, - data_type=TensorProto.INT64, - dims=[1], - vals=[1], - ) - ) + self.add_initializer(name=axes_name, data_type=TensorProto.INT64, dims=[1], vals=[1], raw=False) mask_index_node = helper.make_node( "ReduceSum", inputs=[input_name, axes_name], @@ -428,19 +421,12 @@ def create_combined_qkv_bias( qkv_bias_dim = 3 * np.prod(qb.shape) bias_name = name_prefix + "_qkv_bias" - bias = helper.make_tensor( + self.add_initializer( name=bias_name, - data_type=TensorProto.FLOAT, + data_type=q_bias.data_type, dims=[qkv_bias_dim], - vals=qkv_bias.flatten().tolist(), + vals=qkv_bias, ) - - # Convert bias to FP16 if model is using FP16 - if q_bias.data_type == 10: - bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name)) - - self.model.add_initializer(bias, self.this_graph_name) - return bias_name def create_packed_qkv_matmul_node( @@ -488,13 +474,13 @@ def create_packed_qkv_matmul_node( qkv_weight = np.stack((qw, kw, vw), axis=1).reshape((d, 3 * d)) qkv_weight_name = matmul_node_name + "_qkv_weight" - weight = helper.make_tensor( + + self.add_initializer( name=qkv_weight_name, - data_type=TensorProto.FLOAT, + data_type=q_weight.data_type, dims=[qkv_weight.shape[0], qkv_weight.shape[1]], - vals=qkv_weight.flatten().tolist(), + vals=qkv_weight, ) - self.model.add_initializer(weight, self.this_graph_name) # Created packed QKV MatMul with output (B, S, 3*D) # Output is of the form: @@ -519,23 +505,15 @@ def create_packed_qkv_matmul_node( # Create Slice nodes to access Q, K, V q_slice_name = matmul_node_name + "_q_start_index" - q_start_tensor = helper.make_tensor(name=q_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[0]) + self.add_initializer(name=q_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[0], raw=False) k_slice_name = matmul_node_name + "_k_start_index" - k_start_tensor = helper.make_tensor(name=k_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[d]) + self.add_initializer(name=k_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[d], raw=False) v_slice_name = matmul_node_name + "_v_start_index" - v_start_tensor = helper.make_tensor(name=v_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[2 * d]) + self.add_initializer(name=v_slice_name, data_type=TensorProto.INT64, dims=[1], vals=[2 * d], raw=False) end_of_qkv_name = matmul_node_name + "_end_of_qkv_index" - end_of_qkv_tensor = helper.make_tensor( - name=end_of_qkv_name, data_type=TensorProto.INT64, dims=[1], vals=[3 * d] - ) + self.add_initializer(name=end_of_qkv_name, data_type=TensorProto.INT64, dims=[1], vals=[3 * d], raw=False) qkv_last_axis_name = matmul_node_name + "_qkv_last_axis" - qkv_axis_tensor = helper.make_tensor(name=qkv_last_axis_name, data_type=TensorProto.INT64, dims=[1], vals=[-1]) - - self.model.add_initializer(q_start_tensor, self.this_graph_name) - self.model.add_initializer(k_start_tensor, self.this_graph_name) - self.model.add_initializer(v_start_tensor, self.this_graph_name) - self.model.add_initializer(end_of_qkv_tensor, self.this_graph_name) - self.model.add_initializer(qkv_axis_tensor, self.this_graph_name) + self.add_initializer(name=qkv_last_axis_name, data_type=TensorProto.INT64, dims=[1], vals=[-1], raw=False) q_slice_output = matmul_node_name + "_q_out" q_slice = helper.make_node( @@ -823,7 +801,6 @@ def create_attention_node( assert q_bias_shape == k_bias_shape == qw_out_size assert v_bias_shape == vw_out_size - qkv_bias_dim = 0 if is_qkv_diff_dims: qkv_bias = np.concatenate((qb, kb, vb), axis=0) qkv_bias_dim = q_bias_shape + k_bias_shape + v_bias_shape @@ -834,29 +811,20 @@ def create_attention_node( attention_node_name = self.model.create_node_name("Attention") if not self.use_multi_head_attention: - weight = helper.make_tensor( + self.add_initializer( name=attention_node_name + "_qkv_weight", - data_type=TensorProto.FLOAT, + data_type=q_weight.data_type, dims=[qw_in_size, qkv_weight_dim], - vals=qkv_weight.flatten().tolist(), + vals=qkv_weight, ) - # Sometimes weights and bias are stored in fp16 - if q_weight.data_type == 10: - weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name)) - self.model.add_initializer(weight, self.this_graph_name) - - bias = None if has_bias: - bias = helper.make_tensor( + self.add_initializer( name=attention_node_name + "_qkv_bias", - data_type=TensorProto.FLOAT, + data_type=q_bias.data_type, dims=[qkv_bias_dim], - vals=qkv_bias.flatten().tolist(), + vals=qkv_bias, ) - if q_bias.data_type == 10: - bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name)) - self.model.add_initializer(bias, self.this_graph_name) # For MultiHeadAttention operator, use separated inputs for query, key and value, and no weights. if self.use_multi_head_attention: @@ -1198,14 +1166,15 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): if einsum_node is not None: unique_index = einsum_node.input[0] new_edge = "edge_modified_" + unique_index - shape_tensor = helper.make_tensor( + + shape_tensor = self.add_initializer( name="shape_modified_tensor" + unique_index, data_type=TensorProto.INT64, dims=[4], - vals=np.int64([0, 0, q_num_heads, int(q_hidden_size / q_num_heads)]).tobytes(), - raw=True, + vals=np.int64([0, 0, q_num_heads, int(q_hidden_size / q_num_heads)]), + raw=False, ) - self.model.add_initializer(shape_tensor, self.this_graph_name) + self.model.add_node( helper.make_node( "Reshape", diff --git a/onnxruntime/python/tools/transformers/fusion_attention_unet.py b/onnxruntime/python/tools/transformers/fusion_attention_unet.py index 902b1f4f9549e..250ec5f3eb159 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_unet.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_unet.py @@ -210,15 +210,13 @@ def create_attention_node( ) matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_QKV") - weight = helper.make_tensor( + self.add_initializer( name=matmul_node_name + "_weight", data_type=TensorProto.FLOAT, dims=[qkv_weight.shape[0], qkv_weight.shape[1]], - vals=qkv_weight.flatten().tolist(), + vals=qkv_weight, ) - self.model.add_initializer(weight, self.this_graph_name) - matmul_node = helper.make_node( "MatMul", inputs=[k_matmul.input[0], matmul_node_name + "_weight"], @@ -227,13 +225,13 @@ def create_attention_node( ) self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name - shape_tensor = helper.make_tensor( + self.add_initializer( name=matmul_node_name + "_reshape_shape", data_type=TensorProto.INT64, dims=[5], vals=[0, 0, n, 3, h], + raw=False, ) - self.model.add_initializer(shape_tensor, self.this_graph_name) reshape_node = helper.make_node( "Reshape", @@ -251,14 +249,12 @@ def create_attention_node( attention_node_name = self.model.create_node_name("Attention") - weight = helper.make_tensor( + self.add_initializer( name=attention_node_name + "_qkv_weight", data_type=TensorProto.FLOAT, dims=[qw_in_size, qkv_weight_dim], - vals=qkv_weight.flatten().tolist(), + vals=qkv_weight, ) - - self.model.add_initializer(weight, self.this_graph_name) else: # cross attention attention_node_name = self.model.create_node_name("MultiHeadAttention") if self.enable_packed_kv: @@ -282,15 +278,13 @@ def create_attention_node( kv_weight = np.dstack([kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(c, n * 2 * h) matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV") - weight = helper.make_tensor( + self.add_initializer( name=matmul_node_name + "_weight", data_type=TensorProto.FLOAT, dims=[kv_weight.shape[0], kv_weight.shape[1]], - vals=kv_weight.flatten().tolist(), + vals=kv_weight, ) - self.model.add_initializer(weight, self.this_graph_name) - matmul_node = helper.make_node( "MatMul", inputs=[k_matmul.input[0], matmul_node_name + "_weight"], @@ -299,13 +293,13 @@ def create_attention_node( ) self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name - shape_tensor = helper.make_tensor( + self.add_initializer( name=matmul_node_name + "_reshape_shape", data_type=TensorProto.INT64, dims=[5], vals=[0, 0, n, 2, h], + raw=False, ) - self.model.add_initializer(shape_tensor, self.this_graph_name) reshape_node = helper.make_node( "Reshape", @@ -321,13 +315,12 @@ def create_attention_node( qkv_bias = np.zeros([3, hidden_size], dtype=np.float32) qkv_bias_dim = 3 * hidden_size - bias = helper.make_tensor( + self.add_initializer( name=attention_node_name + "_qkv_bias", data_type=TensorProto.FLOAT, dims=[qkv_bias_dim], - vals=qkv_bias.flatten().tolist(), + vals=qkv_bias, ) - self.model.add_initializer(bias, self.this_graph_name) if is_self_attention: if not self.enable_packed_qkv: @@ -519,15 +512,13 @@ def create_attention_node_lora( ) matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_QKV") - weight = helper.make_tensor( + self.add_initializer( name=matmul_node_name + "_weight", data_type=TensorProto.FLOAT, dims=[qkv_weight.shape[0], qkv_weight.shape[1]], - vals=qkv_weight.flatten().tolist(), + vals=qkv_weight, ) - self.model.add_initializer(weight, self.this_graph_name) - matmul_node = helper.make_node( "MatMul", inputs=[k_matmul.input[0], matmul_node_name + "_weight"], @@ -539,13 +530,14 @@ def create_attention_node_lora( # Do the same thing with the LoRA weights, but don't constant fold the result. The goal is to allow # the Q/K/V weights to be changed without having to re-run the optimizer. lora_weight_shape_tensor_name = q_lora_last_node.name + "_reshape_shape" - lora_weight_shape_tensor = helper.make_tensor( + + self.add_initializer( name=lora_weight_shape_tensor_name, data_type=TensorProto.INT64, dims=[4], vals=[0, 0, n, h], + raw=False, ) - self.model.add_initializer(lora_weight_shape_tensor, self.this_graph_name) # Reshape the LoRA Q weights q_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_Q") @@ -594,13 +586,13 @@ def create_attention_node_lora( # Reshape the LoRA concatenated weights to [..., n * 3 * h] reshaped_lora_weights_shape_tensor_name = qkv_lora_concat_node.name + "_reshape_shape" - reshaped_lora_weights_shape_tensor = helper.make_tensor( + self.add_initializer( name=reshaped_lora_weights_shape_tensor_name, data_type=TensorProto.INT64, dims=[3], vals=[0, 0, n * 3 * h], + raw=False, ) - self.model.add_initializer(reshaped_lora_weights_shape_tensor, self.this_graph_name) qkv_lora_reshaped_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_QKV") qkv_lora_reshaped_node = helper.make_node( @@ -623,13 +615,13 @@ def create_attention_node_lora( # Finally, reshape the concatenated Q/K/V result to 5D shape_tensor_name = add_weights_node_name + "_reshape_shape" - shape_tensor = helper.make_tensor( + self.add_initializer( name=shape_tensor_name, data_type=TensorProto.INT64, dims=[5], vals=[0, 0, n, 3, h], + raw=False, ) - self.model.add_initializer(shape_tensor, self.this_graph_name) reshape_node = helper.make_node( "Reshape", @@ -678,15 +670,13 @@ def create_attention_node_lora( kv_weight = np.dstack([kw.reshape(c, n, h), vw.reshape(c, n, h)]).reshape(c, n * 2 * h) matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV") - weight = helper.make_tensor( + self.add_initializer( name=matmul_node_name + "_weight", data_type=TensorProto.FLOAT, dims=[kv_weight.shape[0], kv_weight.shape[1]], - vals=kv_weight.flatten().tolist(), + vals=kv_weight, ) - self.model.add_initializer(weight, self.this_graph_name) - matmul_node = helper.make_node( "MatMul", inputs=[k_matmul.input[0], matmul_node_name + "_weight"], @@ -698,13 +688,13 @@ def create_attention_node_lora( # Do the same thing with the LoRA weights, but don't constant fold the result. The goal is to allow # the Q/K/V weights to be changed without having to re-run the optimizer. kv_lora_weight_shape_tensor_name = q_lora_last_node.name + "_reshape_shape" - lora_weight_shape_tensor = helper.make_tensor( + self.add_initializer( name=kv_lora_weight_shape_tensor_name, data_type=TensorProto.INT64, dims=[4], vals=[0, 0, n, h], + raw=False, ) - self.model.add_initializer(lora_weight_shape_tensor, self.this_graph_name) # Reshape the LoRA K weights k_lora_reshape_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_K") @@ -739,13 +729,13 @@ def create_attention_node_lora( # Reshape the LoRA concatenated weights to [..., n * 2 * h] reshaped_kv_lora_weights_shape_tensor_name = kv_lora_concat_node.name + "_reshape_shape" - reshaped_kv_lora_weights_shape_tensor = helper.make_tensor( + self.add_initializer( name=reshaped_kv_lora_weights_shape_tensor_name, data_type=TensorProto.INT64, dims=[3], vals=[0, 0, n * 2 * h], + raw=False, ) - self.model.add_initializer(reshaped_kv_lora_weights_shape_tensor, self.this_graph_name) kv_lora_reshaped_node_name = self.model.create_node_name("Reshape", name_prefix="Reshape_LoRA_KV") kv_lora_reshaped_node = helper.make_node( @@ -768,13 +758,13 @@ def create_attention_node_lora( # Finally, reshape the concatenated K/V result to 5D shape_tensor_name = add_kv_weights_node_name + "_reshape_shape" - shape_tensor = helper.make_tensor( + self.add_initializer( name=shape_tensor_name, data_type=TensorProto.INT64, dims=[5], vals=[0, 0, n, 2, h], + raw=False, ) - self.model.add_initializer(shape_tensor, self.this_graph_name) reshape_node = helper.make_node( "Reshape", @@ -802,14 +792,12 @@ def create_attention_node_lora( # No bias, use zeros qkv_bias = np.zeros([3, hidden_size], dtype=np.float32) qkv_bias_dim = 3 * hidden_size - - bias = helper.make_tensor( + self.add_initializer( name=attention_node_name + "_qkv_bias", data_type=TensorProto.FLOAT, dims=[qkv_bias_dim], - vals=qkv_bias.flatten().tolist(), + vals=qkv_bias, ) - self.model.add_initializer(bias, self.this_graph_name) if is_self_attention: if not self.enable_packed_qkv: diff --git a/onnxruntime/python/tools/transformers/fusion_attention_vae.py b/onnxruntime/python/tools/transformers/fusion_attention_vae.py index e91a8a61fcc24..151c04f9334fe 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_vae.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_vae.py @@ -170,26 +170,23 @@ def create_attention_node( qkv_bias = np.stack((q_bias, k_bias, v_bias), axis=0) qkv_bias_dim = 3 * q_bias_shape - weight = helper.make_tensor( + self.add_initializer( name=attention_node_name + "_qkv_weight", data_type=TensorProto.FLOAT, dims=[qw_in_size, qkv_weight_dim], - vals=qkv_weight.flatten().tolist(), + vals=qkv_weight, ) - self.model.add_initializer(weight, self.this_graph_name) - # No bias, use zeros qkv_bias = np.zeros([3, hidden_size], dtype=np.float32) qkv_bias_dim = 3 * hidden_size - bias = helper.make_tensor( + self.add_initializer( name=attention_node_name + "_qkv_bias", data_type=TensorProto.FLOAT, dims=[qkv_bias_dim], - vals=qkv_bias.flatten().tolist(), + vals=qkv_bias, ) - self.model.add_initializer(bias, self.this_graph_name) attention_inputs = [ input_name, diff --git a/onnxruntime/python/tools/transformers/fusion_bart_attention.py b/onnxruntime/python/tools/transformers/fusion_bart_attention.py index 513c68a29dbd1..71801401e9d06 100644 --- a/onnxruntime/python/tools/transformers/fusion_bart_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_bart_attention.py @@ -4,6 +4,7 @@ # -------------------------------------------------------------------------- import logging +import numpy as np from fusion_attention import AttentionMask, FusionAttention from onnx import TensorProto, helper from onnx_model import OnnxModel @@ -259,8 +260,12 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): empty_bias_name = "empty_bias" empty_tensor = self.model.get_initializer(empty_bias_name) if empty_tensor is None: - empty_tensor = helper.make_tensor(empty_bias_name, TensorProto.FLOAT, [bias_dim], [0.0] * bias_dim) - self.model.add_initializer(empty_tensor, self.this_graph_name) + self.add_initializer( + empty_bias_name, + TensorProto.FLOAT, + dims=[bias_dim], + vals=np.array([0.0] * bias_dim, dtype=np.float32), + ) add_name = self.model.create_node_name("Add") add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k_1.name], add_name) diff --git a/onnxruntime/python/tools/transformers/fusion_base.py b/onnxruntime/python/tools/transformers/fusion_base.py index d53a2f4ba4d2b..117468be412fa 100644 --- a/onnxruntime/python/tools/transformers/fusion_base.py +++ b/onnxruntime/python/tools/transformers/fusion_base.py @@ -4,9 +4,10 @@ # -------------------------------------------------------------------------- from collections import defaultdict from logging import getLogger -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Sequence, Union -from onnx import NodeProto +import numpy as np +from onnx import NodeProto, helper from onnx_model import OnnxModel logger = getLogger(__name__) @@ -86,3 +87,29 @@ def apply(self): self.model.prune_graph() elif self.nodes_to_remove or self.nodes_to_add: self.model.update_graph() + + def add_initializer(self, name: str, data_type: int, dims: Sequence[int], vals: Any, raw: bool = True): + if raw: + np_type = helper.tensor_dtype_to_np_dtype(data_type) + if not isinstance(vals, np.ndarray): + bytes = np.array(vals, dtype=np_type).tobytes() + else: + bytes = vals.astype(np_type).tobytes() + tensor = helper.make_tensor( + name=name, + data_type=data_type, + dims=dims, + vals=bytes, + raw=True, + ) + else: + tensor = helper.make_tensor( + name=name, + data_type=data_type, + dims=dims, + vals=vals, + raw=False, + ) + + self.model.add_initializer(tensor, self.this_graph_name) + return tensor diff --git a/onnxruntime/python/tools/transformers/fusion_gpt_attention.py b/onnxruntime/python/tools/transformers/fusion_gpt_attention.py index 7b9e758178e2d..a3f98d411ebad 100644 --- a/onnxruntime/python/tools/transformers/fusion_gpt_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_gpt_attention.py @@ -239,7 +239,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): [0, None, 0, 0, 0, 0, 0], output_name_to_node=output_name_to_node, return_indice=return_indice, - ) # yapf: disable + ) else: qkv_nodes = self.model.match_parent_path( normalize_node, @@ -247,7 +247,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): [None, 0, 0, 0, 0, 0], output_name_to_node=output_name_to_node, return_indice=return_indice, - ) # yapf: disable + ) if qkv_nodes is None: return @@ -361,7 +361,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): "Div", ], [1, 0, 1, 0, 1, 0, 0, 0, 0, 0], - ) # yapf: disable + ) if mask_nodes is None: logger.debug("fuse_attention: failed to match unidirectional mask path") return @@ -414,7 +414,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ), # useless cast and reshape are removed. ], output_name_to_node, - ) # yapf: disable + ) if input_mask_nodes is None: logger.debug("fuse_attention: failed to match input attention mask path") return @@ -437,7 +437,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ), ], output_name_to_node, - ) # yapf: disable + ) if mask_nodes is None: # TODO: match mask path for GPT2LMHeadModel_BeamSearchStep. logger.debug("fuse_attention: failed to match mask path") diff --git a/onnxruntime/python/tools/transformers/fusion_gpt_attention_megatron.py b/onnxruntime/python/tools/transformers/fusion_gpt_attention_megatron.py index 052dd243fd788..7eb774b746cac 100644 --- a/onnxruntime/python/tools/transformers/fusion_gpt_attention_megatron.py +++ b/onnxruntime/python/tools/transformers/fusion_gpt_attention_megatron.py @@ -72,9 +72,7 @@ def fuse_attention_node( self.prune_graph = True def match_mask(self, sub_qk, mul_qk, matmul_qk, layernorm_before_attention): - mask_nodes = self.model.match_parent_path( - sub_qk, ["Mul", "Sub", "Slice", "Slice"], [1, 0, 1, 0] - ) # yapf: disable + mask_nodes = self.model.match_parent_path(sub_qk, ["Mul", "Sub", "Slice", "Slice"], [1, 0, 1, 0]) if mask_nodes is None: logger.debug("fuse_attention: failed to match unidirectional mask path") return None @@ -176,14 +174,14 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): ["Add", "Add", "MatMul", "Reshape", "Transpose", "MatMul"], [0, 1, None, 0, 0, 0], output_name_to_node=output_name_to_node, - ) # yapf: disable + ) else: qkv_nodes = self.model.match_parent_path( normalize_node, ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], [1, None, 0, 0, 0], output_name_to_node=output_name_to_node, - ) # yapf: disable + ) if qkv_nodes is None: return @@ -223,7 +221,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): "LayerNormalization", ], [1, 1, 0, 0, 0, None, 0], - ) # yapf: disable + ) if v_nodes is None: v_nodes = self.model.match_parent_path( @@ -238,7 +236,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): "SkipLayerNormalization", ], [1, 1, 0, 0, 0, None, 0], - ) # yapf: disable + ) if v_nodes is None: logger.debug("fuse_attention: failed to match v path") diff --git a/onnxruntime/python/tools/transformers/fusion_gpt_attention_no_past.py b/onnxruntime/python/tools/transformers/fusion_gpt_attention_no_past.py index 83fa51dcfafa6..b217743c4ab14 100644 --- a/onnxruntime/python/tools/transformers/fusion_gpt_attention_no_past.py +++ b/onnxruntime/python/tools/transformers/fusion_gpt_attention_no_past.py @@ -76,7 +76,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): [0, None, 0, 0, 0, 0, 0], output_name_to_node=output_name_to_node, return_indice=return_indice, - ) # yapf: disable + ) else: qkv_nodes = self.model.match_parent_path( normalize_node, @@ -84,7 +84,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): [None, 0, 0, 0, 0, 0], output_name_to_node=output_name_to_node, return_indice=return_indice, - ) # yapf: disable + ) if qkv_nodes is None: return @@ -116,7 +116,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): matmul_qkv, ["Transpose", "Reshape", "Split", "Reshape", "Gemm", "Reshape"], [1, 0, 0, 0, 0, 0], - ) # yapf: disable + ) if v_nodes is None: logger.debug("fuse_attention: failed to match v path") return @@ -168,7 +168,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): "Div", ], [1, 0, 1, 0, 1, 0, 0, 0, 0, 0], - ) # yapf: disable + ) if mask_nodes is None: logger.debug("fuse_attention: failed to match mask path") return @@ -201,7 +201,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): "Div", ], [0, 0, 0, 1, 0, 0, 0, 0, 0], - ) # yapf: disable + ) if mask_nodes is None: logger.debug("fuse_attention: failed to match mask path") return @@ -225,7 +225,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): mul_qk, ["Slice", "Slice", "Unsqueeze", "Squeeze", "Slice", "Shape", "Div"], [1, 0, 2, 0, 0, 0, 0], - ) # yapf: disable + ) if mask_nodes is None: logger.debug("fuse_attention: failed to match mask path") return diff --git a/onnxruntime/python/tools/transformers/fusion_group_norm.py b/onnxruntime/python/tools/transformers/fusion_group_norm.py index 2cae366d3f9bd..a4491d29b3698 100644 --- a/onnxruntime/python/tools/transformers/fusion_group_norm.py +++ b/onnxruntime/python/tools/transformers/fusion_group_norm.py @@ -107,21 +107,19 @@ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict): if weight_elements not in [320, 640, 960, 1280, 1920, 2560, 128, 256, 512]: logger.info("GroupNorm channels=%d", weight_elements) - gamma = helper.make_tensor( + self.add_initializer( name=group_norm_name + "_gamma", data_type=TensorProto.FLOAT, dims=[weight_elements], - vals=weight.flatten().tolist(), + vals=weight, ) - self.model.add_initializer(gamma, self.this_graph_name) - beta = helper.make_tensor( + self.add_initializer( name=group_norm_name + "_beta", data_type=TensorProto.FLOAT, dims=[bias_elements], - vals=bias.flatten().tolist(), + vals=bias, ) - self.model.add_initializer(beta, self.this_graph_name) last_node = add_node subgraph_nodes = [add_node, weight_mul, reshape_4d, instance_norm, reshape_3d, shape_node] diff --git a/onnxruntime/python/tools/transformers/fusion_layernorm.py b/onnxruntime/python/tools/transformers/fusion_layernorm.py index ec485e0dfaac0..68d26fc46fa23 100644 --- a/onnxruntime/python/tools/transformers/fusion_layernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_layernorm.py @@ -187,7 +187,7 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): ), ], output_name_to_node, - ) # yapf: disable + ) if parent_nodes is None: return diff --git a/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py b/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py index d8ecb652800f6..141ebb1f95a11 100644 --- a/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py +++ b/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py @@ -54,13 +54,12 @@ def fuse(self, conv, input_name_to_nodes, output_name_to_node): weight = weight.transpose(0, 2, 3, 1) weight_name = node_name + "_weight_NHWC" - nhwc_weight = helper.make_tensor( + self.add_initializer( name=weight_name, data_type=TensorProto.FLOAT, dims=list(weight.shape), - vals=weight.flatten().tolist(), + vals=weight, ) - self.model.add_initializer(nhwc_weight, self.this_graph_name) weight_transpose_node = None else: weight_transpose_node = self.create_transpose_node(conv.input[1], [0, 2, 3, 1]) diff --git a/onnxruntime/python/tools/transformers/fusion_transpose.py b/onnxruntime/python/tools/transformers/fusion_transpose.py index 6602d168309f0..2762d95dd7b00 100644 --- a/onnxruntime/python/tools/transformers/fusion_transpose.py +++ b/onnxruntime/python/tools/transformers/fusion_transpose.py @@ -139,23 +139,23 @@ def fuse( # Here we use hard-coded name so that it could be shared for the whole model. axes_1 = "ort_const_unsqueeze_axes_1" if self.model.get_initializer(axes_1) is None: - axes_1_tensor = helper.make_tensor( + self.add_initializer( name=axes_1, data_type=TensorProto.INT64, dims=[1], vals=[1], + raw=False, ) - self.model.add_initializer(axes_1_tensor, self.this_graph_name) axes_2 = "ort_const_unsqueeze_axes_2" if self.model.get_initializer(axes_2) is None: - axes_2_tensor = helper.make_tensor( + self.add_initializer( name=axes_2, data_type=TensorProto.INT64, dims=[1], vals=[2], + raw=False, ) - self.model.add_initializer(axes_2_tensor, self.this_graph_name) unsqueeze_3.input[1] = "ort_const_unsqueeze_axes_2" unsqueeze_2.input[1] = "ort_const_unsqueeze_axes_1" diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py b/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py index 1229825fec3d4..c781a91c9e493 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py @@ -435,7 +435,7 @@ def remove_extra_reshape_2(self): "SkipLayerNormalization", ], [None, 0, 0, 0, 0, 0, 0, 0, 0, 0], - ) # yapf: disable + ) if path is None: continue diff --git a/onnxruntime/python/tools/transformers/onnx_model_t5.py b/onnxruntime/python/tools/transformers/onnx_model_t5.py index 8fb31da4a61f7..ab6a7c72a2c7a 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_t5.py +++ b/onnxruntime/python/tools/transformers/onnx_model_t5.py @@ -111,7 +111,8 @@ def create_attention_node( name=attention_node_name + "_qkv_weight", data_type=TensorProto.FLOAT, dims=[qw_in_size, qkv_weight_dim], - vals=qkv_weight.flatten().tolist(), + vals=qkv_weight.tobytes(), + raw=True, ) self.model.add_initializer(weight, self.this_graph_name) @@ -665,7 +666,8 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): name=self.model.create_node_name("bias_table_weight", name_prefix=node_name_prefix), data_type=TensorProto.FLOAT, dims=[np.shape(table_weight)[0], np.shape(table_weight)[1]], - vals=table_weight_t.flatten().tolist(), + vals=table_weight_t.tobytes(), + raw=True, ) self.model.add_initializer(bias_table, self.this_graph_name) diff --git a/onnxruntime/python/tools/transformers/onnx_model_tnlr.py b/onnxruntime/python/tools/transformers/onnx_model_tnlr.py index d1815394e9661..98235de6ba6fd 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_tnlr.py +++ b/onnxruntime/python/tools/transformers/onnx_model_tnlr.py @@ -5,10 +5,9 @@ import logging from typing import Union -import numpy as np from fusion_attention import AttentionMask, FusionAttention from fusion_utils import NumpyHelper -from onnx import NodeProto, TensorProto, helper, numpy_helper +from onnx import NodeProto, helper from onnx_model import OnnxModel from onnx_model_bert import BertOnnxModel @@ -57,26 +56,24 @@ def create_attention_node( attention_node_name = self.model.create_node_name("Attention") + tensor_dtype = weight.data_type + np_type = helper.tensor_dtype_to_np_dtype(tensor_dtype) weight = helper.make_tensor( name=attention_node_name + "_qkv_weight", - data_type=TensorProto.FLOAT, + data_type=tensor_dtype, dims=[hidden_size, 3 * hidden_size], - vals=qkv_weight.flatten().tolist(), + vals=qkv_weight.astype(np_type).tobytes(), + raw=True, ) - - # Sometimes weights and bias are stored in fp16 - if weight.data_type == 10: - weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name)) self.model.add_initializer(weight, self.this_graph_name) bias = helper.make_tensor( name=attention_node_name + "_qkv_bias", - data_type=TensorProto.FLOAT, + data_type=tensor_dtype, dims=[3 * hidden_size], - vals=qkv_bias.flatten().tolist(), + vals=qkv_bias.astype(np_type).tobytes(), + raw=True, ) - if bias.data_type == 10: - bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name)) self.model.add_initializer(bias, self.this_graph_name) attention_inputs = [ diff --git a/onnxruntime/test/python/transformers/test_attention_fusion.py b/onnxruntime/test/python/transformers/test_attention_fusion.py index 2edc2ec06d631..76d1dcf013321 100644 --- a/onnxruntime/test/python/transformers/test_attention_fusion.py +++ b/onnxruntime/test/python/transformers/test_attention_fusion.py @@ -31,7 +31,18 @@ def verify_fusion(self, optimized_model, expected_model_filename): expected_model = OnnxModel(onnx.load(expected_model_path)) expected_model.topological_sort(is_deterministic=True) - self.assertEqual(str(optimized_model.model.graph), str(expected_model.model.graph)) + nodes = optimized_model.model.graph.node + self.assertEqual(len(nodes), len(expected_model.model.graph.node)) + + for i in range(len(nodes)): + self.assertEqual(nodes[i], expected_model.model.graph.node[i]) + + for expected_initializer in expected_model.model.graph.initializer: + self.assertTrue( + OnnxModel.has_same_value( + optimized_model.get_initializer(expected_initializer.name), expected_initializer + ) + ) def test_multi_head_attention_fusion(self): model = create_bert_attention() diff --git a/onnxruntime/test/python/transformers/test_parity_huggingface_gpt_attention.py b/onnxruntime/test/python/transformers/test_parity_huggingface_gpt_attention.py index ad4117f997567..85b30bea4f0af 100644 --- a/onnxruntime/test/python/transformers/test_parity_huggingface_gpt_attention.py +++ b/onnxruntime/test/python/transformers/test_parity_huggingface_gpt_attention.py @@ -339,7 +339,7 @@ def verify_attention( ort_outputs = onnxruntime_inference(ort_session, input_hidden_states, attention_mask, layer_past) - tolerance = 1e-03 if float16 else 1e-05 + tolerance = 1e-02 if float16 else 1e-04 is_all_close, max_diff = compare_outputs(torch_outputs, ort_outputs, atol=tolerance, verbose=True) max_diffs.append(max_diff) if is_all_close: diff --git a/onnxruntime/test/python/transformers/test_whisper.py b/onnxruntime/test/python/transformers/test_whisper.py index a2aa6383c2fbe..ebda0bccaadcf 100644 --- a/onnxruntime/test/python/transformers/test_whisper.py +++ b/onnxruntime/test/python/transformers/test_whisper.py @@ -37,7 +37,18 @@ def verify_fusion(self, optimized_model, expected_model_filename): expected_model = OnnxModel(onnx.load(expected_model_path)) expected_model.topological_sort(is_deterministic=True) - self.assertEqual(str(optimized_model.model.graph), str(expected_model.model.graph)) + nodes = optimized_model.model.graph.node + self.assertEqual(len(nodes), len(expected_model.model.graph.node)) + + for i in range(len(nodes)): + self.assertEqual(nodes[i], expected_model.model.graph.node[i]) + + for expected_initializer in expected_model.model.graph.initializer: + self.assertTrue( + OnnxModel.has_same_value( + optimized_model.get_initializer(expected_initializer.name), expected_initializer + ) + ) # Attention type #1 in onnx_model_bart.py def test_encoder_attention_fusion_with_skiplayernorm(self):