Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyems committed Jan 26, 2024
1 parent 990c1da commit 15324e5
Showing 1 changed file with 160 additions and 35 deletions.
195 changes: 160 additions & 35 deletions onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
from typing import List
import numpy as np
import torch
import argparse
import onnx
import torch

from enum import Enum
from onnx import ModelProto, TensorProto, helper
from transformers import AutoConfig, AutoModelForCausalLM
from typing import List

# --------------------------------------------------------------------------
# The following code is used when this file is not in the ORT package
Expand All @@ -19,19 +21,38 @@
sys.path.append(transformers_dir)
# --------------------------------------------------------------------------

from benchmark_helper import Precision


class AttentionOpType(Enum):
Attention = "attention"
MultiHeadAttention = "mha"
GroupQueryAttention = "gqa"

def __str__(self):
return self.value


class ConvertPhi2ToONNX:
def __init__(self, model_class: str, device: torch.device, cache_dir: str = "./cache"):
def __init__(
self,
device: torch.device,
model_class: str = "microsoft/phi-2",
cache_dir: str = "./cache",
):
self.model_class = model_class
self.device = device
self.cache_dir = cache_dir
self.phi_config = AutoConfig.from_pretrained(self.model_class, trust_remote_code=True, cache_dir=self.cache_dir)
self.phi_model = None
self.batch_size = 2
self.sequence_length = 8

self.phi2_edge_dict = self.__get_phi2_edge_dict(self.phi_config)

def init_attn_type_and_precision(self, attn_op_type: AttentionOpType, precision: Precision):
self.attn_op_type = attn_op_type
self.precision = precision

def __get_phi2_edge_dict(self, config: AutoConfig) -> dict:
edge_dict = {}
edge_dict["lm_head_1"] = "logits"
Expand All @@ -45,7 +66,7 @@ def __get_phi2_edge_dict(self, config: AutoConfig) -> dict:
edge_dict[f"model_layers_{i}_1_1"] = f"present_value_{i}"
return edge_dict

def __simplify_phi2_op_type_name(self, onnx_model: ModelProto):
def __simplify_phi2_op_type(self, onnx_model: ModelProto):
phi2_transformer_layer_name = "modeling_phi_PhiDecoderLayer_model_layers"
for node in onnx_model.graph.node:
index = node.op_type.find(phi2_transformer_layer_name)
Expand All @@ -54,7 +75,8 @@ def __simplify_phi2_op_type_name(self, onnx_model: ModelProto):

return onnx_model

