Skip to content

Commit

Permalink
Add sparse attention kernel for H100 (sm90) (#20553)
Browse files Browse the repository at this point in the history
### Description
Follow up of #20216 to add
sparse attention kernel compiled by Triton for H100 (sm90).
- [x] Refine sparse attention v1 kernel compilation (remove some
combinations)
- [x] compile kernels for v1 kernels
- [x] compile kernels for H100
- [x] run performance tests

### Performane

Test setting `batch_size=4, num_heads=32, max_seq_len=8192,
head_size=128, sparse_block_size=64, local_blocks=16, vert_stride=8,
num_layout=8`

We compare sparse attention to corresponding GQA with local attention
windows size 1024, or GQA with dense causal. Note that ORT-GQA-Dense has
more computation than ORT-SparseAtt, while ORT-GQA-Local has less
computation (no vertial strides) than ORT-SparseAtt. They are added for
reference. It is not fair comparison, but could show the benefit of
sparsity vs dense.

Example results in Azure Standard_ND96isr_H100_v5 VM with NVIDIA
H100-80GB-HBM3 GPU (sm=90):
```
    prompt-sm90-batch4-head32-d128-local16-vert8-torch.float16:
       sequence_length  TORCH-GQA  ORT-GQA-Dense  ORT-GQA-Local  ORT-SparseAtt
    0             16.0   0.079877       0.006362       0.006403       0.042758
    1             32.0   0.086920       0.016404       0.016686       0.044183
    2             64.0   0.090727       0.020429       0.020409       0.045343
    3            128.0   0.128148       0.032009       0.031984       0.051516
    4            256.0   0.323933       0.074110       0.073920       0.068308
    5            512.0   1.021856       0.162167       0.161951       0.109226
    6           1024.0   3.596002       0.452629       0.452780       0.231653
    7           2048.0  13.865088       1.499534       1.195749       0.515488
    8           4096.0   0.000000       5.454785       2.669682       1.163233
    9           8192.0   0.000000      22.068159       6.018604       2.772873

    token-sm90-batch4-head32-d128-local16-vert8-torch.float16:
       past_sequence_length  TORCH-GQA  ORT-GQA-Dense  ORT-GQA-Local  ORT-SparseAtt
    0                  16.0   0.104460       0.012652       0.012661       0.069549
    1                  32.0   0.113866       0.012776       0.012765       0.069024
    2                  64.0   0.124600       0.016791       0.012672       0.069397
    3                 128.0   0.108658       0.017900       0.018294       0.074844
    4                 256.0   0.115463       0.029409       0.029608       0.078911
    5                 512.0   0.149824       0.033968       0.033701       0.092998
    6                1024.0   0.234050       0.042930       0.042951       0.116920
    7                2048.0   0.390695       0.061462       0.043008       0.121555
    8                4096.0   0.000000       0.097505       0.042948       0.134757
    9                8191.0   0.000000       0.165861       0.043542       0.158796
```
The following might be able to help performance on short sequence
length. Need update operator spec:
 Fall back to flash attention when total_sequence length < local_blocks * block_size

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
tianleiwu authored May 5, 2024
1 parent cb37b1b commit baaef59
Show file tree
Hide file tree
Showing 47 changed files with 832 additions and 1,307 deletions.
10 changes: 5 additions & 5 deletions onnxruntime/contrib_ops/cuda/sparse/sparse_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,15 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
total_seq_len));
// Some limitations of CUDA kernels
// The v1 and v2 kernels have same coverage, so only check one of them to see whether it is supported.
if (!sparse_attention_v1::is_supported_device(device_prop)) {
int sm = device_prop.major * 10 + device_prop.minor;
if (!sparse_attention_v1::is_supported_device(sm)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"SparseAttention only support CUDA device with compute capacity 8.*. Got ",
device_prop.major);
"SparseAttention only supports CUDA device with compute capacity 7.5, 8.0, 8.6, 8.9 and 9.0. Got sm=",
sm);
}
if (!sparse_attention_v1::is_supported_sparse_attention(parameters.head_size, sparse_block_size_)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"SparseAttention only support head_size=128 and sparse_block_size=64. Got head_size=",
"SparseAttention only supports head_size=128 and sparse_block_size=64. Got head_size=",
parameters.head_size,
",sparse_block_size=",
sparse_block_size_);
Expand Down Expand Up @@ -149,7 +150,6 @@ Status SparseAttention<T>::ComputeInternal(OpKernelContext* context) const {
}

