Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUDA] Refactor GroupNorm and add common vectorize implementation #19158

Merged
merged 5 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

template <typename T>
struct DispatchGroupNorm {
Status operator()(cudaStream_t stream,
Status operator()(CudaTuningContext* tuning_ctx,
Stream* ort_stream,
Tensor* output,
Tensor* add_out,
const Tensor* input,
Expand All @@ -44,7 +45,8 @@
int channels_per_block) {
typedef typename ToCudaType<T>::MappedType CudaT;
return LaunchGroupNormKernel<CudaT>(
stream,
tuning_ctx,
ort_stream,
reinterpret_cast<CudaT*>(output->MutableData<T>()),
add_out == nullptr ? nullptr : reinterpret_cast<CudaT*>(add_out->MutableData<T>()),
reinterpret_cast<const CudaT*>(input->Data<T>()),
Expand Down Expand Up @@ -209,7 +211,8 @@
context->GetComputeStream());

utils::MLTypeCallDispatcher<GROUP_NORM_TYPES> dispatcher(input->GetElementType());
return dispatcher.InvokeRet<Status, DispatchGroupNorm>(Stream(context), output, add_out, input, skip, bias,
return dispatcher.InvokeRet<Status, DispatchGroupNorm>(GetTuningContext(),
context->GetComputeStream(), output, add_out, input, skip, bias,

Check warning on line 215 in onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc#L215

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc:215:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
gamma, beta, workspace.get(),
epsilon_,
batch_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ struct GroupNormNHWCParams {
const T* bias,
const float* gamma,
const float* beta,
void* workspace,
float* workspace,
float epsilon,
int batch_size,
int num_channels,
Expand All @@ -151,7 +151,7 @@ struct GroupNormNHWCParams {
this->bias = bias;
this->gamma = gamma;
this->beta = beta;
this->group_sum_buffer = reinterpret_cast<float*>(workspace);
this->group_sum_buffer = workspace;
this->n = batch_size;
this->h = height;
this->w = width;
Expand Down
61 changes: 37 additions & 24 deletions onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,26 @@
// The number of instances.
grid.z = params.n;

#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \
GroupNormNHWCSumKernel<T, ThreadsPerBlock, VecSize> \
<<<grid, ThreadsPerBlock, 0, stream>>>( \
params.skip_workspace, params.group_sum_buffer, params.src, params.skip, params.bias, \
params.channels_per_block, params.hw_per_block, params.hw, params.hwc, params.c, \
params.channels_per_group, params.groups, params.groups_per_block, params.broadcast_skip); \
break;

// Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2.
switch (params.threads_per_block) {
case 256:
GroupNormNHWCSumKernel<T, 256><<<grid, 256, 0, stream>>>(params);
break;
LAUNCH_GROUPNORM_SUM(256, CHANNELS_PER_THREAD)
case 192:
GroupNormNHWCSumKernel<T, 192><<<grid, 192, 0, stream>>>(params);
break;
LAUNCH_GROUPNORM_SUM(192, CHANNELS_PER_THREAD)
case 160:
GroupNormNHWCSumKernel<T, 160><<<grid, 160, 0, stream>>>(params);
break;
LAUNCH_GROUPNORM_SUM(160, CHANNELS_PER_THREAD)
case 128:
GroupNormNHWCSumKernel<T, 128><<<grid, 128, 0, stream>>>(params);
break;
LAUNCH_GROUPNORM_SUM(128, CHANNELS_PER_THREAD)
case 64:
GroupNormNHWCSumKernel<T, 64><<<grid, 64, 0, stream>>>(params);
break;
LAUNCH_GROUPNORM_SUM(64, CHANNELS_PER_THREAD)
}
}

Expand All @@ -80,29 +83,34 @@
// The number of instances.
grid.z = params.n;

#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \
GroupNormNHWCScaleKernel<T, VecSize> \
<<<grid, ThreadsPerBlock, 0, stream>>>( \
params.dst, params.src, params.skip, params.gamma, params.beta, params.skip_workspace, \
params.group_sum_buffer, params.epsilon, params.c, params.channels_per_block, params.channels_per_group, \
params.groups, params.hwc, params.inv_hw_channels_per_group, params.hw, params.hw_per_block, \
params.use_silu); \
break;

// Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2.
switch (params.threads_per_block) {
case 256:
GroupNormNHWCScaleKernel<T><<<grid, 256, 0, stream>>>(params);
break;
LAUNCH_GROUPNORM_SCALE(256, CHANNELS_PER_THREAD)
case 192:
GroupNormNHWCScaleKernel<T><<<grid, 192, 0, stream>>>(params);
break;
LAUNCH_GROUPNORM_SCALE(192, CHANNELS_PER_THREAD)
case 160:
GroupNormNHWCScaleKernel<T><<<grid, 160, 0, stream>>>(params);
break;
LAUNCH_GROUPNORM_SCALE(160, CHANNELS_PER_THREAD)
case 128:
GroupNormNHWCScaleKernel<T><<<grid, 128, 0, stream>>>(params);
break;
LAUNCH_GROUPNORM_SCALE(128, CHANNELS_PER_THREAD)
case 64:
GroupNormNHWCScaleKernel<T><<<grid, 64, 0, stream>>>(params);
break;
LAUNCH_GROUPNORM_SCALE(64, CHANNELS_PER_THREAD)
}
}

template <typename T>
Status LaunchGroupNormKernel(
cudaStream_t stream,
CudaTuningContext* tuning_ctx,
Stream* ort_stream,
T* output,
T* add_out,
const T* input,
Expand All @@ -120,7 +128,11 @@
bool use_silu,
bool broadcast_skip,
int channels_per_block) {
GroupNormNHWCParams<T> params(output, add_out, input, skip, bias, gamma, beta, workspace, epsilon,

// tuning_ctx only used for ROCm EP.
ORT_UNUSED_PARAMETER(tuning_ctx);

GroupNormNHWCParams<T> params(output, add_out, input, skip, bias, gamma, beta, reinterpret_cast<float*>(workspace), epsilon,

Check warning on line 135 in onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu#L135

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu:135:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
batch_size, num_channels, height, width, num_groups, use_silu,
broadcast_skip, channels_per_block);

Expand All @@ -135,6 +147,7 @@
" groups=", num_groups);
}

auto stream = static_cast<cudaStream_t>(ort_stream->GetHandle());
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(
params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), stream));

Expand All @@ -150,14 +163,14 @@
return Status::OK();
}

template Status LaunchGroupNormKernel<half>(cudaStream_t stream, half* output, half* add_out,
template Status LaunchGroupNormKernel<half>(CudaTuningContext* tuning_ctx, Stream* stream, half* output, half* add_out,
const half* input, const half* skip, const half* bias,
const float* gamma, const float* beta, void* workspace,
float epsilon, int batch_size, int num_channels,
int height, int width, int num_groups, bool silu,
bool broadcast_skip, int channels_per_block);

template Status LaunchGroupNormKernel<float>(cudaStream_t stream, float* output, float* add_out,
template Status LaunchGroupNormKernel<float>(CudaTuningContext* tuning_ctx, Stream* stream, float* output, float* add_out,

Check warning on line 173 in onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu#L173

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu:173:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
const float* input, const float* skip, const float* bias,
const float* gamma, const float* beta, void* workspace,
float epsilon, int batch_size, int num_channels,
Expand Down
5 changes: 4 additions & 1 deletion onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
#include <cuda.h>
#include <cuda_fp16.h>

#include "core/providers/cuda/tunable/cuda_tunable.h"

namespace onnxruntime {
namespace contrib {
namespace cuda {
Expand All @@ -21,7 +23,8 @@ int GetChannelsPerBlock(int num_channels, int num_groups);

template <typename T>
Status LaunchGroupNormKernel(
cudaStream_t stream,
CudaTuningContext* tuning_ctx,
Stream* ort_stream,
T* output, // normalized output tensor. Shape is (n, h, w, c)
T* add_out, // optional output tensor for element-wise sum of input + skip + bias. Shape is (n, h, w, c)
const T* input, // input tensor. Shape is (n, h, w, c)
Expand Down
Loading
Loading