diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index f8bb820743f87..0fd8790e0d29d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -96,7 +96,7 @@ struct GroupQueryAttentionParameters { int head_size; int kv_hidden_size; int kv_num_heads; - int num_splits; // number of splits for splitkv + int num_splits; // number of splits for splitkv bool is_unidirectional; // causal float scale; AttentionQkvFormat qkv_format; diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index a7fed3903b78f..89a27c4d2b0d3 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -195,7 +195,7 @@ std::tuple get_num_splits_and_buffer_sizes(int batch_size, int se int max_splits = 128; // split kv buffers int num_splits = num_splits_heuristic(batch_size, seqlen_q, seqlen_k, num_heads, head_size, - num_SMs, max_splits); + num_SMs, max_splits); if (num_splits > 1) { // softmax_lse_accum buffer int softmax_lse_accum_bytes = get_softmax_lse_accum_size(num_splits, batch_size, num_heads, seqlen_q);