Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUDA EP vs ROCM EP hipify audit #17776

Merged
merged 18 commits into from
Oct 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/onnxruntime_providers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -202,4 +202,4 @@ endif()

if (onnxruntime_USE_AZURE)
include(onnxruntime_providers_azure.cmake)
endif()
endif()
4 changes: 2 additions & 2 deletions cmake/onnxruntime_providers_migraphx.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_common onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface)
add_dependencies(onnxruntime_providers_migraphx onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES})
target_link_libraries(onnxruntime_providers_migraphx PRIVATE ${migraphx_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} onnx flatbuffers::flatbuffers Boost::mp11 safeint_interface)
target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR})
target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime)
set_target_properties(onnxruntime_providers_migraphx PROPERTIES LINKER_LANGUAGE CXX)
set_target_properties(onnxruntime_providers_migraphx PROPERTIES FOLDER "ONNXRuntime")
target_compile_definitions(onnxruntime_providers_migraphx PRIVATE ONNXIFI_BUILD_LIBRARY=1)
Expand Down Expand Up @@ -72,4 +72,4 @@
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
)
)
5 changes: 3 additions & 2 deletions cmake/onnxruntime_providers_rocm.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
find_package(hiprand REQUIRED)
find_package(rocblas REQUIRED)
find_package(MIOpen REQUIRED)
find_package(hipfft REQUIRED)

# MIOpen version
if(NOT DEFINED ENV{MIOPEN_PATH})
Expand Down Expand Up @@ -48,7 +49,7 @@

find_library(RCCL_LIB rccl REQUIRED)
find_library(ROCTRACER_LIB roctracer64 REQUIRED)
set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen ${RCCL_LIB} ${ROCTRACER_LIB})
set(ONNXRUNTIME_ROCM_LIBS roc::rocblas MIOpen hip::hipfft ${RCCL_LIB} ${ROCTRACER_LIB})

