From 990c1da5f3dd0def9b9c1e41d9beb5ac81569c7d Mon Sep 17 00:00:00 2001 From: Your Date: Thu, 25 Jan 2024 23:24:50 +0000 Subject: [PATCH] update --- .../models/phi2/convert_to_onnx.py | 18 +- .../tools/transformers/onnx_model_phi.py | 473 +++++++++--------- 2 files changed, 253 insertions(+), 238 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py index ad02bfbba06cf..3749d2e14f537 100644 --- a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py @@ -297,13 +297,13 @@ def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str, use_fp16: bool device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") converter = ConvertPhi2ToONNX(model_class, device) -converter.dynamo_export("phi-2_temp.onnx") -# TODO:preprocessed onnx model takes up large disk space -converter.preprocess_onnx( - "phi-2_temp.onnx", - "phi-2.onnx", - "modeling_phi_PhiModel_model_1", - use_gqa=True, -) -converter.erase_onnx_model("phi-2_temp.onnx") +# converter.dynamo_export("phi-2_temp.onnx") +# # TODO:preprocessed onnx model takes up large disk space +# converter.preprocess_onnx( +# "phi-2_temp.onnx", +# "phi-2.onnx", +# "modeling_phi_PhiModel_model_1", +# use_gqa=True, +# ) +# converter.erase_onnx_model("phi-2_temp.onnx") converter.optimize_phi2_onnx("phi-2.onnx", "phi-2_opt.onnx", use_fp16=True) diff --git a/onnxruntime/python/tools/transformers/onnx_model_phi.py b/onnxruntime/python/tools/transformers/onnx_model_phi.py index ffeb4e94c39d7..6bd0559353cfa 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_phi.py +++ b/onnxruntime/python/tools/transformers/onnx_model_phi.py @@ -64,10 +64,7 @@ def __call__(self, x): return x -def uname(layer_id, name): - return name + "_" + str(layer_id) - - +# TODO: move to a seperate file class Fission(Fusion): def __init__( self, @@ -124,7 +121,9 @@ def replace_fp32_value_info(self, name, shape): ) self.model.graph().value_info.extend([new_value_info]) - def set_unique_name(self, subgraph_nodes: List[NodeProto], layer_id: int, layer_known_edges_names: List[str]): + def set_unique_name_and_add_nodes( + self, subgraph_nodes: List[NodeProto], layer_id: int, layer_known_edges_names: List[str] + ): for new_node in subgraph_nodes: for i, name in enumerate(new_node.input): if name == "": @@ -142,6 +141,167 @@ def set_unique_name(self, subgraph_nodes: List[NodeProto], layer_id: int, layer_ self.nodes_to_add.append(new_node) self.node_name_to_graph_name[new_node.name] = self.this_graph_name + def layernorm(self, inputs: List[str], outputs: List[str], prefix: str = ""): + assert len(inputs) == 3 + assert len(outputs) == 1 + node = helper.make_node( + "LayerNormalization", + inputs=inputs, + outputs=outputs, + name=prefix + "_LayerNormalization", + epsilon=9.999999747378752e-06, + ) + return [node] + + def gemm(self, inputs: List[str], outputs: List[str], prefix: str = ""): + assert len(inputs) == 3 + assert len(outputs) == 1 + matmul = helper.make_node( + "MatMul", + inputs=[inputs[0], inputs[1]], + outputs=[prefix + "matmul_out"], + name=prefix + "MatMul", + ) + add = helper.make_node( + "Add", + inputs=[prefix + "matmul_out", inputs[2]], + outputs=outputs, + name=prefix + "Bias", + ) + return [matmul, add] + + def rotary(self, inputs: List[str], outputs: List[str], prefix: str = "", rot_dim=32, num_heads=32): + assert len(inputs) == 4 + assert len(outputs) == 1 + node = helper.make_node( + "RotaryEmbedding", + inputs=inputs, + outputs=outputs, + name=prefix + "RotaryEmbedding", + domain="com.microsoft", + rotary_embedding_dim=rot_dim, + num_heads=num_heads, + ) + return [node] + + def fastgelu(self, inputs: List[str], outputs: List[str], prefix: str = ""): + assert len(inputs) == 1 + assert len(outputs) == 1 + node = helper.make_node( + "FastGelu", + inputs=inputs, + outputs=outputs, + name=prefix + "FastGelu", + domain="com.microsoft", + ) + return [node] + + def add(self, inputs: List[str], outputs: List[str], prefix: str = ""): + assert len(inputs) == 2 + assert len(outputs) == 1 + node = helper.make_node( + "Add", + inputs=inputs, + outputs=outputs, + name=prefix + "Add", + ) + return [node] + + def mha(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32): + assert len(inputs) == 8 + assert len(outputs) == 3 + node = helper.make_node( + "MultiHeadAttention", + inputs=inputs, + outputs=outputs, + name=prefix + "MultiHeadAttention", + domain="com.microsoft", + num_heads=num_heads, + unidirectional=1, + ) + return [node] + + def gqa(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32): + assert len(inputs) == 7 + assert len(outputs) == 3 + node = helper.make_node( + "GroupQueryAttention", + inputs=inputs, + outputs=outputs, + name=prefix + "GroupQueryAttention", + domain="com.microsoft", + num_heads=num_heads, + kv_num_heads=num_heads, + ) + return [node] + + +class FissionTransformerEmbeddingPhi(Fission): + def __init__( + self, + model: OnnxModel, + ): + super().__init__(model, ["torch_nn_modules_sparse_Embedding_model_embed_tokens_1"]) + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + print(node.name) + + assert len(node.input) == 2 + assert len(node.output) == 1 + + input = node.input[0] + output = node.output[0] + + embedding = self.get_io_by_name(node, "embed_tokens.weight") + + layer_known_edges_names = [input, output, embedding] + + subgraph_nodes = [ + helper.make_node( + "Gather", + inputs=[embedding, input], + outputs=[output], + name="Embedding_Gather", + ), + ] + + self.set_unique_name_and_add_nodes(subgraph_nodes, 0, layer_known_edges_names) + self.nodes_to_remove.append(node) + self.prune_graph = True + + +class FissionTransformerLayerNormPhi(Fission): + def __init__( + self, + model: OnnxModel, + ): + super().__init__(model, ["torch_nn_modules_normalization_LayerNorm_model_final_layernorm_1"]) + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + print(node.name) + + assert len(node.input) == 3 + assert len(node.output) == 1 + + input = node.input[0] + output = node.output[0] + + ln_weight = self.get_io_by_name(node, "final_layernorm.weight") + ln_bias = self.get_io_by_name(node, "final_layernorm.bias") + + layer_known_edges_names = [input, output, ln_weight, ln_bias] + + subgraph_nodes = [] + subgraph_nodes.extend(self.layernorm([input, ln_weight, ln_bias], [output], "Final")) + + self.set_unique_name_and_add_nodes(subgraph_nodes, 99, layer_known_edges_names) + + self.replace_fp32_value_info(input, ["batch_size", "seq_len", "hidden_size"]) + self.replace_fp32_value_info(output, ["batch_size", "seq_len", "hidden_size"]) + + self.nodes_to_remove.append(node) + self.prune_graph = True + class FissionTransformerCausalLMHeadPhi(Fission): def __init__( @@ -151,6 +311,11 @@ def __init__( super().__init__(model, ["torch_nn_modules_linear_Linear_lm_head_1"]) def fuse(self, node, input_name_to_nodes, output_name_to_node): + print(node.name) + + assert len(node.input) == 5 + assert len(node.output) == 1 + input = node.input[2] output = node.output[0] @@ -159,23 +324,10 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): layer_known_edges_names = [input, output, fc_weight, fc_bias] - # opt graph construction. - subgraph_nodes = [ - helper.make_node( - "MatMul", - inputs=[input, fc_weight], - outputs=["matmul_out"], - name="OutProj_MatMul", - ), - helper.make_node( - "Add", - inputs=["matmul_out", fc_bias], - outputs=[output], - name="OutProj_Add", - ), - ] + subgraph_nodes = [] + subgraph_nodes.extend(self.gemm([input, fc_weight, fc_bias], [output], "LMHead_")) - self.set_unique_name(subgraph_nodes, 99, layer_known_edges_names) + self.set_unique_name_and_add_nodes(subgraph_nodes, 99, layer_known_edges_names) self.replace_fp32_value_info(input, ["batch_size", "seq_len", "hidden_size"]) self.replace_fp32_value_info(output, ["batch_size", "seq_len", 51200]) @@ -204,6 +356,47 @@ def __init__( def get_layer_id(self, node): return self.func_to_layer_id[node.op_type] + def get_gqa_aux_nodes(self): + gqa_aux_nodes = [ + helper.make_node( + "ReduceSum", + inputs=["attention_mask", "one"], + outputs=["attention_mask_row_sums"], + name="ReduceSum_gqa_aux", + ), + helper.make_node( + "Sub", + inputs=["attention_mask_row_sums", "one"], + outputs=["seqlens_k_int64"], + name="Sub_gqa_aux", + ), + helper.make_node( + "Cast", + inputs=["seqlens_k_int64"], + outputs=["seqlens_k"], + name="Cast_gqa_aux_0", + to=TensorProto.INT32, + ), + helper.make_node( + "Shape", inputs=["attention_mask"], outputs=["attention_mask_shape"], name="Shape_gqa_aux_0" + ), + helper.make_node( + "Gather", + inputs=["attention_mask_shape", "one"], + outputs=["total_seq_len_int64"], + name="Gather_gqa_aux_0", + axis=0, + ), + helper.make_node( + "Cast", + inputs=["total_seq_len_int64"], + outputs=["total_sequence_length"], + name="Cast_gqa_aux_1", + to=TensorProto.INT32, + ), + ] + return gqa_aux_nodes + # def fuse_with_attn( # self, # node, @@ -345,8 +538,9 @@ def fuse( input_name_to_nodes, output_name_to_node, ): + print(node.name) + layer_id = self.get_layer_id(node) - print(f"fuse layer {layer_id}") i_hidden_states = node.input[0] i_key_cache = self.get_io_by_name(node, "past_key") @@ -405,207 +599,37 @@ def fuse( layer_known_edges_names.extend([mlp_fc1_weight, mlp_fc1_bias, mlp_fc2_weight, mlp_fc2_bias]) layer_known_edges_names.extend(["attention_mask", "step", "seqlens_k", "total_sequence_length"]) - # opt graph construction. - subgraph_nodes = [ - helper.make_node( - "LayerNormalization", - inputs=[i_hidden_states, ln_weight, ln_bias], - outputs=["ln_out"], - name="LayerNormalization", - epsilon=9.999999747378752e-06, - ), - helper.make_node( - "MatMul", - inputs=["ln_out", attn_q_weight], - outputs=["q_matmul_out"], - name="Q_MatMul", - ), - helper.make_node( - "Add", - inputs=["q_matmul_out", attn_q_bias], - outputs=["query"], - name="Q_Bias", - ), - helper.make_node( - "MatMul", - inputs=["ln_out", attn_k_weight], - outputs=["k_matmul_out"], - name="K_MatMul", - ), - helper.make_node( - "Add", - inputs=["k_matmul_out", attn_k_bias], - outputs=["key"], - name="K_Bias", - ), - helper.make_node( - "RotaryEmbedding", - inputs=["query", "step", cos_cache, sin_cache], - outputs=["query_rot"], - name="RotaryEmbedding_Q", - domain="com.microsoft", - rotary_embedding_dim=32, - num_heads=self.num_heads, - ), - helper.make_node( - "RotaryEmbedding", - inputs=["key", "step", cos_cache, sin_cache], - outputs=["key_rot"], - name="RotaryEmbedding_K", - domain="com.microsoft", - rotary_embedding_dim=32, - num_heads=self.num_heads, - ), - helper.make_node( - "MatMul", - inputs=["ln_out", attn_v_weight], - outputs=["v_matmul_out"], - name="V_MatMul", - ), - helper.make_node( - "Add", - inputs=["v_matmul_out", attn_v_bias], - outputs=["value"], - name="V_Bias", - ), - helper.make_node( - "MatMul", - inputs=["attn_out", attn_out_weight], - outputs=["matmul_out"], - name="OutProj_MatMul", - ), - helper.make_node( - "Add", - inputs=["matmul_out", attn_out_bias], - outputs=["add_out"], - name="OutProj_Add", - ), - helper.make_node( - "MatMul", - inputs=["ln_out", mlp_fc1_weight], - outputs=["fc1_w_out"], - name="FC1_MatMul", - ), - helper.make_node( - "Add", - inputs=["fc1_w_out", mlp_fc1_bias], - outputs=["fc1_b_out"], - name="FC1_Bias", - ), - helper.make_node( - "FastGelu", - inputs=["fc1_b_out"], - outputs=["new_gelu_out"], - name="FastGelu", - domain="com.microsoft", - ), - helper.make_node( - "MatMul", - inputs=["new_gelu_out", mlp_fc2_weight], - outputs=["fc2_w_out"], - name="FC2_MatMul", - ), - helper.make_node( - "Add", - inputs=["fc2_w_out", mlp_fc2_bias], - outputs=["fc2_b_out"], - name="FC2_Bias", - ), - helper.make_node( - "Add", - inputs=["add_out", "fc2_b_out"], - outputs=["residual_1_out"], - name="Residual_Add_1", - ), - helper.make_node( - "Add", - inputs=[i_hidden_states, "residual_1_out"], - outputs=[o_hidden_states], - name="Residual_Add_2", - ), - ] + subgraph_nodes = [] + subgraph_nodes.extend(self.layernorm([i_hidden_states, ln_weight, ln_bias], ["ln_out"])) + subgraph_nodes.extend(self.gemm(["ln_out", attn_q_weight, attn_q_bias], ["query"], "Q_")) + subgraph_nodes.extend(self.gemm(["ln_out", attn_k_weight, attn_k_bias], ["key"], "K_")) + subgraph_nodes.extend(self.gemm(["ln_out", attn_v_weight, attn_v_bias], ["value"], "V_")) + subgraph_nodes.extend(self.rotary(["query", "step", cos_cache, sin_cache], ["query_rot"], "Q_")) + subgraph_nodes.extend(self.rotary(["key", "step", cos_cache, sin_cache], ["key_rot"], "K_")) + subgraph_nodes.extend(self.gemm(["attn_out", attn_out_weight, attn_out_bias], ["attn_add_out"], "OutProj_")) + subgraph_nodes.extend(self.gemm(["ln_out", mlp_fc1_weight, mlp_fc1_bias], ["fc1_out"], "FC1_")) + subgraph_nodes.extend(self.fastgelu(["fc1_out"], ["gelu_out"])) + subgraph_nodes.extend(self.gemm(["gelu_out", mlp_fc2_weight, mlp_fc2_bias], ["fc2_out"], "FC2_")) + subgraph_nodes.extend(self.add(["attn_add_out", "fc2_out"], ["residual_1_out"], "Residual_1")) + subgraph_nodes.extend(self.add([i_hidden_states, "residual_1_out"], [o_hidden_states], "Residual_2")) use_mha = False if use_mha: - subgraph_nodes.append( - helper.make_node( - "MultiHeadAttention", - inputs=[ - "query_rot", - "key_rot", - "value", - "", - "attention_mask", - "", - i_key_cache, - i_value_cache, - ], - outputs=["attn_out", o_key_cache, o_value_cache], - name="MultiHeadAttention_0", - domain="com.microsoft", - num_heads=self.num_heads, - unidirectional=1, + subgraph_nodes.extend( + self.mha( + ["query_rot", "key_rot", "value", "", "attention_mask", "", i_key_cache, i_value_cache], + ["attn_out", o_key_cache, o_value_cache], ) ) else: - subgraph_nodes.append( - helper.make_node( - "GroupQueryAttention", - inputs=[ - "query_rot", - "key_rot", - "value", - i_key_cache, - i_value_cache, - "seqlens_k", - "total_sequence_length", - ], - outputs=["attn_out", o_key_cache, o_value_cache], - name="GroupQueryAttention_0", - domain="com.microsoft", - num_heads=self.num_heads, - kv_num_heads=self.num_heads, - ), + subgraph_nodes.extend( + self.gqa( + ["query_rot", "key_rot", "value", i_key_cache, i_value_cache, "seqlens_k", "total_sequence_length"], + ["attn_out", o_key_cache, o_value_cache], + ) ) if layer_id == 0: - gqa_aux_nodes = [ - helper.make_node( - "ReduceSum", - inputs=["attention_mask", "one"], - outputs=["attention_mask_row_sums"], - name="ReduceSum_gqa_aux", - ), - helper.make_node( - "Sub", - inputs=["attention_mask_row_sums", "one"], - outputs=["seqlens_k_int64"], - name="Sub_gqa_aux", - ), - helper.make_node( - "Cast", - inputs=["seqlens_k_int64"], - outputs=["seqlens_k"], - name="Cast_gqa_aux_0", - to=TensorProto.INT32, - ), - helper.make_node( - "Shape", inputs=["attention_mask"], outputs=["attention_mask_shape"], name="Shape_gqa_aux_0" - ), - helper.make_node( - "Gather", - inputs=["attention_mask_shape", "one"], - outputs=["total_seq_len_int64"], - name="Gather_gqa_aux_0", - axis=0, - ), - helper.make_node( - "Cast", - inputs=["total_seq_len_int64"], - outputs=["total_sequence_length"], - name="Cast_gqa_aux_1", - to=TensorProto.INT32, - ), - ] + gqa_aux_nodes = self.get_gqa_aux_nodes() for new_node in gqa_aux_nodes: self.nodes_to_add.append(new_node) self.node_name_to_graph_name[new_node.name] = self.this_graph_name @@ -613,7 +637,7 @@ def fuse( numpy_helper.from_array(np.array([1], dtype="int64"), name="one"), self.this_graph_name ) - self.set_unique_name(subgraph_nodes, layer_id, layer_known_edges_names) + self.set_unique_name_and_add_nodes(subgraph_nodes, layer_id, layer_known_edges_names) self.replace_fp32_value_info(i_hidden_states, ["batch_size", "seq_len", "hidden_size"]) self.replace_fp32_value_info(o_hidden_states, ["batch_size", "seq_len", "hidden_size"]) @@ -713,6 +737,8 @@ def __init__(self, model: ModelProto, num_heads: int = 0, head_size: int = 0): super().__init__(model) self.fission_transformer_block = FissionTransformerBlockPhi(self, num_heads) self.fission_causal_lm_head = FissionTransformerCausalLMHeadPhi(self) + self.fission_transformer_layernorm = FissionTransformerLayerNormPhi(self) + self.fission_transformer_embedding = FissionTransformerEmbeddingPhi(self) self.fuse_sln = FusionSkipLayerNormalization(self) self.fuse_bias_sln = FusionBiasSkipLayerNormalization(self) @@ -723,20 +749,9 @@ def postprocess(self): def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False): self.fission_transformer_block.apply() + self.fission_transformer_layernorm.apply() self.fission_causal_lm_head.apply() + self.fission_transformer_embedding.apply() self.fuse_sln.apply() self.fuse_bias_sln.apply() self.postprocess() - - # def get_fused_operator_statistics(self): - # """ - # Returns node count of fused operators. - # """ - # op_count = {} - # return op_count - - # def is_fully_optimized(self, fused_op_count=None): - # """ - # Returns True when the model is fully optimized. - # """ - # return False