-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add sparse attention kernel for H100 (sm90) (#20553)
### Description Follow up of #20216 to add sparse attention kernel compiled by Triton for H100 (sm90). - [x] Refine sparse attention v1 kernel compilation (remove some combinations) - [x] compile kernels for v1 kernels - [x] compile kernels for H100 - [x] run performance tests ### Performane Test setting `batch_size=4, num_heads=32, max_seq_len=8192, head_size=128, sparse_block_size=64, local_blocks=16, vert_stride=8, num_layout=8` We compare sparse attention to corresponding GQA with local attention windows size 1024, or GQA with dense causal. Note that ORT-GQA-Dense has more computation than ORT-SparseAtt, while ORT-GQA-Local has less computation (no vertial strides) than ORT-SparseAtt. They are added for reference. It is not fair comparison, but could show the benefit of sparsity vs dense. Example results in Azure Standard_ND96isr_H100_v5 VM with NVIDIA H100-80GB-HBM3 GPU (sm=90): ``` prompt-sm90-batch4-head32-d128-local16-vert8-torch.float16: sequence_length TORCH-GQA ORT-GQA-Dense ORT-GQA-Local ORT-SparseAtt 0 16.0 0.079877 0.006362 0.006403 0.042758 1 32.0 0.086920 0.016404 0.016686 0.044183 2 64.0 0.090727 0.020429 0.020409 0.045343 3 128.0 0.128148 0.032009 0.031984 0.051516 4 256.0 0.323933 0.074110 0.073920 0.068308 5 512.0 1.021856 0.162167 0.161951 0.109226 6 1024.0 3.596002 0.452629 0.452780 0.231653 7 2048.0 13.865088 1.499534 1.195749 0.515488 8 4096.0 0.000000 5.454785 2.669682 1.163233 9 8192.0 0.000000 22.068159 6.018604 2.772873 token-sm90-batch4-head32-d128-local16-vert8-torch.float16: past_sequence_length TORCH-GQA ORT-GQA-Dense ORT-GQA-Local ORT-SparseAtt 0 16.0 0.104460 0.012652 0.012661 0.069549 1 32.0 0.113866 0.012776 0.012765 0.069024 2 64.0 0.124600 0.016791 0.012672 0.069397 3 128.0 0.108658 0.017900 0.018294 0.074844 4 256.0 0.115463 0.029409 0.029608 0.078911 5 512.0 0.149824 0.033968 0.033701 0.092998 6 1024.0 0.234050 0.042930 0.042951 0.116920 7 2048.0 0.390695 0.061462 0.043008 0.121555 8 4096.0 0.000000 0.097505 0.042948 0.134757 9 8191.0 0.000000 0.165861 0.043542 0.158796 ``` The following might be able to help performance on short sequence length. Need update operator spec: Fall back to flash attention when total_sequence length < local_blocks * block_size ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
- Loading branch information
Showing
47 changed files
with
832 additions
and
1,307 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
84 changes: 84 additions & 0 deletions
84
...ntime/contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_dispatcher_bf16_sm90.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
// This file is generated by compile_sparse_attention.py using triton AoT compiler | ||
|
||
#pragma once | ||
#include "contrib_ops/cuda/sparse/sparse_attention_v1/sparse_attention_common.h" | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
namespace cuda { | ||
namespace sparse_attention_v1 { | ||
|
||
// launcher for: sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3 | ||
Status sparse_attention_bf16_sm90_eb17c351(SparseAttentionParams& params); | ||
|
||
Status sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3(SparseAttentionParams& params) { | ||
return sparse_attention_bf16_sm90_eb17c351(params); | ||
} | ||
|
||
// load for: sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3 | ||
void load_sparse_attention_bf16_sm90_eb17c351(); | ||
void load_sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3() { | ||
load_sparse_attention_bf16_sm90_eb17c351(); | ||
} | ||
|
||
// unload for: sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3 | ||
void unload_sparse_attention_bf16_sm90_eb17c351(); | ||
void unload_sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3() { | ||
unload_sparse_attention_bf16_sm90_eb17c351(); | ||
} | ||
|
||
// launcher for: sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3 | ||
Status sparse_attention_bf16_sm90_d7dba852(SparseAttentionParams& params); | ||
|
||
Status sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3(SparseAttentionParams& params) { | ||
return sparse_attention_bf16_sm90_d7dba852(params); | ||
} | ||
|
||
// load for: sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3 | ||
void load_sparse_attention_bf16_sm90_d7dba852(); | ||
void load_sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3() { | ||
load_sparse_attention_bf16_sm90_d7dba852(); | ||
} | ||
|
||
// unload for: sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3 | ||
void unload_sparse_attention_bf16_sm90_d7dba852(); | ||
void unload_sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3() { | ||
unload_sparse_attention_bf16_sm90_d7dba852(); | ||
} | ||
|
||
typedef Status (*kernel_func_t)(SparseAttentionParams& params); | ||
kernel_func_t sparse_attention_bf16_sm90_kernels[] = { | ||
sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3, | ||
sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3, | ||
}; | ||
|
||
int sparse_attention_bf16_sm90_get_num_algos(void) { | ||
return (int)sizeof(sparse_attention_bf16_sm90_kernels); | ||
} | ||
|
||
Status sparse_attention_bf16_sm90(SparseAttentionParams& params, int algo_id) { | ||
assert(algo_id < (int)sizeof(sparse_attention_bf16_sm90_kernels)); | ||
return sparse_attention_bf16_sm90_kernels[algo_id](params); | ||
} | ||
|
||
void load_sparse_attention_bf16_sm90(void) { | ||
load_sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3(); | ||
load_sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3(); | ||
} | ||
|
||
void unload_sparse_attention_bf16_sm90(void) { | ||
unload_sparse_attention_bf16_sm90_64x0x64x0x64x2_warps4xstages3(); | ||
unload_sparse_attention_bf16_sm90_64x1x64x1x64x2_warps4xstages3(); | ||
} | ||
|
||
Status sparse_attention_bf16_sm90_default(SparseAttentionParams& params) { | ||
return sparse_attention_bf16_sm90(params, 0); | ||
} | ||
|
||
} // namespace sparse_attention_v1 | ||
} // namespace cuda | ||
} // namespace contrib | ||
} // namespace onnxruntime |
Oops, something went wrong.