Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add GQA support for ROCm #21032

Merged
merged 18 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ option(onnxruntime_ENABLE_TRITON "Enable Triton" OFF)

# composable kernel is managed automatically, unless user want to explicitly disable it, it should not be manually set
option(onnxruntime_USE_COMPOSABLE_KERNEL "Enable composable kernel for ROCm EP" ON)
option(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE "Enable ck_tile for composable kernel" ON)
cmake_dependent_option(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE "Enable ck_tile for composable kernel" ON "onnxruntime_USE_COMPOSABLE_KERNEL" OFF)
option(onnxruntime_USE_ROCBLAS_EXTENSION_API "Enable rocblas tuning for ROCm EP" OFF)
option(onnxruntime_USE_TRITON_KERNEL "Enable triton compiled kernel" OFF)
option(onnxruntime_BUILD_KERNEL_EXPLORER "Build Kernel Explorer for testing and profiling GPU kernels" OFF)
Expand Down
1 change: 0 additions & 1 deletion cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ set(contrib_ops_excluded_files
"cuda_contrib_kernels.h"
"inverse.cc"
"fused_conv.cc"
"bert/group_query_attention_helper.h"
"bert/group_query_attention.h"
"bert/group_query_attention.cc"
"bert/group_query_attention_impl.h"
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,13 @@ Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream,
const T* qkv_buffer,
T* present);

template <typename T>
Status LaunchStridedCopy(
cudaStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h)
T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h)
int max_threads_per_block);

template <typename T>
Status LaunchStridedCopy(cudaStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
Expand Down
63 changes: 45 additions & 18 deletions onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,27 @@ namespace cuda {

template <typename T>
__global__ void StridedCopy(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides // coord (b,n,s,h)
) {
T* out, longlong4 out_strides, // coord (b,n,s,h)
const int32_t* in_seqlens_offset, const int32_t* out_seqlens_offset) {
const int h = threadIdx.x;
const int n = threadIdx.y;
const int s = blockIdx.x;
const int b = blockIdx.y;

const int s_offset_i = in_seqlens_offset == nullptr ? 0 : in_seqlens_offset[b];
const int s_offset_o = out_seqlens_offset == nullptr ? 0 : out_seqlens_offset[b];

if (h < H) {
const int in_offset = b * in_strides.x + n * in_strides.y + s * in_strides.z + h * in_strides.w;
const int out_offset = b * out_strides.x + n * out_strides.y + s * out_strides.z + h * out_strides.w;
const int in_offset = b * in_strides.x + n * in_strides.y + (s + s_offset_i) * in_strides.z + h * in_strides.w;
const int out_offset = b * out_strides.x + n * out_strides.y + (s + s_offset_o) * out_strides.z + h * out_strides.w;
out[out_offset] = in[in_offset];
}
}

template <typename T>
__global__ void StridedCopyLarge(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides // coord (b,n,s,h)
) {
T* out, longlong4 out_strides, // coord (b,n,s,h)
const int* in_seqlens_offset, const int* out_seqlens_offset) {
// Use when (H*)*num_heads > 1024
int h = threadIdx.x;
const int n = threadIdx.y;
Expand All @@ -37,9 +41,12 @@ __global__ void StridedCopyLarge(const T* in, const int H, longlong4 in_strides,

const int h_step = blockDim.x;

const int s_offset_i = in_seqlens_offset == nullptr ? 0 : in_seqlens_offset[b];
const int s_offset_o = out_seqlens_offset == nullptr ? 0 : out_seqlens_offset[b];

while (h < H) {
const int in_offset = b * in_strides.x + n * in_strides.y + s * in_strides.z + h * in_strides.w;
const int out_offset = b * out_strides.x + n * out_strides.y + s * out_strides.z + h * out_strides.w;
const int in_offset = b * in_strides.x + n * in_strides.y + (s + s_offset_i) * in_strides.z + h * in_strides.w;
const int out_offset = b * out_strides.x + n * out_strides.y + (s + s_offset_o) * out_strides.z + h * out_strides.w;
out[out_offset] = in[in_offset];
h += h_step;
}
Expand Down Expand Up @@ -77,10 +84,11 @@ template <int NumBytes>
using ToBytes = typename ToByteType<NumBytes>::T;

template <typename T>
Status LaunchStridedCopy(cudaStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides, // coord (b,n,s,h)
int max_threads_per_block) {
Status LaunchStridedCopy(
cudaStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h)
T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h)
int max_threads_per_block) {
int batch_size = in_shape.x;
int num_heads = in_shape.y;
int sequence_length = in_shape.z;
Expand All @@ -102,11 +110,13 @@ Status LaunchStridedCopy(cudaStream_t stream,
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
StridedCopy<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), H, in_strides,
reinterpret_cast<Bytes*>(out), out_strides);
reinterpret_cast<Bytes*>(out), out_strides,
in_seqlens_offset, out_seqlens_offset);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
StridedCopyLarge<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), H, in_strides,
reinterpret_cast<Bytes*>(out), out_strides);
reinterpret_cast<Bytes*>(out), out_strides,
in_seqlens_offset, out_seqlens_offset);
}
} else if (0 == (head_size % 2)) { // pack 2 element together
using Bytes = ToBytes<sizeof(T) * 2>;
Expand All @@ -120,27 +130,44 @@ Status LaunchStridedCopy(cudaStream_t stream,
if (H * num_heads <= max_threads_per_block) {
const dim3 block(H, num_heads, 1);
StridedCopy<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), H, in_strides,
reinterpret_cast<Bytes*>(out), out_strides);
reinterpret_cast<Bytes*>(out), out_strides,
in_seqlens_offset, out_seqlens_offset);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
StridedCopyLarge<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), H, in_strides,
reinterpret_cast<Bytes*>(out), out_strides);
reinterpret_cast<Bytes*>(out), out_strides,
in_seqlens_offset, out_seqlens_offset);
}
} else {
using Bytes = ToBytes<sizeof(T)>;
if (head_size * num_heads <= max_threads_per_block) {
const dim3 block(head_size, num_heads, 1);
StridedCopy<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), head_size, in_strides,
reinterpret_cast<Bytes*>(out), out_strides);
reinterpret_cast<Bytes*>(out), out_strides,
in_seqlens_offset, out_seqlens_offset);
} else {
const dim3 block(max_threads_per_block / num_heads, num_heads, 1);
StridedCopyLarge<Bytes><<<grid, block, 0, stream>>>(reinterpret_cast<const Bytes*>(in), head_size, in_strides,
reinterpret_cast<Bytes*>(out), out_strides);
reinterpret_cast<Bytes*>(out), out_strides,
in_seqlens_offset, out_seqlens_offset);
}
}
return CUDA_CALL(cudaGetLastError());
}

