Skip to content

Commit

Permalink
Add SkipGroupNorm and BiasGroupNorm
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Oct 27, 2023
1 parent aca36a4 commit ffce23f
Show file tree
Hide file tree
Showing 7 changed files with 286 additions and 151 deletions.
9 changes: 7 additions & 2 deletions onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,15 @@ Status GroupNorm<T>::ComputeInternal(OpKernelContext* context) const {
Tensor* output = context->Output(0, input->Shape());

if (!channels_last_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"only the channels_last layout is supported");
}

if (!gamma->IsDataType<float>() || !beta->IsDataType<float>()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"GroupNorm only supports gamma and beta in float type");
}

const auto& input_dims = input->Shape().GetDims();
if (input_dims.size() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand Down Expand Up @@ -174,7 +179,7 @@ Status GroupNorm<T>::ComputeInternal(OpKernelContext* context) const {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"bias is expected to have 2 dimension, got ", bias_dims.size());
}
if (bias_dims[0] != num_channels) {
if (bias_dims[0] != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"First dimension (batch size) in bias and input does not match");
}
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cuda/diffusion/group_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class GroupNorm final : public CudaKernel {
GroupNorm(const OpKernelInfo& op_kernel_info);
Status ComputeInternal(OpKernelContext* context) const override;

protected:
bool use_swish_activation_;
private:
bool use_swish_activation_; // use SiLU (also known as Swish) activation after group normalization?
float epsilon_;
int num_groups_;
bool channels_last_;
Expand Down
145 changes: 84 additions & 61 deletions onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/

// The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5
// Modifications: support more cPerBlock
// Modifications: heuristic cPerBlock; support epsilon; support skip and bias etc.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

Expand All @@ -28,12 +28,14 @@
#include "contrib_ops/cuda/diffusion/group_norm_impl.h"
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"

using namespace onnxruntime::cuda;

Check warning on line 31 in onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu#L31

Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
Raw output
onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu:31:  Do not use namespace using-directives.  Use using-declarations instead.  [build/namespaces] [5]

namespace onnxruntime {
namespace contrib {
namespace cuda {

namespace {
constexpr static int32_t CHANNELS_PER_THREAD = 2; // 2 channels per thread
constexpr static int32_t CHANNELS_PER_THREAD = 2;

constexpr static int kSizes[] = {64, 128, 256, 320, 384, 512};
constexpr static size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]);
Expand Down Expand Up @@ -83,55 +85,69 @@ template <typename T>
struct GroupNormNHWCParams {
// The output buffer. Shape is (n, h, w, c)
T* dst;
// Optional output of element-wise add result of src, skip and bias. Shape is (n, h, w, c) for SkipGroupNorm

// Optional output of element-wise add result of src, skip and bias. Shape is (n, h, w, c) for SkipGroupNorm.
T* add_out;

// The input buffer. Shape is (n, h, w, c)
T const* src;
// Optional input buffer for skip. Shape is (n, h, w, c) for SkipGroupNorm

// Optional input buffer for skip tensor. Shape is (n, h, w, c) for SkipGroupNorm.
T const* skip;
// Optional input buffer for bias. Shape is (c) for SkipGroupNorm or (n, 1, 1, c) for BiasGroupNorm

// Optional input buffer for bias tensor. Shape is (c) for SkipGroupNorm or (n, 1, 1, c) for BiasGroupNorm.
T const* bias;

// The gamma scaling factor.
float const* gamma;

// The beta term to add in GN.
float const* beta;
// The temporary buffer to do the global parallel reduction. Size is n x g x 2, where g is number of groups.

// The temporary buffer to do the global parallel reduction. Shape is (n, 2, g), where g is number of groups.
float* redBuffer;

// The number of instances in the batch.
int32_t n;

// The height and width of each activation map.
int32_t h;
int32_t w;
// The number of channels.

// Number of channels.
int32_t c;
// The number of groups.

// Number of groups.
int32_t groups;
// Do we apply the Swish activation function?
bool withSwish;

// Do we apply the SiLU activation function?
bool withSilu;

// Precomputed values and parameters to control the execution of the kernels.

// The number of activations per instance (h * w) and the number of
// activations per block.
// Number of activations per instance (h * w)
int32_t hw;

// Number of activations per block
int32_t hwPerBlock;
// The number of channels per group and blocks per activation in the C
// dimension.

// Number of channels per block in the C dimension.
int32_t cPerBlock;

// Number of channels per group in the C dimension.
int32_t cPerGroup;

// The precomputed stride between instances.
int32_t hwc;
// The inverse of hwc in floats (to compute mean/var).
// The inverse of hw*cPerGroup to compute mean of a group.
float invHWC;
// The precomputed number of groups per block.
int32_t groupsPerBlock;

// Number of threads per block
int32_t threadsPerBlock;

// Epsilon to get stable variance in normalization.
float epsilon;
};

Expand Down Expand Up @@ -203,27 +219,31 @@ inline __device__ void AddSkipBias(const float* src, const float* skip, const fl

// Sum for BiasGroupNorm
template <typename T>
inline __device__ void AddBias(const T* src, const T* bias,
inline __device__ void AddBias(const T* src, const T* bias, T* add_out,
int64_t offset, int32_t bias_offset, float& sum, float& sumSq);

template <>
inline __device__ void AddBias(const half* src, const half* bias,
inline __device__ void AddBias(const half* src, const half* bias, half* add_out,
int64_t offset, int32_t bias_offset, float& sum, float& sumSq) {
__half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]);
__half2 b = *reinterpret_cast<__half2 const*>(&bias[bias_offset]);
h2 += b;

*reinterpret_cast<__half2*>(&add_out[offset]) = h2;

float2 f2 = __half22float2(h2);
sum += f2.x + f2.y;
sumSq += f2.x * f2.x + f2.y * f2.y;
}

template <>
inline __device__ void AddBias(const float* src, const float* bias,
inline __device__ void AddBias(const float* src, const float* bias, float* add_out,
int64_t offset, int32_t bias_offset, float& sum, float& sumSq) {
float2 f2 = *reinterpret_cast<float2 const*>(&src[offset]);
float2 b = *reinterpret_cast<float2 const*>(&bias[bias_offset]);
f2.x += b.x;
f2.y += b.y;
*reinterpret_cast<float2*>(&add_out[offset]) = f2;
sum += f2.x + f2.y;
sumSq += f2.x * f2.x + f2.y * f2.y;
}
Expand Down Expand Up @@ -263,6 +283,7 @@ __global__ void groupNormNHWCSumKernel(GroupNormNHWCParams<T> params) {
// (1) SkipGroupNorm: skip is (n, h, w, c) and bias is (c), add_out is (n, h, w, c)
// The additional output add_out = src + skip + bias.
// (2) BiasGroupNorm: bias is (n, c), add_out and skip are empty
// We will use dst as temp storage to store src + bias.
// (3) GroupNorm: skip, bias and add_out not exists

int64_t offset = static_cast<int64_t>(ni) * params.hwc + static_cast<int64_t>(hwBegin) * params.c + ci;
Expand All @@ -274,7 +295,7 @@ __global__ void groupNormNHWCSumKernel(GroupNormNHWCParams<T> params) {
} else if (params.bias != nullptr) { // BiasGroupNorm
const int64_t bias_offset = static_cast<int64_t>(ni) * params.c + ci;
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi, offset += params.c) {
AddBias(params.src, params.bias, offset, bias_offset, sum, sumSq);
AddBias(params.src, params.bias, params.dst, offset, bias_offset, sum, sumSq);
}
} else { // GroupNorm
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi, offset += params.c) {
Expand Down Expand Up @@ -307,8 +328,9 @@ __global__ void groupNormNHWCSumKernel(GroupNormNHWCParams<T> params) {
if (is_last_of_a_group) {
int32_t gj = ci / params.cPerGroup; // absolute group index
float2 sums = smem[gi];
atomicAdd(&params.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x);
atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y);
const int index = (2 * ni) * params.groups + gj;
atomicAdd(&params.redBuffer[index], sums.x);
atomicAdd(&params.redBuffer[index + params.groups], sums.y);
}
}

Expand Down Expand Up @@ -349,11 +371,12 @@ void groupNormNHWCSum(GroupNormNHWCParams<T> const& params, cudaStream_t stream)
}

template <typename T>
__device__ void computeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float invStdDev, float2& gammaF2, float2& betaF2, bool swish);
__device__ void computeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float invStdDev,
float2& gammaF2, float2& betaF2, bool silu);

template <>
__device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float invStdDev,
float2& gammaF2, float2& betaF2, bool swish) {
float2& gammaF2, float2& betaF2, bool silu) {
// Fetch two channels per thread.
__half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]);

