Skip to content

Commit

Permalink
Add SkipGroupNorm and BiasGroupNorm
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Oct 26, 2023
1 parent 1fae3d3 commit 3123d66
Show file tree
Hide file tree
Showing 8 changed files with 313 additions and 45 deletions.
4 changes: 4 additions & 0 deletions onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Gelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Gelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasGelu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasGroupNorm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasSplitGelu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasAdd);
Expand Down Expand Up @@ -95,6 +96,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Samp
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SkipGroupNorm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization);
Expand Down Expand Up @@ -179,6 +181,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, Gelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BiasGroupNorm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasSplitGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, BiasSplitGelu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, BiasAdd)>,
Expand Down Expand Up @@ -254,6 +257,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, SkipGroupNorm)>,

Check warning on line 260 in onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc#L260

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc:260:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, SkipLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization)>,
Expand Down
40 changes: 35 additions & 5 deletions onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@ namespace cuda {

ONNX_OPERATOR_KERNEL_EX(
GroupNorm, kMSDomain, 1, kCudaExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints<GROUP_NORM_TYPES>()), GroupNorm);
(*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints<GROUP_NORM_TYPES>()), GroupNorm<GroupNormOp>);

Check warning on line 16 in onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc#L16

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc:16:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

ONNX_OPERATOR_KERNEL_EX(
SkipGroupNorm, kMSDomain, 1, kCudaExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints<GROUP_NORM_TYPES>()), GroupNorm<SkipGroupNormOp>);

Check warning on line 20 in onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc#L20

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc:20:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

ONNX_OPERATOR_KERNEL_EX(
BiasGroupNorm, kMSDomain, 1, kCudaExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", BuildKernelDefConstraints<GROUP_NORM_TYPES>()), GroupNorm<BiasGroupNormOp>);

Check warning on line 24 in onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc#L24

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/diffusion/group_norm.cc:24:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

using namespace ONNX_NAMESPACE;

Expand All @@ -22,7 +30,10 @@ template <typename T>
struct DispatchGroupNorm {
Status operator()(cudaStream_t stream,
Tensor* output,
Tensor* add_out,
const Tensor* input,
const Tensor* skip,
const Tensor* bias,
const Tensor* gamma,
const Tensor* beta,
void* workspace,
Expand All @@ -37,7 +48,10 @@ struct DispatchGroupNorm {
return LaunchGroupNormKernel<CudaT>(
stream,
reinterpret_cast<CudaT*>(output->MutableData<T>()),
add_out == nullptr ? nullptr : reinterpret_cast<CudaT*>(add_out->MutableData<T>()),
reinterpret_cast<const CudaT*>(input->Data<T>()),
skip == nullptr ? nullptr : reinterpret_cast<const CudaT*>(skip->Data<T>()),
bias == nullptr ? nullptr : reinterpret_cast<const CudaT*>(bias->Data<T>()),
gamma->Data<float>(),
beta->Data<float>(),
workspace,
Expand All @@ -53,7 +67,8 @@ struct DispatchGroupNorm {

} // namespace

GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) {
template<GroupNormOperatorType T>
GroupNorm<T>::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) {
epsilon_ = op_info.GetAttrOrDefault<float>("epsilon", 1e-5f);
ORT_ENFORCE(epsilon_ >= 0);

Expand All @@ -70,7 +85,8 @@ GroupNorm::GroupNorm(const OpKernelInfo& op_info) : CudaKernel(op_info) {
channels_last_ = (op_info.GetAttrOrDefault<int64_t>("channels_last", static_cast<int64_t>(1)) != 0);
}

Status GroupNorm::ComputeInternal(OpKernelContext* context) const {
template<GroupNormOperatorType T>
Status GroupNorm<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
const Tensor* gamma = context->Input<Tensor>(1);
const Tensor* beta = context->Input<Tensor>(2);
Expand Down Expand Up @@ -125,10 +141,24 @@ Status GroupNorm::ComputeInternal(OpKernelContext* context) const {
});
}

auto workspace = GetScratchBuffer<void>(GetGroupNormWorkspaceSizeInBytes(), context->GetComputeStream());
const Tensor* skip = nullptr;
const Tensor* bias = nullptr;
Tensor* add_out = nullptr;

if (T == SkipGroupNormOp) {
bias = context->Input<Tensor>(3);
skip = context->Input<Tensor>(4);
add_out = context->Output(1, input->Shape());
} else if (T == BiasGroupNormOp) {
bias = context->Input<Tensor>(3);
}

auto workspace = GetScratchBuffer<void>(GetGroupNormWorkspaceSizeInBytes(batch_size, num_groups_),
context->GetComputeStream());

utils::MLTypeCallDispatcher<GROUP_NORM_TYPES> dispatcher(input->GetElementType());
return dispatcher.InvokeRet<Status, DispatchGroupNorm>(Stream(context), output, input, gamma, beta, workspace.get(),
return dispatcher.InvokeRet<Status, DispatchGroupNorm>(Stream(context), output, add_out, input, skip, bias,
gamma, beta, workspace.get(),
epsilon_,
batch_size,
num_channels,
Expand Down
9 changes: 8 additions & 1 deletion onnxruntime/contrib_ops/cuda/diffusion/group_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,19 @@ namespace cuda {

using namespace onnxruntime::cuda;

enum GroupNormOperatorType {
GroupNormOp,
SkipGroupNormOp,
BiasGroupNormOp
};

template<GroupNormOperatorType opType>
class GroupNorm final : public CudaKernel {
public:
GroupNorm(const OpKernelInfo& op_kernel_info);
Status ComputeInternal(OpKernelContext* context) const override;

private:
protected:
bool use_swish_activation_;
float epsilon_;
int num_groups_;
Expand Down
Loading

0 comments on commit 3123d66

Please sign in to comment.