From c4354297ea6bad1caf82f4b8514cccac2eb1a35b Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 10 Oct 2023 20:08:45 -0700 Subject: [PATCH] [CUDA/ROCm] Remove limitation of BiasAdd (#17848) Previously, BiasAdd only supports hidden dimensions of 32, 640 and 1280 for stable diffusion. This adds a kernel that could support any number of channels. ### Motivation and Context Stable Diffusion XL refiner model uses hidden dimensions of 768 or 1536, which was not supported in BiasAdd. --- .../contrib_ops/cuda/diffusion/bias_add.cc | 5 ----- .../cuda/diffusion/bias_add_impl.cu | 21 ++++++++++++++----- .../test/contrib_ops/bias_add_op_test.cc | 14 +++++++++++++ 3 files changed, 30 insertions(+), 10 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc b/onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc index a38dfd34cc977..274bc9a730d87 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc @@ -44,11 +44,6 @@ Status BiasAdd::ComputeInternal(OpKernelContext* context) const { "The input is expected to have 3 dimensions, got ", input_dims.size()); } - if (input_dims[2] != 320 && input_dims[2] != 640 && input_dims[2] != 1280) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Number of channels should be 320, 640 or 1280, got ", input_dims[2]); - } - const Tensor* bias = context->Input(1); const auto& bias_dims = bias->Shape().GetDims(); if (bias_dims.size() != 1) { diff --git a/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu b/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu index 2983cc99e30b1..8e8068b5e56ca 100644 --- a/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu +++ b/onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu @@ -42,6 +42,17 @@ __global__ void BiasAddKernel(T const* input, T const* bias, T const* residual, } } +template +__global__ void BiasAddLargeKernel( + int32_t const ld, const T* input, const T* bias, const T* residual, T* output) { + int32_t const offset = blockIdx.x * ld; + + for (int32_t i = threadIdx.x; i < ld; i += TPB) { + int32_t const base_offset = offset + i; + output[base_offset] = input[base_offset] + bias[i] + residual[base_offset]; + } +} + template __global__ void BiasAddKernel(float const*, float const*, float const*, float*); template __global__ void BiasAddKernel(float const*, float const*, float const*, float*); template __global__ void BiasAddKernel(float const*, float const*, float const*, float*); @@ -52,19 +63,19 @@ template __global__ void BiasAddKernel(half const*, half const* template void LaunchBiasAddKernel(cudaStream_t stream, int32_t grid_size, int32_t num_channels, T const* input, T const* bias, T const* residual, T* output) { - constexpr int32_t TPB = 320; // thread per block switch (num_channels) { case 320: - (BiasAddKernel)<<>>(input, bias, residual, output); + (BiasAddKernel)<<>>(input, bias, residual, output); break; case 640: - (BiasAddKernel)<<>>(input, bias, residual, output); + (BiasAddKernel)<<>>(input, bias, residual, output); break; case 1280: - (BiasAddKernel)<<>>(input, bias, residual, output); + (BiasAddKernel)<<>>(input, bias, residual, output); break; default: - ORT_NOT_IMPLEMENTED("Not implemented"); + BiasAddLargeKernel<<>>(num_channels, input, bias, residual, output); + break; } } diff --git a/onnxruntime/test/contrib_ops/bias_add_op_test.cc b/onnxruntime/test/contrib_ops/bias_add_op_test.cc index 7699f4479caa7..6fd091ef66110 100644 --- a/onnxruntime/test/contrib_ops/bias_add_op_test.cc +++ b/onnxruntime/test/contrib_ops/bias_add_op_test.cc @@ -107,6 +107,20 @@ TEST(BiasAddTest, BiasAddTest_HiddenSize_1280) { constexpr int64_t num_channels = 1280; RunBiasAddTest(batch_size, image_size, num_channels); } + +TEST(BiasAddTest, BiasAddTest_HiddenSize_768) { + constexpr int64_t batch_size = 2; + constexpr int64_t image_size = 5; + constexpr int64_t num_channels = 768; + RunBiasAddTest(batch_size, image_size, num_channels); +} + +TEST(BiasAddTest, BiasAddTest_HiddenSize_1536) { + constexpr int64_t batch_size = 1; + constexpr int64_t image_size = 3; + constexpr int64_t num_channels = 1536; + RunBiasAddTest(batch_size, image_size, num_channels); +} #endif } // namespace test