Skip to content

Commit

Permalink
Fix colfax_cutlass flash_attention operator (#2401)
Browse files Browse the repository at this point in the history
Summary:
colfax_cutlass kernels will fail because of C++ template instantiation.
We need to explicitly include the header file to instantiate all template parameters.

Pull Request resolved: #2401

Test Plan:
Install the colfax_cutlass operators:
```
python install.py --userbenchmark triton --cutlass
/home/xz/git/benchmark/submodules/cutlass-kernels/src/fmha/fmha_forward.cu(826): warning #117: non-void function "main" should return a value
      return;
            ^

Remark: The warnings can be suppressed with "-diag-suppress <warning-number>"

/home/xz/git/benchmark/submodules/cutlass-kernels/src/fmha/fmha_forward.cu(826): warning #117: non-void function "main" should return a value
      return;
            ^

Remark: The warnings can be suppressed with "-diag-suppress <warning-number>"
```

Run the flash_attention operator from colfax_cutlass

```
python run_benchmark.py triton --op flash_attention --only colfax_cutlass --num-inputs 1

  (Batch, Heads, SeqLen, Dhead)    colfax_cutlass-latency
-------------------------------  ------------------------
              (32, 32, 512, 64)                  0.001024
```

Reviewed By: manman-ren

Differential Revision: D60557212

Pulled By: xuzhao9

fbshipit-source-id: 25b216f850d2e82815041059d372627806bfd3ca
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Aug 1, 2024
1 parent f4ed185 commit 0a2ff22
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 50 deletions.
4 changes: 2 additions & 2 deletions torchbenchmark/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@
# colfax Flash Attention V2 for Hopper
torch.ops.load_library("//ai_codesign/gen_ai/cutlass-kernels:fmha_forward_lib")
else:
from userbenchmark.triton.utils import load_library
load_library("colfax_cutlass/fmha_forward_lib.so")
from userbenchmark.triton.loader import load_library
load_library("cutlass_kernels/fmha_forward_lib.so")
colfax_cutlass_fmha = torch.ops.cutlass.fmha_forward
except (ImportError, IOError, AttributeError):
colfax_cutlass_fmha = None
Expand Down
24 changes: 0 additions & 24 deletions userbenchmark/triton/cutlass_kernels/include/fmha_forward.h

This file was deleted.

7 changes: 3 additions & 4 deletions userbenchmark/triton/cutlass_kernels/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
CUDA_HOME = "/usr/local/cuda" if not "CUDA_HOME" in os.environ else os.environ["CUDA_HOME"]
REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent.parent
FBGEMM_PATH = REPO_PATH.joinpath("submodules", "FBGEMM", "fbgemm_gpu")
FBGEMM_CUTLASS_PATH = FBGEMM_PATH.parent.joinpath("third_party", "cutlass")
FBGEMM_CUTLASS_PATH = FBGEMM_PATH.parent.joinpath("external", "cutlass")
COLFAX_CUTLASS_PATH = REPO_PATH.joinpath("submodules", "cutlass-kernels")
COLFAX_CUTLASS_TRITONBENCH_PATH = REPO_PATH.joinpath("userbenchmark", "triton", "cutlass_kernels")

Expand Down Expand Up @@ -37,6 +37,7 @@
COMPILER_FLAGS = [
f"-I{str(COLFAX_CUTLASS_PATH.joinpath('lib').resolve())}",
f"-I{str(COLFAX_CUTLASS_PATH.joinpath('include').resolve())}",
f"-I{str(COLFAX_CUTLASS_PATH.joinpath('src', 'fmha').resolve())}",
f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('include').resolve())}",
f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('examples', 'commmon').resolve())}",
f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('tools', 'util', 'include').resolve())}",
Expand All @@ -63,9 +64,7 @@
"-ldl",
]
FMHA_SOURCES = [
# Source 1
f"{str(COLFAX_CUTLASS_PATH.joinpath('src', 'fmha', 'fmha_forward.cu').resolve())}",
# Source 2
# Source
f"{str(COLFAX_CUTLASS_TRITONBENCH_PATH.joinpath('src', 'fmha', 'register_op.cu').resolve())}",
"-o",
"fmha_forward_lib.so",
Expand Down
37 changes: 17 additions & 20 deletions userbenchmark/triton/cutlass_kernels/src/fmha/register_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,17 @@

// #include "autogen/cutlassF.h"
#include "pytorch_utils.h"
#include "fmha_forward.h"
#include "fmha_forward.cu"

template <typename PrecType, typename OutputType, int HEADDIM>
template <typename PrecType, int HEADDIM>
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
fmha_forward(
const int64_t& seq_length,
const int64_t& key_length,
const int64_t& batch,
const at::Tensor& query, // [b, seqlen, num_heads, K]
const at::Tensor& key, // [b, seqlen, num_heads, K]
const at::Tensor& value, // [b, seqlen, num_heads, Kv]
at::Tensor& value, // [b, seqlen, num_heads, Kv]
const float& scale) {
TORCH_CHECK(query.dim() == 4);
TORCH_CHECK(key.dim() == 4);
Expand Down Expand Up @@ -70,7 +70,7 @@ fmha_forward(
query.options().dtype(CutlassToAtenDtype<PrecType>::atScalarType()));
at::Tensor ret = at::empty(
{B, M, num_heads, Kv},
query.options().dtype(CutlassToAtenDtype<OutputType>::atScalarType()));
query.options().dtype(CutlassToAtenDtype<PrecType>::atScalarType()));
using AccumType = float; // AccumType is always float.

at::Tensor devMiOut = at::empty(
Expand All @@ -80,16 +80,16 @@ fmha_forward(
{B, M, num_heads},
query.options().dtype(CutlassToAtenDtype<AccumType>::atScalarType()));

fmhaForwardDevice<PrecType, OutputType, AccumType, HEADDIM>(
fmhaForwardDevice<PrecType, AccumType, HEADDIM>(
seq_length,
key_length,
num_heads,
B,
reinterpret_cast<PrecType const*>(query.data_ptr()),
reinterpret_cast<PrecType const*>(key.data_ptr()),
reinterpret_cast<OutputType const*>(value.data_ptr()),
reinterpret_cast<OutputType*>(S.data_ptr()),
reinterpret_cast<OutputType*>(ret.data_ptr()),
reinterpret_cast<PrecType*>(value.data_ptr()),
reinterpret_cast<PrecType*>(S.data_ptr()),
reinterpret_cast<PrecType*>(ret.data_ptr()),
reinterpret_cast<AccumType*>(devMiOut.data_ptr()),
reinterpret_cast<AccumType*>(devSprimeOut.data_ptr()),
1,
Expand All @@ -99,25 +99,25 @@ fmha_forward(
return std::make_tuple(S, ret, devMiOut, devSprimeOut);
}

template<typename compute_data_type, typename output_data_type>
template<typename compute_data_type>
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
launch_forward(
const int64_t& seq_length,
const int64_t& key_length,
const int64_t& batch,
const at::Tensor& query, // [b, seqlen, num_heads, K]
const at::Tensor& key, // [b, seqlen, num_heads, K]
const at::Tensor& value, // [b, seqlen, num_heads, Kv]
at::Tensor& value, // [b, seqlen, num_heads, Kv]
const double& scale,
const int64_t& Kdim) {
if (Kdim == 64) {
return fmha_forward<compute_data_type, output_data_type, 64>(
return fmha_forward<compute_data_type, 64>(
seq_length, key_length, batch, query, key, value, scale);
} else if (Kdim == 128) {
return fmha_forward<compute_data_type, output_data_type, 128>(
return fmha_forward<compute_data_type, 128>(
seq_length, key_length, batch, query, key, value, scale);
} else if (Kdim == 256) {
return fmha_forward<compute_data_type, output_data_type, 256>(
return fmha_forward<compute_data_type, 256>(
seq_length, key_length, batch, query, key, value, scale);
}
throw std::runtime_error("Kdim wrong");
Expand All @@ -131,18 +131,15 @@ fmha_forward_dispatch(
const int64_t& batch,
const at::Tensor& query, // [b, seqlen, num_heads, K]
const at::Tensor& key, // [b, seqlen, num_heads, K]
const at::Tensor& value, // [b, seqlen, num_heads, Kv]
at::Tensor& value, // [b, seqlen, num_heads, Kv]
const double& scale) {
int64_t Kdim = query.size(-1);

if (query.scalar_type() == at::kHalf){
return launch_forward<cutlass::half_t, cutlass::half_t>(seq_length, key_length, batch, query, key, value, scale, Kdim);
return launch_forward<cutlass::half_t>(seq_length, key_length, batch, query, key, value, scale, Kdim);
}
else if (query.scalar_type() == at::kBFloat16){
return launch_forward<cutlass::bfloat16_t, cutlass::bfloat16_t>(seq_length, key_length, batch, query, key, value, scale, Kdim);
}
else if (query.scalar_type() == at::kFloat8_e4m3fn){
return launch_forward<cutlass::float_e4m3_t, cutlass::bfloat16_t>(seq_length, key_length, batch, query, key, value, scale, Kdim);
return launch_forward<cutlass::bfloat16_t>(seq_length, key_length, batch, query, key, value, scale, Kdim);
}
else {
std::cout << "unsupported data type: " << query.scalar_type() << std::endl;
Expand All @@ -159,7 +156,7 @@ fmha_forward_dispatch_meta(
const int64_t& batch,
const at::Tensor& query, // [b, seqlen, num_heads, K]
const at::Tensor& key, // [b, seqlen, num_heads, K]
const at::Tensor& value, // [b, seqlen, num_heads, Kv]
at::Tensor& value, // [b, seqlen, num_heads, Kv]
const double& scale) {

TORCH_CHECK(query.dim() == 4);
Expand Down

0 comments on commit 0a2ff22

Please sign in to comment.