Skip to content

Commit

Permalink
Revert "add split3inner (#19886)"
Browse files Browse the repository at this point in the history
This reverts commit a1bbfeb.
  • Loading branch information
cloudhan authored Jul 2, 2024
1 parent beb2496 commit bdc894e
Show file tree
Hide file tree
Showing 4 changed files with 0 additions and 193 deletions.
25 changes: 0 additions & 25 deletions onnxruntime/core/providers/cuda/tensor/split.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,31 +76,6 @@ Status SplitKernel::ComputeInternal(OpKernelContext* ctx) const {
auto input_dims = input_shape.GetDims();
auto output_dimensions{input_shape.AsShapeVector()};

if (split_sizes.size() == 3 && ((axis + 1) == gsl::narrow_cast<int64_t>(input_shape.NumDimensions()))) {
// we use (axis + 1) == num_dimensions to check if we are splitting on inner most axis.
// only when split on inner axis and output size is 3, we can use Split3Inner.
// this kernel is not using pin_memory, so it is ok for using cuda graph.
output_dimensions[axis] = split_sizes[0];
Tensor* output0 = ctx->Output(0, TensorShape{output_dimensions});
output_dimensions[axis] = split_sizes[1];
Tensor* output1 = ctx->Output(1, TensorShape{output_dimensions});
output_dimensions[axis] = split_sizes[2];
Tensor* output2 = ctx->Output(2, TensorShape{output_dimensions});

// if input tensor is empty, we don't need to launch kernel, but still need to set output tensor.
if (input_tensor->Shape().Size() <= 0) return Status::OK();

return Split3Inner(Stream(ctx),
input_tensor->DataType()->Size(),
split_sizes[0], split_sizes[1],
split_sizes[2],
input_tensor->DataRaw(),
output0->MutableDataRaw(),
output1->MutableDataRaw(),
output2->MutableDataRaw(),
input_dims);
}

CudaAsyncBuffer<void*> output_ptr(this, num_outputs);
gsl::span<void*> output_ptr_span = output_ptr.CpuSpan();
TensorShapeVector axis_dimension_input_output_mapping(input_dims[axis]);
Expand Down
107 changes: 0 additions & 107 deletions onnxruntime/core/providers/cuda/tensor/split_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -157,112 +157,5 @@ Status SplitImpl(cudaStream_t stream, const size_t element_size, const int block
return Status::OK();
}

template <typename T>
__global__ void _Split3InnerKernel(const int64_t size0_in_byte,
const int64_t size1_in_byte,
const int64_t size2_in_byte,
const void* input_data,
void* output_data0,
void* output_data1,
void* output_data2,
const int64_t inner_size_in_byte) {
// each block copy one row of input data
auto size0 = size0_in_byte / sizeof(T);
auto size1 = size1_in_byte / sizeof(T);
auto size2 = size2_in_byte / sizeof(T);
auto inner_size = inner_size_in_byte / sizeof(T);
auto output0_vec = reinterpret_cast<T*>(output_data0) + blockIdx.x * size0;
auto output1_vec = reinterpret_cast<T*>(output_data1) + blockIdx.x * size1;
auto output2_vec = reinterpret_cast<T*>(output_data2) + blockIdx.x * size2;
auto input_vec = reinterpret_cast<const T*>(input_data) + blockIdx.x * inner_size;
// all size and pointer are aligned to sizeof(T)
// so here use all threads in the block to do vectorized copy

for (auto tid = threadIdx.x; tid < inner_size; tid += blockDim.x) {
auto data = input_vec[tid];
if (tid < size0) {
output0_vec[tid] = data;
} else if (tid < (size0 + size1)) {
output1_vec[tid - size0] = data;
} else {
output2_vec[tid - size0 - size1] = data;
}
}
}

Status Split3Inner(cudaStream_t stream, const size_t element_size, const int64_t size0, const int64_t size1,
const int64_t size2, const void* input_data, void* output_data0, void* output_data1,
void* output_data2, const gsl::span<const int64_t>& input_shape) {
CUDA_LONG outer_size = 1;
for (size_t i = 0; i < input_shape.size() - 1; ++i) {
outer_size *= static_cast<CUDA_LONG>(input_shape[i]);
}
CUDA_LONG inner_size_in_byte = static_cast<CUDA_LONG>(input_shape[input_shape.size() - 1] * element_size);

auto select = [](size_t value) {
if (value % 16 == 0) {
return 16;
} else if (value % 8 == 0) {
return 8;
} else if (value % 4 == 0) {
return 4;
} else if (value % 2 == 0) {
return 2;
} else {
return 1;
}
};

auto input_v = reinterpret_cast<size_t>(input_data);
auto output_v0 = reinterpret_cast<size_t>(output_data0);
auto output_v1 = reinterpret_cast<size_t>(output_data1);
auto output_v2 = reinterpret_cast<size_t>(output_data2);
auto size0_in_byte = size0 * element_size;
auto size1_in_byte = size1 * element_size;
auto size2_in_byte = size2 * element_size;

auto VEC_SIZE = std::min(select(size0_in_byte), std::min(select(size1_in_byte), select(size2_in_byte)));
auto min_output_vec_size = std::min(select(output_v0), std::min(select(output_v1), select(output_v2)));
VEC_SIZE = std::min(VEC_SIZE, std::min(select(input_v), min_output_vec_size));

// determine threads based on the size of the output
auto threadsPerBlock = kNumThreadsPerBlock;
if ((inner_size_in_byte / VEC_SIZE) <= 128) {
// use less threads when the size is small
threadsPerBlock = 128;
}

switch (VEC_SIZE) {
#define CASE_ELEMENT_TYPE(type) \
_Split3InnerKernel<type><<<outer_size, threadsPerBlock, 0, stream>>>( \
size0_in_byte, \
size1_in_byte, \
size2_in_byte, \
input_data, \
output_data0, \
output_data1, \
output_data2, \
inner_size_in_byte)
case 16:
CASE_ELEMENT_TYPE(int4);
break;
case 8:
CASE_ELEMENT_TYPE(int64_t);
break;
case 4:
CASE_ELEMENT_TYPE(int32_t);
break;
case 2:
CASE_ELEMENT_TYPE(int16_t);
break;
default:
CASE_ELEMENT_TYPE(int8_t);
break;
#undef CASE_ELEMENT_TYPE
}

return Status::OK();
}

} // namespace cuda
} // namespace onnxruntime
4 changes: 0 additions & 4 deletions onnxruntime/core/providers/cuda/tensor/split_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,5 @@ Status SplitImpl(cudaStream_t stream, const size_t element_size, const int block
const int64_t* axis_dimension_input_output_mapping, const int num_outputs, const void* input_data,
void** output_data, const size_t input_size);

