diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 16df788c284ee..edf249a816923 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -373,7 +373,7 @@ Do not modify directly.*
|||[13, 17]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[2, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|SplitToSequence|*in* input:**T**
*in* split:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))
**T** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string)|
+|SplitToSequence|*in* input:**T**
*in* split:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))
**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(string)|
|Sqrt|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float)|
|||[6, 12]|**T** = tensor(double), tensor(float)|
|Squeeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* squeezed:**T**
or
*in* data:**T**
*out* squeezed:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
diff --git a/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc b/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc
index 4759938cd8250..8064bc0a58cb1 100644
--- a/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc
+++ b/onnxruntime/core/providers/cpu/sequence/sequence_ops.cc
@@ -334,27 +334,14 @@ Status SequenceConstruct::Compute(OpKernelContext* context) const {
// SplitToSequence
-namespace op_kernel_type_control {
-ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS(
- kCpuExecutionProvider, kOnnxDomain, SplitToSequence, Input, 0,
- float, double, int32_t, int64_t, std::string);
-} // namespace op_kernel_type_control
-
-namespace {
-using EnabledSplitToSequenceDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(
- kCpuExecutionProvider, kOnnxDomain, SplitToSequence, Input, 0);
-} // namespace
-
ONNX_CPU_OPERATOR_KERNEL(
SplitToSequence,
11,
KernelDefBuilder()
.TypeConstraint("T",
- BuildKernelDefConstraintsFromTypeList())
+ BuildKernelDefConstraints())
.TypeConstraint("S", DataTypeImpl::AllSequenceTensorTypes())
- .TypeConstraint("I", std::vector{
- DataTypeImpl::GetTensorType(),
- DataTypeImpl::GetTensorType()}),
+ .TypeConstraint("I", BuildKernelDefConstraints()),
SplitToSequence);
SplitToSequence::SplitToSequence(const OpKernelInfo& info) : OpKernel(info) {
@@ -366,29 +353,14 @@ Status SplitToSequence::Compute(OpKernelContext* context) const {
const Tensor& input = *context->Input(0);
const Tensor* p_split_input = context->Input(1);
- Status status;
-
- if (input.IsDataType())
- status = ComputeImpl(*context, input, p_split_input);
- else if (input.IsDataType())
- status = ComputeImpl(*context, input, p_split_input);
- else if (input.IsDataType())
- status = ComputeImpl(*context, input, p_split_input);
- else if (input.IsDataType())
- status = ComputeImpl(*context, input, p_split_input);
- else if (input.IsDataTypeString())
- status = ComputeImpl(*context, input, p_split_input);
- else
- status = ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "SplitToSequence operator does not support ", input.DataType(), " yet");
-
- return status;
+ return ComputeImpl(*context, input, p_split_input);
}
Status SplitToSequence::PrepareForCompute(const TensorShape& input_shape, int64_t split_scalar, bool is_split_input_scalar,
int64_t& num_outputs, int64_t& axis, int& before_dims,
int& after_dims_including_split_axis, int& after_dims_excluding_split,
bool& is_uneven_split, int& num_remaining_splits,
- std::vector& split_sizes) const {
+ InlinedVector& split_sizes) const {
auto input_dims = input_shape.GetDims();
const auto num_dimensions = gsl::narrow_cast(input_shape.NumDimensions());
axis = HandleNegativeAxis(axis_, num_dimensions); // handle negative and enforce axis is valid
@@ -416,7 +388,7 @@ Status SplitToSequence::PrepareForCompute(const TensorShape& input_shape, int64_
// populate split_sizes with the same size for each output
num_outputs = split_dim_size;
// https://github.com/onnx/onnx/issues/2396
- split_sizes = std::vector(static_cast(num_outputs), DEFAULT_LENGTH_EACH_OUTPUT_);
+ split_sizes = InlinedVector(static_cast(num_outputs), DEFAULT_LENGTH_EACH_OUTPUT_);
} else {
auto split_size_sum = std::accumulate(split_sizes.cbegin(), split_sizes.cend(), 0LL);
if (split_size_sum != split_dim_size) {
@@ -453,7 +425,7 @@ static int64_t GetScalarSplitInput(const Tensor& tensor) {
return retval;
}
-static void GetSplitSizesInput(const Tensor& tensor, std::vector& split_sizes) {
+static void GetSplitSizesInput(const Tensor& tensor, InlinedVector& split_sizes) {
auto num_elems = tensor.Shape().Size();
split_sizes.reserve(onnxruntime::narrow(num_elems));
if (tensor.IsDataType()) {
@@ -467,13 +439,8 @@ static void GetSplitSizesInput(const Tensor& tensor, std::vector& split
}
}
-template
Status SplitToSequence::ComputeImpl(OpKernelContext& context, const Tensor& input,
const Tensor* p_split_input) const {
- if (!utils::HasType()) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Data type is not supported in this build.");
- }
-
auto& input_shape = input.Shape();
int64_t num_outputs = 0;
int64_t axis = axis_;
@@ -484,7 +451,9 @@ Status SplitToSequence::ComputeImpl(OpKernelContext& context, const Tensor& inpu
bool is_split_input_scalar = false;
bool is_uneven_split = false;
int num_remaining_splits = 0;
- std::vector split_sizes;
+ InlinedVector split_sizes;
+ const bool is_string_type = input.IsDataTypeString();
+ const size_t element_size = (is_string_type) ? 0U : input.DataType()->Size();
// figure out split_scalar or split_sizes
if (p_split_input) {
@@ -520,8 +489,8 @@ Status SplitToSequence::ComputeImpl(OpKernelContext& context, const Tensor& inpu
// copy dimensions so we can update the selected axis in place
auto output_dimensions = input_shape.AsShapeVector();
- int64_t input_offset = 0;
- const T* input_data = input.Data();
+ SafeInt input_offset = 0;
+ const void* input_data = input.DataRaw();
for (int i = 0; i < num_outputs; ++i) {
// update size of dimension for axis we're splitting on while considering uneven split
int split_size;
@@ -535,20 +504,50 @@ Status SplitToSequence::ComputeImpl(OpKernelContext& context, const Tensor& inpu
AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(context.GetTempSpaceAllocator(&alloc));
Tensor output_tensor(input.DataType(), onnxruntime::TensorShape(output_dimensions), alloc);
- T* output_data = output_tensor.MutableData();
-
- ::onnxruntime::math::CopyMatrix(
- before_dims, // M
- split_size * after_dims_excluding_split, // N
- static_cast(input_data + input_offset), // A
- after_dims_including_split_axis, // lda
- static_cast(output_data), // B
- split_size * after_dims_excluding_split, // ldb
- [](const T* src, T* dst, size_t count) {
- copy_data(src, dst, count);
- });
-
- input_offset += static_cast(split_size) * after_dims_excluding_split; // offset by the N data we used in this iteration
+ void* output_data = output_tensor.MutableDataRaw();
+
+ const auto M = before_dims;
+ const auto* A = static_cast(input_data) + static_cast(input_offset * element_size);
+ const auto lda = after_dims_including_split_axis;
+ auto* B = output_data;
+
+ const auto N = split_size * after_dims_excluding_split;
+ const auto ldb = N;
+
+ if (is_string_type) {
+ const auto* src = reinterpret_cast(A);
+ auto* dst = reinterpret_cast(B);
+ if (lda == N) {
+ copy_data(src, dst, static_cast(M * N));
+ } else {
+ size_t lda_offset = 0;
+ size_t ldb_offset = 0;
+ for (size_t idx = 0; idx < static_cast(M); ++idx,
+ lda_offset += lda, ldb_offset += ldb) {
+ copy_data(src + lda_offset, dst + ldb_offset, static_cast(N));
+ }
+ }
+ } else {
+ if (lda == N) {
+ // if the data is contiguous, we can just copy the data
+ const size_t bytes_to_copy = static_cast(N) * static_cast(M) * element_size;
+ memcpy(B, A, bytes_to_copy);
+ } else {
+ // otherwise we need to copy each row
+ const size_t row_bytes = SafeInt(N) * element_size;
+ const auto lda_bytes_inc = SafeInt(lda) * element_size;
+ const auto ldb_bytes_inc = SafeInt(ldb) * element_size;
+ SafeInt lda_bytes_offset = 0;
+ SafeInt ldb_bytes_offset = 0;
+ for (size_t idx = 0; idx < static_cast(M); ++idx,
+ lda_bytes_offset += lda_bytes_inc, ldb_bytes_offset += ldb_bytes_inc) {
+ memcpy(reinterpret_cast(B) + static_cast(ldb_bytes_offset),
+ reinterpret_cast(A) + static_cast(lda_bytes_offset), row_bytes);
+ }
+ }
+ }
+
+ input_offset += SafeInt(split_size) * after_dims_excluding_split; // offset by the N data we used in this iteration
// if keep_dims = 0, reshape the tensor by dropping the dimension corresponding to 'axis'
if (use_keep_dims && keepdims_ == 0) {
diff --git a/onnxruntime/core/providers/cpu/sequence/sequence_ops.h b/onnxruntime/core/providers/cpu/sequence/sequence_ops.h
index 9466d3f0fd108..ccca226fb07ee 100644
--- a/onnxruntime/core/providers/cpu/sequence/sequence_ops.h
+++ b/onnxruntime/core/providers/cpu/sequence/sequence_ops.h
@@ -60,13 +60,12 @@ class SplitToSequence final : public OpKernel {
Status Compute(OpKernelContext* context) const override;
private:
- template
Status ComputeImpl(OpKernelContext& context, const Tensor& input, const Tensor* p_split_input) const;
Status PrepareForCompute(const TensorShape& input_shape, int64_t split_scalar, bool is_split_input_scalar,
int64_t& num_outputs, int64_t& axis, int& before_dims,
int& after_dims_including_split_axis, int& after_dims_excluding_split,
bool& is_uneven_split, int& num_remaining_splits,
- std::vector& split_sizes) const;
+ InlinedVector& split_sizes) const;
int64_t axis_{};
int64_t keepdims_{1};
const int64_t DEFAULT_LENGTH_EACH_OUTPUT_ = 1;
diff --git a/onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc b/onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc
index d29aac81150c5..60e75811e4333 100644
--- a/onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc
+++ b/onnxruntime/test/providers/cpu/sequence/sequence_ops_test.cc
@@ -330,15 +330,26 @@ TEST(SequenceOpsTest, SequenceConstructPositive) {
// SplitToSequence
template
-static std::vector GetConsequtiveVector(T start, int num) {
+static std::vector GetConsecutiveVector(T start, size_t num) {
std::vector inputv(num);
std::iota(inputv.begin(), inputv.end(), start);
return inputv;
}
+template <>
+std::vector GetConsecutiveVector(MLFloat16 start, size_t num) {
+ std::vector inputv;
+ inputv.reserve(num);
+ float start_f = start.ToFloat();
+ for (size_t i = 0; i < num; ++i) {
+ inputv.push_back(MLFloat16{start_f + static_cast(i)});
+ }
+ return inputv;
+}
+
TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0EqualSplitFloat) {
OpTester test("SplitToSequence", 11);
- test.AddInput("input", {4, 2}, GetConsequtiveVector(1.f, 8));
+ test.AddInput("input", {4, 2}, GetConsecutiveVector(1.f, 8));
test.AddInput("split", {1, 2}, {2, 2});
SeqTensors output;
output.AddTensor({2, 2}, {1.f, 2.f, 3.f, 4.f});
@@ -347,9 +358,31 @@ TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0EqualSplitFloat) {
test.Run();
}
+TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0EqualSplitMLFloat16) {
+ OpTester test("SplitToSequence", 11);
+ test.AddInput("input", {4, 2}, GetConsecutiveVector(MLFloat16::One, 8));
+ test.AddInput("split", {1, 2}, {2, 2});
+ SeqTensors output;
+
+ std::vector tensor_1;
+ const auto data_1 = {1.f, 2.f, 3.f, 4.f};
+ for (auto f : data_1)
+ tensor_1.push_back(MLFloat16{f});
+
+ std::vector tensor_2;
+ const auto data_2 = {5.f, 6.f, 7.f, 8.f};
+ for (auto f : data_2)
+ tensor_2.push_back(MLFloat16{f});
+
+ output.AddTensor({2, 2}, tensor_1);
+ output.AddTensor({2, 2}, tensor_2);
+ test.AddSeqOutput("S2", output);
+ test.Run();
+}
+
TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0EqualSplitLong) {
OpTester test("SplitToSequence", 11);
- test.AddInput("input", {4, 2}, GetConsequtiveVector(1, 8));
+ test.AddInput("input", {4, 2}, GetConsecutiveVector(1, 8));
test.AddInput("split", {1, 2}, {2, 2});
SeqTensors output;
output.AddTensor({2, 2}, {1, 2, 3, 4});
@@ -360,7 +393,7 @@ TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0EqualSplitLong) {
TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0EqualSplitFloatScalarSplit) {
OpTester test("SplitToSequence", 11);
- test.AddInput("input", {4, 2}, GetConsequtiveVector(1.f, 8));
+ test.AddInput("input", {4, 2}, GetConsecutiveVector(1.f, 8));
test.AddInput("split", {}, {2});
SeqTensors output;
output.AddTensor({2, 2}, {1.f, 2.f, 3.f, 4.f});
@@ -371,7 +404,7 @@ TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0EqualSplitFloatScalarSplit) {
TEST(SequenceOpsTest, SplitToSequence_Axis0DefaultSplitFloatSetAxisExplicitly) {
OpTester test("SplitToSequence", 11);
- test.AddInput("input", {4, 2}, GetConsequtiveVector(1.f, 8));
+ test.AddInput("input", {4, 2}, GetConsecutiveVector(1.f, 8));
int64_t axis = 0;
test.AddAttribute("axis", axis);
SeqTensors output;
@@ -385,7 +418,7 @@ TEST(SequenceOpsTest, SplitToSequence_Axis0DefaultSplitFloatSetAxisExplicitly) {
TEST(SequenceOpsTest, SplitToSequence_PositiveAxisScalarSplit) {
OpTester test("SplitToSequence", 11);
- test.AddInput("input", {2, 2, 6}, GetConsequtiveVector(1.f, 2 * 2 * 6));
+ test.AddInput("input", {2, 2, 6}, GetConsecutiveVector(1.f, 2 * 2 * 6));
int64_t axis = 2;
test.AddAttribute("axis", axis);
test.AddInput("split", {}, {2});
@@ -411,11 +444,11 @@ TEST(SequenceOpsTest, SplitToSequence_PositiveAxisScalarSplit) {
TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0UnevenSplitFloat) {
OpTester test("SplitToSequence", 11);
- test.AddInput("input", {5, 2}, GetConsequtiveVector(1.f, 10));
+ test.AddInput("input", {5, 2}, GetConsecutiveVector(1.f, 10));
test.AddInput("split", {}, {2});
SeqTensors output;
- output.AddTensor({2, 2}, GetConsequtiveVector(1.f, 4));
- output.AddTensor({2, 2}, GetConsequtiveVector(5.f, 4));
+ output.AddTensor({2, 2}, GetConsecutiveVector(1.f, 4));
+ output.AddTensor({2, 2}, GetConsecutiveVector(5.f, 4));
output.AddTensor({1, 2}, {9.f, 10.f});
test.AddSeqOutput("S2", output);
test.Run();
@@ -423,22 +456,22 @@ TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0UnevenSplitFloat) {
TEST(SequenceOpsTest, SplitToSequence_DefaultAxis0UnevenSplitFloat2) {
OpTester test("SplitToSequence", 11);
- test.AddInput("input", {17, 2}, GetConsequtiveVector(1.f, 34));
+ test.AddInput("input", {17, 2}, GetConsecutiveVector(1.f, 34));
test.AddInput("split", {}, {3});
SeqTensors output;
- output.AddTensor({3, 2}, GetConsequtiveVector(1.f, 6));
- output.AddTensor({3, 2}, GetConsequtiveVector(7.f, 6));
- output.AddTensor({3, 2}, GetConsequtiveVector(13.f, 6));
- output.AddTensor({3, 2}, GetConsequtiveVector(19.f, 6));
- output.AddTensor({3, 2}, GetConsequtiveVector(25.f, 6));
- output.AddTensor({2, 2}, GetConsequtiveVector(31.f, 4));
+ output.AddTensor({3, 2}, GetConsecutiveVector(1.f, 6));
+ output.AddTensor({3, 2}, GetConsecutiveVector(7.f, 6));
+ output.AddTensor({3, 2}, GetConsecutiveVector(13.f, 6));
+ output.AddTensor({3, 2}, GetConsecutiveVector(19.f, 6));
+ output.AddTensor({3, 2}, GetConsecutiveVector(25.f, 6));
+ output.AddTensor({2, 2}, GetConsecutiveVector(31.f, 4));
test.AddSeqOutput("S2", output);
test.Run();
}
TEST(SequenceOpsTest, SplitToSequence_PositiveAxisUnevenSplit) {
OpTester test("SplitToSequence", 11);
- test.AddInput("input", {2, 5}, GetConsequtiveVector(1.f, 10));
+ test.AddInput("input", {2, 5}, GetConsecutiveVector(1.f, 10));
test.AddInput("split", {}, {2});
int64_t axis = 1;
test.AddAttribute("axis", axis);
@@ -452,33 +485,33 @@ TEST(SequenceOpsTest, SplitToSequence_PositiveAxisUnevenSplit) {
TEST(SequenceOpsTest, SplitToSequence_Axis0DefaultSplitFloatSetAxisExplicitlyDontKeepDims3Dim) {
OpTester test("SplitToSequence", 11);
- test.AddInput("input", {2, 3, 4}, GetConsequtiveVector(1.f, 2 * 3 * 4));
+ test.AddInput("input", {2, 3, 4}, GetConsecutiveVector(1.f, 2 * 3 * 4));
test.AddAttribute("keepdims", 0);
int64_t axis = 0;
test.AddAttribute("axis", axis);
SeqTensors output;
- output.AddTensor({3, 4}, GetConsequtiveVector(1.f, 12));
- output.AddTensor({3, 4}, GetConsequtiveVector(13.f, 12));
+ output.AddTensor({3, 4}, GetConsecutiveVector(1.f, 12));
+ output.AddTensor({3, 4}, GetConsecutiveVector(13.f, 12));
test.AddSeqOutput("S2", output);
test.Run();
}
TEST(SequenceOpsTest, SplitToSequence_Axis0DefaultSplitFloatSetAxisExplicitlyDontKeepDims2Dim) {
OpTester test("SplitToSequence", 11);
- test.AddInput("input", {2, 3}, GetConsequtiveVector(1.f, 2 * 3));
+ test.AddInput("input", {2, 3}, GetConsecutiveVector(1.f, 2 * 3));
test.AddAttribute("keepdims", 0);
int64_t axis = 0;
test.AddAttribute("axis", axis);
SeqTensors output;
- output.AddTensor({3}, GetConsequtiveVector(1.f, 3));
- output.AddTensor({3}, GetConsequtiveVector(4.f, 3));
+ output.AddTensor({3}, GetConsecutiveVector(1.f, 3));
+ output.AddTensor({3}, GetConsecutiveVector(4.f, 3));
test.AddSeqOutput("S2", output);
test.Run();
}
TEST(SequenceOpsTest, SplitToSequence_PositiveAxisDontKeepDims) {
OpTester test("SplitToSequence", 11);
- test.AddInput("input", {2, 3, 4}, GetConsequtiveVector(1.f, 2 * 3 * 4));
+ test.AddInput("input", {2, 3, 4}, GetConsecutiveVector(1.f, 2 * 3 * 4));
test.AddAttribute("keepdims", 0);
int64_t axis = 2;
test.AddAttribute("axis", axis);