diff --git a/onnxruntime/core/providers/cpu/nn/string_split.cc b/onnxruntime/core/providers/cpu/nn/string_split.cc index 0a1abcef30117..592575327cae8 100644 --- a/onnxruntime/core/providers/cpu/nn/string_split.cc +++ b/onnxruntime/core/providers/cpu/nn/string_split.cc @@ -12,82 +12,66 @@ ONNX_CPU_OPERATOR_KERNEL( .TypeConstraint("T3", DataTypeImpl::GetTensorType()), StringSplit); -int64_t countSubstrings(const std::string& str, const std::string& substr) { +int64_t countSubstrings(std::string_view str, std::string_view substr) { if (str.empty()) { return 0; } if (substr.empty()) { // Count consecutive whitespace as one delimiter - bool in_whitespace = false; int64_t count = 1; - for (const auto& c : str) { - if (isspace(c)) { - if (!in_whitespace) { - in_whitespace = true; - count += 1; - } - } else { - in_whitespace = false; - } + size_t pos = str.find_first_not_of(" "); + while (pos != std::string::npos) { + ++count; + pos = str.find_first_not_of(" ", str.find_first_of(" ", pos)); } return count; } else { int64_t count = 1; - size_t pos = 0; - while ((pos = str.find(substr, pos)) != std::string::npos) { + size_t pos = str.find(substr); + while (pos != std::string::npos) { ++count; - pos += substr.length(); + pos = str.find(substr, pos + substr.length()); } return count; } } -size_t fill_substrings(const std::string& str, const std::string& delimiter, gsl::span output, int64_t output_index, size_t max_tokens) { +int64_t fillSubstrings(std::string_view str, std::string_view delimiter, gsl::span output, int64_t output_index, size_t max_tokens) { + if (str.empty()) { + return 0; + } if (delimiter.empty()) { - // Count consecutive whitespace as one delimiter - bool in_whitespace = false; + // Count consecutive whitespace as one delimiter. Preceding and trailing whitespace is meant to be ignored. + size_t pos = str.find_first_not_of(" "); size_t token_index = 0; - size_t substr_start_index = 0; - - for (size_t i = 0; i < str.size(); ++i) { - if (token_index == max_tokens - 1) { - // if we are at the max_tokens-1 substring, the next and final substring should be the remainder of the string - output[output_index + token_index] = str.substr(i); - ++token_index; + while (token_index < max_tokens && pos != std::string::npos) { + auto next_pos = token_index == max_tokens - 1 ? std::string::npos : str.find_first_of(" ", pos); + output[output_index + token_index] = str.substr(pos, next_pos - pos); + ++token_index; + if (next_pos == std::string::npos) { break; } - if (!isspace(str[i])) { - // if currently in a whitespace, this index marks the start of the current substring - if (in_whitespace) { - substr_start_index = i; - in_whitespace = false; - } - } else if (!in_whitespace) { - // if currently not in whitespace, this index is the end of a substring - output[output_index + token_index] = str.substr(substr_start_index, i - substr_start_index); - in_whitespace = true; - ++token_index; - } + pos = str.find_first_not_of(" ", next_pos); } return token_index; } else { size_t pos = 0; size_t token_index = 0; - while (token_index < max_tokens) { - auto next_pos = str.find(delimiter, pos); + while (token_index < max_tokens && pos != std::string::npos) { + auto next_pos = token_index == max_tokens - 1 ? std::string::npos : str.find(delimiter, pos); output[output_index + token_index] = str.substr(pos, next_pos - pos); - pos = next_pos + delimiter.size(); ++token_index; if (next_pos == std::string::npos) { break; } + pos = next_pos + delimiter.size(); } return token_index; } } StringSplit::StringSplit(const OpKernelInfo& info) : OpKernel(info) { - info.GetAttrOrDefault("maxsplit", &maxsplit_, static_cast(-1)); // TODO is this the right thing to do here? + info.GetAttrOrDefault("maxsplit", &maxsplit_, std::numeric_limits::max() - 1); // TODO is this the right thing to do here? info.GetAttrOrDefault("delimiter", &delimiter_, std::string("")); } @@ -101,12 +85,14 @@ Status StringSplit::Compute(OpKernelContext* context) const { return Status(common::ONNXRUNTIME, common::FAIL, "output count mismatch"); } auto input_data = input->template DataAsSpan(); - auto last_dim = maxsplit_; + int64_t last_dim = 0; for (auto i = 0; i < input->Shape().Size(); ++i) { auto substring_count = countSubstrings(input_data[i], delimiter_); last_dim = std::max(last_dim, substring_count); } + last_dim = std::min(last_dim, maxsplit_ + 1); + // Set up num_substrings output auto num_substrings_data = num_substrings->template MutableDataAsSpan(); // Set up splits output @@ -117,7 +103,7 @@ Status StringSplit::Compute(OpKernelContext* context) const { auto splits_data = context->Output(0, splits_shape)->template MutableDataAsSpan(); auto splits_index = 0; for (auto i = 0; i < input->Shape().Size(); ++i) { - num_substrings_data[i] = static_cast(fill_substrings(input_data[i], delimiter_, splits_data, splits_index, last_dim)); + num_substrings_data[i] = fillSubstrings(input_data[i], delimiter_, splits_data, splits_index, last_dim); splits_index += last_dim; } return Status::OK(); diff --git a/onnxruntime/test/providers/cpu/nn/string_split_test.cc b/onnxruntime/test/providers/cpu/nn/string_split_test.cc index a0a426bd40e2e..e4aa5a8161339 100644 --- a/onnxruntime/test/providers/cpu/nn/string_split_test.cc +++ b/onnxruntime/test/providers/cpu/nn/string_split_test.cc @@ -19,30 +19,34 @@ TEST(StringSplit, MaxSplitTest) { test.AddAttribute("delimiter", ";"); test.AddAttribute("maxsplit", 1); test.AddOutput("Y", {2, 2, 2}, {"eggs", "milk;chesse", "pepper", "salt", "chicken", "fish;pork", "spinach", ""}); - test.AddOutput("Z", {2, 2}, {2, 1, 2, 1}); + test.AddOutput("Z", {2, 2}, {2, 2, 2, 1}); + test.Run(); } TEST(StringSplit, EmptyStringDelimiterTest) { OpTester test("StringSplit", 20); test.AddInput("X", {1, 4}, {"hello world", "hello world", " hello world", "hello world "}); test.AddAttribute("delimiter", ""); - test.AddOutput("Y", {1, 4, 3}, {"hello", "world", "", "hello", "world", "", "", "hello", "world", "hello", "world", ""}); + test.AddOutput("Y", {1, 4, 3}, {"hello", "world", "", "hello", "world", "", "hello", "world", "", "hello", "world", ""}); test.AddOutput("Z", {1, 4}, {2, 2, 2, 2}); + test.Run(); } TEST(StringSplit, SubsequentWhitespaceDefaultTest) { OpTester test("StringSplit", 20); - test.AddInput("X", {1, 4}, {"hello world", "hello world", " hello world", "hello world "}); - test.AddOutput("Y", {1, 4, 3}, {"hello", "world", "", "hello", "world", "", "", "hello", "world", "hello", "world", ""}); + test.AddInput("X", {1, 4}, {"hello world", "hello world", " hello world", "hello world "}); + test.AddOutput("Y", {1, 4, 3}, {"hello", "world", "", "hello", "world", "", "hello", "world", "", "hello", "world", ""}); test.AddOutput("Z", {1, 4}, {2, 2, 2, 2}); + test.Run(); } TEST(StringSplit, SubsequentWhitespaceWithLimitTest) { OpTester test("StringSplit", 20); - test.AddInput("X", {1, 4}, {"lorem ipsum doler", "Open Neural Network Exchange (ONNX)", "onnx", "ONNX runtime "}); + test.AddInput("X", {1, 4}, {"lorem ipsum doler", " Open Neural Network Exchange (ONNX)", "onnx", "ONNX runtime "}); test.AddAttribute("maxsplit", 1); test.AddOutput("Y", {1, 4, 2}, {"lorem", "ipsum doler", "Open", "Neural Network Exchange (ONNX)", "onnx", "", "ONNX", "runtime "}); test.AddOutput("Z", {1, 4}, {2, 2, 1, 2}); + test.Run(); } } // namespace test