From 7551dd039f67e1adc9fa6d2b2627c6c7f17a0a8d Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Fri, 13 Oct 2023 10:21:06 -0700 Subject: [PATCH] Fix build break with cuda 12.2 (#17922) ### Description nvcc 12.2 crashes while building onnxruntime/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_* for SM<8.0. nvcc 18.8 works though. It should be a bug in nvcc 12.2. This PR excludes building flashattention for arch < 800. --- .../bert/flash_attention/flash_fwd_launch_template.h | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h index e0be6b828f85d..784335a124c75 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h @@ -12,18 +12,30 @@ namespace flash { template __global__ void flash_fwd_kernel(Flash_fwd_params params) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 flash::compute_attn(params); +#else + (void)params; +#endif } template __global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 flash::compute_attn_splitkv(params); +#else + (void)params; +#endif } template __global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 static_assert(Log_max_splits >= 1); flash::combine_attn_seqk_parallel(params); +#else + (void)params; +#endif } template