From d303e5bcfca249aed666a531a1a68539a8a34a45 Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Tue, 9 Jan 2024 23:21:14 +0000 Subject: [PATCH] Obtain string slices in first pass and remove CountSubstrings function entirely --- .../core/providers/cpu/nn/string_split.cc | 95 ++++++++----------- .../providers/cpu/nn/string_split_test.cc | 17 ++++ 2 files changed, 55 insertions(+), 57 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/string_split.cc b/onnxruntime/core/providers/cpu/nn/string_split.cc index 0c788c3fc54c0..eadd189c1e6ee 100644 --- a/onnxruntime/core/providers/cpu/nn/string_split.cc +++ b/onnxruntime/core/providers/cpu/nn/string_split.cc @@ -1,8 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "string_split.h" +#include "core/providers/cpu/nn/string_split.h" #include +#include +#include #include "core/common/common.h" namespace onnxruntime { @@ -13,75 +15,47 @@ ONNX_CPU_OPERATOR_KERNEL(StringSplit, 20, .TypeConstraint("T3", DataTypeImpl::GetTensorType()), StringSplit); -/// Count the number of instances of substring ``substr`` in ``str``. If ``substr`` is an empty string it counts the -/// number of whitespace delimited words. -int64_t CountSubstrings(std::string_view str, std::string_view substr) { - if (substr.empty()) { - // Count consecutive whitespace as one delimiter - int64_t count = 0; - size_t pos = str.find_first_not_of(" "); - while (pos != std::string::npos) { - ++count; - pos = str.find_first_not_of(" ", str.find_first_of(" ", pos)); - } - return count; - } else { - int64_t count = 0; - size_t pos = 0; - while (pos != std::string::npos) { - ++count; - pos = str.find(substr, pos); - if (pos != std::string::npos) { - pos += substr.length(); - } - } - return count; - } -} - /// Fill substrings of ``str`` based on split delimiter ``delimiter`` into ``output`` span. Restrict maximum number of /// generated substrings to ``max_tokens``. The function returns the number of substrings generated (this is less or /// equal to ``max_tokens``). -int64_t FillSubstrings(std::string_view str, std::string_view delimiter, - gsl::details::span_iterator output, size_t max_tokens) { +InlinedVector FillSubstrings(std::string_view str, std::string_view delimiter, int64_t max_splits) { + InlinedVector output; if (str.empty()) { - return 0; + return output; } if (delimiter.empty()) { // Count consecutive whitespace as one delimiter. Preceding and trailing whitespace is meant to be ignored. size_t pos = str.find_first_not_of(" "); int64_t token_count = 0; while (pos != std::string::npos) { - if (++token_count == max_tokens) { + if (token_count++ == max_splits) { // trim down last substring as required in specification size_t next_pos = str.length() - 1; while (str[next_pos] == ' ') { next_pos--; } - *output = str.substr(pos, next_pos - pos + 1); + output.push_back(str.substr(pos, next_pos - pos + 1)); break; } else { auto next_pos = str.find_first_of(" ", pos); - *output = str.substr(pos, next_pos - pos); + output.push_back(str.substr(pos, next_pos - pos)); pos = str.find_first_not_of(" ", next_pos); } - - output++; } - return token_count; + return output; } else { size_t pos = 0; int64_t token_count = 0; while (pos != std::string::npos) { - auto next_pos = token_count == max_tokens - 1 ? std::string::npos : str.find(delimiter, pos); - *output++ = str.substr(pos, next_pos - pos); - token_count++; - if (next_pos == std::string::npos) { + auto next_pos = str.find(delimiter, pos); + if (token_count++ == max_splits || next_pos == std::string::npos) { + output.push_back(str.substr(pos)); break; } + output.push_back(str.substr(pos, next_pos - pos)); pos = next_pos + delimiter.size(); } - return token_count; + return output; } } @@ -94,29 +68,36 @@ Status StringSplit::Compute(OpKernelContext* context) const { const Tensor* input = context->Input(0); auto input_data = input->template DataAsSpan(); - int64_t last_dim = 0; - for (const auto& str : input_data) { - last_dim = std::max(last_dim, CountSubstrings(str, delimiter_)); + // Set up number of tokens output + auto num_tokens_data = context->Output(1, input->Shape())->template MutableDataAsSpan(); + auto num_tokens_iter = num_tokens_data.begin(); + + int64_t last_dim = 1; + + InlinedVector> input_slices; + input_slices.reserve(input_data.size()); + auto input_slice_iterator = input_slices.begin(); + for (auto input_iter = input_data.begin(); input_iter != input_data.end(); input_iter++, input_slice_iterator++, num_tokens_iter++) { + auto substrs = FillSubstrings(*input_iter, delimiter_, maxsplit_); + auto substr_count = static_cast(substrs.size()); + input_slices.push_back(substrs); + last_dim = std::max(last_dim, substr_count); + *num_tokens_iter = substr_count; } + last_dim = std::min(last_dim, maxsplit_ + 1); // Set up splits output auto splits_shape = input->Shape().AsShapeVector(); - if (last_dim > 0) { - splits_shape.push_back(last_dim); - } - auto splits_data = context->Output(0, splits_shape)->template MutableDataAsSpan(); - auto output_splits_iter = splits_data.begin(); - - // Set up number of tokens output - auto* num_substrings = context->Output(1, input->Shape()); - auto num_substrings_data = num_substrings->template MutableDataAsSpan(); - auto output_num_tokens_iter = num_substrings_data.begin(); + splits_shape.push_back(last_dim); - for (auto input_iter = input_data.begin(); input_iter != input_data.end(); - input_iter++, output_splits_iter += last_dim, output_num_tokens_iter++) { - *output_num_tokens_iter = FillSubstrings(*input_iter, delimiter_, output_splits_iter, last_dim); + auto splits_data = context->Output(0, splits_shape)->template MutableDataAsSpan(); + auto slices_iter = input_slices.begin(); + for (auto output_splits_iter = splits_data.begin(); output_splits_iter != splits_data.end(); output_splits_iter += last_dim, slices_iter++) { + const auto output_slices = *slices_iter; + std::copy(output_slices.begin(), output_slices.end(), output_splits_iter); } + return Status::OK(); } diff --git a/onnxruntime/test/providers/cpu/nn/string_split_test.cc b/onnxruntime/test/providers/cpu/nn/string_split_test.cc index 286b4e41b384c..d2f37168babe3 100644 --- a/onnxruntime/test/providers/cpu/nn/string_split_test.cc +++ b/onnxruntime/test/providers/cpu/nn/string_split_test.cc @@ -87,5 +87,22 @@ TEST(StringSplit, EmptyInputTest) { test.Run(); } +TEST(StringSplit, OnlyEmptyInputTest) { + OpTester test("StringSplit", 20); + test.AddAttribute("delimiter", "*"); + test.AddInput("X", {1, 2, 1}, {"", ""}); + test.AddOutput("Y", {1, 2, 1, 1}, {"", ""}); + test.AddOutput("Z", {1, 2, 1}, {0, 0}); + test.Run(); +} + +TEST(StringSplit, OnlyEmptyNoDelimiterInputTest) { + OpTester test("StringSplit", 20); + test.AddInput("X", {1, 2, 1}, {"", ""}); + test.AddOutput("Y", {1, 2, 1, 1}, {"", ""}); + test.AddOutput("Z", {1, 2, 1}, {0, 0}); + test.Run(); +} + } // namespace test } // namespace onnxruntime