From 013f55100fa2f6278cf916f2b7c9ebdb0872c087 Mon Sep 17 00:00:00 2001 From: Your Date: Fri, 26 Jan 2024 06:08:37 +0000 Subject: [PATCH] update --- .../models/phi2/convert_to_onnx.py | 228 +++++++++++++----- .../tools/transformers/onnx_model_phi.py | 228 ++++++------------ 2 files changed, 244 insertions(+), 212 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py index 51f5481c270cf..73643faf6b578 100644 --- a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py @@ -2,17 +2,19 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. import argparse +import logging import onnx +import os import torch from enum import Enum from onnx import ModelProto, TensorProto, helper +from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer from transformers import AutoConfig, AutoModelForCausalLM -from typing import List # -------------------------------------------------------------------------- # The following code is used when this file is not in the ORT package -import sys, os +import sys sys.path.append(os.path.dirname(__file__)) @@ -25,14 +27,20 @@ class AttentionOpType(Enum): - Attention = "attention" - MultiHeadAttention = "mha" - GroupQueryAttention = "gqa" + Attention = "Attention" + MultiHeadAttention = "MultiHeadAttention" + GroupQueryAttention = "GroupQueryAttention" def __str__(self): return self.value +def env_reset(): + for flag in ["AttentionOpType"]: + if flag in os.environ: + del os.environ[flag] + + class ConvertPhi2ToONNX: def __init__( self, @@ -48,11 +56,17 @@ def __init__( self.batch_size = 2 self.sequence_length = 8 self.phi2_edge_dict = self.__get_phi2_edge_dict(self.phi_config) + self.attn_op_type = None + self.precision = None + self.optimized_model = False def init_attn_type_and_precision(self, attn_op_type: AttentionOpType, precision: Precision): self.attn_op_type = attn_op_type self.precision = precision + env_reset() + os.environ["AttentionOpType"] = str(attn_op_type) + def __get_phi2_edge_dict(self, config: AutoConfig) -> dict: edge_dict = {} edge_dict["lm_head_1"] = "logits" @@ -77,6 +91,7 @@ def __simplify_phi2_op_type(self, onnx_model: ModelProto): def __process_graph_io(self, config: AutoConfig, onnx_model: ModelProto): use_gqa = self.attn_op_type == AttentionOpType.GroupQueryAttention + use_attn = self.attn_op_type == AttentionOpType.Attention graph = onnx_model.graph new_inputs = [] for i, vi in enumerate(graph.input): @@ -96,19 +111,34 @@ def __process_graph_io(self, config: AutoConfig, onnx_model: ModelProto): elem_type=TensorProto.INT64 if use_gqa else TensorProto.INT32, shape=["batch_size", "seq_len"], ) - new_inputs.extend([vi, vi_pid, vi_mask]) - if "past_key" in vi.name or "past_value" in vi.name: - vi_cache = helper.make_tensor_value_info( - vi.name, - elem_type=vi.type.tensor_type.elem_type, - shape=[ - "batch_size", - config.num_attention_heads, - "past_seq_len", - config.hidden_size // config.num_attention_heads, - ], - ) - new_inputs.extend([vi_cache]) + new_inputs.extend([vi, vi_pid, vi_mask]) if not use_attn else new_inputs.extend([vi, vi_mask]) + if not use_attn: + if "past_key" in vi.name or "past_value" in vi.name: + vi_cache = helper.make_tensor_value_info( + vi.name, + elem_type=vi.type.tensor_type.elem_type, + shape=[ + "batch_size", + config.num_attention_heads, + "past_seq_len", + config.hidden_size // config.num_attention_heads, + ], + ) + new_inputs.extend([vi_cache]) + else: + if "past_key" in vi.name: + vi_cache = helper.make_tensor_value_info( + vi.name.replace("past_key", "past"), + elem_type=vi.type.tensor_type.elem_type, + shape=[ + 2, + "batch_size", + config.num_attention_heads, + "seq_len", + config.hidden_size // config.num_attention_heads, + ], + ) + new_inputs.extend([vi_cache]) graph.ClearField("input") graph.input.extend(new_inputs) @@ -119,18 +149,34 @@ def __process_graph_io(self, config: AutoConfig, onnx_model: ModelProto): vi = helper.make_tensor_value_info( vi.name, elem_type=vi.type.tensor_type.elem_type, shape=["batch_size", "seq_len", config.vocab_size] ) + new_outputs.extend([vi]) else: - vi = helper.make_tensor_value_info( - vi.name, - elem_type=vi.type.tensor_type.elem_type, - shape=[ - "batch_size", - config.num_attention_heads, - "total_seq_len", - config.hidden_size // config.num_attention_heads, - ], - ) - new_outputs.extend([vi]) + if not use_attn: + vi = helper.make_tensor_value_info( + vi.name, + elem_type=vi.type.tensor_type.elem_type, + shape=[ + "batch_size", + config.num_attention_heads, + "total_seq_len", + config.hidden_size // config.num_attention_heads, + ], + ) + new_outputs.extend([vi]) + else: + if "present_key" in vi.name: + vi = helper.make_tensor_value_info( + vi.name.replace("present_key", "present"), + elem_type=vi.type.tensor_type.elem_type, + shape=[ + 2, + "batch_size", + config.num_attention_heads, + "seq_len", + config.hidden_size // config.num_attention_heads, + ], + ) + new_outputs.extend([vi]) graph.ClearField("output") graph.output.extend(new_outputs) @@ -162,6 +208,7 @@ def __unroll_function(self, model: onnx.ModelProto, func_name: str) -> onnx.Mode """ Unrolls the function with the given name in the model. """ + logging.info(f"Unrolling function {func_name}...") nodes_to_remove = [] nodes_to_add = [] edges_to_remove = [] @@ -203,6 +250,7 @@ def __remove_dropout_layer(self, model: onnx.ModelProto): """ Removes the dropout layer in the model. """ + logging.info("Removing dropout layer...") edge_mapping = {} nodes_to_remove = [] for node in model.graph.node: @@ -217,6 +265,7 @@ def __remove_dropout_layer(self, model: onnx.ModelProto): return self.__update_edges(model, edge_mapping) def __get_phi2_torch_model(self): + logging.info("Loading phi2 torch model...") if self.phi_model is not None: return self.phi_model = AutoModelForCausalLM.from_pretrained( @@ -225,9 +274,16 @@ def __get_phi2_torch_model(self): self.phi_model.eval() self.phi_model.to(self.device) + def optimized_model_exists(self): + return self.optimized_model + def get_phi2_torch_inputs(self, batch_size: int, sequence_length: int): input_ids = torch.randint( - low=0, high=self.phi_config.vocab_size, size=(batch_size, sequence_length), dtype=torch.int64, device=device + low=0, + high=self.phi_config.vocab_size, + size=(batch_size, sequence_length), + dtype=torch.int64, + device=self.device, ) self.__get_phi2_torch_model() torch_inputs = self.phi_model.prepare_inputs_for_generation( @@ -239,14 +295,17 @@ def erase_onnx_model(self, onnx_path: str): assert onnx_path.endswith(".onnx") if not os.path.exists(onnx_path): return + model = onnx.load_model(onnx_path, load_external_data=False) onnx_data_path = None for initializer in model.graph.initializer: if initializer.data_location == 1 and initializer.external_data[0].key == "location": onnx_data_path = "./" + initializer.external_data[0].value break + logging.info(f"Erasing {onnx_path}...") os.remove(onnx_path) if onnx_data_path is not None: + logging.info(f"Erasing {onnx_data_path}...") os.remove(onnx_data_path) def dynamo_export(self, onnx_path: str): @@ -256,6 +315,8 @@ def dynamo_export(self, onnx_path: str): from torch._dynamo import config config.capture_scalar_outputs = True + + logging.info("Exporting Phi2 torch model to ONNX...") torch.onnx.dynamo_export( self.phi_model, input_ids, @@ -277,7 +338,6 @@ def preprocess_onnx(self, onnx_path_in: str, onnx_path_out: str, func_name: str) model = self.__unroll_function(model, function_name) model = self.__update_edges(model, self.phi2_edge_dict) model = self.__simplify_phi2_op_type(model) - model = self.__process_graph_io(self.phi_config, model) model = self.__remove_dropout_layer(model) onnx.save_model( model, @@ -287,13 +347,26 @@ def preprocess_onnx(self, onnx_path_in: str, onnx_path_out: str, func_name: str) location=onnx_path_out + ".data", ) - def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str, use_fp16: bool = False): + def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str): + self.optimized_model = True + from fusion_options import FusionOptions from optimizer import optimize_model + processed_onnx_path = "phi2_processed.onnx" + model = onnx.load_model(onnx_path, load_external_data=True) + model = self.__process_graph_io(self.phi_config, model) + onnx.save_model( + model, + processed_onnx_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + location=processed_onnx_path + ".data", + ) + optimization_options = FusionOptions("phi") optimizer = optimize_model( - onnx_path, + processed_onnx_path, model_type="phi", num_heads=self.phi_config.num_attention_heads, hidden_size=self.phi_config.hidden_size, @@ -302,17 +375,45 @@ def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str, use_fp16: bool only_onnxruntime=False, ) - if use_fp16: - node_block_list = ["GroupQueryAttention_0_29", "GroupQueryAttention_0_30", "GroupQueryAttention_0_31"] + self.erase_onnx_model(processed_onnx_path) + + if self.precision == Precision.FLOAT32: + optimizer.save_model_to_file(onnx_path_opt, use_external_data_format=True) + return + + if ( + self.precision == Precision.FLOAT16 or self.precision == Precision.INT4 + ) and self.attn_op_type != AttentionOpType.MultiHeadAttention: + node_block_list = [ + "GroupQueryAttention_29", + "GroupQueryAttention_30", + "GroupQueryAttention_31", + "Attention_29", + "Attention_30", + "Attention_31", + ] + logging.info(f"Converting onnx model to float16/bfloat16...") optimizer.convert_float_to_float16( keep_io_types=False, node_block_list=node_block_list, use_symbolic_shape_infer=True, - use_bfloat16_as_blocked_nodes_dtype=True, + use_bfloat16_as_blocked_nodes_dtype=True + if self.attn_op_type == AttentionOpType.GroupQueryAttention + else False, ) - optimizer.save_model_to_file(onnx_path_opt, use_external_data_format=True) - optimizer.get_operator_statistics() + if self.precision == Precision.FLOAT16: + optimizer.save_model_to_file(onnx_path_opt, use_external_data_format=True) + return + else: + assert self.precision == Precision.INT4 + quant = MatMul4BitsQuantizer( + model=optimizer.model, + block_size=16, + is_symmetric=True, + ) + quant.process() + quant.model.save_model_to_file(onnx_path_opt, use_external_data_format=True) def parse_arguments(): @@ -322,56 +423,56 @@ def parse_arguments(): "--fp32_cpu", required=False, action="store_true", - help="Generate fp32 onnx model for CPU", + help="Generate fp32 ONNX model for CPU", ) parser.add_argument( "--int4_cpu", required=False, action="store_true", - help="Generate int4 onnx model for CPU", + help="Generate int4 ONNX model for CPU", ) parser.add_argument( "--fp32_gpu", required=False, action="store_true", - help="Generate fp32 onnx model for Nvidia GPUs", + help="Generate fp32 ONNX model for Nvidia GPUs", ) parser.add_argument( "--fp16_gpu", required=False, action="store_true", - help="Generate fp16 onnx model for Nvidia GPUs", + help="Generate fp16 ONNX model for Nvidia GPUs", ) parser.add_argument( "--int4_gpu", required=False, action="store_true", - help="Generate int4 onnx model for Nvidia GPUs", + help="Generate int4 ONNX model for Nvidia GPUs", ) parser.add_argument( "--fp16_a100", required=False, action="store_true", - help="Generate fp16 onnx model for Nvidia A100", + help="Generate fp16 ONNX model for Nvidia A100", ) parser.add_argument( "--int4_a100", required=False, action="store_true", - help="Generate int4 onnx model for Nvidia A100", + help="Generate int4 ONNX model for Nvidia A100", ) parser.add_argument( "--overwrite", required=False, action="store_true", - help="Overwrite existing onnx models", + help="Overwrite existing ONNX models", ) args = parser.parse_args() @@ -394,40 +495,39 @@ def main(): temp_onnx_path, original_onnx_path, func_name="modeling_phi_PhiModel_model_1", # The function to unroll - use_gqa=True, ) converter.erase_onnx_model(temp_onnx_path) - # TODO: support batch export if args.fp32_cpu: converter.init_attn_type_and_precision(AttentionOpType.MultiHeadAttention, Precision.FLOAT32) - converter.optimize_phi2_onnx(original_onnx_path, "fp32_cpu/phi2_opt.onnx") - elif args.int4_cpu: + converter.optimize_phi2_onnx(original_onnx_path, "phi2_fp32_cpu_opt.onnx") + if args.int4_cpu: converter.init_attn_type_and_precision(AttentionOpType.MultiHeadAttention, Precision.INT4) - converter.optimize_phi2_onnx(original_onnx_path, "int4_cpu/phi2_opt.onnx") - elif args.fp32_gpu: + converter.optimize_phi2_onnx(original_onnx_path, "phi2_int4_cpu_opt.onnx") + if args.fp32_gpu: converter.init_attn_type_and_precision(AttentionOpType.Attention, Precision.FLOAT32) - converter.optimize_phi2_onnx(original_onnx_path, "fp32_gpu/phi2_opt.onnx") - elif args.fp16_gpu: + converter.optimize_phi2_onnx(original_onnx_path, "phi2_fp32_gpu_opt.onnx") + if args.fp16_gpu: converter.init_attn_type_and_precision(AttentionOpType.Attention, Precision.FLOAT16) - converter.optimize_phi2_onnx(original_onnx_path, "fp16_gpu/phi2_opt.onnx") - elif args.int4_gpu: + converter.optimize_phi2_onnx(original_onnx_path, "phi2_fp16_gpu_opt.onnx") + if args.int4_gpu: converter.init_attn_type_and_precision(AttentionOpType.Attention, Precision.INT4) - converter.optimize_phi2_onnx(original_onnx_path, "int4_gpu/phi2_opt.onnx") - elif args.fp16_a100: + converter.optimize_phi2_onnx(original_onnx_path, "phi2_int4_gpu_opt.onnx") + if args.fp16_a100: converter.init_attn_type_and_precision(AttentionOpType.GroupQueryAttention, Precision.FLOAT16) - converter.optimize_phi2_onnx(original_onnx_path, "fp16_a100/phi2_opt.onnx") - elif args.int4_a100: + converter.optimize_phi2_onnx(original_onnx_path, "phi2_fp16_a100_opt.onnx") + if args.int4_a100: converter.init_attn_type_and_precision(AttentionOpType.GroupQueryAttention, Precision.INT4) - converter.optimize_phi2_onnx(original_onnx_path, "int4_a100/phi2_opt.onnx") - else: - print( + converter.optimize_phi2_onnx(original_onnx_path, "phi2_int4_a100_opt.onnx") + + if not converter.optimized_model_exists(): + logging.warning( "Please specify a valid option from --fp32_cpu, --int4_cpu, --fp32_gpu, --fp16_gpu, --int4_gpu, --fp16_a100, --int4_a100" ) return # converter.erase_onnx_model(original_onnx_path) - print("done") + logging.info("done") if __name__ == "__main__": diff --git a/onnxruntime/python/tools/transformers/onnx_model_phi.py b/onnxruntime/python/tools/transformers/onnx_model_phi.py index 6bd0559353cfa..7d5388da81418 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_phi.py +++ b/onnxruntime/python/tools/transformers/onnx_model_phi.py @@ -12,6 +12,7 @@ from fusion_options import FusionOptions from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization import numpy as np +import os logger = getLogger(__name__) @@ -244,7 +245,7 @@ def __init__( super().__init__(model, ["torch_nn_modules_sparse_Embedding_model_embed_tokens_1"]) def fuse(self, node, input_name_to_nodes, output_name_to_node): - print(node.name) + logger.info("Optimizing %s...", node.name) assert len(node.input) == 2 assert len(node.output) == 1 @@ -278,7 +279,7 @@ def __init__( super().__init__(model, ["torch_nn_modules_normalization_LayerNorm_model_final_layernorm_1"]) def fuse(self, node, input_name_to_nodes, output_name_to_node): - print(node.name) + logger.info("Optimizing %s...", node.name) assert len(node.input) == 3 assert len(node.output) == 1 @@ -311,7 +312,7 @@ def __init__( super().__init__(model, ["torch_nn_modules_linear_Linear_lm_head_1"]) def fuse(self, node, input_name_to_nodes, output_name_to_node): - print(node.name) + logger.info("Optimizing %s...", node.name) assert len(node.input) == 5 assert len(node.output) == 1 @@ -538,7 +539,10 @@ def fuse( input_name_to_nodes, output_name_to_node, ): - print(node.name) + logger.info("Optimizing %s...", node.name) + + attn_type = os.environ.get("AttentionOpType") + logger.info(f"AttentionOpType: {attn_type}") layer_id = self.get_layer_id(node) @@ -552,27 +556,36 @@ def fuse( ln_weight = self.get_io_by_name(node, "input_layernorm.weight") ln_bias = self.get_io_by_name(node, "input_layernorm.bias") - attn_q_weight = self.process_initializer( - self.get_io_by_name(node, "self_attn.q_proj.weight"), ProcessGemmWFunc() - ) - attn_k_weight = self.process_initializer( - self.get_io_by_name(node, "self_attn.k_proj.weight"), ProcessGemmWFunc() - ) - attn_v_weight = self.process_initializer( - self.get_io_by_name(node, "self_attn.v_proj.weight"), ProcessGemmWFunc() - ) + + if attn_type != "Attention": + attn_q_weight = self.process_initializer( + self.get_io_by_name(node, "self_attn.q_proj.weight"), ProcessGemmWFunc() + ) + attn_k_weight = self.process_initializer( + self.get_io_by_name(node, "self_attn.k_proj.weight"), ProcessGemmWFunc() + ) + attn_v_weight = self.process_initializer( + self.get_io_by_name(node, "self_attn.v_proj.weight"), ProcessGemmWFunc() + ) + attn_q_bias = self.get_io_by_name(node, "self_attn.q_proj.bias") + attn_k_bias = self.get_io_by_name(node, "self_attn.k_proj.bias") + attn_v_bias = self.get_io_by_name(node, "self_attn.v_proj.bias") + + cos_cache = self.process_initializer( + self.get_io_by_name(node, "rotary_emb.cos_cached"), ProcessRotCacheFunc() + ) + sin_cache = self.process_initializer( + self.get_io_by_name(node, "rotary_emb.sin_cached"), ProcessRotCacheFunc() + ) + else: + attn_qkv_weight = None + attn_qkv_bias = None + attn_out_weight = self.process_initializer( self.get_io_by_name(node, "self_attn.dense.weight"), ProcessGemmWFunc() ) - - attn_q_bias = self.get_io_by_name(node, "self_attn.q_proj.bias") - attn_k_bias = self.get_io_by_name(node, "self_attn.k_proj.bias") - attn_v_bias = self.get_io_by_name(node, "self_attn.v_proj.bias") attn_out_bias = self.get_io_by_name(node, "self_attn.dense.bias") - cos_cache = self.process_initializer(self.get_io_by_name(node, "rotary_emb.cos_cached"), ProcessRotCacheFunc()) - sin_cache = self.process_initializer(self.get_io_by_name(node, "rotary_emb.sin_cached"), ProcessRotCacheFunc()) - mlp_fc1_weight = self.process_initializer(self.get_io_by_name(node, "mlp.fc1.weight"), ProcessGemmWFunc()) mlp_fc2_weight = self.process_initializer(self.get_io_by_name(node, "mlp.fc2.weight"), ProcessGemmWFunc()) mlp_fc1_bias = self.get_io_by_name(node, "mlp.fc1.bias") @@ -582,60 +595,64 @@ def fuse( layer_known_edges_names.extend([i_hidden_states, i_key_cache, i_value_cache]) layer_known_edges_names.extend([o_hidden_states, o_key_cache, o_value_cache]) layer_known_edges_names.extend([ln_weight, ln_bias]) + if attn_type != "Attention": + layer_known_edges_names.extend( + [ + attn_q_weight, + attn_q_bias, + attn_k_weight, + attn_k_bias, + attn_v_weight, + attn_v_bias, + cos_cache, + sin_cache, + ] + ) + else: + layer_known_edges_names.extend([attn_qkv_weight, attn_qkv_bias]) layer_known_edges_names.extend( - [ - attn_q_weight, - attn_q_bias, - attn_k_weight, - attn_k_bias, - attn_v_weight, - attn_v_bias, - cos_cache, - sin_cache, - attn_out_weight, - attn_out_bias, - ] + [attn_out_weight, attn_out_bias, mlp_fc1_weight, mlp_fc1_bias, mlp_fc2_weight, mlp_fc2_bias] ) - layer_known_edges_names.extend([mlp_fc1_weight, mlp_fc1_bias, mlp_fc2_weight, mlp_fc2_bias]) layer_known_edges_names.extend(["attention_mask", "step", "seqlens_k", "total_sequence_length"]) subgraph_nodes = [] subgraph_nodes.extend(self.layernorm([i_hidden_states, ln_weight, ln_bias], ["ln_out"])) - subgraph_nodes.extend(self.gemm(["ln_out", attn_q_weight, attn_q_bias], ["query"], "Q_")) - subgraph_nodes.extend(self.gemm(["ln_out", attn_k_weight, attn_k_bias], ["key"], "K_")) - subgraph_nodes.extend(self.gemm(["ln_out", attn_v_weight, attn_v_bias], ["value"], "V_")) - subgraph_nodes.extend(self.rotary(["query", "step", cos_cache, sin_cache], ["query_rot"], "Q_")) - subgraph_nodes.extend(self.rotary(["key", "step", cos_cache, sin_cache], ["key_rot"], "K_")) subgraph_nodes.extend(self.gemm(["attn_out", attn_out_weight, attn_out_bias], ["attn_add_out"], "OutProj_")) subgraph_nodes.extend(self.gemm(["ln_out", mlp_fc1_weight, mlp_fc1_bias], ["fc1_out"], "FC1_")) subgraph_nodes.extend(self.fastgelu(["fc1_out"], ["gelu_out"])) subgraph_nodes.extend(self.gemm(["gelu_out", mlp_fc2_weight, mlp_fc2_bias], ["fc2_out"], "FC2_")) subgraph_nodes.extend(self.add(["attn_add_out", "fc2_out"], ["residual_1_out"], "Residual_1")) subgraph_nodes.extend(self.add([i_hidden_states, "residual_1_out"], [o_hidden_states], "Residual_2")) - - use_mha = False - if use_mha: - subgraph_nodes.extend( - self.mha( - ["query_rot", "key_rot", "value", "", "attention_mask", "", i_key_cache, i_value_cache], - ["attn_out", o_key_cache, o_value_cache], + if attn_type != "Attention": + subgraph_nodes.extend(self.gemm(["ln_out", attn_q_weight, attn_q_bias], ["query"], "Q_")) + subgraph_nodes.extend(self.gemm(["ln_out", attn_k_weight, attn_k_bias], ["key"], "K_")) + subgraph_nodes.extend(self.gemm(["ln_out", attn_v_weight, attn_v_bias], ["value"], "V_")) + subgraph_nodes.extend(self.rotary(["query", "step", cos_cache, sin_cache], ["query_rot"], "Q_")) + subgraph_nodes.extend(self.rotary(["key", "step", cos_cache, sin_cache], ["key_rot"], "K_")) + if attn_type == "MultiHeadAttention": + subgraph_nodes.extend( + self.mha( + ["query_rot", "key_rot", "value", "", "attention_mask", "", i_key_cache, i_value_cache], + ["attn_out", o_key_cache, o_value_cache], + ) ) - ) - else: - subgraph_nodes.extend( - self.gqa( - ["query_rot", "key_rot", "value", i_key_cache, i_value_cache, "seqlens_k", "total_sequence_length"], - ["attn_out", o_key_cache, o_value_cache], - ) - ) - if layer_id == 0: - gqa_aux_nodes = self.get_gqa_aux_nodes() - for new_node in gqa_aux_nodes: - self.nodes_to_add.append(new_node) - self.node_name_to_graph_name[new_node.name] = self.this_graph_name - self.model.add_initializer( - numpy_helper.from_array(np.array([1], dtype="int64"), name="one"), self.this_graph_name + elif attn_type == "GroupQueryAttention": + subgraph_nodes.extend( + self.gqa( + ["query_rot", "key_rot", "value", i_key_cache, i_value_cache, "seqlens_k", "total_sequence_length"], + ["attn_out", o_key_cache, o_value_cache], + ) ) + if layer_id == 0: + gqa_aux_nodes = self.get_gqa_aux_nodes() + for new_node in gqa_aux_nodes: + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + self.model.add_initializer( + numpy_helper.from_array(np.array([1], dtype="int64"), name="one"), self.this_graph_name + ) + else: + print("bugbug") self.set_unique_name_and_add_nodes(subgraph_nodes, layer_id, layer_known_edges_names) @@ -650,88 +667,6 @@ def shape_of(vi): return tuple([d.dim_param if (d.dim_param) else d.dim_value for d in vi.type.tensor_type.shape.dim]) -def postprocess_io(model: ModelProto): - graph = model.graph - new_inputs = [] - for i, vi in enumerate(graph.input): - if "attention_mask" in vi.name: - vi = helper.make_tensor_value_info( - IO_MAPPING[vi.name], - elem_type=TensorProto.INT32, - shape=["batch_size", "seq_len"], - ) - # vi_pid = helper.make_tensor_value_info( - # "step", - # elem_type=TensorProto.INT64, - # shape=[1], - # ) - new_inputs.extend([vi]) - if "input_ids" in vi.name: - vi = helper.make_tensor_value_info( - IO_MAPPING[vi.name], - elem_type=TensorProto.INT32, - shape=["batch_size", "seq_len"], - ) - new_inputs.extend([vi]) - if "kv_cache" in vi.name: - vi = helper.make_tensor_value_info( - IO_MAPPING[vi.name], - elem_type=vi.type.tensor_type.elem_type, - shape=[2, "batch_size", 32, "past_seq_len", 80], - ) - new_inputs.extend([vi]) - # add past_sequence_length - # vi = helper.make_tensor_value_info( - # "past_sequence_length", - # elem_type=TensorProto.INT32, - # shape=[], - # ) - # new_inputs.extend([vi]) - - graph.ClearField("input") - graph.input.extend(new_inputs) - - new_outputs = [] - for i, vi in enumerate(graph.output): - if i == 0: - vi = helper.make_tensor_value_info( - IO_MAPPING[vi.name], elem_type=vi.type.tensor_type.elem_type, shape=["batch_size", "seq_len", 51200] - ) - else: - shape = shape_of(vi) - vi = helper.make_tensor_value_info( - IO_MAPPING[vi.name], - elem_type=vi.type.tensor_type.elem_type, - shape=[2, "batch_size", 32, "total_seq_len", 80], - ) - new_outputs.extend([vi]) - - graph.ClearField("output") - graph.output.extend(new_outputs) - - for node in graph.node: - for i, name in enumerate(node.input): - if name in IO_MAPPING: - node.input[i] = IO_MAPPING[name] - for i, name in enumerate(node.output): - if name in IO_MAPPING: - node.output[i] = IO_MAPPING[name] - - -# def postprocess_value_info(model: ModelProto): -# for value_info in model.graph.value_info: -# shape = shape_of(value_info) -# if len(shape) == 3 and shape[0] == 2: -# print("value info: ", value_info.name, shape) -# new_value_info = helper.make_tensor_value_info( -# value_info.name, -# elem_type=value_info.type.tensor_type.elem_type, -# shape=["batch_size", shape[1], shape[2]], -# ) -# model.graph.value_info.remove(value_info) -# model.graph.value_info.extend([new_value_info]) - - class PhiOnnxModel(OnnxModel): def __init__(self, model: ModelProto, num_heads: int = 0, head_size: int = 0): super().__init__(model) @@ -742,16 +677,13 @@ def __init__(self, model: ModelProto, num_heads: int = 0, head_size: int = 0): self.fuse_sln = FusionSkipLayerNormalization(self) self.fuse_bias_sln = FusionBiasSkipLayerNormalization(self) - def postprocess(self): - print("post process") - # postprocess_io(self.model) - # postprocess_io_split_kv(self.model, True) - def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): self.fission_transformer_block.apply() self.fission_transformer_layernorm.apply() self.fission_causal_lm_head.apply() self.fission_transformer_embedding.apply() + + super().prune_graph() + self.fuse_sln.apply() self.fuse_bias_sln.apply() - self.postprocess()