Skip to content

Commit

Permalink
[ROCm] Add SkipGroupNorm for ROCm EP (#19303)
Browse files Browse the repository at this point in the history
Add SkipGroupNorm for ROCm EP.

---------

Co-authored-by: Peixuan Zuo <[email protected]@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
  • Loading branch information
PeixuanZuo and Peixuan Zuo authored Feb 21, 2024
1 parent 8fadc6c commit 6226c5f
Show file tree
Hide file tree
Showing 18 changed files with 382 additions and 732 deletions.
5 changes: 0 additions & 5 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,7 @@ set(contrib_ops_excluded_files
"bert/packed_multihead_attention.cc"
"bert/packed_multihead_attention_impl.h"
"bert/packed_multihead_attention_impl.cu"
"diffusion/group_norm.cc"
"diffusion/group_norm_impl.cu"
"diffusion/group_norm_impl.h"
"diffusion/group_norm_impl_kernel.cuh"
"diffusion/group_norm_common_base.h"
"diffusion/group_norm_common_base.cc"
"diffusion/nhwc_conv.cc"
"math/gemm_float8.cc"
"math/gemm_float8.cu"
Expand Down
152 changes: 0 additions & 152 deletions onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc

This file was deleted.

35 changes: 22 additions & 13 deletions onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,40 +26,49 @@ namespace rocm {

using onnxruntime::rocm::CKDataTypeAdaptor;

using Swish = ck::tensor_operation::element_wise::Swish;
// The SiLU function is a special case of Swish function,
// The Swish function is parametrized by b, which is set to 1.0 for SiLU. They are defined as:
// SiLU(x) = x * sigmoid(x)
// Swish(x) = x * sigmoid(bx)
// The default value of b is 1.0 in ck::tensor_operation::element_wise::Swish function. We treat them as the same function here.
using Silu = ck::tensor_operation::element_wise::Swish;
using Pass = ck::tensor_operation::element_wise::PassThrough;

constexpr int Rank = 5;
constexpr int NumReduceDim = 3;

template <typename T, typename AccT, bool WithSwish>
template <typename T, typename AccT, bool WithSilu>
auto GetCKGroupNormNHWCTypeStringAndOps() {
using XDataType = typename CKDataTypeAdaptor<T>::type;
using YDataType = typename CKDataTypeAdaptor<T>::type;
using SaveMeanInvStdDataType = typename CKDataTypeAdaptor<AccT>::type;
using GammaDataType = float;
using BetaDataType = float;

using Activation = std::conditional_t<WithSwish, Swish, Pass>;
using Activation = std::conditional_t<WithSilu, Silu, Pass>;

std::vector<std::pair<std::string, onnxruntime::rocm::tunable::Op<GroupNormNHWCParams<T>>>> ret;
std::vector<std::pair<std::string, onnxruntime::rocm::tunable::Op<GroupNormNHWCTunableParams<T>>>> ret;
for (auto&& impl : internal::GetDeviceGroupNormInstances<XDataType, GammaDataType, BetaDataType, YDataType,
SaveMeanInvStdDataType, Activation, Rank, NumReduceDim>()) {
std::string swish_suffix = WithSwish ? "_Swish" : "_Pass";
auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + swish_suffix;
std::string silu_suffix = WithSilu ? "_Silu" : "_Pass";
auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + silu_suffix;
auto invoker = impl->MakeInvokerPointer();

auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GroupNormNHWCParams<T>* params) -> Status {
if constexpr (WithSwish) {
auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)](
const GroupNormNHWCTunableParams<T>* params) -> Status {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr),
"Input skip or bias is not supported by composable kernel.");
if constexpr (WithSilu) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
!params->withSwish, "Swish version only support groupnorm with swish");
!params->use_silu, "Silu version only support groupnorm with silu");
} else {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->withSwish, "Pass version only support groupnorm without swish");
params->use_silu, "Pass version only support groupnorm without silu");
}
std::vector<ck::index_t> in_lengths{params->n, params->h, params->w, params->groups, params->cPerGroup};
std::vector<ck::index_t> in_out_strides{params->h * params->w * params->c, params->w * params->c, params->c, params->cPerGroup, 1};
std::vector<ck::index_t> gamma_beta_strides{0, 0, 0, params->cPerGroup, 1};
std::vector<ck::index_t> in_lengths{params->n, params->h, params->w, params->groups, params->channels_per_group};
std::vector<ck::index_t> in_out_strides{params->h * params->w * params->c, params->w * params->c,
params->c, params->channels_per_group, 1};
std::vector<ck::index_t> gamma_beta_strides{0, 0, 0, params->channels_per_group, 1};
std::vector<ck::index_t> reduce_dims{1, 2, 4};

