Skip to content

Commit

Permalink
Add colfax_cutlass backend to flash_attention operator (#2296)
Browse files Browse the repository at this point in the history
Summary:
Add colfax_cutlass kernel compilation:
```
$ python install.py --userbenchmark triton --cutlass
```

Run with sdpa, triton_tutorial_flash_v2, and colfax_cutlass on H100:
```
$ python run_benchmark.py triton --op flash_attention --only sdpa,triton_tutorial_flash_v2,colfax_cutlass --batch 128 --input-id 3 --num-inputs 5 --n-heads 8 --d-head 128 --metrics latency,tflops
  SeqLen    sdpa-latency    sdpa-tflops    triton_tutorial_flash_v2-latency    triton_tutorial_flash_v2-tflops    colfax_cutlass-latency    colfax_cutlass-tflops
--------  --------------  -------------  ----------------------------------  ---------------------------------  ------------------------  -----------------------
    1024         1.91248        287.457                             1.55574                            353.372                   1.38538                  396.828
    2048         7.49987        293.208                             5.70656                            385.35                    5.4792                   401.34
    4096        29.4748         298.428                            21.7369                             404.662                  20.8335                   422.21
    8192       122.297          287.696                            85.1293                             413.305                  82.3884                   427.055
   16384       462.649          304.199                           334.992                              420.122                 328.363                    428.604
```

Pull Request resolved: #2296

Reviewed By: aaronenyeshi

Differential Revision: D58671502

Pulled By: xuzhao9

fbshipit-source-id: 38cba58463c6783c535eda3c11e5a75707ef9730
  • Loading branch information
xuzhao9 authored and facebook-github-bot committed Jun 17, 2024
1 parent 48223b8 commit d5f0a12
Show file tree
Hide file tree
Showing 10 changed files with 510 additions and 26 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@
[submodule "submodules/FBGEMM"]
path = submodules/FBGEMM
url = https://github.com/pytorch/FBGEMM.git
[submodule "submodules/cutlass-kernels"]
path = submodules/cutlass-kernels
url = https://github.com/ColfaxResearch/cutlass-kernels.git
1 change: 1 addition & 0 deletions submodules/cutlass-kernels
Submodule cutlass-kernels added at c796d7
63 changes: 37 additions & 26 deletions torchbenchmark/operators/flash_attention/operator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
This benchmark script is based on the benchmark code from:
Expand All @@ -10,17 +14,18 @@
https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
* Flash-V2: the FA-V2 from //ai_codesign/gen_ai/flash_attention_v2:flash_attention_v2,
* SDPA: the torch.nn.attention version of FA-V2
* [optional] Flash-V2: the FA-V2 from //ai_codesign/gen_ai/flash_attention_v2:flash_attention_v2,
which was imported from https://github.com/Dao-AILab/flash-attention
* Xformers: the memory-efficient attention from xformers:
* [optional] Xformers: the memory-efficient attention from xformers:
https://fburl.com/code/cuorcm9h
* [optional] Xformers-Splitk: the triton-splitk FMHA kernel from xformers:
https://fburl.com/code/awt36vjj
Disabled by default because it failed with some configs. Note that
the relevant benchmark only works with causal = False at the moment.
Known to work with "--batch=8 --n-heads=8 --xformers-splitk"
Expand All @@ -29,34 +34,44 @@
import argparse
import math
import os
from typing import Callable, Optional

import numpy

import torch
import triton # @manual=//triton:triton
import xformers # @manual=//fair/xformers:xformers
import xformers.ops.fmha as xformers_fmha # @manual=//fair/xformers:xformers

from triton.ops.flash_attention import attention as triton_op_FA2
from torchbenchmark.util.kernels.triton_fused_attention import attention as triton_tutorial_FA2
from aikl.gpu.triton.tests.test_fmha_utils import (
generate_qkv,
make_packed_qkv,
permute_qkv,
)
from flash_attn.flash_attn_interface import flash_attn_qkvpacked_func as flash_attn_func
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.nn.functional import scaled_dot_product_attention as sdpa
from triton.ops.flash_attention import attention as triton_op_FA2

from typing import Callable, Optional

# [Optional] flash_attn_func
try:
# colfax Flash Attention V2 for Hopper
# https://www.internalfb.com/code/fbsource/fbcode/ai_codesign/gen_ai/cutlass-kernels/src/fmha/README.md
torch.ops.load_library("//ai_codesign/gen_ai/cutlass-kernels:fmha_forward_lib")
from .test_fmha_utils import make_packed_qkv
from flash_attn.flash_attn_interface import flash_attn_qkvpacked_func as flash_attn_func
except (ImportError, IOError, AttributeError):
pass

# [Optional] xformers backend
try:
import xformers # @manual=//fair/xformers:xformers
import xformers.ops.fmha as xformers_fmha # @manual=//fair/xformers:xformers
from .test_fmha_utils import permute_qkv
except (ImportError, IOError, AttributeError):
pass

# [Optional] colfax cutlass backend
try:
if not hasattr(torch.version, "git_version"):
# colfax Flash Attention V2 for Hopper
torch.ops.load_library("//ai_codesign/gen_ai/cutlass-kernels:fmha_forward_lib")
else:
from userbenchmark.triton.utils import load_library
load_library("colfax_cutlass/fmha_forward_lib.so")
colfax_cutlass_fmha = torch.ops.cutlass.fmha_forward
except (ImportError, IOError, AttributeError):
colfax_cutlass_fmha = None

# [Optional] ThunderKittens backend
try:
import h100_fwd as tk_fwd
import h100_fwd_causal as tk_fwd_causal
Expand All @@ -65,10 +80,6 @@
tk_fwd_causal = None

from typing import Any, Generator, List

import torch
import triton
import triton.language as tl
from torchbenchmark.util.input import input_filter

from torchbenchmark.util.triton_op import (
Expand Down Expand Up @@ -200,15 +211,15 @@ def xformers_preprocess(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) -> xformers_fmha.Inputs:
):
q_1, k_1, v_1 = permute_qkv(q, k, v, perm=(0, 2, 1, 3))
attn_bias = xformers.ops.LowerTriangularMask() if self.causal else None
fhma_input = xformers_fmha.Inputs(
query=q_1, key=k_1, value=v_1, attn_bias=attn_bias, scale=self.sm_scale
)
return fhma_input

@register_benchmark()
@register_benchmark(enabled=False)
def xformers(
self,
q: torch.Tensor,
Expand Down
70 changes: 70 additions & 0 deletions torchbenchmark/operators/flash_attention/test_fmha_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Tuple

import torch


def generate_qkv(
BATCH: int,
H: int,
N_CTX: int,
D_HEAD: int,
dtype: torch.dtype,
device: str = "cuda",
requires_grad: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
torch.manual_seed(20)
q = torch.randn(
(BATCH, H, N_CTX, D_HEAD),
dtype=dtype,
device=device,
requires_grad=requires_grad,
)
k = torch.randn(
(BATCH, H, N_CTX, D_HEAD),
dtype=dtype,
device=device,
requires_grad=requires_grad,
)
v = torch.randn(
(BATCH, H, N_CTX, D_HEAD),
dtype=dtype,
device=device,
requires_grad=requires_grad,
)
return (q, k, v)


def permute_qkv(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
perm: Tuple[int],
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
q_1 = torch.permute(q, perm)
k_1 = torch.permute(k, perm)
v_1 = torch.permute(v, perm)
return (q_1, k_1, v_1)


def make_packed_qkv(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
) -> torch.Tensor:
"""
Make a packed qkv tensor for flash_attention:
from 3 * (batch, num_head, seq, head_dim) -> (batch, seq, 3, num_head, head_dim)
"""
assert (
q.size() == k.size() == v.size()
), f"{q.size()=}, {k.size()=}, {v.size()=} must be equal!"
(BATCH, H, N_CTX, D_HEAD) = q.size()
(q_1, k_1, v_1) = permute_qkv(q, k, v, perm=(0, 2, 1, 3))
qkv = torch.cat([q_1, k_1, v_1], dim=2)
return torch.reshape(qkv, (BATCH, N_CTX, 3, H, D_HEAD))
24 changes: 24 additions & 0 deletions userbenchmark/triton/cutlass_kernels/include/fmha_forward.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

template <typename PrecType, typename OutputType, typename AccumType, int HEADDIM>
void fmhaForwardDevice(
int SEQLEN,
int KEYLEN,
int NUMHEADS,
int BATCH,
PrecType const* tensorQ,
PrecType const* tensorK,
OutputType const* tensorV,
OutputType* tensorS,
OutputType* tensorO,
AccumType* miOut,
AccumType* sPrimeOut,
int iterations,
float scale,
cudaStream_t stream = 0);
48 changes: 48 additions & 0 deletions userbenchmark/triton/cutlass_kernels/include/pytorch_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include <cutlass/numeric_types.h>

template <typename scalar_t>
struct CutlassToAtenDtype;

template <>
struct CutlassToAtenDtype<cutlass::half_t> {
using scalar_t = cutlass::half_t;

static constexpr __host__ at::ScalarType atScalarType() {
return at::ScalarType::Half;
}
};

template <>
struct CutlassToAtenDtype<cutlass::bfloat16_t> {
using scalar_t = cutlass::bfloat16_t;

static constexpr __host__ at::ScalarType atScalarType() {
return at::ScalarType::BFloat16;
}
};

template <>
struct CutlassToAtenDtype<float> {
using scalar_t = float;

static constexpr __host__ at::ScalarType atScalarType() {
return at::ScalarType::Float;
}
};

template <>
struct CutlassToAtenDtype<cutlass::float_e4m3_t> {
using scalar_t = float;

static constexpr __host__ at::ScalarType atScalarType() {
return at::ScalarType::Float8_e4m3fn;
}
};
93 changes: 93 additions & 0 deletions userbenchmark/triton/cutlass_kernels/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import os
from pathlib import Path
import subprocess
import torch

CUDA_HOME = "/usr/local/cuda" if not "CUDA_HOME" in os.environ else os.environ["CUDA_HOME"]
REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent.parent
FBGEMM_PATH = REPO_PATH.joinpath("submodules", "FBGEMM", "fbgemm_gpu")
FBGEMM_CUTLASS_PATH = FBGEMM_PATH.parent.joinpath("third_party", "cutlass")
COLFAX_CUTLASS_PATH = REPO_PATH.joinpath("submodules", "cutlass-kernels")
COLFAX_CUTLASS_TRITONBENCH_PATH = REPO_PATH.joinpath("userbenchmark", "triton", "cutlass_kernels")

TORCH_BASE_PATH = Path(torch.__file__).parent

NVCC_GENCODE = "-gencode=arch=compute_90a,code=[sm_90a]"

NVCC_FLAGS = [
NVCC_GENCODE,
"--use_fast_math",
"-forward-unknown-to-host-compiler",
"-O3",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"-forward-unknown-to-host-compiler",
"--use_fast_math",
"-Xcompiler=-fno-strict-aliasing",
"-Xcompiler=-fPIE",
"-Xcompiler=-lcuda",
"-DNDEBUG",
"-DCUTLASS_TEST_LEVEL=0",
"-DCUTLASS_DEBUG_TRACE_LEVEL=0",
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1",
"-DCUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED=1",
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1",
"-D_GLIBCXX_USE_CXX11_ABI=0",
]
COMPILER_FLAGS = [
f"-I{str(COLFAX_CUTLASS_PATH.joinpath('lib').resolve())}",
f"-I{str(COLFAX_CUTLASS_PATH.joinpath('include').resolve())}",
f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('include').resolve())}",
f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('examples', 'commmon').resolve())}",
f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('tools', 'util', 'include').resolve())}",
f"-I{CUDA_HOME}/include",
f"-I{str(TORCH_BASE_PATH.joinpath('include').resolve())}",
f"-I{str(COLFAX_CUTLASS_TRITONBENCH_PATH.joinpath('include').resolve())}",
f"-Wl,-rpath,'{CUDA_HOME}/lib64'",
f"-Wl,-rpath,'{CUDA_HOME}/lib'",
]
LINKER_FLAGS = [
"--shared",
"-fPIC",
f"-L{str(TORCH_BASE_PATH.joinpath('lib').resolve())}",
"-ltorch",
"-ltorch_cuda",
"-lc10",
"-lc10_cuda",
"-lcuda",
"-lcudadevrt",
"-lcudart_static",
"-lcublas",
"-lrt",
"-lpthread",
"-ldl",
]
FMHA_SOURCES = [
# Source 1
f"{str(COLFAX_CUTLASS_PATH.joinpath('src', 'fmha', 'fmha_forward.cu').resolve())}",
# Source 2
f"{str(COLFAX_CUTLASS_TRITONBENCH_PATH.joinpath('src', 'fmha', 'register_op.cu').resolve())}",
"-o",
"fmha_forward_lib.so",
]


def test_colfax_cutlass(colfax_cutlass_lib: str):
assert os.path.exists(colfax_cutlass_lib), \
f"{colfax_cutlass_lib} should exist as the built cutlass kernel."
torch.ops.load_library(colfax_cutlass_lib)

def install_colfax_cutlass():
# compile colfax_cutlass kernels
output_dir = COLFAX_CUTLASS_TRITONBENCH_PATH.joinpath(".data")
output_dir.mkdir(parents=True, exist_ok=True)
cmd = ["nvcc"]
cmd.extend(COMPILER_FLAGS)
cmd.extend(NVCC_FLAGS)
cmd.extend(FMHA_SOURCES)
cmd.extend(LINKER_FLAGS)
print(" ".join(cmd))
print(str(output_dir.resolve()))
subprocess.check_call(cmd, cwd=str(output_dir.resolve()))
colfax_cutlass_lib = str(output_dir.joinpath(FMHA_SOURCES[-1]).resolve())
test_colfax_cutlass(colfax_cutlass_lib)
Loading

0 comments on commit d5f0a12

Please sign in to comment.