From 82c1cb416b8054f67fe1f73928ad4c276d80afdb Mon Sep 17 00:00:00 2001 From: PeixuanZuo <94887879+PeixuanZuo@users.noreply.github.com> Date: Mon, 29 Jan 2024 09:15:10 +0800 Subject: [PATCH] [CUDA] Refactor GroupNorm and add common vectorize implementation (#19158) Co-authored-by: Peixuan Zuo --- .../contrib_ops/cuda/diffusion/group_norm.cc | 9 +- .../cuda/diffusion/group_norm_common_base.h | 4 +- .../cuda/diffusion/group_norm_impl.cu | 61 +++-- .../cuda/diffusion/group_norm_impl.h | 5 +- .../cuda/diffusion/group_norm_impl_kernel.cuh | 240 ++++++++++++------ 5 files changed, 217 insertions(+), 102 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index 87e88ac31c998..dea5391c7629b 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -24,7 +24,8 @@ namespace { template struct DispatchGroupNorm { - Status operator()(cudaStream_t stream, + Status operator()(CudaTuningContext* tuning_ctx, + Stream* ort_stream, Tensor* output, Tensor* add_out, const Tensor* input, @@ -44,7 +45,8 @@ struct DispatchGroupNorm { int channels_per_block) { typedef typename ToCudaType::MappedType CudaT; return LaunchGroupNormKernel( - stream, + tuning_ctx, + ort_stream, reinterpret_cast(output->MutableData()), add_out == nullptr ? nullptr : reinterpret_cast(add_out->MutableData()), reinterpret_cast(input->Data()), @@ -209,7 +211,8 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { context->GetComputeStream()); utils::MLTypeCallDispatcher dispatcher(input->GetElementType()); - return dispatcher.InvokeRet(Stream(context), output, add_out, input, skip, bias, + return dispatcher.InvokeRet(GetTuningContext(), + context->GetComputeStream(), output, add_out, input, skip, bias, gamma, beta, workspace.get(), epsilon_, batch_size, diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h index 84f3403b8d5ae..ea87d0c29111e 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h @@ -126,7 +126,7 @@ struct GroupNormNHWCParams { const T* bias, const float* gamma, const float* beta, - void* workspace, + float* workspace, float epsilon, int batch_size, int num_channels, @@ -151,7 +151,7 @@ struct GroupNormNHWCParams { this->bias = bias; this->gamma = gamma; this->beta = beta; - this->group_sum_buffer = reinterpret_cast(workspace); + this->group_sum_buffer = workspace; this->n = batch_size; this->h = height; this->w = width; diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu index d7b2cc2379f4f..4909dc5e3897b 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -49,23 +49,26 @@ void GroupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) // The number of instances. grid.z = params.n; +#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \ + GroupNormNHWCSumKernel \ + <<>>( \ + params.skip_workspace, params.group_sum_buffer, params.src, params.skip, params.bias, \ + params.channels_per_block, params.hw_per_block, params.hw, params.hwc, params.c, \ + params.channels_per_group, params.groups, params.groups_per_block, params.broadcast_skip); \ + break; + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. switch (params.threads_per_block) { case 256: - GroupNormNHWCSumKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SUM(256, CHANNELS_PER_THREAD) case 192: - GroupNormNHWCSumKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SUM(192, CHANNELS_PER_THREAD) case 160: - GroupNormNHWCSumKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SUM(160, CHANNELS_PER_THREAD) case 128: - GroupNormNHWCSumKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SUM(128, CHANNELS_PER_THREAD) case 64: - GroupNormNHWCSumKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SUM(64, CHANNELS_PER_THREAD) } } @@ -80,29 +83,34 @@ void GroupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t strea // The number of instances. grid.z = params.n; +#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \ + GroupNormNHWCScaleKernel \ + <<>>( \ + params.dst, params.src, params.skip, params.gamma, params.beta, params.skip_workspace, \ + params.group_sum_buffer, params.epsilon, params.c, params.channels_per_block, params.channels_per_group, \ + params.groups, params.hwc, params.inv_hw_channels_per_group, params.hw, params.hw_per_block, \ + params.use_silu); \ + break; + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. switch (params.threads_per_block) { case 256: - GroupNormNHWCScaleKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SCALE(256, CHANNELS_PER_THREAD) case 192: - GroupNormNHWCScaleKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SCALE(192, CHANNELS_PER_THREAD) case 160: - GroupNormNHWCScaleKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SCALE(160, CHANNELS_PER_THREAD) case 128: - GroupNormNHWCScaleKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SCALE(128, CHANNELS_PER_THREAD) case 64: - GroupNormNHWCScaleKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SCALE(64, CHANNELS_PER_THREAD) } } template Status LaunchGroupNormKernel( - cudaStream_t stream, + CudaTuningContext* tuning_ctx, + Stream* ort_stream, T* output, T* add_out, const T* input, @@ -120,7 +128,11 @@ Status LaunchGroupNormKernel( bool use_silu, bool broadcast_skip, int channels_per_block) { - GroupNormNHWCParams params(output, add_out, input, skip, bias, gamma, beta, workspace, epsilon, + + // tuning_ctx only used for ROCm EP. + ORT_UNUSED_PARAMETER(tuning_ctx); + + GroupNormNHWCParams params(output, add_out, input, skip, bias, gamma, beta, reinterpret_cast(workspace), epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, channels_per_block); @@ -135,6 +147,7 @@ Status LaunchGroupNormKernel( " groups=", num_groups); } + auto stream = static_cast(ort_stream->GetHandle()); CUDA_RETURN_IF_ERROR(cudaMemsetAsync( params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), stream)); @@ -150,14 +163,14 @@ Status LaunchGroupNormKernel( return Status::OK(); } -template Status LaunchGroupNormKernel(cudaStream_t stream, half* output, half* add_out, +template Status LaunchGroupNormKernel(CudaTuningContext* tuning_ctx, Stream* stream, half* output, half* add_out, const half* input, const half* skip, const half* bias, const float* gamma, const float* beta, void* workspace, float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, bool silu, bool broadcast_skip, int channels_per_block); -template Status LaunchGroupNormKernel(cudaStream_t stream, float* output, float* add_out, +template Status LaunchGroupNormKernel(CudaTuningContext* tuning_ctx, Stream* stream, float* output, float* add_out, const float* input, const float* skip, const float* bias, const float* gamma, const float* beta, void* workspace, float epsilon, int batch_size, int num_channels, diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h index 9532aeecb2f57..98f38a1475eee 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h @@ -8,6 +8,8 @@ #include #include +#include "core/providers/cuda/tunable/cuda_tunable.h" + namespace onnxruntime { namespace contrib { namespace cuda { @@ -21,7 +23,8 @@ int GetChannelsPerBlock(int num_channels, int num_groups); template Status LaunchGroupNormKernel( - cudaStream_t stream, + CudaTuningContext* tuning_ctx, + Stream* ort_stream, T* output, // normalized output tensor. Shape is (n, h, w, c) T* add_out, // optional output tensor for element-wise sum of input + skip + bias. Shape is (n, h, w, c) const T* input, // input tensor. Shape is (n, h, w, c) diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh index 081e9a3de578c..ecd06315e3708 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh @@ -21,9 +21,9 @@ // Licensed under the MIT License. #pragma once #include +#include #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cu_inc/common.cuh" -#include "contrib_ops/cuda/diffusion/group_norm_impl.h" using namespace onnxruntime::cuda; @@ -54,11 +54,21 @@ struct GroupSumsOp { } }; -template -inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sum_sq); +template +inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sum_sq) { + using VecT = onnxruntime::cuda::aligned_vector; + const VecT input_v = *reinterpret_cast(src + offset); + +#pragma unroll + for (int i = 0; i < ILP; i++) { + const float val = static_cast(input_v.val[i]); + sum += val; + sum_sq += val * val; + } +} template <> -inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sum_sq) { +inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sum_sq) { // Fetch two channels per thread. __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); @@ -72,7 +82,7 @@ inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, fl } template <> -inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) { +inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) { // Fetch two channels per thread. float2 f2 = *reinterpret_cast(&src[offset]); @@ -84,13 +94,28 @@ inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, f } // Sum for SkipGroupNorm: add_out[offset] = src[offset] + skip[skip_offset] + bias[bias_offset] -template +template inline __device__ void AddSkipBias(T* add_out, const T* src, const T* skip, const T* bias, - int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq); + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + using VecT = onnxruntime::cuda::aligned_vector; + const VecT input_v = *reinterpret_cast(src + offset); + const VecT skip_v = *reinterpret_cast(skip + skip_offset); + const VecT bias_v = *reinterpret_cast(bias + bias_offset); + VecT output_v = *reinterpret_cast(add_out + offset); + +#pragma unroll + for (int i = 0; i < ILP; i++) { + output_v.val[i] = input_v.val[i] + skip_v.val[i] + bias_v.val[i]; + const float val = static_cast(output_v.val[i]); + sum += val; + sum_sq += val * val; + } + *(reinterpret_cast(add_out + offset)) = output_v; +} template <> -inline __device__ void AddSkipBias(half* add_out, const half* src, const half* skip, const half* bias, - int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { +inline __device__ void AddSkipBias(half* add_out, const half* src, const half* skip, const half* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { // Fetch two channels per thread. __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); @@ -106,8 +131,8 @@ inline __device__ void AddSkipBias(half* add_out, const half* src, const half* s } template <> -inline __device__ void AddSkipBias(float* add_out, const float* src, const float* skip, const float* bias, - int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { +inline __device__ void AddSkipBias(float* add_out, const float* src, const float* skip, const float* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { float2 f2 = *reinterpret_cast(&src[offset]); float2 s = *reinterpret_cast(&skip[skip_offset]); float2 b = *reinterpret_cast(&bias[bias_offset]); @@ -121,13 +146,27 @@ inline __device__ void AddSkipBias(float* add_out, const float* src, const float } // Sum for SkipGroupNorm without bias: add_out[offset] = src[offset] + skip[skip_offset] -template +template inline __device__ void AddSkip(T* add_out, const T* src, const T* skip, - int64_t offset, int64_t skip_offset, float& sum, float& sum_sq); + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + using VecT = onnxruntime::cuda::aligned_vector; + const VecT input_v = *reinterpret_cast(src + offset); + const VecT skip_v = *reinterpret_cast(skip + skip_offset); + VecT output_v = *reinterpret_cast(add_out + offset); + +#pragma unroll + for (int i = 0; i < ILP; i++) { + output_v.val[i] = input_v.val[i] + skip_v.val[i]; + const float val = static_cast(output_v.val[i]); + sum += val; + sum_sq += val * val; + } + *(reinterpret_cast(add_out + offset)) = output_v; +} template <> -inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, - int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { +inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); h2 = h2 + s; @@ -140,8 +179,8 @@ inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, } template <> -inline __device__ void AddSkip(float* add_out, const float* src, const float* skip, - int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { +inline __device__ void AddSkip(float* add_out, const float* src, const float* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { float2 f2 = *reinterpret_cast(&src[offset]); float2 s = *reinterpret_cast(&skip[skip_offset]); f2.x += s.x; @@ -151,8 +190,10 @@ inline __device__ void AddSkip(float* add_out, const float* src, const float* sk sum_sq += f2.x * f2.x + f2.y * f2.y; } -template -__global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { +template +__global__ void GroupNormNHWCSumKernel(T* skip_workspace, float* group_sum_buffer, const T* src, const T* skip, const T* bias, + int32_t channels_per_block, int32_t hw_per_block, int32_t hw, int32_t hwc, int32_t c, + int32_t channels_per_group, int32_t groups, int32_t groups_per_block, bool broadcast_skip) { // The object in charge of doing the sums for the different blocks. typedef cub::BlockScan BlockScan; @@ -166,60 +207,60 @@ __global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { int32_t ni = blockIdx.z; // The channel loaded by that thread. - int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; + int32_t ci = blockIdx.x * channels_per_block + threadIdx.x * ILP; - if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { + if (ci >= c || threadIdx.x * ILP >= channels_per_block) { return; } // The first activation loaded by that block. - int32_t hw_begin = blockIdx.y * params.hw_per_block; + int32_t hw_begin = blockIdx.y * hw_per_block; // The last activation loaded by that block. - int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); + int32_t hw_end = min(hw_begin + hw_per_block, hw); // The sums. float sum = 0.F; float sum_sq = 0.F; // Iterate over the activations to compute the sums. - int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; - if (params.skip != nullptr) { + int64_t offset = static_cast(ni) * hwc + static_cast(hw_begin) * c + ci; + if (skip != nullptr) { // SkipGroupNorm: skip is (n, h, w, c) or (n, 1, 1, c) or (n, c), bias is (c), and add_out is (n, h, w, c) const int64_t bias_offset = static_cast(ci); - T* add_out = params.skip_workspace; - if (params.broadcast_skip) { - const int64_t skip_offset = static_cast(ni) * params.c + ci; + T* add_out = skip_workspace; + if (broadcast_skip) { + const int64_t skip_offset = static_cast(ni) * c + ci; - if (params.bias != nullptr) { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkipBias(add_out, params.src, params.skip, params.bias, offset, skip_offset, bias_offset, sum, sum_sq); + if (bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + AddSkipBias(add_out, src, skip, bias, offset, skip_offset, bias_offset, sum, sum_sq); } } else { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkip(add_out, params.src, params.skip, offset, skip_offset, sum, sum_sq); + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + AddSkip(add_out, src, skip, offset, skip_offset, sum, sum_sq); } } } else { - if (params.bias != nullptr) { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkipBias(add_out, params.src, params.skip, params.bias, offset, offset, bias_offset, sum, sum_sq); + if (bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + AddSkipBias(add_out, src, skip, bias, offset, offset, bias_offset, sum, sum_sq); } } else { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkip(add_out, params.src, params.skip, offset, offset, sum, sum_sq); + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + AddSkip(add_out, src, skip, offset, offset, sum, sum_sq); } } } } else { // GroupNorm - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - UpdateSum(params.src, offset, sum, sum_sq); + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + UpdateSum(src, offset, sum, sum_sq); } } // The group index relative to the first group within the same block. - int32_t gi = threadIdx.x * CHANNELS_PER_THREAD / params.channels_per_group; + int32_t gi = threadIdx.x * ILP / channels_per_group; // The channel in the group. - int32_t cj = ci % params.channels_per_group; + int32_t cj = ci % channels_per_group; // The data for the summations. GroupSums inp{cj == 0 ? 1 : 0, sum, sum_sq}; @@ -230,7 +271,7 @@ __global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { // Store the results for the groups in shared memory (to produce coalesced stores later). // For each group, only the last thread of that group is picked to save sum to shared memory. - if (cj == params.channels_per_group - CHANNELS_PER_THREAD) { + if (cj == channels_per_group - ILP) { smem[gi] = make_float2(out.sum, out.sum_sq); } @@ -238,20 +279,41 @@ __global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { __syncthreads(); // Threads that have nothing left to do, exit. - if (threadIdx.x >= params.groups_per_block) { + if (threadIdx.x >= groups_per_block) { return; } // The global group index. // Use neighboring threads for coalesced write. - int32_t gj = blockIdx.x * params.groups_per_block + threadIdx.x; + int32_t gj = blockIdx.x * groups_per_block + threadIdx.x; - if (gj < params.groups) { + if (gj < groups) { float2 sums = smem[threadIdx.x]; - const int index = (2 * ni) * params.groups + gj; - atomicAdd(¶ms.group_sum_buffer[index], sums.x); - atomicAdd(¶ms.group_sum_buffer[index + params.groups], sums.y); + const int index = (2 * ni) * groups + gj; + atomicAdd(&group_sum_buffer[index], sums.x); + atomicAdd(&group_sum_buffer[index + groups], sums.y); + } +} + +template +__device__ void computeGroupNormVec(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev, + const float* gamma_v, const float* beta_v, bool silu) { + using VecT = onnxruntime::cuda::aligned_vector; + const VecT input_v = *reinterpret_cast(src + offset); + VecT output_v; + +#pragma unroll + for (int i = 0; i < ILP; i++) { + float val = static_cast(input_v.val[i]); + val = (val - mean) * inv_std_dev; + val = gamma_v[i] * val + beta_v[i]; + + if (silu) { + val = val * sigmoid(val); + } + output_v.val[i] = static_cast(val); } + *(reinterpret_cast(dst + offset)) = output_v; } template @@ -307,11 +369,51 @@ __device__ void ComputeGroupNorm(const float* src, float* dst, int64_t offset, f *reinterpret_cast(&dst[offset]) = f2; } -template -__global__ void GroupNormNHWCScaleKernel(GroupNormNHWCParams params) { +template +__device__ void ComputeGroupNormKernel(const T* input, T* dst, int64_t offset, float mean, float inv_std_dev, + const float* gamma, const float* beta, bool use_silu, int32_t c, int32_t ci, int32_t hw_begin, int32_t hw_end) { + using VecF = onnxruntime::cuda::aligned_vector; + + const VecF gamma_v = *reinterpret_cast(gamma + ci); + const VecF beta_v = *reinterpret_cast(beta + ci); + // Iterate over the activations to compute the sums. + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + // Fetch ILP channels per thread. + computeGroupNormVec(input, dst, offset, mean, inv_std_dev, gamma_v.val, beta_v.val, use_silu); + } +} + +template <> +__device__ void ComputeGroupNormKernel(const float* input, float* dst, int64_t offset, float mean, float inv_std_dev, + const float* gamma, const float* beta, bool use_silu, int32_t c, int32_t ci, int32_t hw_begin, int32_t hw_end) { + // Load gamma/beta. Fetch two per thread. + float2 gamma_f2 = *reinterpret_cast(&gamma[ci]); + float2 beta_f2 = *reinterpret_cast(&beta[ci]); + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + ComputeGroupNorm(input, dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, use_silu); + } +} + +template <> +__device__ void ComputeGroupNormKernel(const half* input, half* dst, int64_t offset, float mean, float inv_std_dev, + const float* gamma, const float* beta, bool use_silu, int32_t c, int32_t ci, int32_t hw_begin, int32_t hw_end) { + // Load gamma/beta. Fetch two per thread. + float2 gamma_f2 = *reinterpret_cast(&gamma[ci]); + float2 beta_f2 = *reinterpret_cast(&beta[ci]); + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + ComputeGroupNorm(input, dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, use_silu); + } +} + +template +__global__ void GroupNormNHWCScaleKernel(T* dst, const T* src, const T* skip, const float* gamma, const float* beta, + const T* skip_workspace, const float* group_sum_buffer, float epsilon, + int32_t c, int32_t channels_per_block, int32_t channels_per_group, + int32_t groups, int32_t hwc, float inv_hw_channels_per_group, + int32_t hw, int32_t hw_per_block, bool use_silu) { // The channel loaded by that thread. - int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; - if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { + int32_t ci = blockIdx.x * channels_per_block + threadIdx.x * ILP; + if (ci >= c || threadIdx.x * ILP >= channels_per_block) { return; } @@ -319,35 +421,29 @@ __global__ void GroupNormNHWCScaleKernel(GroupNormNHWCParams params) { int32_t ni = blockIdx.z; // The group that thread works on. - int32_t gi = ci / params.channels_per_group; + int32_t gi = ci / channels_per_group; // Load the sum and sum of squares for the group. float sum = 0.F, sum_sq = 0.F; - if (gi < params.groups) { - const int index = (2 * ni) * params.groups + gi; - sum = params.group_sum_buffer[index]; - sum_sq = params.group_sum_buffer[index + params.groups]; + if (gi < groups) { + const int index = (2 * ni) * groups + gi; + sum = group_sum_buffer[index]; + sum_sq = group_sum_buffer[index + groups]; } - // Load gamma/beta. Fetch two per thread. - float2 gamma_f2 = *reinterpret_cast(¶ms.gamma[ci]); - float2 beta_f2 = *reinterpret_cast(¶ms.beta[ci]); - // Compute the mean. - float mean = sum * params.inv_hw_channels_per_group; + float mean = sum * inv_hw_channels_per_group; // Compute the variance. - float var = sum_sq * params.inv_hw_channels_per_group - (mean * mean); + float var = sum_sq * inv_hw_channels_per_group - (mean * mean); // Compute the inverse of the stddev. - float inv_std_dev = rsqrtf(var + params.epsilon); + float inv_std_dev = rsqrtf(var + epsilon); - int32_t hw_begin = blockIdx.y * params.hw_per_block; - int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); + int32_t hw_begin = blockIdx.y * hw_per_block; + int32_t hw_end = min(hw_begin + hw_per_block, hw); - const T* input = (params.skip != nullptr) ? params.skip_workspace : params.src; - int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - ComputeGroupNorm(input, params.dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, params.use_silu); - } + const T* input = (skip != nullptr) ? skip_workspace : src; + int64_t offset = static_cast(ni) * hwc + static_cast(hw_begin) * c + ci; + ComputeGroupNormKernel(input, dst, offset, mean, inv_std_dev, gamma, beta, use_silu, c, ci, hw_begin, hw_end); } } // namespace cuda