Expand All @@ -368,8 +391,8 @@ __device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, flo
f2.x = gammaF2.x * f2.x + betaF2.x;
f2.y = gammaF2.y * f2.y + betaF2.y;

// Apply SiLU (also known as Swish) if needed.
if (swish) {
// Apply SiLU activation if needed.
if (silu) {
f2.x = f2.x * sigmoid(f2.x);
f2.y = f2.y * sigmoid(f2.y);
}
Expand All @@ -379,7 +402,7 @@ __device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, flo

template <>
__device__ void computeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float invStdDev,
float2& gammaF2, float2& betaF2, bool swish) {
float2& gammaF2, float2& betaF2, bool silu) {
// Fetch two channels per thread.
float2 f2 = *reinterpret_cast<float2 const*>(&src[offset]);

Expand All @@ -391,8 +414,8 @@ __device__ void computeGroupNorm(const float* src, float* dst, int64_t offset, f
f2.x = gammaF2.x * f2.x + betaF2.x;
f2.y = gammaF2.y * f2.y + betaF2.y;

// Apply SiLU (also known as Swish) if needed.
if (swish) {
// Apply SiLU activation if needed.
if (silu) {
f2.x = f2.x * sigmoid(f2.x);
f2.y = f2.y * sigmoid(f2.y);
}
Expand All @@ -411,17 +434,18 @@ __global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams<T> params) {
// The instance in the batch.
int32_t ni = blockIdx.z;

// The group that thread works on and the channel in the group (modulus).
// The group that thread works on.
int32_t gi = ci / params.cPerGroup;

// Load the sum and sum of squares for the group.
float sum = 0.F, sumSq = 0.F;
if (gi < params.groups) {
sum = params.redBuffer[(2 * ni + 0) * params.groups + gi];
sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi];
const int index = (2 * ni) * params.groups + gi;
sum = params.redBuffer[index];
sumSq = params.redBuffer[index + params.groups];
}

// Load gamma/beta.
// Load gamma/beta. Fetch two per thread.
float2 gammaF2 = *reinterpret_cast<float2 const*>(&params.gamma[ci]);
float2 betaF2 = *reinterpret_cast<float2 const*>(&params.beta[ci]);

Expand All @@ -432,18 +456,15 @@ __global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams<T> params) {
// Compute the inverse of the stddev.
float invStdDev = rsqrtf(var + params.epsilon);

// The first activation loaded by that block.
int32_t hwBegin = blockIdx.y * params.hwPerBlock;
// The last activation loaded by that block.
int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw);

// Iterate over the activations to compute the sums.
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
// The src/dst offset.
int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci;

// Fetch two channels per thread.
computeGroupNorm<T>(params.src, params.dst, offset, mean, invStdDev, gammaF2, betaF2, params.withSwish);
// For SkipGroupNorm, the source is sum of src + skip + bias, which was stored in add_out.
// For BiasGroupNorm, the source is src + bias, which was stored in dst as intermediate data.
const T* source = (params.skip != nullptr) ? params.add_out : (params.bias != nullptr ? params.dst : params.src);
int64_t offset = static_cast<int64_t>(ni) * params.hwc + static_cast<int64_t>(hwBegin) * params.c + ci;
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi, offset += params.c) {
computeGroupNorm<T>(source, params.dst, offset, mean, invStdDev, gammaF2, betaF2, params.withSilu);
}
}

