From 6a585e8c9b2dbf667fed10c3eb3917a5e8dc9e17 Mon Sep 17 00:00:00 2001 From: aciddelgado Date: Mon, 2 Oct 2023 16:20:51 -0700 Subject: [PATCH] fix namespace issue --- .../cuda/bert/group_query_attention_impl.cu | 50 +++++++++---------- 1 file changed, 24 insertions(+), 26 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 62d1b659afffe..7f0a451eb6ae4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -43,7 +43,6 @@ limitations under the License. #include "contrib_ops/cuda/bert/attention_impl.h" using namespace onnxruntime::cuda; -using namespace onnxruntime::contrib::attention_softmax_cuda; namespace onnxruntime { namespace contrib { @@ -247,31 +246,31 @@ Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, const dim3 grid(kv_sequence_length, batch_size, 1); const dim3 block(H, kv_num_heads, 1); ConcatKVInPlace<<< grid, block, 0, stream >>>(past_sequence_length, - reinterpret_cast(data.past_key), - reinterpret_cast(data.key), - reinterpret_cast(data.present_key), - past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + reinterpret_cast(data.present_key), + reinterpret_cast(data.key), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); ConcatKVInPlace<<< grid, block, 0, stream >>>(past_sequence_length, - reinterpret_cast(data.past_value), - reinterpret_cast(data.value), - reinterpret_cast(data.present_value), - past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - } else { - const dim3 grid(present_sequence_length, batch_size, 1); - const dim3 block(max_threads_per_block / kv_num_heads, kv_num_heads, 1); - ConcatNewToPastKVLarge<<< grid, block, 0, stream >>>(kv_sequence_length, - H, - reinterpret_cast(data.past_key), - reinterpret_cast(data.key), - reinterpret_cast(data.present_key), - past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); - ConcatNewToPastKVLarge<<< grid, block, 0, stream >>>(kv_sequence_length, - H, - reinterpret_cast(data.past_value), - reinterpret_cast(data.value), - reinterpret_cast(data.present_value), - past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + reinterpret_cast(data.present_value), + reinterpret_cast(data.value), + past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); } + // TODO(aciddelgado): big gulp version + // } else { + // const dim3 grid(present_sequence_length, batch_size, 1); + // const dim3 block(max_threads_per_block / kv_num_heads, kv_num_heads, 1); + // ConcatNewToPastKVLarge<<< grid, block, 0, stream >>>(kv_sequence_length, + // H, + // reinterpret_cast(data.past_key), + // reinterpret_cast(data.key), + // reinterpret_cast(data.present_key), + // past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + // ConcatNewToPastKVLarge<<< grid, block, 0, stream >>>(kv_sequence_length, + // H, + // reinterpret_cast(data.past_value), + // reinterpret_cast(data.value), + // reinterpret_cast(data.present_value), + // past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); + // } return CUDA_CALL(cudaGetLastError()); } @@ -374,10 +373,9 @@ Status FlashAttention( const int num_heads = parameters.num_heads; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; - AttentionQkvFormat qkv_format = parameters.qkv_format; AttentionQkvFormat past_kv_format = parameters.past_kv_format; - assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH); + assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH); assert(parameters.num_heads % parameters.kv_num_heads == 0); void* query = reinterpret_cast(const_cast(data.query));