if (!kernel_loaded_) {
int sm = device_prop.major * 10 + device_prop.minor;
if constexpr (std::is_same<T, MLFloat16>::value) {
// std::call_once is used in load_sparse_attention_fp16 so no need to use mutex here.
// After kernel is loaded, it will stay in memory until the process exits. We do not unload explicitly.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# Use triton AoT compiler to convert sparse_attention_triton.py to C source files including cubin and dispatcher.
# Example to use this script (Tested with Python 3.10 and CUDA 12.3 in Ubuntu 20.04):
# python3 -m pip install torch==2.3.0 triton==2.3.0
# python3 -m pip install numpy==1.26.4 torch==2.3.0 triton==2.3.0
# python3 compile_sparse_attention.py | sh
#
# Note that sparse_attention_v1_*.cc and sparse_attention_dispatcher_*.h are the generated files.
Expand Down Expand Up @@ -35,31 +35,35 @@ def generate_triton_compile_shell_script(sm, dtype="fp16"):
print(f"rm -rf {out_dir}")
print(f"mkdir -p {out_dir}")

# Note that block_n * num_block_d is the head_size. We support head_size = 128 for now.
block_n_values = [64]
block_d_values = [64]
num_block_d_values = [2]
even_m_values = [True, False]
even_n_values = [True, False]

# Use triton compiler to compile the kernel of different combinations of constant parameters.
for block_n, block_d, num_blocks_d, even_m, even_n in product(
block_n_values, block_d_values, num_block_d_values, even_m_values, even_n_values
for block_n, block_d, num_blocks_d, even_n in product(
block_n_values, block_d_values, num_block_d_values, even_n_values
):
block_m_values = [16, block_n] if block_n != 16 else [block_n]
for block_m in block_m_values:
scalar_params = "i32,i32,i32,fp32,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32,i32,i32"
sig = f"*{dtype}:16,*{dtype}:16,*{dtype}:16,*{dtype}:16,*i32:16,*i32:16,{scalar_params},{block_m},{int(even_m)},{block_n},{int(even_n)},{block_d},{num_blocks_d}"
prefix = "python compile.py sparse_attention_triton.py"
filename = f"sparse_attention_v1_{dtype}_m{block_m}_{int(even_m)}_n{block_n}_{int(even_n)}_d{block_d}_{num_blocks_d}_sm{sm}"
name = f"sparse_attention_{dtype}_sm{sm}"
num_warps = max(1, 2 ** int(math.log2(min(block_m, block_n, block_d) / 16)))
num_stages = 2
# TODO: use different kernel name (change the name in sparse_attention_triton.py before running compile.py)
print(
f"{prefix} -n block_sparse_attention_kernel -o {out_dir}/{filename} --out-name {name} "
f'-w {num_warps} -ns {num_stages} -s "{sig}" -g "(total_seq_len - past_seq_len + {block_m} - 1) / {block_m}, batch_size * num_heads, 1"'
)
head_size = block_d * num_blocks_d
block_m = block_n
even_m = even_n
scalar_params = "i32,i32,i32,fp32,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32:16,i32,i32,i32"
sig = f"*{dtype}:16,*{dtype}:16,*{dtype}:16,*{dtype}:16,*i32:16,*i32:16,{scalar_params},{block_m},{int(even_m)},{block_n},{int(even_n)},{block_d},{num_blocks_d}"
prefix = "python compile.py sparse_attention_triton.py"
filename = f"sparse_attention_v1_{dtype}_d{head_size}_n{block_n}_e{int(even_n)}_sm{sm}"
name = f"sparse_attention_{dtype}_sm{sm}"
num_warps = max(1, 2 ** int(math.log2(min(block_m, block_n, block_d) / 16)))

# Shared memory is 96KB for V100 (sm70), 64KB for T4 (sm75), 164KB for A100 (sm80), 228KB for H100 (sm90).
# Adjust stages so that shared memory size is within limit, and choose the one with best performance.
sm_to_stages = {90: 3, 80: 2, 75: 2}

num_stages = sm_to_stages[sm]

# TODO: use different kernel name (change the name in sparse_attention_triton.py before running compile.py)
print(
f"{prefix} -n block_sparse_attention_kernel -o {out_dir}/{filename} --out-name {name} "
f'-w {num_warps} -ns {num_stages} -s "{sig}" -g "(total_seq_len - past_seq_len + {block_m} - 1) / {block_m}, batch_size * num_heads, 1"'
)

# Generate the dispatcher.
dispatcher = f"sparse_attention_dispatcher_{dtype}_sm{sm}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,82 +11,6 @@ namespace contrib {
namespace cuda {
namespace sparse_attention_v1 {

// launcher for: sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2
Status sparse_attention_bf16_sm80_ba65ff9c(SparseAttentionParams& params);

Status sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2(SparseAttentionParams& params) {
return sparse_attention_bf16_sm80_ba65ff9c(params);
}

// load for: sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2
void load_sparse_attention_bf16_sm80_ba65ff9c();
void load_sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2() {
load_sparse_attention_bf16_sm80_ba65ff9c();
}

// unload for: sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2
void unload_sparse_attention_bf16_sm80_ba65ff9c();
void unload_sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2() {
unload_sparse_attention_bf16_sm80_ba65ff9c();
}

// launcher for: sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2
Status sparse_attention_bf16_sm80_f951a16d(SparseAttentionParams& params);

Status sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2(SparseAttentionParams& params) {
return sparse_attention_bf16_sm80_f951a16d(params);
}

// load for: sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2
void load_sparse_attention_bf16_sm80_f951a16d();
void load_sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2() {
load_sparse_attention_bf16_sm80_f951a16d();
}

// unload for: sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2
void unload_sparse_attention_bf16_sm80_f951a16d();
void unload_sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2() {
unload_sparse_attention_bf16_sm80_f951a16d();
}

// launcher for: sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2
Status sparse_attention_bf16_sm80_646fefc8(SparseAttentionParams& params);

Status sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2(SparseAttentionParams& params) {
return sparse_attention_bf16_sm80_646fefc8(params);
}

// load for: sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2
void load_sparse_attention_bf16_sm80_646fefc8();
void load_sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2() {
load_sparse_attention_bf16_sm80_646fefc8();
}

// unload for: sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2
void unload_sparse_attention_bf16_sm80_646fefc8();
void unload_sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2() {
unload_sparse_attention_bf16_sm80_646fefc8();
}

// launcher for: sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2
Status sparse_attention_bf16_sm80_21cac990(SparseAttentionParams& params);

Status sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2(SparseAttentionParams& params) {
return sparse_attention_bf16_sm80_21cac990(params);
}

// load for: sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2
void load_sparse_attention_bf16_sm80_21cac990();
void load_sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2() {
load_sparse_attention_bf16_sm80_21cac990();
}

// unload for: sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2
void unload_sparse_attention_bf16_sm80_21cac990();
void unload_sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2() {
unload_sparse_attention_bf16_sm80_21cac990();
}

// launcher for: sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2
Status sparse_attention_bf16_sm80_31acb592(SparseAttentionParams& params);

Expand All @@ -106,44 +30,6 @@ void unload_sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2() {
unload_sparse_attention_bf16_sm80_31acb592();
}

// launcher for: sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2
Status sparse_attention_bf16_sm80_d55ab166(SparseAttentionParams& params);

Status sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2(SparseAttentionParams& params) {
return sparse_attention_bf16_sm80_d55ab166(params);
}

// load for: sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2
void load_sparse_attention_bf16_sm80_d55ab166();
void load_sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2() {
load_sparse_attention_bf16_sm80_d55ab166();
}

// unload for: sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2
void unload_sparse_attention_bf16_sm80_d55ab166();
void unload_sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2() {
unload_sparse_attention_bf16_sm80_d55ab166();
}

// launcher for: sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2
Status sparse_attention_bf16_sm80_b0560d11(SparseAttentionParams& params);

Status sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2(SparseAttentionParams& params) {
return sparse_attention_bf16_sm80_b0560d11(params);
}

// load for: sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2
void load_sparse_attention_bf16_sm80_b0560d11();
void load_sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2() {
load_sparse_attention_bf16_sm80_b0560d11();
}

// unload for: sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2
void unload_sparse_attention_bf16_sm80_b0560d11();
void unload_sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2() {
unload_sparse_attention_bf16_sm80_b0560d11();
}

// launcher for: sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2
Status sparse_attention_bf16_sm80_c777f3f5(SparseAttentionParams& params);

Expand All @@ -165,13 +51,7 @@ void unload_sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2() {

typedef Status (*kernel_func_t)(SparseAttentionParams& params);
kernel_func_t sparse_attention_bf16_sm80_kernels[] = {
sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2,
sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2,
sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2,
sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2,
sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2,
sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2,
sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2,
sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2,
};

Expand All @@ -185,24 +65,12 @@ Status sparse_attention_bf16_sm80(SparseAttentionParams& params, int algo_id) {
}

void load_sparse_attention_bf16_sm80(void) {
load_sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2();
load_sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2();
load_sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2();
load_sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2();
load_sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2();
load_sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2();
load_sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2();
load_sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2();
}

void unload_sparse_attention_bf16_sm80(void) {
unload_sparse_attention_bf16_sm80_16x0x64x0x64x2_warps1xstages2();
unload_sparse_attention_bf16_sm80_16x0x64x1x64x2_warps1xstages2();
unload_sparse_attention_bf16_sm80_16x1x64x0x64x2_warps1xstages2();
unload_sparse_attention_bf16_sm80_16x1x64x1x64x2_warps1xstages2();
unload_sparse_attention_bf16_sm80_64x0x64x0x64x2_warps4xstages2();
unload_sparse_attention_bf16_sm80_64x0x64x1x64x2_warps4xstages2();
unload_sparse_attention_bf16_sm80_64x1x64x0x64x2_warps4xstages2();
unload_sparse_attention_bf16_sm80_64x1x64x1x64x2_warps4xstages2();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// This file is generated by compile_sparse_attention.py using triton AoT compiler

#pragma once
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h"

namespace onnxruntime {
namespace contrib {
namespace cuda {
namespace sparse_attention_v1 {

// launcher for: sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3
Status sparse_attention_bf16_sm90_eb17c351(SparseAttentionParams& params);

Status sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3(SparseAttentionParams& params) {
return sparse_attention_bf16_sm90_eb17c351(params);
}

// load for: sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3
void load_sparse_attention_bf16_sm90_eb17c351();
void load_sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3() {
load_sparse_attention_bf16_sm90_eb17c351();
}

// unload for: sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3
void unload_sparse_attention_bf16_sm90_eb17c351();
void unload_sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3() {
unload_sparse_attention_bf16_sm90_eb17c351();
}

// launcher for: sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3
Status sparse_attention_bf16_sm90_d7dba852(SparseAttentionParams& params);

Status sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3(SparseAttentionParams& params) {
return sparse_attention_bf16_sm90_d7dba852(params);
}

// load for: sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3
void load_sparse_attention_bf16_sm90_d7dba852();
void load_sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3() {
load_sparse_attention_bf16_sm90_d7dba852();
}

// unload for: sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3
void unload_sparse_attention_bf16_sm90_d7dba852();
void unload_sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3() {
unload_sparse_attention_bf16_sm90_d7dba852();
}

typedef Status (*kernel_func_t)(SparseAttentionParams& params);
kernel_func_t sparse_attention_bf16_sm90_kernels[] = {
sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3,
sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3,
};

int sparse_attention_bf16_sm90_get_num_algos(void) {
return (int)sizeof(sparse_attention_bf16_sm90_kernels);
}

Status sparse_attention_bf16_sm90(SparseAttentionParams& params, int algo_id) {
assert(algo_id < (int)sizeof(sparse_attention_bf16_sm90_kernels));
return sparse_attention_bf16_sm90_kernels[algo_id](params);
}

void load_sparse_attention_bf16_sm90(void) {
load_sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3();
load_sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3();
}

void unload_sparse_attention_bf16_sm90(void) {
unload_sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3();
unload_sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3();
}

Status sparse_attention_bf16_sm90_default(SparseAttentionParams& params) {
return sparse_attention_bf16_sm90(params, 0);
}

} // namespace sparse_attention_v1
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
Loading

0 comments on commit baaef59

Please sign in to comment.