Skip to content

Commit

Permalink
[CUDA] Sparse Attention support 128k sequence length (#20614)
Browse files Browse the repository at this point in the history
### Description
When sequence length is 128K, block_mask has 2048 rows, that is not
supported by previous kernel.
(1) Add a new kernel to handle more than 1024 rows, and each thread need
handle two rows.
(2) Add a test for sequence length 128k.
  • Loading branch information
tianleiwu authored May 9, 2024
1 parent a0db218 commit 69cfcba
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 23 deletions.
90 changes: 73 additions & 17 deletions onnxruntime/contrib_ops/cuda/sparse/block_mask.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,28 @@ __global__ void MaskToCSR(const int* mask, int* csr_row_indices, int* csr_col_in
}
}

extern __shared__ int non_zero_counts[];
non_zero_counts[threadIdx.x] = count;
extern __shared__ int shared_row_indices[];
shared_row_indices[row + 1] = count;
__syncthreads();

// The first thread will calculate the accumulated partial sum of non-zero counts.
// The result is csr_row_indices stored in shared memory.
if (row == 0) {
shared_row_indices[0] = 0;
for (int i = 1; i < num_rows; i++) {
non_zero_counts[i] += non_zero_counts[i - 1];
shared_row_indices[i + 1] += shared_row_indices[i];
}

// The first thread outputs the last element.
csr_row_indices[num_rows] = shared_row_indices[num_rows];
}
__syncthreads();

// The starting index of current row in csr_col_indices
int offset = (row == 0) ? 0 : non_zero_counts[row - 1];
// The starting index of current row in csr_col_indices
int offset = shared_row_indices[row];

// Output row indices.
csr_row_indices[row] = offset;
if (row == 0) {
// The first thread output the last element.
csr_row_indices[num_rows] = non_zero_counts[num_rows - 1];
}

