Skip to content

Commit

Permalink
[CUDA] Refactor GroupNorm and add common vectorize implementation (#1…
Browse files Browse the repository at this point in the history
…9158)

Co-authored-by: Peixuan Zuo <[email protected]@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
  • Loading branch information
PeixuanZuo and Peixuan Zuo authored Jan 29, 2024
1 parent 6d7ac9c commit 82c1cb4
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 102 deletions.
9 changes: 6 additions & 3 deletions onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ namespace {

template <typename T>
struct DispatchGroupNorm {
Status operator()(cudaStream_t stream,
Status operator()(CudaTuningContext* tuning_ctx,
Stream* ort_stream,
Tensor* output,
Tensor* add_out,
const Tensor* input,
Expand All @@ -44,7 +45,8 @@ struct DispatchGroupNorm {
int channels_per_block) {
typedef typename ToCudaType<T>::MappedType CudaT;
return LaunchGroupNormKernel<CudaT>(
stream,
tuning_ctx,
ort_stream,
reinterpret_cast<CudaT*>(output->MutableData<T>()),
add_out == nullptr ? nullptr : reinterpret_cast<CudaT*>(add_out->MutableData<T>()),
reinterpret_cast<const CudaT*>(input->Data<T>()),
Expand Down Expand Up @@ -209,7 +211,8 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const {
context->GetComputeStream());

utils::MLTypeCallDispatcher<GROUP_NORM_TYPES> dispatcher(input->GetElementType());
return dispatcher.InvokeRet<Status, DispatchGroupNorm>(Stream(context), output, add_out, input, skip, bias,
return dispatcher.InvokeRet<Status, DispatchGroupNorm>(GetTuningContext(),
context->GetComputeStream(), output, add_out, input, skip, bias,
gamma, beta, workspace.get(),
epsilon_,
batch_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ struct GroupNormNHWCParams {
const T* bias,
const float* gamma,
const float* beta,
void* workspace,
float* workspace,
float epsilon,
int batch_size,
int num_channels,
Expand All @@ -151,7 +151,7 @@ struct GroupNormNHWCParams {
this->bias = bias;
this->gamma = gamma;
this->beta = beta;
this->group_sum_buffer = reinterpret_cast<float*>(workspace);
this->group_sum_buffer = workspace;
this->n = batch_size;
this->h = height;
this->w = width;
Expand Down
61 changes: 37 additions & 24 deletions onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,26 @@ void GroupNormNHWCSum(GroupNormNHWCParams<T> const& params, cudaStream_t stream)
// The number of instances.
grid.z = params.n;

#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \
GroupNormNHWCSumKernel<T, ThreadsPerBlock, VecSize> \
<<<grid, ThreadsPerBlock, 0, stream>>>( \
params.skip_workspace, params.group_sum_buffer, params.src, params.skip, params.bias, \
params.channels_per_block, params.hw_per_block, params.hw, params.hwc, params.c, \
params.channels_per_group, params.groups, params.groups_per_block, params.broadcast_skip); \
break;

// Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2.
switch (params.threads_per_block) {
case 256:
GroupNormNHWCSumKernel<T, 256><<<grid, 256, 0, stream>>>(params);
break;
LAUNCH_GROUPNORM_SUM(256, CHANNELS_PER_THREAD)
case 192:
GroupNormNHWCSumKernel<T, 192><<<grid, 192, 0, stream>>>(params);
break;
LAUNCH_GROUPNORM_SUM(192, CHANNELS_PER_THREAD)
case 160:
GroupNormNHWCSumKernel<T, 160><<<grid, 160, 0, stream>>>(params);
break;
LAUNCH_GROUPNORM_SUM(160, CHANNELS_PER_THREAD)
case 128:
GroupNormNHWCSumKernel<T, 128><<<grid, 128, 0, stream>>>(params);
break;
LAUNCH_GROUPNORM_SUM(128, CHANNELS_PER_THREAD)
case 64:
GroupNormNHWCSumKernel<T, 64><<<grid, 64, 0, stream>>>(params);
break;
LAUNCH_GROUPNORM_SUM(64, CHANNELS_PER_THREAD)
}
}

Expand All @@ -80,29 +83,34 @@ void GroupNormNHWCScale(GroupNormNHWCParams<T> const& params, cudaStream_t strea
// The number of instances.
grid.z = params.n;

#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \
GroupNormNHWCScaleKernel<T, VecSize> \
<<<grid, ThreadsPerBlock, 0, stream>>>( \
params.dst, params.src, params.skip, params.gamma, params.beta, params.skip_workspace, \
params.group_sum_buffer, params.epsilon, params.c, params.channels_per_block, params.channels_per_group, \
params.groups, params.hwc, params.inv_hw_channels_per_group, params.hw, params.hw_per_block, \
params.use_silu); \
break;

// Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2.
switch (params.threads_per_block) {
case 256:
GroupNormNHWCScaleKernel<T><<<grid, 256, 0, stream>>>(params);
break;
LAUNCH_GROUPNORM_SCALE(256, CHANNELS_PER_THREAD)
case 192:
GroupNormNHWCScaleKernel<T><<<grid, 192, 0, stream>>>(params);
break;
LAUNCH_GROUPNORM_SCALE(192, CHANNELS_PER_THREAD)
case 160:
GroupNormNHWCScaleKernel<T><<<grid, 160, 0, stream>>>(params);
break;
LAUNCH_GROUPNORM_SCALE(160, CHANNELS_PER_THREAD)
case 128:
GroupNormNHWCScaleKernel<T><<<grid, 128, 0, stream>>>(params);
break;
LAUNCH_GROUPNORM_SCALE(128, CHANNELS_PER_THREAD)
case 64:
GroupNormNHWCScaleKernel<T><<<grid, 64, 0, stream>>>(params);
break;
LAUNCH_GROUPNORM_SCALE(64, CHANNELS_PER_THREAD)
}
}

template <typename T>
Status LaunchGroupNormKernel(
cudaStream_t stream,
CudaTuningContext* tuning_ctx,
Stream* ort_stream,
T* output,
T* add_out,
const T* input,
Expand All @@ -120,7 +128,11 @@ Status LaunchGroupNormKernel(
bool use_silu,
bool broadcast_skip,
int channels_per_block) {
GroupNormNHWCParams<T> params(output, add_out, input, skip, bias, gamma, beta, workspace, epsilon,

// tuning_ctx only used for ROCm EP.
ORT_UNUSED_PARAMETER(tuning_ctx);

GroupNormNHWCParams<T> params(output, add_out, input, skip, bias, gamma, beta, reinterpret_cast<float*>(workspace), epsilon,
batch_size, num_channels, height, width, num_groups, use_silu,
broadcast_skip, channels_per_block);

Expand All @@ -135,6 +147,7 @@ Status LaunchGroupNormKernel(
" groups=", num_groups);
}

auto stream = static_cast<cudaStream_t>(ort_stream->GetHandle());
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(
params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), stream));

Expand All @@ -150,14 +163,14 @@ Status LaunchGroupNormKernel(
return Status::OK();
}

template Status LaunchGroupNormKernel<half>(cudaStream_t stream, half* output, half* add_out,
template Status LaunchGroupNormKernel<half>(CudaTuningContext* tuning_ctx, Stream* stream, half* output, half* add_out,
const half* input, const half* skip, const half* bias,
const float* gamma, const float* beta, void* workspace,
float epsilon, int batch_size, int num_channels,
int height, int width, int num_groups, bool silu,
bool broadcast_skip, int channels_per_block);

template Status LaunchGroupNormKernel<float>(cudaStream_t stream, float* output, float* add_out,
template Status LaunchGroupNormKernel<float>(CudaTuningContext* tuning_ctx, Stream* stream, float* output, float* add_out,
const float* input, const float* skip, const float* bias,
const float* gamma, const float* beta, void* workspace,
float epsilon, int batch_size, int num_channels,
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <cuda.h>
#include <cuda_fp16.h>

#include "core/providers/cuda/tunable/cuda_tunable.h"

namespace onnxruntime {
namespace contrib {
namespace cuda {
Expand All @@ -21,7 +23,8 @@ int GetChannelsPerBlock(int num_channels, int num_groups);

template <typename T>
Status LaunchGroupNormKernel(
cudaStream_t stream,
CudaTuningContext* tuning_ctx,
Stream* ort_stream,
T* output, // normalized output tensor. Shape is (n, h, w, c)
T* add_out, // optional output tensor for element-wise sum of input + skip + bias. Shape is (n, h, w, c)
const T* input, // input tensor. Shape is (n, h, w, c)
Expand Down
Loading

0 comments on commit 82c1cb4

Please sign in to comment.