Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable GQA on CPU #270

Merged
merged 7 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/models/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ KV_Cache::KV_Cache(const Model& model, State& state)
: model_{model},
state_{state},
layer_count_{model_.config_->model.decoder.num_hidden_layers},
past_present_share_buffer_{state_.params_->search.past_present_share_buffer && state_.params_->search.num_beams == 1 && (model_.device_type_ == DeviceType::CUDA || model_.device_type_ == DeviceType::DML)},
past_present_share_buffer_{state_.params_->search.past_present_share_buffer && state_.params_->search.num_beams == 1},
shape_{state_.params_->BatchBeamSize(), model.config_->model.decoder.num_key_value_heads, 0, model.config_->model.decoder.head_size} {
if (g_log.enabled && g_log.warning && past_present_share_buffer_ != state_.params_->search.past_present_share_buffer)
Log("warning", "past_present_share_buffer search option set to true, but has been disabled due to the current configuration. See https://aka.ms/generate_config for details");
Expand Down
6 changes: 4 additions & 2 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
"use_rotemb_in_attn": False, # Use rotary embeddings within attention op (instead of a separate RotaryEmbedding op)
"use_packed_matmul": False, # Use packed MatMul (instead of 3 separate MatMuls for Q/K/V)
}
if self.ep in {"cuda", "dml"} and self.io_dtype == TensorProto.FLOAT16:
enable_GQA_on_CPU = True if "enable_GQA_on_CPU" in extra_options and extra_options["enable_GQA_on_CPU"] == "1" else False
if (self.ep in {"cuda", "dml"} and self.io_dtype == TensorProto.FLOAT16) or (enable_GQA_on_CPU and self.ep == "cpu" and self.io_dtype == TensorProto.FLOAT):
# Change model settings for GroupQueryAttention
self.attention_attrs["op_type"] = "GroupQueryAttention"
print("GroupQueryAttention (GQA) is used in this model. GQA is currently supported only for INT4 and FP16 on the CUDA and DML execution providers.")
Expand All @@ -176,7 +177,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
self.attention_attrs["use_packed_matmul"] = self.ep != "dml" and self.num_attn_heads == self.num_kv_heads

# GQA + Rot.Emb. does not require `position ids` as input
if self.ep == "cuda":
if self.ep in {"cuda", "cpu"}:
self.attention_attrs["use_rotemb_in_attn"] = True
self.input_names.remove("position_ids")

Expand Down Expand Up @@ -1979,6 +1980,7 @@ def get_args():
enable_cuda_graph = 1 : The model can use CUDA graph capture for CUDA execution provider. If enabled, all nodes being placed on the CUDA EP
is the prerequisite for the CUDA graph to be used correctly. It is not guaranteed that cuda graph be enabled as it depends on the model
and the graph structure.
enable_GQA_on_CPU = Enalbe G(Group)Query(Q)Attention(A) on CPU.
"""),
)

Expand Down
Loading