Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add split3inner #19886

Merged
merged 15 commits into from
Jun 27, 2024
22 changes: 22 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/split.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,28 @@ Status SplitKernel::ComputeInternal(OpKernelContext* ctx) const {
auto input_dims = input_shape.GetDims();
auto output_dimensions{input_shape.AsShapeVector()};

if (split_sizes.size() == 3 && (block_size_inside_axis_dim == 1)) {
kailums marked this conversation as resolved.
Show resolved Hide resolved
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
// we use block_size_inside_axis_dim == 1 to check if we are splitting on inner most axis.
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
// only when split on inner axis and output size is 3, we can use Split3Inner.
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
// this kernel is not using pin_memory, so it is ok for using cuda graph.
output_dimensions[axis] = split_sizes[0];
Tensor* output0 = ctx->Output(0, TensorShape{output_dimensions});
output_dimensions[axis] = split_sizes[1];
Tensor* output1 = ctx->Output(1, TensorShape{output_dimensions});
output_dimensions[axis] = split_sizes[2];
Tensor* output2 = ctx->Output(2, TensorShape{output_dimensions});

return Split3Inner(Stream(ctx),
input_tensor->DataType()->Size(),
split_sizes[0], split_sizes[1],
split_sizes[2],
input_tensor->DataRaw(),
output0->MutableDataRaw(),
output1->MutableDataRaw(),
output2->MutableDataRaw(),
input_dims);
}

CudaAsyncBuffer<void*> output_ptr(this, num_outputs);
gsl::span<void*> output_ptr_span = output_ptr.CpuSpan();
TensorShapeVector axis_dimension_input_output_mapping(input_dims[axis]);
Expand Down
62 changes: 62 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/split_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -157,5 +157,67 @@
return Status::OK();
}

template <typename T>
__global__ void _Split3InnerKernel(const int64_t size0,
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
const int64_t size1,
const int64_t size2,
const T* input_data,
T* output_data0,
T* output_data1,
T* output_data2,
const int64_t outer_size,
const int64_t inner_size) {
int64_t data_id = blockIdx.x * blockDim.x + threadIdx.x;
int64_t row_id = data_id / inner_size;
int64_t col_id = data_id % inner_size;

if (row_id >= outer_size || col_id >= inner_size) {
return;
}
if (col_id < size0) {
output_data0[row_id * size0 + col_id] = input_data[data_id];
} else if (col_id < size0 + size1) {
output_data1[row_id * size1 + col_id - size0] = input_data[data_id];
} else {
output_data2[row_id * size2 + col_id - size0 - size1] = input_data[data_id];
}
}

Status Split3Inner(cudaStream_t stream, const size_t element_size, const int64_t size0, const int64_t size1,
const int64_t size2, const void* input_data, void* output_data0, void* output_data1,
void* output_data2, const gsl::span<const int64_t>& input_shape) {
int64_t outer_size = 1;
for (size_t i = 0; i < input_shape.size() - 1; ++i) {
outer_size *= input_shape[i];
}
int64_t inner_size = input_shape[input_shape.size() - 1];
assert (inner_size == (size0 + size1 + size2));

Check warning on line 194 in onnxruntime/core/providers/cuda/tensor/split_impl.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Extra space before ( in function call [whitespace/parens] [4] Raw Output: onnxruntime/core/providers/cuda/tensor/split_impl.cu:194: Extra space before ( in function call [whitespace/parens] [4]

int64_t N = outer_size * inner_size;
int blocksPerGrid = CeilDiv(N, kNumThreadsPerBlock);
dim3 block(kNumThreadsPerBlock);
dim3 grid(blocksPerGrid);

switch (element_size) {
#define CASE_ELEMENT_TYPE(type) \
case sizeof(type): { \
_Split3InnerKernel<<<grid, block, 0, stream>>>(size0, size1, size2, \
reinterpret_cast<const ToCudaType<type>::MappedType*>(input_data), \
reinterpret_cast<ToCudaType<type>::MappedType*>(output_data0), \
reinterpret_cast<ToCudaType<type>::MappedType*>(output_data1), \
reinterpret_cast<ToCudaType<type>::MappedType*>(output_data2), outer_size, inner_size); \
} break
CASE_ELEMENT_TYPE(int8_t);
CASE_ELEMENT_TYPE(int16_t);
CASE_ELEMENT_TYPE(int32_t);
CASE_ELEMENT_TYPE(int64_t);
#undef CASE_ELEMENT_TYPE
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Type not supported for Split3Inner operator");
}

return Status::OK();
}

} // namespace cuda
} // namespace onnxruntime
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/split_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,9 @@ Status SplitImpl(cudaStream_t stream, const size_t element_size, const int block
const int64_t* axis_dimension_input_output_mapping, const int num_outputs, const void* input_data,
void** output_data, const size_t input_size);

Status Split3Inner(cudaStream_t stream, const size_t element_size, const int64_t size0, const int64_t size1,
const int64_t size2, const void* input_data, void* output_data0, void* output_data1,
void* output_data2, const gsl::span<const int64_t>& input_shape);

} // namespace cuda
} // namespace onnxruntime
35 changes: 35 additions & 0 deletions onnxruntime/test/providers/cpu/tensor/split_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -815,5 +815,40 @@
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, false, true, num_outputs, false);
}

TEST(SplitOperatorTest, Split3Inner) {
constexpr int64_t axis = -1;
std::vector<ShapeAndFloatData> outputs;

// input shape and data
ShapeAndFloatData input = {{2, 6}, // shape
{
1.f,
2.f,
3.f,
4.f,
5.f,
6.f,
7.f,
8.f,
9.f,
10.f,
11.f,
12.f,
}};

outputs.push_back({{2, 3},
{1.f, 2.f, 3.f,
7.f, 8.f, 9.f}});
outputs.push_back({{2, 2},
{4.f, 5.f,
10.f, 11.f}});
outputs.push_back({{2, 1},
{6.f, 12.f}});

int64_t num_outputs = -1; // when provides split_sizes, then num_outputs should not be provided
RunTest<float>(axis, {3, 2, 1}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, false, true, num_outputs);

Check warning on line 849 in onnxruntime/test/providers/cpu/tensor/split_op_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/providers/cpu/tensor/split_op_test.cc:849: Lines should be <= 120 characters long [whitespace/line_length] [2]
RunTest<float>(axis, {3, 2, 1}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, false, true, num_outputs, false);

Check warning on line 850 in onnxruntime/test/providers/cpu/tensor/split_op_test.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/test/providers/cpu/tensor/split_op_test.cc:850: Lines should be <= 120 characters long [whitespace/line_length] [2]
}

} // namespace test
} // namespace onnxruntime
Loading