Skip to content

Commit

Permalink
Test SkipGroupNorm
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Oct 27, 2023
1 parent aca36a4 commit 5d5c14e
Show file tree
Hide file tree
Showing 6 changed files with 4,428 additions and 94 deletions.
4,229 changes: 4,229 additions & 0 deletions .gitignore

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ Status GroupNorm<T>::ComputeInternal(OpKernelContext* context) const {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"bias is expected to have 2 dimension, got ", bias_dims.size());
}
if (bias_dims[0] != num_channels) {
if (bias_dims[0] != batch_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"First dimension (batch size) in bias and input does not match");
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/diffusion/group_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class GroupNorm final : public CudaKernel {
GroupNorm(const OpKernelInfo& op_kernel_info);
Status ComputeInternal(OpKernelContext* context) const override;

protected:
private:
bool use_swish_activation_;
float epsilon_;
int num_groups_;
Expand Down
49 changes: 29 additions & 20 deletions onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/

// The CUDA kernel is modified from GroupNorm plugin of TensorRT 8.5
// Modifications: support more cPerBlock
// Modifications: heuristic cPerBlock; support epsilon; support skip and bias etc.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

Expand All @@ -28,6 +28,8 @@
#include "contrib_ops/cuda/diffusion/group_norm_impl.h"
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"

using namespace onnxruntime::cuda;

Check warning on line 31 in onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu#L31

Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
Raw output
onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu:31:  Do not use namespace using-directives.  Use using-declarations instead.  [build/namespaces] [5]

namespace onnxruntime {
namespace contrib {
namespace cuda {
Expand Down Expand Up @@ -96,7 +98,7 @@ struct GroupNormNHWCParams {
float const* gamma;
// The beta term to add in GN.
float const* beta;
// The temporary buffer to do the global parallel reduction. Size is n x g x 2, where g is number of groups.
// The temporary buffer to do the global parallel reduction. Shape is (n, 2, g), where g is number of groups.
float* redBuffer;

// The number of instances in the batch.
Expand Down Expand Up @@ -203,27 +205,31 @@ inline __device__ void AddSkipBias(const float* src, const float* skip, const fl

// Sum for BiasGroupNorm
template <typename T>
inline __device__ void AddBias(const T* src, const T* bias,
inline __device__ void AddBias(const T* src, const T* bias, T* add_out,
int64_t offset, int32_t bias_offset, float& sum, float& sumSq);

template <>
inline __device__ void AddBias(const half* src, const half* bias,
inline __device__ void AddBias(const half* src, const half* bias, half* add_out,
int64_t offset, int32_t bias_offset, float& sum, float& sumSq) {
__half2 h2 = *reinterpret_cast<__half2 const*>(&src[offset]);
__half2 b = *reinterpret_cast<__half2 const*>(&bias[bias_offset]);
h2 += b;

*reinterpret_cast<__half2*>(&add_out[offset]) = h2;

float2 f2 = __half22float2(h2);
sum += f2.x + f2.y;
sumSq += f2.x * f2.x + f2.y * f2.y;
}

template <>
inline __device__ void AddBias(const float* src, const float* bias,
inline __device__ void AddBias(const float* src, const float* bias, float* add_out,
int64_t offset, int32_t bias_offset, float& sum, float& sumSq) {
float2 f2 = *reinterpret_cast<float2 const*>(&src[offset]);
float2 b = *reinterpret_cast<float2 const*>(&bias[bias_offset]);
f2.x += b.x;
f2.y += b.y;
*reinterpret_cast<float2*>(&add_out[offset]) = f2;
sum += f2.x + f2.y;
sumSq += f2.x * f2.x + f2.y * f2.y;
}
Expand Down Expand Up @@ -263,6 +269,7 @@ __global__ void groupNormNHWCSumKernel(GroupNormNHWCParams<T> params) {
// (1) SkipGroupNorm: skip is (n, h, w, c) and bias is (c), add_out is (n, h, w, c)
// The additional output add_out = src + skip + bias.
// (2) BiasGroupNorm: bias is (n, c), add_out and skip are empty
// We will use dst as temp storage to store src + bias.
// (3) GroupNorm: skip, bias and add_out not exists

int64_t offset = static_cast<int64_t>(ni) * params.hwc + static_cast<int64_t>(hwBegin) * params.c + ci;
Expand All @@ -274,7 +281,7 @@ __global__ void groupNormNHWCSumKernel(GroupNormNHWCParams<T> params) {
} else if (params.bias != nullptr) { // BiasGroupNorm
const int64_t bias_offset = static_cast<int64_t>(ni) * params.c + ci;
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi, offset += params.c) {
AddBias(params.src, params.bias, offset, bias_offset, sum, sumSq);
AddBias(params.src, params.bias, params.dst, offset, bias_offset, sum, sumSq);
}
} else { // GroupNorm
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi, offset += params.c) {
Expand Down Expand Up @@ -307,8 +314,9 @@ __global__ void groupNormNHWCSumKernel(GroupNormNHWCParams<T> params) {
if (is_last_of_a_group) {
int32_t gj = ci / params.cPerGroup; // absolute group index
float2 sums = smem[gi];
atomicAdd(&params.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x);
atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y);
const int index = (2 * ni) * params.groups + gj;
atomicAdd(&params.redBuffer[index], sums.x);
atomicAdd(&params.redBuffer[index + params.groups], sums.y);
}
}

Expand Down Expand Up @@ -411,14 +419,15 @@ __global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams<T> params) {
// The instance in the batch.
int32_t ni = blockIdx.z;

// The group that thread works on and the channel in the group (modulus).
// The group that thread works on.
int32_t gi = ci / params.cPerGroup;

// Load the sum and sum of squares for the group.
float sum = 0.F, sumSq = 0.F;
if (gi < params.groups) {
sum = params.redBuffer[(2 * ni + 0) * params.groups + gi];
sumSq = params.redBuffer[(2 * ni + 1) * params.groups + gi];
const int index = (2 * ni) * params.groups + gi;
sum = params.redBuffer[index];
sumSq = params.redBuffer[index + params.groups];
}

// Load gamma/beta.
Expand All @@ -432,18 +441,15 @@ __global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams<T> params) {
// Compute the inverse of the stddev.
float invStdDev = rsqrtf(var + params.epsilon);

// The first activation loaded by that block.
int32_t hwBegin = blockIdx.y * params.hwPerBlock;
// The last activation loaded by that block.
int32_t hwEnd = min(hwBegin + params.hwPerBlock, params.hw);

// Iterate over the activations to compute the sums.
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi) {
// The src/dst offset.
int64_t offset = (int64_t)ni * params.hwc + hwi * params.c + ci;

// Fetch two channels per thread.
computeGroupNorm<T>(params.src, params.dst, offset, mean, invStdDev, gammaF2, betaF2, params.withSwish);
// For SkipGroupNorm, the source is sum of src + skip + bias, which was stored in add_out.
// For BiasGroupNorm, the source is src + bias, which was stored in dst as intermediate data.
const T* source = (params.skip != nullptr) ? params.add_out : (params.bias != nullptr ? params.dst : params.src);
int64_t offset = static_cast<int64_t>(ni) * params.hwc + static_cast<int64_t>(hwBegin) * params.c + ci;
for (int32_t hwi = hwBegin; hwi < hwEnd; ++hwi, offset += params.c) {
computeGroupNorm<T>(source, params.dst, offset, mean, invStdDev, gammaF2, betaF2, params.withSwish);
}
}

Expand Down Expand Up @@ -609,6 +615,9 @@ Status LaunchGroupNormKernel(
groupNormNHWCSum<T>(params, stream);
CUDA_RETURN_IF_ERROR(cudaGetLastError());

DUMP_TENSOR_INIT();
DUMP_TENSOR("workspace", params.redBuffer, batch_size, 2, num_groups);

groupNormNHWCScale<T>(params, stream);
CUDA_RETURN_IF_ERROR(cudaGetLastError());

Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/python/tools/transformers/io_binding_helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from collections import OrderedDict
from typing import Any, Dict, List
from typing import Any, Dict

import numpy
import torch
Expand Down Expand Up @@ -112,7 +112,7 @@ def prepare_io_binding(
input_ids: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor,
past: List[torch.Tensor],
past: list[torch.Tensor],
output_buffers,
output_shapes,
name_to_np_type=None,
Expand Down Expand Up @@ -229,7 +229,7 @@ def __del__(self):
del self.io_binding
del self.ort_session

def allocate_buffers(self, shape_dict: Dict[str, tuple]):
def allocate_buffers(self, shape_dict: Dict[str, tuple[int] | list[int]]):
"""Allocate tensors for I/O Binding"""
if self.enable_cuda_graph:
for name, shape in shape_dict.items():
Expand Down
Loading

0 comments on commit 5d5c14e

Please sign in to comment.