From aca36a441ba1ebcb16479f311f3c5b414d9984ca Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Thu, 26 Oct 2023 23:03:17 +0000 Subject: [PATCH] fix kernel --- .../contrib_ops/cuda/diffusion/group_norm.cc | 61 ++++++++++++----- .../contrib_ops/cuda/diffusion/group_norm.h | 2 +- .../cuda/diffusion/group_norm_impl.cu | 66 +++++++++---------- .../core/graph/contrib_ops/diffusion_defs.cc | 7 +- .../python/transformers/test_group_norm.py | 28 ++++++-- 5 files changed, 106 insertions(+), 58 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc index 4fd5e21636530..bf28384522def 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc @@ -67,7 +67,7 @@ struct DispatchGroupNorm { } // namespace -template +template GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { epsilon_ = op_info.GetAttrOrDefault("epsilon", 1e-5f); ORT_ENFORCE(epsilon_ >= 0); @@ -85,7 +85,7 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) { channels_last_ = (op_info.GetAttrOrDefault("channels_last", static_cast(1)) != 0); } -template +template Status GroupNorm::ComputeInternal(OpKernelContext* context) const { const Tensor* input = context->Input(0); const Tensor* gamma = context->Input(1); @@ -103,12 +103,23 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { "input is expected to have 4 dimensions, got ", input_dims.size()); } + // Input and output format is NHWC + int batch_size = static_cast(input_dims[0]); + int num_channels = static_cast(input_dims[3]); + int height = static_cast(input_dims[1]); + int width = static_cast(input_dims[2]); + + if (num_channels % num_groups_ != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "number of channels should be divisiable by num_groups"); + } + const auto& gamma_dims = gamma->Shape().GetDims(); if (gamma_dims.size() != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "gamma is expected to have 1 dimension, got ", gamma_dims.size()); } - if (gamma_dims[0] != input_dims[3]) { + if (gamma_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of channels in gamma and input does not match"); } @@ -118,22 +129,11 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "beta is expected to have 1 dimension, got ", beta_dims.size()); } - if (beta_dims[0] != input_dims[3]) { + if (beta_dims[0] != num_channels) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Number of channels in beta and input does not match"); } - // Input and output format is NHWC - int batch_size = static_cast(input_dims[0]); - int num_channels = static_cast(input_dims[3]); - int height = static_cast(input_dims[1]); - int width = static_cast(input_dims[2]); - - if (num_channels % num_groups_ != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "number of channels should be divisiable by num_groups"); - } - if (context->GetUseDeterministicCompute()) { static std::once_flag log_warning; std::call_once(log_warning, []() { @@ -149,8 +149,39 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const { bias = context->Input(3); skip = context->Input(4); add_out = context->Output(1, input->Shape()); + + // For SkipGroupNorm, bias has shape (C) + const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "bias is expected to have 1 dimension, got ", bias_dims.size()); + } + if (bias_dims[0] != num_channels) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of channels in bias and input does not match"); + } + + if (skip->Shape() != input->Shape()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "skip is expected to have same shape as input"); + } } else if (T == BiasGroupNormOp) { bias = context->Input(3); + + // For BiasGroupNorm, bias has shape (N, C) + const auto& bias_dims = bias->Shape().GetDims(); + if (bias_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "bias is expected to have 2 dimension, got ", bias_dims.size()); + } + if (bias_dims[0] != num_channels) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "First dimension (batch size) in bias and input does not match"); + } + if (bias_dims[1] != num_channels) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Number of channels in bias and input does not match"); + } } auto workspace = GetScratchBuffer(GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups_), diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h index 9c4bca093316c..4636ee9f29cd9 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm.h @@ -17,7 +17,7 @@ enum GroupNormOperatorType { BiasGroupNormOp }; -template +template class GroupNorm final : public CudaKernel { public: GroupNorm(const OpKernelInfo& op_kernel_info); diff --git a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu index 957e5f3428c55..10f1121ebf55f 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu @@ -262,7 +262,7 @@ __global__ void groupNormNHWCSumKernel(GroupNormNHWCParams params) { // We have 3 operators: // (1) SkipGroupNorm: skip is (n, h, w, c) and bias is (c), add_out is (n, h, w, c) // The additional output add_out = src + skip + bias. - // (2) BiasGroupNorm: bias is (n, 1, 1, c), add_out and skip are empty + // (2) BiasGroupNorm: bias is (n, c), add_out and skip are empty // (3) GroupNorm: skip, bias and add_out not exists int64_t offset = static_cast(ni) * params.hwc + static_cast(hwBegin) * params.c + ci; @@ -282,8 +282,9 @@ __global__ void groupNormNHWCSumKernel(GroupNormNHWCParams params) { } } - // The group that thread works on and the channel in the group (modulus). + // The group index relative to the first group within the same block. int32_t gi = threadIdx.x * CHANNELS_PER_THREAD / params.cPerGroup; + // The channel in the group. int32_t cj = ci % params.cPerGroup; // The data for the summations. @@ -294,27 +295,21 @@ __global__ void groupNormNHWCSumKernel(GroupNormNHWCParams params) { BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp()); // Store the results for the groups in shared memory (to produce coalesced stores later). - if (cj == params.cPerGroup - CHANNELS_PER_THREAD) { + // For each group, only the last thread of that group is picked to save sum to shared memory and update red buffer. + const bool is_last_of_a_group = (cj == params.cPerGroup - CHANNELS_PER_THREAD); + if (is_last_of_a_group) { smem[gi] = make_float2(out.sum, out.sumSq); } // Make sure the data is in shared memory. __syncthreads(); - // The global group index. - int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x; - - // Threads that have nothing left to do, exit. - if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) { - return; + if (is_last_of_a_group) { + int32_t gj = ci / params.cPerGroup; // absolute group index + float2 sums = smem[gi]; + atomicAdd(¶ms.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x); + atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); } - - // The first threads (those storing to global memory, load the values). - float2 sums = smem[threadIdx.x]; - - // Store to global memory. - atomicAdd(¶ms.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x); - atomicAdd(¶ms.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y); } template @@ -409,7 +404,7 @@ template __global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams params) { // The channel loaded by that thread. int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * CHANNELS_PER_THREAD; - if (ci >= params.c) { + if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.cPerBlock) { return; } @@ -435,7 +430,7 @@ __global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams params) { // Compute the variance. float var = sumSq * params.invHWC - (mean * mean); // Compute the inverse of the stddev. - float invStdDev = var <= 0.F ? 1.F : rsqrtf(var + params.epsilon); + float invStdDev = rsqrtf(var + params.epsilon); // The first activation loaded by that block. int32_t hwBegin = blockIdx.y * params.hwPerBlock; @@ -523,6 +518,8 @@ Status LaunchGroupNormKernel( bool use_swish_activation) { GroupNormNHWCParams params; + int32_t cPerGroup = num_channels / num_groups; + int32_t cPerBlock; switch (num_channels) { case 2560: @@ -553,13 +550,11 @@ Status LaunchGroupNormKernel( break; default: cPerBlock = 320; - } - - // Find a maximum cPerBlock that num_channels could be divisible by it. - // Try to be close to 512 since we have multiple kSizes values are within [256, 512] range that could act as fallback. - int32_t cPerGroup = num_channels / num_groups; - if (cPerBlock % cPerGroup != 0) { - cPerBlock = findMaxDivisor(num_groups, kMaxSize / cPerGroup) * cPerGroup; + if (num_channels % cPerBlock != 0 || cPerBlock % cPerGroup != 0) { + // Find a maximum cPerBlock that num_channels could be divisible by it. + // Try to be close to 512 since multiple kSizes values within [256, 512] range could act as fallback. + cPerBlock = findMaxDivisor(num_groups, kMaxSize / cPerGroup) * cPerGroup; + } } params.withSwish = use_swish_activation; @@ -578,6 +573,7 @@ Status LaunchGroupNormKernel( params.groups = num_groups; params.hw = params.h * params.w; + // This will allocate as many blocks as possible to partition HW. constexpr int32_t maxBlocksPerHW = 1024; const int32_t blocksPerHW = findMaxDivisor(params.hw, maxBlocksPerHW); params.hwPerBlock = divUp(params.hw, blocksPerHW); @@ -587,9 +583,13 @@ Status LaunchGroupNormKernel( params.hwc = params.hw * params.c; params.invHWC = 1.F / (float)(params.hw * params.cPerGroup); params.groupsPerBlock = cPerBlock / params.cPerGroup; + params.epsilon = epsilon; - // TODO: Update the kernel to support CHANNELS_PER_THREAD==1 - if (cPerBlock > 512 || (params.cPerGroup % CHANNELS_PER_THREAD != 0)) { + // TODO: Update the kernel to support CHANNELS_PER_THREAD==1 and other corner cases + if (params.c % params.cPerBlock != 0 || + params.cPerBlock % params.cPerGroup != 0 || + cPerBlock > 512 || + (params.cPerGroup % CHANNELS_PER_THREAD != 0)) { printf("n=%d h=%d w=%d c=%d groups=%d hw=%d hwPerBlock=%d cPerBlock=%d cPerGroup=%d\n", params.n, params.h, params.w, params.c, params.groups, params.hw, params.hwPerBlock, params.cPerBlock, params.cPerGroup); @@ -598,13 +598,13 @@ Status LaunchGroupNormKernel( params.threadsPerBlock = nextSize(cPerBlock) / CHANNELS_PER_THREAD; - cudaMemsetAsync(params.redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), stream); - - // Make sure the values are as we expect. - ORT_ENFORCE(params.c % params.cPerBlock == 0); +#ifdef DUMP_GROUP_NORM + printf("n=%d h=%d w=%d c=%d groups=%d hw=%d hwPerBlock=%d cPerBlock=%d cPerGroup=%d threadsPerBlock=%d\n", + params.n, params.h, params.w, params.c, params.groups, params.hw, params.hwPerBlock, + params.cPerBlock, params.cPerGroup, params.threadsPerBlock); +#endif - // Make sure a group does not span multiple blocks. - ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0); + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(params.redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), stream)); groupNormNHWCSum(params, stream); CUDA_RETURN_IF_ERROR(cudaGetLastError()); diff --git a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc index 3e487f65e1b66..a87b7f4d923a3 100644 --- a/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/diffusion_defs.cc @@ -111,7 +111,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "beta", "1D beta tensor for normalization with shape (C), where C is number of channels", "M") - .Input(3, + .Input(3, "bias", "Bias data tensor. Dimensions are (N x C), where N is the batch size and C is the number of channels", "T") @@ -123,7 +123,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .TypeConstraint("M", {"tensor(float16)", "tensor(float)"}, "Constrain gamma and beta to float tensors.") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput)); - constexpr const char* SkipGroupNorm_ver1_doc = R"DOC( This operator element-wise adds input x, skip and bias, then apply group normalization and optional activation. @@ -190,8 +189,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( propagateElemTypeFromInputToOutput(ctx, 0, 0); propagateElemTypeFromInputToOutput(ctx, 0, 1); if (hasNInputShapes(ctx, 1)) { - propagateShapeFromInputToOutput(ctx, 0, 0); - propagateShapeFromInputToOutput(ctx, 0, 1); + propagateShapeFromInputToOutput(ctx, 0, 0); + propagateShapeFromInputToOutput(ctx, 0, 1); } })); diff --git a/onnxruntime/test/python/transformers/test_group_norm.py b/onnxruntime/test/python/transformers/test_group_norm.py index e8c8d04eb835f..88157a77cae05 100644 --- a/onnxruntime/test/python/transformers/test_group_norm.py +++ b/onnxruntime/test/python/transformers/test_group_norm.py @@ -180,13 +180,12 @@ def run_parity(config, measure_latency=True): " G:", config.num_groups, " activation:", - config.activation, + int(config.activation), " channels_last:", - config.channels_last, + int(config.channels_last), " fp16:", - config.fp16, - " Latency(ms):", - latency * 1000 if isinstance(latency, float) else latency, + int(config.fp16), + f" Latency(ms): {latency * 1000}" if isinstance(latency, float) else "", " AvgDiff:", numpy.mean(numpy.abs(ort_result - torch_result)), " Pass:", @@ -250,6 +249,23 @@ def run_odd_channels(fp16, measure_latency=True): run_parity(config, measure_latency=measure_latency) +def run_small_inputs(): + # Test small number of N, H, W, C + config = GroupNormConfig(2, 2, 2, 16, fp16=True, activation=False, num_groups=4) + run_parity(config, measure_latency=False) + + config.fp16 = False + config.activation = True + run_parity(config, measure_latency=False) + + config = GroupNormConfig(1, 1, 1, 64, fp16=True, activation=False) + run_parity(config, measure_latency=False) + + config.fp16 = False + config.activation = True + run_parity(config, measure_latency=False) + + def run_performance(fp16): # Run perf test to tune parameters for given number of channels. for h, w in get_latent_height_width()[2:3]: @@ -261,6 +277,8 @@ def run_performance(fp16): def run_all(): run_performance(True) + run_small_inputs() + measure_latency = False run_odd_channels(True, measure_latency=measure_latency) run_odd_channels(False, measure_latency=measure_latency)