diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index fe56f84f0a886..814aa1fb3c8f0 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -110,6 +110,11 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { parameters.do_rotary = do_rotary_; parameters.rotary_interleaved = rotary_interleaved_; + if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1"); + } + TensorShapeVector output_shape(3); output_shape[0] = static_cast(parameters.batch_size); output_shape[1] = static_cast(sequence_length); diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index a2cdd17e19fa5..894e11275056e 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1273,7 +1273,7 @@ def find_past_seq_len_usage(subg: GraphProto): def replace_mha_with_gqa( - model: OnnxModel, attn_mask: str, kv_num_heads: int = 0, world_size: int = 1, window_size: int = 0 + model: OnnxModel, attn_mask: str, kv_num_heads: int = 0, world_size: int = 1, window_size: int = -1 ): # Insert attention_mask subgraph to calculate shared inputs for all GroupQueryAttention nodes # @@ -1339,31 +1339,163 @@ def replace_mha_with_gqa( ) # Replace MultiHeadAttention with GroupQueryAttention + # + # When replacing, fuse the following subgraph: + # + # root_input + # / | \ + # MatMul MatMul MatMul + # | | | + # Add Add Add (optional Adds) + # | | | + # RotEmb RotEmb | + # \ | / + # MultiHeadAttention + # + # to this new subgraph: + # + # root_input + # | + # PackedMatMul (if possible) + # | + # PackedAdd (if possible) + # | + # GroupQueryAttention + # + mha_nodes = list(filter(lambda node: node.op_type == "MultiHeadAttention", model.model.graph.node)) - for node in mha_nodes: - num_heads_mha = 0 + for idx, node in enumerate(mha_nodes): + # Detect Q path to MHA + q_path_1 = model.match_parent_path(node, ["RotaryEmbedding", "Add", "MatMul"], [0, 0, 0]) + q_path_2 = model.match_parent_path(node, ["RotaryEmbedding", "MatMul"], [0, 0]) + + q_rotary, q_add, q_matmul = None, None, None + if q_path_1 is not None: + q_rotary, q_add, q_matmul = q_path_1 + elif q_path_2 is not None: + q_rotary, q_matmul = q_path_2 + + # Detect K path to MHA + k_path_1 = model.match_parent_path(node, ["RotaryEmbedding", "Add", "MatMul"], [1, 0, 0]) + k_path_2 = model.match_parent_path(node, ["RotaryEmbedding", "MatMul"], [1, 0]) + + k_rotary, k_add, k_matmul = None, None, None + if k_path_1 is not None: + k_rotary, k_add, k_matmul = k_path_1 + elif k_path_2 is not None: + k_rotary, k_matmul = k_path_2 + + # Detect V path to MHA + v_path_1 = model.match_parent_path(node, ["Add", "MatMul"], [2, 0]) + v_path_2 = model.match_parent_path(node, ["MatMul"], [2]) + + v_add, v_matmul = None, None + if v_path_1 is not None: + v_add, v_matmul = v_path_1 + elif v_path_2 is not None: + v_matmul = v_path_2[0] + + # Get `interleaved` attribute from RotaryEmbedding + interleaved = 0 + if q_rotary is not None and k_rotary is not None: + for att in q_rotary.attribute: + if att.name == "interleaved": + interleaved = att.i + + # Get `num_heads` attribute from MHA + num_heads = 0 for att in node.attribute: if att.name == "num_heads": - num_heads_mha = att.i + num_heads = att.i + + # Check if root_input to Q/K/V paths is the same + root_input_is_same = q_matmul.input[0] == k_matmul.input[0] and k_matmul.input[0] == v_matmul.input[0] + + # Check if Q/K/V paths all have bias or all don't have bias + all_paths_have_bias = q_add is not None and k_add is not None and v_add is not None + all_paths_have_no_bias = q_add is None and k_add is None and v_add is None + + # Make PackedMatMul node if possible + q_input_to_attention, k_input_to_attention, v_input_to_attention = "", "", "" + if root_input_is_same and (all_paths_have_bias or all_paths_have_no_bias): + qw = NumpyHelper.to_array(model.get_initializer(q_matmul.input[1])) + kw = NumpyHelper.to_array(model.get_initializer(k_matmul.input[1])) + vw = NumpyHelper.to_array(model.get_initializer(v_matmul.input[1])) + + dim = qw.shape[-1] + qkv_weight = np.stack((qw, kw, vw), axis=1).reshape(dim, 3 * dim) + qkv_weight = onnx.numpy_helper.from_array(qkv_weight, name=f"QKV_Weight_{idx}") + model.add_initializer(qkv_weight) + + packed_matmul_node = onnx.helper.make_node( + "MatMul", + inputs=[q_matmul.input[0], qkv_weight.name], + outputs=[f"{qkv_weight.name}_output"], + name=model.create_node_name("MatMul"), + ) + model.model.graph.node.extend([packed_matmul_node]) + model.model.graph.node.remove(q_matmul) + model.model.graph.node.remove(k_matmul) + model.model.graph.node.remove(v_matmul) + q_input_to_attention = packed_matmul_node.output[0] + + # Make PackedAdd node if possible + if all_paths_have_bias: + qb = NumpyHelper.to_array(model.get_initializer(q_add.input[1])) + kb = NumpyHelper.to_array(model.get_initializer(k_add.input[1])) + vb = NumpyHelper.to_array(model.get_initializer(v_add.input[1])) + + dim = qb.shape[-1] + qkv_bias = np.stack((qb, kb, vb), axis=0).reshape(3 * dim) + qkv_bias = onnx.numpy_helper.from_array(qkv_bias, name=f"QKV_Bias_{idx}") + model.add_initializer(qkv_bias) + packed_add_node = onnx.helper.make_node( + "Add", + inputs=[packed_matmul_node.output[0], qkv_bias.name], + outputs=[f"{qkv_bias.name}_output"], + ) + model.model.graph.node.extend([packed_add_node]) + model.model.graph.node.remove(q_add) + model.model.graph.node.remove(k_add) + model.model.graph.node.remove(v_add) + q_input_to_attention = packed_add_node.output[0] + + else: + q_input_to_attention = q_matmul.output[0] + k_input_to_attention = k_matmul.output[0] + v_input_to_attention = v_matmul.output[0] + + # Make GQA node gqa_node = onnx.helper.make_node( "GroupQueryAttention", inputs=[ - node.input[0], # query - node.input[1], # key - node.input[2], # value + q_input_to_attention, # query + k_input_to_attention, # key + v_input_to_attention, # value node.input[6], # past_key node.input[7], # past_value - "seqlens_k", # seqlens_k (for attention_mask) - "total_seq_len", # total_seq_len (for attention_mask) + seqlen_k_cast_node.output[0], # seqlens_k (for attention mask) + total_seqlen_cast_node.output[0], # total_seq_len (for attention mask) + q_rotary.input[2] if q_rotary is not None else "", # cos_cache (for rotary embeddings) + q_rotary.input[3] if q_rotary is not None else "", # sin_cache (for rotary embeddings) ], outputs=node.output, name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"), domain="com.microsoft", - num_heads=num_heads_mha // world_size, - kv_num_heads=num_heads_mha // world_size if kv_num_heads == 0 else kv_num_heads // world_size, + num_heads=num_heads // world_size, + kv_num_heads=num_heads // world_size if kv_num_heads == 0 else kv_num_heads // world_size, + local_window_size=window_size, + do_rotary=int(q_rotary is not None and k_rotary is not None), + rotary_interleaved=interleaved, ) model.model.graph.node.remove(node) model.model.graph.node.extend([gqa_node]) + + if q_rotary is not None: + model.model.graph.node.remove(q_rotary) + if k_rotary is not None: + model.model.graph.node.remove(k_rotary) + return model diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index a329b73259dda..18202f4b81c0f 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -222,7 +222,8 @@ def get_msft_sample_inputs( # Create past_key_values # Each is of shape (batch_size, num_heads, past_sequence_length, head_size) def get_past_kv_inputs(config: AutoConfig, batch_size: int, past_seq_len: int, use_fp16: bool, world_size: int = 1): - num_heads, head_size = config.num_key_value_heads // world_size, config.hidden_size // config.num_attention_heads + num_heads = config.num_key_value_heads // world_size + head_size = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads torch_dtype = torch.float16 if use_fp16 else torch.float32 past_kv = [ ( @@ -286,7 +287,14 @@ def add_io_bindings( ): io_binding = model.io_binding() + model_inputs = set(map(lambda i: i.name, model.get_inputs())) for k, v in ort_inputs.items(): + # Use this check to handle scenarios such as INT4 CUDA and FP16 CUDA models with + # GQA + RotaryEmbedding fusion where `position_ids` is removed as an ONNX model input + # but `position_ids` is used as a PyTorch model input + if k not in model_inputs: + continue + # Bind OrtValue inputs to device if use_gqa and ("cache" in k or "past_key_values" in k): if k not in kv_cache_ortvalues: