Skip to content

Commit

Permalink
StringSplit operator (#18016)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->



### Motivation and Context
Closes #17596
  • Loading branch information
adityagoel4512 authored and mszhanyi committed Jan 15, 2024
1 parent 48d163a commit 7b6a615
Show file tree
Hide file tree
Showing 7 changed files with 248 additions and 7 deletions.
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ Do not modify directly.*
|||[1, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|StringConcat|*in* X:**T**<br> *in* Y:**T**<br> *out* Z:**T**|20+|**T** = tensor(string)|
|StringNormalizer|*in* X:**tensor(string)**<br> *out* Y:**tensor(string)**|10+|**X** = tensor(string)|
|StringSplit|*in* X:**T1**<br> *out* Y:**T2**<br> *out* Z:**T3**|20+|**T1** = tensor(string)<br/> **T2** = tensor(string)<br/> **T3** = tensor(int64)|
|Sub|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|||13|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|||[7, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c Microsoft Corporation. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "qlinear_util.h"
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -991,6 +991,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, StringConcat);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, RegexFullMatch);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, StringSplit);

// !!PLEASE READ BELOW!! Following that, add new entries above this comment

Expand Down Expand Up @@ -2451,6 +2452,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, StringConcat)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, RegexFullMatch)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, StringSplit)>,
};

for (auto& function_table_entry : function_table) {
Expand Down
98 changes: 98 additions & 0 deletions onnxruntime/core/providers/cpu/text/string_split.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "string_split.h"
#include <algorithm>
#include <limits>
#include <string>
#include "core/common/common.h"
namespace onnxruntime {

ONNX_CPU_OPERATOR_KERNEL(StringSplit, 20,
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType<std::string>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<std::string>())
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int64_t>()),
StringSplit);

/// Calculate substrings in ``str`` delimited by ``delimiter``. A maximum of ``max_splits`` splits are permitted.
/// Returns a vector of string slices into ``str`` representing the substrings as string views. The user must ensure
/// the returned views' lifetime does not exceed ``str``'s.
void ComputeSubstrings(std::string_view str, std::string_view delimiter, int64_t max_splits, InlinedVector<std::string_view>& out) {
if (str.empty()) {
return;
}
if (delimiter.empty()) {
// Count consecutive whitespace as one delimiter. Preceding and trailing whitespace is meant to be ignored.
size_t pos = str.find_first_not_of(" ");
int64_t token_count = 0;
while (pos != std::string::npos) {
if (token_count++ == max_splits) {
// Trim down last substring as required in specification
size_t next_pos = str.length() - 1;
while (str[next_pos] == ' ') {
next_pos--;
}
out.push_back(str.substr(pos, next_pos - pos + 1));
break;
} else {
auto next_pos = str.find_first_of(" ", pos);
out.push_back(str.substr(pos, next_pos - pos));
pos = str.find_first_not_of(" ", next_pos);
}
}
} else {
size_t pos = 0;
int64_t token_count = 0;
while (pos != std::string::npos) {
auto next_pos = str.find(delimiter, pos);
if (token_count++ == max_splits || next_pos == std::string::npos) {
out.push_back(str.substr(pos));
break;
}
out.push_back(str.substr(pos, next_pos - pos));
pos = next_pos + delimiter.size();
}
}
}

StringSplit::StringSplit(const OpKernelInfo& info) : OpKernel(info) {
info.GetAttrOrDefault("maxsplit", &maxsplit_, std::numeric_limits<int64_t>::max() - 1);
info.GetAttrOrDefault("delimiter", &delimiter_, std::string());
}

Status StringSplit::Compute(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
auto input_data = input->template DataAsSpan<std::string>();

// Set up number of tokens output
auto num_tokens_data = context->Output(1, input->Shape())->template MutableDataAsSpan<int64_t>();
auto num_tokens_iter = num_tokens_data.begin();

InlinedVector<InlinedVector<std::string_view>> input_slices;
input_slices.reserve(input_data.size());
size_t last_dim = 0;

for (const auto& s : input_data) {
auto& substrs = input_slices.emplace_back();
ComputeSubstrings(s, delimiter_, maxsplit_, substrs);
auto substr_count = substrs.size();
last_dim = std::max(last_dim, substr_count);
*num_tokens_iter = static_cast<int64_t>(substr_count);
++num_tokens_iter;
}

// Set up splits output
auto splits_shape = input->Shape().AsShapeVector();
splits_shape.push_back(last_dim);

auto splits_data = context->Output(0, splits_shape)->template MutableDataAsSpan<std::string>();
auto slices_iter = input_slices.begin();
for (auto output_splits_iter = splits_data.begin(); output_splits_iter != splits_data.end(); output_splits_iter += last_dim, ++slices_iter) {
std::copy(slices_iter->begin(), slices_iter->end(), output_splits_iter);
}

return Status::OK();
}

} // namespace onnxruntime
20 changes: 20 additions & 0 deletions onnxruntime/core/providers/cpu/text/string_split.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/framework/op_kernel.h"

namespace onnxruntime {

class StringSplit final : public OpKernel {
public:
explicit StringSplit(const OpKernelInfo& info);
Status Compute(OpKernelContext* context) const override;

private:
std::string delimiter_;
int64_t maxsplit_;
};

} // namespace onnxruntime
126 changes: 126 additions & 0 deletions onnxruntime/test/providers/cpu/text/string_split_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"

