Skip to content

Commit

Permalink
Fix build break with cuda 12.2
Browse files Browse the repository at this point in the history
  • Loading branch information
yufenglee committed Oct 13, 2023
1 parent a441a71 commit 8f3a3ae
Showing 1 changed file with 12 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,30 @@ namespace flash {

template <typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
__global__ void flash_fwd_kernel(Flash_fwd_params params) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
flash::compute_attn<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Return_softmax>(params);
#else
(void)params;
#endif
}

template <typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV>
__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Split, Append_KV>(params);
#else
(void)params;
#endif
}

template <typename Kernel_traits, int Log_max_splits, bool Is_even_K>
__global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static_assert(Log_max_splits >= 1);
flash::combine_attn_seqk_parallel<Kernel_traits, Log_max_splits, Is_even_K>(params);
#else
(void)params;
#endif
}

template <typename Kernel_traits, bool Is_causal>
Expand Down

0 comments on commit 8f3a3ae

Please sign in to comment.