Skip to content

Commit

Permalink
Main Optimized
Browse files Browse the repository at this point in the history
  • Loading branch information
hariharans29 committed Sep 19, 2023
1 parent 03b56f7 commit 11f6a18
Showing 1 changed file with 8 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
q = add_vec(q, q_bias);
}


T* params_k_cache = reinterpret_cast<T*>(params.k_cache);

const float inv_sqrt_dh = params.scale;
Expand Down Expand Up @@ -350,24 +349,22 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio

// The keys loaded from the key cache.
K_vec_k k_vec[K_VECS_PER_THREAD];
if (ti < tlength) {
if (has_beams) {
const int beam_offset = beam_indices[ti] * params.num_heads * params.max_sequence_length * head_size;

if (has_beams) {
#pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_sequence_length + ti;
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_sequence_length + ti;

if (ti < tlength) {
const int beam_offset = beam_indices[ti] * params.num_heads * params.max_sequence_length * head_size;
k_vec[ii] = vec_conversion<K_vec_k, K_vec_m>(
(*reinterpret_cast<const K_vec_m*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B])));
}
}
} else {
} else {
#pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_sequence_length + ti;
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_sequence_length + ti;

if (ti < tlength) {
k_vec[ii] = vec_conversion<K_vec_k, K_vec_m>(
(*reinterpret_cast<const K_vec_m*>(&k_cache_batch[jj * QK_ELTS_IN_16B])));
}
Expand Down

0 comments on commit 11f6a18

Please sign in to comment.