Skip to content

Commit

Permalink
update GQA message (#396)
Browse files Browse the repository at this point in the history
  • Loading branch information
yufenglee authored May 3, 2024
1 parent 88d46dd commit b272ba4
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
if (self.ep in {"cuda", "dml"} and self.io_dtype == TensorProto.FLOAT16) or (enable_GQA_on_CPU and self.ep == "cpu" and self.io_dtype == TensorProto.FLOAT):
# Change model settings for GroupQueryAttention
self.attention_attrs["op_type"] = "GroupQueryAttention"
print("GroupQueryAttention (GQA) is used in this model. GQA is currently supported only for INT4 and FP16 on the CUDA and DML execution providers.")
print("GroupQueryAttention (GQA) is used in this model.")

# DML doesn't support packed Q/K/V for GQA yet
self.attention_attrs["use_packed_matmul"] = self.ep != "dml" and self.num_attn_heads == self.num_kv_heads
Expand Down

0 comments on commit b272ba4

Please sign in to comment.