Expand All @@ -458,6 +479,7 @@ void groupNormNHWCScale(GroupNormNHWCParams<T> const& params, cudaStream_t strea
// The number of instances.
grid.z = params.n;

// Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2.
switch (params.threadsPerBlock) {
case 256:
groupNormNHWCScaleKernel<T><<<grid, 256, 0, stream>>>(params);
Expand Down Expand Up @@ -515,7 +537,7 @@ Status LaunchGroupNormKernel(
int height,
int width,
int num_groups,
bool use_swish_activation) {
bool use_silu) {
GroupNormNHWCParams<T> params;

int32_t cPerGroup = num_channels / num_groups;
Expand Down Expand Up @@ -550,14 +572,15 @@ Status LaunchGroupNormKernel(
break;
default:
cPerBlock = 320;
if (num_channels % cPerBlock != 0 || cPerBlock % cPerGroup != 0) {
// Find a maximum cPerBlock that num_channels could be divisible by it.
// Try to be close to 512 since multiple kSizes values within [256, 512] range could act as fallback.
cPerBlock = findMaxDivisor(num_groups, kMaxSize / cPerGroup) * cPerGroup;
}
}

params.withSwish = use_swish_activation;
if (num_channels % cPerBlock != 0 || cPerBlock % cPerGroup != 0) {
// Find a maximum cPerBlock that num_channels could be divisible by it.
// Try to be close to 512 since multiple kSizes values within [256, 512) range could act as fallback.
cPerBlock = findMaxDivisor(num_groups, kMaxSize / cPerGroup) * cPerGroup;
}

params.withSilu = use_silu;
params.dst = output;
params.add_out = add_out;
params.src = input;
Expand Down Expand Up @@ -590,25 +613,25 @@ Status LaunchGroupNormKernel(
params.cPerBlock % params.cPerGroup != 0 ||
cPerBlock > 512 ||
(params.cPerGroup % CHANNELS_PER_THREAD != 0)) {
printf("n=%d h=%d w=%d c=%d groups=%d hw=%d hwPerBlock=%d cPerBlock=%d cPerGroup=%d\n",
params.n, params.h, params.w, params.c, params.groups, params.hw, params.hwPerBlock,
params.cPerBlock, params.cPerGroup);
ORT_NOT_IMPLEMENTED("Not implemented");
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"GroupNorm in CUDA does not support the input: n=", params.n,
" h=", params.h,
" w=", params.w,
" c=", params.c,
" groups=", params.groups);
}

params.threadsPerBlock = nextSize(cPerBlock) / CHANNELS_PER_THREAD;

#ifdef DUMP_GROUP_NORM
printf("n=%d h=%d w=%d c=%d groups=%d hw=%d hwPerBlock=%d cPerBlock=%d cPerGroup=%d threadsPerBlock=%d\n",
params.n, params.h, params.w, params.c, params.groups, params.hw, params.hwPerBlock,
params.cPerBlock, params.cPerGroup, params.threadsPerBlock);
#endif

CUDA_RETURN_IF_ERROR(cudaMemsetAsync(params.redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), stream));
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(
params.redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), stream));

groupNormNHWCSum<T>(params, stream);
CUDA_RETURN_IF_ERROR(cudaGetLastError());

DUMP_TENSOR_INIT();
DUMP_TENSOR("workspace", params.redBuffer, batch_size, 2, num_groups);

groupNormNHWCScale<T>(params, stream);
CUDA_RETURN_IF_ERROR(cudaGetLastError());

Expand All @@ -619,13 +642,13 @@ template Status LaunchGroupNormKernel<half>(cudaStream_t stream, half* output, h
const half* input, const half* skip, const half* bias,
const float* gamma, const float* beta, void* workspace,
float epsilon, int batch_size, int num_channels,
int height, int width, int num_groups, bool swish);
int height, int width, int num_groups, bool silu);

template Status LaunchGroupNormKernel<float>(cudaStream_t stream, float* output, float* add_out,
const float* input, const float* skip, const float* bias,
const float* gamma, const float* beta, void* workspace,
float epsilon, int batch_size, int num_channels,
int height, int width, int num_groups, bool swish);
int height, int width, int num_groups, bool silu);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
Loading

0 comments on commit ffce23f

Please sign in to comment.