From d637111e9f3f816a81a8bfc27d7f183b824fc6a4 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 10 Oct 2023 11:07:27 -0700 Subject: [PATCH] [CUDA/ROCm] Update BiasSplitGelu for SD XL Refiner model (#17849) SD XL Refiner model has new hidden dimension sizes not supported by BiasSplitGelu. This update the kernel to support them. ### Motivation and Context Current BiasSplitGelu does not support optimization for SD XL refiner model. --- .../contrib_ops/cuda/diffusion/bias_split_gelu.cc | 8 ++++++-- .../cuda/diffusion/bias_split_gelu_impl.cu | 10 ++++++++++ .../test/contrib_ops/bias_split_gelu_op_test.cc | 14 ++++++++++++++ 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc index 2b13cdbd803ef..cb02bd8541623 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc @@ -39,9 +39,13 @@ Status BiasSplitGelu::ComputeInternal(OpKernelContext* context) const { "input is expected to have 3 dimensions, got ", input_dims.size()); } - if (input_dims[2] != 2560 && input_dims[2] != 5120 && input_dims[2] != 10240) { + if (input_dims[2] != 2560 && + input_dims[2] != 5120 && + input_dims[2] != 6144 && + input_dims[2] != 10240 && + input_dims[2] != 12288) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "hidden size should be 2560, 5120 or 10240, got ", input_dims[2]); + "hidden size should be 2560, 5120, 6144, 10240 or 12288, got ", input_dims[2]); } const Tensor* bias = context->Input(1); diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu index 19e05a9573f7c..3ae9611d4dfad 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu @@ -65,6 +65,12 @@ void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t h case 5120: (biasSplitGeluKernel)<<>>(input, bias, output); break; + case 3072: + (biasSplitGeluKernel)<<>>(input, bias, output); + break; + case 6144: + (biasSplitGeluKernel)<<>>(input, bias, output); + break; default: ORT_NOT_IMPLEMENTED("Not implemented"); } @@ -73,9 +79,13 @@ void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t h template __global__ void biasSplitGeluKernel(float const*, float const*, float*); template __global__ void biasSplitGeluKernel(float const*, float const*, float*); template __global__ void biasSplitGeluKernel(float const*, float const*, float*); +template __global__ void biasSplitGeluKernel(float const*, float const*, float*); +template __global__ void biasSplitGeluKernel(float const*, float const*, float*); template __global__ void biasSplitGeluKernel(half const*, half const*, half*); template __global__ void biasSplitGeluKernel(half const*, half const*, half*); template __global__ void biasSplitGeluKernel(half const*, half const*, half*); +template __global__ void biasSplitGeluKernel(half const*, half const*, half*); +template __global__ void biasSplitGeluKernel(half const*, half const*, half*); template void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size, float const* input, float const* bias, float* output); diff --git a/onnxruntime/test/contrib_ops/bias_split_gelu_op_test.cc b/onnxruntime/test/contrib_ops/bias_split_gelu_op_test.cc index db14eb3da42cd..a979717d23573 100644 --- a/onnxruntime/test/contrib_ops/bias_split_gelu_op_test.cc +++ b/onnxruntime/test/contrib_ops/bias_split_gelu_op_test.cc @@ -152,6 +152,20 @@ TEST(BiasSplitGeluTest, BiasSplitGeluTest_HiddenSize_10240) { RunBiasSplitGeluTest(batch_size, sequence_length, hidden_size); } +TEST(BiasSplitGeluTest, BiasSplitGeluTest_HiddenSize_6144) { + constexpr int64_t batch_size = 2; + constexpr int64_t sequence_length = 3; + constexpr int64_t hidden_size = 6144; + RunBiasSplitGeluTest(batch_size, sequence_length, hidden_size); +} + +TEST(BiasSplitGeluTest, BiasSplitGeluTest_HiddenSize_12288) { + constexpr int64_t batch_size = 1; + constexpr int64_t sequence_length = 2; + constexpr int64_t hidden_size = 12288; + RunBiasSplitGeluTest(batch_size, sequence_length, hidden_size); +} + #endif } // namespace test