Skip to content

Commit

Permalink
fix kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Oct 26, 2023
1 parent 3123d66 commit aca36a4
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 58 deletions.
61 changes: 46 additions & 15 deletions onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ struct DispatchGroupNorm {

} // namespace

template<GroupNormOperatorType T>
template <GroupNormOperatorType T>
GroupNorm<T>::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) {
epsilon_ = op_info.GetAttrOrDefault<float>("epsilon", 1e-5f);
ORT_ENFORCE(epsilon_ >= 0);
Expand All @@ -85,7 +85,7 @@ GroupNorm<T>::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) {
channels_last_ = (op_info.GetAttrOrDefault<int64_t>("channels_last", static_cast<int64_t>(1)) != 0);
}

template<GroupNormOperatorType T>
template <GroupNormOperatorType T>
Status GroupNorm<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
const Tensor* gamma = context->Input<Tensor>(1);
Expand All @@ -103,12 +103,23 @@ Status GroupNorm<T>::ComputeInternal(OpKernelContext* context) const {
"input is expected to have 4 dimensions, got ", input_dims.size());
}

// Input and output format is NHWC
int batch_size = static_cast<int>(input_dims[0]);
int num_channels = static_cast<int>(input_dims[3]);
int height = static_cast<int>(input_dims[1]);
int width = static_cast<int>(input_dims[2]);

if (num_channels % num_groups_ != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"number of channels should be divisiable by num_groups");
}

const auto& gamma_dims = gamma->Shape().GetDims();
if (gamma_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"gamma is expected to have 1 dimension, got ", gamma_dims.size());
}
if (gamma_dims[0] != input_dims[3]) {
if (gamma_dims[0] != num_channels) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Number of channels in gamma and input does not match");
}
Expand All @@ -118,22 +129,11 @@ Status GroupNorm<T>::ComputeInternal(OpKernelContext* context) const {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"beta is expected to have 1 dimension, got ", beta_dims.size());
}
if (beta_dims[0] != input_dims[3]) {
if (beta_dims[0] != num_channels) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Number of channels in beta and input does not match");
}

// Input and output format is NHWC
int batch_size = static_cast<int>(input_dims[0]);
int num_channels = static_cast<int>(input_dims[3]);
int height = static_cast<int>(input_dims[1]);
int width = static_cast<int>(input_dims[2]);

if (num_channels % num_groups_ != 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"number of channels should be divisiable by num_groups");
}

if (context->GetUseDeterministicCompute()) {
static std::once_flag log_warning;
std::call_once(log_warning, []() {
Expand All @@ -149,8 +149,39 @@ Status GroupNorm<T>::ComputeInternal(OpKernelContext* context) const {
bias = context->Input<Tensor>(3);
skip = context->Input<Tensor>(4);
add_out = context->Output(1, input->Shape());

// For SkipGroupNorm, bias has shape (C)
const auto& bias_dims = bias->Shape().GetDims();
if (bias_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"bias is expected to have 1 dimension, got ", bias_dims.size());
}
if (bias_dims[0] != num_channels) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Number of channels in bias and input does not match");
}

if (skip->Shape() != input->Shape()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"skip is expected to have same shape as input");
}
} else if (T == BiasGroupNormOp) {
bias = context->Input<Tensor>(3);

// For BiasGroupNorm, bias has shape (N, C)
const auto& bias_dims = bias->Shape().GetDims();
if (bias_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"bias is expected to have 2 dimension, got ", bias_dims.size());
}
if (bias_dims[0] != num_channels) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"First dimension (batch size) in bias and input does not match");
}
if (bias_dims[1] != num_channels) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Number of channels in bias and input does not match");
}
}

auto workspace = GetScratchBuffer<void>(GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups_),
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/diffusion/group_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ enum GroupNormOperatorType {
BiasGroupNormOp
};

