Skip to content

Commit

Permalink
[ROCm] prefer hip interfaces over roc during hipify (#22394)
Browse files Browse the repository at this point in the history
### Description
Change the hipify step to remove the -roc option to hipify-perl. This
will prefer hipblas over rocblas. rocblas can still be called directly
such as in TunableOp.

### Motivation and Context
hip interfaces are preferred over roc for porting from cuda to hip.
Calling roc interfaces is meant for ROCm-specific enhancements or
extensions.
  • Loading branch information
jeffdaily authored Oct 15, 2024
1 parent ec7aa63 commit 8c21680
Show file tree
Hide file tree
Showing 42 changed files with 689 additions and 242 deletions.
2 changes: 1 addition & 1 deletion cmake/onnxruntime_kernel_explorer.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ elseif (onnxruntime_USE_ROCM)
)
auto_set_source_files_hip_language(${kernel_explorer_kernel_srcs} ${kernel_explorer_rocm_kernel_srcs})
target_sources(kernel_explorer PRIVATE ${kernel_explorer_rocm_kernel_srcs})
target_compile_definitions(kernel_explorer PRIVATE __HIP_PLATFORM_AMD__=1 __HIP_PLATFORM_HCC__=1)
target_compile_definitions(kernel_explorer PRIVATE __HIP_PLATFORM_AMD__=1 __HIP_PLATFORM_HCC__=1 HIPBLAS_V2)
if (onnxruntime_USE_COMPOSABLE_KERNEL)
target_compile_definitions(kernel_explorer PRIVATE USE_COMPOSABLE_KERNEL)
if (onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE)
Expand Down
5 changes: 3 additions & 2 deletions cmake/onnxruntime_providers_rocm.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

find_package(HIP)
find_package(hiprand REQUIRED)
find_package(rocblas REQUIRED)
find_package(hipblas REQUIRED)
find_package(MIOpen REQUIRED)
find_package(hipfft REQUIRED)

Expand Down Expand Up @@ -50,7 +50,7 @@
find_library(RCCL_LIB rccl REQUIRED)
find_library(ROCTRACER_LIB roctracer64 REQUIRED)
find_package(rocm_smi REQUIRED)
set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen hip::hipfft ${ROCM_SMI_LIBRARY} ${RCCL_LIB} ${ROCTRACER_LIB})
set(ONNXRUNTIME_ROCM_LIBS roc::hipblas MIOpen hip::hipfft ${ROCM_SMI_LIBRARY} ${RCCL_LIB} ${ROCTRACER_LIB})
include_directories(${ROCM_SMI_INCLUDE_DIR})
link_directories(${ROCM_SMI_LIB_DIR})

Expand Down Expand Up @@ -155,6 +155,7 @@

set_target_properties(onnxruntime_providers_rocm PROPERTIES LINKER_LANGUAGE CXX)
set_target_properties(onnxruntime_providers_rocm PROPERTIES FOLDER "ONNXRuntime")
target_compile_definitions(onnxruntime_providers_rocm PRIVATE HIPBLAS_V2)

if (onnxruntime_ENABLE_TRAINING)
target_include_directories(onnxruntime_providers_rocm PRIVATE ${ORTTRAINING_ROOT} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining ${MPI_CXX_INCLUDE_DIRS})
Expand Down
10 changes: 5 additions & 5 deletions include/onnxruntime/core/providers/rocm/rocm_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "core/providers/custom_op_context.h"
#include <hip/hip_runtime.h>
#include <miopen/miopen.h>
#include <rocblas/rocblas.h>
#include <hipblas/hipblas.h>

