Skip to content

Commit

Permalink
fix format
Browse files Browse the repository at this point in the history
  • Loading branch information
yufenglee committed Oct 12, 2023
1 parent e393299 commit 6b213eb
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
******************************************************************************/
#pragma once

#include<stdexcept>
#include <stdexcept>

#include "contrib_ops/cuda/bert/flash_attention/static_switch.h"
#include "contrib_ops/cuda/bert/flash_attention/flash.h"
Expand Down Expand Up @@ -113,7 +113,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params& params, cudaStream_t stream) {

template <typename T, int Headdim>
void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
bool is_sm8x = params.dprops->major == 8 && params.dprops->minor > 0;
constexpr int kBlockM = 64; // Fixed for all head dimensions
if (!is_sm8x) { // A100, H100
Expand All @@ -125,11 +125,11 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params& params, cudaStream_t stream)
constexpr int kBlockN = Headdim <= 64 ? 256 : (Headdim <= 128 ? 128 : 64);
run_flash_splitkv_fwd<Flash_fwd_kernel_traits<Headdim, kBlockM, kBlockN, 4, false, false, T>>(params, stream);
}
#else
#else
(void)params;
(void)stream;
throw std::runtime_error("FlashAttention is only implemented for SM>=80");
#endif
#endif
}

template <typename T>
Expand Down

0 comments on commit 6b213eb

Please sign in to comment.