From f53d2c2465d81cdb4e14c7241eab327184192c88 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Wed, 14 Feb 2024 18:08:11 +0000 Subject: [PATCH] Phi2 script fixes (#19500) ### Description This PR is intended to support Phi2 passes in Olive. Merge it before https://github.com/microsoft/Olive/pull/938 ### Motivation and Context --- .../tools/transformers/fusion_options.py | 7 ++ .../models/phi2/convert_to_onnx.py | 3 - .../tools/transformers/onnx_model_phi.py | 98 +++++++++++-------- 3 files changed, 62 insertions(+), 46 deletions(-) diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index 4c43e4487bfb1..edac1989e4e9e 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -29,6 +29,13 @@ class AttentionOpType(Enum): def __str__(self): return self.value + # Override __eq__ to return string comparison + def __hash__(self): + return hash(self.value) + + def __eq__(self, other): + return other.value == self.value + class FusionOptions: """Options of fusion in graph optimization""" diff --git a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py index b7881d064067d..796d6ec55ef80 100644 --- a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py @@ -138,9 +138,6 @@ def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str): # We keep last three layers of Attention as float32 or bfloat16 to avoid overflow. node_block_list = ( [ - "GroupQueryAttention_29", - "GroupQueryAttention_30", - "GroupQueryAttention_31", "Attention_29", "Attention_30", "Attention_31", diff --git a/onnxruntime/python/tools/transformers/onnx_model_phi.py b/onnxruntime/python/tools/transformers/onnx_model_phi.py index e68c3120e3f09..0fdce29ae0fa0 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_phi.py +++ b/onnxruntime/python/tools/transformers/onnx_model_phi.py @@ -80,14 +80,17 @@ def set_attention_op_type(self, attn_op_type: AttentionOpType): def get_uname(self, layer_id, name): return name + "_" + str(layer_id) - def get_io_by_name(self, node, name): - for input in node.input: - if input == name or input.endswith(name) or input.startswith(name): - return input - for output in node.output: - if output == name or output.endswith(name) or output.startswith(name): - return output - raise Exception(f"input {name} not found in node {node.name}") + def get_edge_by_name(self, edges, name): + for edge in edges: + if edge == name or edge.endswith(name) or edge.startswith(name): + return edge + raise ValueError(f"Edge {name} not found") + + def get_input_by_name(self, node, name): + return self.get_edge_by_name(node.input, name) + + def get_output_by_name(self, node, name): + return self.get_edge_by_name(node.output, name) def process_initializer(self, initializer_name, functor, custom_name=None): i = self.model.get_initializer(initializer_name) @@ -287,7 +290,6 @@ def __init__(self, model: ModelProto, num_heads: int, hidden_size: int): self.num_attention_heads = num_heads self.hidden_size = hidden_size - self.phi2_edge_dict = self.get_phi2_edge_dict() self.func_name = "modeling_phi_PhiModel_model_1" def get_phi2_edge_dict(self) -> dict: @@ -296,11 +298,20 @@ def get_phi2_edge_dict(self) -> dict: edge_dict["l_input_ids_"] = "input_ids" edge_dict["key_states"] = "past_key_0" edge_dict["value_states"] = "past_value_0" - for i in range(self.num_hidden_layers): + for i in range(1, self.num_hidden_layers, 1): edge_dict[f"key_states_{i}"] = f"past_key_{i}" edge_dict[f"value_states_{i}"] = f"past_value_{i}" edge_dict[f"model_layers_{i}_1"] = f"present_key_{i}" edge_dict[f"model_layers_{i}_1_1"] = f"present_value_{i}" + + outputs = [o.name for o in self.model.graph.output] + if "model_layers_0_1_1" in outputs and "model_layers_0_1_2" in outputs: + edge_dict["model_layers_0_1_1"] = "present_key_0" + edge_dict["model_layers_0_1_2"] = "present_value_0" + else: + assert "model_layers_0_1" in outputs and "model_layers_0_1_1" in outputs + edge_dict["model_layers_0_1"] = "present_key_0" + edge_dict["model_layers_0_1_1"] = "present_value_0" return edge_dict def simplify_phi2_op_type(self): @@ -441,7 +452,7 @@ def preprocess_onnx(self, attn_op_type: AttentionOpType): break assert function_name is not None self.unroll_function(function_name) - self.update_edges(self.phi2_edge_dict) + self.update_edges(self.get_phi2_edge_dict()) self.simplify_phi2_op_type() self.remove_dropout_layer() if attn_op_type == AttentionOpType.PagedAttention: @@ -465,7 +476,7 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): input = node.input[0] output = node.output[0] - embedding = self.get_io_by_name(node, "embed_tokens.weight") + embedding = self.get_input_by_name(node, "embed_tokens.weight") layer_known_edges_names = [input, output, embedding] @@ -499,8 +510,8 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): input = node.input[0] output = node.output[0] - ln_weight = self.get_io_by_name(node, "final_layernorm.weight") - ln_bias = self.get_io_by_name(node, "final_layernorm.bias") + ln_weight = self.get_input_by_name(node, "final_layernorm.weight") + ln_bias = self.get_input_by_name(node, "final_layernorm.bias") layer_known_edges_names = [input, output, ln_weight, ln_bias] @@ -532,8 +543,8 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): input = node.input[2] output = node.output[0] - fc_weight = self.process_initializer(self.get_io_by_name(node, "lm_head.weight"), ProcessGemmWFunc()) - fc_bias = self.get_io_by_name(node, "lm_head.bias") + fc_weight = self.process_initializer(self.get_input_by_name(node, "lm_head.weight"), ProcessGemmWFunc()) + fc_bias = self.get_input_by_name(node, "lm_head.bias") layer_known_edges_names = [input, output, fc_weight, fc_bias] @@ -670,15 +681,15 @@ def fuse( layer_id = self.get_layer_id(node) i_hidden_states = node.input[0] - i_key_cache = self.get_io_by_name(node, "past_key") - i_value_cache = self.get_io_by_name(node, "past_value") + i_key_cache = self.get_input_by_name(node, "past_key") + i_value_cache = self.get_input_by_name(node, "past_value") - o_hidden_states = node.output[3] - o_key_cache = self.get_io_by_name(node, "present_key") - o_value_cache = self.get_io_by_name(node, "present_value") + o_hidden_states = node.output[-1] + o_key_cache = self.get_output_by_name(node, "present_key") + o_value_cache = self.get_output_by_name(node, "present_value") - ln_weight = self.get_io_by_name(node, "input_layernorm.weight") - ln_bias = self.get_io_by_name(node, "input_layernorm.bias") + ln_weight = self.get_input_by_name(node, "input_layernorm.weight") + ln_bias = self.get_input_by_name(node, "input_layernorm.bias") attn_q_weight, attn_q_bias, attn_k_weight, attn_k_bias, attn_v_weight, attn_v_bias = ( None, @@ -693,45 +704,45 @@ def fuse( if self.attn_op_type != AttentionOpType.Attention: attn_q_weight = self.process_initializer( - self.get_io_by_name(node, "self_attn.q_proj.weight"), ProcessGemmWFunc() + self.get_input_by_name(node, "self_attn.q_proj.weight"), ProcessGemmWFunc() ) attn_k_weight = self.process_initializer( - self.get_io_by_name(node, "self_attn.k_proj.weight"), ProcessGemmWFunc() + self.get_input_by_name(node, "self_attn.k_proj.weight"), ProcessGemmWFunc() ) attn_v_weight = self.process_initializer( - self.get_io_by_name(node, "self_attn.v_proj.weight"), ProcessGemmWFunc() + self.get_input_by_name(node, "self_attn.v_proj.weight"), ProcessGemmWFunc() ) - attn_q_bias = self.get_io_by_name(node, "self_attn.q_proj.bias") - attn_k_bias = self.get_io_by_name(node, "self_attn.k_proj.bias") - attn_v_bias = self.get_io_by_name(node, "self_attn.v_proj.bias") + attn_q_bias = self.get_input_by_name(node, "self_attn.q_proj.bias") + attn_k_bias = self.get_input_by_name(node, "self_attn.k_proj.bias") + attn_v_bias = self.get_input_by_name(node, "self_attn.v_proj.bias") cos_cache = self.process_initializer( - self.get_io_by_name(node, "rotary_emb.cos_cached"), ProcessRotCacheFunc() + self.get_input_by_name(node, "rotary_emb.cos_cached"), ProcessRotCacheFunc() ) sin_cache = self.process_initializer( - self.get_io_by_name(node, "rotary_emb.sin_cached"), ProcessRotCacheFunc() + self.get_input_by_name(node, "rotary_emb.sin_cached"), ProcessRotCacheFunc() ) else: attn_qkv_weight, attn_qkv_bias = self.pack_qkv_gemm( - self.get_io_by_name(node, "self_attn.q_proj.weight"), - self.get_io_by_name(node, "self_attn.k_proj.weight"), - self.get_io_by_name(node, "self_attn.v_proj.weight"), - self.get_io_by_name(node, "self_attn.q_proj.bias"), - self.get_io_by_name(node, "self_attn.k_proj.bias"), - self.get_io_by_name(node, "self_attn.v_proj.bias"), + self.get_input_by_name(node, "self_attn.q_proj.weight"), + self.get_input_by_name(node, "self_attn.k_proj.weight"), + self.get_input_by_name(node, "self_attn.v_proj.weight"), + self.get_input_by_name(node, "self_attn.q_proj.bias"), + self.get_input_by_name(node, "self_attn.k_proj.bias"), + self.get_input_by_name(node, "self_attn.v_proj.bias"), self.get_uname(layer_id, "attn_qkv_weight"), self.get_uname(layer_id, "attn_qkv_bias"), ) attn_out_weight = self.process_initializer( - self.get_io_by_name(node, "self_attn.dense.weight"), ProcessGemmWFunc() + self.get_input_by_name(node, "self_attn.dense.weight"), ProcessGemmWFunc() ) - attn_out_bias = self.get_io_by_name(node, "self_attn.dense.bias") + attn_out_bias = self.get_input_by_name(node, "self_attn.dense.bias") - mlp_fc1_weight = self.process_initializer(self.get_io_by_name(node, "mlp.fc1.weight"), ProcessGemmWFunc()) - mlp_fc2_weight = self.process_initializer(self.get_io_by_name(node, "mlp.fc2.weight"), ProcessGemmWFunc()) - mlp_fc1_bias = self.get_io_by_name(node, "mlp.fc1.bias") - mlp_fc2_bias = self.get_io_by_name(node, "mlp.fc2.bias") + mlp_fc1_weight = self.process_initializer(self.get_input_by_name(node, "mlp.fc1.weight"), ProcessGemmWFunc()) + mlp_fc2_weight = self.process_initializer(self.get_input_by_name(node, "mlp.fc2.weight"), ProcessGemmWFunc()) + mlp_fc1_bias = self.get_input_by_name(node, "mlp.fc1.bias") + mlp_fc2_bias = self.get_input_by_name(node, "mlp.fc2.bias") layer_known_edges_names = [] layer_known_edges_names.extend([i_hidden_states, i_key_cache, i_value_cache]) @@ -771,6 +782,7 @@ def fuse( subgraph_nodes.extend(self.gemm(["ln_out", attn_q_weight, attn_q_bias], ["query"], "Q_")) subgraph_nodes.extend(self.gemm(["ln_out", attn_k_weight, attn_k_bias], ["key"], "K_")) subgraph_nodes.extend(self.gemm(["ln_out", attn_v_weight, attn_v_bias], ["value"], "V_")) + # vllm engine requires full position ids as the input pos_ids_name = "position_ids" if self.attn_op_type == AttentionOpType.PagedAttention else "step" subgraph_nodes.extend(self.rotary(["query", pos_ids_name, cos_cache, sin_cache], ["query_rot"], "Q_")) subgraph_nodes.extend(self.rotary(["key", pos_ids_name, cos_cache, sin_cache], ["key_rot"], "K_"))