diff --git a/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h b/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h index 763f654e28..4edc91f428 100644 --- a/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h +++ b/include/flexflow/ops/kernels/inc_multihead_self_attention_kernels.h @@ -14,6 +14,25 @@ namespace FlexFlow { namespace Kernels { namespace IncMultiHeadAttention { +template +__global__ void compute_attention_kernel_generation_kernel( + DT const *query, + DT const *key_cache, + DT const *value_cache, + DT *output_ptr, + float const scale, + int max_seq_length, + int per_head_size, + int hidden_size, + BatchConfig::PerRequestInfo *request_infos, + bool is_beam, + int max_beam_width); + template __global__ void apply_position_bias_qkprd(DT *input_ptr, int num_tokens,