template<GroupNormOperatorType opType>
template <GroupNormOperatorType opType>
class GroupNorm final : public CudaKernel {
public:
GroupNorm(const OpKernelInfo& op_kernel_info);
Expand Down
66 changes: 33 additions & 33 deletions onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ __global__ void groupNormNHWCSumKernel(GroupNormNHWCParams<T> params) {
// We have 3 operators:
// (1) SkipGroupNorm: skip is (n, h, w, c) and bias is (c), add_out is (n, h, w, c)
// The additional output add_out = src + skip + bias.
// (2) BiasGroupNorm: bias is (n, 1, 1, c), add_out and skip are empty
// (2) BiasGroupNorm: bias is (n, c), add_out and skip are empty
// (3) GroupNorm: skip, bias and add_out not exists

int64_t offset = static_cast<int64_t>(ni) * params.hwc + static_cast<int64_t>(hwBegin) * params.c + ci;
Expand All @@ -282,8 +282,9 @@ __global__ void groupNormNHWCSumKernel(GroupNormNHWCParams<T> params) {
}
}

// The group that thread works on and the channel in the group (modulus).
// The group index relative to the first group within the same block.
int32_t gi = threadIdx.x * CHANNELS_PER_THREAD / params.cPerGroup;
// The channel in the group.
int32_t cj = ci % params.cPerGroup;

// The data for the summations.
Expand All @@ -294,27 +295,21 @@ __global__ void groupNormNHWCSumKernel(GroupNormNHWCParams<T> params) {
BlockScan(tempStorage).InclusiveScan(inp, out, GroupSumsOp());

// Store the results for the groups in shared memory (to produce coalesced stores later).
if (cj == params.cPerGroup - CHANNELS_PER_THREAD) {
// For each group, only the last thread of that group is picked to save sum to shared memory and update red buffer.
const bool is_last_of_a_group = (cj == params.cPerGroup - CHANNELS_PER_THREAD);
if (is_last_of_a_group) {
smem[gi] = make_float2(out.sum, out.sumSq);
}

// Make sure the data is in shared memory.
__syncthreads();

// The global group index.
int32_t gj = blockIdx.x * params.groupsPerBlock + threadIdx.x;

// Threads that have nothing left to do, exit.
if (threadIdx.x >= params.groupsPerBlock || gj >= params.groups) {
return;
if (is_last_of_a_group) {
int32_t gj = ci / params.cPerGroup; // absolute group index
float2 sums = smem[gi];
atomicAdd(&params.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x);
atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y);
}

// The first threads (those storing to global memory, load the values).
float2 sums = smem[threadIdx.x];

// Store to global memory.
atomicAdd(&params.redBuffer[(2 * ni + 0) * params.groups + gj], sums.x);
atomicAdd(&params.redBuffer[(2 * ni + 1) * params.groups + gj], sums.y);
}

template <typename T>
Expand Down Expand Up @@ -409,7 +404,7 @@ template <typename T>
__global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams<T> params) {
// The channel loaded by that thread.
int32_t ci = blockIdx.x * params.cPerBlock + threadIdx.x * CHANNELS_PER_THREAD;
if (ci >= params.c) {
if (ci >= params.c || threadIdx.x * CHANNELS_PER_THREAD >= params.cPerBlock) {
return;
}

Expand All @@ -435,7 +430,7 @@ __global__ void groupNormNHWCScaleKernel(GroupNormNHWCParams<T> params) {
// Compute the variance.
float var = sumSq * params.invHWC - (mean * mean);
// Compute the inverse of the stddev.
float invStdDev = var <= 0.F ? 1.F : rsqrtf(var + params.epsilon);
float invStdDev = rsqrtf(var + params.epsilon);

// The first activation loaded by that block.
int32_t hwBegin = blockIdx.y * params.hwPerBlock;
Expand Down Expand Up @@ -523,6 +518,8 @@ Status LaunchGroupNormKernel(
bool use_swish_activation) {
GroupNormNHWCParams<T> params;

int32_t cPerGroup = num_channels / num_groups;

int32_t cPerBlock;
switch (num_channels) {
case 2560:
Expand Down Expand Up @@ -553,13 +550,11 @@ Status LaunchGroupNormKernel(
break;
default:
cPerBlock = 320;
}

// Find a maximum cPerBlock that num_channels could be divisible by it.
// Try to be close to 512 since we have multiple kSizes values are within [256, 512] range that could act as fallback.
int32_t cPerGroup = num_channels / num_groups;
if (cPerBlock % cPerGroup != 0) {
cPerBlock = findMaxDivisor(num_groups, kMaxSize / cPerGroup) * cPerGroup;
if (num_channels % cPerBlock != 0 || cPerBlock % cPerGroup != 0) {
// Find a maximum cPerBlock that num_channels could be divisible by it.
// Try to be close to 512 since multiple kSizes values within [256, 512] range could act as fallback.
cPerBlock = findMaxDivisor(num_groups, kMaxSize / cPerGroup) * cPerGroup;
}
}

params.withSwish = use_swish_activation;
Expand All @@ -578,6 +573,7 @@ Status LaunchGroupNormKernel(
params.groups = num_groups;
params.hw = params.h * params.w;

// This will allocate as many blocks as possible to partition HW.
constexpr int32_t maxBlocksPerHW = 1024;
const int32_t blocksPerHW = findMaxDivisor(params.hw, maxBlocksPerHW);
params.hwPerBlock = divUp(params.hw, blocksPerHW);
Expand All @@ -587,9 +583,13 @@ Status LaunchGroupNormKernel(
params.hwc = params.hw * params.c;
params.invHWC = 1.F / (float)(params.hw * params.cPerGroup);
params.groupsPerBlock = cPerBlock / params.cPerGroup;
params.epsilon = epsilon;

// TODO: Update the kernel to support CHANNELS_PER_THREAD==1
if (cPerBlock > 512 || (params.cPerGroup % CHANNELS_PER_THREAD != 0)) {
// TODO: Update the kernel to support CHANNELS_PER_THREAD==1 and other corner cases

Check warning on line 588 in onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu#L588

Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
Raw output
onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu:588:  Missing username in TODO; it should look like "// TODO(my_username): Stuff."  [readability/todo] [2]
if (params.c % params.cPerBlock != 0 ||
params.cPerBlock % params.cPerGroup != 0 ||
cPerBlock > 512 ||
(params.cPerGroup % CHANNELS_PER_THREAD != 0)) {
printf("n=%d h=%d w=%d c=%d groups=%d hw=%d hwPerBlock=%d cPerBlock=%d cPerGroup=%d\n",
params.n, params.h, params.w, params.c, params.groups, params.hw, params.hwPerBlock,
params.cPerBlock, params.cPerGroup);
Expand All @@ -598,13 +598,13 @@ Status LaunchGroupNormKernel(

params.threadsPerBlock = nextSize(cPerBlock) / CHANNELS_PER_THREAD;

cudaMemsetAsync(params.redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), stream);

// Make sure the values are as we expect.
ORT_ENFORCE(params.c % params.cPerBlock == 0);
#ifdef DUMP_GROUP_NORM
printf("n=%d h=%d w=%d c=%d groups=%d hw=%d hwPerBlock=%d cPerBlock=%d cPerGroup=%d threadsPerBlock=%d\n",
params.n, params.h, params.w, params.c, params.groups, params.hw, params.hwPerBlock,
params.cPerBlock, params.cPerGroup, params.threadsPerBlock);
#endif

// Make sure a group does not span multiple blocks.
ORT_ENFORCE(params.cPerBlock % params.cPerGroup == 0);
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(params.redBuffer, 0, GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups), stream));

Check warning on line 607 in onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu#L607

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/diffusion/group_norm_impl.cu:607:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

groupNormNHWCSum<T>(params, stream);
CUDA_RETURN_IF_ERROR(cudaGetLastError());
Expand Down
7 changes: 3 additions & 4 deletions onnxruntime/core/graph/contrib_ops/diffusion_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"beta",
"1D beta tensor for normalization with shape (C), where C is number of channels",
"M")
.Input(3,
.Input(3,
"bias",
"Bias data tensor. Dimensions are (N x C), where N is the batch size and C is the number of channels",
"T")
Expand All @@ -123,7 +123,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
.TypeConstraint("M", {"tensor(float16)", "tensor(float)"}, "Constrain gamma and beta to float tensors.")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput));


constexpr const char* SkipGroupNorm_ver1_doc = R"DOC(
This operator element-wise adds input x, skip and bias, then apply group normalization and optional activation.
Expand Down Expand Up @@ -190,8 +189,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
propagateElemTypeFromInputToOutput(ctx, 0, 0);
propagateElemTypeFromInputToOutput(ctx, 0, 1);
if (hasNInputShapes(ctx, 1)) {
propagateShapeFromInputToOutput(ctx, 0, 0);
propagateShapeFromInputToOutput(ctx, 0, 1);
propagateShapeFromInputToOutput(ctx, 0, 0);
propagateShapeFromInputToOutput(ctx, 0, 1);
}
}));

Expand Down
28 changes: 23 additions & 5 deletions onnxruntime/test/python/transformers/test_group_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,12 @@ def run_parity(config, measure_latency=True):
" G:",
config.num_groups,
" activation:",
config.activation,
int(config.activation),
" channels_last:",
config.channels_last,
int(config.channels_last),
" fp16:",
config.fp16,
" Latency(ms):",
latency * 1000 if isinstance(latency, float) else latency,
int(config.fp16),
f" Latency(ms): {latency * 1000}" if isinstance(latency, float) else "",
" AvgDiff:",
numpy.mean(numpy.abs(ort_result - torch_result)),
" Pass:",
Expand Down Expand Up @@ -250,6 +249,23 @@ def run_odd_channels(fp16, measure_latency=True):
run_parity(config, measure_latency=measure_latency)


def run_small_inputs():
# Test small number of N, H, W, C
config = GroupNormConfig(2, 2, 2, 16, fp16=True, activation=False, num_groups=4)
run_parity(config, measure_latency=False)

config.fp16 = False
config.activation = True
run_parity(config, measure_latency=False)

config = GroupNormConfig(1, 1, 1, 64, fp16=True, activation=False)
run_parity(config, measure_latency=False)

config.fp16 = False
config.activation = True
run_parity(config, measure_latency=False)


def run_performance(fp16):
# Run perf test to tune parameters for given number of channels.
for h, w in get_latent_height_width()[2:3]:
Expand All @@ -261,6 +277,8 @@ def run_performance(fp16):
def run_all():
run_performance(True)

run_small_inputs()

measure_latency = False
run_odd_channels(True, measure_latency=measure_latency)
run_odd_channels(False, measure_latency=measure_latency)
Expand Down

0 comments on commit aca36a4

Please sign in to comment.