diff --git a/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h b/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h index 9bf2f581e2..26dcf12425 100644 --- a/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h +++ b/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h @@ -56,7 +56,8 @@ __global__ void apply_proj_bias_qkv(DT *input_ptr, int num_heads, int num_kv_heads, bool scaling_query, - float scaling_factor); + float scaling_factor, + int hidden_size); #if defined(FF_USE_CUDA) || defined(FF_USE_HIP_CUDA) template diff --git a/src/ops/attention.cu b/src/ops/attention.cu index 9b8b90da70..18fc810aed 100644 --- a/src/ops/attention.cu +++ b/src/ops/attention.cu @@ -206,7 +206,7 @@ MultiHeadAttentionMeta::MultiHeadAttentionMeta(FFHandler handler, checkCUDNN(cudnnCreateSeqDataDescriptor(&oDesc)); // Currently do not support adding bias to key/value projection assert(!attn->add_bias_kv); - cudnnAttnQueryMap_t attnMode = CUDNN_ATTN_QUERYMAP_ALL_TO_ONE; + unsigned attnMode = CUDNN_ATTN_QUERYMAP_ALL_TO_ONE; // Assume no beam search for now int maxBeamSize = 1; // printf("batchSize(%d) qSize(%d) kSize(%d) vSize(%d) qProjSize(%d)