-
Notifications
You must be signed in to change notification settings - Fork 294
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add colfax_cutlass backend to flash_attention operator (#2296)
Summary: Add colfax_cutlass kernel compilation: ``` $ python install.py --userbenchmark triton --cutlass ``` Run with sdpa, triton_tutorial_flash_v2, and colfax_cutlass on H100: ``` $ python run_benchmark.py triton --op flash_attention --only sdpa,triton_tutorial_flash_v2,colfax_cutlass --batch 128 --input-id 3 --num-inputs 5 --n-heads 8 --d-head 128 --metrics latency,tflops SeqLen sdpa-latency sdpa-tflops triton_tutorial_flash_v2-latency triton_tutorial_flash_v2-tflops colfax_cutlass-latency colfax_cutlass-tflops -------- -------------- ------------- ---------------------------------- --------------------------------- ------------------------ ----------------------- 1024 1.91248 287.457 1.55574 353.372 1.38538 396.828 2048 7.49987 293.208 5.70656 385.35 5.4792 401.34 4096 29.4748 298.428 21.7369 404.662 20.8335 422.21 8192 122.297 287.696 85.1293 413.305 82.3884 427.055 16384 462.649 304.199 334.992 420.122 328.363 428.604 ``` Pull Request resolved: #2296 Reviewed By: aaronenyeshi Differential Revision: D58671502 Pulled By: xuzhao9 fbshipit-source-id: 38cba58463c6783c535eda3c11e5a75707ef9730
- Loading branch information
1 parent
48223b8
commit d5f0a12
Showing
10 changed files
with
510 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Submodule cutlass-kernels
added at
c796d7
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
70 changes: 70 additions & 0 deletions
70
torchbenchmark/operators/flash_attention/test_fmha_utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Tuple | ||
|
||
import torch | ||
|
||
|
||
def generate_qkv( | ||
BATCH: int, | ||
H: int, | ||
N_CTX: int, | ||
D_HEAD: int, | ||
dtype: torch.dtype, | ||
device: str = "cuda", | ||
requires_grad: bool = False, | ||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
torch.manual_seed(20) | ||
q = torch.randn( | ||
(BATCH, H, N_CTX, D_HEAD), | ||
dtype=dtype, | ||
device=device, | ||
requires_grad=requires_grad, | ||
) | ||
k = torch.randn( | ||
(BATCH, H, N_CTX, D_HEAD), | ||
dtype=dtype, | ||
device=device, | ||
requires_grad=requires_grad, | ||
) | ||
v = torch.randn( | ||
(BATCH, H, N_CTX, D_HEAD), | ||
dtype=dtype, | ||
device=device, | ||
requires_grad=requires_grad, | ||
) | ||
return (q, k, v) | ||
|
||
|
||
def permute_qkv( | ||
q: torch.Tensor, | ||
k: torch.Tensor, | ||
v: torch.Tensor, | ||
perm: Tuple[int], | ||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | ||
q_1 = torch.permute(q, perm) | ||
k_1 = torch.permute(k, perm) | ||
v_1 = torch.permute(v, perm) | ||
return (q_1, k_1, v_1) | ||
|
||
|
||
def make_packed_qkv( | ||
q: torch.Tensor, | ||
k: torch.Tensor, | ||
v: torch.Tensor, | ||
) -> torch.Tensor: | ||
""" | ||
Make a packed qkv tensor for flash_attention: | ||
from 3 * (batch, num_head, seq, head_dim) -> (batch, seq, 3, num_head, head_dim) | ||
""" | ||
assert ( | ||
q.size() == k.size() == v.size() | ||
), f"{q.size()=}, {k.size()=}, {v.size()=} must be equal!" | ||
(BATCH, H, N_CTX, D_HEAD) = q.size() | ||
(q_1, k_1, v_1) = permute_qkv(q, k, v, perm=(0, 2, 1, 3)) | ||
qkv = torch.cat([q_1, k_1, v_1], dim=2) | ||
return torch.reshape(qkv, (BATCH, N_CTX, 3, H, D_HEAD)) |
24 changes: 24 additions & 0 deletions
24
userbenchmark/triton/cutlass_kernels/include/fmha_forward.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
// Copyright (c) Meta Platforms, Inc. and affiliates. | ||
// All rights reserved. | ||
// | ||
// This source code is licensed under the BSD-style license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
#pragma once | ||
|
||
template <typename PrecType, typename OutputType, typename AccumType, int HEADDIM> | ||
void fmhaForwardDevice( | ||
int SEQLEN, | ||
int KEYLEN, | ||
int NUMHEADS, | ||
int BATCH, | ||
PrecType const* tensorQ, | ||
PrecType const* tensorK, | ||
OutputType const* tensorV, | ||
OutputType* tensorS, | ||
OutputType* tensorO, | ||
AccumType* miOut, | ||
AccumType* sPrimeOut, | ||
int iterations, | ||
float scale, | ||
cudaStream_t stream = 0); |
48 changes: 48 additions & 0 deletions
48
userbenchmark/triton/cutlass_kernels/include/pytorch_utils.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
// Copyright (c) Meta Platforms, Inc. and affiliates. | ||
// All rights reserved. | ||
// | ||
// This source code is licensed under the BSD-style license found in the | ||
// LICENSE file in the root directory of this source tree. | ||
|
||
#pragma once | ||
|
||
#include <cutlass/numeric_types.h> | ||
|
||
template <typename scalar_t> | ||
struct CutlassToAtenDtype; | ||
|
||
template <> | ||
struct CutlassToAtenDtype<cutlass::half_t> { | ||
using scalar_t = cutlass::half_t; | ||
|
||
static constexpr __host__ at::ScalarType atScalarType() { | ||
return at::ScalarType::Half; | ||
} | ||
}; | ||
|
||
template <> | ||
struct CutlassToAtenDtype<cutlass::bfloat16_t> { | ||
using scalar_t = cutlass::bfloat16_t; | ||
|
||
static constexpr __host__ at::ScalarType atScalarType() { | ||
return at::ScalarType::BFloat16; | ||
} | ||
}; | ||
|
||
template <> | ||
struct CutlassToAtenDtype<float> { | ||
using scalar_t = float; | ||
|
||
static constexpr __host__ at::ScalarType atScalarType() { | ||
return at::ScalarType::Float; | ||
} | ||
}; | ||
|
||
template <> | ||
struct CutlassToAtenDtype<cutlass::float_e4m3_t> { | ||
using scalar_t = float; | ||
|
||
static constexpr __host__ at::ScalarType atScalarType() { | ||
return at::ScalarType::Float8_e4m3fn; | ||
} | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import os | ||
from pathlib import Path | ||
import subprocess | ||
import torch | ||
|
||
CUDA_HOME = "/usr/local/cuda" if not "CUDA_HOME" in os.environ else os.environ["CUDA_HOME"] | ||
REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent.parent | ||
FBGEMM_PATH = REPO_PATH.joinpath("submodules", "FBGEMM", "fbgemm_gpu") | ||
FBGEMM_CUTLASS_PATH = FBGEMM_PATH.parent.joinpath("third_party", "cutlass") | ||
COLFAX_CUTLASS_PATH = REPO_PATH.joinpath("submodules", "cutlass-kernels") | ||
COLFAX_CUTLASS_TRITONBENCH_PATH = REPO_PATH.joinpath("userbenchmark", "triton", "cutlass_kernels") | ||
|
||
TORCH_BASE_PATH = Path(torch.__file__).parent | ||
|
||
NVCC_GENCODE = "-gencode=arch=compute_90a,code=[sm_90a]" | ||
|
||
NVCC_FLAGS = [ | ||
NVCC_GENCODE, | ||
"--use_fast_math", | ||
"-forward-unknown-to-host-compiler", | ||
"-O3", | ||
"--expt-relaxed-constexpr", | ||
"--expt-extended-lambda", | ||
"-forward-unknown-to-host-compiler", | ||
"--use_fast_math", | ||
"-Xcompiler=-fno-strict-aliasing", | ||
"-Xcompiler=-fPIE", | ||
"-Xcompiler=-lcuda", | ||
"-DNDEBUG", | ||
"-DCUTLASS_TEST_LEVEL=0", | ||
"-DCUTLASS_DEBUG_TRACE_LEVEL=0", | ||
"-DCUTLASS_TEST_ENABLE_CACHED_RESULTS=1", | ||
"-DCUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED=1", | ||
"-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1", | ||
"-D_GLIBCXX_USE_CXX11_ABI=0", | ||
] | ||
COMPILER_FLAGS = [ | ||
f"-I{str(COLFAX_CUTLASS_PATH.joinpath('lib').resolve())}", | ||
f"-I{str(COLFAX_CUTLASS_PATH.joinpath('include').resolve())}", | ||
f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('include').resolve())}", | ||
f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('examples', 'commmon').resolve())}", | ||
f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('tools', 'util', 'include').resolve())}", | ||
f"-I{CUDA_HOME}/include", | ||
f"-I{str(TORCH_BASE_PATH.joinpath('include').resolve())}", | ||
f"-I{str(COLFAX_CUTLASS_TRITONBENCH_PATH.joinpath('include').resolve())}", | ||
f"-Wl,-rpath,'{CUDA_HOME}/lib64'", | ||
f"-Wl,-rpath,'{CUDA_HOME}/lib'", | ||
] | ||
LINKER_FLAGS = [ | ||
"--shared", | ||
"-fPIC", | ||
f"-L{str(TORCH_BASE_PATH.joinpath('lib').resolve())}", | ||
"-ltorch", | ||
"-ltorch_cuda", | ||
"-lc10", | ||
"-lc10_cuda", | ||
"-lcuda", | ||
"-lcudadevrt", | ||
"-lcudart_static", | ||
"-lcublas", | ||
"-lrt", | ||
"-lpthread", | ||
"-ldl", | ||
] | ||
FMHA_SOURCES = [ | ||
# Source 1 | ||
f"{str(COLFAX_CUTLASS_PATH.joinpath('src', 'fmha', 'fmha_forward.cu').resolve())}", | ||
# Source 2 | ||
f"{str(COLFAX_CUTLASS_TRITONBENCH_PATH.joinpath('src', 'fmha', 'register_op.cu').resolve())}", | ||
"-o", | ||
"fmha_forward_lib.so", | ||
] | ||
|
||
|
||
def test_colfax_cutlass(colfax_cutlass_lib: str): | ||
assert os.path.exists(colfax_cutlass_lib), \ | ||
f"{colfax_cutlass_lib} should exist as the built cutlass kernel." | ||
torch.ops.load_library(colfax_cutlass_lib) | ||
|
||
def install_colfax_cutlass(): | ||
# compile colfax_cutlass kernels | ||
output_dir = COLFAX_CUTLASS_TRITONBENCH_PATH.joinpath(".data") | ||
output_dir.mkdir(parents=True, exist_ok=True) | ||
cmd = ["nvcc"] | ||
cmd.extend(COMPILER_FLAGS) | ||
cmd.extend(NVCC_FLAGS) | ||
cmd.extend(FMHA_SOURCES) | ||
cmd.extend(LINKER_FLAGS) | ||
print(" ".join(cmd)) | ||
print(str(output_dir.resolve())) | ||
subprocess.check_call(cmd, cwd=str(output_dir.resolve())) | ||
colfax_cutlass_lib = str(output_dir.joinpath(FMHA_SOURCES[-1]).resolve()) | ||
test_colfax_cutlass(colfax_cutlass_lib) |
Oops, something went wrong.