From 56740dd099eaba65ce5780aaa334215de1a99d61 Mon Sep 17 00:00:00 2001 From: peixuanzuo Date: Tue, 16 Jan 2024 10:07:53 +0000 Subject: [PATCH 1/5] refactor --- .../cuda/diffusion/group_norm_impl.cu | 47 ++--- .../cuda/diffusion/group_norm_impl_kernel.cuh | 163 ++++++++++++++---- 2 files changed, 156 insertions(+), 54 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu index d7b2cc2379f4f..82f7565af4c99 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -49,23 +49,26 @@ void GroupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) // The number of instances. grid.z = params.n; +#define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \ + GroupNormNHWCSumKernel \ + <<>>( \ + params.skip_workspace, params.group_sum_bufer, params.src, params.skip, params.bias, \ + params.channels_per_block, params.hw_per_block, params.hw, params.hwc, params.c, \ + params.channels_per_group, params.groups, params.groups_per_block, params.broadcast_skip); \ + break; + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. switch (params.threads_per_block) { case 256: - GroupNormNHWCSumKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SUM(256, CHANNELS_PER_THREAD) case 192: - GroupNormNHWCSumKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SUM(192, CHANNELS_PER_THREAD) case 160: - GroupNormNHWCSumKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SUM(160, CHANNELS_PER_THREAD) case 128: - GroupNormNHWCSumKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SUM(128, CHANNELS_PER_THREAD) case 64: - GroupNormNHWCSumKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SUM(64, CHANNELS_PER_THREAD) } } @@ -80,23 +83,27 @@ void GroupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t strea // The number of instances. grid.z = params.n; +#define LAUNCH_GROUPNORM_SCALE(ThreadsPerBlock, VecSize) \ + GroupNormNHWCScaleKernel \ + <<>>( \ + params.dst, params.src, params.skip, params.gamma, params.beta, params.skip_workspace, \ + params.group_sum_buffer, params.epsilon, params.c, params.channels_per_block, params.channels_per_group, \ + params.groups, params.hwc, params.inv_hw_channels_per_group, params.hw, params.hw_per_block, \ + params.use_silu); \ + break; + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. switch (params.threads_per_block) { case 256: - GroupNormNHWCScaleKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SCALE(256, CHANNELS_PER_THREAD) case 192: - GroupNormNHWCScaleKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SCALE(192, CHANNELS_PER_THREAD) case 160: - GroupNormNHWCScaleKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SCALE(160, CHANNELS_PER_THREAD) case 128: - GroupNormNHWCScaleKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SCALE(128, CHANNELS_PER_THREAD) case 64: - GroupNormNHWCScaleKernel<<>>(params); - break; + LAUNCH_GROUPNORM_SCALE(64, CHANNELS_PER_THREAD) } } diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh index 081e9a3de578c..edd04a55c7d94 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh @@ -23,7 +23,6 @@ #include #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cu_inc/common.cuh" -#include "contrib_ops/cuda/diffusion/group_norm_impl.h" using namespace onnxruntime::cuda; @@ -54,11 +53,21 @@ struct GroupSumsOp { } }; -template -inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sum_sq); +template +inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sum_sq) { + using VecT = onnxruntime::cuda::aligned_vector; + const VecT input_v = *reinterpret_cast(src + offset); + +#pragma unroll + for (int i = 0; i < ILP; i++) { + const float val = static_cast(input_v.val[i]); + sum += val; + sum_sq += val * val; + } +} template <> -inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sum_sq) { +inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sum_sq) { // Fetch two channels per thread. __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); @@ -72,7 +81,7 @@ inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, fl } template <> -inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) { +inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) { // Fetch two channels per thread. float2 f2 = *reinterpret_cast(&src[offset]); @@ -84,13 +93,28 @@ inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, f } // Sum for SkipGroupNorm: add_out[offset] = src[offset] + skip[skip_offset] + bias[bias_offset] -template +template inline __device__ void AddSkipBias(T* add_out, const T* src, const T* skip, const T* bias, - int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq); + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + using VecT = onnxruntime::cuda::aligned_vector; + const VecT input_v = *reinterpret_cast(src + offset); + const VecT skip_v = *reinterpret_cast(skip + skip_offset); + const VecT bias_v = *reinterpret_cast(bias + bias_offset); + VecT output_v = *reinterpret_cast(add_out + offset); + +#pragma unroll + for (int i = 0; i < ILP; i++) { + output_v.val[i] = input_v.val[i] + skip_v.val[i] + bias_v.val[i]; + const float val = static_cast(output_v.val[i]); + sum += val; + sum_sq += val * val; + } + *(reinterpret_cast(add_out + offset)) = output_v; +} template <> -inline __device__ void AddSkipBias(half* add_out, const half* src, const half* skip, const half* bias, - int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { +inline __device__ void AddSkipBias(half* add_out, const half* src, const half* skip, const half* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { // Fetch two channels per thread. __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); @@ -106,8 +130,8 @@ inline __device__ void AddSkipBias(half* add_out, const half* src, const half* s } template <> -inline __device__ void AddSkipBias(float* add_out, const float* src, const float* skip, const float* bias, - int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { +inline __device__ void AddSkipBias(float* add_out, const float* src, const float* skip, const float* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { float2 f2 = *reinterpret_cast(&src[offset]); float2 s = *reinterpret_cast(&skip[skip_offset]); float2 b = *reinterpret_cast(&bias[bias_offset]); @@ -121,13 +145,27 @@ inline __device__ void AddSkipBias(float* add_out, const float* src, const float } // Sum for SkipGroupNorm without bias: add_out[offset] = src[offset] + skip[skip_offset] -template +template inline __device__ void AddSkip(T* add_out, const T* src, const T* skip, - int64_t offset, int64_t skip_offset, float& sum, float& sum_sq); + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + using VecT = onnxruntime::cuda::aligned_vector; + const VecT input_v = *reinterpret_cast(src + offset); + const VecT skip_v = *reinterpret_cast(skip + skip_offset); + VecT output_v = *reinterpret_cast(add_out + offset); + +#pragma unroll + for (int i = 0; i < ILP; i++) { + output_v.val[i] = input_v.val[i] + skip_v.val[i]; + const float val = static_cast(output_v.val[i]); + sum += val; + sum_sq += val * val; + } + *(reinterpret_cast(add_out + offset)) = output_v; +} template <> -inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, - int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { +inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); h2 = h2 + s; @@ -140,8 +178,8 @@ inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, } template <> -inline __device__ void AddSkip(float* add_out, const float* src, const float* skip, - int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { +inline __device__ void AddSkip(float* add_out, const float* src, const float* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { float2 f2 = *reinterpret_cast(&src[offset]); float2 s = *reinterpret_cast(&skip[skip_offset]); f2.x += s.x; @@ -151,8 +189,10 @@ inline __device__ void AddSkip(float* add_out, const float* src, const float* sk sum_sq += f2.x * f2.x + f2.y * f2.y; } -template -__global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { +template +__global__ void GroupNormNHWCSumKernel(T* skip_workspace, float* group_sum_buffer, const T* src, const T* skip, const T* bias, + int32_t channels_per_block, int32_t hw_per_block, int32_t hw, int32_t hwc, int32_t c, + int32_t channels_per_group, int32_t groups, int32_t groups_per_block, bool broadcast_skip) { // The object in charge of doing the sums for the different blocks. typedef cub::BlockScan BlockScan; @@ -166,9 +206,9 @@ __global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { int32_t ni = blockIdx.z; // The channel loaded by that thread. - int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; + int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * ILP; - if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { + if (ci >= params.c || threadIdx.x * ILP >= params.channels_per_block) { return; } @@ -217,7 +257,7 @@ __global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { } // The group index relative to the first group within the same block. - int32_t gi = threadIdx.x * CHANNELS_PER_THREAD / params.channels_per_group; + int32_t gi = threadIdx.x * ILP / params.channels_per_group; // The channel in the group. int32_t cj = ci % params.channels_per_group; @@ -230,7 +270,7 @@ __global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { // Store the results for the groups in shared memory (to produce coalesced stores later). // For each group, only the last thread of that group is picked to save sum to shared memory. - if (cj == params.channels_per_group - CHANNELS_PER_THREAD) { + if (cj == params.channels_per_group - ILP) { smem[gi] = make_float2(out.sum, out.sum_sq); } @@ -254,6 +294,27 @@ __global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { } } +template +__device__ void computeGroupNormVec(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev, + const float* gamma_v, const float* beta_v, bool silu) { + using VecT = onnxruntime::rocm::aligned_vector; + const VecT input_v = *reinterpret_cast(src + offset); + VecT output_v; + +#pragma unroll + for (int i = 0; i < ILP; i++) { + float val = static_cast(input_v.val[i]); + val = (val - mean) * inv_std_dev; + val = gamma_v[i] * val + beta_v[i]; + + if (silu) { + val = val * sigmoid(val); + } + output_v.val[i] = static_cast(val); + } + *(reinterpret_cast(dst + offset)) = output_v; +} + template __device__ void ComputeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev, float2& gamma_f2, float2& beta_f2, bool silu); @@ -307,11 +368,51 @@ __device__ void ComputeGroupNorm(const float* src, float* dst, int64_t offset, f *reinterpret_cast(&dst[offset]) = f2; } -template -__global__ void GroupNormNHWCScaleKernel(GroupNormNHWCParams params) { +template +__device__ void ComputeGroupNormKernel(const T* input, T* dst, int64_t offset, float mean, float inv_std_dev, + const float* gamma, const float* beta, bool use_silu, int32_t ci, int32_t hw_begin, int32_t hw_end) { + using VecF = onnxruntime::rocm::aligned_vector; + + const VecF gamma_v = *reinterpret_cast(gamma + ci); + const VecF beta_v = *reinterpret_cast(beta + ci); + // Iterate over the activations to compute the sums. + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + // Fetch ILP channels per thread. + computeGroupNormVec(input, dst, offset, mean, inv_std_dev, gamma_v.val, beta_v.val, use_silu); + } +} + +template <> +__device__ void ComputeGroupNormKernel(const T* input, T* dst, int64_t offset, float mean, float inv_std_dev, + const float* gamma, const float* beta, bool use_silu, int32_t ci, int32_t hw_begin, int32_t hw_end) { + // Load gamma/beta. Fetch two per thread. + float2 gamma_f2 = *reinterpret_cast(&gamma[ci]); + float2 beta_f2 = *reinterpret_cast(&beta[ci]); + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + ComputeGroupNorm(input, dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, use_silu); + } +} + +template <> +__device__ void ComputeGroupNormKernel(const T* input, T* dst, int64_t offset, float mean, float inv_std_dev, + const float* gamma, const float* beta, bool use_silu, int32_t ci, int32_t hw_begin, int32_t hw_end) { + // Load gamma/beta. Fetch two per thread. + float2 gamma_f2 = *reinterpret_cast(&gamma[ci]); + float2 beta_f2 = *reinterpret_cast(&beta[ci]); + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + ComputeGroupNorm(input, dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, use_silu); + } +} + +template +__global__ void GroupNormNHWCScaleKernel(T* dst, const T* src, const T* skip, const float* gamma, const float* beta, + const T* skip_workspace, const float* group_sum_buffer, float epsilon, + int32_t c, int32_t channels_per_block, int32_t channels_per_group, + int32_t groups, int32_t hwc, float inv_hw_channels_per_group, + int32_t hw, int32_t hw_per_block, bool use_silu) { // The channel loaded by that thread. - int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; - if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { + int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * ILP; + if (ci >= params.c || threadIdx.x * ILP >= params.channels_per_block) { return; } @@ -329,10 +430,6 @@ __global__ void GroupNormNHWCScaleKernel(GroupNormNHWCParams params) { sum_sq = params.group_sum_buffer[index + params.groups]; } - // Load gamma/beta. Fetch two per thread. - float2 gamma_f2 = *reinterpret_cast(¶ms.gamma[ci]); - float2 beta_f2 = *reinterpret_cast(¶ms.beta[ci]); - // Compute the mean. float mean = sum * params.inv_hw_channels_per_group; // Compute the variance. @@ -345,9 +442,7 @@ __global__ void GroupNormNHWCScaleKernel(GroupNormNHWCParams params) { const T* input = (params.skip != nullptr) ? params.skip_workspace : params.src; int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - ComputeGroupNorm(input, params.dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, params.use_silu); - } + ComputeGroupNormKernel(input, params.dst, offset, mean, inv_std_dev, params.gamma, params.beta, params.use_silu, ci, hw_begin, hw_end); } } // namespace cuda From ce0c19fef12a90764b7b447ca14cc2c885c56bc3 Mon Sep 17 00:00:00 2001 From: Peixuan Zuo Date: Tue, 16 Jan 2024 10:52:43 +0000 Subject: [PATCH 2/5] update --- .../cuda/diffusion/group_norm_impl.cu | 2 +- .../cuda/diffusion/group_norm_impl_kernel.cuh | 106 +++++++++--------- 2 files changed, 54 insertions(+), 54 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu index 82f7565af4c99..534add5196808 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -52,7 +52,7 @@ void GroupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) #define LAUNCH_GROUPNORM_SUM(ThreadsPerBlock, VecSize) \ GroupNormNHWCSumKernel \ <<>>( \ - params.skip_workspace, params.group_sum_bufer, params.src, params.skip, params.bias, \ + params.skip_workspace, params.group_sum_buffer, params.src, params.skip, params.bias, \ params.channels_per_block, params.hw_per_block, params.hw, params.hwc, params.c, \ params.channels_per_group, params.groups, params.groups_per_block, params.broadcast_skip); \ break; diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh index edd04a55c7d94..cfeeac2049355 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh @@ -81,7 +81,7 @@ inline __device__ void UpdateSum(const half* src, int64_t offset, float } template <> -inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) { +inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) { // Fetch two channels per thread. float2 f2 = *reinterpret_cast(&src[offset]); @@ -206,60 +206,60 @@ __global__ void GroupNormNHWCSumKernel(T* skip_workspace, float* group_sum_buffe int32_t ni = blockIdx.z; // The channel loaded by that thread. - int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * ILP; + int32_t ci = blockIdx.x * channels_per_block + threadIdx.x * ILP; - if (ci >= params.c || threadIdx.x * ILP >= params.channels_per_block) { + if (ci >= c || threadIdx.x * ILP >= channels_per_block) { return; } // The first activation loaded by that block. - int32_t hw_begin = blockIdx.y * params.hw_per_block; + int32_t hw_begin = blockIdx.y * hw_per_block; // The last activation loaded by that block. - int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); + int32_t hw_end = min(hw_begin + hw_per_block, hw); // The sums. float sum = 0.F; float sum_sq = 0.F; // Iterate over the activations to compute the sums. - int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; - if (params.skip != nullptr) { + int64_t offset = static_cast(ni) * hwc + static_cast(hw_begin) * c + ci; + if (skip != nullptr) { // SkipGroupNorm: skip is (n, h, w, c) or (n, 1, 1, c) or (n, c), bias is (c), and add_out is (n, h, w, c) const int64_t bias_offset = static_cast(ci); - T* add_out = params.skip_workspace; - if (params.broadcast_skip) { - const int64_t skip_offset = static_cast(ni) * params.c + ci; + T* add_out = skip_workspace; + if (broadcast_skip) { + const int64_t skip_offset = static_cast(ni) * c + ci; - if (params.bias != nullptr) { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkipBias(add_out, params.src, params.skip, params.bias, offset, skip_offset, bias_offset, sum, sum_sq); + if (bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + AddSkipBias(add_out, src, skip, bias, offset, skip_offset, bias_offset, sum, sum_sq); } } else { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkip(add_out, params.src, params.skip, offset, skip_offset, sum, sum_sq); + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + AddSkip(add_out, src, skip, offset, skip_offset, sum, sum_sq); } } } else { - if (params.bias != nullptr) { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkipBias(add_out, params.src, params.skip, params.bias, offset, offset, bias_offset, sum, sum_sq); + if (bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + AddSkipBias(add_out, src, skip, bias, offset, offset, bias_offset, sum, sum_sq); } } else { - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - AddSkip(add_out, params.src, params.skip, offset, offset, sum, sum_sq); + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + AddSkip(add_out, src, skip, offset, offset, sum, sum_sq); } } } } else { // GroupNorm - for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { - UpdateSum(params.src, offset, sum, sum_sq); + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += c) { + UpdateSum(src, offset, sum, sum_sq); } } // The group index relative to the first group within the same block. - int32_t gi = threadIdx.x * ILP / params.channels_per_group; + int32_t gi = threadIdx.x * ILP / channels_per_group; // The channel in the group. - int32_t cj = ci % params.channels_per_group; + int32_t cj = ci % channels_per_group; // The data for the summations. GroupSums inp{cj == 0 ? 1 : 0, sum, sum_sq}; @@ -270,7 +270,7 @@ __global__ void GroupNormNHWCSumKernel(T* skip_workspace, float* group_sum_buffe // Store the results for the groups in shared memory (to produce coalesced stores later). // For each group, only the last thread of that group is picked to save sum to shared memory. - if (cj == params.channels_per_group - ILP) { + if (cj == channels_per_group - ILP) { smem[gi] = make_float2(out.sum, out.sum_sq); } @@ -278,26 +278,26 @@ __global__ void GroupNormNHWCSumKernel(T* skip_workspace, float* group_sum_buffe __syncthreads(); // Threads that have nothing left to do, exit. - if (threadIdx.x >= params.groups_per_block) { + if (threadIdx.x >= groups_per_block) { return; } // The global group index. // Use neighboring threads for coalesced write. - int32_t gj = blockIdx.x * params.groups_per_block + threadIdx.x; + int32_t gj = blockIdx.x * groups_per_block + threadIdx.x; - if (gj < params.groups) { + if (gj < groups) { float2 sums = smem[threadIdx.x]; - const int index = (2 * ni) * params.groups + gj; - atomicAdd(¶ms.group_sum_buffer[index], sums.x); - atomicAdd(¶ms.group_sum_buffer[index + params.groups], sums.y); + const int index = (2 * ni) * groups + gj; + atomicAdd(&group_sum_buffer[index], sums.x); + atomicAdd(&group_sum_buffer[index + groups], sums.y); } } template __device__ void computeGroupNormVec(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev, const float* gamma_v, const float* beta_v, bool silu) { - using VecT = onnxruntime::rocm::aligned_vector; + using VecT = onnxruntime::cuda::aligned_vector; const VecT input_v = *reinterpret_cast(src + offset); VecT output_v; @@ -370,8 +370,8 @@ __device__ void ComputeGroupNorm(const float* src, float* dst, int64_t offset, f template __device__ void ComputeGroupNormKernel(const T* input, T* dst, int64_t offset, float mean, float inv_std_dev, - const float* gamma, const float* beta, bool use_silu, int32_t ci, int32_t hw_begin, int32_t hw_end) { - using VecF = onnxruntime::rocm::aligned_vector; + const float* gamma, const float* beta, bool use_silu, int32_t c, int32_t ci, int32_t hw_begin, int32_t hw_end) { + using VecF = onnxruntime::cuda::aligned_vector; const VecF gamma_v = *reinterpret_cast(gamma + ci); const VecF beta_v = *reinterpret_cast(beta + ci); @@ -383,8 +383,8 @@ __device__ void ComputeGroupNormKernel(const T* input, T* dst, int64_t offset, f } template <> -__device__ void ComputeGroupNormKernel(const T* input, T* dst, int64_t offset, float mean, float inv_std_dev, - const float* gamma, const float* beta, bool use_silu, int32_t ci, int32_t hw_begin, int32_t hw_end) { +__device__ void ComputeGroupNormKernel(const float* input, float* dst, int64_t offset, float mean, float inv_std_dev, + const float* gamma, const float* beta, bool use_silu, int32_t c, int32_t ci, int32_t hw_begin, int32_t hw_end) { // Load gamma/beta. Fetch two per thread. float2 gamma_f2 = *reinterpret_cast(&gamma[ci]); float2 beta_f2 = *reinterpret_cast(&beta[ci]); @@ -394,8 +394,8 @@ __device__ void ComputeGroupNormKernel(const T* input, T* dst, int64_t } template <> -__device__ void ComputeGroupNormKernel(const T* input, T* dst, int64_t offset, float mean, float inv_std_dev, - const float* gamma, const float* beta, bool use_silu, int32_t ci, int32_t hw_begin, int32_t hw_end) { +__device__ void ComputeGroupNormKernel(const half* input, half* dst, int64_t offset, float mean, float inv_std_dev, + const float* gamma, const float* beta, bool use_silu, int32_t c, int32_t ci, int32_t hw_begin, int32_t hw_end) { // Load gamma/beta. Fetch two per thread. float2 gamma_f2 = *reinterpret_cast(&gamma[ci]); float2 beta_f2 = *reinterpret_cast(&beta[ci]); @@ -411,8 +411,8 @@ __global__ void GroupNormNHWCScaleKernel(T* dst, const T* src, const T* skip, co int32_t groups, int32_t hwc, float inv_hw_channels_per_group, int32_t hw, int32_t hw_per_block, bool use_silu) { // The channel loaded by that thread. - int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * ILP; - if (ci >= params.c || threadIdx.x * ILP >= params.channels_per_block) { + int32_t ci = blockIdx.x * channels_per_block + threadIdx.x * ILP; + if (ci >= c || threadIdx.x * ILP >= channels_per_block) { return; } @@ -420,29 +420,29 @@ __global__ void GroupNormNHWCScaleKernel(T* dst, const T* src, const T* skip, co int32_t ni = blockIdx.z; // The group that thread works on. - int32_t gi = ci / params.channels_per_group; + int32_t gi = ci / channels_per_group; // Load the sum and sum of squares for the group. float sum = 0.F, sum_sq = 0.F; - if (gi < params.groups) { - const int index = (2 * ni) * params.groups + gi; - sum = params.group_sum_buffer[index]; - sum_sq = params.group_sum_buffer[index + params.groups]; + if (gi < groups) { + const int index = (2 * ni) * groups + gi; + sum = group_sum_buffer[index]; + sum_sq = group_sum_buffer[index + groups]; } // Compute the mean. - float mean = sum * params.inv_hw_channels_per_group; + float mean = sum * inv_hw_channels_per_group; // Compute the variance. - float var = sum_sq * params.inv_hw_channels_per_group - (mean * mean); + float var = sum_sq * inv_hw_channels_per_group - (mean * mean); // Compute the inverse of the stddev. - float inv_std_dev = rsqrtf(var + params.epsilon); + float inv_std_dev = rsqrtf(var + epsilon); - int32_t hw_begin = blockIdx.y * params.hw_per_block; - int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); + int32_t hw_begin = blockIdx.y * hw_per_block; + int32_t hw_end = min(hw_begin + hw_per_block, hw); - const T* input = (params.skip != nullptr) ? params.skip_workspace : params.src; - int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; - ComputeGroupNormKernel(input, params.dst, offset, mean, inv_std_dev, params.gamma, params.beta, params.use_silu, ci, hw_begin, hw_end); + const T* input = (skip != nullptr) ? skip_workspace : src; + int64_t offset = static_cast(ni) * hwc + static_cast(hw_begin) * c + ci; + ComputeGroupNormKernel(input, dst, offset, mean, inv_std_dev, gamma, beta, use_silu, c, ci, hw_begin, hw_end); } } // namespace cuda From 0145d8ce0594d8a109eee5e3b9645ca8d8f6f6b2 Mon Sep 17 00:00:00 2001 From: Peixuan Zuo Date: Wed, 17 Jan 2024 07:07:37 +0000 Subject: [PATCH 3/5] add tuning_ctx --- onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc | 9 ++++++--- .../contrib_ops/cuda/diffusion/group_norm_impl.cu | 12 +++++++++--- .../contrib_ops/cuda/diffusion/group_norm_impl.h | 5 ++++- 3 files changed, 19 insertions(+), 7 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index 87e88ac31c998..dea5391c7629b 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -24,7 +24,8 @@ namespace { template struct DispatchGroupNorm { - Status operator()(cudaStream_t stream, + Status operator()(CudaTuningContext* tuning_ctx, + Stream* ort_stream, Tensor* output, Tensor* add_out, const Tensor* input, @@ -44,7 +45,8 @@ struct DispatchGroupNorm { int channels_per_block) { typedef typename ToCudaType::MappedType CudaT; return LaunchGroupNormKernel( - stream, + tuning_ctx, + ort_stream, reinterpret_cast(output->MutableData()), add_out == nullptr ? nullptr : reinterpret_cast(add_out->MutableData()), reinterpret_cast(input->Data()), @@ -209,7 +211,8 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { context->GetComputeStream()); utils::MLTypeCallDispatcher dispatcher(input->GetElementType()); - return dispatcher.InvokeRet(Stream(context), output, add_out, input, skip, bias, + return dispatcher.InvokeRet(GetTuningContext(), + context->GetComputeStream(), output, add_out, input, skip, bias, gamma, beta, workspace.get(), epsilon_, batch_size, diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu index 534add5196808..847a1ff8b595c 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -109,7 +109,8 @@ void GroupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t strea template Status LaunchGroupNormKernel( - cudaStream_t stream, + CudaTuningContext* tuning_ctx, + Stream* ort_stream, T* output, T* add_out, const T* input, @@ -127,6 +128,10 @@ Status LaunchGroupNormKernel( bool use_silu, bool broadcast_skip, int channels_per_block) { + + // tuning_ctx only used for ROCm EP. + ORT_UNUSED_PARAMETER(tuning_ctx); + GroupNormNHWCParams params(output, add_out, input, skip, bias, gamma, beta, workspace, epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, channels_per_block); @@ -142,6 +147,7 @@ Status LaunchGroupNormKernel( " groups=", num_groups); } + auto stream = static_cast(ort_stream->GetHandle()); CUDA_RETURN_IF_ERROR(cudaMemsetAsync( params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), stream)); @@ -157,14 +163,14 @@ Status LaunchGroupNormKernel( return Status::OK(); } -template Status LaunchGroupNormKernel(cudaStream_t stream, half* output, half* add_out, +template Status LaunchGroupNormKernel(CudaTuningContext* tuning_ctx, Stream* stream, half* output, half* add_out, const half* input, const half* skip, const half* bias, const float* gamma, const float* beta, void* workspace, float epsilon, int batch_size, int num_channels, int height, int width, int num_groups, bool silu, bool broadcast_skip, int channels_per_block); -template Status LaunchGroupNormKernel(cudaStream_t stream, float* output, float* add_out, +template Status LaunchGroupNormKernel(CudaTuningContext* tuning_ctx, Stream* stream, float* output, float* add_out, const float* input, const float* skip, const float* bias, const float* gamma, const float* beta, void* workspace, float epsilon, int batch_size, int num_channels, diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h index 9532aeecb2f57..98f38a1475eee 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h @@ -8,6 +8,8 @@ #include #include +#include "core/providers/cuda/tunable/cuda_tunable.h" + namespace onnxruntime { namespace contrib { namespace cuda { @@ -21,7 +23,8 @@ int GetChannelsPerBlock(int num_channels, int num_groups); template Status LaunchGroupNormKernel( - cudaStream_t stream, + CudaTuningContext* tuning_ctx, + Stream* ort_stream, T* output, // normalized output tensor. Shape is (n, h, w, c) T* add_out, // optional output tensor for element-wise sum of input + skip + bias. Shape is (n, h, w, c) const T* input, // input tensor. Shape is (n, h, w, c) From 0c095e884809cb63c5482940d4a36586c0855c45 Mon Sep 17 00:00:00 2001 From: Peixuan Zuo Date: Wed, 17 Jan 2024 07:42:09 +0000 Subject: [PATCH 4/5] update --- .../contrib_ops/cuda/diffusion/group_norm_common_base.h | 4 ++-- onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h index 84f3403b8d5ae..ea87d0c29111e 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_common_base.h @@ -126,7 +126,7 @@ struct GroupNormNHWCParams { const T* bias, const float* gamma, const float* beta, - void* workspace, + float* workspace, float epsilon, int batch_size, int num_channels, @@ -151,7 +151,7 @@ struct GroupNormNHWCParams { this->bias = bias; this->gamma = gamma; this->beta = beta; - this->group_sum_buffer = reinterpret_cast(workspace); + this->group_sum_buffer = workspace; this->n = batch_size; this->h = height; this->w = width; diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu index 847a1ff8b595c..4909dc5e3897b 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -132,7 +132,7 @@ Status LaunchGroupNormKernel( // tuning_ctx only used for ROCm EP. ORT_UNUSED_PARAMETER(tuning_ctx); - GroupNormNHWCParams params(output, add_out, input, skip, bias, gamma, beta, workspace, epsilon, + GroupNormNHWCParams params(output, add_out, input, skip, bias, gamma, beta, reinterpret_cast(workspace), epsilon, batch_size, num_channels, height, width, num_groups, use_silu, broadcast_skip, channels_per_block); From 852e40877e0fcb15c02543aa7347ec2bfb5320d1 Mon Sep 17 00:00:00 2001 From: peixuanzuo Date: Wed, 17 Jan 2024 08:30:41 +0000 Subject: [PATCH 5/5] update --- .../contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh index cfeeac2049355..ecd06315e3708 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl_kernel.cuh @@ -21,6 +21,7 @@ // Licensed under the MIT License. #pragma once #include +#include #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cu_inc/common.cuh"