Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyems committed Jan 26, 2024
1 parent 013f551 commit eed75b8
Showing 1 changed file with 83 additions and 142 deletions.
225 changes: 83 additions & 142 deletions onnxruntime/python/tools/transformers/onnx_model_phi.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,22 @@ def gqa(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads
)
return [node]

def attention(self, inputs: List[str], outputs: List[str], prefix: str = "", num_heads=32):
assert len(inputs) == 5
assert len(outputs) == 2
node = helper.make_node(
"Attention",
inputs=inputs,
outputs=outputs,
name=prefix + "Attention",
domain="com.microsoft",
num_heads=num_heads,
unidirectional=1,
do_rotary=1,
rotary_embedding_dim=32,
)
return [node]


class FissionTransformerEmbeddingPhi(Fission):
def __init__(
Expand Down Expand Up @@ -398,140 +414,46 @@ def get_gqa_aux_nodes(self):
]
return gqa_aux_nodes

# def fuse_with_attn(
# self,
# node,
# input_name_to_nodes,
# output_name_to_node,
# ):
# layer_id = self.get_layer_id(node)
# print(f"fuse layer {layer_id}")

# # transformer block input and output
# i_hidden_states = node.input[0]
# i_attn_mask = node.input[1]
# i_kv_cache = node.input[3]
# o_hidden_states = node.output[3]
# o_kv_cache = node.output[0]

# # internal nodes weights
# ln_weight = node.input[5] # float32[2560]
# ln_bias = node.input[6] # float32[2560]
# attn_qkv_weight = self.process_initializer(node.input[7], ProcessGemmWFunc()) # float32[7680,2560]
# attn_qkv_bias = node.input[8] # float32[7680]
# attn_out_weight = self.process_initializer(node.input[11], ProcessGemmWFunc()) # float32[2560,2560]
# attn_out_bias = node.input[12] # float32[2560]
# mlp_fc1_weight = self.process_initializer(node.input[13], ProcessGemmWFunc()) # float32[10240,2560]
# mlp_fc1_bias = node.input[14] # float32[10240]
# mlp_fc2_weight = self.process_initializer(node.input[15], ProcessGemmWFunc()) # float32[2560,10240]
# mlp_fc2_bias = node.input[16] # float32[2560]

# # opt graph construction.
# subgraph_nodes = [
# helper.make_node(
# "LayerNormalization",
# inputs=[i_hidden_states, ln_weight, ln_bias],
# outputs=[uname(layer_id, "ln_out")],
# name=uname(layer_id, "LayerNormalization"),
# epsilon=9.999999747378752e-06,
# ),
# helper.make_node(
# "Attention",
# inputs=[
# uname(layer_id, "ln_out"),
# attn_qkv_weight,
# attn_qkv_bias,
# i_attn_mask,
# i_kv_cache,
# # "",
# # "past_sequence_length",
# ],
# outputs=[uname(layer_id, "attn_out"), o_kv_cache],
# name=uname(layer_id, "Attention"),
# domain="com.microsoft",
# num_heads=32,
# unidirectional=1,
# do_rotary=1,
# rotary_embedding_dim=32,
# # past_present_share_buffer=1,
# ),
# helper.make_node(
# "MatMul",
# inputs=[uname(layer_id, "attn_out"), attn_out_weight],
# outputs=[uname(layer_id, "matmul_out")],
# name=uname(layer_id, "OutProj_MatMul"),
# ),
# helper.make_node(
# "Add",
# inputs=[uname(layer_id, "matmul_out"), attn_out_bias],
# outputs=[uname(layer_id, "add_out")],
# name=uname(layer_id, "OutProj_Add"),
# ),
# helper.make_node(
# "MatMul",
# inputs=[uname(layer_id, "ln_out"), mlp_fc1_weight],
# outputs=[uname(layer_id, "fc1_w_out")],
# name=uname(layer_id, "FC1_MatMul"),
# ),
# helper.make_node(
# "Add",
# inputs=[uname(layer_id, "fc1_w_out"), mlp_fc1_bias],
# outputs=[uname(layer_id, "fc1_b_out")],
# name=uname(layer_id, "FC1_Bias"),
# ),
# helper.make_node(
# "FastGelu",
# inputs=[uname(layer_id, "fc1_b_out")],
# outputs=[uname(layer_id, "new_gelu_out")],
# name=uname(layer_id, "FastGelu"),
# domain="com.microsoft",
# ),
# helper.make_node(
# "MatMul",
# inputs=[uname(layer_id, "new_gelu_out"), mlp_fc2_weight],
# outputs=[uname(layer_id, "fc2_w_out")],
# name=uname(layer_id, "FC2_MatMul"),
# ),
# helper.make_node(
# "Add",
# inputs=[uname(layer_id, "fc2_w_out"), mlp_fc2_bias],
# outputs=[uname(layer_id, "fc2_b_out")],
# name=uname(layer_id, "FC2_Bias"),
# ),
# helper.make_node(
# "Add",
# inputs=[uname(layer_id, "add_out"), uname(layer_id, "fc2_b_out")],
# outputs=[uname(layer_id, "residual_1_out")],
# name=uname(layer_id, "Residual_Add_1"),
# ),
# helper.make_node(
# "Add",
# inputs=[i_hidden_states, uname(layer_id, "residual_1_out")],
# outputs=[o_hidden_states],
# name=uname(layer_id, "Residual_Add_2"),
# ),
# ]

