Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-goliaro committed Oct 22, 2024
1 parent 674eed7 commit 2dab7cb
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/ops/inc_multihead_self_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ Tensor FFModel::groupquery_self_attention(const Tensor input,
bool add_zero_attn,
DataType data_type,
Initializer *kernel_initializer,
RotaryEmbeddingMeta rotary_embedding_meta,,
RotaryEmbeddingMeta rotary_embedding_meta,
bool scaling_query,
float scaling_factor,
bool qk_prod_scaling,
Expand Down
9 changes: 8 additions & 1 deletion src/ops/kernels/inc_multihead_self_attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ void apply_pos_encoding_to_tokens_in_batch(
DT *output_ptr,
cudaStream_t stream) {
// apply rotary embedding if needed
if (!*m->apply_rotary_embedding) {
if (!m->rotary_embedding_meta->apply_rotary_embedding) {
return;
}
int num_tokens = bc->num_active_tokens();
Expand All @@ -338,13 +338,20 @@ void apply_pos_encoding_to_tokens_in_batch(
}
int parallelism = num_tokens * m->local_hidden_size;
size_t q_array_size = m->qk_dim * num_tokens * m->num_q_heads;
bool llama3_rope = (m->rotary_embedding_meta->rope_type == "llama3");
apply_pos_encoding_to_tokens_in_batch_kernel<<<GET_BLOCKS(parallelism),
min(CUDA_NUM_THREADS,
parallelism),
0,
stream>>>(
output_ptr,
m->token_infos,
m->rotary_embedding_meta->rope_theta,
llama3_rope,
m->rotary_embedding_meta->factor,
m->rotary_embedding_meta->low_freq_factor,
m->rotary_embedding_meta->high_freq_factor,
m->rotary_embedding_meta->original_max_position_embeddings,
m->qk_dim,
num_tokens,
q_array_size,
Expand Down
2 changes: 1 addition & 1 deletion src/ops/spec_inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ void SpecIncMultiHeadSelfAttention::inference_kernel_wrapper(
GenericTensorAccessorR const &bias) {
cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));
// bool use_bias = *m->qkv_bias || *m->final_bias;
bool use_bias = *m->qkv_bias || *m->final_bias;

cudaEvent_t t_start, t_end;
if (m->profiling) {
Expand Down
2 changes: 1 addition & 1 deletion src/ops/tree_inc_multihead_self_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ void TreeIncMultiHeadSelfAttention::inference_kernel_wrapper(
GenericTensorAccessorR const &bias) {
cudaStream_t stream;
checkCUDA(get_legion_stream(&stream));
// bool use_bias = *m->qkv_bias || *m->final_bias;
bool use_bias = *m->qkv_bias || *m->final_bias;
// int device;
// checkCUDA(cudaGetDevice(&device));
Expand Down

0 comments on commit 2dab7cb

Please sign in to comment.