Skip to content

Commit

Permalink
fix namespace issue
Browse files Browse the repository at this point in the history
  • Loading branch information
aciddelgado committed Oct 2, 2023
1 parent c6dabc1 commit 6a585e8
Showing 1 changed file with 24 additions and 26 deletions.
50 changes: 24 additions & 26 deletions onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ limitations under the License.
#include "contrib_ops/cuda/bert/attention_impl.h"

using namespace onnxruntime::cuda;
using namespace onnxruntime::contrib::attention_softmax_cuda;

namespace onnxruntime {
namespace contrib {
Expand Down Expand Up @@ -247,31 +246,31 @@ Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters,
const dim3 grid(kv_sequence_length, batch_size, 1);
const dim3 block(H, kv_num_heads, 1);
ConcatKVInPlace<float2><<< grid, block, 0, stream >>>(past_sequence_length,
reinterpret_cast<const float2*>(data.past_key),
reinterpret_cast<const float2*>(data.key),
reinterpret_cast<float2*>(data.present_key),
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
reinterpret_cast<float2*>(data.present_key),
reinterpret_cast<const float2*>(data.key),
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
ConcatKVInPlace<float2><<< grid, block, 0, stream >>>(past_sequence_length,
reinterpret_cast<const float2*>(data.past_value),
reinterpret_cast<const float2*>(data.value),
reinterpret_cast<float2*>(data.present_value),
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
} else {
const dim3 grid(present_sequence_length, batch_size, 1);
const dim3 block(max_threads_per_block / kv_num_heads, kv_num_heads, 1);
ConcatNewToPastKVLarge<float2><<< grid, block, 0, stream >>>(kv_sequence_length,
H,
reinterpret_cast<const float2*>(data.past_key),
reinterpret_cast<const float2*>(data.key),
reinterpret_cast<float2*>(data.present_key),
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
ConcatNewToPastKVLarge<float2><<< grid, block, 0, stream >>>(kv_sequence_length,
H,
reinterpret_cast<const float2*>(data.past_value),
reinterpret_cast<const float2*>(data.value),
reinterpret_cast<float2*>(data.present_value),
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
reinterpret_cast<float2*>(data.present_value),
reinterpret_cast<const float2*>(data.value),
past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
}
// TODO(aciddelgado): big gulp version
// } else {
// const dim3 grid(present_sequence_length, batch_size, 1);
// const dim3 block(max_threads_per_block / kv_num_heads, kv_num_heads, 1);
// ConcatNewToPastKVLarge<float2><<< grid, block, 0, stream >>>(kv_sequence_length,
// H,
// reinterpret_cast<const float2*>(data.past_key),
// reinterpret_cast<const float2*>(data.key),
// reinterpret_cast<float2*>(data.present_key),
// past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
// ConcatNewToPastKVLarge<float2><<< grid, block, 0, stream >>>(kv_sequence_length,
// H,
// reinterpret_cast<const float2*>(data.past_value),
// reinterpret_cast<const float2*>(data.value),
// reinterpret_cast<float2*>(data.present_value),
// past_kv_format == AttentionQkvFormat::Q_K_V_BSNH);
// }
return CUDA_CALL(cudaGetLastError());
}

Expand Down Expand Up @@ -374,10 +373,9 @@ Status FlashAttention(
const int num_heads = parameters.num_heads;
const int kv_num_heads = parameters.kv_num_heads;
const int head_size = parameters.head_size;
AttentionQkvFormat qkv_format = parameters.qkv_format;
AttentionQkvFormat past_kv_format = parameters.past_kv_format;

assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
assert(parameters.qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
assert(parameters.num_heads % parameters.kv_num_heads == 0);

void* query = reinterpret_cast<void*>(const_cast<T*>(data.query));
Expand Down

0 comments on commit 6a585e8

Please sign in to comment.