# for new_node in subgraph_nodes:
# self.nodes_to_add.append(new_node)
# self.node_name_to_graph_name[new_node.name] = self.this_graph_name

# self.add_fp32_value_info(uname(layer_id, "ln_out"))
# self.add_fp32_value_info(uname(layer_id, "attn_out"))
# self.add_fp32_value_info(uname(layer_id, "matmul_out"))
# self.add_fp32_value_info(uname(layer_id, "add_out"))
# self.add_fp32_value_info(uname(layer_id, "fc1_w_out"))
# self.add_fp32_value_info(uname(layer_id, "fc1_b_out"))
# self.add_fp32_value_info(uname(layer_id, "new_gelu_out"))
# self.add_fp32_value_info(uname(layer_id, "fc2_w_out"))
# self.add_fp32_value_info(uname(layer_id, "fc2_b_out"))
# self.add_fp32_value_info(uname(layer_id, "residual_1_out"))

# self.replace_fp32_value_info(i_hidden_states, ["batch_size", "seq_len", "hidden_size"])
# self.replace_fp32_value_info(o_hidden_states, ["batch_size", "seq_len", "hidden_size"])

# self.nodes_to_remove.append(node)
# self.prune_graph = True
def pack_qkv_gemm(self, q_w, k_w, v_w, q_b, k_b, v_b, weight_name, bias_name):
q_weight = self.model.get_initializer(q_w)
k_weight = self.model.get_initializer(k_w)
v_weight = self.model.get_initializer(v_w)
qw = NumpyHelper.to_array(q_weight)
kw = NumpyHelper.to_array(k_weight)
vw = NumpyHelper.to_array(v_weight)
qkv_weight = np.stack((qw, kw, vw), axis=1)

q_bias = self.model.get_initializer(q_b)
k_bias = self.model.get_initializer(k_b)
v_bias = self.model.get_initializer(v_b)
qb = NumpyHelper.to_array(q_bias)
kb = NumpyHelper.to_array(k_bias)
vb = NumpyHelper.to_array(v_bias)
qkv_bias = np.stack((qb, kb, vb), axis=0)

# bugbug: shape is wrong
weight = helper.make_tensor(
weight_name,
data_type=TensorProto.FLOAT,
dims=qkv_weight.shape,
vals=qkv_weight.flatten().tobytes(),
raw=True,
)
self.model.add_initializer(weight, self.this_graph_name)

bias = helper.make_tensor(
bias_name,
data_type=TensorProto.FLOAT,
dims=qkv_bias.shape,
vals=qkv_bias.flatten().tobytes(),
raw=True,
)
self.model.add_initializer(bias, self.this_graph_name)

self.add_fp32_value_info(weight.name)
self.add_fp32_value_info(bias.name)

return weight_name, bias_name

def fuse(
self,
Expand Down Expand Up @@ -578,8 +500,16 @@ def fuse(
self.get_io_by_name(node, "rotary_emb.sin_cached"), ProcessRotCacheFunc()
)
else:
attn_qkv_weight = None
attn_qkv_bias = None
attn_qkv_weight, attn_qkv_bias = self.pack_qkv_gemm(
self.get_io_by_name(node, "self_attn.q_proj.weight"),
self.get_io_by_name(node, "self_attn.k_proj.weight"),
self.get_io_by_name(node, "self_attn.v_proj.weight"),
self.get_io_by_name(node, "self_attn.q_proj.bias"),
self.get_io_by_name(node, "self_attn.k_proj.bias"),
self.get_io_by_name(node, "self_attn.v_proj.bias"),
self.get_uname(layer_id, "attn_qkv_weight"),
self.get_uname(layer_id, "attn_qkv_bias")
)

attn_out_weight = self.process_initializer(
self.get_io_by_name(node, "self_attn.dense.weight"), ProcessGemmWFunc()
Expand Down Expand Up @@ -639,7 +569,15 @@ def fuse(
elif attn_type == "GroupQueryAttention":
subgraph_nodes.extend(
self.gqa(
["query_rot", "key_rot", "value", i_key_cache, i_value_cache, "seqlens_k", "total_sequence_length"],
[
"query_rot",
"key_rot",
"value",
i_key_cache,
i_value_cache,
"seqlens_k",
"total_sequence_length",
],
["attn_out", o_key_cache, o_value_cache],
)
)
Expand All @@ -652,7 +590,14 @@ def fuse(
numpy_helper.from_array(np.array([1], dtype="int64"), name="one"), self.this_graph_name
)
else:
print("bugbug")
past_name = f"past_{layer_id}"
present_name = f"present_{layer_id}"
layer_known_edges_names.extend([past_name, present_name])
subgraph_nodes.extend(
self.attention(
["ln_out", attn_qkv_weight, attn_qkv_bias, "attention_mask", past_name], ["attn_out", present_name]
)
)

self.set_unique_name_and_add_nodes(subgraph_nodes, layer_id, layer_known_edges_names)

Expand All @@ -663,10 +608,6 @@ def fuse(
self.prune_graph = True


def shape_of(vi):
return tuple([d.dim_param if (d.dim_param) else d.dim_value for d in vi.type.tensor_type.shape.dim])


class PhiOnnxModel(OnnxModel):
def __init__(self, model: ModelProto, num_heads: int = 0, head_size: int = 0):
super().__init__(model)
Expand Down

0 comments on commit eed75b8

Please sign in to comment.