namespace onnxruntime {
namespace test {

TEST(StringSplit, BasicSplitTest) {
OpTester test("StringSplit", 20);
test.AddInput<std::string>("X", {3}, {"hello world", "hello", "world"});
test.AddAttribute<std::string>("delimiter", " ");
test.AddOutput<std::string>("Y", {3, 2}, {"hello", "world", "hello", "", "world", ""});
test.AddOutput<int64_t>("Z", {3}, {2, 1, 1});
test.Run();
}

TEST(StringSplit, MaxSplitTest) {
OpTester test("StringSplit", 20);
test.AddInput<std::string>("X", {2, 2}, {"eggs;milk;chesse", "pepper;salt", "chicken;fish;pork", "spinach"});
test.AddAttribute<std::string>("delimiter", ";");
test.AddAttribute<int64_t>("maxsplit", 1);
test.AddOutput<std::string>("Y", {2, 2, 2},
{"eggs", "milk;chesse", "pepper", "salt", "chicken", "fish;pork", "spinach", ""});
test.AddOutput<int64_t>("Z", {2, 2}, {2, 2, 2, 1});
test.Run();
}

TEST(StringSplit, EmptyStringDelimiterTest) {
OpTester test("StringSplit", 20);
test.AddInput<std::string>("X", {1, 4}, {"hello world", "hello world", " hello world", "hello world "});
test.AddAttribute<std::string>("delimiter", "");
test.AddOutput<std::string>("Y", {1, 4, 2}, {"hello", "world", "hello", "world", "hello", "world", "hello", "world"});
test.AddOutput<int64_t>("Z", {1, 4}, {2, 2, 2, 2});
test.Run();
}

TEST(StringSplit, SubsequentWhitespaceDefaultTest) {
OpTester test("StringSplit", 20);
test.AddInput<std::string>("X", {1, 4}, {"hello world", "hello world", " hello world", "hello world "});
test.AddOutput<std::string>("Y", {1, 4, 2}, {"hello", "world", "hello", "world", "hello", "world", "hello", "world"});
test.AddOutput<int64_t>("Z", {1, 4}, {2, 2, 2, 2});
test.Run();
}

TEST(StringSplit, SubsequentWhitespaceWithLimitTest) {
OpTester test("StringSplit", 20);
test.AddInput<std::string>("X", {1, 4},
{"lorem ipsum doler", " Open Neural Network Exchange (ONNX)", "onnx", "ONNX runtime "});
test.AddAttribute<int64_t>("maxsplit", 1);
test.AddOutput<std::string>(
"Y", {1, 4, 2},
{"lorem", "ipsum doler", "Open", "Neural Network Exchange (ONNX)", "onnx", "", "ONNX", "runtime"});
test.AddOutput<int64_t>("Z", {1, 4}, {2, 2, 1, 2});
test.Run();
}

TEST(StringSplit, SingleTokenTest) {
OpTester test("StringSplit", 20);
test.AddAttribute<std::string>("delimiter", "*");
test.AddInput<std::string>("X", {1, 1, 1}, {"lorem"});
test.AddOutput<std::string>("Y", {1, 1, 1, 1}, {"lorem"});
test.AddOutput<int64_t>("Z", {1, 1, 1}, {1});
test.Run();
}

TEST(StringSplit, SingleTokenWhitespaceTest) {
OpTester test("StringSplit", 20);
test.AddInput<std::string>("X", {1, 1, 1}, {"lorem"});
test.AddOutput<std::string>("Y", {1, 1, 1, 1}, {"lorem"});
test.AddOutput<int64_t>("Z", {1, 1, 1}, {1});
test.Run();
}

TEST(StringSplit, EdgeWhitespaceTest) {
OpTester test("StringSplit", 20);
test.AddInput<std::string>("X", {1, 1, 1}, {" lorem "});
test.AddOutput<std::string>("Y", {1, 1, 1, 1}, {"lorem"});
test.AddOutput<int64_t>("Z", {1, 1, 1}, {1});
test.Run();
}

TEST(StringSplit, EmptyInputTest) {
OpTester test("StringSplit", 20);
test.AddInput<std::string>("X", {1, 3, 1}, {"", "+", "*"});
test.AddAttribute<std::string>("delimiter", "*");
test.AddOutput<std::string>("Y", {1, 3, 1, 2}, {"", "", "+", "", "", ""});
test.AddOutput<int64_t>("Z", {1, 3, 1}, {0, 1, 2});
test.Run();
}

TEST(StringSplit, OnlyEmptyInputTest) {
OpTester test("StringSplit", 20);
test.AddAttribute<std::string>("delimiter", "*");
test.AddInput<std::string>("X", {1, 2, 1}, {"", ""});
test.AddOutput<std::string>("Y", {1, 2, 1, 0}, {});
test.AddOutput<int64_t>("Z", {1, 2, 1}, {0, 0});
test.Run();
}

TEST(StringSplit, OnlyEmptyNoDelimiterInputTest) {
OpTester test("StringSplit", 20);
test.AddInput<std::string>("X", {1, 2, 1}, {"", ""});
test.AddOutput<std::string>("Y", {1, 2, 1, 0}, {});
test.AddOutput<int64_t>("Z", {1, 2, 1}, {0, 0});
test.Run();
}

TEST(StringSplit, NoInputTest) {
OpTester test("StringSplit", 20);
test.AddInput<std::string>("X", {
0,
},
{});
test.AddOutput<std::string>("Y", {
0,
0,
},
{});
test.AddOutput<int64_t>("Z", {
0,
},
{});
test.Run();
}

} // namespace test
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -248,12 +248,6 @@
"^test_image_decoder_decode_pnm_rgb",
"^test_image_decoder_decode_tiff_rgb",
"^test_image_decoder_decode_webp_rgb",
"^test_string_split_basic",
"^test_string_split_consecutive_delimiters",
"^test_string_split_empty_string_delimiter",
"^test_string_split_empty_tensor",
"^test_string_split_maxsplit",
"^test_string_split_no_delimiter",
"^test_reduce_l1_empty_set_cuda",
"^test_reduce_l1_empty_set_expanded_cuda",
"^test_reduce_l2_empty_set_cuda",
Expand Down

0 comments on commit 7b6a615

Please sign in to comment.