Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] upgrade cutlass to 3.5.0 #20940

Merged
merged 13 commits into from
Jun 11, 2024
2 changes: 1 addition & 1 deletion cgmanifests/generated/cgmanifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@
"component": {
"type": "git",
"git": {
"commitHash": "6f47420213f757831fae65c686aa471749fa8d60",
"commitHash": "7d49e6c7e2f8896c47f586706e67e1fb215529dc",
"repositoryUrl": "https://github.com/NVIDIA/cutlass.git"
},
"comments": "cutlass"
Expand Down
5 changes: 5 additions & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ else()
set(CMAKE_CXX_STANDARD 17)
endif()

if (MSVC)
# Make sure Visual Studio sets __cplusplus macro correctly: https://learn.microsoft.com/en-us/cpp/build/reference/zc-cplusplus
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:__cplusplus")
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
endif()

set_property(GLOBAL PROPERTY USE_FOLDERS ON)
# NOTE: POSITION INDEPENDENT CODE hurts performance, and it only make sense on POSIX systems
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
Expand Down
2 changes: 1 addition & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/959002f82d7962a473d8b
re2;https://github.com/google/re2/archive/refs/tags/2024-05-01.tar.gz;206cfee5ee0b4c6844680ba66275e9e8faa77405
safeint;https://github.com/dcleblanc/SafeInt/archive/refs/tags/3.0.28.zip;23f252040ff6cb9f1fd18575b32fa8fb5928daac
tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381
cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.1.0.zip;757f90a795034a89d4f48a79d1f009f7a04c8dee
cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.5.0.zip;ae038931b9fc2c416c17d9cda91d9706b343f56d
utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156
extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c
composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299
Expand Down
4 changes: 4 additions & 0 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,10 @@
endif()
endif()

if(MSVC)
target_compile_options(${target} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:-Xcompiler /Zc:__cplusplus>")
endif()

onnxruntime_add_include_to_target(${target} onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers)
if (onnxruntime_ENABLE_TRAINING_OPS)
onnxruntime_add_include_to_target(${target} onnxruntime_training)
Expand Down
105 changes: 85 additions & 20 deletions onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,23 @@
namespace contrib {
namespace cuda {

// TODO: remove this flag and unused code after testing.

Check warning on line 19 in onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h:19: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
#define USE_MEMORY_EFFICIENT_TO_BATCH_HOOK 0

#if USE_MEMORY_EFFICIENT_TO_BATCH_HOOK
struct GQAToBatchHook {
template <typename Params>
CUTLASS_DEVICE static bool advance_to_batch(Params& p, int64_t& q_start, int64_t& k_start) {
auto batch_id = blockIdx.z;
q_start = batch_id* p.num_queries const int64_t max_sequence_length = p.v_strideB / p.v_strideM;
const bool is_kv_bsnh = (p.k_strideH == p.head_dim && p.k_strideM == p.num_heads * p.head_dim);
k_start = batch_id * (is_kv_bsnh ? max_sequence_length : p.num_heads * max_sequence_length);
return true;
}
};

#else

template <typename AttentionKernel, int kQueriesPerBlock>
struct RightPaddingBatchHook {
using scalar_t = typename AttentionKernel::scalar_t;
Expand Down Expand Up @@ -51,18 +68,34 @@
return false;
}

// TODO: use GroupQueryAttentionToBatchHook

Check warning on line 71 in onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h:71: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
bool is_kv_bsnh = (p.k_strideH == p.head_dim && p.k_strideM == p.num_heads * p.head_dim);
int64_t q_start = batch_id * p.num_queries;
const int64_t max_sequence_length = p.v_strideB / p.v_strideM;
int64_t k_start = batch_id * (is_kv_bsnh ? max_sequence_length : p.num_heads * max_sequence_length);

// Advance to the current batch / head / query_start
p.query_ptr += batch_id * p.q_strideB + query_start * p.q_strideM + head_id * p.q_strideH;
p.key_ptr += batch_id * p.k_strideB + head_id * p.k_strideH;
p.value_ptr += batch_id * p.v_strideB + head_id * p.v_strideH;
p.output_ptr += int64_t(batch_id * p.num_queries) * p.o_strideM + int64_t(query_start) * p.o_strideM + head_id * p.head_dim_value;
// p.query_ptr += batch_id * p.q_strideB + query_start * p.q_strideM + head_id * p.q_strideH;
p.query_ptr += (q_start + query_start) * p.q_strideM + head_id * p.q_strideH;

// p.key_ptr += batch_id * p.k_strideB + head_id * p.k_strideH;
p.key_ptr += k_start * p.k_strideM + head_id * p.k_strideH;

// p.value_ptr += batch_id * p.v_strideB + head_id * p.v_strideH;
p.value_ptr += k_start * p.v_strideM + head_id * p.v_strideH;

// p.output_ptr += int64_t(batch_id * p.num_queries) * p.o_strideM + int64_t(query_start) * p.o_strideM
// + head_id * p.head_dim_value;
p.output_ptr += int64_t(q_start + query_start) * p.o_strideM + head_id * p.head_dim_value;

if (kSupportsBias && p.attn_bias_ptr != nullptr) {
p.attn_bias_ptr += (batch_id * p.bias_strideB) + (head_id * p.bias_strideH);
}
if (p.output_accum_ptr != nullptr) {
p.output_accum_ptr += int64_t(batch_id * p.num_queries) * (p.head_dim_value * p.num_heads) +
int64_t(query_start) * (p.head_dim_value * p.num_heads) +
// p.output_accum_ptr += int64_t(batch_id * p.num_queries) * (p.head_dim_value * p.num_heads) +
// int64_t(query_start) * (p.head_dim_value * p.num_heads) +
// head_id * p.head_dim_value;
p.output_accum_ptr += int64_t(q_start + query_start) * (p.head_dim_value * p.num_heads) +
head_id * p.head_dim_value;
} else {
// Accumulate directly in the destination buffer (eg for f32)
Expand All @@ -76,11 +109,11 @@
}

// Custom masking
if (p.causal_diagonal_ptr) {
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
p.causal_diagonal_offset = p.causal_diagonal_ptr[batch_id];
}
// if (p.causal_diagonal_ptr) {
// p.causal_diagonal_offset = p.causal_diagonal_ptr[batch_id];
// }
if (p.custom_mask_type == AttentionKernel::CausalFromBottomRight) {
p.causal_diagonal_offset += p.num_keys - p.num_queries;
p.causal_diagonal_offset = p.num_keys - p.num_queries;
}
if (p.custom_mask_type == AttentionKernel::CausalFromTopLeft ||
p.custom_mask_type == AttentionKernel::CausalFromBottomRight) {
Expand Down Expand Up @@ -126,8 +159,8 @@
p.num_queries = warp_uniform(p.num_queries);
p.num_keys = warp_uniform(p.num_keys);
p.num_heads = warp_uniform(p.num_heads);
p.head_dim = warp_uniform(p.head_dim);
p.head_dim_value = warp_uniform(p.head_dim_value);
// p.head_dim = warp_uniform(p.head_dim);
// p.head_dim_value = warp_uniform(p.head_dim_value);
p.o_strideM = warp_uniform(p.o_strideM);
p.custom_mask_type = warp_uniform(p.custom_mask_type);
return true;
Expand All @@ -142,10 +175,14 @@
}
AK::attention_kernel(p);
}
#endif

template <typename T, typename ArchTag, bool is_aligned, int queries_per_block, int keys_per_block, bool single_value_iteration>
#if USE_MEMORY_EFFICIENT_TO_BATCH_HOOK == 0
template <typename T, typename Attention, int queries_per_block>
#else
template <typename T, typename Attention>
#endif
void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
using Attention = AttentionKernel<T, ArchTag, is_aligned, queries_per_block, keys_per_block, single_value_iteration>;
typename Attention::Params p;
{ // set parameters
p.query_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.query));
Expand Down Expand Up @@ -220,9 +257,12 @@
}

auto kernel_fn = attention_kernel_batched_impl<Attention>;

#if USE_MEMORY_EFFICIENT_TO_BATCH_HOOK == 0
if (params.has_custom_right_padding) {
kernel_fn = attention_kernel_batched_impl_right_padding<Attention, queries_per_block>;
}
#endif

int smem_bytes = sizeof(typename Attention::SharedStorage);
if (smem_bytes > 0xc000) {
Expand All @@ -237,20 +277,45 @@
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, params.stream>>>(p);
}

template <typename T, typename ArchTag, int queries_per_block, int keys_per_block, bool single_value_iteration>
template <typename T, typename ArchTag, bool is_aligned, int queries_per_block, int keys_per_block, int max_head_size>
void RunCutlassFmha(const MemoryEfficientAttentionParams& params) {
constexpr bool kSupportsDropout = false;
constexpr bool kSupportsBias = true;

#if USE_MEMORY_EFFICIENT_TO_BATCH_HOOK
if (params.has_custom_right_padding) {
using Attention = AttentionKernel<T, ArchTag, is_aligned, queries_per_block, keys_per_block, max_head_size,
kSupportsDropout, kSupportsBias, GQAToBatchHook>;
LaunchCutlassFmha<T, Attention>(params);
} else {
using Attention = AttentionKernel<T, ArchTag, is_aligned, queries_per_block, keys_per_block, max_head_size,
kSupportsDropout, kSupportsBias, DefaultToBatchHook>;
LaunchCutlassFmha<T, Attention>(params);
}
#else
using Attention = AttentionKernel<T, ArchTag, is_aligned, queries_per_block, keys_per_block, max_head_size,
kSupportsDropout, kSupportsBias, DefaultToBatchHook>;
LaunchCutlassFmha<T, Attention, queries_per_block>(params);
#endif
}

template <typename T, typename ArchTag, int queries_per_block, int keys_per_block, int max_head_size>
void DispatchIsAligned(const MemoryEfficientAttentionParams& params) {
using AlignedAK = AttentionKernel<T, ArchTag, true, queries_per_block, keys_per_block, single_value_iteration>;
using AlignedAK = AttentionKernel<T, ArchTag, true, queries_per_block, keys_per_block, max_head_size>;
#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(push)
#pragma warning(disable : 6287 4189) // kAligned is used via capture so 4189 warning seems incorrect
#endif

// Run a more efficient kernel with `isAligned=True` when memory is correctly aligned.
bool is_aligned = params.qk_head_size % AlignedAK::kAlignmentQ == 0 &&
params.qk_head_size % AlignedAK::kAlignmentK == 0 &&
params.v_head_size % AlignedAK::kAlignmentV == 0;

DISPATCH_BOOL(is_aligned, kIsAligned, ([&]() {
LaunchCutlassFmha<T, ArchTag, kIsAligned, queries_per_block, keys_per_block, single_value_iteration>(params);
RunCutlassFmha<T, ArchTag, kIsAligned, queries_per_block, keys_per_block, max_head_size>(params);
}));

#if defined(_MSC_VER) && !defined(__clang__)
#pragma warning(pop)
#endif
Expand All @@ -259,11 +324,11 @@
template <typename T, typename ArchTag>
void DispatchBlockSize(const MemoryEfficientAttentionParams& params) {
if (params.v_head_size <= 64) {
DispatchIsAligned<T, ArchTag, 64, 64, true>(params);
DispatchIsAligned<T, ArchTag, 64, 64, 64>(params);
} else if (params.v_head_size <= 128) {
DispatchIsAligned<T, ArchTag, 32, 128, true>(params);
DispatchIsAligned<T, ArchTag, 32, 128, 128>(params);
} else {
DispatchIsAligned<T, ArchTag, 32, 128, false>(params);
DispatchIsAligned<T, ArchTag, 32, 128, 65536>(params);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ steps:
packageType: upack
feed: '/7424c8e4-5c62-490e-95c4-79446f31017c'
definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0'
version: 1.0.156
version: 1.0.157
downloadPath: $(Build.BinariesDirectory)/deps

# The private ADO project
Expand All @@ -22,7 +22,7 @@ steps:
packageType: upack
feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325'
definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a'
version: 1.0.156
version: 1.0.157
downloadPath: $(Build.BinariesDirectory)/deps

# You can add more ADO accounts at here.
Loading