Skip to content

Commit

Permalink
add an option to control GQA on CPU
Browse files Browse the repository at this point in the history
  • Loading branch information
yufenglee committed May 1, 2024
1 parent 70ab890 commit 7dc9de1
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
"use_rotemb_in_attn": False, # Use rotary embeddings within attention op (instead of a separate RotaryEmbedding op)
"use_packed_matmul": False, # Use packed MatMul (instead of 3 separate MatMuls for Q/K/V)
}
if (self.ep in {"cuda", "dml"} and self.io_dtype == TensorProto.FLOAT16) or (self.ep == "cpu" and self.io_dtype == TensorProto.FLOAT):
enable_GQA_on_CPU = True if "enable_GQA_on_CPU" in extra_options and extra_options["enable_GQA_on_CPU"] == "1" else False
if (self.ep in {"cuda", "dml"} and self.io_dtype == TensorProto.FLOAT16) or (enable_GQA_on_CPU and self.ep == "cpu" and self.io_dtype == TensorProto.FLOAT):
# Change model settings for GroupQueryAttention
self.attention_attrs["op_type"] = "GroupQueryAttention"
print("GroupQueryAttention (GQA) is used in this model. GQA is currently supported only for INT4 and FP16 on the CUDA and DML execution providers.")
Expand All @@ -176,7 +177,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
self.attention_attrs["use_packed_matmul"] = self.ep != "dml" and self.num_attn_heads == self.num_kv_heads

# GQA + Rot.Emb. does not require `position ids` as input
if self.ep == "cuda" or self.ep == "cpu":
if self.ep in {"cuda", "cpu"}:
self.attention_attrs["use_rotemb_in_attn"] = True
self.input_names.remove("position_ids")

Expand Down Expand Up @@ -1979,6 +1980,7 @@ def get_args():
enable_cuda_graph = 1 : The model can use CUDA graph capture for CUDA execution provider. If enabled, all nodes being placed on the CUDA EP
is the prerequisite for the CUDA graph to be used correctly. It is not guaranteed that cuda graph be enabled as it depends on the model
and the graph structure.
enable_GQA_on_CPU = Enalbe G(Group)Query(Q)Attention(A) on CPU.
"""),
)

Expand Down

0 comments on commit 7dc9de1

Please sign in to comment.