diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh index a726c4a200ced..1bfc08cfcfe98 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh @@ -49,8 +49,10 @@ auto GetCKGroupNormNHWCTypeStringAndOps() { auto type_string = onnxruntime::MakeString(impl->GetTypeString()) + swish_suffix; auto invoker = impl->MakeInvokerPointer(); - auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)](const GroupNormNHWCTunableParams* params) -> Status { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr), "Skip is not supported"); + auto ck_group_norm_op = [impl = std::move(impl), invoker = std::move(invoker)]( + const GroupNormNHWCTunableParams* params) -> Status { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr), + "Skip is not supported"); if constexpr (WithSwish) { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( !params->use_silu, "Swish version only support groupnorm with swish"); @@ -59,7 +61,8 @@ auto GetCKGroupNormNHWCTypeStringAndOps() { params->use_silu, "Pass version only support groupnorm without swish"); } std::vector in_lengths{params->n, params->h, params->w, params->groups, params->channels_per_group}; - std::vector in_out_strides{params->h * params->w * params->c, params->w * params->c, params->c, params->channels_per_group, 1}; + std::vector in_out_strides{params->h * params->w * params->c, params->w * params->c, + params->c, params->channels_per_group, 1}; std::vector gamma_beta_strides{0, 0, 0, params->channels_per_group, 1}; std::vector reduce_dims{1, 2, 4}; diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h index 0db8696941fe9..5679bb54e9c61 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_common.h @@ -41,7 +41,8 @@ struct GroupNormNHWCTunableParams : OpParams, GroupNormNHWCParams { std::string Signature() const override { std::string swish_suffix = this->use_silu ? "_silu" : "_pass"; - std::string sig = std::to_string(this->n) + "_" + std::to_string(this->h * this->w) + "_" + std::to_string(this->c) + "_" + std::to_string(this->groups) + swish_suffix; + std::string sig = std::to_string(this->n) + "_" + std::to_string(this->h * this->w) + "_" + + std::to_string(this->c) + "_" + std::to_string(this->groups) + swish_suffix; return sig; } }; diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu index ca9e19c240205..142aaf14e8d2d 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_impl.cu @@ -34,8 +34,8 @@ Status LaunchGroupNormKernel( bool broadcast_skip, int channels_per_block) { GroupNormNHWCTunableParams params(tuning_ctx, ort_stream, output, add_out, input, skip, bias, gamma, beta, - reinterpret_cast(workspace), epsilon, batch_size, num_channels, height, width, - num_groups, use_silu, broadcast_skip, channels_per_block); + reinterpret_cast(workspace), epsilon, batch_size, num_channels, + height, width, num_groups, use_silu, broadcast_skip, channels_per_block); if (params.channels_per_block % params.channels_per_group != 0 || params.channels_per_block > kMaxSize || @@ -59,17 +59,17 @@ Status LaunchGroupNormKernel( return GroupNormNHWCStaticSelection(¶ms); } -template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, half* output, half* add_out, - const half* input, const half* skip, const half* bias, const float* gamma, - const float* beta, void* workspace, float epsilon, int batch_size, - int num_channels, int height, int width, int num_groups, bool use_silu, - bool broadcast_skip, int channels_per_block); +template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, half* output, + half* add_out, const half* input, const half* skip, const half* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, + int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block); -template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, float* output, float* add_out, - const float* input, const float* skip, const float* bias, const float* gamma, - const float* beta, void* workspace, float epsilon, int batch_size, - int num_channels, int height, int width, int num_groups, bool use_silu, - bool broadcast_skip, int channels_per_block); +template Status LaunchGroupNormKernel(RocmTuningContext* tuning_ctx, Stream* stream, float* output, + float* add_out, const float* input, const float* skip, const float* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, + int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block); } // namespace rocm } // namespace contrib diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh index d8d909daa0417..551109f407d7a 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_triton.cuh @@ -46,10 +46,12 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() { auto block_size = metadata->constants.at("BLOCK_SIZE"); auto hw_size = metadata->constants.at("HW_SIZE"); auto impl = [i, block_size, hw_size](const GroupNormNHWCTunableParams* params) -> Status { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr), "Skip is not supported"); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF((params->skip != nullptr || params->bias != nullptr), + "Skip is not supported"); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->channels_per_group > block_size || params->channels_per_group * 2 <= block_size, - "Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (", params->channels_per_group, ")."); + "Arg block_size (", block_size, ") is not the next power of 2 of channels_per_group (", + params->channels_per_group, ")."); TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->hw % hw_size != 0, "Arg hw_size (", hw_size, ") is not a divisor of hw (", params->hw, ")."); if constexpr (WithSwish) { diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h index 6120a81900990..679b8a0ae5aed 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm_tunable_op.h @@ -81,13 +81,13 @@ void groupNormNHWCScale(const GroupNormNHWCTunableParams* params) { // The number of instances. grid.z = params->n; -#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \ - GroupNormNHWCScaleKernel \ - <<StreamHandle()>>>( \ - params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, \ - params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, params->channels_per_group, \ - params->groups, params->hwc, params->inv_hw_channels_per_group, params->hw, params->hw_per_block, \ - params->use_silu); \ +#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \ + GroupNormNHWCScaleKernel \ + <<StreamHandle()>>>( \ + params->dst, params->src, params->skip, params->gamma, params->beta, params->skip_workspace, \ + params->group_sum_buffer, params->epsilon, params->c, params->channels_per_block, \ + params->channels_per_group, params->groups, params->hwc, params->inv_hw_channels_per_group, \ + params->hw, params->hw_per_block, params->use_silu); \ break; // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. @@ -127,8 +127,10 @@ template class GroupNormNHWCOp { public: Status operator()(const GroupNormNHWCTunableParams* params) { - HIP_RETURN_IF_ERROR(hipMemsetAsync( - params->group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), params->StreamHandle())); + HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, + 0, + GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), + params->StreamHandle())); auto status = GroupNormNHWCSumOp(params); ORT_RETURN_IF_ERROR(status); HIP_RETURN_IF_ERROR(hipGetLastError()); @@ -155,8 +157,10 @@ class GroupNormNHWCOp { template Status GroupNormNHWCStaticSelection(const GroupNormNHWCTunableParams* params) { - HIP_RETURN_IF_ERROR(hipMemsetAsync( - params->group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), params->StreamHandle())); + HIP_RETURN_IF_ERROR(hipMemsetAsync(params->group_sum_buffer, + 0, + GetGroupNormWorkspaceSizeInBytes(params->n, params->groups), + params->StreamHandle())); groupNormNHWCSum(params); HIP_RETURN_IF_ERROR(hipGetLastError()); groupNormNHWCScale(params); diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu index 9ff22ec2bde56..5f6482b895c7a 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/group_norm.cu @@ -20,11 +20,13 @@ class GroupNormNHWC : public IKernelExplorer { public: GroupNormNHWC(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, float epsilon, - int batch_size, int num_channels, int height, int width, int num_groups, bool use_silu, bool broadcast_skip, int channels_per_block) + int batch_size, int num_channels, int height, int width, int num_groups, bool use_silu, + bool broadcast_skip, int channels_per_block) : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), - epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, channels_per_block) { + epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, + channels_per_block) { type_string_ = "GroupNormNHWC_" + std::to_string(ThreadsPerBlock) + "_" + std::to_string(VecSize); } @@ -51,19 +53,23 @@ class GroupNormNHWC : public IKernelExplorer { template class GroupNormNHWCStaticSelection : public IKernelExplorer { public: - GroupNormNHWCStaticSelection(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, DeviceArray& bias, - DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, float epsilon, - int batch_size, int num_channels, int height, int width, int num_groups, bool use_silu, bool broadcast_skip, int channels_per_block) + GroupNormNHWCStaticSelection(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, + DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, + float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block) : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), - epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, channels_per_block) { + epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, + channels_per_block) { type_string_ = "GroupNormNHWCStaticSelection"; } void Run() override { - HIP_CALL_THROW(hipMemsetAsync( - params_.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(params_.n, params_.groups), params_.StreamHandle())); + HIP_CALL_THROW(hipMemsetAsync(params_.group_sum_buffer, + 0, + GetGroupNormWorkspaceSizeInBytes(params_.n, params_.groups), + params_.StreamHandle())); ORT_THROW_IF_ERROR((contrib::rocm::GroupNormNHWCStaticSelection(¶ms_))); } @@ -72,8 +78,10 @@ class GroupNormNHWCStaticSelection : public IKernelExplorer { } bool SelectOp(const std::string& name) { - HIP_CALL_THROW(hipMemsetAsync( - params_.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(params_.n, params_.groups), params_.StreamHandle())); + HIP_CALL_THROW(hipMemsetAsync(params_.group_sum_buffer, + 0, + GetGroupNormWorkspaceSizeInBytes(params_.n, params_.groups), + params_.StreamHandle())); Status status = contrib::rocm::GroupNormNHWCStaticSelection(¶ms_); return status.IsOK() && name == type_string_; } @@ -87,19 +95,23 @@ class GroupNormNHWCStaticSelection : public IKernelExplorer { template class GroupNormNHWCTunable : public IKernelExplorer { public: - GroupNormNHWCTunable(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, DeviceArray& bias, - DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, float epsilon, - int batch_size, int num_channels, int height, int width, int num_groups, bool use_silu, bool broadcast_skip, int channels_per_block) + GroupNormNHWCTunable(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, + DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, + float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block) : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), - epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, channels_per_block) { + epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, + channels_per_block) { params_.TuningContext()->EnableTunableOpAndTuning(); } void Run() override { - HIP_CALL_THROW(hipMemsetAsync( - params_.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(params_.n, params_.groups), params_.StreamHandle())); + HIP_CALL_THROW(hipMemsetAsync(params_.group_sum_buffer, + 0, + GetGroupNormWorkspaceSizeInBytes(params_.n, params_.groups), + params_.StreamHandle())); ORT_THROW_IF_ERROR(op_(¶ms_)); } @@ -121,13 +133,15 @@ class GroupNormNHWCTunable : public IKernelExplorer { template class CKGroupNormNHWC : public IKernelExplorer { public: - CKGroupNormNHWC(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, DeviceArray& bias, - DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, float epsilon, - int batch_size, int num_channels, int height, int width, int num_groups, bool use_silu, bool broadcast_skip, int channels_per_block) + CKGroupNormNHWC(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, + DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, + float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block) : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), - epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, channels_per_block) { + epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, + channels_per_block) { for (auto&& [type_string, op] : contrib::rocm::GetCKGroupNormNHWCTypeStringAndOps()) { type_strings_.emplace_back(std::move(type_string)); ops_.emplace_back(std::move(op)); @@ -135,8 +149,10 @@ class CKGroupNormNHWC : public IKernelExplorer { } void Run() override { - HIP_CALL_THROW(hipMemsetAsync( - params_.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(params_.n, params_.groups), params_.StreamHandle())); + HIP_CALL_THROW(hipMemsetAsync(params_.group_sum_buffer, + 0, + GetGroupNormWorkspaceSizeInBytes(params_.n, params_.groups), + params_.StreamHandle())); ORT_THROW_IF_ERROR(ops_[selected_op_](¶ms_)); } @@ -148,8 +164,10 @@ class CKGroupNormNHWC : public IKernelExplorer { for (size_t i = 0; i < ops_.size(); i++) { if (type_strings_[i] == name) { selected_op_ = i; - HIP_CALL_THROW(hipMemsetAsync( - params_.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(params_.n, params_.groups), params_.StreamHandle())); + HIP_CALL_THROW(hipMemsetAsync(params_.group_sum_buffer, + 0, + GetGroupNormWorkspaceSizeInBytes(params_.n, params_.groups), + params_.StreamHandle())); Status status = ops_[i](¶ms_); return status.IsOK(); } @@ -172,13 +190,15 @@ class CKGroupNormNHWC : public IKernelExplorer { template class GroupNormNHWCTriton : public IKernelExplorer { public: - GroupNormNHWCTriton(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, DeviceArray& bias, - DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, float epsilon, - int batch_size, int num_channels, int height, int width, int num_groups, bool use_silu, bool broadcast_skip, int channels_per_block) + GroupNormNHWCTriton(DeviceArray& output, DeviceArray& add_output, DeviceArray& input, DeviceArray& skip, + DeviceArray& bias, DeviceArray& gamma, DeviceArray& beta, DeviceArray& workspace, + float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, + bool use_silu, bool broadcast_skip, int channels_per_block) : params_(TuningContext(), Stream(), static_cast(output.ptr()), static_cast(add_output.ptr()), static_cast(input.ptr()), static_cast(skip.ptr()), static_cast(bias.ptr()), static_cast(gamma.ptr()), static_cast(beta.ptr()), static_cast(workspace.ptr()), - epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, channels_per_block) { + epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, + channels_per_block) { for (auto&& [name, op] : contrib::rocm::GetTritonGroupNormNHWCTypeStringAndOps()) { name_strings_.emplace_back(name); ops_.emplace_back(std::move(op));