Skip to content

Commit

Permalink
Update replacing MultiHeadAttention with GroupQueryAttention (#19882)
Browse files Browse the repository at this point in the history
### Description
This PR updates the replacement of MultiHeadAttention (MHA) with
GroupQueryAttention (GQA). It is related to the changes in [this
PR](#18906).

### Motivation and Context
The updated replacement of MHA with GQA includes the following fusion
changes.
- Apply sliding window within GQA
- Fuse the rotary embeddings within GQA
- Fuse the 3 MatMuls into 1 packed MatMul if possible
- Fuse the 3 Adds into 1 packed Add if possible
  • Loading branch information
kunal-vaishnavi authored Mar 13, 2024
1 parent 8eb49c5 commit 4ac98d6
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 12 deletions.
5 changes: 5 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.do_rotary = do_rotary_;
parameters.rotary_interleaved = rotary_interleaved_;

if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1");
}

TensorShapeVector output_shape(3);
output_shape[0] = static_cast<int64_t>(parameters.batch_size);
output_shape[1] = static_cast<int64_t>(sequence_length);
Expand Down
154 changes: 143 additions & 11 deletions onnxruntime/python/tools/transformers/convert_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,7 +1273,7 @@ def find_past_seq_len_usage(subg: GraphProto):


def replace_mha_with_gqa(
model: OnnxModel, attn_mask: str, kv_num_heads: int = 0, world_size: int = 1, window_size: int = 0
model: OnnxModel, attn_mask: str, kv_num_heads: int = 0, world_size: int = 1, window_size: int = -1
):
# Insert attention_mask subgraph to calculate shared inputs for all GroupQueryAttention nodes
#
Expand Down Expand Up @@ -1339,31 +1339,163 @@ def replace_mha_with_gqa(
)

# Replace MultiHeadAttention with GroupQueryAttention
#
# When replacing, fuse the following subgraph:
#
# root_input
# / | \
# MatMul MatMul MatMul
# | | |
# Add Add Add (optional Adds)
# | | |
# RotEmb RotEmb |
# \ | /
# MultiHeadAttention
#
# to this new subgraph:
#
# root_input
# |
# PackedMatMul (if possible)
# |
# PackedAdd (if possible)
# |
# GroupQueryAttention
#

mha_nodes = list(filter(lambda node: node.op_type == "MultiHeadAttention", model.model.graph.node))
for node in mha_nodes:
num_heads_mha = 0
for idx, node in enumerate(mha_nodes):
# Detect Q path to MHA
q_path_1 = model.match_parent_path(node, ["RotaryEmbedding", "Add", "MatMul"], [0, 0, 0])
q_path_2 = model.match_parent_path(node, ["RotaryEmbedding", "MatMul"], [0, 0])

q_rotary, q_add, q_matmul = None, None, None
if q_path_1 is not None:
q_rotary, q_add, q_matmul = q_path_1
elif q_path_2 is not None:
q_rotary, q_matmul = q_path_2

# Detect K path to MHA
k_path_1 = model.match_parent_path(node, ["RotaryEmbedding", "Add", "MatMul"], [1, 0, 0])
k_path_2 = model.match_parent_path(node, ["RotaryEmbedding", "MatMul"], [1, 0])

k_rotary, k_add, k_matmul = None, None, None
if k_path_1 is not None:
k_rotary, k_add, k_matmul = k_path_1
elif k_path_2 is not None:
k_rotary, k_matmul = k_path_2

# Detect V path to MHA
v_path_1 = model.match_parent_path(node, ["Add", "MatMul"], [2, 0])
v_path_2 = model.match_parent_path(node, ["MatMul"], [2])

v_add, v_matmul = None, None
if v_path_1 is not None:
v_add, v_matmul = v_path_1
elif v_path_2 is not None:
v_matmul = v_path_2[0]

# Get `interleaved` attribute from RotaryEmbedding
interleaved = 0
if q_rotary is not None and k_rotary is not None:
for att in q_rotary.attribute:
if att.name == "interleaved":
interleaved = att.i

# Get `num_heads` attribute from MHA
num_heads = 0
for att in node.attribute:
if att.name == "num_heads":
num_heads_mha = att.i
num_heads = att.i

# Check if root_input to Q/K/V paths is the same
root_input_is_same = q_matmul.input[0] == k_matmul.input[0] and k_matmul.input[0] == v_matmul.input[0]

# Check if Q/K/V paths all have bias or all don't have bias
all_paths_have_bias = q_add is not None and k_add is not None and v_add is not None
all_paths_have_no_bias = q_add is None and k_add is None and v_add is None

# Make PackedMatMul node if possible
q_input_to_attention, k_input_to_attention, v_input_to_attention = "", "", ""
if root_input_is_same and (all_paths_have_bias or all_paths_have_no_bias):
qw = NumpyHelper.to_array(model.get_initializer(q_matmul.input[1]))
kw = NumpyHelper.to_array(model.get_initializer(k_matmul.input[1]))
vw = NumpyHelper.to_array(model.get_initializer(v_matmul.input[1]))

dim = qw.shape[-1]
qkv_weight = np.stack((qw, kw, vw), axis=1).reshape(dim, 3 * dim)
qkv_weight = onnx.numpy_helper.from_array(qkv_weight, name=f"QKV_Weight_{idx}")
model.add_initializer(qkv_weight)

packed_matmul_node = onnx.helper.make_node(
"MatMul",
inputs=[q_matmul.input[0], qkv_weight.name],
outputs=[f"{qkv_weight.name}_output"],
name=model.create_node_name("MatMul"),
)
model.model.graph.node.extend([packed_matmul_node])
model.model.graph.node.remove(q_matmul)
model.model.graph.node.remove(k_matmul)
model.model.graph.node.remove(v_matmul)
q_input_to_attention = packed_matmul_node.output[0]

# Make PackedAdd node if possible
if all_paths_have_bias:
qb = NumpyHelper.to_array(model.get_initializer(q_add.input[1]))
kb = NumpyHelper.to_array(model.get_initializer(k_add.input[1]))
vb = NumpyHelper.to_array(model.get_initializer(v_add.input[1]))

dim = qb.shape[-1]
qkv_bias = np.stack((qb, kb, vb), axis=0).reshape(3 * dim)
qkv_bias = onnx.numpy_helper.from_array(qkv_bias, name=f"QKV_Bias_{idx}")
model.add_initializer(qkv_bias)
packed_add_node = onnx.helper.make_node(
"Add",
inputs=[packed_matmul_node.output[0], qkv_bias.name],
outputs=[f"{qkv_bias.name}_output"],
)
model.model.graph.node.extend([packed_add_node])
model.model.graph.node.remove(q_add)
model.model.graph.node.remove(k_add)
model.model.graph.node.remove(v_add)
q_input_to_attention = packed_add_node.output[0]

else:
q_input_to_attention = q_matmul.output[0]
k_input_to_attention = k_matmul.output[0]
v_input_to_attention = v_matmul.output[0]

# Make GQA node
gqa_node = onnx.helper.make_node(
"GroupQueryAttention",
inputs=[
node.input[0], # query
node.input[1], # key
node.input[2], # value
q_input_to_attention, # query
k_input_to_attention, # key
v_input_to_attention, # value
node.input[6], # past_key
node.input[7], # past_value
"seqlens_k", # seqlens_k (for attention_mask)
"total_seq_len", # total_seq_len (for attention_mask)
seqlen_k_cast_node.output[0], # seqlens_k (for attention mask)
total_seqlen_cast_node.output[0], # total_seq_len (for attention mask)
q_rotary.input[2] if q_rotary is not None else "", # cos_cache (for rotary embeddings)
q_rotary.input[3] if q_rotary is not None else "", # sin_cache (for rotary embeddings)
],
outputs=node.output,
name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"),
domain="com.microsoft",
num_heads=num_heads_mha // world_size,
kv_num_heads=num_heads_mha // world_size if kv_num_heads == 0 else kv_num_heads // world_size,
num_heads=num_heads // world_size,
kv_num_heads=num_heads // world_size if kv_num_heads == 0 else kv_num_heads // world_size,
local_window_size=window_size,
do_rotary=int(q_rotary is not None and k_rotary is not None),
rotary_interleaved=interleaved,
)
model.model.graph.node.remove(node)
model.model.graph.node.extend([gqa_node])

if q_rotary is not None:
model.model.graph.node.remove(q_rotary)
if k_rotary is not None:
model.model.graph.node.remove(k_rotary)

return model


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ def get_msft_sample_inputs(
# Create past_key_values
# Each is of shape (batch_size, num_heads, past_sequence_length, head_size)
def get_past_kv_inputs(config: AutoConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1):
num_heads, head_size = config.num_key_value_heads // world_size, config.hidden_size // config.num_attention_heads
num_heads = config.num_key_value_heads // world_size
head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
torch_dtype = torch.float16 if use_fp16 else torch.float32
past_kv = [
(
Expand Down Expand Up @@ -286,7 +287,14 @@ def add_io_bindings(
):
io_binding = model.io_binding()

model_inputs = set(map(lambda i: i.name, model.get_inputs()))
for k, v in ort_inputs.items():
# Use this check to handle scenarios such as INT4 CUDA and FP16 CUDA models with
# GQA + RotaryEmbedding fusion where `position_ids` is removed as an ONNX model input
# but `position_ids` is used as a PyTorch model input
if k not in model_inputs:
continue

# Bind OrtValue inputs to device
if use_gqa and ("cache" in k or "past_key_values" in k):
if k not in kv_cache_ortvalues:
Expand Down

0 comments on commit 4ac98d6

Please sign in to comment.