Skip to content

Commit

Permalink
Implement empty string case
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Oct 18, 2023
1 parent 38f0b99 commit 5309a68
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 47 deletions.
121 changes: 80 additions & 41 deletions onnxruntime/core/providers/cpu/nn/string_split.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,74 @@ int64_t countSubstrings(const std::string& str, const std::string& substr) {
if (str.empty()) {
return 0;
}
int64_t count = 1;
size_t pos = 0;
while ((pos = str.find(substr, pos)) != std::string::npos) {
++count;
pos += substr.length();
if (substr.empty()) {
// Count consecutive whitespace as one delimiter
bool in_whitespace = false;
int64_t count = 1;
for (const auto& c : str) {
if (isspace(c)) {
if (!in_whitespace) {
in_whitespace = true;
count += 1;
}
} else {
in_whitespace = false;
}
}
return count;
} else {
int64_t count = 1;
size_t pos = 0;
while ((pos = str.find(substr, pos)) != std::string::npos) {
++count;
pos += substr.length();
}
return count;
}

return count;
}

size_t fill_substrings(const std::string& str, const std::string& delimiter, gsl::span<std::string> output, int64_t output_index, int64_t max_tokens) {
// Fill output with substrings of str, delimited by delimiter and place into output starting at output_index and incrementing up.
// Up to max_tokens substrings should be reached. If we are done before max_tokens, fill the rest with "". If we would not be done after max_tokens, make sure the max_tokensth substring is the remainder of the string.
auto pos = 0;
size_t token_index = 0;
while (token_index < max_tokens) {
auto next_pos = str.find(delimiter, pos);
output[output_index + token_index] = str.substr(pos, next_pos - pos);
pos = next_pos + delimiter.size();
++token_index;
if (next_pos == std::string::npos) {
break;
size_t fill_substrings(const std::string& str, const std::string& delimiter, gsl::span<std::string> output, int64_t output_index, size_t max_tokens) {
if (delimiter.empty()) {
// Count consecutive whitespace as one delimiter
bool in_whitespace = false;
size_t token_index = 0;
size_t substr_start_index = 0;

for (size_t i = 0; i < str.size(); ++i) {
if (token_index == max_tokens - 1) {
// if we are at the max_tokens-1 substring, the next and final substring should be the remainder of the string
output[output_index + token_index] = str.substr(i);
++token_index;
break;
}
if (!isspace(str[i])) {
// if currently in a whitespace, this index marks the start of the current substring
if (in_whitespace) {
substr_start_index = i;
in_whitespace = false;
}
} else if (!in_whitespace) {
// if currently not in whitespace, this index is the end of a substring
output[output_index + token_index] = str.substr(substr_start_index, i - substr_start_index);
in_whitespace = true;
++token_index;
}
}
return token_index;
} else {
size_t pos = 0;
size_t token_index = 0;
while (token_index < max_tokens) {
auto next_pos = str.find(delimiter, pos);
output[output_index + token_index] = str.substr(pos, next_pos - pos);
pos = next_pos + delimiter.size();
++token_index;
if (next_pos == std::string::npos) {
break;
}
}
return token_index;
}
return token_index;
}

StringSplit::StringSplit(const OpKernelInfo& info) : OpKernel(info) {
Expand All @@ -57,29 +100,25 @@ Status StringSplit::Compute(OpKernelContext* context) const {
if (nullptr == num_substrings) {
return Status(common::ONNXRUNTIME, common::FAIL, "output count mismatch");
}
if ("" == delimiter_) {
// TODO: takes consecutive whitespace as delimiter
} else {
auto input_data = input->template DataAsSpan<std::string>();
auto last_dim = maxsplit_;
for (auto i = 0; i < input->Shape().Size(); ++i) {
auto substring_count = countSubstrings(input_data[i], delimiter_);
last_dim = std::max(last_dim, substring_count);
}
// 1. instantiate output shape to be shape input + (last_dim,)
// 2. maintain output tensor index/pointer. Iterate over input tensor; split tensor into last_dim substrings (with "" at end for extra); copy into output tensor and output pointer/index. advance output pointer/index by last_dim.
auto input_data = input->template DataAsSpan<std::string>();
auto last_dim = maxsplit_;
for (auto i = 0; i < input->Shape().Size(); ++i) {
auto substring_count = countSubstrings(input_data[i], delimiter_);
last_dim = std::max(last_dim, substring_count);
}

// Set up num_substrings output
auto num_substrings_data = num_substrings->template MutableDataAsSpan<int64_t>();
// Set up splits output
auto splits_shape = input->Shape().AsShapeVector();
// Set up num_substrings output
auto num_substrings_data = num_substrings->template MutableDataAsSpan<int64_t>();
// Set up splits output
auto splits_shape = input->Shape().AsShapeVector();
if (last_dim > 0) {
splits_shape.push_back(last_dim);
auto splits_data = context->Output(0, splits_shape)->template MutableDataAsSpan<std::string>();
auto splits_index = 0;
for (auto i = 0; i < input->Shape().Size(); ++i) {
num_substrings_data[i] = static_cast<int64_t>(fill_substrings(input_data[i], delimiter_, splits_data, splits_index, last_dim));
splits_index += last_dim;
}
}
auto splits_data = context->Output(0, splits_shape)->template MutableDataAsSpan<std::string>();
auto splits_index = 0;
for (auto i = 0; i < input->Shape().Size(); ++i) {
num_substrings_data[i] = static_cast<int64_t>(fill_substrings(input_data[i], delimiter_, splits_data, splits_index, last_dim));
splits_index += last_dim;
}
return Status::OK();
}
Expand Down
23 changes: 23 additions & 0 deletions onnxruntime/test/providers/cpu/nn/string_split_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,28 @@ TEST(StringSplit, MaxSplitTest) {
test.AddOutput<int64_t>("Z", {2, 2}, {2, 1, 2, 1});
}

TEST(StringSplit, EmptyStringDelimiterTest) {
OpTester test("StringSplit", 20);
test.AddInput<std::string>("X", {1, 4}, {"hello world", "hello world", " hello world", "hello world "});
test.AddAttribute<std::string>("delimiter", "");
test.AddOutput<std::string>("Y", {1, 4, 3}, {"hello", "world", "", "hello", "world", "", "", "hello", "world", "hello", "world", ""});
test.AddOutput<int64_t>("Z", {1, 4}, {2, 2, 2, 2});
}

TEST(StringSplit, SubsequentWhitespaceDefaultTest) {
OpTester test("StringSplit", 20);
test.AddInput<std::string>("X", {1, 4}, {"hello world", "hello world", " hello world", "hello world "});
test.AddOutput<std::string>("Y", {1, 4, 3}, {"hello", "world", "", "hello", "world", "", "", "hello", "world", "hello", "world", ""});
test.AddOutput<int64_t>("Z", {1, 4}, {2, 2, 2, 2});
}

TEST(StringSplit, SubsequentWhitespaceWithLimitTest) {
OpTester test("StringSplit", 20);
test.AddInput<std::string>("X", {1, 4}, {"lorem ipsum doler", "Open Neural Network Exchange (ONNX)", "onnx", "ONNX runtime "});
test.AddAttribute<int64_t>("maxsplit", 1);
test.AddOutput<std::string>("Y", {1, 4, 2}, {"lorem", "ipsum doler", "Open", "Neural Network Exchange (ONNX)", "onnx", "", "ONNX", "runtime "});
test.AddOutput<int64_t>("Z", {1, 4}, {2, 2, 1, 2});
}

} // namespace test
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,6 @@
"^test_string_concat_empty_string",
"^test_string_concat_utf8",
"^test_string_concat_zero_dimensional",
"^test_string_split_basic",
"^test_string_split_consecutive_delimiters",
"^test_string_split_empty_string_delimiter",
"^test_string_split_empty_tensor",
"^test_string_split_maxsplit",
"^test_string_split_no_delimiter",
"^test_dft_axis",
"^test_dft",
"^test_dft_inverse",
Expand Down

0 comments on commit 5309a68

Please sign in to comment.