diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index 301b2e76b1b2d..6a8faece2034b 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -53,7 +53,7 @@ struct DispatchGroupNorm { } // namespace -GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { +GroupNormBase::GroupNormBase(const OpKernelInfo& op_info) { epsilon_ = op_info.GetAttrOrDefault("epsilon", 1e-5f); ORT_ENFORCE(epsilon_ >= 0); @@ -70,6 +70,10 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { channels_last_ = (op_info.GetAttrOrDefault("channels_last", static_cast(1)) != 0); } + +GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info), GroupNormBase(op_info) { +} + Status GroupNorm::ComputeInternal(OpKernelContext* context) const { const Tensor* input = context->Input(0); const Tensor* gamma = context->Input(1); diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h index 52c006e6bdb96..024f8541fbd18 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h @@ -11,18 +11,39 @@ namespace cuda { using namespace onnxruntime::cuda; -class GroupNorm final : public CudaKernel { + +class GroupNormBase { public: - GroupNorm(const OpKernelInfo& op_kernel_info); - Status ComputeInternal(OpKernelContext* context) const override; + GroupNormBase(const OpKernelInfo& op_kernel_info); + Status CheckInputs(OpKernelContext* context) const; - private: + protected: bool use_swish_activation_; float epsilon_; int num_groups_; bool channels_last_; }; +class GroupNorm final : public CudaKernel, public GroupNormBase{ + public: + GroupNorm(const OpKernelInfo& op_kernel_info); + Status ComputeInternal(OpKernelContext* context) const override; +}; + + +class SkipGroupNorm final : public CudaKernel, public GroupNormBase{ + public: + SkipGroupNorm(const OpKernelInfo& op_kernel_info); + Status ComputeInternal(OpKernelContext* context) const override; +}; + +class BiasGroupNorm final : public CudaKernel, public GroupNormBase{ + public: + BiasGroupNorm(const OpKernelInfo& op_kernel_info); + Status ComputeInternal(OpKernelContext* context) const override; +}; + + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu index 8252a6a2d7cb7..f1fd4be1053aa 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -78,12 +78,31 @@ struct GroupSumsOp { } }; +enum GroupNormOperatorType { + GroupNormOp, + SkipGroupNormOp, + BiasGroupNormOp, +} + +// It supports 3 operators: +// (1) GroupNorm: skip, bias and add_out not exists +// (2) SkipGroupNorm: skip is (n, h, w, c) and bias is (c), add_out is (n, h, w, c) +// The additional output add_out = src + skip + bias is also the input of group normalization. +// (3) BiasGroupNorm: bias is (n, 1, 1, c), add_out and skip are empty + template struct GroupNormNHWCParams { - // The output buffer. Layout NHWC. + // The output buffer. Shape is (n, h, w, c) T* dst; - // The input buffer. Layout NHWC. + // Optional output of element-wise add of src, skip and bias. Shape is (n, h, w, c) + T* add_out; + // The input buffer. Shape is (n, h, w, c) T const* src; + // Optional input buffer for skip. Shape is (n, h, w, c) + T const* skip; + // Optional input buffer for bias. Shape is (c) or (n, 1, 1, c) + T const* bias; + // The gamma scaling factor. float const* gamma; // The beta term to add in GN. @@ -123,7 +142,7 @@ struct GroupNormNHWCParams { int32_t groupsPerBlock; // Number of threads per block - int32_t threads_per_block; + int32_t threadsPerBlock; float epsilon; }; @@ -157,7 +176,7 @@ inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, f sumSq += f2.x * f2.x + f2.y * f2.y; } -template +template __global__ void groupNormNHWCSumKernel(GroupNormNHWCParams params) { // The object in charge of doing the sums for the different blocks. typedef cub::BlockScan BlockScan; @@ -188,9 +207,30 @@ __global__ void groupNormNHWCSumKernel(GroupNormNHWCParams params) { float sumSq = 0.F; // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - int64_t offset = static_cast(ni) * params.hwc + static_cast(hwi) * params.c + ci; - UpdateSum(params.src, offset, sum, sumSq); + if (OP == SkipGroupNormOp) { + int64_t bias_offset = static_cast(ci); + for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + UpdateSum(params.bias, bias_offset, sum, sumSq); + int64_t offset = static_cast(ni) * params.hwc + static_cast(hwi) * params.c + ci; + UpdateSum(params.src, offset, sum, sumSq); + UpdateSum(params.skip, offset, sum, sumSq); + } + } + + if (OP == BiasGroupNormOp) { + int64_t bias_offset = static_cast(ni) * params.c + ci; + for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + UpdateSum(params.bias, bias_offset, sum, sumSq); + int64_t offset = static_cast(ni) * params.hwc + static_cast(hwi) * params.c + ci; + UpdateSum(params.src, offset, sum, sumSq); + } + } + + if (OP == GroupNormOp) { + for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { + int64_t offset = static_cast(ni) * params.hwc + static_cast(hwi) * params.c + ci; + UpdateSum(params.src, offset, sum, sumSq); + } } // The group that thread works on and the channel in the group (modulus). @@ -233,7 +273,7 @@ void groupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) dim3 grid; // The number of blocks to compute all the channels. - grid.x = divUp(params.c, params.cPerBlock); + grid.x = params.c / params.cPerBlock; // The number of blocks to compute all the activations in a given instance. grid.y = divUp(params.hw, params.hwPerBlock); @@ -242,7 +282,7 @@ void groupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) grid.z = params.n; // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. - switch (params.threads_per_block) { + switch (params.threadsPerBlock) { case 256: groupNormNHWCSumKernel<<>>(params); break; @@ -368,13 +408,13 @@ void groupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t strea dim3 grid; // The number of blocks to compute all the channels. - grid.x = divUp(params.c, params.cPerBlock); + grid.x = params.c / params.cPerBlock; // The number of blocks to compute all the activations in a given instance. grid.y = divUp(params.hw, params.hwPerBlock); // The number of instances. grid.z = params.n; - switch (params.threads_per_block) { + switch (params.threadsPerBlock) { case 256: groupNormNHWCScaleKernel<<>>(params); break; @@ -511,7 +551,7 @@ Status LaunchGroupNormKernel( ORT_NOT_IMPLEMENTED("Not implemented"); } - params.threads_per_block = nextSize(cPerBlock) / CHANNELS_PER_THREAD; + params.threadsPerBlock = nextSize(cPerBlock) / CHANNELS_PER_THREAD; cudaMemsetAsync(params.redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), stream);