def __process_graph_io(self, config: AutoConfig, onnx_model: ModelProto, use_gqa=False):
def __process_graph_io(self, config: AutoConfig, onnx_model: ModelProto):
use_gqa = self.attn_op_type == AttentionOpType.GroupQueryAttention
graph = onnx_model.graph
new_inputs = []
for i, vi in enumerate(graph.input):
Expand All @@ -64,16 +86,16 @@ def __process_graph_io(self, config: AutoConfig, onnx_model: ModelProto, use_gqa
elem_type=TensorProto.INT32,
shape=["batch_size", "seq_len"],
)
vi_mask = helper.make_tensor_value_info(
"attention_mask",
elem_type=TensorProto.INT64 if use_gqa else TensorProto.INT32,
shape=["batch_size", "seq_len"],
)
vi_pid = helper.make_tensor_value_info(
"step",
elem_type=TensorProto.INT64,
shape=[1],
)
vi_mask = helper.make_tensor_value_info(
"attention_mask",
elem_type=TensorProto.INT64 if use_gqa else TensorProto.INT32,
shape=["batch_size", "seq_len"],
)
new_inputs.extend([vi, vi_pid, vi_mask])
if "past_key" in vi.name or "past_value" in vi.name:
vi_cache = helper.make_tensor_value_info(
Expand Down Expand Up @@ -136,9 +158,9 @@ def __update_edges(self, model: onnx.ModelProto, edge_mapping: dict):

return model

def __inline_function(self, model: onnx.ModelProto, func_name: str) -> onnx.ModelProto:
def __unroll_function(self, model: onnx.ModelProto, func_name: str) -> onnx.ModelProto:
"""
Inlines the function with the given name in the model.
Unrolls the function with the given name in the model.
"""
nodes_to_remove = []
nodes_to_add = []
Expand Down Expand Up @@ -215,6 +237,8 @@ def get_phi2_torch_inputs(self, batch_size: int, sequence_length: int):

def erase_onnx_model(self, onnx_path: str):
assert onnx_path.endswith(".onnx")
if not os.path.exists(onnx_path):
return
model = onnx.load_model(onnx_path, load_external_data=False)
onnx_data_path = None
for initializer in model.graph.initializer:
Expand Down Expand Up @@ -242,18 +266,18 @@ def dynamo_export(self, onnx_path: str):
onnx.checker.check_model(onnx_path)
onnx.shape_inference.infer_shapes_path(onnx_path)

def preprocess_onnx(self, onnx_path_in: str, onnx_path_out: str, func_name: List[str], use_gqa=False):
def preprocess_onnx(self, onnx_path_in: str, onnx_path_out: str, func_name: str):
model = onnx.load_model(onnx_path_in, load_external_data=True)
function_name = None
for func in model.functions:
if func.name.endswith(func_name):
function_name = func.name
break
assert function_name is not None
model = self.__inline_function(model, function_name)
model = self.__unroll_function(model, function_name)
model = self.__update_edges(model, self.phi2_edge_dict)
model = self.__simplify_phi2_op_type_name(model)
model = self.__process_graph_io(self.phi_config, model, use_gqa)
model = self.__simplify_phi2_op_type(model)
model = self.__process_graph_io(self.phi_config, model)
model = self.__remove_dropout_layer(model)
onnx.save_model(
model,
Expand All @@ -279,9 +303,7 @@ def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str, use_fp16: bool
)

if use_fp16:
node_block_list = ["GroupQueryAttention_0_29",
"GroupQueryAttention_0_30",
"GroupQueryAttention_0_31"]
node_block_list = ["GroupQueryAttention_0_29", "GroupQueryAttention_0_30", "GroupQueryAttention_0_31"]
optimizer.convert_float_to_float16(
keep_io_types=False,
node_block_list=node_block_list,
Expand All @@ -293,17 +315,120 @@ def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str, use_fp16: bool
optimizer.get_operator_statistics()


model_class = "microsoft/phi-2"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

converter = ConvertPhi2ToONNX(model_class, device)
# converter.dynamo_export("phi-2_temp.onnx")
# # TODO:preprocessed onnx model takes up large disk space
# converter.preprocess_onnx(
# "phi-2_temp.onnx",
# "phi-2.onnx",
# "modeling_phi_PhiModel_model_1",
# use_gqa=True,
# )
# converter.erase_onnx_model("phi-2_temp.onnx")
converter.optimize_phi2_onnx("phi-2.onnx", "phi-2_opt.onnx", use_fp16=True)
def parse_arguments():
parser = argparse.ArgumentParser()

parser.add_argument(
"--fp32_cpu",
required=False,
action="store_true",
help="Generate fp32 onnx model for CPU",
)

parser.add_argument(
"--int4_cpu",
required=False,
action="store_true",
help="Generate int4 onnx model for CPU",
)

parser.add_argument(
"--fp32_gpu",
required=False,
action="store_true",
help="Generate fp32 onnx model for Nvidia GPUs",
)

parser.add_argument(
"--fp16_gpu",
required=False,
action="store_true",
help="Generate fp16 onnx model for Nvidia GPUs",
)

parser.add_argument(
"--int4_gpu",
required=False,
action="store_true",
help="Generate int4 onnx model for Nvidia GPUs",
)

parser.add_argument(
"--fp16_a100",
required=False,
action="store_true",
help="Generate fp16 onnx model for Nvidia A100",
)

parser.add_argument(
"--int4_a100",
required=False,
action="store_true",
help="Generate int4 onnx model for Nvidia A100",
)

parser.add_argument(
"--overwrite",
required=False,
action="store_true",
help="Overwrite existing onnx models",
)

args = parser.parse_args()
return args


def main():
args = parse_arguments()

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

converter = ConvertPhi2ToONNX(device)

temp_onnx_path = "phi2_temp.onnx"
original_onnx_path = "phi2.onnx"

if not os.path.exists(original_onnx_path) or args.overwrite:
converter.dynamo_export(temp_onnx_path)
converter.preprocess_onnx(
temp_onnx_path,
original_onnx_path,
func_name="modeling_phi_PhiModel_model_1", # The function to unroll
use_gqa=True,
)
converter.erase_onnx_model(temp_onnx_path)

# TODO: support batch export
if args.fp32_cpu:
converter.init_attn_type_and_precision(AttentionOpType.MultiHeadAttention, Precision.FLOAT32)
converter.optimize_phi2_onnx(original_onnx_path, "fp32_cpu/phi2_opt.onnx")
elif args.int4_cpu:
converter.init_attn_type_and_precision(AttentionOpType.MultiHeadAttention, Precision.INT4)
converter.optimize_phi2_onnx(original_onnx_path, "int4_cpu/phi2_opt.onnx")
elif args.fp32_gpu:
converter.init_attn_type_and_precision(AttentionOpType.Attention, Precision.FLOAT32)
converter.optimize_phi2_onnx(original_onnx_path, "fp32_gpu/phi2_opt.onnx")
elif args.fp16_gpu:
converter.init_attn_type_and_precision(AttentionOpType.Attention, Precision.FLOAT16)
converter.optimize_phi2_onnx(original_onnx_path, "fp16_gpu/phi2_opt.onnx")
elif args.int4_gpu:
converter.init_attn_type_and_precision(AttentionOpType.Attention, Precision.INT4)
converter.optimize_phi2_onnx(original_onnx_path, "int4_gpu/phi2_opt.onnx")
elif args.fp16_a100:
converter.init_attn_type_and_precision(AttentionOpType.GroupQueryAttention, Precision.FLOAT16)
converter.optimize_phi2_onnx(original_onnx_path, "fp16_a100/phi2_opt.onnx")
elif args.int4_a100:
converter.init_attn_type_and_precision(AttentionOpType.GroupQueryAttention, Precision.INT4)
converter.optimize_phi2_onnx(original_onnx_path, "int4_a100/phi2_opt.onnx")
else:
print(
"Please specify a valid option from --fp32_cpu, --int4_cpu, --fp32_gpu, --fp16_gpu, --int4_gpu, --fp16_a100, --int4_a100"
)
return

# converter.erase_onnx_model(original_onnx_path)
print("done")


if __name__ == "__main__":
main()

0 comments on commit 15324e5

Please sign in to comment.