Skip to content

Commit

Permalink
Update BiasSplitGelu for SDXL Refiner
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Oct 9, 2023
1 parent ba72bb6 commit ff9fa23
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
8 changes: 6 additions & 2 deletions onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,13 @@ Status BiasSplitGelu<T>::ComputeInternal(OpKernelContext* context) const {
"input is expected to have 3 dimensions, got ", input_dims.size());
}

if (input_dims[2] != 2560 && input_dims[2] != 5120 && input_dims[2] != 10240) {
if (input_dims[2] != 2560 &&
input_dims[2] != 5120 &&
input_dims[2] != 6144 &&
input_dims[2] != 10240 &&
input_dims[2] != 12288) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"hidden size should be 2560, 5120 or 10240, got ", input_dims[2]);
"hidden size should be 2560, 5120, 6144, 10240 or 12288, got ", input_dims[2]);
}

const Tensor* bias = context->Input<Tensor>(1);
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/contrib_ops/cuda/diffusion/bias_split_gelu_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@ void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t h
case 5120:
(biasSplitGeluKernel<T, 5120, TPB>)<<<grid_size, TPB, 0, stream>>>(input, bias, output);
break;
case 3072:
(biasSplitGeluKernel<T, 3072, TPB>)<<<grid_size, TPB, 0, stream>>>(input, bias, output);
break;
case 6144:
(biasSplitGeluKernel<T, 6144, TPB>)<<<grid_size, TPB, 0, stream>>>(input, bias, output);
break;
default:
ORT_NOT_IMPLEMENTED("Not implemented");
}
Expand All @@ -73,9 +79,13 @@ void LaunchBiasSplitGeluKernel(cudaStream_t stream, int32_t grid_size, int32_t h
template __global__ void biasSplitGeluKernel<float, 1280, 256>(float const*, float const*, float*);
template __global__ void biasSplitGeluKernel<float, 2560, 256>(float const*, float const*, float*);
template __global__ void biasSplitGeluKernel<float, 5120, 256>(float const*, float const*, float*);
template __global__ void biasSplitGeluKernel<float, 3072, 256>(float const*, float const*, float*);
template __global__ void biasSplitGeluKernel<float, 6144, 256>(float const*, float const*, float*);
template __global__ void biasSplitGeluKernel<half, 1280, 256>(half const*, half const*, half*);
template __global__ void biasSplitGeluKernel<half, 2560, 256>(half const*, half const*, half*);
template __global__ void biasSplitGeluKernel<half, 5120, 256>(half const*, half const*, half*);
template __global__ void biasSplitGeluKernel<half, 3072, 256>(half const*, half const*, half*);
template __global__ void biasSplitGeluKernel<half, 6144, 256>(half const*, half const*, half*);

template void LaunchBiasSplitGeluKernel<float>(cudaStream_t stream, int32_t grid_size, int32_t half_hidden_size,
float const* input, float const* bias, float* output);
Expand Down

0 comments on commit ff9fa23

Please sign in to comment.