From 0a2ff2231c1030852dacfcaa5a4eb2a9d12fc197 Mon Sep 17 00:00:00 2001 From: Xu Zhao Date: Thu, 1 Aug 2024 07:02:49 -0700 Subject: [PATCH] Fix colfax_cutlass flash_attention operator (#2401) Summary: colfax_cutlass kernels will fail because of C++ template instantiation. We need to explicitly include the header file to instantiate all template parameters. Pull Request resolved: https://github.com/pytorch/benchmark/pull/2401 Test Plan: Install the colfax_cutlass operators: ``` python install.py --userbenchmark triton --cutlass /home/xz/git/benchmark/submodules/cutlass-kernels/src/fmha/fmha_forward.cu(826): warning https://github.com/pytorch/benchmark/issues/117-D: non-void function "main" should return a value return; ^ Remark: The warnings can be suppressed with "-diag-suppress " /home/xz/git/benchmark/submodules/cutlass-kernels/src/fmha/fmha_forward.cu(826): warning https://github.com/pytorch/benchmark/issues/117-D: non-void function "main" should return a value return; ^ Remark: The warnings can be suppressed with "-diag-suppress " ``` Run the flash_attention operator from colfax_cutlass ``` python run_benchmark.py triton --op flash_attention --only colfax_cutlass --num-inputs 1 (Batch, Heads, SeqLen, Dhead) colfax_cutlass-latency ------------------------------- ------------------------ (32, 32, 512, 64) 0.001024 ``` Reviewed By: manman-ren Differential Revision: D60557212 Pulled By: xuzhao9 fbshipit-source-id: 25b216f850d2e82815041059d372627806bfd3ca --- .../operators/flash_attention/operator.py | 4 +- .../cutlass_kernels/include/fmha_forward.h | 24 ------------ .../triton/cutlass_kernels/install.py | 7 ++-- .../cutlass_kernels/src/fmha/register_op.cu | 37 +++++++++---------- 4 files changed, 22 insertions(+), 50 deletions(-) delete mode 100644 userbenchmark/triton/cutlass_kernels/include/fmha_forward.h diff --git a/torchbenchmark/operators/flash_attention/operator.py b/torchbenchmark/operators/flash_attention/operator.py index 3e08f79eca..53e77670a8 100644 --- a/torchbenchmark/operators/flash_attention/operator.py +++ b/torchbenchmark/operators/flash_attention/operator.py @@ -83,8 +83,8 @@ # colfax Flash Attention V2 for Hopper torch.ops.load_library("//ai_codesign/gen_ai/cutlass-kernels:fmha_forward_lib") else: - from userbenchmark.triton.utils import load_library - load_library("colfax_cutlass/fmha_forward_lib.so") + from userbenchmark.triton.loader import load_library + load_library("cutlass_kernels/fmha_forward_lib.so") colfax_cutlass_fmha = torch.ops.cutlass.fmha_forward except (ImportError, IOError, AttributeError): colfax_cutlass_fmha = None diff --git a/userbenchmark/triton/cutlass_kernels/include/fmha_forward.h b/userbenchmark/triton/cutlass_kernels/include/fmha_forward.h deleted file mode 100644 index 35caedaf3a..0000000000 --- a/userbenchmark/triton/cutlass_kernels/include/fmha_forward.h +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once - -template -void fmhaForwardDevice( - int SEQLEN, - int KEYLEN, - int NUMHEADS, - int BATCH, - PrecType const* tensorQ, - PrecType const* tensorK, - OutputType const* tensorV, - OutputType* tensorS, - OutputType* tensorO, - AccumType* miOut, - AccumType* sPrimeOut, - int iterations, - float scale, - cudaStream_t stream = 0); diff --git a/userbenchmark/triton/cutlass_kernels/install.py b/userbenchmark/triton/cutlass_kernels/install.py index 471853a9d8..a4b9aaf39d 100644 --- a/userbenchmark/triton/cutlass_kernels/install.py +++ b/userbenchmark/triton/cutlass_kernels/install.py @@ -6,7 +6,7 @@ CUDA_HOME = "/usr/local/cuda" if not "CUDA_HOME" in os.environ else os.environ["CUDA_HOME"] REPO_PATH = Path(os.path.abspath(__file__)).parent.parent.parent.parent FBGEMM_PATH = REPO_PATH.joinpath("submodules", "FBGEMM", "fbgemm_gpu") -FBGEMM_CUTLASS_PATH = FBGEMM_PATH.parent.joinpath("third_party", "cutlass") +FBGEMM_CUTLASS_PATH = FBGEMM_PATH.parent.joinpath("external", "cutlass") COLFAX_CUTLASS_PATH = REPO_PATH.joinpath("submodules", "cutlass-kernels") COLFAX_CUTLASS_TRITONBENCH_PATH = REPO_PATH.joinpath("userbenchmark", "triton", "cutlass_kernels") @@ -37,6 +37,7 @@ COMPILER_FLAGS = [ f"-I{str(COLFAX_CUTLASS_PATH.joinpath('lib').resolve())}", f"-I{str(COLFAX_CUTLASS_PATH.joinpath('include').resolve())}", + f"-I{str(COLFAX_CUTLASS_PATH.joinpath('src', 'fmha').resolve())}", f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('include').resolve())}", f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('examples', 'commmon').resolve())}", f"-I{str(FBGEMM_CUTLASS_PATH.joinpath('tools', 'util', 'include').resolve())}", @@ -63,9 +64,7 @@ "-ldl", ] FMHA_SOURCES = [ - # Source 1 - f"{str(COLFAX_CUTLASS_PATH.joinpath('src', 'fmha', 'fmha_forward.cu').resolve())}", - # Source 2 + # Source f"{str(COLFAX_CUTLASS_TRITONBENCH_PATH.joinpath('src', 'fmha', 'register_op.cu').resolve())}", "-o", "fmha_forward_lib.so", diff --git a/userbenchmark/triton/cutlass_kernels/src/fmha/register_op.cu b/userbenchmark/triton/cutlass_kernels/src/fmha/register_op.cu index 94b1c4c484..1a590acf4f 100644 --- a/userbenchmark/triton/cutlass_kernels/src/fmha/register_op.cu +++ b/userbenchmark/triton/cutlass_kernels/src/fmha/register_op.cu @@ -21,9 +21,9 @@ // #include "autogen/cutlassF.h" #include "pytorch_utils.h" -#include "fmha_forward.h" +#include "fmha_forward.cu" -template +template std::tuple fmha_forward( const int64_t& seq_length, @@ -31,7 +31,7 @@ fmha_forward( const int64_t& batch, const at::Tensor& query, // [b, seqlen, num_heads, K] const at::Tensor& key, // [b, seqlen, num_heads, K] - const at::Tensor& value, // [b, seqlen, num_heads, Kv] + at::Tensor& value, // [b, seqlen, num_heads, Kv] const float& scale) { TORCH_CHECK(query.dim() == 4); TORCH_CHECK(key.dim() == 4); @@ -70,7 +70,7 @@ fmha_forward( query.options().dtype(CutlassToAtenDtype::atScalarType())); at::Tensor ret = at::empty( {B, M, num_heads, Kv}, - query.options().dtype(CutlassToAtenDtype::atScalarType())); + query.options().dtype(CutlassToAtenDtype::atScalarType())); using AccumType = float; // AccumType is always float. at::Tensor devMiOut = at::empty( @@ -80,16 +80,16 @@ fmha_forward( {B, M, num_heads}, query.options().dtype(CutlassToAtenDtype::atScalarType())); - fmhaForwardDevice( + fmhaForwardDevice( seq_length, key_length, num_heads, B, reinterpret_cast(query.data_ptr()), reinterpret_cast(key.data_ptr()), - reinterpret_cast(value.data_ptr()), - reinterpret_cast(S.data_ptr()), - reinterpret_cast(ret.data_ptr()), + reinterpret_cast(value.data_ptr()), + reinterpret_cast(S.data_ptr()), + reinterpret_cast(ret.data_ptr()), reinterpret_cast(devMiOut.data_ptr()), reinterpret_cast(devSprimeOut.data_ptr()), 1, @@ -99,7 +99,7 @@ fmha_forward( return std::make_tuple(S, ret, devMiOut, devSprimeOut); } -template +template std::tuple launch_forward( const int64_t& seq_length, @@ -107,17 +107,17 @@ launch_forward( const int64_t& batch, const at::Tensor& query, // [b, seqlen, num_heads, K] const at::Tensor& key, // [b, seqlen, num_heads, K] - const at::Tensor& value, // [b, seqlen, num_heads, Kv] + at::Tensor& value, // [b, seqlen, num_heads, Kv] const double& scale, const int64_t& Kdim) { if (Kdim == 64) { - return fmha_forward( + return fmha_forward( seq_length, key_length, batch, query, key, value, scale); } else if (Kdim == 128) { - return fmha_forward( + return fmha_forward( seq_length, key_length, batch, query, key, value, scale); } else if (Kdim == 256) { - return fmha_forward( + return fmha_forward( seq_length, key_length, batch, query, key, value, scale); } throw std::runtime_error("Kdim wrong"); @@ -131,18 +131,15 @@ fmha_forward_dispatch( const int64_t& batch, const at::Tensor& query, // [b, seqlen, num_heads, K] const at::Tensor& key, // [b, seqlen, num_heads, K] - const at::Tensor& value, // [b, seqlen, num_heads, Kv] + at::Tensor& value, // [b, seqlen, num_heads, Kv] const double& scale) { int64_t Kdim = query.size(-1); if (query.scalar_type() == at::kHalf){ - return launch_forward(seq_length, key_length, batch, query, key, value, scale, Kdim); + return launch_forward(seq_length, key_length, batch, query, key, value, scale, Kdim); } else if (query.scalar_type() == at::kBFloat16){ - return launch_forward(seq_length, key_length, batch, query, key, value, scale, Kdim); - } - else if (query.scalar_type() == at::kFloat8_e4m3fn){ - return launch_forward(seq_length, key_length, batch, query, key, value, scale, Kdim); + return launch_forward(seq_length, key_length, batch, query, key, value, scale, Kdim); } else { std::cout << "unsupported data type: " << query.scalar_type() << std::endl; @@ -159,7 +156,7 @@ fmha_forward_dispatch_meta( const int64_t& batch, const at::Tensor& query, // [b, seqlen, num_heads, K] const at::Tensor& key, // [b, seqlen, num_heads, K] - const at::Tensor& value, // [b, seqlen, num_heads, Kv] + at::Tensor& value, // [b, seqlen, num_heads, Kv] const double& scale) { TORCH_CHECK(query.dim() == 4);