Status Split3Inner(cudaStream_t stream, const size_t element_size, const int64_t size0, const int64_t size1,
const int64_t size2, const void* input_data, void* output_data0, void* output_data1,
void* output_data2, const gsl::span<const int64_t>& input_shape);

} // namespace cuda
} // namespace onnxruntime
57 changes: 0 additions & 57 deletions onnxruntime/test/providers/cpu/tensor/split_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -815,62 +815,5 @@ TEST(SplitOperatorTest, Split18_NumOutputsUnevenSplitAxis1) {
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, false, true, num_outputs, false);
}

TEST(SplitOperatorTest, Split3Inner) {
constexpr int64_t axis = -1;
using ShapeAndDataT = ShapeAndData<uint8_t>;
std::vector<ShapeAndDataT> outputs;
int64_t num_outputs = -1; // when provides split_sizes, then num_outputs should not be provided
const int batch = 16;
const int data_len = 96; // should be multiple of 3

// create input with shape {b, l}, and data from 1 ~ b*l
auto input = CreateInput<uint8_t>({batch, data_len}); // input is 1.f ~ 48.f

// slice the input data by start and end in axis of -1
auto gen_output = [&](int start, int end) {
std::vector<uint8_t> data0;
auto input_data = input.second;
for (int b = 0; b < batch; b++) {
for (int i = start; i < end; i++) {
data0.push_back(input_data[b * data_len + i]);
}
}
return ShapeAndDataT{{batch, end - start}, data0};
};

auto do_test = [&](std::vector<int>& splits) {
outputs.clear();
outputs.push_back(gen_output(0, splits[0]));
outputs.push_back(gen_output(splits[0], splits[1]));
outputs.push_back(gen_output(splits[1], data_len));

RunTest<uint8_t>(axis, {splits[0], splits[1] - splits[0], data_len - splits[1]}, input, outputs, {kTensorrtExecutionProvider, kQnnExecutionProvider}, false, true, num_outputs);
};

// split into 3 same size, and aligned to 16
std::vector<int> splits{data_len / 3, data_len / 3 * 2};
do_test(splits);

// test split with data alignment is 8
splits[0] = splits[0] + 8;
splits[1] = splits[1] + 8;
do_test(splits);

// test split with data alignment is 4
splits[0] = splits[0] + 4;
splits[1] = splits[1] + 4;
do_test(splits);

// test split with data alignment is 2
splits[0] = splits[0] + 2;
splits[1] = splits[1] + 2;
do_test(splits);

// test split with data alignment is 1
splits[0] = splits[0] + 1;
splits[1] = splits[1] + 1;
do_test(splits);
}

} // namespace test
} // namespace onnxruntime

0 comments on commit bdc894e

Please sign in to comment.