namespace Ort {

Expand All @@ -16,7 +16,7 @@ namespace Custom {
struct RocmContext : public CustomOpContext {
hipStream_t hip_stream = {};
miopenHandle_t miopen_handle = {};
rocblas_handle rblas_handle = {};
hipblasHandle_t blas_handle = {};

void Init(const OrtKernelContext& kernel_ctx) {
const auto& ort_api = Ort::GetApi();
Expand All @@ -40,11 +40,11 @@ struct RocmContext : public CustomOpContext {

resource = {};
status = ort_api.KernelContext_GetResource(
&kernel_ctx, ORT_ROCM_RESOURCE_VERSION, RocmResource::rocblas_handle_t, &resource);
&kernel_ctx, ORT_ROCM_RESOURCE_VERSION, RocmResource::hipblas_handle_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch rocblas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
ORT_CXX_API_THROW("failed to fetch hipblas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
rblas_handle = reinterpret_cast<rocblas_handle>(resource);
blas_handle = reinterpret_cast<hipblasHandle_t>(resource);
}
};

Expand Down
2 changes: 1 addition & 1 deletion include/onnxruntime/core/providers/rocm/rocm_resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
enum RocmResource : int {
hip_stream_t = rocm_resource_offset,
miopen_handle_t,
rocblas_handle_t
hipblas_handle_t
};
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ Status LaunchLongformerSoftmaxKernel(
cudaDataType_t Atype;
cudaDataType_t Btype;
cudaDataType_t Ctype;
cudaDataType_t resultType;
cublasComputeType_t resultType;
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;

__half one_fp16, zero_fp16;
Expand All @@ -412,7 +412,7 @@ Status LaunchLongformerSoftmaxKernel(
Atype = CUDA_R_16F;
Btype = CUDA_R_16F;
Ctype = CUDA_R_16F;
resultType = CUDA_R_16F;
resultType = CUBLAS_COMPUTE_16F;
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
} else {
one_fp32 = 1.f;
Expand All @@ -423,7 +423,7 @@ Status LaunchLongformerSoftmaxKernel(
Atype = CUDA_R_32F;
Btype = CUDA_R_32F;
Ctype = CUDA_R_32F;
resultType = CUDA_R_32F;
resultType = CUBLAS_COMPUTE_32F;
}

// Strided batch matrix multiply
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ Status LaunchLongformerSoftmaxSimpleKernel(
cudaDataType_t Atype;
cudaDataType_t Btype;
cudaDataType_t Ctype;
cudaDataType_t resultType;
cublasComputeType_t resultType;
cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT;

__half one_fp16, zero_fp16;
Expand All @@ -237,7 +237,7 @@ Status LaunchLongformerSoftmaxSimpleKernel(
Atype = CUDA_R_16F;
Btype = CUDA_R_16F;
Ctype = CUDA_R_16F;
resultType = CUDA_R_16F;
resultType = CUBLAS_COMPUTE_16F;
algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
} else {
one_fp32 = 1.f;
Expand All @@ -248,7 +248,7 @@ Status LaunchLongformerSoftmaxSimpleKernel(
Atype = CUDA_R_32F;
Btype = CUDA_R_32F;
Ctype = CUDA_R_32F;
resultType = CUDA_R_32F;
resultType = CUBLAS_COMPUTE_32F;
}

// Strided batch matrix multiply
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/contrib_ops/rocm/bert/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
Tensor* present = context->Output(kPresentOutputIndex, present_shape);

auto stream = Stream(context);
rocblas_handle rocblas = GetRocblasHandle(context);
hipblasHandle_t hipblas = GetHipblasHandle(context);

using HipT = typename ToHipType<T>::MappedType;
using QkvProjectGeneric = GemmPermuteGenericPipeline<HipT>;
Expand Down Expand Up @@ -113,7 +113,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
auto& params = gemm_permute_params;
params.tuning_ctx = GetTuningContext();
params.stream = context->GetComputeStream();
params.handle = rocblas;
params.handle = hipblas;
params.attention = &attn;
params.device_prop = &device_prop;

Expand Down Expand Up @@ -179,7 +179,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
auto& params = gemm_softmax_gemm_permute_params;
params.tuning_ctx = GetTuningContext();
params.stream = context->GetComputeStream();
params.handle = rocblas;
params.handle = hipblas;
params.attention = &attn;
params.device_prop = &device_prop;
// FIXME: the params.scale seems to be different from AttentionParameters::scale;
Expand Down
16 changes: 8 additions & 8 deletions onnxruntime/contrib_ops/rocm/bert/attention_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ Status DecoderQkvToContext(
const hipDeviceProp_t& prop,
RocmTuningContext* tuning_ctx,
Stream* ort_stream,
rocblas_handle& rocblas,
hipblasHandle_t& hipblas,
const size_t element_size,
const int batch_size,
const int sequence_length,
Expand Down Expand Up @@ -284,7 +284,7 @@ Status DecoderQkvToContext(
const int strideB = sequence_length * head_size;
if (use_past && static_kv) {
ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
tuning_ctx, ort_stream, rocblas,
tuning_ctx, ort_stream, hipblas,
blas::BlasOp::Trans, blas::BlasOp::NonTrans,
kv_sequence_length, sequence_length, head_size,
/*alpha=*/rsqrt_head_size,
Expand All @@ -295,7 +295,7 @@ Status DecoderQkvToContext(
BN));
} else {
ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
tuning_ctx, ort_stream, rocblas,
tuning_ctx, ort_stream, hipblas,
blas::BlasOp::Trans, blas::BlasOp::NonTrans,
kv_sequence_length, sequence_length, head_size,
/*alpha=*/rsqrt_head_size,
Expand All @@ -320,7 +320,7 @@ Status DecoderQkvToContext(
// compute P*V (as V*P), and store in scratch3: BxNxSxH
if (use_past && static_kv) {
ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
tuning_ctx, ort_stream, rocblas,
tuning_ctx, ort_stream, hipblas,
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
head_size, sequence_length, kv_sequence_length,
/*alpha=*/1.0f,
Expand All @@ -331,7 +331,7 @@ Status DecoderQkvToContext(
BN));
} else {
ORT_RETURN_IF_ERROR(blas::column_major::StridedBatchedGemm(
tuning_ctx, ort_stream, rocblas,
tuning_ctx, ort_stream, hipblas,
blas::BlasOp::NonTrans, blas::BlasOp::NonTrans,
head_size, sequence_length, kv_sequence_length,
/*alpha=*/1.0f,
Expand All @@ -351,7 +351,7 @@ Status LaunchDecoderAttentionKernel(
const hipDeviceProp_t& prop,
RocmTuningContext* tuning_ctx,
Stream* stream,
rocblas_handle& rocblas,
hipblasHandle_t& hipblas,
const size_t element_size,
const int batch_size,
const int sequence_length,
Expand All @@ -378,7 +378,7 @@ Status LaunchDecoderAttentionKernel(
prop,
tuning_ctx,
stream,
rocblas,
hipblas,
element_size,
batch_size,
sequence_length,
Expand All @@ -405,7 +405,7 @@ Status LaunchDecoderAttentionKernel(
prop,
tuning_ctx,
stream,
rocblas,
hipblas,
element_size,
batch_size,
sequence_length,
Expand Down
105 changes: 50 additions & 55 deletions onnxruntime/contrib_ops/rocm/bert/attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#pragma once

#include <hip/hip_fp16.h>
#include <rocblas/rocblas.h>
#include <hipblas/hipblas.h>
#include "contrib_ops/cpu/bert/attention_common.h"
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "core/providers/rocm/tunable/rocm_tunable.h"
Expand Down Expand Up @@ -70,64 +70,59 @@ Status LaunchConcatTensorToTensor(hipStream_t stream,
const half* tensor_add,
half* tensor_out);

inline rocblas_status _compat_rocblas_gemm_strided_batched_ex(rocblas_handle handle,
rocblas_operation transa,
rocblas_operation transb,
int m,
int n,
int k,
const void* alpha,
const void* A,
rocblas_datatype a_type,
rocblas_int lda,
rocblas_stride stride_A,
const void* b,
rocblas_datatype b_type,
rocblas_int ldb,
rocblas_stride stride_b,
const void* beta,
void* c,
rocblas_datatype c_type,
rocblas_int ldc,
rocblas_stride stride_c,
rocblas_int batch_count,
rocblas_datatype compute_type,
rocblas_gemm_algo algo) {
return rocblas_gemm_strided_batched_ex(handle,
transa,
transb,
m, // m
n, // n
k, // k
alpha, // alpha
A, // A
a_type, // A type
lda, // lda
stride_A, // strideA
b, // B
b_type, // B type
ldb, // ldb
stride_b, // strideB
beta, // beta
c, // C
c_type, // C type
ldc, // ldc
stride_c, // strideC
c, // D = C
c_type, // D type = C type
ldc, // ldd = ldc
stride_c, // strideD = strideC
batch_count, // batch count
compute_type,
algo,
0, 0);
inline hipblasStatus_t _compat_hipblas_gemm_strided_batched_ex(hipblasHandle_t handle,
hipblasOperation_t transa,
hipblasOperation_t transb,
int m,
int n,
int k,
const void* alpha,
const void* A,
hipDataType a_type,
int lda,
hipblasStride stride_A,
const void* b,
hipDataType b_type,
int ldb,
hipblasStride stride_b,
const void* beta,
void* c,
hipDataType c_type,
int ldc,
hipblasStride stride_c,
int batch_count,
hipblasComputeType_t compute_type,
hipblasGemmAlgo_t algo) {
return hipblasGemmStridedBatchedEx(handle,
transa,
transb,
m, // m
n, // n
k, // k
alpha, // alpha
A, // A
a_type, // A type
lda, // lda
stride_A, // strideA
b, // B
b_type, // B type
ldb, // ldb
stride_b, // strideB
beta, // beta
c, // C
c_type, // C type
ldc, // ldc
stride_c, // strideC
batch_count, // batch count
compute_type,
algo);
}

// Compatible for CublasMathModeSetter
class CompatRocblasMathModeSetter {
class CompatHipblasMathModeSetter {
public:
CompatRocblasMathModeSetter(const hipDeviceProp_t&,
rocblas_handle,
CompatHipblasMathModeSetter(const hipDeviceProp_t&,
hipblasHandle_t,
int) {
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ struct GemmPermuteParams : onnxruntime::rocm::tunable::OpParams {
return MakeString("M", m, "_N", n, "_K", k, "_B", batch);
}

rocblas_handle handle;
hipblasHandle_t handle;
const AttentionParameters* attention;
const hipDeviceProp_t* device_prop;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ struct GemmSoftmaxGemmPermuteParams : onnxruntime::rocm::tunable::OpParams {
return {m, n, k, o, batch};
}

rocblas_handle handle;
hipblasHandle_t handle;
const RocmAttentionParameters* attention;
const hipDeviceProp_t* device_prop;

Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/rocm/bert/decoder_attention_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#pragma once

#include <hip/hip_fp16.h>
#include <rocblas/rocblas.h>
#include <hipblas/hipblas.h>
#include "contrib_ops/cpu/bert/attention_common.h"
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "core/providers/rocm/tunable/rocm_tunable.h"
Expand All @@ -17,7 +17,7 @@ Status LaunchDecoderAttentionKernel(
const hipDeviceProp_t& prop, // Device Properties
RocmTuningContext* tuning_ctx, // context for tuning
Stream* stream, // ORT Stream
rocblas_handle& rocblas, // Rocblas handle
hipblasHandle_t& hipblas, // hipblas handle
const size_t element_size, // Element size of input tensor
const int batch_size, // Batch size (B)
const int sequence_length, // Sequence length (S)
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Status GemmFastGelu<T>::ComputeInternal(OpKernelContext* ctx) const {
using onnxruntime::rocm::tunable::blas::BlasOp;

return blas::row_major::GemmFastGelu(
GetTuningContext(), ctx->GetComputeStream(), GetRocblasHandle(ctx),
GetTuningContext(), ctx->GetComputeStream(), GetHipblasHandle(ctx),
transa ? BlasOp::Trans : BlasOp::NonTrans,
transb ? BlasOp::Trans : BlasOp::NonTrans,
helper.M(), helper.N(), helper.K(),
Expand Down
Loading

0 comments on commit 8c21680

Please sign in to comment.