file(GLOB_RECURSE onnxruntime_providers_rocm_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/rocm/*.h"
Expand Down Expand Up @@ -219,4 +220,4 @@
install(TARGETS onnxruntime_providers_rocm
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR})
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR})
27 changes: 0 additions & 27 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,6 @@ set(contrib_ops_excluded_files
"diffusion/group_norm_impl.cu"
"diffusion/group_norm_impl.h"
"diffusion/nhwc_conv.cc"
"math/complex_mul.cc"
"math/complex_mul.h"
"math/complex_mul_impl.cu"
"math/complex_mul_impl.h"
"math/cufft_plan_cache.h"
"math/fft_ops.cc"
"math/fft_ops.h"
"math/fft_ops_impl.cu"
"math/fft_ops_impl.h"
"quantization/attention_quantization.cc"
"quantization/attention_quantization.h"
"quantization/attention_quantization_impl.cu"
Expand Down Expand Up @@ -86,19 +77,6 @@ set(contrib_ops_excluded_files
"quantization/qordered_ops/qordered_unary_ops.cc"
"quantization/qordered_ops/qordered_unary_ops_impl.h"
"quantization/qordered_ops/qordered_unary_ops_impl.cu"
"tensor/crop.cc"
"tensor/crop.h"
"tensor/crop_impl.cu"
"tensor/crop_impl.h"
"tensor/dynamicslice.cc"
"tensor/image_scaler.cc"
"tensor/image_scaler.h"
"tensor/image_scaler_impl.cu"
"tensor/image_scaler_impl.h"
"transformers/greedy_search.cc"
"transformers/greedy_search.h"
"conv_transpose_with_dynamic_pads.cc"
"conv_transpose_with_dynamic_pads.h"
"cuda_contrib_kernels.cc"
"cuda_contrib_kernels.h"
"inverse.cc"
Expand All @@ -119,10 +97,6 @@ endif()

set(provider_excluded_files
"atomic/common.cuh"
"controlflow/loop.cc"
"controlflow/loop.h"
"controlflow/scan.cc"
"controlflow/scan.h"
"cu_inc/common.cuh"
"math/einsum_utils/einsum_auxiliary_ops.cc"
"math/einsum_utils/einsum_auxiliary_ops.h"
Expand Down Expand Up @@ -170,7 +144,6 @@ set(provider_excluded_files
"cuda_memory_check.h"
"cuda_fence.cc"
"cuda_fence.h"
"cuda_fwd.h"
"cuda_kernel.h"
"cuda_pch.cc"
"cuda_pch.h"
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cuda/transformers/greedy_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@ GreedySearch::GreedySearch(const OpKernelInfo& info)

SetConsoleDumper(&g_cuda_dumper_greedysearch);

#ifndef USE_ROCM
cuda_device_prop_ = &reinterpret_cast<const CUDAExecutionProvider*>(info.GetExecutionProvider())->GetDeviceProp();

cuda_device_arch_ = static_cast<const cudaDeviceProp*>(cuda_device_prop_)->major * 100 +
static_cast<const cudaDeviceProp*>(cuda_device_prop_)->minor * 10;
#endif
}

Status GreedySearch::ComputeInternal(OpKernelContext* context) const {
Expand Down
146 changes: 91 additions & 55 deletions onnxruntime/contrib_ops/rocm/rocm_contrib_kernels.cc

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cuda/cuda_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ struct CUDA_Provider : Provider {
cuda_options.arena_extend_strategy = internal_options.arena_extend_strategy;
cuda_options.do_copy_in_default_stream = internal_options.do_copy_in_default_stream;
cuda_options.has_user_compute_stream = internal_options.has_user_compute_stream;
// The 'has_user_compute_stream' of the OrtCUDAProviderOptionsV2 instance can be set byC API UpdateCUDAProviderOptionsWithValue() as well.
// The 'has_user_compute_stream' of the OrtCUDAProviderOptionsV2 instance can be set by C API UpdateCUDAProviderOptionsWithValue() as well.
// We only set the 'has_user_compute_stream' of the OrtCUDAProviderOptionsV2 instance if it is provided in options
if (options.find("has_user_compute_stream") != options.end()) {
cuda_options.user_compute_stream = internal_options.user_compute_stream;
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/cuda/nn/conv_transpose.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ Status ConvTranspose<T>::DoConvTranspose(OpKernelContext* context, bool dynamic_
}
s_.y_dims = gsl::make_span(y_dims);

if (w_dims_changed)
if (w_dims_changed) {
ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType<CudaT>()));
}

// Special case when there is a dim value of 0 in the shape.
// Return only after we have cached the following for subsequent runs :
Expand Down
20 changes: 12 additions & 8 deletions onnxruntime/core/providers/rocm/cu_inc/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
// Licensed under the MIT License.

#pragma once
#include <type_traits>
#include <memory>
#include <stdint.h>
#include <vector>
#include <mutex>
Expand Down Expand Up @@ -294,6 +292,14 @@ __device__ __inline__ T _Gelu(T a) {
return a * _Normcdf(a);
}

template <>
__device__ __inline__ half _Gelu(half a) {
const half kHalf = half(0.5);
const half kOne = half(1.0);
const half kAlpha = half(M_SQRT1_2);
return a * kHalf * (kOne + _Erf(kAlpha * a));
}

template <typename T>
__device__ __inline__ T _Mod(T a, T b) {
T r = a % b;
Expand Down Expand Up @@ -348,21 +354,19 @@ struct GridDim {
};
};

// aligned vector generates vectorized load/store
// aligned vector generates vectorized load/store on ROCM
template <typename T, int vec_size>
struct alignas(sizeof(T) * vec_size) aligned_vector {
T val[vec_size];
};

#define CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N) \
#define CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(id, N) \
HIP_LONG id = blockDim.x * blockIdx.x + threadIdx.x; \
if (id >= N) \
if (id >= N) \
return;

// HIP_KERNEL_ASSERT is a macro that wraps an assert() call inside rocm kernels.
// TODO ROCM added support recently, should verify.
#define HIP_KERNEL_ASSERT(...)
// #define HIP_KERNEL_ASSERT(...) assert(__VA_ARGS__)
#define HIP_KERNEL_ASSERT(...) assert(__VA_ARGS__)

// WARP related definitions and functions
constexpr int GPU_WARP_SIZE = warpSize;
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/rocm/fpgeneric.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,14 @@ rocblas_status rocblasTransposeHelper(hipStream_t stream, rocblas_handle, rocbla
rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle, int n, const half* x, int incx, half* y, int incy) {
dim3 dimGrid((unsigned int)(n + COPY_BLOCK_DIM - 1) / COPY_BLOCK_DIM, 1, 1);
dim3 dimBlock(COPY_BLOCK_DIM, 1, 1);
CopyVectorHalf<<<dim3(dimGrid), dim3(dimBlock), 0, stream>>>(x, incx, y, incy, n);
CopyVectorHalf<<<dimGrid, dimBlock, 0, stream>>>(x, incx, y, incy, n);
return rocblas_status_success;
}

rocblas_status rocblasCopyHelper(hipStream_t stream, rocblas_handle, int n, const onnxruntime::BFloat16* x, int incx,
onnxruntime::BFloat16* y, int incy) {
dim3 dimGrid((unsigned int)(n + COPY_BLOCK_DIM - 1) / COPY_BLOCK_DIM, 1, 1);
dim3 dimBlock(COPY_BLOCK_DIM, 1, 1);
CopyVectorBFloat16<<<dim3(dimGrid), dim3(dimBlock), 0, stream>>>(x, incx, y, incy, n);
CopyVectorBFloat16<<<dimGrid, dimBlock, 0, stream>>>(x, incx, y, incy, n);
return rocblas_status_success;
}
34 changes: 15 additions & 19 deletions onnxruntime/core/providers/rocm/gpu_data_transfer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
// Licensed under the MIT License.

#include "core/providers/shared_library/provider_api.h"
#include "core/providers/rocm/rocm_common.h"

#include "core/providers/rocm/gpu_data_transfer.h"
#include "core/providers/rocm/rocm_common.h"

// use default stream for copy for now, to avoid racing in BFC arena as in issue #4829
// note this may cause some models to run slower if there are ops running on CPU
// so we leave it as optional, in case user need the previous behavior
// a full fix to BFC arena is being looked at, and once it's in, we can revert this change
namespace onnxruntime {
GPUDataTransfer::GPUDataTransfer() {}

GPUDataTransfer::~GPUDataTransfer() {}

bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const {
return src_device.Type() == OrtDevice::GPU || src_device.MemType() == OrtDevice::MemType::HIP_PINNED ||
dst_device.Type() == OrtDevice::GPU || dst_device.MemType() == OrtDevice::MemType::HIP_PINNED;
Expand All @@ -34,12 +35,12 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const
} else {
// copy from other CPU memory to GPU, this is blocking
HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice));
HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); // TODO: still need stream sync? since already blocking
HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr));
}
} else if (src_device.Type() == OrtDevice::GPU) {
// copying from GPU to CPU memory, this is blocking
HIP_RETURN_IF_ERROR(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost));
HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr)); // TODO: still need stream sync? since already blocking
HIP_RETURN_IF_ERROR(hipStreamSynchronize(nullptr));
} else {
// copying between cpu memory
memcpy(dst_data, src_data, bytes);
Expand All @@ -57,34 +58,29 @@ common::Status GPUDataTransfer::CopyTensorAsync(const Tensor& src, Tensor& dst,
auto& dst_device = dst.Location().device;

if (dst_device.Type() == OrtDevice::GPU) {
if (src_device.Type() == OrtDevice::CPU && src_device.MemType() == OrtDevice::MemType::HIP_PINNED) {
if (src_device.Type() == OrtDevice::CPU) {
// copy from pinned memory to GPU, this is non-blocking
HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast<hipStream_t>(stream.GetHandle())));
} else if (src_device.Type() == OrtDevice::GPU) {
// copying between GPU, this is non-blocking
// Copy only if the two addresses are different.
if (dst_data != src_data) {
HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, static_cast<hipStream_t>(stream.GetHandle())));
}
} else {
// copy from other CPU memory to GPU, this is blocking
HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, static_cast<hipStream_t>(stream.GetHandle())));
HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast<hipStream_t>(stream.GetHandle())));
}
} else if (src_device.Type() == OrtDevice::GPU) {
if (dst_device.Type() == OrtDevice::CPU && dst_device.MemType() == OrtDevice::MemType::HIP_PINNED) {
if (dst_device.Type() == OrtDevice::CPU) {
// copying from GPU to pinned memory, this is non-blocking
HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast<hipStream_t>(stream.GetHandle())));
} else {
// copying from GPU to CPU memory, this is blocking
HIP_RETURN_IF_ERROR(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, static_cast<hipStream_t>(stream.GetHandle())));
HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast<hipStream_t>(stream.GetHandle())));
}
} else {
// copying between cpu memory
if (src_device.MemType() == OrtDevice::MemType::CUDA_PINNED) {
// sync the stream first to make sure the data arrived
HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast<hipStream_t>(stream.GetHandle())));
}
memcpy(dst_data, src_data, bytes);
}

return Status::OK();
}

} // namespace onnxruntime
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/rocm/gpu_data_transfer.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ namespace onnxruntime {

class GPUDataTransfer : public IDataTransfer {
public:
GPUDataTransfer() = default;
~GPUDataTransfer() = default;
GPUDataTransfer();
~GPUDataTransfer();

bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override;

Expand Down
21 changes: 12 additions & 9 deletions onnxruntime/core/providers/rocm/integer_gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
#include <rocblas/rocblas.h>
#include "core/providers/rocm/shared_inc/integer_gemm.h"

#include "core/common/safeint.h"
#include "core/providers/rocm/rocm_common.h"
#include "core/providers/rocm/shared_inc/rocm_call.h"

namespace onnxruntime {
namespace rocm {

inline int roundoff(int v, int d) {
constexpr int roundoff(int v, int d) {
return (v + d - 1) / d * d;
}

Expand All @@ -21,35 +22,37 @@ Status GemmInt8(int m, int n, int k,
const RocmKernel* rocm_kernel, onnxruntime::Stream* ort_stream) {
ORT_ENFORCE(a != nullptr && b != nullptr && c != nullptr, "input matrix should not be null");
ORT_ENFORCE(rocm_kernel != nullptr, "kernel is null");
ORT_ENFORCE(ort_stream != nullptr, "Rocm kernel must have the stream instance");

hipStream_t stream = ort_stream ? static_cast<hipStream_t>(ort_stream->GetHandle()) : nullptr;
hipStream_t stream = static_cast<hipStream_t>(ort_stream->GetHandle());

// pad A and B to make their leading dimension be multiples of 32
// because cublasGemmEx requires:
// because rocblas_gemm_ex requires:
// 1. leading dimension is multiples of 4
// 2. A, B is 32-bit aligned

const int mask = 0x1F;
constexpr int mask = 0x1F;
int lda_aligned = lda;
IAllocatorUniquePtr<int8_t> a_padded;
if ((mask & lda_aligned) != 0) {
lda_aligned = roundoff(lda, 32);
a_padded = rocm_kernel->GetScratchBuffer<int8_t>(m * lda_aligned, ort_stream);
a_padded = rocm_kernel->GetScratchBuffer<int8_t>(SafeInt<size_t>(m) * lda_aligned, ort_stream);
HIP_RETURN_IF_ERROR(hipMemcpy2DAsync(a_padded.get(), lda_aligned, a, lda, k, m, hipMemcpyDeviceToDevice, stream));
}

int ldb_aligned = ldb;
IAllocatorUniquePtr<int8_t> b_padded;
if ((mask & ldb_aligned) != 0) {
ldb_aligned = roundoff(ldb, 32);
b_padded = rocm_kernel->GetScratchBuffer<int8_t>(k * ldb_aligned, ort_stream);
b_padded = rocm_kernel->GetScratchBuffer<int8_t>(SafeInt<size_t>(k) * ldb_aligned, ort_stream);
HIP_RETURN_IF_ERROR(hipMemcpy2DAsync(b_padded.get(), ldb_aligned, b, ldb, n, k, hipMemcpyDeviceToDevice, stream));
}

RocmStream* ort_rocm_stream = static_cast<RocmStream*>(ort_stream);
auto handle = ort_rocm_stream->rocblas_handle_;
auto* ort_rocm_stream = dynamic_cast<RocmStream*>(ort_stream);
auto rocblas = ort_rocm_stream->rocblas_handle_;

ROCBLAS_RETURN_IF_ERROR(rocblas_gemm_ex(
handle,
rocblas,
rocblas_operation_none, rocblas_operation_none,
n, m, k,
&alpha,
Expand Down
5 changes: 2 additions & 3 deletions onnxruntime/core/providers/rocm/math/einsum.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ class Einsum final : public onnxruntime::Einsum {
Einsum(const OpKernelInfo& info) : onnxruntime::Einsum(info) {
// We need to cast away the const as PerThreadRocblasHandle() is currently a non-const method
// TODO: Clean up the ROCMExecutionProvider interface to avoid this
rocm_ep_ = const_cast<ROCMExecutionProvider*>(
static_cast<const ROCMExecutionProvider*>(info.GetExecutionProvider()));
rocm_ep_ = static_cast<const ROCMExecutionProvider*>(info.GetExecutionProvider());
}

Status Compute(OpKernelContext* context) const override;
Expand All @@ -32,7 +31,7 @@ class Einsum final : public onnxruntime::Einsum {
using onnxruntime::Einsum::equation_;

// We need to access to the ROCM EP instance to get the rocblas/miopen handles
ROCMExecutionProvider* rocm_ep_;
const ROCMExecutionProvider* rocm_ep_;
};

} // namespace rocm
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,18 @@ namespace EinsumOp {
// Holds ROCM assets required for ROCM ops that need to be executed as part of the Einsum flow
struct EinsumRocmAssets {
explicit EinsumRocmAssets(rocblas_handle rocblas_handle,
ROCMExecutionProvider* rocm_ep,
Stream* ort_stream,
AllocatorPtr gpu_allocator) : rocblas_handle_(rocblas_handle),
rocm_ep_(rocm_ep),
ort_stream_(ort_stream),
gpu_allocator_(gpu_allocator) {}
const ROCMExecutionProvider* rocm_ep,
Stream* ort_stream, AllocatorPtr gpu_allocator) : rocblas_handle_(rocblas_handle),
rocm_ep_(rocm_ep),
ort_stream_(ort_stream),
gpu_allocator_(gpu_allocator) {}

hipStream_t GetRocmStream() {
return ort_stream_ ? static_cast<hipStream_t>(ort_stream_->GetHandle()) : nullptr;
}

rocblas_handle rocblas_handle_;
ROCMExecutionProvider* rocm_ep_;
const ROCMExecutionProvider* rocm_ep_;
Stream* ort_stream_;
AllocatorPtr gpu_allocator_;
};
Expand Down
Loading
Loading