From 8f3a3aedda1f4ffb70127ec1a725b52c8acd7a14 Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Fri, 13 Oct 2023 05:53:20 +0000 Subject: [PATCH] Fix build break with cuda 12.2 --- .../bert/flash_attention/flash_fwd_launch_template.h | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h index e0be6b828f85d..784335a124c75 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h @@ -12,18 +12,30 @@ namespace flash { template __global__ void flash_fwd_kernel(Flash_fwd_params params) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 flash::compute_attn(params); +#else + (void)params; +#endif } template __global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 flash::compute_attn_splitkv(params); +#else + (void)params; +#endif } template __global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 static_assert(Log_max_splits >= 1); flash::combine_attn_seqk_parallel(params); +#else + (void)params; +#endif } template