Skip to content

Commit

Permalink
[CUDA/ROCm] Remove limitation of BiasAdd (microsoft#17848)
Browse files Browse the repository at this point in the history
Previously, BiasAdd only supports hidden dimensions of 32, 640 and 1280
for stable diffusion. This adds a kernel that could support any number
of channels.

### Motivation and Context
Stable Diffusion XL refiner model uses hidden dimensions of 768 or 1536,
which was not supported in BiasAdd.
  • Loading branch information
tianleiwu authored and kleiti committed Mar 22, 2024
1 parent 09f7d2d commit c435429
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 10 deletions.
5 changes: 0 additions & 5 deletions onnxruntime/contrib_ops/cuda/diffusion/bias_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@ Status BiasAdd<T>::ComputeInternal(OpKernelContext* context) const {
"The input is expected to have 3 dimensions, got ", input_dims.size());
}

if (input_dims[2] != 320 && input_dims[2] != 640 && input_dims[2] != 1280) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Number of channels should be 320, 640 or 1280, got ", input_dims[2]);
}

const Tensor* bias = context->Input<Tensor>(1);
const auto& bias_dims = bias->Shape().GetDims();
if (bias_dims.size() != 1) {
Expand Down
21 changes: 16 additions & 5 deletions onnxruntime/contrib_ops/cuda/diffusion/bias_add_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,17 @@ __global__ void BiasAddKernel(T const* input, T const* bias, T const* residual,
}
}

template <typename T, unsigned TPB>
__global__ void BiasAddLargeKernel(
int32_t const ld, const T* input, const T* bias, const T* residual, T* output) {
int32_t const offset = blockIdx.x * ld;

for (int32_t i = threadIdx.x; i < ld; i += TPB) {
int32_t const base_offset = offset + i;
output[base_offset] = input[base_offset] + bias[i] + residual[base_offset];
}
}

template __global__ void BiasAddKernel<float, 320, 320>(float const*, float const*, float const*, float*);
template __global__ void BiasAddKernel<float, 640, 320>(float const*, float const*, float const*, float*);
template __global__ void BiasAddKernel<float, 1280, 320>(float const*, float const*, float const*, float*);
Expand All @@ -52,19 +63,19 @@ template __global__ void BiasAddKernel<half, 1280, 320>(half const*, half const*
template <typename T>
void LaunchBiasAddKernel(cudaStream_t stream, int32_t grid_size, int32_t num_channels,
T const* input, T const* bias, T const* residual, T* output) {
constexpr int32_t TPB = 320; // thread per block
switch (num_channels) {
case 320:
(BiasAddKernel<T, 320, TPB>)<<<grid_size, TPB, 0, stream>>>(input, bias, residual, output);
(BiasAddKernel<T, 320, 320>)<<<grid_size, 320, 0, stream>>>(input, bias, residual, output);
break;
case 640:
(BiasAddKernel<T, 640, TPB>)<<<grid_size, TPB, 0, stream>>>(input, bias, residual, output);
(BiasAddKernel<T, 640, 320>)<<<grid_size, 320, 0, stream>>>(input, bias, residual, output);
break;
case 1280:
(BiasAddKernel<T, 1280, TPB>)<<<grid_size, TPB, 0, stream>>>(input, bias, residual, output);
(BiasAddKernel<T, 1280, 320>)<<<grid_size, 320, 0, stream>>>(input, bias, residual, output);
break;
default:
ORT_NOT_IMPLEMENTED("Not implemented");
BiasAddLargeKernel<T, 256><<<grid_size, 256, 0, stream>>>(num_channels, input, bias, residual, output);
break;
}
}

Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/test/contrib_ops/bias_add_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,20 @@ TEST(BiasAddTest, BiasAddTest_HiddenSize_1280) {
constexpr int64_t num_channels = 1280;
RunBiasAddTest(batch_size, image_size, num_channels);
}

TEST(BiasAddTest, BiasAddTest_HiddenSize_768) {
constexpr int64_t batch_size = 2;
constexpr int64_t image_size = 5;
constexpr int64_t num_channels = 768;
RunBiasAddTest(batch_size, image_size, num_channels);
}

TEST(BiasAddTest, BiasAddTest_HiddenSize_1536) {
constexpr int64_t batch_size = 1;
constexpr int64_t image_size = 3;
constexpr int64_t num_channels = 1536;
RunBiasAddTest(batch_size, image_size, num_channels);
}
#endif

} // namespace test
Expand Down

0 comments on commit c435429

Please sign in to comment.