template <typename T>
Status LaunchStridedCopy(cudaStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
T* out, longlong4 out_strides, // coord (b,n,s,h)
int max_threads_per_block) {
const int* in_seqlens_offset = nullptr;
const int* out_seqlens_offset = nullptr;
return LaunchStridedCopy<T>(
stream, in, in_shape, in_strides, in_seqlens_offset,
out, out_strides, out_seqlens_offset,
max_threads_per_block);
}

template Status LaunchStridedCopy<float>(
cudaStream_t stream,
const float* in, int4 in_shape, longlong4 in_strides,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unp
}

// Kernel to convert seqlens_k to position_ids
__global__ void SeqlensToPosIdsPrompt(int32_t* seqlens_k, int64_t* position_ids, const int seqlen,
__global__ void SeqlensToPosIdsPrompt(const int32_t* seqlens_k, int64_t* position_ids, const int seqlen,
const int batch_size) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
int b = tid / seqlen;
Expand All @@ -592,15 +592,15 @@ __global__ void SeqlensToPosIdsPrompt(int32_t* seqlens_k, int64_t* position_ids,
}

// Kernel to convert seqlens_k to position_ids
__global__ void SeqlensToPosIdsToken(int32_t* seqlens_k, int64_t* position_ids, const int batch_size) {
__global__ void SeqlensToPosIdsToken(const int32_t* seqlens_k, int64_t* position_ids, const int batch_size) {
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid < batch_size) {
position_ids[tid] = seqlens_k[tid];
}
}

// Convert seqlens_k to position_ids
Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k,
Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, const int32_t* seqlens_k,
int64_t* position_ids, cudaStream_t stream, const int max_threads_per_block) {
const int seqlen = parameters.sequence_length;
const int batch_size = parameters.batch_size;
Expand Down
59 changes: 41 additions & 18 deletions onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
const int64_t* position_ids, // (1) or BxS
const int sequence_length, const int num_heads, const int head_size,
const int rotary_embedding_dim, const int position_ids_format,
const bool interleaved, const int batch_stride, const int seq_stride,
const int head_stride) {
const bool interleaved,
int4 in_strides, int4 out_strides // strides in bnsh coord, h is always contiguous
) {
// B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length
// Use .x in innermost loop to access global memory efficiently

Expand All @@ -40,10 +41,8 @@
return;
}

const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;

const T* input_data = input + block_offset;
T* output_data = output + block_offset;
const T* input_data = input + b * in_strides.x + s * in_strides.z + n * in_strides.y;
T* output_data = output + b * out_strides.x + s * out_strides.z + n * out_strides.y;

