From 6b213eb7ebd2175af3129895988e1e5c95c44954 Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Thu, 12 Oct 2023 21:05:44 +0000 Subject: [PATCH] fix format --- .../cuda/bert/flash_attention/flash_fwd_launch_template.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h index 2b8226c7e40c2..c7d200fdcb3a1 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_fwd_launch_template.h @@ -3,7 +3,7 @@ ******************************************************************************/ #pragma once -#include +#include #include "contrib_ops/cuda/bert/flash_attention/static_switch.h" #include "contrib_ops/cuda/bert/flash_attention/flash.h" @@ -113,7 +113,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) { template void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream) { - #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0; constexpr int kBlockM = 64; // Fixed for all head dimensions if (!is_sm8x) { // A100, H100 @@ -125,11 +125,11 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream) constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64); run_flash_splitkv_fwd>(params, stream); } - #else +#else (void)params; (void)stream; throw std::runtime_error("FlashAttention is only implemented for SM>=80"); - #endif +#endif } template