Skip to content

Commit

Permalink
Phi2 script fixes (#19500)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->

This PR is intended to support Phi2 passes in Olive. 
Merge it before microsoft/Olive#938

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
wangyems authored Feb 14, 2024
1 parent 5444070 commit f53d2c2
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 46 deletions.
7 changes: 7 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ class AttentionOpType(Enum):
def __str__(self):
return self.value

# Override __eq__ to return string comparison
def __hash__(self):
return hash(self.value)

def __eq__(self, other):
return other.value == self.value


class FusionOptions:
"""Options of fusion in graph optimization"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,6 @@ def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str):
# We keep last three layers of Attention as float32 or bfloat16 to avoid overflow.
node_block_list = (
[
"GroupQueryAttention_29",
"GroupQueryAttention_30",
"GroupQueryAttention_31",
"Attention_29",
"Attention_30",
"Attention_31",
Expand Down
98 changes: 55 additions & 43 deletions onnxruntime/python/tools/transformers/onnx_model_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,17 @@ def set_attention_op_type(self, attn_op_type: AttentionOpType):
def get_uname(self, layer_id, name):
return name + "_" + str(layer_id)

def get_io_by_name(self, node, name):
for input in node.input:
if input == name or input.endswith(name) or input.startswith(name):
return input
for output in node.output:
if output == name or output.endswith(name) or output.startswith(name):
return output
raise Exception(f"input {name} not found in node {node.name}")
def get_edge_by_name(self, edges, name):
for edge in edges:
if edge == name or edge.endswith(name) or edge.startswith(name):
return edge
raise ValueError(f"Edge {name} not found")

def get_input_by_name(self, node, name):
return self.get_edge_by_name(node.input, name)

def get_output_by_name(self, node, name):
return self.get_edge_by_name(node.output, name)

def process_initializer(self, initializer_name, functor, custom_name=None):
i = self.model.get_initializer(initializer_name)
Expand Down Expand Up @@ -287,7 +290,6 @@ def __init__(self, model: ModelProto, num_heads: int, hidden_size: int):
self.num_attention_heads = num_heads
self.hidden_size = hidden_size

self.phi2_edge_dict = self.get_phi2_edge_dict()
self.func_name = "modeling_phi_PhiModel_model_1"

def get_phi2_edge_dict(self) -> dict:
Expand All @@ -296,11 +298,20 @@ def get_phi2_edge_dict(self) -> dict:
edge_dict["l_input_ids_"] = "input_ids"
edge_dict["key_states"] = "past_key_0"
edge_dict["value_states"] = "past_value_0"
for i in range(self.num_hidden_layers):
for i in range(1, self.num_hidden_layers, 1):
edge_dict[f"key_states_{i}"] = f"past_key_{i}"
edge_dict[f"value_states_{i}"] = f"past_value_{i}"
edge_dict[f"model_layers_{i}_1"] = f"present_key_{i}"
edge_dict[f"model_layers_{i}_1_1"] = f"present_value_{i}"

outputs = [o.name for o in self.model.graph.output]
if "model_layers_0_1_1" in outputs and "model_layers_0_1_2" in outputs:
edge_dict["model_layers_0_1_1"] = "present_key_0"
edge_dict["model_layers_0_1_2"] = "present_value_0"
else:
assert "model_layers_0_1" in outputs and "model_layers_0_1_1" in outputs
edge_dict["model_layers_0_1"] = "present_key_0"
edge_dict["model_layers_0_1_1"] = "present_value_0"
return edge_dict

def simplify_phi2_op_type(self):
Expand Down Expand Up @@ -441,7 +452,7 @@ def preprocess_onnx(self, attn_op_type: AttentionOpType):
break
assert function_name is not None
self.unroll_function(function_name)
self.update_edges(self.phi2_edge_dict)
self.update_edges(self.get_phi2_edge_dict())
self.simplify_phi2_op_type()
self.remove_dropout_layer()
if attn_op_type == AttentionOpType.PagedAttention:
Expand All @@ -465,7 +476,7 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
input = node.input[0]
output = node.output[0]

embedding = self.get_io_by_name(node, "embed_tokens.weight")
embedding = self.get_input_by_name(node, "embed_tokens.weight")

layer_known_edges_names = [input, output, embedding]

Expand Down Expand Up @@ -499,8 +510,8 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
input = node.input[0]
output = node.output[0]

ln_weight = self.get_io_by_name(node, "final_layernorm.weight")
ln_bias = self.get_io_by_name(node, "final_layernorm.bias")
ln_weight = self.get_input_by_name(node, "final_layernorm.weight")
ln_bias = self.get_input_by_name(node, "final_layernorm.bias")

layer_known_edges_names = [input, output, ln_weight, ln_bias]

Expand Down Expand Up @@ -532,8 +543,8 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node):
input = node.input[2]
output = node.output[0]

fc_weight = self.process_initializer(self.get_io_by_name(node, "lm_head.weight"), ProcessGemmWFunc())
fc_bias = self.get_io_by_name(node, "lm_head.bias")
fc_weight = self.process_initializer(self.get_input_by_name(node, "lm_head.weight"), ProcessGemmWFunc())
fc_bias = self.get_input_by_name(node, "lm_head.bias")

layer_known_edges_names = [input, output, fc_weight, fc_bias]

Expand Down Expand Up @@ -670,15 +681,15 @@ def fuse(
layer_id = self.get_layer_id(node)

i_hidden_states = node.input[0]
i_key_cache = self.get_io_by_name(node, "past_key")
i_value_cache = self.get_io_by_name(node, "past_value")
i_key_cache = self.get_input_by_name(node, "past_key")
i_value_cache = self.get_input_by_name(node, "past_value")

o_hidden_states = node.output[3]
o_key_cache = self.get_io_by_name(node, "present_key")
o_value_cache = self.get_io_by_name(node, "present_value")
o_hidden_states = node.output[-1]
o_key_cache = self.get_output_by_name(node, "present_key")
o_value_cache = self.get_output_by_name(node, "present_value")

ln_weight = self.get_io_by_name(node, "input_layernorm.weight")
ln_bias = self.get_io_by_name(node, "input_layernorm.bias")
ln_weight = self.get_input_by_name(node, "input_layernorm.weight")
ln_bias = self.get_input_by_name(node, "input_layernorm.bias")

attn_q_weight, attn_q_bias, attn_k_weight, attn_k_bias, attn_v_weight, attn_v_bias = (
None,
Expand All @@ -693,45 +704,45 @@ def fuse(

if self.attn_op_type != AttentionOpType.Attention:
attn_q_weight = self.process_initializer(
self.get_io_by_name(node, "self_attn.q_proj.weight"), ProcessGemmWFunc()
self.get_input_by_name(node, "self_attn.q_proj.weight"), ProcessGemmWFunc()
)
attn_k_weight = self.process_initializer(
self.get_io_by_name(node, "self_attn.k_proj.weight"), ProcessGemmWFunc()
self.get_input_by_name(node, "self_attn.k_proj.weight"), ProcessGemmWFunc()
)
attn_v_weight = self.process_initializer(
self.get_io_by_name(node, "self_attn.v_proj.weight"), ProcessGemmWFunc()
self.get_input_by_name(node, "self_attn.v_proj.weight"), ProcessGemmWFunc()
)
attn_q_bias = self.get_io_by_name(node, "self_attn.q_proj.bias")
attn_k_bias = self.get_io_by_name(node, "self_attn.k_proj.bias")
attn_v_bias = self.get_io_by_name(node, "self_attn.v_proj.bias")
attn_q_bias = self.get_input_by_name(node, "self_attn.q_proj.bias")
attn_k_bias = self.get_input_by_name(node, "self_attn.k_proj.bias")
attn_v_bias = self.get_input_by_name(node, "self_attn.v_proj.bias")

cos_cache = self.process_initializer(
self.get_io_by_name(node, "rotary_emb.cos_cached"), ProcessRotCacheFunc()
self.get_input_by_name(node, "rotary_emb.cos_cached"), ProcessRotCacheFunc()
)
sin_cache = self.process_initializer(
self.get_io_by_name(node, "rotary_emb.sin_cached"), ProcessRotCacheFunc()
self.get_input_by_name(node, "rotary_emb.sin_cached"), ProcessRotCacheFunc()
)
else:
attn_qkv_weight, attn_qkv_bias = self.pack_qkv_gemm(
self.get_io_by_name(node, "self_attn.q_proj.weight"),
self.get_io_by_name(node, "self_attn.k_proj.weight"),
self.get_io_by_name(node, "self_attn.v_proj.weight"),
self.get_io_by_name(node, "self_attn.q_proj.bias"),
self.get_io_by_name(node, "self_attn.k_proj.bias"),
self.get_io_by_name(node, "self_attn.v_proj.bias"),
self.get_input_by_name(node, "self_attn.q_proj.weight"),
self.get_input_by_name(node, "self_attn.k_proj.weight"),
self.get_input_by_name(node, "self_attn.v_proj.weight"),
self.get_input_by_name(node, "self_attn.q_proj.bias"),
self.get_input_by_name(node, "self_attn.k_proj.bias"),
self.get_input_by_name(node, "self_attn.v_proj.bias"),
self.get_uname(layer_id, "attn_qkv_weight"),
self.get_uname(layer_id, "attn_qkv_bias"),
)

attn_out_weight = self.process_initializer(
self.get_io_by_name(node, "self_attn.dense.weight"), ProcessGemmWFunc()
self.get_input_by_name(node, "self_attn.dense.weight"), ProcessGemmWFunc()
)
attn_out_bias = self.get_io_by_name(node, "self_attn.dense.bias")
attn_out_bias = self.get_input_by_name(node, "self_attn.dense.bias")

mlp_fc1_weight = self.process_initializer(self.get_io_by_name(node, "mlp.fc1.weight"), ProcessGemmWFunc())
mlp_fc2_weight = self.process_initializer(self.get_io_by_name(node, "mlp.fc2.weight"), ProcessGemmWFunc())
mlp_fc1_bias = self.get_io_by_name(node, "mlp.fc1.bias")
mlp_fc2_bias = self.get_io_by_name(node, "mlp.fc2.bias")
mlp_fc1_weight = self.process_initializer(self.get_input_by_name(node, "mlp.fc1.weight"), ProcessGemmWFunc())
mlp_fc2_weight = self.process_initializer(self.get_input_by_name(node, "mlp.fc2.weight"), ProcessGemmWFunc())
mlp_fc1_bias = self.get_input_by_name(node, "mlp.fc1.bias")
mlp_fc2_bias = self.get_input_by_name(node, "mlp.fc2.bias")

layer_known_edges_names = []
layer_known_edges_names.extend([i_hidden_states, i_key_cache, i_value_cache])
Expand Down Expand Up @@ -771,6 +782,7 @@ def fuse(
subgraph_nodes.extend(self.gemm(["ln_out", attn_q_weight, attn_q_bias], ["query"], "Q_"))
subgraph_nodes.extend(self.gemm(["ln_out", attn_k_weight, attn_k_bias], ["key"], "K_"))
subgraph_nodes.extend(self.gemm(["ln_out", attn_v_weight, attn_v_bias], ["value"], "V_"))
# vllm engine requires full position ids as the input
pos_ids_name = "position_ids" if self.attn_op_type == AttentionOpType.PagedAttention else "step"
subgraph_nodes.extend(self.rotary(["query", pos_ids_name, cos_cache, sin_cache], ["query_rot"], "Q_"))
subgraph_nodes.extend(self.rotary(["key", pos_ids_name, cos_cache, sin_cache], ["key_rot"], "K_"))
Expand Down

0 comments on commit f53d2c2

Please sign in to comment.