Skip to content

Commit

Permalink
Fix build break with cuda 12.2 (#17922)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
nvcc 12.2 crashes while building
onnxruntime/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_*
for SM<8.0. nvcc 18.8 works though. It should be a bug in nvcc 12.2.

This PR excludes building flashattention for arch < 800.
  • Loading branch information
yufenglee authored Oct 13, 2023
1 parent 28c1944 commit 7551dd0
Showing 1 changed file with 12 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,30 @@ namespace flash {

template <typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Return_softmax>
__global__ void flash_fwd_kernel(Flash_fwd_params params) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
flash::compute_attn<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Return_softmax>(params);
#else
(void)params;
#endif
}

template <typename Kernel_traits, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Split, bool Append_KV>
__global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
flash::compute_attn_splitkv<Kernel_traits, Is_causal, Is_even_MN, Is_even_K, Split, Append_KV>(params);
#else
(void)params;
#endif
}

template <typename Kernel_traits, int Log_max_splits, bool Is_even_K>
__global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static_assert(Log_max_splits >= 1);
flash::combine_attn_seqk_parallel<Kernel_traits, Log_max_splits, Is_even_K>(params);
#else
(void)params;
#endif
}

template <typename Kernel_traits, bool Is_causal>
Expand Down

0 comments on commit 7551dd0

Please sign in to comment.