diff --git a/onnxruntime/contrib_ops/cuda/sparse/block_mask.cu b/onnxruntime/contrib_ops/cuda/sparse/block_mask.cu index 38e89949ac088..1e6461a145144 100644 --- a/onnxruntime/contrib_ops/cuda/sparse/block_mask.cu +++ b/onnxruntime/contrib_ops/cuda/sparse/block_mask.cu @@ -27,27 +27,28 @@ __global__ void MaskToCSR(const int* mask, int* csr_row_indices, int* csr_col_in } } - extern __shared__ int non_zero_counts[]; - non_zero_counts[threadIdx.x] = count; + extern __shared__ int shared_row_indices[]; + shared_row_indices[row + 1] = count; __syncthreads(); // The first thread will calculate the accumulated partial sum of non-zero counts. + // The result is csr_row_indices stored in shared memory. if (row == 0) { + shared_row_indices[0] = 0; for (int i = 1; i < num_rows; i++) { - non_zero_counts[i] += non_zero_counts[i - 1]; + shared_row_indices[i + 1] += shared_row_indices[i]; } + + // The first thread outputs the last element. + csr_row_indices[num_rows] = shared_row_indices[num_rows]; } __syncthreads(); - // The starting index of current row in csr_col_indices - int offset = (row == 0) ? 0 : non_zero_counts[row - 1]; + // The starting index of current row in csr_col_indices + int offset = shared_row_indices[row]; // Output row indices. csr_row_indices[row] = offset; - if (row == 0) { - // The first thread output the last element. - csr_row_indices[num_rows] = non_zero_counts[num_rows - 1]; - } for (int col = 0; col < num_cols; col++) { if (mask[row * num_cols + col] == 1) { @@ -60,6 +61,59 @@ __global__ void MaskToCSR(const int* mask, int* csr_row_indices, int* csr_col_in // The last element of csr_row_indices is the total number of non-zero elements. } +__global__ void MaskToCSR_Large(const int* mask, + int* csr_row_indices, + int* csr_col_indices, + int num_rows, + int num_cols, + int rows_per_thread // Each thread handles multiple rows +) { + extern __shared__ int shared_row_indices[]; + + // Update input and output data pointers to the start of current head + int head = blockIdx.x; + mask += head * num_rows * num_cols; + csr_row_indices += head * (num_rows + 1); + csr_col_indices += head * num_rows * num_cols; + + int tid = threadIdx.x; + for (int row = tid * rows_per_thread; row < num_rows && row < (tid + 1) * rows_per_thread; row++) { + int count = 0; + for (int col = 0; col < num_cols; col++) { + if (mask[row * num_cols + col] == 1) { + count++; + } + } + shared_row_indices[row + 1] = count; + } + + __syncthreads(); + + // The first thread will calculate the accumulated partial sum of non-zero counts. + if (tid == 0) { + shared_row_indices[0] = 0; + for (int i = 1; i < num_rows; i++) { + shared_row_indices[i + 1] += shared_row_indices[i]; + } + + csr_row_indices[num_rows] = shared_row_indices[num_rows]; + } + + __syncthreads(); + + for (int row = tid * rows_per_thread; row < num_rows && row < (tid + 1) * rows_per_thread; row++) { + int offset = shared_row_indices[row]; + csr_row_indices[row] = offset; + + for (int col = 0; col < num_cols; col++) { + if (mask[row * num_cols + col] == 1) { + csr_col_indices[offset] = col; + offset++; + } + } + } +} + void ConvertMaskToCSR(cudaStream_t stream, const int* mask, // input mask with shape (num_layout, num_rows, num_cols) int num_layout, // number of layouts @@ -68,15 +122,17 @@ void ConvertMaskToCSR(cudaStream_t stream, int* csr_row_indices, // output CSR row indices int* csr_col_indices, // output CSR column indices int max_threads_per_block) { - int threads_per_block = (num_rows + 31) / 32 * 32; - - // Each thread handle one row. The kernel assumes that all rows of one head can be handled in one block. - if (threads_per_block > max_threads_per_block) { - ORT_THROW("num_rows is too large: num_rows=", num_rows, ", max_threads_per_block=", max_threads_per_block); + if (num_rows <= max_threads_per_block) { + // Each thread handle one row. + MaskToCSR<<>>( + mask, csr_row_indices, csr_col_indices, num_rows, num_cols); + } else { + // Each thread will handle multiple rows when number of rows > max_threads_per_block. + // For example 128K length with sparse block size 64 will have 2048 rows. Each thread will handle 2 rows. + int rows_per_thread = (num_rows + max_threads_per_block - 1) / max_threads_per_block; + MaskToCSR_Large<<>>( + mask, csr_row_indices, csr_col_indices, num_rows, num_cols, rows_per_thread); } - - MaskToCSR<<>>( - mask, csr_row_indices, csr_col_indices, num_rows, num_cols); } } // namespace cuda diff --git a/onnxruntime/test/python/transformers/test_sparse_attention.py b/onnxruntime/test/python/transformers/test_sparse_attention.py index 58c48aa1338e1..9fcc23288ef4a 100644 --- a/onnxruntime/test/python/transformers/test_sparse_attention.py +++ b/onnxruntime/test/python/transformers/test_sparse_attention.py @@ -4,9 +4,7 @@ # -------------------------------------------------------------------------- """ -Parity test and benchmark performance of SparseAttention. Requires Nvidia GPU of Compute Capability 8.x. -Install required packages before running this script: - pip install matplotlib pandas onnx torch onnxruntime-gpu +Parity test and benchmark performance of SparseAttention. Requires Nvidia GPU of Compute Capability 7.5 or above. """ import math import unittest @@ -726,13 +724,13 @@ def test_sparse_attention(self): self.run_relevance_test(sm) def run_one_relevance_test(self, config: SparseAttentionConfig): - if not config.do_rotary: - # Run QGA by Torch + if (not config.do_rotary) and config.total_sequence_length <= 2048: + # Run QGA by Torch (support mask, but not packed QKV, rotary and very long sequence) gqa_config: GroupQueryAttentionConfig = config.get_comparable_torch_gqa_config(use_sparse=True) obj = TorchGroupQueryAttention(gqa_config) expected_out = obj.infer() else: - # Run QGA by ORT + # Run QGA by ORT (support packed QKV, rotary and very long sequence, but no mask so dense only). gqa_config: GroupQueryAttentionConfig = config.get_comparable_ort_gqa_config(use_local=False) obj = OrtGroupQueryAttention(gqa_config) ort_qga_outputs = obj.infer() @@ -820,10 +818,66 @@ def run_relevance_past(self, sm: int, device, do_rotary: bool): config.dtype = torch.bfloat16 self.run_one_relevance_test(config) + def run_relevance_no_past_128k(self, sm: int, device): + """Test kernel could support up to 128K sequence length.""" + for seq_len in [131072]: + for packed_qkv in [False, True]: + config = SparseAttentionConfig( + batch_size=1, + sequence_length=seq_len, + max_sequence_length=131072, + past_sequence_length=0, + num_heads=1, + kv_num_heads=1, + head_size=128, + sparse_block_size=64, + num_layout=1, + local_blocks=2048, # use dense to compare with GQA + vert_stride=8, + softmax_scale=None, + device=device, + is_packed_qkv=packed_qkv, + ) + self.run_one_relevance_test(config) + + if sm >= 80 and not packed_qkv: + config.dtype = torch.bfloat16 + self.run_one_relevance_test(config) + + def run_relevance_past_128k(self, sm: int, device): + """Test kernel could support up to 128K sequence length.""" + for past_seq_len in [131071]: + for packed_qkv in [False, True]: + config = SparseAttentionConfig( + batch_size=1, + sequence_length=1, + max_sequence_length=131072, + past_sequence_length=past_seq_len, + num_heads=1, + kv_num_heads=1, + head_size=128, + sparse_block_size=64, + num_layout=1, + local_blocks=2048, # use dense to compare with GQA + vert_stride=8, + softmax_scale=None, + device=device, + is_packed_qkv=packed_qkv, + ) + self.run_one_relevance_test(config) + + if sm >= 80 and not packed_qkv: + config.dtype = torch.bfloat16 + self.run_one_relevance_test(config) + def run_relevance_test(self, sm: int): device_id = torch.cuda.current_device() device = torch.device("cuda", device_id) with torch.no_grad(): + # Test long sequence when GPU memory is enough (need about 12 GB for 128K sequence length) + if torch.cuda.get_device_properties(device_id).total_memory > 13 * 1024 * 1024 * 1024: + self.run_relevance_no_past_128k(sm, device) + self.run_relevance_past_128k(sm, device) self.run_relevance_no_past(sm, device) self.run_relevance_past(sm, device, do_rotary=False) self.run_relevance_past(sm, device, do_rotary=True)