From ff9fa2362d9db31849247a1edaa97259a4eb99fe Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Mon, 9 Oct 2023 19:07:45 +0000 Subject: [PATCH] Update BiasSplitGelu for SDXL Refiner --- .../contrib_ops/cuda/diffusion/bias_split_gelu.cc | 8 ++++++-- .../contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu | 10 ++++++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc index 2b13cdbd803ef..cb02bd8541623 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc @@ -39,9 +39,13 @@ Status BiasSplitGelu::ComputeInternal(OpKernelContext* context) const { "input is expected to have 3 dimensions, got ", input_dims.size()); } - if (input_dims[2] != 2560 && input_dims[2] != 5120 && input_dims[2] != 10240) { + if (input_dims[2] != 2560 && + input_dims[2] != 5120 && + input_dims[2] != 6144 && + input_dims[2] != 10240 && + input_dims[2] != 12288) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "hidden size should be 2560, 5120 or 10240, got ", input_dims[2]); + "hidden size should be 2560, 5120, 6144, 10240 or 12288, got ", input_dims[2]); } const Tensor* bias = context->Input(1); diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu index 19e05a9573f7c..3ae9611d4dfad 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu @@ -65,6 +65,12 @@ void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t h case 5120: (biasSplitGeluKernel)<<>>(input, bias, output); break; + case 3072: + (biasSplitGeluKernel)<<>>(input, bias, output); + break; + case 6144: + (biasSplitGeluKernel)<<>>(input, bias, output); + break; default: ORT_NOT_IMPLEMENTED("Not implemented"); } @@ -73,9 +79,13 @@ void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t h template __global__ void biasSplitGeluKernel(float const*, float const*, float*); template __global__ void biasSplitGeluKernel(float const*, float const*, float*); template __global__ void biasSplitGeluKernel(float const*, float const*, float*); +template __global__ void biasSplitGeluKernel(float const*, float const*, float*); +template __global__ void biasSplitGeluKernel(float const*, float const*, float*); template __global__ void biasSplitGeluKernel(half const*, half const*, half*); template __global__ void biasSplitGeluKernel(half const*, half const*, half*); template __global__ void biasSplitGeluKernel(half const*, half const*, half*); +template __global__ void biasSplitGeluKernel(half const*, half const*, half*); +template __global__ void biasSplitGeluKernel(half const*, half const*, half*); template void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size, float const* input, float const* bias, float* output);