diff --git a/onnxruntime/core/providers/cuda/tensor/split.cc b/onnxruntime/core/providers/cuda/tensor/split.cc index 5f73512ab8696..ca82387600085 100644 --- a/onnxruntime/core/providers/cuda/tensor/split.cc +++ b/onnxruntime/core/providers/cuda/tensor/split.cc @@ -76,6 +76,31 @@ Status SplitKernel::ComputeInternal(OpKernelContext* ctx) const { auto input_dims = input_shape.GetDims(); auto output_dimensions{input_shape.AsShapeVector()}; + if (split_sizes.size() == 3 && ((axis + 1) == gsl::narrow_cast(input_shape.NumDimensions()))) { + // we use (axis + 1) == num_dimensions to check if we are splitting on inner most axis. + // only when split on inner axis and output size is 3, we can use Split3Inner. + // this kernel is not using pin_memory, so it is ok for using cuda graph. + output_dimensions[axis] = split_sizes[0]; + Tensor* output0 = ctx->Output(0, TensorShape{output_dimensions}); + output_dimensions[axis] = split_sizes[1]; + Tensor* output1 = ctx->Output(1, TensorShape{output_dimensions}); + output_dimensions[axis] = split_sizes[2]; + Tensor* output2 = ctx->Output(2, TensorShape{output_dimensions}); + + // if input tensor is empty, we don't need to launch kernel, but still need to set output tensor. + if (input_tensor->Shape().Size() <= 0) return Status::OK(); + + return Split3Inner(Stream(ctx), + input_tensor->DataType()->Size(), + split_sizes[0], split_sizes[1], + split_sizes[2], + input_tensor->DataRaw(), + output0->MutableDataRaw(), + output1->MutableDataRaw(), + output2->MutableDataRaw(), + input_dims); + } + CudaAsyncBuffer output_ptr(this, num_outputs); gsl::span output_ptr_span = output_ptr.CpuSpan(); TensorShapeVector axis_dimension_input_output_mapping(input_dims[axis]); diff --git a/onnxruntime/core/providers/cuda/tensor/split_impl.cu b/onnxruntime/core/providers/cuda/tensor/split_impl.cu index b0ff856a43970..00f94694f83c0 100644 --- a/onnxruntime/core/providers/cuda/tensor/split_impl.cu +++ b/onnxruntime/core/providers/cuda/tensor/split_impl.cu @@ -157,5 +157,112 @@ Status SplitImpl(cudaStream_t stream, const size_t element_size, const int block return Status::OK(); } +template +__global__ void _Split3InnerKernel(const int64_t size0_in_byte, + const int64_t size1_in_byte, + const int64_t size2_in_byte, + const void* input_data, + void* output_data0, + void* output_data1, + void* output_data2, + const int64_t inner_size_in_byte) { + // each block copy one row of input data + auto size0 = size0_in_byte / sizeof(T); + auto size1 = size1_in_byte / sizeof(T); + auto size2 = size2_in_byte / sizeof(T); + auto inner_size = inner_size_in_byte / sizeof(T); + auto output0_vec = reinterpret_cast(output_data0) + blockIdx.x * size0; + auto output1_vec = reinterpret_cast(output_data1) + blockIdx.x * size1; + auto output2_vec = reinterpret_cast(output_data2) + blockIdx.x * size2; + auto input_vec = reinterpret_cast(input_data) + blockIdx.x * inner_size; + // all size and pointer are aligned to sizeof(T) + // so here use all threads in the block to do vectorized copy + + for (auto tid = threadIdx.x; tid < inner_size; tid += blockDim.x) { + auto data = input_vec[tid]; + if (tid < size0) { + output0_vec[tid] = data; + } else if (tid < (size0 + size1)) { + output1_vec[tid - size0] = data; + } else { + output2_vec[tid - size0 - size1] = data; + } + } +} + +Status Split3Inner(cudaStream_t stream, const size_t element_size, const int64_t size0, const int64_t size1, + const int64_t size2, const void* input_data, void* output_data0, void* output_data1, + void* output_data2, const gsl::span& input_shape) { + CUDA_LONG outer_size = 1; + for (size_t i = 0; i < input_shape.size() - 1; ++i) { + outer_size *= static_cast(input_shape[i]); + } + CUDA_LONG inner_size_in_byte = static_cast(input_shape[input_shape.size() - 1] * element_size); + + auto select = [](size_t value) { + if (value % 16 == 0) { + return 16; + } else if (value % 8 == 0) { + return 8; + } else if (value % 4 == 0) { + return 4; + } else if (value % 2 == 0) { + return 2; + } else { + return 1; + } + }; + + auto input_v = reinterpret_cast(input_data); + auto output_v0 = reinterpret_cast(output_data0); + auto output_v1 = reinterpret_cast(output_data1); + auto output_v2 = reinterpret_cast(output_data2); + auto size0_in_byte = size0 * element_size; + auto size1_in_byte = size1 * element_size; + auto size2_in_byte = size2 * element_size; + + auto VEC_SIZE = std::min(select(size0_in_byte), std::min(select(size1_in_byte), select(size2_in_byte))); + auto min_output_vec_size = std::min(select(output_v0), std::min(select(output_v1), select(output_v2))); + VEC_SIZE = std::min(VEC_SIZE, std::min(select(input_v), min_output_vec_size)); + + // determine threads based on the size of the output + auto threadsPerBlock = kNumThreadsPerBlock; + if ((inner_size_in_byte / VEC_SIZE) <= 128) { + // use less threads when the size is small + threadsPerBlock = 128; + } + + switch (VEC_SIZE) { +#define CASE_ELEMENT_TYPE(type) \ + _Split3InnerKernel<<>>( \ + size0_in_byte, \ + size1_in_byte, \ + size2_in_byte, \ + input_data, \ + output_data0, \ + output_data1, \ + output_data2, \ + inner_size_in_byte) + case 16: + CASE_ELEMENT_TYPE(int4); + break; + case 8: + CASE_ELEMENT_TYPE(int64_t); + break; + case 4: + CASE_ELEMENT_TYPE(int32_t); + break; + case 2: + CASE_ELEMENT_TYPE(int16_t); + break; + default: + CASE_ELEMENT_TYPE(int8_t); + break; +#undef CASE_ELEMENT_TYPE + } + + return Status::OK(); +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/tensor/split_impl.h b/onnxruntime/core/providers/cuda/tensor/split_impl.h index 16961cfb7d22d..62e0da7716608 100644 --- a/onnxruntime/core/providers/cuda/tensor/split_impl.h +++ b/onnxruntime/core/providers/cuda/tensor/split_impl.h @@ -19,5 +19,9 @@ Status SplitImpl(cudaStream_t stream, const size_t element_size, const int block const int64_t* axis_dimension_input_output_mapping, const int num_outputs, const void* input_data, void** output_data, const size_t input_size); +Status Split3Inner(cudaStream_t stream, const size_t element_size, const int64_t size0, const int64_t size1, + const int64_t size2, const void* input_data, void* output_data0, void* output_data1, + void* output_data2, const gsl::span& input_shape); + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc index 15a7d7cd9fdbf..066302a4a37d1 100644 --- a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc @@ -815,5 +815,62 @@ TEST(SplitOperatorTest, Split18_NumOutputsUnevenSplitAxis1) { RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, false, true, num_outputs, false); } +TEST(SplitOperatorTest, Split3Inner) { + constexpr int64_t axis = -1; + using ShapeAndDataT = ShapeAndData; + std::vector outputs; + int64_t num_outputs = -1; // when provides split_sizes, then num_outputs should not be provided + const int batch = 16; + const int data_len = 96; // should be multiple of 3 + + // create input with shape {b, l}, and data from 1 ~ b*l + auto input = CreateInput({batch, data_len}); // input is 1.f ~ 48.f + + // slice the input data by start and end in axis of -1 + auto gen_output = [&](int start, int end) { + std::vector data0; + auto input_data = input.second; + for (int b = 0; b < batch; b++) { + for (int i = start; i < end; i++) { + data0.push_back(input_data[b * data_len + i]); + } + } + return ShapeAndDataT{{batch, end - start}, data0}; + }; + + auto do_test = [&](std::vector& splits) { + outputs.clear(); + outputs.push_back(gen_output(0, splits[0])); + outputs.push_back(gen_output(splits[0], splits[1])); + outputs.push_back(gen_output(splits[1], data_len)); + + RunTest(axis, {splits[0], splits[1] - splits[0], data_len - splits[1]}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, false, true, num_outputs); + }; + + // split into 3 same size, and aligned to 16 + std::vector splits{data_len / 3, data_len / 3 * 2}; + do_test(splits); + + // test split with data alignment is 8 + splits[0] = splits[0] + 8; + splits[1] = splits[1] + 8; + do_test(splits); + + // test split with data alignment is 4 + splits[0] = splits[0] + 4; + splits[1] = splits[1] + 4; + do_test(splits); + + // test split with data alignment is 2 + splits[0] = splits[0] + 2; + splits[1] = splits[1] + 2; + do_test(splits); + + // test split with data alignment is 1 + splits[0] = splits[0] + 1; + splits[1] = splits[1] + 1; + do_test(splits); +} + } // namespace test } // namespace onnxruntime