diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 890403556cc47..ed1049b0bd73a 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -95,6 +95,7 @@ Do not modify directly.* * com.microsoft.RotaryEmbedding * com.microsoft.SampleOp * com.microsoft.Sampling + * com.microsoft.SkipGroupNorm * com.microsoft.SkipLayerNormalization * com.microsoft.SkipSimplifiedLayerNormalization * com.microsoft.Snpe @@ -2342,7 +2343,7 @@ This version of the operator has been available since version 1 of the 'com.micr
activation : int (required)
-
Activation after group normalization: 0 for None, 1 for Swish
+
Activation after group normalization: 0 for None, 1 for SiLU
channels_last : int
1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.
epsilon : float
@@ -2582,6 +2583,7 @@ This version of the operator has been available since version 1 of the 'com.micr Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + #### Version This version of the operator has been available since version 1 of the 'com.microsoft' operator set. @@ -5083,6 +5085,72 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.SkipGroupNorm** + + This operator element-wise adds x, skip and bias, then apply group normalization and optional activation. + + This operator transforms input according to + s = x + skip + bias + y = gamma * (s - mean) / sqrt(variance + epsilon) + beta + + The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. + The num_channels must be divisible by num_groups. + The mean and standard-deviation of s are calculated separately over the each group. + The weight and bias are per-channel affine transform parameter vectors of size num_channels. + + The activation attribute can be used to enable activation after group normalization. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
activation : int (required)
+
Activation after group normalization: 0 for None, 1 for SiLU
+
channels_last : int
+
1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.
+
epsilon : float
+
The epsilon value to use to avoid division by zero
+
groups : int (required)
+
The number of groups of channels. It should be a divisor of the number of channels C
+
+ +#### Inputs (4 - 5) + +
+
X : T
+
Input data tensor. Dimensions are (N x H x W x C) when channels_last is 1 or (N x C x H x W) otherwise, where N is the batch size, C is the number of channels, and H and W are the height and width of the data
+
gamma : M
+
1D gamma tensor for normalization with shape (C), where C is number of channels
+
beta : M
+
1D beta tensor for normalization with shape (C), where C is number of channels
+
skip : T
+
4D or 2D skip tensor. The shape can be (N x H x W x C) or (N x 1 x 1 x C) or (N x C)
+
bias (optional) : T
+
1D bias tensor. Dimensions are (C), where C is number of channels
+
+ +#### Outputs (1 - 2) + +
+
Y : T
+
The output tensor of the same shape as X
+
S (optional) : T
+
The element-wise sum of input x, skip and bias tensors. It has the same shape as X
+
+ +#### Type Constraints + +
+
T : tensor(float16), tensor(float)
+
Constrain input X, skip, bias and output Y, S types to float tensors.
+
M : tensor(float16), tensor(float)
+
Constrain gamma and beta to float tensors.
+
+ + ### **com.microsoft.SkipLayerNormalization** Skip and Layer Normalization Fusion diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index bfb7716dc5cea..dcdf73cbdbf08 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -861,6 +861,7 @@ Do not modify directly.* |Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)| |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)| +|SkipGroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*in* skip:**T**
*in* bias:**T**
*out* Y:**T**
*out* S:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 2618fe4a238bd..d51915b85095f 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -97,6 +97,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Samp class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SkipGroupNorm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); @@ -269,6 +270,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index 301b2e76b1b2d..87e88ac31c998 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -1,6 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. - #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/diffusion/group_norm.h" #include "contrib_ops/cuda/diffusion/group_norm_impl.h" @@ -15,14 +14,22 @@ ONNX_OPERATOR_KERNEL_EX( GroupNorm, kMSDomain, 1, kCudaExecutionProvider, (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()), GroupNorm); +ONNX_OPERATOR_KERNEL_EX( + SkipGroupNorm, kMSDomain, 1, kCudaExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints()), GroupNorm); + using namespace ONNX_NAMESPACE; namespace { + template struct DispatchGroupNorm { Status operator()(cudaStream_t stream, Tensor* output, + Tensor* add_out, const Tensor* input, + const Tensor* skip, + const Tensor* bias, const Tensor* gamma, const Tensor* beta, void* workspace, @@ -32,12 +39,17 @@ struct DispatchGroupNorm { int height, int width, int num_groups, - bool use_swish_activation) { + bool use_swish_activation, + bool broadcast_skip, + int channels_per_block) { typedef typename ToCudaType::MappedType CudaT; return LaunchGroupNormKernel( stream, reinterpret_cast(output->MutableData()), + add_out == nullptr ? nullptr : reinterpret_cast(add_out->MutableData()), reinterpret_cast(input->Data()), + skip == nullptr ? nullptr : reinterpret_cast(skip->Data()), + bias == nullptr ? nullptr : reinterpret_cast(bias->Data()), gamma->Data(), beta->Data(), workspace, @@ -47,13 +59,21 @@ struct DispatchGroupNorm { height, width, num_groups, - use_swish_activation); + use_swish_activation, + broadcast_skip, + channels_per_block); } }; } // namespace GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { + has_skip_ = false; + const std::string& op_name = op_info.GetKernelDef().OpName(); + if (op_name == "SkipGroupNorm") { + has_skip_ = true; + } + epsilon_ = op_info.GetAttrOrDefault("epsilon", 1e-5f); ORT_ENFORCE(epsilon_ >= 0); @@ -68,6 +88,23 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { use_swish_activation_ = (activation == 1); channels_last_ = (op_info.GetAttrOrDefault("channels_last", static_cast(1)) != 0); + + channels_per_block_ = 0; +} + +Status GroupNorm::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr /*alloc*/, + bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { + is_packed = false; + + // Compute and cache cPerBlock using number of channels from gamma tensor shape. + if (input_idx == 1) { + auto gamma_shape = tensor.Shape(); + if (gamma_shape.NumDimensions() == 1) { + channels_per_block_ = GetChannelsPerBlock(static_cast(gamma_shape[0]), num_groups_); + } + } + + return Status::OK(); } Status GroupNorm::ComputeInternal(OpKernelContext* context) const { @@ -77,22 +114,38 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { Tensor* output = context->Output(0, input->Shape()); if (!channels_last_) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "only the channels_last layout is supported"); } + if (!gamma->IsDataType() || !beta->IsDataType()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "GroupNorm only supports gamma and beta in float type"); + } + const auto& input_dims = input->Shape().GetDims(); if (input_dims.size() != 4) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input is expected to have 4 dimensions, got ", input_dims.size()); } + // Only support NHWC format right now. + int batch_size = static_cast(input_dims[0]); + int height = static_cast(input_dims[1]); + int width = static_cast(input_dims[2]); + int num_channels = static_cast(input_dims[3]); + + if (num_channels % num_groups_ != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "number of channels should be divisiable by num_groups"); + } + const auto& gamma_dims = gamma->Shape().GetDims(); if (gamma_dims.size() != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "gamma is expected to have 1 dimension, got ", gamma_dims.size()); } - if (gamma_dims[0] != input_dims[3]) { + if (gamma_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of channels in gamma and input does not match"); } @@ -102,22 +155,11 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "beta is expected to have 1 dimension, got ", beta_dims.size()); } - if (beta_dims[0] != input_dims[3]) { + if (beta_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of channels in beta and input does not match"); } - // Input and output format is NHWC - int batch_size = static_cast(input_dims[0]); - int num_channels = static_cast(input_dims[3]); - int height = static_cast(input_dims[1]); - int width = static_cast(input_dims[2]); - - if (num_channels % num_groups_ != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "number of channels should be divisiable by num_groups"); - } - if (context->GetUseDeterministicCompute()) { static std::once_flag log_warning; std::call_once(log_warning, []() { @@ -125,17 +167,59 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { }); } - auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream()); + const Tensor* skip = nullptr; + const Tensor* bias = nullptr; + Tensor* add_out = nullptr; + + bool broadcast_skip = false; + if (has_skip_) { + skip = context->Input(3); + bias = context->Input(4); + add_out = context->Output(1, input->Shape()); + + if (bias != nullptr) { // Bias is optional + // If provided, bias has shape (C). + const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "bias is expected to have 1 dimension, got ", bias_dims.size()); + } + if (bias_dims[0] != num_channels) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of channels in bias and input does not match"); + } + } + + // Check whether skip can be broadcasted to input shape. + if (skip->Shape() != input->Shape()) { + const auto& dims = skip->Shape().GetDims(); + // The shape of ship can be (N, C) or (N, 1, 1, C) for broadcast. + const bool b2 = (dims.size() == 2 && dims[0] == batch_size && dims[1] == num_channels); + const bool b4 = (dims.size() == 4 && dims[0] == batch_size && + dims[1] == 1 && dims[2] == 1 && dims[3] == num_channels); + broadcast_skip = b2 || b4; + if (!broadcast_skip) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "skip shape is expected to be (N, H, W, C) or (N, 1, 1, C) or (N, C)"); + } + } + } + + auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups_), + context->GetComputeStream()); utils::MLTypeCallDispatcher dispatcher(input->GetElementType()); - return dispatcher.InvokeRet(Stream(context), output, input, gamma, beta, workspace.get(), + return dispatcher.InvokeRet(Stream(context), output, add_out, input, skip, bias, + gamma, beta, workspace.get(), epsilon_, batch_size, num_channels, height, width, num_groups_, - use_swish_activation_); + use_swish_activation_, + broadcast_skip, + channels_per_block_); } } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h index 52c006e6bdb96..b408b3c1ee79b 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h @@ -16,11 +16,16 @@ class GroupNorm final : public CudaKernel { GroupNorm(const OpKernelInfo& op_kernel_info); Status ComputeInternal(OpKernelContext* context) const override; + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + bool& is_packed, PrePackedWeights* prepacked_weights) override; + private: - bool use_swish_activation_; + bool use_swish_activation_; // use SiLU (also known as Swish) activation after group normalization? float epsilon_; int num_groups_; bool channels_last_; + bool has_skip_; // true for SkipGroupNorm operator; false for GroupNorm + int channels_per_block_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu index 01ba078b4be77..48b161552ce0c 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -16,18 +16,45 @@ */ // The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5 +// Modifications: heuristic channels per block; support epsilon; support skip and bias; update coding style. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include #include #include #include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cu_inc/common.cuh" #include "contrib_ops/cuda/diffusion/group_norm_impl.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +using namespace onnxruntime::cuda; + namespace onnxruntime { namespace contrib { namespace cuda { -static inline int32_t divUp(int32_t m, int32_t n) { +namespace { + +// TODO: Similar to SkipLayerNorm kernel, read/write up to 8 channels at same time. +constexpr static int32_t CHANNELS_PER_THREAD = 2; + +constexpr static int kSizes[] = {128, 256, 320, 384, 512}; +constexpr static size_t kNumOfSizes = sizeof(kSizes) / sizeof(kSizes[0]); +constexpr static int kMaxSize = kSizes[kNumOfSizes - 1]; + +int NextSize(int x) { + for (size_t i = 0; i < kNumOfSizes; ++i) { + if (x <= kSizes[i]) { + return kSizes[i]; + } + } + + return x; +} +} // namespace + +static inline int32_t DivUp(int32_t m, int32_t n) { return (m + n - 1) / n; } @@ -41,14 +68,14 @@ struct GroupSums { // The sum. float sum; // The sum of squares. - float sumSq; + float sum_sq; }; struct GroupSumsOp { inline __device__ GroupSums operator()(GroupSums const& a, GroupSums const& b) { GroupSums dst; dst.sum = b.flag ? b.sum : (a.sum + b.sum); - dst.sumSq = b.flag ? b.sumSq : (a.sumSq + b.sumSq); + dst.sum_sq = b.flag ? b.sum_sq : (a.sum_sq + b.sum_sq); dst.flag = a.flag + b.flag; return dst; } @@ -56,54 +83,85 @@ struct GroupSumsOp { template struct GroupNormNHWCParams { - // The output buffer. Layout NHWC. + // The output buffer. Shape is (n, h, w, c). T* dst; - // The input buffer. Layout NHWC. + + // Optional output of element-wise add result of src, skip and bias. Shape is (n, h, w, c). + T* add_out; + + // The input buffer. Shape is (n, h, w, c). T const* src; + + // Optional input buffer for skip tensor. Shape is (n, h, w, c) or (n, 1, 1, c) or (n, c). + T const* skip; + + // Optional input buffer for bias tensor. Shape is (c). + T const* bias; + // The gamma scaling factor. float const* gamma; + // The beta term to add in GN. float const* beta; - // The temporary buffer to do the global parallel reduction. Size: - // BLOCKS_PER_BATCH x C x 2. - float* redBuffer; + + // The temporary buffer to do the global parallel reduction. Shape is (n, 2, g), where g is number of groups. + float* group_sum_buffer; // The number of instances in the batch. int32_t n; + // The height and width of each activation map. int32_t h; int32_t w; - // The number of channels. + + // Number of channels. int32_t c; - // The number of groups. + + // Number of groups. int32_t groups; - // Do we apply the Swish activation function? - bool withSwish; + + // Do we apply the SiLU activation function? + bool use_silu; // Precomputed values and parameters to control the execution of the kernels. - // The number of activations per instance (h * w) and the number of - // activations per block. + // Number of activations per instance (h * w) int32_t hw; - int32_t hwPerBlock; - // The number of channels per group and blocks per activation in the C - // dimension. - int32_t cPerBlock; - int32_t cPerGroup; + + // Number of activations per block + int32_t hw_per_block; + + // Number of channels per block in the C dimension. + int32_t channels_per_block; + + // Number of channels per group in the C dimension. + int32_t channels_per_group; // The precomputed stride between instances. int32_t hwc; - // The inverse of hwc in floats (to compute mean/var). - float invHWC; + // The inverse of hw*channels_per_group to compute mean of a group. + float inv_hw_channels_per_group; // The precomputed number of groups per block. - int32_t groupsPerBlock; + int32_t groups_per_block; + + // Number of threads per block + int32_t threads_per_block; + + // Epsilon to get stable variance in normalization. + float epsilon; + + // Whether skip need broadcast. True if shape of skip is (N, C) or (N, 1, 1, C); False otherwise. + bool broadcast_skip; + + // For SkipGroupNorm, it points to the intermediate result of adding skip and bias. + T* skip_workspace; }; template -inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sumSq); +inline __device__ void UpdateSum(const T* src, int64_t offset, float& sum, float& sum_sq); template <> -inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sumSq) { +inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, float& sum_sq) { // Fetch two channels per thread. __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); @@ -113,11 +171,11 @@ inline __device__ void UpdateSum(const half* src, int64_t offset, float& sum, fl sum += f2.x + f2.y; // Update the sum of squares. - sumSq += f2.x * f2.x + f2.y * f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; } template <> -inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sumSq) { +inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, float& sum_sq) { // Fetch two channels per thread. float2 f2 = *reinterpret_cast(&src[offset]); @@ -125,119 +183,220 @@ inline __device__ void UpdateSum(const float* src, int64_t offset, float& sum, f sum += f2.x + f2.y; // Update the sum of squares. - sumSq += f2.x * f2.x + f2.y * f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +// Sum for SkipGroupNorm: add_out[offset] = src[offset] + skip[skip_offset] + bias[bias_offset] +template +inline __device__ void AddSkipBias(T* add_out, const T* src, const T* skip, const T* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq); + +template <> +inline __device__ void AddSkipBias(half* add_out, const half* src, const half* skip, const half* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + // Fetch two channels per thread. + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); + __half2 b = *reinterpret_cast<__half2 const*>(&bias[bias_offset]); + h2 = h2 + b; + h2 = h2 + s; + + *reinterpret_cast<__half2*>(&add_out[offset]) = h2; + + float2 f2 = __half22float2(h2); + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void AddSkipBias(float* add_out, const float* src, const float* skip, const float* bias, + int64_t offset, int64_t skip_offset, int64_t bias_offset, float& sum, float& sum_sq) { + float2 f2 = *reinterpret_cast(&src[offset]); + float2 s = *reinterpret_cast(&skip[skip_offset]); + float2 b = *reinterpret_cast(&bias[bias_offset]); + f2.x += s.x + b.x; + f2.y += s.y + b.y; + + *reinterpret_cast(&add_out[offset]) = f2; + + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +// Sum for SkipGroupNorm without bias: add_out[offset] = src[offset] + skip[skip_offset] +template +inline __device__ void AddSkip(T* add_out, const T* src, const T* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq); + +template <> +inline __device__ void AddSkip(half* add_out, const half* src, const half* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); + __half2 s = *reinterpret_cast<__half2 const*>(&skip[skip_offset]); + h2 = h2 + s; + + *reinterpret_cast<__half2*>(&add_out[offset]) = h2; + + float2 f2 = __half22float2(h2); + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; +} + +template <> +inline __device__ void AddSkip(float* add_out, const float* src, const float* skip, + int64_t offset, int64_t skip_offset, float& sum, float& sum_sq) { + float2 f2 = *reinterpret_cast(&src[offset]); + float2 s = *reinterpret_cast(&skip[skip_offset]); + f2.x += s.x; + f2.y += s.y; + *reinterpret_cast(&add_out[offset]) = f2; + sum += f2.x + f2.y; + sum_sq += f2.x * f2.x + f2.y * f2.y; } -template -__global__ void groupNormNHWCSumKernel(GroupNormNHWCParams params) { +template +__global__ void GroupNormNHWCSumKernel(GroupNormNHWCParams params) { // The object in charge of doing the sums for the different blocks. - typedef cub::BlockScan BlockScan; + typedef cub::BlockScan BlockScan; // Allocate shared memory for BlockScan. - __shared__ typename BlockScan::TempStorage tempStorage; - // Allocate shared memory for the groups. We could reduce the amount of shared - // memory reserved. - __shared__ float2 smem[tTHREADS_PER_BLOCK]; + __shared__ typename BlockScan::TempStorage temp_storage; + + // Allocate shared memory for the groups. We could reduce the amount of shared memory reserved. + __shared__ float2 smem[THREADS_PER_BLOCK]; // The instance in the batch. int32_t ni = blockIdx.z; - // The channel loaded by that thread (2 channels per thread for F16x2). - int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; + + // The channel loaded by that thread. + int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; + + if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { + return; + } // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * params.hwPerBlock; + int32_t hw_begin = blockIdx.y * params.hw_per_block; // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); // The sums. float sum = 0.F; - float sumSq = 0.F; + float sum_sq = 0.F; // Iterate over the activations to compute the sums. - if (ci < params.c) { - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The offset. - int64_t offset = static_cast(ni) * params.hwc + static_cast(hwi) * params.c + ci; - UpdateSum(params.src, offset, sum, sumSq); + int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; + if (params.skip != nullptr) { + // SkipGroupNorm: skip is (n, h, w, c) or (n, 1, 1, c) or (n, c), bias is (c), and add_out is (n, h, w, c) + const int64_t bias_offset = static_cast(ci); + T* add_out = params.skip_workspace; + if (params.broadcast_skip) { + const int64_t skip_offset = static_cast(ni) * params.c + ci; + + if (params.bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkipBias(add_out, params.src, params.skip, params.bias, offset, skip_offset, bias_offset, sum, sum_sq); + } + } else { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkip(add_out, params.src, params.skip, offset, skip_offset, sum, sum_sq); + } + } + } else { + if (params.bias != nullptr) { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkipBias(add_out, params.src, params.skip, params.bias, offset, offset, bias_offset, sum, sum_sq); + } + } else { + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + AddSkip(add_out, params.src, params.skip, offset, offset, sum, sum_sq); + } + } + } + } else { // GroupNorm + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + UpdateSum(params.src, offset, sum, sum_sq); } } - // The group that thread works on and the channel in the group (modulus). - int32_t gi = threadIdx.x * 2 / params.cPerGroup; - int32_t cj = threadIdx.x * 2 - params.cPerGroup * gi; + // The group index relative to the first group within the same block. + int32_t gi = threadIdx.x * CHANNELS_PER_THREAD / params.channels_per_group; + // The channel in the group. + int32_t cj = ci % params.channels_per_group; // The data for the summations. - GroupSums inp{cj == 0 ? 1 : 0, sum, sumSq}; + GroupSums inp{cj == 0 ? 1 : 0, sum, sum_sq}; - // Do the segmented scan. + // Do the segmented scan. InclusiveScan is not deterministic. GroupSums out; - BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp()); + BlockScan(temp_storage).InclusiveScan(inp, out, GroupSumsOp()); - // Store the results for the groups in shared memory (to produce coalesced - // stores later). - if (cj == params.cPerGroup - 2) { //2 channels per thread - smem[gi] = make_float2(out.sum, out.sumSq); + // Store the results for the groups in shared memory (to produce coalesced stores later). + // For each group, only the last thread of that group is picked to save sum to shared memory. + if (cj == params.channels_per_group - CHANNELS_PER_THREAD) { + smem[gi] = make_float2(out.sum, out.sum_sq); } // Make sure the data is in shared memory. __syncthreads(); - // The global group index. - int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x; - // Threads that have nothing left to do, exit. - if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) { + if (threadIdx.x >= params.groups_per_block) { return; } - // The first threads (those storing to global memory, load the values). - float2 sums = smem[threadIdx.x]; - - // Store to global memory. - atomicAdd(¶ms.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x); - atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); + // The global group index. + // Use neighboring threads for coalesced write. + int32_t gj = blockIdx.x * params.groups_per_block + threadIdx.x; + + if (gj < params.groups) { + float2 sums = smem[threadIdx.x]; + const int index = (2 * ni) * params.groups + gj; + atomicAdd(¶ms.group_sum_buffer[index], sums.x); + atomicAdd(¶ms.group_sum_buffer[index + params.groups], sums.y); + } } template -void groupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) { - // Make sure the values are as we expect. - ORT_ENFORCE(params.c % params.cPerBlock == 0 && params.hw % params.hwPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0); - +void GroupNormNHWCSum(GroupNormNHWCParams const& params, cudaStream_t stream) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params.c / params.cPerBlock; + grid.x = DivUp(params.c, params.channels_per_block); + // The number of blocks to compute all the activations in a given instance. - grid.y = divUp(params.hw, params.hwPerBlock); + grid.y = DivUp(params.hw, params.hw_per_block); + // The number of instances. grid.z = params.n; - switch (params.cPerBlock) { - case 320: - groupNormNHWCSumKernel<<>>(params); + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params.threads_per_block) { + case 256: + GroupNormNHWCSumKernel<<>>(params); break; - case 480: - groupNormNHWCSumKernel<<>>(params); + case 192: + GroupNormNHWCSumKernel<<>>(params); break; - case 256: - groupNormNHWCSumKernel<<>>(params); + case 160: + GroupNormNHWCSumKernel<<>>(params); break; case 128: - groupNormNHWCSumKernel<<>>(params); + GroupNormNHWCSumKernel<<>>(params); + break; + case 64: + GroupNormNHWCSumKernel<<>>(params); break; - default: - ORT_NOT_IMPLEMENTED("Not implemented"); } } template -__device__ void computeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float invStdDev, float2& gammaF2, float2& betaF2, bool swish); +__device__ void ComputeGroupNorm(const T* src, T* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu); template <> -__device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float invStdDev, - float2& gammaF2, float2& betaF2, bool swish) { +__device__ void ComputeGroupNorm(const half* src, half* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu) { // Fetch two channels per thread. __half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]); @@ -245,15 +404,15 @@ __device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, flo float2 f2 = __half22float2(h2); // Normalize the channels. - f2.x = (f2.x - mean) * invStdDev; - f2.y = (f2.y - mean) * invStdDev; + f2.x = (f2.x - mean) * inv_std_dev; + f2.y = (f2.y - mean) * inv_std_dev; // Scale by gamma and add beta. - f2.x = gammaF2.x * f2.x + betaF2.x; - f2.y = gammaF2.y * f2.y + betaF2.y; + f2.x = gamma_f2.x * f2.x + beta_f2.x; + f2.y = gamma_f2.y * f2.y + beta_f2.y; - // Apply Swish if needed. - if (swish) { + // Apply SiLU activation if needed. + if (silu) { f2.x = f2.x * sigmoid(f2.x); f2.y = f2.y * sigmoid(f2.y); } @@ -262,21 +421,21 @@ __device__ void computeGroupNorm(const half* src, half* dst, int64_t offset, flo } template <> -__device__ void computeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float invStdDev, - float2& gammaF2, float2& betaF2, bool swish) { +__device__ void ComputeGroupNorm(const float* src, float* dst, int64_t offset, float mean, float inv_std_dev, + float2& gamma_f2, float2& beta_f2, bool silu) { // Fetch two channels per thread. float2 f2 = *reinterpret_cast(&src[offset]); // Normalize the channels. - f2.x = (f2.x - mean) * invStdDev; - f2.y = (f2.y - mean) * invStdDev; + f2.x = (f2.x - mean) * inv_std_dev; + f2.y = (f2.y - mean) * inv_std_dev; // Scale by gamma and add beta. - f2.x = gammaF2.x * f2.x + betaF2.x; - f2.y = gammaF2.y * f2.y + betaF2.y; + f2.x = gamma_f2.x * f2.x + beta_f2.x; + f2.y = gamma_f2.y * f2.y + beta_f2.y; - // Apply Swish if needed. - if (swish) { + // Apply SiLU activation if needed. + if (silu) { f2.x = f2.x * sigmoid(f2.x); f2.y = f2.y * sigmoid(f2.y); } @@ -284,110 +443,142 @@ __device__ void computeGroupNorm(const float* src, float* dst, int64_t offset, f *reinterpret_cast(&dst[offset]) = f2; } -template -__global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams params) { - // The channel loaded by that thread (2 channels per thread for F16x2). - int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * 2; - if (ci >= params.c) { +template +__global__ void GroupNormNHWCScaleKernel(GroupNormNHWCParams params) { + // The channel loaded by that thread. + int32_t ci = blockIdx.x * params.channels_per_block + threadIdx.x * CHANNELS_PER_THREAD; + if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.channels_per_block) { return; } // The instance in the batch. int32_t ni = blockIdx.z; - // The group that thread works on and the channel in the group (modulus). - int32_t gi = ci / params.cPerGroup; + // The group that thread works on. + int32_t gi = ci / params.channels_per_group; // Load the sum and sum of squares for the group. - float sum = 0.F, sumSq = 0.F; + float sum = 0.F, sum_sq = 0.F; if (gi < params.groups) { - sum = params.redBuffer[(2 * ni + 0) * params.groups + gi]; - sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi]; + const int index = (2 * ni) * params.groups + gi; + sum = params.group_sum_buffer[index]; + sum_sq = params.group_sum_buffer[index + params.groups]; } - // Load gamma/beta. - float2 gammaF2 = *reinterpret_cast(¶ms.gamma[ci]); - float2 betaF2 = *reinterpret_cast(¶ms.beta[ci]); + // Load gamma/beta. Fetch two per thread. + float2 gamma_f2 = *reinterpret_cast(¶ms.gamma[ci]); + float2 beta_f2 = *reinterpret_cast(¶ms.beta[ci]); // Compute the mean. - float mean = sum * params.invHWC; + float mean = sum * params.inv_hw_channels_per_group; // Compute the variance. - float var = sumSq * params.invHWC - (mean * mean); + float var = sum_sq * params.inv_hw_channels_per_group - (mean * mean); // Compute the inverse of the stddev. - float invStdDev = var <= 0.F ? 1.F : rsqrtf(var); + float inv_std_dev = rsqrtf(var + params.epsilon); - // The first activation loaded by that block. - int32_t hwBegin = blockIdx.y * params.hwPerBlock; - // The last activation loaded by that block. - int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw); + int32_t hw_begin = blockIdx.y * params.hw_per_block; + int32_t hw_end = min(hw_begin + params.hw_per_block, params.hw); - // Iterate over the activations to compute the sums. - for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) { - // The src/dst offset. - int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci; - - // Fetch two channels per thread. - computeGroupNorm(params.src, params.dst, offset, mean, invStdDev, gammaF2, betaF2, params.withSwish); + const T* input = (params.skip != nullptr) ? params.skip_workspace : params.src; + int64_t offset = static_cast(ni) * params.hwc + static_cast(hw_begin) * params.c + ci; + for (int32_t hwi = hw_begin; hwi < hw_end; ++hwi, offset += params.c) { + ComputeGroupNorm(input, params.dst, offset, mean, inv_std_dev, gamma_f2, beta_f2, params.use_silu); } } template -void groupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream) { - // Make sure the dimensions are aligned with what we expect. - ORT_ENFORCE(params.c % params.cPerBlock == 0); - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0); - +void GroupNormNHWCScale(GroupNormNHWCParams const& params, cudaStream_t stream) { dim3 grid; // The number of blocks to compute all the channels. - grid.x = params.c / params.cPerBlock; + grid.x = DivUp(params.c, params.channels_per_block); // The number of blocks to compute all the activations in a given instance. - grid.y = divUp(params.hw, params.hwPerBlock); + grid.y = DivUp(params.hw, params.hw_per_block); // The number of instances. grid.z = params.n; - switch (params.cPerBlock) { - case 320: - groupNormNHWCScaleKernel<<>>(params); + // Threads_per_block is half of values in kSizes since CHANNELS_PER_THREAD = 2. + switch (params.threads_per_block) { + case 256: + GroupNormNHWCScaleKernel<<>>(params); break; - case 480: - groupNormNHWCScaleKernel<<>>(params); + case 192: + GroupNormNHWCScaleKernel<<>>(params); break; - case 256: - groupNormNHWCScaleKernel<<>>(params); + case 160: + GroupNormNHWCScaleKernel<<>>(params); break; case 128: - groupNormNHWCScaleKernel<<>>(params); + GroupNormNHWCScaleKernel<<>>(params); + break; + case 64: + GroupNormNHWCScaleKernel<<>>(params); break; - default: - ORT_NOT_IMPLEMENTED("Not implemented"); } } -int32_t findMaxDivisor(int32_t n, int32_t maxAllowedDivisor) { - int32_t maxDivisor = -1; +int32_t FindMaxDivisor(int32_t n, int32_t max_allowed_divisor) { + int32_t max_divisor = -1; for (int32_t i = 1; i <= std::sqrt(n); i++) { if (n % i == 0) { int32_t divisor1 = n / i; int32_t divisor2 = i; - if (divisor1 > maxDivisor && divisor1 < maxAllowedDivisor) { - maxDivisor = divisor1; + if (divisor1 > max_divisor && divisor1 < max_allowed_divisor) { + max_divisor = divisor1; } - if (divisor2 > maxDivisor && divisor2 < maxAllowedDivisor) { - maxDivisor = divisor2; + if (divisor2 > max_divisor && divisor2 < max_allowed_divisor) { + max_divisor = divisor2; } } } - return maxDivisor; + return max_divisor; +} + +// Find proper channels per block based on a cost function: The cost is number of channels corresponding to +// extra threads allocated but no channels assigned to them to work on. If cost is zero, every thread has +// work to do so it is ideal case. +int FindChannelsPerBlock(int num_channels, int channels_per_group) { + int min_cost = -1; + int best_candidate = -1; + for (size_t i = kNumOfSizes; i > 0; --i) { + if (kSizes[i - 1] < channels_per_group) { + break; + } + + int channels_per_block = kSizes[i - 1] / channels_per_group * channels_per_group; + int blocks = (num_channels + channels_per_block - 1) / channels_per_block; + int cost = blocks * kSizes[i - 1] - num_channels; + if (cost == 0) { + return channels_per_block; + } + + if (min_cost == -1 || cost < min_cost) { + min_cost = cost; + best_candidate = channels_per_block; + } + } + + return best_candidate; +} + +int GetChannelsPerBlock(int num_channels, int num_groups) { + int32_t channels_per_group = num_channels / num_groups; + int32_t channels_per_block = channels_per_group; + if (channels_per_group < kMaxSize / 2) { + channels_per_block = FindChannelsPerBlock(num_channels, channels_per_group); + } + return channels_per_block; } template Status LaunchGroupNormKernel( cudaStream_t stream, T* output, + T* add_out, const T* input, + const T* skip, + const T* bias, const float* gamma, const float* beta, void* workspace, @@ -397,79 +588,94 @@ Status LaunchGroupNormKernel( int height, int width, int num_groups, - bool use_swish_activation) { - if (batch_size > static_cast(kMaxGroupNormBatchSize)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only support batch_size <= 32. Got", batch_size); - } + bool use_silu, + bool broadcast_skip, + int channels_per_block) { + GroupNormNHWCParams params; - if (num_groups != static_cast(kGroupNormNumberOfGroups)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, StatusCode::NOT_IMPLEMENTED, - "only num_groups=32 is supported. Got", num_groups); + int32_t channels_per_group = num_channels / num_groups; + // channels_per_block is computed in PrePack. + // If the gamma is not initializer, channels_per_block might be zero after PrePack. In that happens, compute it here. + if (channels_per_block < channels_per_group) { + channels_per_block = GetChannelsPerBlock(num_channels, num_groups); } - GroupNormNHWCParams params; - int32_t cPerBlock = 320; - int32_t maxBlocksPerHW = 1024; - switch (num_channels) { - case 960: - case 1920: - cPerBlock = 480; - break; - case 512: - case 256: - cPerBlock = 256; - break; - case 128: - cPerBlock = 128; - break; - default: - cPerBlock = 320; + // TODO: Update the kernel to support CHANNELS_PER_THREAD==1 and other corner cases + if (channels_per_block % channels_per_group != 0 || + channels_per_block > kMaxSize || + (channels_per_group % CHANNELS_PER_THREAD != 0)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "GroupNorm in CUDA does not support the input: n=", batch_size, + " h=", height, + " w=", width, + " c=", num_channels, + " groups=", num_groups); } - params.withSwish = use_swish_activation; + params.use_silu = use_silu; params.dst = output; + params.add_out = add_out; params.src = input; + params.skip = skip; + params.bias = bias; params.gamma = gamma; params.beta = beta; - params.redBuffer = reinterpret_cast(workspace); + params.group_sum_buffer = reinterpret_cast(workspace); params.n = batch_size; params.h = height; params.w = width; params.c = num_channels; params.groups = num_groups; params.hw = params.h * params.w; - const int32_t blocksPerHW = findMaxDivisor(params.hw, maxBlocksPerHW); - params.hwPerBlock = divUp(params.hw, blocksPerHW); - params.cPerBlock = cPerBlock; - params.cPerGroup = params.c / params.groups; + + // This will allocate as many blocks as possible to partition HW. + // For Stable Diffusion, latent hw is 4K ~ 16K. This will allocate 1024 blocks, and each handles 4~16 hw. + // TODO: tune this logic to find proper blocks when hw is small. + constexpr int32_t max_blocks_per_hw = 1024; + const int32_t blocks_per_hw = FindMaxDivisor(params.hw, max_blocks_per_hw); + params.hw_per_block = DivUp(params.hw, blocks_per_hw); + + params.channels_per_block = channels_per_block; + params.channels_per_group = channels_per_group; params.hwc = params.hw * params.c; - params.invHWC = 1.F / (float)(params.hw * params.cPerGroup); - params.groupsPerBlock = cPerBlock / params.cPerGroup; + params.inv_hw_channels_per_group = 1.F / (float)(params.hw * params.channels_per_group); + params.groups_per_block = channels_per_block / params.channels_per_group; + params.epsilon = epsilon; + params.broadcast_skip = broadcast_skip; - DUMP_TENSOR_INIT(); - DUMP_TENSOR("input", input, batch_size, num_channels, height * width); - DUMP_TENSOR("gamma", gamma, 1, num_channels); - DUMP_TENSOR("beta", beta, 1, num_channels); - cudaMemsetAsync(params.redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(), stream); - groupNormNHWCSum(params, stream); - DUMP_TENSOR("workspace", params.redBuffer, batch_size, num_groups, 2); + // Workspace for SkipGroupNorm to store intermediate results of src+skip+bias. + params.skip_workspace = (params.add_out != nullptr) ? params.add_out : params.dst; + + params.threads_per_block = NextSize(channels_per_block) / CHANNELS_PER_THREAD; + + CUDA_RETURN_IF_ERROR(cudaMemsetAsync( + params.group_sum_buffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), stream)); + + GroupNormNHWCSum(params, stream); CUDA_RETURN_IF_ERROR(cudaGetLastError()); - groupNormNHWCScale(params, stream); + + DUMP_TENSOR_INIT(); + DUMP_TENSOR("workspace", params.group_sum_buffer, batch_size, 2, num_groups); + + GroupNormNHWCScale(params, stream); CUDA_RETURN_IF_ERROR(cudaGetLastError()); - DUMP_TENSOR("output", output, batch_size, num_channels, height * width); + return Status::OK(); } -template Status LaunchGroupNormKernel(cudaStream_t stream, half* output, - const half* input, const float* gamma, const float* beta, void* workspace, +template Status LaunchGroupNormKernel(cudaStream_t stream, half* output, half* add_out, + const half* input, const half* skip, const half* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + int height, int width, int num_groups, bool silu, + bool broadcast_skip, int channels_per_block); -template Status LaunchGroupNormKernel(cudaStream_t stream, float* output, - const float* input, const float* gamma, const float* beta, void* workspace, +template Status LaunchGroupNormKernel(cudaStream_t stream, float* output, float* add_out, + const float* input, const float* skip, const float* bias, + const float* gamma, const float* beta, void* workspace, float epsilon, int batch_size, int num_channels, - int height, int width, int num_groups, bool swish); + int height, int width, int num_groups, bool silu, + bool broadcast_skip, int channels_per_block); } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h index c7e9245050ee6..9532aeecb2f57 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.h @@ -12,29 +12,33 @@ namespace onnxruntime { namespace contrib { namespace cuda { -constexpr size_t kMaxGroupNormBatchSize = 32; -constexpr size_t kGroupNormNumberOfGroups = 32; - -constexpr size_t GetGroupNormWorkspaceSizeInBytes() { +constexpr size_t GetGroupNormWorkspaceSizeInBytes(size_t batch_size, size_t num_groups) { // Two buffers for sum and squared sum - return (sizeof(float) * 2) * kMaxGroupNormBatchSize * kGroupNormNumberOfGroups; + return (sizeof(float) * 2) * batch_size * num_groups; } +int GetChannelsPerBlock(int num_channels, int num_groups); + template Status LaunchGroupNormKernel( cudaStream_t stream, - T* output, // normalized output tensor - const T* input, // input tensor - const float* gamma, // gamma (also known as weight or scale) - const float* beta, // beta (also known as bias) - void* workspace, // Work space - float epsilon, // epsilon used normalization - int batch_size, // N - int num_channels, // C - int height, // H - int width, // W - int num_groups, // number of groups - bool use_swish_activation // Whether there is Swish activation after group normalization + T* output, // normalized output tensor. Shape is (n, h, w, c) + T* add_out, // optional output tensor for element-wise sum of input + skip + bias. Shape is (n, h, w, c) + const T* input, // input tensor. Shape is (n, h, w, c) + const T* skip, // optional skip tensor. Shape is (n, h, w, c) + const T* bias, // optional bias tensor. Shape is (c) for SkipGroupNorm or (n, c) for BiasGroupNorm + const float* gamma, // gamma (also known as weight or scale). Shape is (c) + const float* beta, // beta (also known as bias). Shape is (c) + void* workspace, // Work space + float epsilon, // epsilon used normalization + int batch_size, // N + int num_channels, // C + int height, // H + int width, // W + int num_groups, // number of groups + bool use_silu, // Whether there is Sigmoid Linear Unit (SiLU) activation after group normalization + bool broadcast_skip, // Whether skip need broadcast. When skip has shape (n, c) or (n, 1, 1, c), it need broadcast. + int channels_per_block // Pre-computed channels per block. ); } // namespace cuda diff --git a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc b/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc index c665da89af36c..e82e15a304f4c 100644 --- a/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/rocm/diffusion/group_norm.cc @@ -72,6 +72,12 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : RocmKernel(op_info) { channels_last_ = (op_info.GetAttrOrDefault("channels_last", static_cast(1)) != 0); } +Status GroupNorm::PrePack(const Tensor& /*tensor*/, int /*input_idx*/, AllocatorPtr /*alloc*/, + bool& is_packed, PrePackedWeights* /*prepacked_weights*/) { + is_packed = false; + return Status::OK(); +} + Status GroupNorm::ComputeInternal(OpKernelContext* context) const { const Tensor* input = context->Input(0); const Tensor* gamma = context->Input(1); diff --git a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc index c2f5edaa6149b..f81c3b8e0182c 100644 --- a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc @@ -42,7 +42,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "The number of groups of channels. It should be a divisor of the number of channels C", AttributeProto::INT) .Attr("activation", - "Activation after group normalization: 0 for None, 1 for Swish", + "Activation after group normalization: 0 for None, 1 for SiLU", AttributeProto::INT) .Attr("channels_last", "1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.", @@ -68,6 +68,85 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .TypeConstraint("M", {"tensor(float16)", "tensor(float)"}, "Constrain gamma and beta to float tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); +constexpr const char* SkipGroupNorm_ver1_doc = R"DOC( +This operator element-wise adds x, skip and bias, then apply group normalization and optional activation. + +This operator transforms input according to + s = x + skip + bias + y = gamma * (s - mean) / sqrt(variance + epsilon) + beta + +The input channels are separated into num_groups groups, each containing num_channels / num_groups channels. +The num_channels must be divisible by num_groups. +The mean and standard-deviation of s are calculated separately over the each group. +The weight and bias are per-channel affine transform parameter vectors of size num_channels. + +The activation attribute can be used to enable activation after group normalization. +)DOC"; + +ONNX_MS_OPERATOR_SET_SCHEMA( + SkipGroupNorm, 1, + OpSchema() + .SetDoc(SkipGroupNorm_ver1_doc) + .Attr("epsilon", "The epsilon value to use to avoid division by zero", + AttributeProto::FLOAT, static_cast(1e-5)) + .Attr("groups", + "The number of groups of channels. It should be a divisor of the number of channels C", + AttributeProto::INT) + .Attr("activation", + "Activation after group normalization: 0 for None, 1 for SiLU", + AttributeProto::INT) + .Attr("channels_last", + "1 if the input and output are in the NHWC layout, 0 if it is in the NCHW layout. Defaults to 1.", + AttributeProto::INT, + static_cast(1)) + .Input(0, + "X", + "Input data tensor. Dimensions are (N x H x W x C) when channels_last is 1 " + " or (N x C x H x W) otherwise, where N is the batch size, C is the number of channels," + " and H and W are the height and width of the data", + "T") + .Input(1, + "gamma", + "1D gamma tensor for normalization with shape (C), where C is number of channels", + "M") + .Input(2, + "beta", + "1D beta tensor for normalization with shape (C), where C is number of channels", + "M") + .Input(3, + "skip", + "4D or 2D skip tensor. The shape can be (N x H x W x C) or (N x 1 x 1 x C) or (N x C)", + "T") + .Input(4, + "bias", + "1D bias tensor. Dimensions are (C), where C is number of channels", + "T", + OpSchema::Optional) + .Output(0, + "Y", + "The output tensor of the same shape as X", + "T") + .Output(1, + "S", + "The element-wise sum of input x, skip and bias tensors. It has the same shape as X", + "T", + OpSchema::Optional) + .TypeConstraint("T", {"tensor(float16)", "tensor(float)"}, "Constrain input X, skip, bias and output Y, S types to float tensors.") + .TypeConstraint("M", {"tensor(float16)", "tensor(float)"}, "Constrain gamma and beta to float tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + if (ctx.getNumOutputs() > 1) { + propagateElemTypeFromInputToOutput(ctx, 0, 1); + } + + if (hasInputShape(ctx, 0)) { + propagateShapeFromInputToOutput(ctx, 0, 0); + if (ctx.getNumOutputs() > 1) { + propagateShapeFromInputToOutput(ctx, 0, 1); + } + } + })); + constexpr const char* BiasSplitGelu_ver1_doc = R"DOC( A fusion used in diffusion model that after adding bias, hidden state is sliced into two tensors of same size, then left tensor multiplies the Gelu activation result of right tensor. diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index aa31f3b5a7c62..b35cfc5d12f36 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -98,6 +98,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RotaryEmbedding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipGroupNorm); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipSimplifiedLayerNormalization); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseToDenseMatMul); @@ -205,6 +206,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index ef1c46b83946a..9b68aef57656e 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -200,6 +200,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "GemmFastGelu": self._infer_GemmFastGelu, "GemmFloat8": self._infer_GemmFloat8, "GroupNorm": self._infer_GroupNorm, + "SkipGroupNorm": self._infer_SkipGroupNorm, "LayerNormalization": self._infer_LayerNormalization, "LongformerAttention": self._infer_LongformerAttention, "MultiHeadAttention": self._infer_MultiHeadAttention, @@ -2376,6 +2377,11 @@ def _infer_SkipLayerNormalization(self, node): # noqa: N802 def _infer_GroupNorm(self, node): # noqa: N802 self._propagate_shape_and_type(node) + def _infer_SkipGroupNorm(self, node): # noqa: N802 + self._propagate_shape_and_type(node, 0, 0) + if len(node.output) > 1: + self._propagate_shape_and_type(node, 0, 1) + def _infer_BiasSplitGelu(self, node): # noqa: N802 input_shape = self._get_shape(node, 0) bias_shape = self._get_shape(node, 1) diff --git a/onnxruntime/python/tools/transformers/io_binding_helper.py b/onnxruntime/python/tools/transformers/io_binding_helper.py index de17f195c99cc..50703b9c17e03 100644 --- a/onnxruntime/python/tools/transformers/io_binding_helper.py +++ b/onnxruntime/python/tools/transformers/io_binding_helper.py @@ -1,6 +1,6 @@ import logging from collections import OrderedDict -from typing import Any, Dict, List +from typing import Any, Dict, List, Tuple, Union import numpy import torch @@ -229,7 +229,7 @@ def __del__(self): del self.io_binding del self.ort_session - def allocate_buffers(self, shape_dict: Dict[str, tuple]): + def allocate_buffers(self, shape_dict: Dict[str, Union[Tuple[int], List[int]]]): """Allocate tensors for I/O Binding""" if self.enable_cuda_graph: for name, shape in shape_dict.items(): diff --git a/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc b/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc new file mode 100644 index 0000000000000..fefd5722054de --- /dev/null +++ b/onnxruntime/test/contrib_ops/skip_group_norm_op_test.cc @@ -0,0 +1,286 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/framework/test_utils.h" +#include "test/providers/provider_test_utils.h" + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +using namespace std; + +namespace onnxruntime { +namespace test { + +TEST(SkipGroupNormTest, SkipGroupNorm_with_bias) { + constexpr int64_t B = 2; + constexpr int64_t C = 16; + constexpr int64_t H = 2; + constexpr int64_t W = 2; + + std::vector dims_nhwc{B, H, W, C}; + std::vector input_data_nhwc = { + -0.768555f, 1.575195f, -0.698242f, 1.587891f, 0.371826f, -0.280029f, -1.328125f, 0.127197f, + -0.197144f, 0.982422f, -0.671387f, -1.925781f, 1.800781f, -0.020218f, -0.782227f, 1.291992f, + -0.935059f, 1.782227f, -0.674316f, -1.943359f, -0.218994f, 0.054138f, -1.539062f, -0.546387f, + -2.160156f, 1.195312f, 1.653320f, -0.674316f, 0.224731f, -0.093262f, 1.160156f, -0.389404f, + 1.748047f, 0.766113f, 0.234375f, 0.011177f, -0.055847f, -0.930664f, -0.490234f, -0.655762f, + -0.382568f, -0.554688f, 0.910645f, -0.227295f, 1.687500f, 0.028397f, -0.241699f, -0.480957f, + -0.355713f, -2.095703f, -0.443359f, -0.126221f, -0.815918f, 0.792969f, -0.450439f, -0.952148f, + -1.174805f, 0.242798f, 0.138550f, -0.237061f, -0.994141f, 0.346436f, 0.147705f, 0.125854f, + -0.517090f, 0.253906f, 0.400146f, -0.540039f, -0.788574f, 0.146606f, -0.409668f, 0.281982f, + 1.444336f, 0.044434f, -0.366699f, 2.250000f, -0.453613f, -0.652344f, 1.828125f, -0.244751f, + 0.307129f, -0.051361f, 0.106384f, 0.844727f, 1.648438f, -0.904785f, -0.353760f, 0.510742f, + 0.074829f, -0.311279f, 0.274902f, 1.594727f, 1.367188f, 0.098755f, 0.043304f, -0.207397f, + 0.068298f, -0.601074f, 0.083008f, 0.264893f, -0.659180f, -0.216797f, -0.086548f, -0.683594f, + -0.964844f, -2.591797f, -0.817383f, -0.461914f, -1.840820f, -0.712402f, -0.052094f, -0.583008f, + 1.114258f, 0.190308f, 1.087891f, 0.005146f, 1.041992f, 1.363281f, -0.273682f, -0.465576f, + -0.027618f, 1.345703f, 0.789551f, -0.015991f, 0.401611f, 0.726562f, 0.598633f, 0.133667f}; + + std::vector gamma_data = { + 0.241255f, 0.556660f, -0.835532f, 0.564596f, -1.338308f, -0.278924f, 0.357326f, -1.745484f, + 0.277184f, 0.101415f, -0.018637f, -0.526188f, -0.011698f, -2.349411f, 0.206578f, 0.357679f}; + + std::vector beta_data = { + -1.194839f, 0.209146f, -0.677225f, -0.547338f, 1.275685f, -1.099577f, 0.470916f, 0.293907f, + -1.094209f, 2.350204f, -1.633769f, 0.248753f, -0.180166f, 0.365134f, -0.555731f, 1.843083f}; + + std::vector skip_data_nhwc = { + 0.892578f, -0.471924f, -0.423096f, 1.277344f, 0.257080f, -1.366211f, 1.552734f, 0.441406f, + -0.033142f, -0.059418f, 1.536133f, -0.225464f, 1.472656f, 0.591309f, -0.386230f, -2.197266f, + 0.089600f, -0.256592f, -1.873047f, 0.916992f, 0.392090f, 0.015526f, -0.949219f, 0.566895f, + -0.220459f, 1.262695f, -0.437744f, -2.283203f, -0.264893f, -0.660156f, 2.353516f, 1.992188f, + 0.865723f, -0.854004f, -1.014648f, 0.899414f, -1.041016f, 1.378906f, -0.075073f, -2.541016f, + -0.883789f, -0.428711f, 0.981934f, -0.072754f, 2.214844f, 0.658203f, 0.170166f, -1.727539f, + -0.672363f, -1.373047f, 0.318115f, 0.422363f, 0.260742f, -0.547852f, 0.545898f, -0.155762f, + 0.679688f, 2.861328f, -0.300781f, -0.504883f, 1.548828f, 0.353760f, -0.387695f, -1.595703f, + -0.170166f, -0.002897f, 0.273193f, -0.383545f, -1.082031f, -0.894043f, -1.048828f, -0.044708f, + 0.049286f, 0.220215f, 0.272705f, -0.853027f, -0.489258f, 0.513672f, 0.977051f, 0.310547f, + -0.577148f, -0.479004f, 0.838867f, 0.872559f, -0.510254f, 0.101807f, -0.299805f, -1.179688f, + -1.555664f, 0.668457f, 0.939453f, 0.118103f, -0.376709f, 0.735352f, -0.214233f, -1.987305f, + -0.931152f, 1.268555f, 1.427734f, -0.757812f, -1.324219f, 0.375488f, 1.364258f, -1.708008f, + 0.976562f, -0.037659f, -1.779297f, -0.196655f, 1.636719f, 0.690430f, 0.941895f, -1.882812f, + 0.431641f, 0.203857f, 1.306641f, -0.126343f, 1.408203f, 1.188477f, 0.432861f, -2.296875f, + -0.475342f, 1.517578f, -0.824219f, 1.288086f, -0.028244f, 1.918945f, 0.352295f, 0.693359f}; + + std::vector bias_data = { + -0.537598f, 0.500488f, -0.252441f, -0.460693f, -1.640625f, -1.298828f, 0.331787f, -1.588867f, + 1.000977f, 1.458984f, 0.702637f, 0.147827f, 1.143555f, 0.533691f, -0.072510f, 0.511230f}; + + std::vector norm_data_nhwc = { + -1.213867f, 0.856445f, -0.119141f, 0.386475f, 0.714355f, -0.804688f, + 1.048828f, -0.426270f, -1.091797f, 2.435547f, -1.641602f, 0.989746f, + -0.200928f, 0.267334f, -0.800781f, 1.577148f, -1.357422f, 1.000977f, + 0.613281f, -0.963867f, 1.179688f, -1.169922f, 0.308350f, 0.304199f, + -1.396484f, 2.513672f, -1.644531f, 1.206055f, -0.180664f, 1.896484f, + -0.294678f, 2.046875f, -0.844238f, 0.448486f, -0.294189f, -0.291504f, + 2.480469f, -1.250977f, 0.833008f, 4.593750f, -1.238281f, 2.335938f, + -1.651367f, 0.491943f, -0.204834f, 0.125610f, -0.682129f, 1.333984f, + -1.384766f, -0.708008f, -0.630859f, -0.504883f, 1.924805f, -1.208008f, + 1.013672f, 1.809570f, -1.128906f, 2.546875f, -1.631836f, 0.610840f, + -0.184326f, 0.110046f, -0.700195f, 1.471680f, -1.511719f, 0.492188f, + -0.847168f, -1.373047f, 2.837891f, -0.998047f, 0.521484f, 0.262207f, + -0.810547f, 2.400391f, -1.628906f, 0.049896f, -0.174927f, 1.076172f, + -0.252197f, 1.784180f, -1.418945f, 0.090820f, -1.056641f, 0.002945f, + 0.627441f, -0.989746f, 0.679199f, 1.130859f, -1.371094f, 2.408203f, + -1.645508f, -0.062988f, -0.192017f, -0.655762f, -0.718262f, 1.170898f, + -1.550781f, 0.706055f, -1.492188f, -1.148438f, 2.921875f, -1.136719f, + 1.058594f, 2.781250f, -1.089844f, 2.201172f, -1.597656f, 0.785645f, + -0.181396f, 0.868164f, -0.552246f, 1.097656f, -1.015625f, 0.565430f, + -2.173828f, -0.955078f, -0.336426f, -1.503906f, 0.838867f, 3.136719f, + -1.186523f, 2.580078f, -1.629883f, 0.094604f, -0.186523f, -3.884766f, + -0.542480f, 1.990234f}; + + std::vector add_out_data_nhwc = { + -0.414062f, 1.604492f, -1.374023f, 2.404297f, -1.011719f, -2.945312f, 0.556641f, -1.020508f, + 0.770508f, 2.382812f, 1.567383f, -2.003906f, 4.417969f, 1.105469f, -1.240234f, -0.394531f, + -1.382812f, 2.027344f, -2.800781f, -1.487305f, -1.466797f, -1.229492f, -2.156250f, -1.568359f, + -1.379883f, 3.917969f, 1.917969f, -2.808594f, 1.103516f, -0.219727f, 3.441406f, 2.113281f, + 2.076172f, 0.412598f, -1.033203f, 0.449951f, -2.738281f, -0.851562f, -0.233521f, -4.785156f, + -0.265625f, 0.475586f, 2.595703f, -0.152222f, 5.046875f, 1.220703f, -0.144043f, -1.697266f, + -1.566406f, -2.968750f, -0.377686f, -0.164551f, -2.195312f, -1.053711f, 0.427246f, -2.697266f, + 0.505859f, 4.562500f, 0.540527f, -0.594238f, 1.698242f, 1.233398f, -0.312500f, -0.958496f, + -1.224609f, 0.751465f, 0.420898f, -1.384766f, -3.511719f, -2.046875f, -1.126953f, -1.351562f, + 2.494141f, 1.724609f, 0.608398f, 1.544922f, 0.200684f, 0.395020f, 2.732422f, 0.577148f, + -0.807617f, -0.029785f, 0.692871f, 1.256836f, -0.502441f, -2.101562f, -0.321777f, -2.257812f, + -0.479492f, 1.816406f, 1.916992f, 1.860352f, 2.134766f, 1.367188f, -0.243408f, -1.683594f, + -1.400391f, 1.167969f, 1.257812f, -0.953613f, -3.625000f, -1.140625f, 1.609375f, -3.980469f, + 1.012695f, -1.170898f, -1.894531f, -0.510742f, 0.939453f, 0.511719f, 0.817383f, -1.955078f, + 1.007812f, 0.894531f, 2.142578f, -0.582031f, 0.809570f, 1.252930f, 0.490967f, -4.351562f, + 0.497803f, 4.320312f, 0.667969f, 1.419922f, 1.516602f, 3.179688f, 0.878906f, 1.337891f}; + + int min_cuda_architecture = 530; + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + + std::array channels_last_values = {-1, 1}; + + for (const int channels_last : channels_last_values) { + if (enable_cuda) { + std::vector> execution_providers; + if (enable_cuda && channels_last != 0) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + + // Don't run the test if no providers are supported + if (execution_providers.empty()) { + continue; + } + + OpTester test("SkipGroupNorm", 1, onnxruntime::kMSDomain); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("groups", 4); + test.AddAttribute("activation", 0); + + // We interpret channels_last==-1 as the attribute not being provided + if (channels_last != -1) { + test.AddAttribute("channels_last", channels_last); + } + + test.AddInput("X", dims_nhwc, ToFloat16(input_data_nhwc)); + test.AddInput("gamma", {C}, gamma_data); + test.AddInput("beta", {C}, beta_data); + test.AddInput("skip", dims_nhwc, ToFloat16(skip_data_nhwc)); + test.AddInput("bias", {C}, ToFloat16(bias_data)); + + constexpr float rel_error = 0.0f; + constexpr float abs_error = 0.02f; + test.AddOutput("Y", dims_nhwc, ToFloat16(norm_data_nhwc), false, rel_error, abs_error); + test.AddOutput("S", dims_nhwc, ToFloat16(add_out_data_nhwc), false, rel_error, abs_error); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + } +} + +TEST(SkipGroupNormTest, SkipGroupNorm_no_bias_broadcast_skip) { + constexpr int64_t B = 1; + constexpr int64_t C = 64; + constexpr int64_t H = 1; + constexpr int64_t W = 1; + + std::vector dims_nhwc{B, H, W, C}; + std::vector input_data_nhwc = { + 0.588867f, 0.896484f, -0.213623f, 0.803223f, 0.659180f, -0.216187f, 1.197266f, -0.486084f, + -0.718750f, 0.332031f, -0.364746f, -0.831543f, -0.031219f, -1.059570f, 0.161621f, 1.519531f, + 0.169312f, 1.048828f, 1.330078f, 0.450195f, -2.867188f, -1.456055f, 0.708496f, -1.120117f, + -1.208984f, -1.199219f, -1.505859f, -0.549316f, 0.505371f, 0.723145f, -0.359131f, -0.250977f, + -0.879883f, -0.305664f, 0.709473f, 0.815430f, 0.617676f, -0.638672f, 0.066772f, -2.330078f, + -1.316406f, 1.744141f, 1.122070f, -0.633789f, -1.802734f, -0.825684f, 0.622559f, -0.481689f, + -1.364258f, -0.536621f, -0.464111f, 0.247437f, -0.213989f, 0.384521f, 0.556641f, -0.303711f, + -0.160034f, 0.882324f, -0.212036f, -0.796387f, 0.153076f, -1.311523f, 2.212891f, 0.685059f}; + + std::vector gamma_data = { + 0.789682f, 0.869051f, -0.010169f, -0.021685f, 0.506611f, 1.267444f, -0.312695f, 0.877844f, + 0.598637f, 0.598314f, -1.721544f, -0.593328f, 0.986705f, -0.419391f, -0.852584f, -0.572351f, + 0.912797f, -0.586863f, 0.477761f, -0.484418f, -0.193835f, 0.347757f, 0.327637f, -1.100304f, + 1.233108f, -0.272569f, -0.688656f, 0.687245f, 0.398386f, 0.888089f, -0.792587f, -0.769029f, + -0.427778f, 0.100768f, -2.187060f, 1.279301f, 1.109054f, 0.375992f, 1.514775f, 1.271436f, + 0.822896f, -0.476750f, 0.475507f, -1.011297f, 1.177197f, 1.586540f, -1.059944f, -0.145351f, + 0.841555f, -2.014113f, -0.230498f, 0.302128f, -0.180508f, 0.980534f, -0.126871f, 0.203151f, + -0.754841f, 0.420570f, -1.085798f, 1.335042f, -0.674930f, 2.453507f, 2.139259f, 1.087436f}; + + std::vector beta_data = { + -0.064518f, -0.262683f, 0.827528f, -0.960938f, 1.062519f, 2.417941f, 0.212789f, -1.638430f, + 1.875453f, -0.883058f, -0.006704f, 0.424894f, -0.869972f, 0.727008f, 0.879303f, -3.024141f, + -2.610873f, 1.269641f, 0.883006f, 0.804167f, -1.510324f, 2.258091f, -0.006750f, -1.553668f, + -1.659453f, 0.579603f, 0.652358f, 0.007077f, 0.099180f, 0.418658f, -0.273778f, -1.036199f, + -1.128691f, -0.296022f, -0.224056f, 1.476306f, 0.577624f, -0.372049f, -0.581659f, -1.841807f, + -0.361721f, 0.051160f, -0.749332f, -2.634807f, 0.562719f, -0.738667f, 0.024864f, -1.135937f, + -1.368144f, -1.458886f, -0.946683f, 1.953936f, -1.198661f, 0.166648f, 0.447206f, -0.458140f, + -0.553395f, 0.112900f, 0.255989f, -0.184551f, 1.254163f, -0.260479f, -1.232429f, 1.902575f}; + + std::vector skip_data = { + 0.952148f, 1.342773f, -0.172974f, -0.395264f, 1.119141f, 0.330566f, + 0.281494f, 0.472900f, -0.692871f, -0.634766f, 0.013504f, -1.866211f, + -0.428223f, 0.669922f, -0.323486f, 0.713867f, -0.350586f, 0.659180f, + -0.288574f, 0.324219f, -0.300781f, -0.789551f, -0.216431f, -0.221436f, + -0.086670f, 0.366211f, -0.643555f, -0.977051f, 0.001021f, 0.415527f, + -0.271729f, 0.836426f, 0.035370f, -0.806152f, 0.936035f, -0.021332f, + -1.095703f, 0.971680f, 1.648438f, 0.840820f, 0.837402f, 0.607910f, + -1.894531f, 0.666016f, -0.171143f, 1.625977f, -0.620117f, -0.039581f, + 1.702148f, -2.410156f, 1.565430f, -0.756348f, 1.446289f, 0.583496f, + -0.497559f, -0.271729f, -0.956055f, -1.642578f, 0.833496f, -1.136719f, + 1.248047f, -2.515625f, 0.080383f, 0.376221f}; + + std::vector norm_data_nhwc = { + 0.494873f, 1.017578f, 0.841797f, -0.949219f, 1.552734f, 1.333984f, 0.012703f, -2.511719f, + 1.424805f, -0.818359f, -0.128418f, 1.462891f, -0.882812f, 0.709961f, 0.693848f, -4.210938f, + -2.505859f, 0.513184f, 1.300781f, 0.460938f, -1.172852f, 1.851562f, 0.167969f, -0.885254f, + -2.535156f, 0.656738f, 1.683594f, -0.627441f, 0.478271f, 1.782227f, -0.196777f, -1.824219f, + -0.791016f, -0.398682f, -3.197266f, 2.275391f, 0.052704f, -0.286865f, 1.567383f, -3.552734f, + -0.646973f, -0.927734f, -1.032227f, -2.722656f, -1.337891f, 0.432129f, -0.040253f, -1.080078f, + -1.118164f, 3.123047f, -1.153320f, 1.843750f, -1.378906f, 0.941406f, 0.437256f, -0.542969f, + -0.218872f, 0.006115f, -0.265869f, -1.356445f, 0.649902f, -4.882812f, 1.696289f, 2.679688f}; + + std::vector add_out_data_nhwc = { + 1.541016f, 2.238281f, -0.386719f, 0.407959f, 1.778320f, 0.114380f, + 1.478516f, -0.013184f, -1.412109f, -0.302734f, -0.351318f, -2.697266f, + -0.459473f, -0.389648f, -0.161865f, 2.234375f, -0.181274f, 1.708008f, + 1.041016f, 0.774414f, -3.167969f, -2.246094f, 0.492188f, -1.341797f, + -1.295898f, -0.833008f, -2.148438f, -1.526367f, 0.506348f, 1.138672f, + -0.630859f, 0.585449f, -0.844727f, -1.111328f, 1.645508f, 0.793945f, + -0.478027f, 0.333008f, 1.714844f, -1.489258f, -0.479004f, 2.351562f, + -0.772461f, 0.032227f, -1.973633f, 0.800293f, 0.002441f, -0.521484f, + 0.337891f, -2.947266f, 1.101562f, -0.508789f, 1.232422f, 0.967773f, + 0.059082f, -0.575195f, -1.116211f, -0.760254f, 0.621582f, -1.933594f, + 1.401367f, -3.828125f, 2.292969f, 1.061523f}; + + int min_cuda_architecture = 530; + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + + std::array has_add_out_values = {true, false}; + std::array skip_dims = {2, 4}; + + constexpr int channels_last = 1; + for (const int skip_dim : skip_dims) { + for (const bool has_add_out : has_add_out_values) { + if (enable_cuda) { + std::vector> execution_providers; + if (enable_cuda && channels_last != 0) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + + // Don't run the test if no providers are supported + if (execution_providers.empty()) { + continue; + } + + OpTester test("SkipGroupNorm", 1, onnxruntime::kMSDomain); + test.AddAttribute("epsilon", 1e-05f); + test.AddAttribute("groups", 8); + test.AddAttribute("activation", 0); + + // We interpret channels_last==-1 as the attribute not being provided + if (channels_last != -1) { + test.AddAttribute("channels_last", channels_last); + } + + test.AddInput("X", dims_nhwc, ToFloat16(input_data_nhwc)); + test.AddInput("gamma", {C}, gamma_data); + test.AddInput("beta", {C}, beta_data); + if (skip_dim == 2) { + test.AddInput("skip", {B, C}, ToFloat16(skip_data)); + } else { + test.AddInput("skip", {B, 1, 1, C}, ToFloat16(skip_data)); + } + // no bias + + constexpr float rel_error = 0.0f; + constexpr float abs_error = 0.02f; + test.AddOutput("Y", dims_nhwc, ToFloat16(norm_data_nhwc), false, rel_error, abs_error); + + if (has_add_out) { + test.AddOutput("S", dims_nhwc, ToFloat16(add_out_data_nhwc), false, rel_error, abs_error); + } + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + } + } +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/python/transformers/test_group_norm.py b/onnxruntime/test/python/transformers/test_group_norm.py new file mode 100644 index 0000000000000..bf295a65c8b53 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_group_norm.py @@ -0,0 +1,541 @@ +# -------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +import statistics +from dataclasses import dataclass +from enum import Enum +from time import perf_counter +from typing import Optional, Tuple + +import numpy +import torch +from onnx import TensorProto, helper + +from onnxruntime import InferenceSession +from onnxruntime.transformers.io_binding_helper import CudaSession + +torch.manual_seed(0) + + +class GroupNormOpType(Enum): + GROUP_NORM = 1 + SKIP_GROUP_NORM = 2 + + +@dataclass +class GroupNormConfig: + batch_size: int + height: int + width: int + channels: int + epsilon: float = 1e-5 + num_groups: int = 32 + activation: bool = False + channels_last: bool = True + fp16: bool = False + + op_type: GroupNormOpType = GroupNormOpType.GROUP_NORM + has_bias: bool = False + has_add_out: bool = False + broadcast_skip: int = 0 # 2 for (N, C), 4 for (N, 1, 1, C) + + def get_skip_symbolic_shape(self): + skip_shape = {0: ["N", "H", "W", "C"], 2: ["N", "C"], 4: ["N", 1, 1, "C"]} + return skip_shape[self.broadcast_skip] + + def get_skip_shape(self): + skip_shape = { + 0: [self.batch_size, self.height, self.width, self.channels], + 2: [self.batch_size, self.channels], + 4: [self.batch_size, 1, 1, self.channels], + } + return skip_shape[self.broadcast_skip] + + def broadcast(self, skip: torch.Tensor): + if self.broadcast_skip == 2: + return skip.reshape(self.batch_size, 1, 1, self.channels) + + return skip + + @staticmethod + def create( + b: int, + h: int, + w: int, + c: int, + fp16: bool = False, + activation: bool = False, + template: int = 0, + num_groups: int = 32, + ): + if template == 0: + return GroupNormConfig( + b, h, w, c, fp16=fp16, activation=activation, op_type=GroupNormOpType.GROUP_NORM, num_groups=num_groups + ) + + if template == 1: + return GroupNormConfig( + b, + h, + w, + c, + fp16=fp16, + activation=activation, + op_type=GroupNormOpType.SKIP_GROUP_NORM, + has_bias=True, + has_add_out=True, + broadcast_skip=0, + num_groups=num_groups, + ) + + if template == 2: + return GroupNormConfig( + b, + h, + w, + c, + fp16=fp16, + activation=activation, + op_type=GroupNormOpType.SKIP_GROUP_NORM, + has_bias=False, + has_add_out=False, + broadcast_skip=2, + num_groups=num_groups, + ) + + if template == 3: + return GroupNormConfig( + b, + h, + w, + c, + fp16=fp16, + activation=activation, + op_type=GroupNormOpType.SKIP_GROUP_NORM, + has_bias=True, + has_add_out=False, + broadcast_skip=4, + num_groups=num_groups, + ) + + if template == 4: # No bias + return GroupNormConfig( + b, + h, + w, + c, + fp16=fp16, + activation=activation, + op_type=GroupNormOpType.SKIP_GROUP_NORM, + has_bias=False, + has_add_out=True, + broadcast_skip=0, + num_groups=num_groups, + ) + + if template == 5: # No bias, no add_out + return GroupNormConfig( + b, + h, + w, + c, + fp16=fp16, + activation=activation, + op_type=GroupNormOpType.SKIP_GROUP_NORM, + has_bias=False, + has_add_out=False, + broadcast_skip=0, + num_groups=num_groups, + ) + + return None + + +def create_group_norm_graph(config: GroupNormConfig) -> bytes: + inputs = ["input", "gamma", "beta"] + outputs = ["output"] + op_type = "GroupNorm" + if config.op_type == GroupNormOpType.SKIP_GROUP_NORM: + op_type = "SkipGroupNorm" + inputs = [*inputs, "skip"] + if config.has_bias: + inputs = [*inputs, "bias"] + if config.has_add_out: + outputs = [*outputs, "add_out"] + + nodes = [ + helper.make_node( + op_type, + inputs, + outputs, + op_type + "_0", + activation=int(config.activation), + channels_last=int(config.channels_last), + epsilon=config.epsilon, + groups=config.num_groups, + domain="com.microsoft", + ), + ] + + float_type = TensorProto.FLOAT16 if config.fp16 else TensorProto.FLOAT + + input_shapes = [ + helper.make_tensor_value_info("input", float_type, ["N", "H", "W", "C"]), + helper.make_tensor_value_info("gamma", TensorProto.FLOAT, ["C"]), + helper.make_tensor_value_info("beta", TensorProto.FLOAT, ["C"]), + ] + output_shapes = [ + helper.make_tensor_value_info("output", float_type, ["N", "H", "W", "C"]), + ] + + if config.op_type == GroupNormOpType.SKIP_GROUP_NORM: + input_shapes = [ + *input_shapes, + helper.make_tensor_value_info("skip", float_type, config.get_skip_symbolic_shape()), + ] + if config.has_bias: + input_shapes = [*input_shapes, helper.make_tensor_value_info("bias", float_type, ["C"])] + if config.has_add_out: + output_shapes = [*output_shapes, helper.make_tensor_value_info("add_out", float_type, ["N", "H", "W", "C"])] + + graph = helper.make_graph( + nodes, + "Group_Norm_Graph", + input_shapes, + output_shapes, + ) + + model = helper.make_model(graph) + return model.SerializeToString() + + +def group_norm_ort( + src: torch.Tensor, + gamma: torch.Tensor, + beta: torch.Tensor, + skip: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + config: GroupNormConfig, + measure_latency=False, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[float]]: + onnx_model_str = create_group_norm_graph(config) + ort_session = InferenceSession(onnx_model_str, providers=["CUDAExecutionProvider"]) + + session = CudaSession(ort_session, device=torch.device("cuda:0")) + + io_shape = { + "input": [config.batch_size, config.height, config.width, config.channels], + "gamma": [config.channels], + "beta": [config.channels], + "output": [config.batch_size, config.height, config.width, config.channels], + } + + if config.op_type == GroupNormOpType.SKIP_GROUP_NORM: + io_shape["skip"] = config.get_skip_shape() + if config.has_bias: + io_shape["bias"] = [config.channels] + if config.has_add_out: + io_shape["add_out"] = [config.batch_size, config.height, config.width, config.channels] + + session.allocate_buffers(io_shape) + + ort_inputs = { + "input": src, + "gamma": gamma, + "beta": beta, + } + + if config.op_type == GroupNormOpType.SKIP_GROUP_NORM: + ort_inputs["skip"] = skip + if config.has_bias: + ort_inputs["bias"] = bias + + ort_outputs = session.infer(ort_inputs) + output = ort_outputs["output"] + + add_out = ( + ort_outputs["add_out"] if config.op_type == GroupNormOpType.SKIP_GROUP_NORM and config.has_add_out else None + ) + + if measure_latency: + latency_list = [] + for _ in range(10000): + start_time = perf_counter() + session.infer(ort_inputs) + end_time = perf_counter() + latency_list.append(end_time - start_time) + average_latency = statistics.mean(latency_list) + return output, add_out, average_latency + + return output, add_out, None + + +def group_norm_torch( + src: torch.Tensor, + gamma: torch.Tensor, + beta: torch.Tensor, + skip: Optional[torch.Tensor], + bias: Optional[torch.Tensor], + config: GroupNormConfig, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + add_out = src + + if skip is not None: + assert config.op_type == GroupNormOpType.SKIP_GROUP_NORM + add_out = add_out + config.broadcast(skip) + + if bias is not None: + assert config.op_type == GroupNormOpType.SKIP_GROUP_NORM + add_out = add_out + bias.reshape(1, 1, 1, bias.shape[0]) + + x = add_out + if config.channels_last: + x = add_out.clone().permute(0, 3, 1, 2) # from NHWC to NCHW + + weight = gamma.to(x.dtype) + bias = beta.to(x.dtype) + output = torch.nn.functional.group_norm(x, config.num_groups, weight=weight, bias=bias, eps=config.epsilon) + + if config.activation: + torch.nn.functional.silu(output, inplace=True) + + if config.channels_last: + output = output.permute(0, 2, 3, 1) # from NCHW to NHWC + + return output, add_out + + +def print_tensor(name, tensor): + # Print in the format that could be directly added to unit tests in C++. + torch.set_printoptions(precision=6, sci_mode=False, linewidth=100, profile="full", threshold=1000) + print(name) + if tensor is not None: + print("shape", tensor.shape) + text = str(tensor.clone().flatten()) + print(text.replace("[", "[\n").replace("]", ",\n]").replace(",", "f,")) + else: + print(tensor) + + +def run_parity(config, measure_latency=True, verbose=False): + float_type = torch.float16 if config.fp16 else torch.float32 + + input_tensor = torch.randn( + config.batch_size, + config.height, + config.width, + config.channels, + device="cuda", + dtype=float_type, + requires_grad=False, + ) + + gamma = torch.randn( + config.channels, + device="cuda", + dtype=torch.float32, + requires_grad=False, + ) + + beta = torch.randn( + config.channels, + device="cuda", + dtype=torch.float32, + requires_grad=False, + ) + + skip = None + bias = None + if config.op_type == GroupNormOpType.SKIP_GROUP_NORM: + skip = torch.randn( + *config.get_skip_shape(), + device="cuda", + dtype=float_type, + requires_grad=False, + ) + if config.has_bias: + bias = torch.randn( + config.channels, + device="cuda", + dtype=float_type, + requires_grad=False, + ) + + if verbose: + print(config) + print_tensor("input", input_tensor) + print_tensor("gamma", gamma) + print_tensor("beta", beta) + print_tensor("skip", skip) + print_tensor("bias", bias) + + out_ort, ort_add_out, latency = group_norm_ort( + input_tensor, gamma, beta, skip, bias, config, measure_latency=measure_latency + ) + + if verbose: + print_tensor("out_ort", out_ort) + print_tensor("ort_add_out", ort_add_out) + + torch_out, torch_add_out = group_norm_torch(input_tensor, gamma, beta, skip, bias, config) + + if verbose: + print_tensor("torch_out", torch_out) + print_tensor("torch_add_out", torch_add_out) + + average_diff = numpy.mean(numpy.abs(out_ort.detach().cpu().numpy() - torch_out.detach().cpu().numpy())) + + is_close = numpy.allclose( + out_ort.detach().cpu().numpy(), + torch_out.detach().cpu().numpy(), + rtol=1e-1 if config.fp16 else 1e-3, + atol=1e-1 if config.fp16 else 1e-3, + equal_nan=True, + ) + + is_add_out_close = ( + numpy.allclose( + ort_add_out.detach().cpu().numpy(), + torch_add_out.detach().cpu().numpy(), + rtol=1e-1 if config.fp16 else 1e-3, + atol=1e-1 if config.fp16 else 1e-3, + equal_nan=True, + ) + if ort_add_out is not None + else "" + ) + + # Compare results + print( + config.op_type.name, + " B:", + config.batch_size, + " H:", + config.height, + " W:", + config.width, + " C:", + config.channels, + " G:", + config.num_groups, + " activation:", + int(config.activation), + " channels_last:", + int(config.channels_last), + " fp16:", + int(config.fp16), + f" Latency(μs): {int(latency * 1e6)}" if isinstance(latency, float) else "", + " AvgDiff:", + average_diff, + " Pass:", + is_close, + is_add_out_close, + ) + + +def get_latent_height_width(): + default_size = [(512, 512), (768, 768), (1024, 1024)] + small_img_size = [(512, 768), (768, 512)] + xl_img_size = [ + (1152, 896), + (896, 1152), + (1216, 832), + (832, 1216), + (1344, 768), + (768, 1344), + (1536, 640), + (640, 1536), + ] + return [(int(h / 8), int(w / 8)) for (h, w) in default_size + small_img_size + xl_img_size] + + +def get_channels(): + return [128, 256, 512, 1024, 2048, 320, 640, 960, 1920, 2560, 384, 768, 1536, 3072, 1152, 2304] + + +def run_activation(template: int, fp16, measure_latency=False): + print("Test GroupNorm with Silu Activation for ", "fp16" if fp16 else "fp32") + for b in [2]: + for h, w in get_latent_height_width(): + for c in get_channels(): + config = GroupNormConfig.create(b, h, w, c, fp16=fp16, activation=True, template=template) + run_parity(config, measure_latency=measure_latency) + + +def run_no_activation(template: int, fp16, measure_latency=False): + print("Test GroupNorm without Activation for ", "fp16" if fp16 else "fp32") + for b in [1, 2, 4]: + for h, w in get_latent_height_width(): + for c in get_channels(): + config = GroupNormConfig.create(b, h, w, c, fp16=fp16, template=template) + run_parity(config, measure_latency=measure_latency) + + +def run_all_groups(template: int, fp16, measure_latency=False): + group_sizes = [1, 2, 4, 8, 16, 32] + print("Test GroupNorm for different group sizes:", group_sizes) + for group_size in group_sizes: + for h, w in get_latent_height_width()[:3]: + for c in get_channels()[:2]: + config = GroupNormConfig.create(2, h, w, c, fp16=fp16, num_groups=group_size, template=template) + run_parity(config, measure_latency=measure_latency) + + +def run_odd_channels(template: int, fp16, measure_latency=False): + # Test some random number of channels that can be divisible by 2 * num_groups + for h, w in get_latent_height_width(): + for c in [448, 704, 832, 1664, 2240, 2688, 2880, 3008]: + config = GroupNormConfig.create(2, h, w, c, fp16=fp16, num_groups=32, template=template) + run_parity(config, measure_latency=measure_latency) + + +def run_small_inputs(template: int, fp16): + config = GroupNormConfig.create(2, 2, 2, 16, fp16=fp16, activation=False, num_groups=4, template=template) + run_parity(config, measure_latency=False) + + config = GroupNormConfig.create(1, 1, 1, 64, fp16=fp16, activation=False, num_groups=8, template=template) + run_parity(config, measure_latency=False) + + config = GroupNormConfig.create(1, 1, 1, 64, fp16=fp16, activation=True, num_groups=8, template=template) + run_parity(config, measure_latency=False) + + +def run_performance(fp16): + # Run perf test to tune parameters for given number of channels. + for h, w in get_latent_height_width()[:3]: + for c in get_channels(): + config = GroupNormConfig.create(2, h, w, c, fp16=fp16, num_groups=32, template=0) + run_parity(config, measure_latency=True) + + +def run_all(template: int): + for fp16 in [True, False]: + run_small_inputs(template, fp16) + run_odd_channels(template, fp16) + run_all_groups(template, fp16) + run_activation(template, fp16) + run_no_activation(template, fp16) + + +def run_not_implemented(): + # Expect failure. Check whether the error message is expected. + try: + config = GroupNormConfig(1, 2, 2, 513, num_groups=3) + run_parity(config) + except RuntimeError as e: + assert "GroupNorm in CUDA does not support the input: n=1 h=2 w=2 c=513 groups=3" in str(e) + + +def main(): + run_performance(True) + + run_not_implemented() + + for template in range(6): + run_all(template) + + +if __name__ == "__main__": + main()