diff --git a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py index 3749d2e14f537..51f5481c270cf 100644 --- a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py @@ -1,12 +1,14 @@ # ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -from typing import List -import numpy as np -import torch +import argparse import onnx +import torch + +from enum import Enum from onnx import ModelProto, TensorProto, helper from transformers import AutoConfig, AutoModelForCausalLM +from typing import List # -------------------------------------------------------------------------- # The following code is used when this file is not in the ORT package @@ -19,9 +21,25 @@ sys.path.append(transformers_dir) # -------------------------------------------------------------------------- +from benchmark_helper import Precision + + +class AttentionOpType(Enum): + Attention = "attention" + MultiHeadAttention = "mha" + GroupQueryAttention = "gqa" + + def __str__(self): + return self.value + class ConvertPhi2ToONNX: - def __init__(self, model_class: str, device: torch.device, cache_dir: str = "./cache"): + def __init__( + self, + device: torch.device, + model_class: str = "microsoft/phi-2", + cache_dir: str = "./cache", + ): self.model_class = model_class self.device = device self.cache_dir = cache_dir @@ -29,9 +47,12 @@ def __init__(self, model_class: str, device: torch.device, cache_dir: str = "./c self.phi_model = None self.batch_size = 2 self.sequence_length = 8 - self.phi2_edge_dict = self.__get_phi2_edge_dict(self.phi_config) + def init_attn_type_and_precision(self, attn_op_type: AttentionOpType, precision: Precision): + self.attn_op_type = attn_op_type + self.precision = precision + def __get_phi2_edge_dict(self, config: AutoConfig) -> dict: edge_dict = {} edge_dict["lm_head_1"] = "logits" @@ -45,7 +66,7 @@ def __get_phi2_edge_dict(self, config: AutoConfig) -> dict: edge_dict[f"model_layers_{i}_1_1"] = f"present_value_{i}" return edge_dict - def __simplify_phi2_op_type_name(self, onnx_model: ModelProto): + def __simplify_phi2_op_type(self, onnx_model: ModelProto): phi2_transformer_layer_name = "modeling_phi_PhiDecoderLayer_model_layers" for node in onnx_model.graph.node: index = node.op_type.find(phi2_transformer_layer_name) @@ -54,7 +75,8 @@ def __simplify_phi2_op_type_name(self, onnx_model: ModelProto): return onnx_model - def __process_graph_io(self, config: AutoConfig, onnx_model: ModelProto, use_gqa=False): + def __process_graph_io(self, config: AutoConfig, onnx_model: ModelProto): + use_gqa = self.attn_op_type == AttentionOpType.GroupQueryAttention graph = onnx_model.graph new_inputs = [] for i, vi in enumerate(graph.input): @@ -64,16 +86,16 @@ def __process_graph_io(self, config: AutoConfig, onnx_model: ModelProto, use_gqa elem_type=TensorProto.INT32, shape=["batch_size", "seq_len"], ) - vi_mask = helper.make_tensor_value_info( - "attention_mask", - elem_type=TensorProto.INT64 if use_gqa else TensorProto.INT32, - shape=["batch_size", "seq_len"], - ) vi_pid = helper.make_tensor_value_info( "step", elem_type=TensorProto.INT64, shape=[1], ) + vi_mask = helper.make_tensor_value_info( + "attention_mask", + elem_type=TensorProto.INT64 if use_gqa else TensorProto.INT32, + shape=["batch_size", "seq_len"], + ) new_inputs.extend([vi, vi_pid, vi_mask]) if "past_key" in vi.name or "past_value" in vi.name: vi_cache = helper.make_tensor_value_info( @@ -136,9 +158,9 @@ def __update_edges(self, model: onnx.ModelProto, edge_mapping: dict): return model - def __inline_function(self, model: onnx.ModelProto, func_name: str) -> onnx.ModelProto: + def __unroll_function(self, model: onnx.ModelProto, func_name: str) -> onnx.ModelProto: """ - Inlines the function with the given name in the model. + Unrolls the function with the given name in the model. """ nodes_to_remove = [] nodes_to_add = [] @@ -215,6 +237,8 @@ def get_phi2_torch_inputs(self, batch_size: int, sequence_length: int): def erase_onnx_model(self, onnx_path: str): assert onnx_path.endswith(".onnx") + if not os.path.exists(onnx_path): + return model = onnx.load_model(onnx_path, load_external_data=False) onnx_data_path = None for initializer in model.graph.initializer: @@ -242,7 +266,7 @@ def dynamo_export(self, onnx_path: str): onnx.checker.check_model(onnx_path) onnx.shape_inference.infer_shapes_path(onnx_path) - def preprocess_onnx(self, onnx_path_in: str, onnx_path_out: str, func_name: List[str], use_gqa=False): + def preprocess_onnx(self, onnx_path_in: str, onnx_path_out: str, func_name: str): model = onnx.load_model(onnx_path_in, load_external_data=True) function_name = None for func in model.functions: @@ -250,10 +274,10 @@ def preprocess_onnx(self, onnx_path_in: str, onnx_path_out: str, func_name: List function_name = func.name break assert function_name is not None - model = self.__inline_function(model, function_name) + model = self.__unroll_function(model, function_name) model = self.__update_edges(model, self.phi2_edge_dict) - model = self.__simplify_phi2_op_type_name(model) - model = self.__process_graph_io(self.phi_config, model, use_gqa) + model = self.__simplify_phi2_op_type(model) + model = self.__process_graph_io(self.phi_config, model) model = self.__remove_dropout_layer(model) onnx.save_model( model, @@ -279,9 +303,7 @@ def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str, use_fp16: bool ) if use_fp16: - node_block_list = ["GroupQueryAttention_0_29", - "GroupQueryAttention_0_30", - "GroupQueryAttention_0_31"] + node_block_list = ["GroupQueryAttention_0_29", "GroupQueryAttention_0_30", "GroupQueryAttention_0_31"] optimizer.convert_float_to_float16( keep_io_types=False, node_block_list=node_block_list, @@ -293,17 +315,120 @@ def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str, use_fp16: bool optimizer.get_operator_statistics() -model_class = "microsoft/phi-2" -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - -converter = ConvertPhi2ToONNX(model_class, device) -# converter.dynamo_export("phi-2_temp.onnx") -# # TODO:preprocessed onnx model takes up large disk space -# converter.preprocess_onnx( -# "phi-2_temp.onnx", -# "phi-2.onnx", -# "modeling_phi_PhiModel_model_1", -# use_gqa=True, -# ) -# converter.erase_onnx_model("phi-2_temp.onnx") -converter.optimize_phi2_onnx("phi-2.onnx", "phi-2_opt.onnx", use_fp16=True) +def parse_arguments(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--fp32_cpu", + required=False, + action="store_true", + help="Generate fp32 onnx model for CPU", + ) + + parser.add_argument( + "--int4_cpu", + required=False, + action="store_true", + help="Generate int4 onnx model for CPU", + ) + + parser.add_argument( + "--fp32_gpu", + required=False, + action="store_true", + help="Generate fp32 onnx model for Nvidia GPUs", + ) + + parser.add_argument( + "--fp16_gpu", + required=False, + action="store_true", + help="Generate fp16 onnx model for Nvidia GPUs", + ) + + parser.add_argument( + "--int4_gpu", + required=False, + action="store_true", + help="Generate int4 onnx model for Nvidia GPUs", + ) + + parser.add_argument( + "--fp16_a100", + required=False, + action="store_true", + help="Generate fp16 onnx model for Nvidia A100", + ) + + parser.add_argument( + "--int4_a100", + required=False, + action="store_true", + help="Generate int4 onnx model for Nvidia A100", + ) + + parser.add_argument( + "--overwrite", + required=False, + action="store_true", + help="Overwrite existing onnx models", + ) + + args = parser.parse_args() + return args + + +def main(): + args = parse_arguments() + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + converter = ConvertPhi2ToONNX(device) + + temp_onnx_path = "phi2_temp.onnx" + original_onnx_path = "phi2.onnx" + + if not os.path.exists(original_onnx_path) or args.overwrite: + converter.dynamo_export(temp_onnx_path) + converter.preprocess_onnx( + temp_onnx_path, + original_onnx_path, + func_name="modeling_phi_PhiModel_model_1", # The function to unroll + use_gqa=True, + ) + converter.erase_onnx_model(temp_onnx_path) + + # TODO: support batch export + if args.fp32_cpu: + converter.init_attn_type_and_precision(AttentionOpType.MultiHeadAttention, Precision.FLOAT32) + converter.optimize_phi2_onnx(original_onnx_path, "fp32_cpu/phi2_opt.onnx") + elif args.int4_cpu: + converter.init_attn_type_and_precision(AttentionOpType.MultiHeadAttention, Precision.INT4) + converter.optimize_phi2_onnx(original_onnx_path, "int4_cpu/phi2_opt.onnx") + elif args.fp32_gpu: + converter.init_attn_type_and_precision(AttentionOpType.Attention, Precision.FLOAT32) + converter.optimize_phi2_onnx(original_onnx_path, "fp32_gpu/phi2_opt.onnx") + elif args.fp16_gpu: + converter.init_attn_type_and_precision(AttentionOpType.Attention, Precision.FLOAT16) + converter.optimize_phi2_onnx(original_onnx_path, "fp16_gpu/phi2_opt.onnx") + elif args.int4_gpu: + converter.init_attn_type_and_precision(AttentionOpType.Attention, Precision.INT4) + converter.optimize_phi2_onnx(original_onnx_path, "int4_gpu/phi2_opt.onnx") + elif args.fp16_a100: + converter.init_attn_type_and_precision(AttentionOpType.GroupQueryAttention, Precision.FLOAT16) + converter.optimize_phi2_onnx(original_onnx_path, "fp16_a100/phi2_opt.onnx") + elif args.int4_a100: + converter.init_attn_type_and_precision(AttentionOpType.GroupQueryAttention, Precision.INT4) + converter.optimize_phi2_onnx(original_onnx_path, "int4_a100/phi2_opt.onnx") + else: + print( + "Please specify a valid option from --fp32_cpu, --int4_cpu, --fp32_gpu, --fp16_gpu, --int4_gpu, --fp16_a100, --int4_a100" + ) + return + + # converter.erase_onnx_model(original_onnx_path) + print("done") + + +if __name__ == "__main__": + main()