Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add split3inner #19886

Merged
merged 15 commits into from
Jun 27, 2024
25 changes: 25 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/split.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,31 @@ Status SplitKernel::ComputeInternal(OpKernelContext* ctx) const {
auto input_dims = input_shape.GetDims();
auto output_dimensions{input_shape.AsShapeVector()};

if (split_sizes.size() == 3 && (block_size_inside_axis_dim == 1)) {
kailums marked this conversation as resolved.
Show resolved Hide resolved
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
// we use block_size_inside_axis_dim == 1 to check if we are splitting on inner most axis.
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
// only when split on inner axis and output size is 3, we can use Split3Inner.
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
// this kernel is not using pin_memory, so it is ok for using cuda graph.
output_dimensions[axis] = split_sizes[0];
Tensor* output0 = ctx->Output(0, TensorShape{output_dimensions});
output_dimensions[axis] = split_sizes[1];
Tensor* output1 = ctx->Output(1, TensorShape{output_dimensions});
output_dimensions[axis] = split_sizes[2];
Tensor* output2 = ctx->Output(2, TensorShape{output_dimensions});

// if input tensor is empty, we don't need to launch kernel, but still need to set output tensor.
if (input_tensor->Shape().Size() <= 0) return Status::OK();

return Split3Inner(Stream(ctx),
input_tensor->DataType()->Size(),
split_sizes[0], split_sizes[1],
split_sizes[2],
input_tensor->DataRaw(),
output0->MutableDataRaw(),
output1->MutableDataRaw(),
output2->MutableDataRaw(),
input_dims);
}

CudaAsyncBuffer<void*> output_ptr(this, num_outputs);
gsl::span<void*> output_ptr_span = output_ptr.CpuSpan();
TensorShapeVector axis_dimension_input_output_mapping(input_dims[axis]);
Expand Down
59 changes: 59 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/split_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -157,5 +157,64 @@ Status SplitImpl(cudaStream_t stream, const size_t element_size, const int block
return Status::OK();
}

template <typename T>
__global__ void _Split3InnerKernel(const int64_t size0,
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
const int64_t size1,
const int64_t size2,
const T* input_data,
T* output_data0,
T* output_data1,
T* output_data2,
const int64_t outer_size,
const int64_t inner_size) {
int64_t data_id = blockIdx.x * blockDim.x + threadIdx.x;
int64_t row_id = data_id / inner_size;
int64_t col_id = data_id % inner_size;

if (row_id >= outer_size || col_id >= inner_size) {
return;
}
if (col_id < size0) {
output_data0[row_id * size0 + col_id] = input_data[data_id];
} else if (col_id < size0 + size1) {
output_data1[row_id * size1 + col_id - size0] = input_data[data_id];
} else {
output_data2[row_id * size2 + col_id - size0 - size1] = input_data[data_id];
}
}

Status Split3Inner(cudaStream_t stream, const size_t element_size, const int64_t size0, const int64_t size1,
const int64_t size2, const void* input_data, void* output_data0, void* output_data1,
void* output_data2, const gsl::span<const int64_t>& input_shape) {
CUDA_LONG outer_size = 1;
for (size_t i = 0; i < input_shape.size() - 1; ++i) {
outer_size *= static_cast<CUDA_LONG>(input_shape[i]);
}
CUDA_LONG inner_size = static_cast<CUDA_LONG>(input_shape[input_shape.size() - 1]);

CUDA_LONG N = outer_size * inner_size;
int blocksPerGrid = CeilDiv(N, kNumThreadsPerBlock);

switch (element_size) {
#define CASE_ELEMENT_TYPE(type) \
case sizeof(type): { \
_Split3InnerKernel<<<blocksPerGrid, kNumThreadsPerBlock, 0, stream>>>(size0, size1, size2, \
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
reinterpret_cast<const ToCudaType<type>::MappedType*>(input_data), \
reinterpret_cast<ToCudaType<type>::MappedType*>(output_data0), \
reinterpret_cast<ToCudaType<type>::MappedType*>(output_data1), \
reinterpret_cast<ToCudaType<type>::MappedType*>(output_data2), outer_size, inner_size); \
} break
CASE_ELEMENT_TYPE(int8_t);
CASE_ELEMENT_TYPE(int16_t);
CASE_ELEMENT_TYPE(int32_t);
CASE_ELEMENT_TYPE(int64_t);
#undef CASE_ELEMENT_TYPE
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Split3Inner operator");
}

return Status::OK();
}

} // namespace cuda
} // namespace onnxruntime
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/split_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,9 @@ Status SplitImpl(cudaStream_t stream, const size_t element_size, const int block
const int64_t* axis_dimension_input_output_mapping, const int num_outputs, const void* input_data,
void** output_data, const size_t input_size);

Status Split3Inner(cudaStream_t stream, const size_t element_size, const int64_t size0, const int64_t size1,
const int64_t size2, const void* input_data, void* output_data0, void* output_data1,
void* output_data2, const gsl::span<const int64_t>& input_shape);

} // namespace cuda
} // namespace onnxruntime
35 changes: 35 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/split_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -815,5 +815,40 @@
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, false, true, num_outputs, false);
}

TEST(SplitOperatorTest, Split3Inner) {
constexpr int64_t axis = -1;
std::vector<ShapeAndFloatData> outputs;

// input shape and data
ShapeAndFloatData input = {{2, 6}, // shape
{
1.f,
2.f,
3.f,
4.f,
5.f,
6.f,
7.f,
8.f,
9.f,
10.f,
11.f,
12.f,
}};

outputs.push_back({{2, 3},
{1.f, 2.f, 3.f,
7.f, 8.f, 9.f}});
outputs.push_back({{2, 2},
{4.f, 5.f,
10.f, 11.f}});
outputs.push_back({{2, 1},
{6.f, 12.f}});

int64_t num_outputs = -1; // when provides split_sizes, then num_outputs should not be provided
RunTest<float>(axis, {3, 2, 1}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, false, true, num_outputs);

Check warning on line 849 in onnxruntime/test/providers/cpu/tensor/split_op_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/providers/cpu/tensor/split_op_test.cc:849: Lines should be <= 120 characters long [whitespace/line_length] [2]
RunTest<float>(axis, {3, 2, 1}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, false, true, num_outputs, false);

Check warning on line 850 in onnxruntime/test/providers/cpu/tensor/split_op_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/providers/cpu/tensor/split_op_test.cc:850: Lines should be <= 120 characters long [whitespace/line_length] [2]
}

} // namespace test
} // namespace onnxruntime
Loading