diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h index e0be6b828f85d..784335a124c75 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h @@ -12,18 +12,30 @@ namespace flash { template __global__ void flash_fwd_kernel(Flash_fwd_params params) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 flash::compute_attn(params); +#else + (void)params; +#endif } template __global__ void flash_fwd_splitkv_kernel(Flash_fwd_params params) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 flash::compute_attn_splitkv(params); +#else + (void)params; +#endif } template __global__ void flash_fwd_splitkv_combine_kernel(Flash_fwd_params params) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 static_assert(Log_max_splits >= 1); flash::combine_attn_seqk_parallel(params); +#else + (void)params; +#endif } template