for (int col = 0; col < num_cols; col++) {
if (mask[row * num_cols + col] == 1) {
Expand All @@ -60,6 +61,59 @@ __global__ void MaskToCSR(const int* mask, int* csr_row_indices, int* csr_col_in
// The last element of csr_row_indices is the total number of non-zero elements.
}

__global__ void MaskToCSR_Large(const int* mask,
int* csr_row_indices,
int* csr_col_indices,
int num_rows,
int num_cols,
int rows_per_thread // Each thread handles multiple rows
) {
extern __shared__ int shared_row_indices[];

// Update input and output data pointers to the start of current head
int head = blockIdx.x;
mask += head * num_rows * num_cols;
csr_row_indices += head * (num_rows + 1);
csr_col_indices += head * num_rows * num_cols;

int tid = threadIdx.x;
for (int row = tid * rows_per_thread; row < num_rows && row < (tid + 1) * rows_per_thread; row++) {
int count = 0;
for (int col = 0; col < num_cols; col++) {
if (mask[row * num_cols + col] == 1) {
count++;
}
}
shared_row_indices[row + 1] = count;
}

__syncthreads();

// The first thread will calculate the accumulated partial sum of non-zero counts.
if (tid == 0) {
shared_row_indices[0] = 0;
for (int i = 1; i < num_rows; i++) {
shared_row_indices[i + 1] += shared_row_indices[i];
}

csr_row_indices[num_rows] = shared_row_indices[num_rows];
}

__syncthreads();

for (int row = tid * rows_per_thread; row < num_rows && row < (tid + 1) * rows_per_thread; row++) {
int offset = shared_row_indices[row];
csr_row_indices[row] = offset;

for (int col = 0; col < num_cols; col++) {
if (mask[row * num_cols + col] == 1) {
csr_col_indices[offset] = col;
offset++;
}
}
}
}

void ConvertMaskToCSR(cudaStream_t stream,
const int* mask, // input mask with shape (num_layout, num_rows, num_cols)
int num_layout, // number of layouts
Expand All @@ -68,15 +122,17 @@ void ConvertMaskToCSR(cudaStream_t stream,
int* csr_row_indices, // output CSR row indices
int* csr_col_indices, // output CSR column indices
int max_threads_per_block) {
int threads_per_block = (num_rows + 31) / 32 * 32;

// Each thread handle one row. The kernel assumes that all rows of one head can be handled in one block.
if (threads_per_block > max_threads_per_block) {
ORT_THROW("num_rows is too large: num_rows=", num_rows, ", max_threads_per_block=", max_threads_per_block);
if (num_rows <= max_threads_per_block) {
// Each thread handle one row.
MaskToCSR<<<num_layout, num_rows, (num_rows + 1) * sizeof(int), stream>>>(
mask, csr_row_indices, csr_col_indices, num_rows, num_cols);
} else {
// Each thread will handle multiple rows when number of rows > max_threads_per_block.
// For example 128K length with sparse block size 64 will have 2048 rows. Each thread will handle 2 rows.
int rows_per_thread = (num_rows + max_threads_per_block - 1) / max_threads_per_block;
MaskToCSR_Large<<<num_layout, max_threads_per_block, (num_rows + 1) * sizeof(int), stream>>>(
mask, csr_row_indices, csr_col_indices, num_rows, num_cols, rows_per_thread);
}

MaskToCSR<<<num_layout, threads_per_block, threads_per_block * sizeof(int), stream>>>(
mask, csr_row_indices, csr_col_indices, num_rows, num_cols);
}

} // namespace cuda
Expand Down
66 changes: 60 additions & 6 deletions onnxruntime/test/python/transformers/test_sparse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
# --------------------------------------------------------------------------

"""
Parity test and benchmark performance of SparseAttention. Requires Nvidia GPU of Compute Capability 8.x.
Install required packages before running this script:
pip install matplotlib pandas onnx torch onnxruntime-gpu
Parity test and benchmark performance of SparseAttention. Requires Nvidia GPU of Compute Capability 7.5 or above.
"""
import math
import unittest
Expand Down Expand Up @@ -726,13 +724,13 @@ def test_sparse_attention(self):
self.run_relevance_test(sm)

def run_one_relevance_test(self, config: SparseAttentionConfig):
if not config.do_rotary:
# Run QGA by Torch
if (not config.do_rotary) and config.total_sequence_length <= 2048:
# Run QGA by Torch (support mask, but not packed QKV, rotary and very long sequence)
gqa_config: GroupQueryAttentionConfig = config.get_comparable_torch_gqa_config(use_sparse=True)
obj = TorchGroupQueryAttention(gqa_config)
expected_out = obj.infer()
else:
# Run QGA by ORT
# Run QGA by ORT (support packed QKV, rotary and very long sequence, but no mask so dense only).
gqa_config: GroupQueryAttentionConfig = config.get_comparable_ort_gqa_config(use_local=False)
obj = OrtGroupQueryAttention(gqa_config)
ort_qga_outputs = obj.infer()
Expand Down Expand Up @@ -820,10 +818,66 @@ def run_relevance_past(self, sm: int, device, do_rotary: bool):
config.dtype = torch.bfloat16
self.run_one_relevance_test(config)

def run_relevance_no_past_128k(self, sm: int, device):
"""Test kernel could support up to 128K sequence length."""
for seq_len in [131072]:
for packed_qkv in [False, True]:
config = SparseAttentionConfig(
batch_size=1,
sequence_length=seq_len,
max_sequence_length=131072,
past_sequence_length=0,
num_heads=1,
kv_num_heads=1,
head_size=128,
sparse_block_size=64,
num_layout=1,
local_blocks=2048, # use dense to compare with GQA
vert_stride=8,
softmax_scale=None,
device=device,
is_packed_qkv=packed_qkv,
)
self.run_one_relevance_test(config)

if sm >= 80 and not packed_qkv:
config.dtype = torch.bfloat16
self.run_one_relevance_test(config)

def run_relevance_past_128k(self, sm: int, device):
"""Test kernel could support up to 128K sequence length."""
for past_seq_len in [131071]:
for packed_qkv in [False, True]:
config = SparseAttentionConfig(
batch_size=1,
sequence_length=1,
max_sequence_length=131072,
past_sequence_length=past_seq_len,
num_heads=1,
kv_num_heads=1,
head_size=128,
sparse_block_size=64,
num_layout=1,
local_blocks=2048, # use dense to compare with GQA
vert_stride=8,
softmax_scale=None,
device=device,
is_packed_qkv=packed_qkv,
)
self.run_one_relevance_test(config)

if sm >= 80 and not packed_qkv:
config.dtype = torch.bfloat16
self.run_one_relevance_test(config)

def run_relevance_test(self, sm: int):
device_id = torch.cuda.current_device()
device = torch.device("cuda", device_id)
with torch.no_grad():
# Test long sequence when GPU memory is enough (need about 12 GB for 128K sequence length)
if torch.cuda.get_device_properties(device_id).total_memory > 13 * 1024 * 1024 * 1024:
self.run_relevance_no_past_128k(sm, device)
self.run_relevance_past_128k(sm, device)
self.run_relevance_no_past(sm, device)
self.run_relevance_past(sm, device, do_rotary=False)
self.run_relevance_past(sm, device, do_rotary=True)
Expand Down

0 comments on commit 69cfcba

Please sign in to comment.