diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index ce22def914851..4dd4cd7d34bbf 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -240,7 +240,7 @@ option(onnxruntime_ENABLE_TRITON "Enable Triton" OFF) # composable kernel is managed automatically, unless user want to explicitly disable it, it should not be manually set option(onnxruntime_USE_COMPOSABLE_KERNEL "Enable composable kernel for ROCm EP" ON) -option(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE "Enable ck_tile for composable kernel" ON) +cmake_dependent_option(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE "Enable ck_tile for composable kernel" ON "onnxruntime_USE_COMPOSABLE_KERNEL" OFF) option(onnxruntime_USE_ROCBLAS_EXTENSION_API "Enable rocblas tuning for ROCm EP" OFF) option(onnxruntime_USE_TRITON_KERNEL "Enable triton compiled kernel" OFF) option(onnxruntime_BUILD_KERNEL_EXPLORER "Build Kernel Explorer for testing and profiling GPU kernels" OFF) diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 2be68146b5e94..2966a4624a966 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -88,7 +88,6 @@ set(contrib_ops_excluded_files "cuda_contrib_kernels.h" "inverse.cc" "fused_conv.cc" - "bert/group_query_attention_helper.h" "bert/group_query_attention.h" "bert/group_query_attention.cc" "bert/group_query_attention_impl.h" diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index 36fd7708de04b..fda7ac2784129 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -176,6 +176,13 @@ Status LaunchAddBiasTransAppendKvToPresent(cudaStream_t stream, const T* qkv_buffer, T* present); +template +Status LaunchStridedCopy( + cudaStream_t stream, + const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) + T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) + int max_threads_per_block); + template Status LaunchStridedCopy(cudaStream_t stream, const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu b/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu index 1466f5fcfe0be..66e56e701c558 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_strided_copy.cu @@ -12,23 +12,27 @@ namespace cuda { template __global__ void StridedCopy(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h) - T* out, longlong4 out_strides // coord (b,n,s,h) -) { + T* out, longlong4 out_strides, // coord (b,n,s,h) + const int32_t* in_seqlens_offset, const int32_t* out_seqlens_offset) { const int h = threadIdx.x; const int n = threadIdx.y; const int s = blockIdx.x; const int b = blockIdx.y; + + const int s_offset_i = in_seqlens_offset == nullptr ? 0 : in_seqlens_offset[b]; + const int s_offset_o = out_seqlens_offset == nullptr ? 0 : out_seqlens_offset[b]; + if (h < H) { - const int in_offset = b * in_strides.x + n * in_strides.y + s * in_strides.z + h * in_strides.w; - const int out_offset = b * out_strides.x + n * out_strides.y + s * out_strides.z + h * out_strides.w; + const int in_offset = b * in_strides.x + n * in_strides.y + (s + s_offset_i) * in_strides.z + h * in_strides.w; + const int out_offset = b * out_strides.x + n * out_strides.y + (s + s_offset_o) * out_strides.z + h * out_strides.w; out[out_offset] = in[in_offset]; } } template __global__ void StridedCopyLarge(const T* in, const int H, longlong4 in_strides, // coord (b,n,s,h) - T* out, longlong4 out_strides // coord (b,n,s,h) -) { + T* out, longlong4 out_strides, // coord (b,n,s,h) + const int* in_seqlens_offset, const int* out_seqlens_offset) { // Use when (H*)*num_heads > 1024 int h = threadIdx.x; const int n = threadIdx.y; @@ -37,9 +41,12 @@ __global__ void StridedCopyLarge(const T* in, const int H, longlong4 in_strides, const int h_step = blockDim.x; + const int s_offset_i = in_seqlens_offset == nullptr ? 0 : in_seqlens_offset[b]; + const int s_offset_o = out_seqlens_offset == nullptr ? 0 : out_seqlens_offset[b]; + while (h < H) { - const int in_offset = b * in_strides.x + n * in_strides.y + s * in_strides.z + h * in_strides.w; - const int out_offset = b * out_strides.x + n * out_strides.y + s * out_strides.z + h * out_strides.w; + const int in_offset = b * in_strides.x + n * in_strides.y + (s + s_offset_i) * in_strides.z + h * in_strides.w; + const int out_offset = b * out_strides.x + n * out_strides.y + (s + s_offset_o) * out_strides.z + h * out_strides.w; out[out_offset] = in[in_offset]; h += h_step; } @@ -77,10 +84,11 @@ template using ToBytes = typename ToByteType::T; template -Status LaunchStridedCopy(cudaStream_t stream, - const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) - T* out, longlong4 out_strides, // coord (b,n,s,h) - int max_threads_per_block) { +Status LaunchStridedCopy( + cudaStream_t stream, + const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) + T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) + int max_threads_per_block) { int batch_size = in_shape.x; int num_heads = in_shape.y; int sequence_length = in_shape.z; @@ -102,11 +110,13 @@ Status LaunchStridedCopy(cudaStream_t stream, if (H * num_heads <= max_threads_per_block) { const dim3 block(H, num_heads, 1); StridedCopy<<>>(reinterpret_cast(in), H, in_strides, - reinterpret_cast(out), out_strides); + reinterpret_cast(out), out_strides, + in_seqlens_offset, out_seqlens_offset); } else { const dim3 block(max_threads_per_block / num_heads, num_heads, 1); StridedCopyLarge<<>>(reinterpret_cast(in), H, in_strides, - reinterpret_cast(out), out_strides); + reinterpret_cast(out), out_strides, + in_seqlens_offset, out_seqlens_offset); } } else if (0 == (head_size % 2)) { // pack 2 element together using Bytes = ToBytes; @@ -120,27 +130,44 @@ Status LaunchStridedCopy(cudaStream_t stream, if (H * num_heads <= max_threads_per_block) { const dim3 block(H, num_heads, 1); StridedCopy<<>>(reinterpret_cast(in), H, in_strides, - reinterpret_cast(out), out_strides); + reinterpret_cast(out), out_strides, + in_seqlens_offset, out_seqlens_offset); } else { const dim3 block(max_threads_per_block / num_heads, num_heads, 1); StridedCopyLarge<<>>(reinterpret_cast(in), H, in_strides, - reinterpret_cast(out), out_strides); + reinterpret_cast(out), out_strides, + in_seqlens_offset, out_seqlens_offset); } } else { using Bytes = ToBytes; if (head_size * num_heads <= max_threads_per_block) { const dim3 block(head_size, num_heads, 1); StridedCopy<<>>(reinterpret_cast(in), head_size, in_strides, - reinterpret_cast(out), out_strides); + reinterpret_cast(out), out_strides, + in_seqlens_offset, out_seqlens_offset); } else { const dim3 block(max_threads_per_block / num_heads, num_heads, 1); StridedCopyLarge<<>>(reinterpret_cast(in), head_size, in_strides, - reinterpret_cast(out), out_strides); + reinterpret_cast(out), out_strides, + in_seqlens_offset, out_seqlens_offset); } } return CUDA_CALL(cudaGetLastError()); } +template +Status LaunchStridedCopy(cudaStream_t stream, + const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) + T* out, longlong4 out_strides, // coord (b,n,s,h) + int max_threads_per_block) { + const int* in_seqlens_offset = nullptr; + const int* out_seqlens_offset = nullptr; + return LaunchStridedCopy( + stream, in, in_shape, in_strides, in_seqlens_offset, + out, out_strides, out_seqlens_offset, + max_threads_per_block); +} + template Status LaunchStridedCopy( cudaStream_t stream, const float* in, int4 in_shape, longlong4 in_strides, diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 62974d12003fe..77c85afffb66c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -577,7 +577,7 @@ Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unp } // Kernel to convert seqlens_k to position_ids -__global__ void SeqlensToPosIdsPrompt(int32_t* seqlens_k, int64_t* position_ids, const int seqlen, +__global__ void SeqlensToPosIdsPrompt(const int32_t* seqlens_k, int64_t* position_ids, const int seqlen, const int batch_size) { int tid = blockDim.x * blockIdx.x + threadIdx.x; int b = tid / seqlen; @@ -592,7 +592,7 @@ __global__ void SeqlensToPosIdsPrompt(int32_t* seqlens_k, int64_t* position_ids, } // Kernel to convert seqlens_k to position_ids -__global__ void SeqlensToPosIdsToken(int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { +__global__ void SeqlensToPosIdsToken(const int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { int tid = blockDim.x * blockIdx.x + threadIdx.x; if (tid < batch_size) { position_ids[tid] = seqlens_k[tid]; @@ -600,7 +600,7 @@ __global__ void SeqlensToPosIdsToken(int32_t* seqlens_k, int64_t* position_ids, } // Convert seqlens_k to position_ids -Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, +Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, const int32_t* seqlens_k, int64_t* position_ids, cudaStream_t stream, const int max_threads_per_block) { const int seqlen = parameters.sequence_length; const int batch_size = parameters.batch_size; diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu index 1b28b288f3d7c..ad0a83c9cde65 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -25,8 +25,9 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const int64_t* position_ids, // (1) or BxS const int sequence_length, const int num_heads, const int head_size, const int rotary_embedding_dim, const int position_ids_format, - const bool interleaved, const int batch_stride, const int seq_stride, - const int head_stride) { + const bool interleaved, + int4 in_strides, int4 out_strides // strides in bnsh coord, h is always contiguous +) { // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length // Use .x in innermost loop to access global memory efficiently @@ -40,10 +41,8 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH return; } - const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; - - const T* input_data = input + block_offset; - T* output_data = output + block_offset; + const T* input_data = input + b * in_strides.x + s * in_strides.z + n * in_strides.y; + T* output_data = output + b * out_strides.x + s * out_strides.z + n * out_strides.y; if (i >= rotary_embedding_dim) { output_data[i] = input_data[i]; @@ -77,34 +76,58 @@ template Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids, const T* cos_cache, const T* sin_cache, const int batch_size, const int sequence_length, const int num_heads, const int head_size, - const int rotary_embedding_dim, const int /*max_sequence_length*/, + const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, const int max_threads_per_block, const bool is_input_bnsh_format) { + int4 in_strides; + int4 out_strides; + if (is_input_bnsh_format) { + int in_head_stride = sequence_length * head_size; + int out_head_stride = sequence_length * head_size; + in_strides = int4{num_heads * in_head_stride, in_head_stride, in_head_stride / sequence_length, 1}; + out_strides = int4{num_heads * out_head_stride, out_head_stride, out_head_stride / sequence_length, 1}; + } else { + int in_head_stride = head_size; + int out_head_stride = head_size; + in_strides = int4{sequence_length * num_heads * in_head_stride, in_head_stride, num_heads * in_head_stride, 1}; + out_strides = int4{sequence_length * num_heads * out_head_stride, out_head_stride, num_heads * out_head_stride, 1}; + } + return LaunchRotaryEmbeddingKernel( + stream, output, input, position_ids, + cos_cache, sin_cache, batch_size, + sequence_length, num_heads, head_size, + rotary_embedding_dim, max_sequence_length, + position_ids_format, interleaved, + max_threads_per_block, + in_strides, out_strides); +} + +template +Status LaunchRotaryEmbeddingKernel(cudaStream_t stream, T* output, const T* input, const int64_t* position_ids, + const T* cos_cache, const T* sin_cache, const int batch_size, + const int sequence_length, const int num_heads, const int head_size, + const int rotary_embedding_dim, const int /*max_sequence_length*/, + const int position_ids_format, const bool interleaved, + const int max_threads_per_block, + int4 in_strides, int4 out_strides // strides in bnsh coord +) { // Note: Current implementation assumes head_size <= max_threads_per_block // because head_size is currently large for LLaMA-2. For smaller head_size // and num_heads values, we can create a block as `block(num_heads, head_size, 1)` // instead. This will require kernel changes to support. ORT_ENFORCE(head_size <= max_threads_per_block, "Rotary embedding dim must be <= max_threads_per_block"); + // strides in canonical bnsh coord, h is always contiguous (dim_stride == 1) + ORT_ENFORCE(in_strides.w == 1 && out_strides.w == 1, "head dim must contiguous"); int tpb = (head_size + 31) / 32 * 32; const dim3 block(tpb); const dim3 grid(sequence_length, batch_size, num_heads); - // Default input tensor shape is [batch, seq, hidden_size] - int head_stride = head_size; - int seq_stride = num_heads * head_stride; - int batch_stride = sequence_length * seq_stride; - if (is_input_bnsh_format) { - seq_stride = head_size; - head_stride = sequence_length * seq_stride; - batch_stride = num_heads * head_stride; - } - assert(head_size <= max_threads_per_block); RotaryEmbeddingBSNH<<>>(output, input, cos_cache, sin_cache, position_ids, sequence_length, num_heads, head_size, rotary_embedding_dim, position_ids_format, - interleaved, batch_stride, seq_stride, head_stride); + interleaved, in_strides, out_strides); return CUDA_CALL(cudaGetLastError()); } diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h index 6053814b835bb..dd0ac6a6e3274 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h @@ -28,6 +28,26 @@ Status LaunchRotaryEmbeddingKernel( const int max_threads_per_block, const bool is_input_bnsh_format); +template +Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + T* output, + const T* input, + const int64_t* position_ids, + const T* cos_cache, + const T* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int rotary_embedding_dim, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block, + int4 in_strides, + int4 out_strides); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h index 3164e8c211099..349df045becf2 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_impl.h @@ -169,6 +169,13 @@ Status ClassifyAttentionMode(AttentionType type, const std::vector& past, const std::vector& present); +template +Status LaunchStridedCopy( + hipStream_t stream, + const T* in, int4 in_shape, longlong4 in_strides, const int* in_seqlens_offset, // coord (b,n,s,h) + T* out, longlong4 out_strides, const int* out_seqlens_offset, // coord (b,n,s,h) + int max_threads_per_block); + template Status LaunchStridedCopy(hipStream_t stream, const T* in, int4 in_shape, longlong4 in_strides, // coord (b,n,s,h) diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu new file mode 100644 index 0000000000000..92c780d4a9d41 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.cu @@ -0,0 +1,526 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/shared_library/provider_api.h" +#include "core/providers/rocm/rocm_common.h" +#include "core/platform/env_var_utils.h" +#include "contrib_ops/rocm/bert/group_query_attention.h" +#include "contrib_ops/rocm/bert/group_query_attention_helper.h" +#include "contrib_ops/rocm/bert/rotary_embedding_impl.h" +#include "contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh" + +#ifdef USE_COMPOSABLE_KERNEL_CK_TILE +#include "ck_tile/core/numeric/integer.hpp" +#include "fmha_fwd.hpp" +#endif + +using namespace onnxruntime::rocm; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + GroupQueryAttention, \ + kMSDomain, \ + 1, \ + T, \ + kRocmExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()) \ + .MayInplace(3, 1) \ + .MayInplace(4, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 6), \ + GroupQueryAttention); + +// REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) +// REGISTER_KERNEL_TYPED(BFloat16) + +template +std::string GetCkFmhaDataTypeString(); + +template <> +std::string GetCkFmhaDataTypeString() { + return "fp16"; +} + +template <> +std::string GetCkFmhaDataTypeString() { + return "bf16"; +} + +__global__ void seqlens_inc_kernel(const int* seqlens, int* out, int num_elems, int inc) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx < num_elems) { + out[idx] = seqlens[idx] + inc; + } +} + +Status LaunchSeqlensInc(hipStream_t stream, const int* seqlens, int* out, int num_elems, int inc) { + constexpr int NumThreads = 128; + int num_blks = CeilDiv(num_elems, NumThreads); + seqlens_inc_kernel<<>>(seqlens, out, num_elems, inc); + return HIP_CALL(hipGetLastError()); +} + +__global__ void seqstart_init_kernel(int* out, int num_elems, int length_per_seq) { + int idx = blockDim.x * blockIdx.x + threadIdx.x; + if (idx < num_elems) { + out[idx] = idx * length_per_seq; + } + if (idx == 0) { + out[num_elems] = num_elems * length_per_seq; + } +} + +Status LaunchSeqStartInit(hipStream_t stream, int* out, int num_elems, int length_per_seq) { + constexpr int NumThreads = 128; + int num_blks = CeilDiv(num_elems, NumThreads); + seqstart_init_kernel<<>>(out, num_elems, length_per_seq); + return HIP_CALL(hipGetLastError()); +} + +// Kernel to convert seqlens_k to position_ids +__global__ void SeqlensToPosIdsPrompt(const int32_t* seqlens_k, int64_t* position_ids, const int seqlen, + const int batch_size) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + int b = tid / seqlen; + int s = tid % seqlen; + if (b < batch_size) { + if (s < seqlens_k[b] + 1) { + position_ids[tid] = s; + } else { + position_ids[tid] = 1; + } + } +} + +// Kernel to convert seqlens_k to position_ids +__global__ void SeqlensToPosIdsToken(const int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid < batch_size) { + position_ids[tid] = seqlens_k[tid]; + } +} + +// Convert seqlens_k to position_ids +Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, const int32_t* seqlens_k, + int64_t* position_ids, hipStream_t stream, const int max_threads_per_block) { + const int seqlen = parameters.sequence_length; + const int batch_size = parameters.batch_size; + const int threads = max_threads_per_block; + const int blocks = (batch_size * seqlen + threads - 1) / threads; + if (parameters.is_prompt) { + SeqlensToPosIdsPrompt<<>>(seqlens_k, position_ids, seqlen, batch_size); + } else { + SeqlensToPosIdsToken<<>>(seqlens_k, position_ids, batch_size); + } + return HIP_CALL(hipGetLastError()); +} + +template +GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) + : RocmKernel(info) { + int64_t num_heads = 0; + int64_t kv_num_heads = 0; + ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); + ORT_ENFORCE(info.GetAttr("kv_num_heads", &kv_num_heads).IsOK() && kv_num_heads > 0 && num_heads % kv_num_heads == 0); + num_heads_ = static_cast(num_heads); + kv_num_heads_ = static_cast(kv_num_heads); + is_past_bsnh_ = false; + is_unidirectional_ = true; + local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); + do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; + rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; + scale_ = info.GetAttrOrDefault("scale", 0.0f); +} + +template <> +std::once_flag GroupQueryAttention::arch_checking_{}; + +template <> +std::once_flag GroupQueryAttention::arch_checking_{}; + +template +Status GroupQueryAttention::ComputeInternal(OpKernelContext* ctx) const { +#if USE_COMPOSABLE_KERNEL_CK_TILE + auto hip_stream = static_cast(ctx->GetComputeStream()->GetHandle()); + const Tensor* query = ctx->Input(0); + const Tensor* key = ctx->Input(1); + const Tensor* value = ctx->Input(2); + const Tensor* past_key = ctx->Input(3); + const Tensor* past_value = ctx->Input(4); + const Tensor* seqlens_k = ctx->Input(5); + const Tensor* total_seqlen = ctx->Input(6); + const Tensor* cos_cache = ctx->Input(7); + const Tensor* sin_cache = ctx->Input(8); + + auto& device_prop = GetDeviceProp(); + std::call_once( + arch_checking_, + [](const hipDeviceProp_t& device_prop) { + if (std::string_view(device_prop.gcnArchName).find("gfx90a") == std::string_view::npos && + std::string_view(device_prop.gcnArchName).find("gfx942") == std::string_view::npos) { + LOGS_DEFAULT(WARNING) + << "GroupQueryAttention currently only supports ck_tile fmha backend which only supports " + << "CDNA2 and CDNA3 archs."; + LOGS_DEFAULT(WARNING) + << "GroupQueryAttention running on an unsuppoted GPU may result in " + << "hipErrorNoBinaryForGpu or hipErrorSharedObjectInitFailedshared error."; + } + }, + device_prop); + + GroupQueryAttentionParameters parameters; + using HipT = typename ToHipType::MappedType; + + const int max_thr_per_blk = device_prop.maxThreadsPerBlock; + + ORT_RETURN_IF_ERROR(group_query_attention_helper::CheckInputs(query, + key, + value, + past_key, + past_value, + cos_cache, + sin_cache, + ¶meters, + num_heads_, + kv_num_heads_, + seqlens_k, + total_seqlen, + is_past_bsnh_, + scale_, + max_thr_per_blk)); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int kv_sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int kv_num_heads = parameters.kv_num_heads; + const int head_size = parameters.head_size; + AttentionQkvFormat past_kv_format = parameters.past_kv_format; + + parameters.local_window_size = local_window_size_; + parameters.is_unidirectional = is_unidirectional_; + // parameters.zeros_count = kZerosCount; + // parameters.zero_ptr = zeros_.get(); + // parameters.left_padding = left_padding_; + parameters.do_rotary = do_rotary_; + parameters.rotary_interleaved = rotary_interleaved_; + + if (do_rotary_ && (cos_cache == nullptr || sin_cache == nullptr)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache and sin_cache must be passed to GroupQueryAttention when do_rotary = 1"); + } + + TensorShapeVector output_shape(3); + output_shape[0] = static_cast(batch_size); + output_shape[1] = static_cast(sequence_length); + output_shape[2] = static_cast(parameters.hidden_size); + Tensor* output = ctx->Output(0, output_shape); + Strides output_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); + + int4 past_shape; + std::vector present_dims; + Strides present_strides; + Strides past_strides; + if (past_kv_format == AttentionQkvFormat::Q_K_V_BSNH) { + past_shape = { + batch_size, parameters.seqlen_past_kv_cache, kv_num_heads, head_size}; + past_strides = Strides::BSNHMemory( + batch_size, parameters.seqlen_past_kv_cache, kv_num_heads, head_size); + present_dims = { + batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size}; + present_strides = Strides::BSNHMemory( + batch_size, parameters.seqlen_present_kv_cache, kv_num_heads, head_size); + } else { // BNSH + past_shape = { + batch_size, kv_num_heads, parameters.seqlen_past_kv_cache, head_size}; + past_strides = Strides::BNSHMemory( + batch_size, kv_num_heads, parameters.seqlen_past_kv_cache, head_size); + present_dims = { + batch_size, kv_num_heads, parameters.seqlen_present_kv_cache, head_size}; + present_strides = Strides::BNSHMemory( + batch_size, kv_num_heads, parameters.seqlen_present_kv_cache, head_size); + } + TensorShape present_shape(present_dims); + Tensor* present_key = ctx->Output(1, present_shape); + Tensor* present_value = ctx->Output(2, present_shape); + + Strides query_strides; + Strides key_strides; + Strides value_strides; + int4 kv_shape{batch_size, kv_num_heads, kv_sequence_length, head_size}; // BNSH coord + const HipT* query_ptr = reinterpret_cast(query->DataRaw()); + const HipT* key_ptr; + const HipT* value_ptr; + if (!parameters.is_packed_qkv) { + query_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); + key_strides = Strides::BSNHMemory(batch_size, kv_sequence_length, kv_num_heads, head_size); + value_strides = key_strides; + key_ptr = reinterpret_cast(key->DataRaw()); + value_ptr = reinterpret_cast(value->DataRaw()); + } else { + query_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads + 2 * kv_num_heads, head_size); + key_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads + 2 * kv_num_heads, head_size); + value_strides = query_strides; + const size_t key_offset = static_cast(num_heads * head_size); + const size_t value_offset = static_cast(kv_num_heads * head_size); + key_ptr = query_ptr + key_offset; + value_ptr = key_ptr + value_offset; + } + + IAllocatorUniquePtr rotary_q_tmp; + IAllocatorUniquePtr rotary_k_tmp; + if (parameters.do_rotary) { + size_t q_size = static_cast(batch_size * sequence_length * num_heads * head_size); + size_t k_size = static_cast(batch_size * sequence_length * kv_num_heads * head_size); + auto rotary_q_strides = Strides::BSNHMemory(batch_size, sequence_length, num_heads, head_size); + auto rotary_k_strides = Strides::BSNHMemory(batch_size, sequence_length, kv_num_heads, head_size); + + rotary_q_tmp = GetScratchBuffer(q_size, ctx->GetComputeStream()); + rotary_k_tmp = GetScratchBuffer(k_size, ctx->GetComputeStream()); + auto rotary_position_ids_tmp = GetScratchBuffer(sequence_length * batch_size, ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(LaunchSeqlensToPosIds(parameters, + reinterpret_cast(seqlens_k->DataRaw()), + reinterpret_cast(rotary_position_ids_tmp.get()), + hip_stream, max_thr_per_blk)); + // Launch rotary embedding kernel + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(hip_stream, rotary_q_tmp.get(), query_ptr, + reinterpret_cast(rotary_position_ids_tmp.get()), + reinterpret_cast(cos_cache->DataRaw()), + reinterpret_cast(sin_cache->DataRaw()), + parameters.batch_size, parameters.sequence_length, + parameters.num_heads, parameters.head_size, + parameters.rotary_dim, parameters.seqlen_present_kv_cache, + /*position_ids_format*/ 1, parameters.rotary_interleaved, + max_thr_per_blk, + query_strides.ForBNSHCoord(), + rotary_q_strides.ForBNSHCoord())); + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(hip_stream, rotary_k_tmp.get(), key_ptr, + reinterpret_cast(rotary_position_ids_tmp.get()), + reinterpret_cast(cos_cache->DataRaw()), + reinterpret_cast(sin_cache->DataRaw()), + parameters.batch_size, parameters.sequence_length, + parameters.kv_num_heads, parameters.head_size, + parameters.rotary_dim, parameters.seqlen_present_kv_cache, + /*position_ids_format*/ 1, parameters.rotary_interleaved, + max_thr_per_blk, + key_strides.ForBNSHCoord(), + rotary_k_strides.ForBNSHCoord())); + query_ptr = reinterpret_cast(rotary_q_tmp.get()); + key_ptr = reinterpret_cast(rotary_k_tmp.get()); + query_strides = rotary_q_strides; + key_strides = rotary_k_strides; + } + + const int* seqlens_k_ptr = seqlens_k ? reinterpret_cast(seqlens_k->DataRaw()) : nullptr; + IAllocatorUniquePtr seqlens_k_tmp; + + // build present kv cache + auto* present_key_ptr = reinterpret_cast(present_key->MutableDataRaw()); + auto* present_value_ptr = reinterpret_cast(present_value->MutableDataRaw()); + if (parameters.is_prompt) { + // copy prompt kv to present kv + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(), + present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, value_ptr, kv_shape, value_strides.ForBNSHCoord(), + present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + } else { + const auto* past_key_ptr = past_key == nullptr ? nullptr : reinterpret_cast(past_key->DataRaw()); + const auto* past_value_ptr = past_key == nullptr ? nullptr : reinterpret_cast(past_value->DataRaw()); + parameters.kv_share_buffer = past_key_ptr == present_key_ptr; // FIXME: + if (!parameters.kv_share_buffer) { + // copy past to present, + // NOTE: we do a low perf full buffer copy due to the seqlens_k indicate the seqlen of different seqs are + // not the same, aka, can not be as simple as strided + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, past_key_ptr, past_shape, past_strides.ForBNSHCoord(), + present_key_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + ORT_RETURN_IF_ERROR(LaunchStridedCopy(hip_stream, past_value_ptr, past_shape, past_strides.ForBNSHCoord(), + present_value_ptr, present_strides.ForBNSHCoord(), max_thr_per_blk)); + } else { + // In the case of share buffer + ORT_ENFORCE(past_key_ptr == nullptr || past_key_ptr == present_key_ptr); + ORT_ENFORCE(past_key_ptr == nullptr || past_value_ptr == present_value_ptr); + } + // then append new kv to present + size_t buffer_offset = seqlens_k ? 0 : present_strides.OffsetAt(0, 0, kv_sequence_length, 0); + ORT_RETURN_IF_ERROR(LaunchStridedCopy( + hip_stream, key_ptr, kv_shape, key_strides.ForBNSHCoord(), /*in_seqlens_offset=*/nullptr, + present_key_ptr + buffer_offset, present_strides.ForBNSHCoord(), seqlens_k_ptr, + max_thr_per_blk)); + ORT_RETURN_IF_ERROR(LaunchStridedCopy( + hip_stream, value_ptr, kv_shape, value_strides.ForBNSHCoord(), /*in_seqlens_offset=*/nullptr, + present_value_ptr + buffer_offset, present_strides.ForBNSHCoord(), seqlens_k_ptr, + max_thr_per_blk)); + + // NOTE: ORT: seqlens_k Indicates past sequence lengths for token generation case. + // we should call fmha with total sequence lenghts + seqlens_k_tmp = GetScratchBuffer(batch_size * sizeof(int), ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(LaunchSeqlensInc(hip_stream, seqlens_k_ptr, seqlens_k_tmp.get(), batch_size, sequence_length)); + seqlens_k_ptr = seqlens_k_tmp.get(); + } + static_assert(std::is_same_v); + + const float scale = parameters.scale == 0.0f + ? 1.f / sqrt(static_cast(parameters.head_size)) + : parameters.scale; + bias_enum bias_type = bias_enum::no_bias; + + mask_info mask = [&]() { + if (local_window_size_ != -1) { + mask_info ret; + ret.type = mask_enum::window_generic; + ret.left = local_window_size_; + ret.right = parameters.is_unidirectional ? 0 : -1; + // ret.x = kv_sequence_length - (sequence_length - ret.left); + // ret.y = sequence_length + (ret.right - kv_sequence_length); + return ret; + } + + if (parameters.is_prompt && is_unidirectional_) { + return mask_info::decode("t", sequence_length, kv_sequence_length); + } + + return mask_info::decode("0", sequence_length, kv_sequence_length); + }(); + + auto seqstart_q_tmp = GetScratchBuffer((batch_size + 1) * sizeof(int), ctx->GetComputeStream()); + auto seqstart_k_tmp = GetScratchBuffer((batch_size + 1) * sizeof(int), ctx->GetComputeStream()); + ORT_RETURN_IF_ERROR(LaunchSeqStartInit( + hip_stream, seqstart_q_tmp.get(), batch_size, + query_strides.strides_for_bnsh_coord.x / query_strides.strides_for_bnsh_coord.z)); + ORT_RETURN_IF_ERROR(LaunchSeqStartInit( + hip_stream, seqstart_k_tmp.get(), batch_size, + present_strides.strides_for_bnsh_coord.x / present_strides.strides_for_bnsh_coord.z)); + + fmha_fwd_args args{ + query_ptr, + present_key->DataRaw(), + present_value->DataRaw(), + nullptr, // bias, alibi/element + nullptr, // lse, logsumexp buffer + output->MutableDataRaw(), + seqstart_q_tmp.get(), // seqstart_q_ptr, for group mode + seqstart_k_tmp.get(), // seqstart_k_ptr, for group mode + seqlens_k_ptr, // seqlen_k_ptr, for group mode + sequence_length, // seqlen_q, for batch mode + kv_sequence_length, // seqlen_k, for batch mode + parameters.batch_size, // batch + parameters.sequence_length, // max_seqlen_q + parameters.head_size, // hdim_q + parameters.head_size, // hdim_v + parameters.num_heads, + parameters.kv_num_heads, + scale, + 1.0f, // scale_p of squant, useless + 1.0f, // scale_o of squant, useless + static_cast(query_strides.strides_for_bnsh_coord.z), // stride_q, to be regarded as stride of dim S + static_cast(present_strides.strides_for_bnsh_coord.z), // stride_k, to be regarded as stride of dim S + static_cast(present_strides.strides_for_bnsh_coord.z), // stride_v, to be regarded as stride of dim S + batch_size, // stride_bias, if alibi, b*h need set this to h, 1*h need set this to 0 + static_cast(output_strides.strides_for_bnsh_coord.z), // stride_o, to be regarded as stride of dim S + static_cast(query_strides.strides_for_bnsh_coord.y), // nhead_stride_q, to be regarded as stride of dim N + static_cast(present_strides.strides_for_bnsh_coord.y), // nhead_stride_k, to be regarded as stride of dim N + static_cast(present_strides.strides_for_bnsh_coord.y), // nhead_stride_v, to be regarded as stride of dim N + 0, // nhead_stride_bias + batch_size, // nhead_stride_lse + static_cast(output_strides.strides_for_bnsh_coord.y), // batch_stride_o, to be regarded as stride of dim B + static_cast(query_strides.strides_for_bnsh_coord.x), // batch_stride_q, to be regarded as stride of dim B + static_cast(present_strides.strides_for_bnsh_coord.x), // batch_stride_k, to be regarded as stride of dim B + static_cast(present_strides.strides_for_bnsh_coord.x), // batch_stride_v, to be regarded as stride of dim B + 0, // batch_stride_bias + num_heads * batch_size, // batch_stride_lse + static_cast(output_strides.strides_for_bnsh_coord.x), // batch_stride_o, to be regarded as stride of dim B + mask.left, // window_size_left + mask.right, // window_size_right + static_cast(mask.type)}; + +#if 0 + std::cout + << "\n sequence_length:" << sequence_length + << "\n kv_sequence_length:" << kv_sequence_length + << "\n seqlen_past_kv_cache:" << parameters.seqlen_past_kv_cache + << "\n seqlen_present_kv_cache:" << parameters.seqlen_present_kv_cache << std::endl; + + std::cout + << "\n q_ptr:" << args.q_ptr + << "\n k_ptr:" << args.k_ptr + << "\n v_ptr:" << args.v_ptr + << "\n bias_ptr:" << args.bias_ptr + << "\n lse_ptr:" << args.lse_ptr + << "\n o_ptr:" << args.o_ptr + << "\n seqstart_q_ptr:" << args.seqstart_q_ptr + << "\n seqstart_k_ptr:" << args.seqstart_k_ptr + << "\n seqlen_k_ptr:" << args.seqlen_k_ptr + << "\n seqlen_q:" << args.seqlen_q + << "\n seqlen_k:" << args.seqlen_k + << "\n batch:" << args.batch + << "\n max_seqlen_q:" << args.max_seqlen_q + << "\n hdim_q:" << args.hdim_q + << "\n hdim_v:" << args.hdim_v + << "\n nhead_q:" << args.nhead_q + << "\n nhead_k:" << args.nhead_k + << "\n scale_s:" << args.scale_s + << "\n scale_p:" << args.scale_p + << "\n scale_o:" << args.scale_o + << "\n stride_q:" << args.stride_q + << "\n stride_k:" << args.stride_k + << "\n stride_v:" << args.stride_v + << "\n stride_bias:" << args.stride_bias + << "\n stride_o:" << args.stride_o + << "\n nhead_stride_q:" << args.nhead_stride_q + << "\n nhead_stride_k:" << args.nhead_stride_k + << "\n nhead_stride_v:" << args.nhead_stride_v + << "\n nhead_stride_bias:" << args.nhead_stride_bias + << "\n nhead_stride_lse:" << args.nhead_stride_lse + << "\n nhead_stride_o:" << args.nhead_stride_o + << "\n batch_stride_q:" << args.batch_stride_q + << "\n batch_stride_k:" << args.batch_stride_k + << "\n batch_stride_v:" << args.batch_stride_v + << "\n batch_stride_bias:" << args.batch_stride_bias + << "\n batch_stride_lse:" << args.batch_stride_lse + << "\n batch_stride_o:" << args.batch_stride_o + << "\n window_size_left:" << args.window_size_left + << "\n window_size_right:" << args.window_size_right + << "\n mask_type:" << args.mask_type + << std::endl; +#endif + + fmha_fwd_traits traits{ + parameters.head_size, + parameters.head_size, // v head size + GetCkFmhaDataTypeString(), + !parameters.is_prompt, // true, // is_group_mode + true, // is_v_rowmajor ? dim is fastest : seq is fastest + mask.type, + bias_type, + false, // has_lse + false, // do_fp8_static_quant, aka, squant + }; + + ck_tile::stream_config stream_config{ + hip_stream, + false // time_kernel + }; + + auto duration = fmha_fwd(traits, args, stream_config); + if (duration < 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "fmha_fwd internal error"); + } + HIP_RETURN_IF_ERROR(hipGetLastError()); + + return Status::OK(); +#else + return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "GroupQueryAttention requires ck_tile to be enabled"); +#endif +} + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h new file mode 100644 index 0000000000000..ce0de1f761aa5 --- /dev/null +++ b/onnxruntime/contrib_ops/rocm/bert/group_query_attention.h @@ -0,0 +1,38 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include "core/providers/rocm/rocm_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace rocm { + +using namespace onnxruntime::rocm; + +template +class GroupQueryAttention final : public RocmKernel { + public: + GroupQueryAttention(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + int num_heads_; // number of attention heads + int kv_num_heads_; // different for k and v for group query attention + int local_window_size_; + bool is_unidirectional_; + bool is_past_bsnh_; + bool do_rotary_; + bool rotary_interleaved_; + float scale_; + + private: + static std::once_flag arch_checking_; +}; + +} // namespace rocm +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc index 7e5e7d7ee076d..4284b4254f485 100644 --- a/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc @@ -71,6 +71,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MLFloat16, Crop); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, MultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, GroupQueryAttention); +// class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16, GroupQueryAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, DecoderAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, DecoderAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, int32_t, DynamicSlice); @@ -227,6 +229,8 @@ Status RegisterRocmContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + // BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/test/python/transformers/test_flash_attn_rocm.py b/onnxruntime/test/python/transformers/test_flash_attn_rocm.py new file mode 100644 index 0000000000000..fe7e39722237f --- /dev/null +++ b/onnxruntime/test/python/transformers/test_flash_attn_rocm.py @@ -0,0 +1,86 @@ +import platform +import unittest + +import torch +from parameterized import parameterized +from test_flash_attn_cuda import ( + Formats, + gqa_no_past_flash_attention_test_cases, + gqa_past_flash_attention_test_cases, + parity_check_gqa_past, + parity_check_gqa_past_no_buff, + parity_check_gqa_prompt, + parity_check_gqa_prompt_no_buff, +) + +import onnxruntime + + +class TestGQA(unittest.TestCase): + @parameterized.expand(gqa_no_past_flash_attention_test_cases()) + def test_gqa_no_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): + config.ep = "ROCMExecutionProvider" + if not torch.cuda.is_available(): + return + if platform.system() != "Linux": + return + if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): + return + print("------- FLASH ATTENTION (PROMPT CASE) --------") + + parity_check_gqa_prompt( + config, + local=local, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + rtol=0.002, + atol=0.002, + ) + parity_check_gqa_prompt_no_buff( + config, + local=local, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + rtol=0.002, + atol=0.002, + ) + + @parameterized.expand(gqa_past_flash_attention_test_cases()) + def test_gqa_past_flash_attention(self, _, config, local, rotary, rotary_interleaved, packed): + config.ep = "ROCMExecutionProvider" + if not torch.cuda.is_available(): + return + if platform.system() != "Linux": + return + if "CUDAExecutionProvider" in onnxruntime.get_available_providers(): + return + print("------- FLASH ATTENTION (TOKEN GEN) -------") + + parity_check_gqa_past( + config, + local=local, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + rtol=0.002, + atol=0.002, + ) + parity_check_gqa_past_no_buff( + config, + local=local, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + rtol=0.002, + atol=0.002, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml index 7ada4ee6757c9..001062452644e 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml @@ -255,6 +255,33 @@ jobs: arguments: -n $(Agent.Name) -d $HIP_VISIBLE_DEVICES -r $DRIVER_RENDER displayName: 'Check ROCm Environment' + # TODO: move to use ci_build/build.py driven tests + - task: CmdLine@2 + inputs: + script: |- + docker run --rm \ + --security-opt seccomp=unconfined \ + --shm-size=1024m \ + --device=/dev/kfd \ + --device=/dev/dri/renderD$DRIVER_RENDER \ + --group-add $(video) \ + --group-add $(render) \ + --user onnxruntimedev \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + -e OPENBLAS_NUM_THREADS=1 \ + -e OPENMP_NUM_THREADS=1 \ + -e MKL_NUM_THREADS=1 \ + -e PYTHONPATH=/build/$(BuildConfig) \ + onnxruntimetrainingrocm-cibuild-rocm$(RocmVersion)-test \ + /bin/bash -c " + set -ex; \ + pip install -r /onnxruntime_src/tools/ci_build/requirements-transformers-test.txt; \ + pytest /onnxruntime_src/onnxruntime/test/python/transformers/test_flash_attn_rocm.py -v -n 4 --reruns 1" + workingDirectory: $(Build.SourcesDirectory) + displayName: 'Run tranformers tests' + condition: succeededOrFailed() + - task: CmdLine@2 inputs: script: |- diff --git a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile index 59f6c0ab2136c..b94826ae0e4bc 100644 --- a/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile +++ b/tools/ci_build/github/pai/rocm-ci-pipeline-env.Dockerfile @@ -77,7 +77,11 @@ RUN ln -sf /usr/lib/x86_64-linux-gnu/libstdc++.so.6 ${CONDA_ENVIRONMENT_PATH}/bi RUN export MAJOR=$(cut -d '.' -f 1 <<< "$ROCM_VERSION") && \ export MINOR=$(cut -d '.' -f 2 <<< "$ROCM_VERSION") && \ export PATCH=$(cut -d '.' -f 3 <<< "$ROCM_VERSION") && \ - pip install torch==2.0.1 torchvision==0.15.2 -f https://repo.radeon.com/rocm/manylinux/rocm-rel-${MAJOR}.${MINOR}/ && \ + if (( MAJOR >= 6 )); then \ + pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm${MAJOR}.${MINOR} ; \ + else \ + pip install torch==2.0.1 torchvision==0.15.2 -f https://repo.radeon.com/rocm/manylinux/rocm-rel-${MAJOR}.${MINOR}/ ; \ + fi && \ pip install torch-ort --no-dependencies ##### Install Cupy to decrease CPU utilization