if (i >= rotary_embedding_dim) {
output_data[i] = input_data[i];
Expand Down Expand Up @@ -77,34 +76,58 @@
Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids,
const T* cos_cache, const T* sin_cache, const int batch_size,
const int sequence_length, const int num_heads, const int head_size,
const int rotary_embedding_dim, const int /*max_sequence_length*/,
const int rotary_embedding_dim, const int max_sequence_length,
const int position_ids_format, const bool interleaved,
const int max_threads_per_block, const bool is_input_bnsh_format) {
int4 in_strides;
int4 out_strides;
if (is_input_bnsh_format) {
int in_head_stride = sequence_length * head_size;
int out_head_stride = sequence_length * head_size;
in_strides = int4{num_heads * in_head_stride, in_head_stride, in_head_stride / sequence_length, 1};
out_strides = int4{num_heads * out_head_stride, out_head_stride, out_head_stride / sequence_length, 1};
} else {
int in_head_stride = head_size;
int out_head_stride = head_size;
in_strides = int4{sequence_length * num_heads * in_head_stride, in_head_stride, num_heads * in_head_stride, 1};
out_strides = int4{sequence_length * num_heads * out_head_stride, out_head_stride, num_heads * out_head_stride, 1};
}
return LaunchRotaryEmbeddingKernel<T>(
stream, output, input, position_ids,
cos_cache, sin_cache, batch_size,
sequence_length, num_heads, head_size,
rotary_embedding_dim, max_sequence_length,
position_ids_format, interleaved,
max_threads_per_block,
in_strides, out_strides);
}

template <typename T>
Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids,
const T* cos_cache, const T* sin_cache, const int batch_size,
const int sequence_length, const int num_heads, const int head_size,
const int rotary_embedding_dim, const int /*max_sequence_length*/,
const int position_ids_format, const bool interleaved,
const int max_threads_per_block,
int4 in_strides, int4 out_strides // strides in bnsh coord
) {
// Note: Current implementation assumes head_size <= max_threads_per_block
// because head_size is currently large for LLaMA-2. For smaller head_size
// and num_heads values, we can create a block as `block(num_heads, head_size, 1)`
// instead. This will require kernel changes to support.
ORT_ENFORCE(head_size <= max_threads_per_block, "Rotary embedding dim must be <= max_threads_per_block");
// strides in cannoical bnsh coord, h is always contiguous (dim_stride == 1)

Check warning on line 119 in onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "cannoical" is a misspelling of "canonical" Raw Output: ./onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu:119:16: "cannoical" is a misspelling of "canonical"
ORT_ENFORCE(in_strides.w == 1 && out_strides.w == 1, "head dim must contiguous");

int tpb = (head_size + 31) / 32 * 32;

const dim3 block(tpb);
const dim3 grid(sequence_length, batch_size, num_heads);

// Default input tensor shape is [batch, seq, hidden_size]
int head_stride = head_size;
int seq_stride = num_heads * head_stride;
int batch_stride = sequence_length * seq_stride;
if (is_input_bnsh_format) {
seq_stride = head_size;
head_stride = sequence_length * seq_stride;
batch_stride = num_heads * head_stride;
}

assert(head_size <= max_threads_per_block);
RotaryEmbeddingBSNH<<<grid, block, 0, stream>>>(output, input, cos_cache, sin_cache, position_ids, sequence_length,
num_heads, head_size, rotary_embedding_dim, position_ids_format,
interleaved, batch_stride, seq_stride, head_stride);
interleaved, in_strides, out_strides);

return CUDA_CALL(cudaGetLastError());
}
Expand Down
20 changes: 20 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,26 @@ Status LaunchRotaryEmbeddingKernel(
const int max_threads_per_block,
const bool is_input_bnsh_format);

template <typename T>
Status LaunchRotaryEmbeddingKernel(
cudaStream_t stream,
T* output,
const T* input,
const int64_t* position_ids,
const T* cos_cache,
const T* sin_cache,
const int batch_size,
const int sequence_length,
const int num_heads,
const int head_size,
const int rotary_embedding_dim,
const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
const int max_threads_per_block,
int4 in_strides,
int4 out_strides);

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
7 changes: 7 additions & 0 deletions onnxruntime/contrib_ops/rocm/bert/attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,13 @@ Status ClassifyAttentionMode(AttentionType type,
const std::vector<const Tensor*>& past,
const std::vector<Tensor*>& present);

template <typename T>
Status LaunchStridedCopy(
hipStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h)
T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h)
int max_threads_per_block);

template <typename T>
Status LaunchStridedCopy(hipStream_t stream,
const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h)
Expand Down
Loading
Loading