Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

llama 70b model fusion and shardding #18175

Merged
merged 15 commits into from
Nov 2, 2023
Merged
1 change: 1 addition & 0 deletions onnxruntime/python/tools/symbolic_shape_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"GatherElements": self._infer_GatherElements,
"GatherND": self._infer_GatherND,
"Identity": self._pass_on_shape_and_type,
"AllReduce": self._pass_on_shape_and_type,
"If": self._infer_If,
"Loop": self._infer_Loop,
"MatMul": self._infer_MatMul,
Expand Down
10 changes: 7 additions & 3 deletions onnxruntime/python/tools/transformers/convert_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,7 +1272,7 @@ def find_past_seq_len_usage(subg: GraphProto):
return tensor_names_to_rename, nodes_to_remove


def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0):
def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0, world_size: int = 1):
past_seq_len = past_seq_len_input
if past_seq_len not in model.get_graphs_input_names():
# Add model input for past sequence length
Expand All @@ -1282,6 +1282,10 @@ def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads
# Replace MultiHeadAttention with GroupQueryAttention
for node in model.model.graph.node:
if node.op_type == "MultiHeadAttention":
num_heads_mha = 0
for att in node.attribute:
if att.name == "num_heads":
num_heads_mha = att.i
gqa_node = onnx.helper.make_node(
"GroupQueryAttention",
inputs=[
Expand All @@ -1295,8 +1299,8 @@ def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads
outputs=node.output,
name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"),
domain="com.microsoft",
num_heads=node.attribute[0].i,
kv_num_heads=node.attribute[0].i if kv_num_heads == 0 else kv_num_heads,
num_heads=num_heads_mha // world_size,
frank-dong-ms marked this conversation as resolved.
Show resolved Hide resolved
kv_num_heads=num_heads_mha // world_size if kv_num_heads == 0 else kv_num_heads // world_size,
is_past_bsnh=0,
)
model.model.graph.node.remove(node)
Expand Down
5 changes: 5 additions & 0 deletions onnxruntime/python/tools/transformers/fusion_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,8 @@ def add_nodes_to_remove(self, nodes: List[NodeProto]):
for node in nodes:
if node not in self.nodes_to_remove:
self.nodes_to_remove.append(node)

def add_nodes_to_remove_with_nodes_to_keep(self, nodes: List[NodeProto], nodes_to_keep: List[NodeProto]):
for node in nodes:
if node not in self.nodes_to_remove and node not in nodes_to_keep:
self.nodes_to_remove.append(node)
Loading
Loading