auto activation = Activation{};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace internal {
using F16 = ck::half_t;
using F32 = float;

using Swish = ck::tensor_operation::element_wise::Swish;
using Silu = ck::tensor_operation::element_wise::Swish;
using Pass = ck::tensor_operation::element_wise::PassThrough;

using ck::tensor_operation::device::DeviceNormalizationFwd; // the interface
Expand Down Expand Up @@ -101,9 +101,9 @@ GetDeviceGroupNormInstances() {

template <>
std::vector<std::unique_ptr<DeviceNormalizationFwd<
F16, F32, F32, F16, F32, Swish, 5, 3>>>
F16, F32, F32, F16, F32, Silu, 5, 3>>>
GetDeviceGroupNormInstances<
F16, F32, F32, F16, F32, Swish, 5, 3>();
F16, F32, F32, F16, F32, Silu, 5, 3>();

template <>
std::vector<std::unique_ptr<DeviceNormalizationFwd<
Expand All @@ -113,9 +113,9 @@ GetDeviceGroupNormInstances<

template <>
std::vector<std::unique_ptr<DeviceNormalizationFwd<
F32, F32, F32, F32, F32, Swish, 5, 3>>>
F32, F32, F32, F32, F32, Silu, 5, 3>>>
GetDeviceGroupNormInstances<
F32, F32, F32, F32, F32, Swish, 5, 3>();
F32, F32, F32, F32, F32, Silu, 5, 3>();

template <>
std::vector<std::unique_ptr<DeviceNormalizationFwd<
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ namespace rocm {
namespace internal {

template <>
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F32, F32, F16, F32, Swish, 5, 3>>>
GetDeviceGroupNormInstances<F16, F32, F32, F16, F32, Swish, 5, 3>() {
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F32, F32, F16, F32, Swish, 5, 3>>> instances;
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F32, F32, F16, F32, Silu, 5, 3>>>
GetDeviceGroupNormInstances<F16, F32, F32, F16, F32, Silu, 5, 3>() {
std::vector<std::unique_ptr<DeviceNormalizationFwd<F16, F32, F32, F16, F32, Silu, 5, 3>>> instances;
ck::tensor_operation::device::instance::add_device_operation_instances(
instances,
device_normalization_f16_instances<Swish, 5, 3>{});
device_normalization_f16_instances<Silu, 5, 3>{});

return instances;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ namespace rocm {
namespace internal {

template <>
std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Swish, 5, 3>>>
GetDeviceGroupNormInstances<F32, F32, F32, F32, F32, Swish, 5, 3>() {
std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Swish, 5, 3>>> instances;
std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Silu, 5, 3>>>
GetDeviceGroupNormInstances<F32, F32, F32, F32, F32, Silu, 5, 3>() {
std::vector<std::unique_ptr<DeviceNormalizationFwd<F32, F32, F32, F32, F32, Silu, 5, 3>>> instances;
ck::tensor_operation::device::instance::add_device_operation_instances(
instances,
device_normalization_f32_instances<Swish, 5, 3>{});
device_normalization_f32_instances<Silu, 5, 3>{});

return instances;
}
Expand Down
Loading

0 comments on commit 6226c5f

Please sign in to comment.