Skip to content

Commit

Permalink
[CUDA/ROCm] Update BiasSplitGelu for SD XL Refiner model (#17849)
Browse files Browse the repository at this point in the history
SD XL Refiner model has new hidden dimension sizes not supported by BiasSplitGelu. This update the kernel to support them.

### Motivation and Context
Current BiasSplitGelu does not support optimization for SD XL refiner model.
  • Loading branch information
tianleiwu authored Oct 10, 2023
1 parent 9a1c884 commit d637111
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 2 deletions.
8 changes: 6 additions & 2 deletions onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,13 @@ Status BiasSplitGelu<T>::ComputeInternal(OpKernelContext* context) const {
"input is expected to have 3 dimensions, got ", input_dims.size());
}

if (input_dims[2] != 2560 && input_dims[2] != 5120 && input_dims[2] != 10240) {
if (input_dims[2] != 2560 &&
input_dims[2] != 5120 &&
input_dims[2] != 6144 &&
input_dims[2] != 10240 &&
input_dims[2] != 12288) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"hidden size should be 2560, 5120 or 10240, got ", input_dims[2]);
"hidden size should be 2560, 5120, 6144, 10240 or 12288, got ", input_dims[2]);
}

const Tensor* bias = context->Input<Tensor>(1);
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t h
case 5120:
(biasSplitGeluKernel<T, 5120, TPB>)<<<grid_size, TPB, 0, stream>>>(input, bias, output);
break;
case 3072:
(biasSplitGeluKernel<T, 3072, TPB>)<<<grid_size, TPB, 0, stream>>>(input, bias, output);
break;
case 6144:
(biasSplitGeluKernel<T, 6144, TPB>)<<<grid_size, TPB, 0, stream>>>(input, bias, output);
break;
default:
ORT_NOT_IMPLEMENTED("Not implemented");
}
Expand All @@ -73,9 +79,13 @@ void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t h
template __global__ void biasSplitGeluKernel<float, 1280, 256>(float const*, float const*, float*);
template __global__ void biasSplitGeluKernel<float, 2560, 256>(float const*, float const*, float*);
template __global__ void biasSplitGeluKernel<float, 5120, 256>(float const*, float const*, float*);
template __global__ void biasSplitGeluKernel<float, 3072, 256>(float const*, float const*, float*);
template __global__ void biasSplitGeluKernel<float, 6144, 256>(float const*, float const*, float*);
template __global__ void biasSplitGeluKernel<half, 1280, 256>(half const*, half const*, half*);
template __global__ void biasSplitGeluKernel<half, 2560, 256>(half const*, half const*, half*);
template __global__ void biasSplitGeluKernel<half, 5120, 256>(half const*, half const*, half*);
template __global__ void biasSplitGeluKernel<half, 3072, 256>(half const*, half const*, half*);
template __global__ void biasSplitGeluKernel<half, 6144, 256>(half const*, half const*, half*);

template void LaunchBiasSplitGeluKernel<float>(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size,
float const* input, float const* bias, float* output);
Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/test/contrib_ops/bias_split_gelu_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,20 @@ TEST(BiasSplitGeluTest, BiasSplitGeluTest_HiddenSize_10240) {
RunBiasSplitGeluTest(batch_size, sequence_length, hidden_size);
}

TEST(BiasSplitGeluTest, BiasSplitGeluTest_HiddenSize_6144) {
constexpr int64_t batch_size = 2;
constexpr int64_t sequence_length = 3;
constexpr int64_t hidden_size = 6144;
RunBiasSplitGeluTest(batch_size, sequence_length, hidden_size);
}

TEST(BiasSplitGeluTest, BiasSplitGeluTest_HiddenSize_12288) {
constexpr int64_t batch_size = 1;
constexpr int64_t sequence_length = 2;
constexpr int64_t hidden_size = 12288;
RunBiasSplitGeluTest(batch_size, sequence_length, hidden_size);
}

#endif

} // namespace test
Expand Down

0 comments on commit d637111

Please sign in to comment.