From d254203e21c02a491f74a47b0de342825ea933ea Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Wed, 18 Oct 2023 02:48:52 +0100 Subject: [PATCH 01/15] Normal case --- .../providers/cpu/cpu_execution_provider.cc | 2 + .../core/providers/cpu/nn/string_split.cc | 87 +++++++++++++++++++ .../core/providers/cpu/nn/string_split.h | 17 ++++ .../providers/cpu/nn/string_split_test.cc | 26 ++++++ 4 files changed, 132 insertions(+) create mode 100644 onnxruntime/core/providers/cpu/nn/string_split.cc create mode 100644 onnxruntime/core/providers/cpu/nn/string_split.h create mode 100644 onnxruntime/test/providers/cpu/nn/string_split_test.cc diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index f60c7ddac5c05..5663c6d5cbb8b 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -989,6 +989,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2FNUZ, IsNaN); #endif class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, StringSplit); // !!PLEASE READ BELOW!! Following that, add new entries above this comment @@ -2447,6 +2448,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, #endif BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/nn/string_split.cc b/onnxruntime/core/providers/cpu/nn/string_split.cc new file mode 100644 index 0000000000000..5ef22ce137d1f --- /dev/null +++ b/onnxruntime/core/providers/cpu/nn/string_split.cc @@ -0,0 +1,87 @@ +#include "string_split.h" +#include "core/common/common.h" +#include +namespace onnxruntime { + +ONNX_CPU_OPERATOR_KERNEL( + StringSplit, + 20, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + StringSplit); + +int64_t countSubstrings(const std::string& str, const std::string& substr) { + if (str.empty()) { + return 0; + } + int64_t count = 1; + size_t pos = 0; + while ((pos = str.find(substr, pos)) != std::string::npos) { + ++count; + pos += substr.length(); + } + + return count; +} + +size_t fill_substrings(const std::string& str, const std::string& delimiter, gsl::span output, int64_t output_index, int64_t max_tokens) { + // Fill output with substrings of str, delimited by delimiter and place into output starting at output_index and incrementing up. + // Up to max_tokens substrings should be reached. If we are done before max_tokens, fill the rest with "". If we would not be done after max_tokens, make sure the max_tokensth substring is the remainder of the string. + auto pos = 0; + size_t token_index = 0; + while (token_index < max_tokens) { + auto next_pos = str.find(delimiter, pos); + output[output_index + token_index] = str.substr(pos, next_pos - pos); + pos = next_pos + delimiter.size(); + ++token_index; + if (next_pos == std::string::npos) { + break; + } + } + return token_index; +} + +StringSplit::StringSplit(const OpKernelInfo& info) : OpKernel(info) { + info.GetAttrOrDefault("maxsplit", &maxsplit_, static_cast(-1)); // TODO is this the right thing to do here? + info.GetAttrOrDefault("delimiter", &delimiter_, std::string("")); +} + +Status StringSplit::Compute(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + if (nullptr == input) { + return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); + } + auto* num_substrings = context->Output(1, input->Shape()); + if (nullptr == num_substrings) { + return Status(common::ONNXRUNTIME, common::FAIL, "output count mismatch"); + } + if ("" == delimiter_) { + // TODO: takes consecutive whitespace as delimiter + } else { + auto input_data = input->template DataAsSpan(); + auto last_dim = maxsplit_; + for (auto i = 0; i < input->Shape().Size(); ++i) { + auto substring_count = countSubstrings(input_data[i], delimiter_); + last_dim = std::max(last_dim, substring_count); + } + // 1. instantiate output shape to be shape input + (last_dim,) + // 2. maintain output tensor index/pointer. Iterate over input tensor; split tensor into last_dim substrings (with "" at end for extra); copy into output tensor and output pointer/index. advance output pointer/index by last_dim. + + // Set up num_substrings output + auto num_substrings_data = num_substrings->template MutableDataAsSpan(); + // Set up splits output + auto splits_shape = input->Shape().AsShapeVector(); + splits_shape.push_back(last_dim); + auto splits_data = context->Output(0, splits_shape)->template MutableDataAsSpan(); + auto splits_index = 0; + for (auto i = 0; i < input->Shape().Size(); ++i) { + num_substrings_data[i] = static_cast(fill_substrings(input_data[i], delimiter_, splits_data, splits_index, last_dim)); + splits_index += last_dim; + } + } + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/string_split.h b/onnxruntime/core/providers/cpu/nn/string_split.h new file mode 100644 index 0000000000000..f9ded9cafc760 --- /dev/null +++ b/onnxruntime/core/providers/cpu/nn/string_split.h @@ -0,0 +1,17 @@ +#pragma once + +#include "core/framework/op_kernel.h" + +namespace onnxruntime { + +class StringSplit final : public OpKernel { + public: + explicit StringSplit(const OpKernelInfo& info); + Status Compute(OpKernelContext* context) const override; + + private: + std::string delimiter_; + int64_t maxsplit_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/nn/string_split_test.cc b/onnxruntime/test/providers/cpu/nn/string_split_test.cc new file mode 100644 index 0000000000000..8d6dffaf76ce0 --- /dev/null +++ b/onnxruntime/test/providers/cpu/nn/string_split_test.cc @@ -0,0 +1,26 @@ +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +TEST(StringSplit, BasicSplitTest) { + OpTester test("StringSplit", 20); + test.AddInput("X", {3}, {"hello world", "hello", "world"}); + test.AddAttribute("delimiter", " "); + test.AddOutput("Y", {3, 2}, {"hello", "world", "hello", "", "world", ""}); + test.AddOutput("Z", {3}, {2, 1, 1}); + test.Run(); +} + +TEST(StringSplit, MaxSplitTest) { + OpTester test("StringSplit", 20); + test.AddInput("X", {2, 2}, {"eggs;milk;chesse", "pepper;salt", "chicken;fish;pork", "spinach"}); + test.AddAttribute("delimiter", ";"); + test.AddAttribute("maxsplit", 1); + test.AddOutput("Y", {2, 2, 2}, {"eggs", "milk;chesse", "pepper", "salt", "chicken", "fish;pork", "spinach", ""}); + test.AddOutput("Z", {2, 2}, {2, 1, 2, 1}); +} + +} // namespace test +} // namespace onnxruntime From f51fd2a9890d136798aaf7e4b94e385c4b9239b3 Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Wed, 18 Oct 2023 13:17:30 +0100 Subject: [PATCH 02/15] Implement empty string case --- .../core/providers/cpu/nn/string_split.cc | 121 ++++++++++++------ .../providers/cpu/nn/string_split_test.cc | 23 ++++ 2 files changed, 103 insertions(+), 41 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/string_split.cc b/onnxruntime/core/providers/cpu/nn/string_split.cc index 5ef22ce137d1f..0a1abcef30117 100644 --- a/onnxruntime/core/providers/cpu/nn/string_split.cc +++ b/onnxruntime/core/providers/cpu/nn/string_split.cc @@ -16,31 +16,74 @@ int64_t countSubstrings(const std::string& str, const std::string& substr) { if (str.empty()) { return 0; } - int64_t count = 1; - size_t pos = 0; - while ((pos = str.find(substr, pos)) != std::string::npos) { - ++count; - pos += substr.length(); + if (substr.empty()) { + // Count consecutive whitespace as one delimiter + bool in_whitespace = false; + int64_t count = 1; + for (const auto& c : str) { + if (isspace(c)) { + if (!in_whitespace) { + in_whitespace = true; + count += 1; + } + } else { + in_whitespace = false; + } + } + return count; + } else { + int64_t count = 1; + size_t pos = 0; + while ((pos = str.find(substr, pos)) != std::string::npos) { + ++count; + pos += substr.length(); + } + return count; } - - return count; } -size_t fill_substrings(const std::string& str, const std::string& delimiter, gsl::span output, int64_t output_index, int64_t max_tokens) { - // Fill output with substrings of str, delimited by delimiter and place into output starting at output_index and incrementing up. - // Up to max_tokens substrings should be reached. If we are done before max_tokens, fill the rest with "". If we would not be done after max_tokens, make sure the max_tokensth substring is the remainder of the string. - auto pos = 0; - size_t token_index = 0; - while (token_index < max_tokens) { - auto next_pos = str.find(delimiter, pos); - output[output_index + token_index] = str.substr(pos, next_pos - pos); - pos = next_pos + delimiter.size(); - ++token_index; - if (next_pos == std::string::npos) { - break; +size_t fill_substrings(const std::string& str, const std::string& delimiter, gsl::span output, int64_t output_index, size_t max_tokens) { + if (delimiter.empty()) { + // Count consecutive whitespace as one delimiter + bool in_whitespace = false; + size_t token_index = 0; + size_t substr_start_index = 0; + + for (size_t i = 0; i < str.size(); ++i) { + if (token_index == max_tokens - 1) { + // if we are at the max_tokens-1 substring, the next and final substring should be the remainder of the string + output[output_index + token_index] = str.substr(i); + ++token_index; + break; + } + if (!isspace(str[i])) { + // if currently in a whitespace, this index marks the start of the current substring + if (in_whitespace) { + substr_start_index = i; + in_whitespace = false; + } + } else if (!in_whitespace) { + // if currently not in whitespace, this index is the end of a substring + output[output_index + token_index] = str.substr(substr_start_index, i - substr_start_index); + in_whitespace = true; + ++token_index; + } + } + return token_index; + } else { + size_t pos = 0; + size_t token_index = 0; + while (token_index < max_tokens) { + auto next_pos = str.find(delimiter, pos); + output[output_index + token_index] = str.substr(pos, next_pos - pos); + pos = next_pos + delimiter.size(); + ++token_index; + if (next_pos == std::string::npos) { + break; + } } + return token_index; } - return token_index; } StringSplit::StringSplit(const OpKernelInfo& info) : OpKernel(info) { @@ -57,29 +100,25 @@ Status StringSplit::Compute(OpKernelContext* context) const { if (nullptr == num_substrings) { return Status(common::ONNXRUNTIME, common::FAIL, "output count mismatch"); } - if ("" == delimiter_) { - // TODO: takes consecutive whitespace as delimiter - } else { - auto input_data = input->template DataAsSpan(); - auto last_dim = maxsplit_; - for (auto i = 0; i < input->Shape().Size(); ++i) { - auto substring_count = countSubstrings(input_data[i], delimiter_); - last_dim = std::max(last_dim, substring_count); - } - // 1. instantiate output shape to be shape input + (last_dim,) - // 2. maintain output tensor index/pointer. Iterate over input tensor; split tensor into last_dim substrings (with "" at end for extra); copy into output tensor and output pointer/index. advance output pointer/index by last_dim. + auto input_data = input->template DataAsSpan(); + auto last_dim = maxsplit_; + for (auto i = 0; i < input->Shape().Size(); ++i) { + auto substring_count = countSubstrings(input_data[i], delimiter_); + last_dim = std::max(last_dim, substring_count); + } - // Set up num_substrings output - auto num_substrings_data = num_substrings->template MutableDataAsSpan(); - // Set up splits output - auto splits_shape = input->Shape().AsShapeVector(); + // Set up num_substrings output + auto num_substrings_data = num_substrings->template MutableDataAsSpan(); + // Set up splits output + auto splits_shape = input->Shape().AsShapeVector(); + if (last_dim > 0) { splits_shape.push_back(last_dim); - auto splits_data = context->Output(0, splits_shape)->template MutableDataAsSpan(); - auto splits_index = 0; - for (auto i = 0; i < input->Shape().Size(); ++i) { - num_substrings_data[i] = static_cast(fill_substrings(input_data[i], delimiter_, splits_data, splits_index, last_dim)); - splits_index += last_dim; - } + } + auto splits_data = context->Output(0, splits_shape)->template MutableDataAsSpan(); + auto splits_index = 0; + for (auto i = 0; i < input->Shape().Size(); ++i) { + num_substrings_data[i] = static_cast(fill_substrings(input_data[i], delimiter_, splits_data, splits_index, last_dim)); + splits_index += last_dim; } return Status::OK(); } diff --git a/onnxruntime/test/providers/cpu/nn/string_split_test.cc b/onnxruntime/test/providers/cpu/nn/string_split_test.cc index 8d6dffaf76ce0..a0a426bd40e2e 100644 --- a/onnxruntime/test/providers/cpu/nn/string_split_test.cc +++ b/onnxruntime/test/providers/cpu/nn/string_split_test.cc @@ -22,5 +22,28 @@ TEST(StringSplit, MaxSplitTest) { test.AddOutput("Z", {2, 2}, {2, 1, 2, 1}); } +TEST(StringSplit, EmptyStringDelimiterTest) { + OpTester test("StringSplit", 20); + test.AddInput("X", {1, 4}, {"hello world", "hello world", " hello world", "hello world "}); + test.AddAttribute("delimiter", ""); + test.AddOutput("Y", {1, 4, 3}, {"hello", "world", "", "hello", "world", "", "", "hello", "world", "hello", "world", ""}); + test.AddOutput("Z", {1, 4}, {2, 2, 2, 2}); +} + +TEST(StringSplit, SubsequentWhitespaceDefaultTest) { + OpTester test("StringSplit", 20); + test.AddInput("X", {1, 4}, {"hello world", "hello world", " hello world", "hello world "}); + test.AddOutput("Y", {1, 4, 3}, {"hello", "world", "", "hello", "world", "", "", "hello", "world", "hello", "world", ""}); + test.AddOutput("Z", {1, 4}, {2, 2, 2, 2}); +} + +TEST(StringSplit, SubsequentWhitespaceWithLimitTest) { + OpTester test("StringSplit", 20); + test.AddInput("X", {1, 4}, {"lorem ipsum doler", "Open Neural Network Exchange (ONNX)", "onnx", "ONNX runtime "}); + test.AddAttribute("maxsplit", 1); + test.AddOutput("Y", {1, 4, 2}, {"lorem", "ipsum doler", "Open", "Neural Network Exchange (ONNX)", "onnx", "", "ONNX", "runtime "}); + test.AddOutput("Z", {1, 4}, {2, 2, 1, 2}); +} + } // namespace test } // namespace onnxruntime From 45edebffc587160a009561a6c18172cf19a13a1c Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Wed, 18 Oct 2023 16:12:22 +0100 Subject: [PATCH 03/15] Bug fix --- .../core/providers/cpu/nn/string_split.cc | 70 ++++++++----------- .../providers/cpu/nn/string_split_test.cc | 14 ++-- 2 files changed, 37 insertions(+), 47 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/string_split.cc b/onnxruntime/core/providers/cpu/nn/string_split.cc index 0a1abcef30117..592575327cae8 100644 --- a/onnxruntime/core/providers/cpu/nn/string_split.cc +++ b/onnxruntime/core/providers/cpu/nn/string_split.cc @@ -12,82 +12,66 @@ ONNX_CPU_OPERATOR_KERNEL( .TypeConstraint("T3", DataTypeImpl::GetTensorType()), StringSplit); -int64_t countSubstrings(const std::string& str, const std::string& substr) { +int64_t countSubstrings(std::string_view str, std::string_view substr) { if (str.empty()) { return 0; } if (substr.empty()) { // Count consecutive whitespace as one delimiter - bool in_whitespace = false; int64_t count = 1; - for (const auto& c : str) { - if (isspace(c)) { - if (!in_whitespace) { - in_whitespace = true; - count += 1; - } - } else { - in_whitespace = false; - } + size_t pos = str.find_first_not_of(" "); + while (pos != std::string::npos) { + ++count; + pos = str.find_first_not_of(" ", str.find_first_of(" ", pos)); } return count; } else { int64_t count = 1; - size_t pos = 0; - while ((pos = str.find(substr, pos)) != std::string::npos) { + size_t pos = str.find(substr); + while (pos != std::string::npos) { ++count; - pos += substr.length(); + pos = str.find(substr, pos + substr.length()); } return count; } } -size_t fill_substrings(const std::string& str, const std::string& delimiter, gsl::span output, int64_t output_index, size_t max_tokens) { +int64_t fillSubstrings(std::string_view str, std::string_view delimiter, gsl::span output, int64_t output_index, size_t max_tokens) { + if (str.empty()) { + return 0; + } if (delimiter.empty()) { - // Count consecutive whitespace as one delimiter - bool in_whitespace = false; + // Count consecutive whitespace as one delimiter. Preceding and trailing whitespace is meant to be ignored. + size_t pos = str.find_first_not_of(" "); size_t token_index = 0; - size_t substr_start_index = 0; - - for (size_t i = 0; i < str.size(); ++i) { - if (token_index == max_tokens - 1) { - // if we are at the max_tokens-1 substring, the next and final substring should be the remainder of the string - output[output_index + token_index] = str.substr(i); - ++token_index; + while (token_index < max_tokens && pos != std::string::npos) { + auto next_pos = token_index == max_tokens - 1 ? std::string::npos : str.find_first_of(" ", pos); + output[output_index + token_index] = str.substr(pos, next_pos - pos); + ++token_index; + if (next_pos == std::string::npos) { break; } - if (!isspace(str[i])) { - // if currently in a whitespace, this index marks the start of the current substring - if (in_whitespace) { - substr_start_index = i; - in_whitespace = false; - } - } else if (!in_whitespace) { - // if currently not in whitespace, this index is the end of a substring - output[output_index + token_index] = str.substr(substr_start_index, i - substr_start_index); - in_whitespace = true; - ++token_index; - } + pos = str.find_first_not_of(" ", next_pos); } return token_index; } else { size_t pos = 0; size_t token_index = 0; - while (token_index < max_tokens) { - auto next_pos = str.find(delimiter, pos); + while (token_index < max_tokens && pos != std::string::npos) { + auto next_pos = token_index == max_tokens - 1 ? std::string::npos : str.find(delimiter, pos); output[output_index + token_index] = str.substr(pos, next_pos - pos); - pos = next_pos + delimiter.size(); ++token_index; if (next_pos == std::string::npos) { break; } + pos = next_pos + delimiter.size(); } return token_index; } } StringSplit::StringSplit(const OpKernelInfo& info) : OpKernel(info) { - info.GetAttrOrDefault("maxsplit", &maxsplit_, static_cast(-1)); // TODO is this the right thing to do here? + info.GetAttrOrDefault("maxsplit", &maxsplit_, std::numeric_limits::max() - 1); // TODO is this the right thing to do here? info.GetAttrOrDefault("delimiter", &delimiter_, std::string("")); } @@ -101,12 +85,14 @@ Status StringSplit::Compute(OpKernelContext* context) const { return Status(common::ONNXRUNTIME, common::FAIL, "output count mismatch"); } auto input_data = input->template DataAsSpan(); - auto last_dim = maxsplit_; + int64_t last_dim = 0; for (auto i = 0; i < input->Shape().Size(); ++i) { auto substring_count = countSubstrings(input_data[i], delimiter_); last_dim = std::max(last_dim, substring_count); } + last_dim = std::min(last_dim, maxsplit_ + 1); + // Set up num_substrings output auto num_substrings_data = num_substrings->template MutableDataAsSpan(); // Set up splits output @@ -117,7 +103,7 @@ Status StringSplit::Compute(OpKernelContext* context) const { auto splits_data = context->Output(0, splits_shape)->template MutableDataAsSpan(); auto splits_index = 0; for (auto i = 0; i < input->Shape().Size(); ++i) { - num_substrings_data[i] = static_cast(fill_substrings(input_data[i], delimiter_, splits_data, splits_index, last_dim)); + num_substrings_data[i] = fillSubstrings(input_data[i], delimiter_, splits_data, splits_index, last_dim); splits_index += last_dim; } return Status::OK(); diff --git a/onnxruntime/test/providers/cpu/nn/string_split_test.cc b/onnxruntime/test/providers/cpu/nn/string_split_test.cc index a0a426bd40e2e..e4aa5a8161339 100644 --- a/onnxruntime/test/providers/cpu/nn/string_split_test.cc +++ b/onnxruntime/test/providers/cpu/nn/string_split_test.cc @@ -19,30 +19,34 @@ TEST(StringSplit, MaxSplitTest) { test.AddAttribute("delimiter", ";"); test.AddAttribute("maxsplit", 1); test.AddOutput("Y", {2, 2, 2}, {"eggs", "milk;chesse", "pepper", "salt", "chicken", "fish;pork", "spinach", ""}); - test.AddOutput("Z", {2, 2}, {2, 1, 2, 1}); + test.AddOutput("Z", {2, 2}, {2, 2, 2, 1}); + test.Run(); } TEST(StringSplit, EmptyStringDelimiterTest) { OpTester test("StringSplit", 20); test.AddInput("X", {1, 4}, {"hello world", "hello world", " hello world", "hello world "}); test.AddAttribute("delimiter", ""); - test.AddOutput("Y", {1, 4, 3}, {"hello", "world", "", "hello", "world", "", "", "hello", "world", "hello", "world", ""}); + test.AddOutput("Y", {1, 4, 3}, {"hello", "world", "", "hello", "world", "", "hello", "world", "", "hello", "world", ""}); test.AddOutput("Z", {1, 4}, {2, 2, 2, 2}); + test.Run(); } TEST(StringSplit, SubsequentWhitespaceDefaultTest) { OpTester test("StringSplit", 20); - test.AddInput("X", {1, 4}, {"hello world", "hello world", " hello world", "hello world "}); - test.AddOutput("Y", {1, 4, 3}, {"hello", "world", "", "hello", "world", "", "", "hello", "world", "hello", "world", ""}); + test.AddInput("X", {1, 4}, {"hello world", "hello world", " hello world", "hello world "}); + test.AddOutput("Y", {1, 4, 3}, {"hello", "world", "", "hello", "world", "", "hello", "world", "", "hello", "world", ""}); test.AddOutput("Z", {1, 4}, {2, 2, 2, 2}); + test.Run(); } TEST(StringSplit, SubsequentWhitespaceWithLimitTest) { OpTester test("StringSplit", 20); - test.AddInput("X", {1, 4}, {"lorem ipsum doler", "Open Neural Network Exchange (ONNX)", "onnx", "ONNX runtime "}); + test.AddInput("X", {1, 4}, {"lorem ipsum doler", " Open Neural Network Exchange (ONNX)", "onnx", "ONNX runtime "}); test.AddAttribute("maxsplit", 1); test.AddOutput("Y", {1, 4, 2}, {"lorem", "ipsum doler", "Open", "Neural Network Exchange (ONNX)", "onnx", "", "ONNX", "runtime "}); test.AddOutput("Z", {1, 4}, {2, 2, 1, 2}); + test.Run(); } } // namespace test From 6ab138a297780fcaddbf9f515cf77f366017dcac Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Tue, 9 Jan 2024 03:55:02 +0000 Subject: [PATCH 04/15] Implement feedback --- .../cpu/quantization/qlinear_concat.cc | 2 +- .../core/providers/cpu/nn/string_split.cc | 91 ++++++++++--------- .../core/providers/cpu/nn/string_split.h | 3 + .../providers/cpu/nn/string_split_test.cc | 40 +++++++- 4 files changed, 91 insertions(+), 45 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc b/onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc index ee9ae7167945c..af163b6be702b 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc @@ -1,4 +1,4 @@ -// Copyright (c Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "qlinear_util.h" diff --git a/onnxruntime/core/providers/cpu/nn/string_split.cc b/onnxruntime/core/providers/cpu/nn/string_split.cc index 592575327cae8..eaf20075ceea4 100644 --- a/onnxruntime/core/providers/cpu/nn/string_split.cc +++ b/onnxruntime/core/providers/cpu/nn/string_split.cc @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include "string_split.h" #include "core/common/common.h" #include @@ -12,13 +15,11 @@ ONNX_CPU_OPERATOR_KERNEL( .TypeConstraint("T3", DataTypeImpl::GetTensorType()), StringSplit); -int64_t countSubstrings(std::string_view str, std::string_view substr) { - if (str.empty()) { - return 0; - } +/// Count the number of instances of substring ``substr`` in ``str``. If ``substr`` is an empty string it counts the number of whitespace delimited words. +int64_t CountSubstrings(std::string_view str, std::string_view substr) { if (substr.empty()) { // Count consecutive whitespace as one delimiter - int64_t count = 1; + int64_t count = 0; size_t pos = str.find_first_not_of(" "); while (pos != std::string::npos) { ++count; @@ -26,85 +27,93 @@ int64_t countSubstrings(std::string_view str, std::string_view substr) { } return count; } else { - int64_t count = 1; - size_t pos = str.find(substr); + int64_t count = 0; + size_t pos = 0; while (pos != std::string::npos) { ++count; - pos = str.find(substr, pos + substr.length()); + pos = str.find(substr, pos); + if (pos != std::string::npos) { + pos += substr.length(); + } } return count; } } -int64_t fillSubstrings(std::string_view str, std::string_view delimiter, gsl::span output, int64_t output_index, size_t max_tokens) { +/// Fill substrings of ``str`` based on split delimiter ``delimiter`` into ``output`` span. Restrict maximum number of generated substrings to ``max_tokens``. +/// The function returns the number of substrings generated (this is less or equal to ``max_tokens``). +int64_t FillSubstrings(std::string_view str, std::string_view delimiter, gsl::details::span_iterator output, size_t max_tokens) { if (str.empty()) { return 0; } if (delimiter.empty()) { // Count consecutive whitespace as one delimiter. Preceding and trailing whitespace is meant to be ignored. size_t pos = str.find_first_not_of(" "); - size_t token_index = 0; - while (token_index < max_tokens && pos != std::string::npos) { - auto next_pos = token_index == max_tokens - 1 ? std::string::npos : str.find_first_of(" ", pos); - output[output_index + token_index] = str.substr(pos, next_pos - pos); - ++token_index; - if (next_pos == std::string::npos) { + int64_t token_count = 0; + while (pos != std::string::npos) { + if (++token_count == max_tokens) { + // trim down last substring as required in specification + size_t next_pos = str.length() - 1; + while (str[next_pos] == ' ') { + next_pos--; + } + *output = str.substr(pos, next_pos - pos + 1); break; + } else { + auto next_pos = str.find_first_of(" ", pos); + *output = str.substr(pos, next_pos - pos); + pos = str.find_first_not_of(" ", next_pos); } - pos = str.find_first_not_of(" ", next_pos); + + output++; } - return token_index; + return token_count; } else { size_t pos = 0; - size_t token_index = 0; - while (token_index < max_tokens && pos != std::string::npos) { - auto next_pos = token_index == max_tokens - 1 ? std::string::npos : str.find(delimiter, pos); - output[output_index + token_index] = str.substr(pos, next_pos - pos); - ++token_index; + int64_t token_count = 0; + while (pos != std::string::npos) { + auto next_pos = token_count == max_tokens - 1 ? std::string::npos : str.find(delimiter, pos); + *output++ = str.substr(pos, next_pos - pos); + token_count++; if (next_pos == std::string::npos) { break; } pos = next_pos + delimiter.size(); } - return token_index; + return token_count; } } StringSplit::StringSplit(const OpKernelInfo& info) : OpKernel(info) { - info.GetAttrOrDefault("maxsplit", &maxsplit_, std::numeric_limits::max() - 1); // TODO is this the right thing to do here? + info.GetAttrOrDefault("maxsplit", &maxsplit_, std::numeric_limits::max() - 1); info.GetAttrOrDefault("delimiter", &delimiter_, std::string("")); } Status StringSplit::Compute(OpKernelContext* context) const { const Tensor* input = context->Input(0); - if (nullptr == input) { - return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); - } - auto* num_substrings = context->Output(1, input->Shape()); - if (nullptr == num_substrings) { - return Status(common::ONNXRUNTIME, common::FAIL, "output count mismatch"); - } auto input_data = input->template DataAsSpan(); + int64_t last_dim = 0; - for (auto i = 0; i < input->Shape().Size(); ++i) { - auto substring_count = countSubstrings(input_data[i], delimiter_); - last_dim = std::max(last_dim, substring_count); + for (const auto& str : input_data) { + last_dim = std::max(last_dim, CountSubstrings(str, delimiter_)); } - last_dim = std::min(last_dim, maxsplit_ + 1); - // Set up num_substrings output - auto num_substrings_data = num_substrings->template MutableDataAsSpan(); // Set up splits output auto splits_shape = input->Shape().AsShapeVector(); if (last_dim > 0) { splits_shape.push_back(last_dim); } auto splits_data = context->Output(0, splits_shape)->template MutableDataAsSpan(); - auto splits_index = 0; - for (auto i = 0; i < input->Shape().Size(); ++i) { - num_substrings_data[i] = fillSubstrings(input_data[i], delimiter_, splits_data, splits_index, last_dim); - splits_index += last_dim; + auto output_splits_iter = splits_data.begin(); + + // Set up number of tokens output + auto* num_substrings = context->Output(1, input->Shape()); + auto num_substrings_data = num_substrings->template MutableDataAsSpan(); + auto output_num_tokens_iter = num_substrings_data.begin(); + + for (auto input_iter = input_data.begin(); input_iter != input_data.end(); input_iter++, output_splits_iter += last_dim, output_num_tokens_iter++) { + *output_num_tokens_iter = FillSubstrings(*input_iter, delimiter_, output_splits_iter, last_dim); } return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/nn/string_split.h b/onnxruntime/core/providers/cpu/nn/string_split.h index f9ded9cafc760..6be249261d4e3 100644 --- a/onnxruntime/core/providers/cpu/nn/string_split.h +++ b/onnxruntime/core/providers/cpu/nn/string_split.h @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #pragma once #include "core/framework/op_kernel.h" diff --git a/onnxruntime/test/providers/cpu/nn/string_split_test.cc b/onnxruntime/test/providers/cpu/nn/string_split_test.cc index e4aa5a8161339..a45fe736c3df4 100644 --- a/onnxruntime/test/providers/cpu/nn/string_split_test.cc +++ b/onnxruntime/test/providers/cpu/nn/string_split_test.cc @@ -27,7 +27,7 @@ TEST(StringSplit, EmptyStringDelimiterTest) { OpTester test("StringSplit", 20); test.AddInput("X", {1, 4}, {"hello world", "hello world", " hello world", "hello world "}); test.AddAttribute("delimiter", ""); - test.AddOutput("Y", {1, 4, 3}, {"hello", "world", "", "hello", "world", "", "hello", "world", "", "hello", "world", ""}); + test.AddOutput("Y", {1, 4, 2}, {"hello", "world", "hello", "world", "hello", "world", "hello", "world"}); test.AddOutput("Z", {1, 4}, {2, 2, 2, 2}); test.Run(); } @@ -35,7 +35,7 @@ TEST(StringSplit, EmptyStringDelimiterTest) { TEST(StringSplit, SubsequentWhitespaceDefaultTest) { OpTester test("StringSplit", 20); test.AddInput("X", {1, 4}, {"hello world", "hello world", " hello world", "hello world "}); - test.AddOutput("Y", {1, 4, 3}, {"hello", "world", "", "hello", "world", "", "hello", "world", "", "hello", "world", ""}); + test.AddOutput("Y", {1, 4, 2}, {"hello", "world", "hello", "world", "hello", "world", "hello", "world"}); test.AddOutput("Z", {1, 4}, {2, 2, 2, 2}); test.Run(); } @@ -44,10 +44,44 @@ TEST(StringSplit, SubsequentWhitespaceWithLimitTest) { OpTester test("StringSplit", 20); test.AddInput("X", {1, 4}, {"lorem ipsum doler", " Open Neural Network Exchange (ONNX)", "onnx", "ONNX runtime "}); test.AddAttribute("maxsplit", 1); - test.AddOutput("Y", {1, 4, 2}, {"lorem", "ipsum doler", "Open", "Neural Network Exchange (ONNX)", "onnx", "", "ONNX", "runtime "}); + test.AddOutput("Y", {1, 4, 2}, {"lorem", "ipsum doler", "Open", "Neural Network Exchange (ONNX)", "onnx", "", "ONNX", "runtime"}); test.AddOutput("Z", {1, 4}, {2, 2, 1, 2}); test.Run(); } +TEST(StringSplit, SingleTokenTest) { + OpTester test("StringSplit", 20); + test.AddAttribute("delimiter", "*"); + test.AddInput("X", {1, 1, 1}, {"lorem"}); + test.AddOutput("Y", {1, 1, 1, 1}, {"lorem"}); + test.AddOutput("Z", {1, 1, 1}, {1}); + test.Run(); +} + +TEST(StringSplit, SingleTokenWhitespaceTest) { + OpTester test("StringSplit", 20); + test.AddInput("X", {1, 1, 1}, {"lorem"}); + test.AddOutput("Y", {1, 1, 1, 1}, {"lorem"}); + test.AddOutput("Z", {1, 1, 1}, {1}); + test.Run(); +} + +TEST(StringSplit, EdgeWhitespaceTest) { + OpTester test("StringSplit", 20); + test.AddInput("X", {1, 1, 1}, {" lorem "}); + test.AddOutput("Y", {1, 1, 1, 1}, {"lorem"}); + test.AddOutput("Z", {1, 1, 1}, {1}); + test.Run(); +} + +TEST(StringSplit, EmptyInputTest) { + OpTester test("StringSplit", 20); + test.AddAttribute("delimiter", "*"); + test.AddInput("X", {1, 3, 1}, {"", "+", "*"}); + test.AddOutput("Y", {1, 3, 1, 2}, {"", "", "+", "", "", ""}); + test.AddOutput("Z", {1, 3, 1}, {0, 1, 2}); + test.Run(); +} + } // namespace test } // namespace onnxruntime From 052617e6d59e629476a681755ca75236dfe565d3 Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Tue, 9 Jan 2024 04:16:11 +0000 Subject: [PATCH 05/15] Drop backend filter --- .../test/testdata/onnx_backend_test_series_filters.jsonc | 6 ------ 1 file changed, 6 deletions(-) diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 3a13e39702904..1108e814dd9d7 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -256,12 +256,6 @@ "^test_string_concat_empty_string", "^test_string_concat_utf8", "^test_string_concat_zero_dimensional", - "^test_string_split_basic", - "^test_string_split_consecutive_delimiters", - "^test_string_split_empty_string_delimiter", - "^test_string_split_empty_tensor", - "^test_string_split_maxsplit", - "^test_string_split_no_delimiter", "^test_reduce_l1_empty_set_cuda", "^test_reduce_l1_empty_set_expanded_cuda", "^test_reduce_l2_empty_set_cuda", From c333ad5675120dfa36a6b4083b1cd32de7d9d04b Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Tue, 9 Jan 2024 17:38:59 +0000 Subject: [PATCH 06/15] Set 120 column limit in .clang-format --- .../core/providers/cpu/nn/string_split.cc | 30 ++++++++++--------- .../providers/cpu/nn/string_split_test.cc | 10 +++++-- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/string_split.cc b/onnxruntime/core/providers/cpu/nn/string_split.cc index eaf20075ceea4..0c788c3fc54c0 100644 --- a/onnxruntime/core/providers/cpu/nn/string_split.cc +++ b/onnxruntime/core/providers/cpu/nn/string_split.cc @@ -2,20 +2,19 @@ // Licensed under the MIT License. #include "string_split.h" -#include "core/common/common.h" #include +#include "core/common/common.h" namespace onnxruntime { -ONNX_CPU_OPERATOR_KERNEL( - StringSplit, - 20, - KernelDefBuilder() - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) - .TypeConstraint("T2", DataTypeImpl::GetTensorType()) - .TypeConstraint("T3", DataTypeImpl::GetTensorType()), - StringSplit); +ONNX_CPU_OPERATOR_KERNEL(StringSplit, 20, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + StringSplit); -/// Count the number of instances of substring ``substr`` in ``str``. If ``substr`` is an empty string it counts the number of whitespace delimited words. +/// Count the number of instances of substring ``substr`` in ``str``. If ``substr`` is an empty string it counts the +/// number of whitespace delimited words. int64_t CountSubstrings(std::string_view str, std::string_view substr) { if (substr.empty()) { // Count consecutive whitespace as one delimiter @@ -40,9 +39,11 @@ int64_t CountSubstrings(std::string_view str, std::string_view substr) { } } -/// Fill substrings of ``str`` based on split delimiter ``delimiter`` into ``output`` span. Restrict maximum number of generated substrings to ``max_tokens``. -/// The function returns the number of substrings generated (this is less or equal to ``max_tokens``). -int64_t FillSubstrings(std::string_view str, std::string_view delimiter, gsl::details::span_iterator output, size_t max_tokens) { +/// Fill substrings of ``str`` based on split delimiter ``delimiter`` into ``output`` span. Restrict maximum number of +/// generated substrings to ``max_tokens``. The function returns the number of substrings generated (this is less or +/// equal to ``max_tokens``). +int64_t FillSubstrings(std::string_view str, std::string_view delimiter, + gsl::details::span_iterator output, size_t max_tokens) { if (str.empty()) { return 0; } @@ -112,7 +113,8 @@ Status StringSplit::Compute(OpKernelContext* context) const { auto num_substrings_data = num_substrings->template MutableDataAsSpan(); auto output_num_tokens_iter = num_substrings_data.begin(); - for (auto input_iter = input_data.begin(); input_iter != input_data.end(); input_iter++, output_splits_iter += last_dim, output_num_tokens_iter++) { + for (auto input_iter = input_data.begin(); input_iter != input_data.end(); + input_iter++, output_splits_iter += last_dim, output_num_tokens_iter++) { *output_num_tokens_iter = FillSubstrings(*input_iter, delimiter_, output_splits_iter, last_dim); } return Status::OK(); diff --git a/onnxruntime/test/providers/cpu/nn/string_split_test.cc b/onnxruntime/test/providers/cpu/nn/string_split_test.cc index a45fe736c3df4..286b4e41b384c 100644 --- a/onnxruntime/test/providers/cpu/nn/string_split_test.cc +++ b/onnxruntime/test/providers/cpu/nn/string_split_test.cc @@ -18,7 +18,8 @@ TEST(StringSplit, MaxSplitTest) { test.AddInput("X", {2, 2}, {"eggs;milk;chesse", "pepper;salt", "chicken;fish;pork", "spinach"}); test.AddAttribute("delimiter", ";"); test.AddAttribute("maxsplit", 1); - test.AddOutput("Y", {2, 2, 2}, {"eggs", "milk;chesse", "pepper", "salt", "chicken", "fish;pork", "spinach", ""}); + test.AddOutput("Y", {2, 2, 2}, + {"eggs", "milk;chesse", "pepper", "salt", "chicken", "fish;pork", "spinach", ""}); test.AddOutput("Z", {2, 2}, {2, 2, 2, 1}); test.Run(); } @@ -42,9 +43,12 @@ TEST(StringSplit, SubsequentWhitespaceDefaultTest) { TEST(StringSplit, SubsequentWhitespaceWithLimitTest) { OpTester test("StringSplit", 20); - test.AddInput("X", {1, 4}, {"lorem ipsum doler", " Open Neural Network Exchange (ONNX)", "onnx", "ONNX runtime "}); + test.AddInput("X", {1, 4}, + {"lorem ipsum doler", " Open Neural Network Exchange (ONNX)", "onnx", "ONNX runtime "}); test.AddAttribute("maxsplit", 1); - test.AddOutput("Y", {1, 4, 2}, {"lorem", "ipsum doler", "Open", "Neural Network Exchange (ONNX)", "onnx", "", "ONNX", "runtime"}); + test.AddOutput( + "Y", {1, 4, 2}, + {"lorem", "ipsum doler", "Open", "Neural Network Exchange (ONNX)", "onnx", "", "ONNX", "runtime"}); test.AddOutput("Z", {1, 4}, {2, 2, 1, 2}); test.Run(); } From d303e5bcfca249aed666a531a1a68539a8a34a45 Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Tue, 9 Jan 2024 23:21:14 +0000 Subject: [PATCH 07/15] Obtain string slices in first pass and remove CountSubstrings function entirely --- .../core/providers/cpu/nn/string_split.cc | 95 ++++++++----------- .../providers/cpu/nn/string_split_test.cc | 17 ++++ 2 files changed, 55 insertions(+), 57 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/string_split.cc b/onnxruntime/core/providers/cpu/nn/string_split.cc index 0c788c3fc54c0..eadd189c1e6ee 100644 --- a/onnxruntime/core/providers/cpu/nn/string_split.cc +++ b/onnxruntime/core/providers/cpu/nn/string_split.cc @@ -1,8 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "string_split.h" +#include "core/providers/cpu/nn/string_split.h" #include +#include +#include #include "core/common/common.h" namespace onnxruntime { @@ -13,75 +15,47 @@ ONNX_CPU_OPERATOR_KERNEL(StringSplit, 20, .TypeConstraint("T3", DataTypeImpl::GetTensorType()), StringSplit); -/// Count the number of instances of substring ``substr`` in ``str``. If ``substr`` is an empty string it counts the -/// number of whitespace delimited words. -int64_t CountSubstrings(std::string_view str, std::string_view substr) { - if (substr.empty()) { - // Count consecutive whitespace as one delimiter - int64_t count = 0; - size_t pos = str.find_first_not_of(" "); - while (pos != std::string::npos) { - ++count; - pos = str.find_first_not_of(" ", str.find_first_of(" ", pos)); - } - return count; - } else { - int64_t count = 0; - size_t pos = 0; - while (pos != std::string::npos) { - ++count; - pos = str.find(substr, pos); - if (pos != std::string::npos) { - pos += substr.length(); - } - } - return count; - } -} - /// Fill substrings of ``str`` based on split delimiter ``delimiter`` into ``output`` span. Restrict maximum number of /// generated substrings to ``max_tokens``. The function returns the number of substrings generated (this is less or /// equal to ``max_tokens``). -int64_t FillSubstrings(std::string_view str, std::string_view delimiter, - gsl::details::span_iterator output, size_t max_tokens) { +InlinedVector FillSubstrings(std::string_view str, std::string_view delimiter, int64_t max_splits) { + InlinedVector output; if (str.empty()) { - return 0; + return output; } if (delimiter.empty()) { // Count consecutive whitespace as one delimiter. Preceding and trailing whitespace is meant to be ignored. size_t pos = str.find_first_not_of(" "); int64_t token_count = 0; while (pos != std::string::npos) { - if (++token_count == max_tokens) { + if (token_count++ == max_splits) { // trim down last substring as required in specification size_t next_pos = str.length() - 1; while (str[next_pos] == ' ') { next_pos--; } - *output = str.substr(pos, next_pos - pos + 1); + output.push_back(str.substr(pos, next_pos - pos + 1)); break; } else { auto next_pos = str.find_first_of(" ", pos); - *output = str.substr(pos, next_pos - pos); + output.push_back(str.substr(pos, next_pos - pos)); pos = str.find_first_not_of(" ", next_pos); } - - output++; } - return token_count; + return output; } else { size_t pos = 0; int64_t token_count = 0; while (pos != std::string::npos) { - auto next_pos = token_count == max_tokens - 1 ? std::string::npos : str.find(delimiter, pos); - *output++ = str.substr(pos, next_pos - pos); - token_count++; - if (next_pos == std::string::npos) { + auto next_pos = str.find(delimiter, pos); + if (token_count++ == max_splits || next_pos == std::string::npos) { + output.push_back(str.substr(pos)); break; } + output.push_back(str.substr(pos, next_pos - pos)); pos = next_pos + delimiter.size(); } - return token_count; + return output; } } @@ -94,29 +68,36 @@ Status StringSplit::Compute(OpKernelContext* context) const { const Tensor* input = context->Input(0); auto input_data = input->template DataAsSpan(); - int64_t last_dim = 0; - for (const auto& str : input_data) { - last_dim = std::max(last_dim, CountSubstrings(str, delimiter_)); + // Set up number of tokens output + auto num_tokens_data = context->Output(1, input->Shape())->template MutableDataAsSpan(); + auto num_tokens_iter = num_tokens_data.begin(); + + int64_t last_dim = 1; + + InlinedVector> input_slices; + input_slices.reserve(input_data.size()); + auto input_slice_iterator = input_slices.begin(); + for (auto input_iter = input_data.begin(); input_iter != input_data.end(); input_iter++, input_slice_iterator++, num_tokens_iter++) { + auto substrs = FillSubstrings(*input_iter, delimiter_, maxsplit_); + auto substr_count = static_cast(substrs.size()); + input_slices.push_back(substrs); + last_dim = std::max(last_dim, substr_count); + *num_tokens_iter = substr_count; } + last_dim = std::min(last_dim, maxsplit_ + 1); // Set up splits output auto splits_shape = input->Shape().AsShapeVector(); - if (last_dim > 0) { - splits_shape.push_back(last_dim); - } - auto splits_data = context->Output(0, splits_shape)->template MutableDataAsSpan(); - auto output_splits_iter = splits_data.begin(); - - // Set up number of tokens output - auto* num_substrings = context->Output(1, input->Shape()); - auto num_substrings_data = num_substrings->template MutableDataAsSpan(); - auto output_num_tokens_iter = num_substrings_data.begin(); + splits_shape.push_back(last_dim); - for (auto input_iter = input_data.begin(); input_iter != input_data.end(); - input_iter++, output_splits_iter += last_dim, output_num_tokens_iter++) { - *output_num_tokens_iter = FillSubstrings(*input_iter, delimiter_, output_splits_iter, last_dim); + auto splits_data = context->Output(0, splits_shape)->template MutableDataAsSpan(); + auto slices_iter = input_slices.begin(); + for (auto output_splits_iter = splits_data.begin(); output_splits_iter != splits_data.end(); output_splits_iter += last_dim, slices_iter++) { + const auto output_slices = *slices_iter; + std::copy(output_slices.begin(), output_slices.end(), output_splits_iter); } + return Status::OK(); } diff --git a/onnxruntime/test/providers/cpu/nn/string_split_test.cc b/onnxruntime/test/providers/cpu/nn/string_split_test.cc index 286b4e41b384c..d2f37168babe3 100644 --- a/onnxruntime/test/providers/cpu/nn/string_split_test.cc +++ b/onnxruntime/test/providers/cpu/nn/string_split_test.cc @@ -87,5 +87,22 @@ TEST(StringSplit, EmptyInputTest) { test.Run(); } +TEST(StringSplit, OnlyEmptyInputTest) { + OpTester test("StringSplit", 20); + test.AddAttribute("delimiter", "*"); + test.AddInput("X", {1, 2, 1}, {"", ""}); + test.AddOutput("Y", {1, 2, 1, 1}, {"", ""}); + test.AddOutput("Z", {1, 2, 1}, {0, 0}); + test.Run(); +} + +TEST(StringSplit, OnlyEmptyNoDelimiterInputTest) { + OpTester test("StringSplit", 20); + test.AddInput("X", {1, 2, 1}, {"", ""}); + test.AddOutput("Y", {1, 2, 1, 1}, {"", ""}); + test.AddOutput("Z", {1, 2, 1}, {0, 0}); + test.Run(); +} + } // namespace test } // namespace onnxruntime From 14b88436e2c53c96bb881976d7b9b8647a683a8c Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Tue, 9 Jan 2024 23:25:58 +0000 Subject: [PATCH 08/15] Fix docstring --- onnxruntime/core/providers/cpu/nn/string_split.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/string_split.cc b/onnxruntime/core/providers/cpu/nn/string_split.cc index eadd189c1e6ee..c6add27eb7484 100644 --- a/onnxruntime/core/providers/cpu/nn/string_split.cc +++ b/onnxruntime/core/providers/cpu/nn/string_split.cc @@ -15,9 +15,9 @@ ONNX_CPU_OPERATOR_KERNEL(StringSplit, 20, .TypeConstraint("T3", DataTypeImpl::GetTensorType()), StringSplit); -/// Fill substrings of ``str`` based on split delimiter ``delimiter`` into ``output`` span. Restrict maximum number of -/// generated substrings to ``max_tokens``. The function returns the number of substrings generated (this is less or -/// equal to ``max_tokens``). +/// Calculate substrings in ``str`` delimited by ``delimiter``. A maximum of ``max_splits`` splits are permitted. +/// Returns a vector of string slices into ``str`` representing the substrings as string views. The user must ensure +/// the returned views' lifetime does not exceed ``str``'s. InlinedVector FillSubstrings(std::string_view str, std::string_view delimiter, int64_t max_splits) { InlinedVector output; if (str.empty()) { @@ -29,7 +29,7 @@ InlinedVector FillSubstrings(std::string_view str, std::string int64_t token_count = 0; while (pos != std::string::npos) { if (token_count++ == max_splits) { - // trim down last substring as required in specification + // Trim down last substring as required in specification size_t next_pos = str.length() - 1; while (str[next_pos] == ' ') { next_pos--; @@ -72,10 +72,10 @@ Status StringSplit::Compute(OpKernelContext* context) const { auto num_tokens_data = context->Output(1, input->Shape())->template MutableDataAsSpan(); auto num_tokens_iter = num_tokens_data.begin(); - int64_t last_dim = 1; - InlinedVector> input_slices; input_slices.reserve(input_data.size()); + int64_t last_dim = 1; + auto input_slice_iterator = input_slices.begin(); for (auto input_iter = input_data.begin(); input_iter != input_data.end(); input_iter++, input_slice_iterator++, num_tokens_iter++) { auto substrs = FillSubstrings(*input_iter, delimiter_, maxsplit_); From 71097121533116a94214ee11eab851a04640640d Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Tue, 9 Jan 2024 23:30:48 +0000 Subject: [PATCH 09/15] Inline variable --- onnxruntime/core/providers/cpu/nn/string_split.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/string_split.cc b/onnxruntime/core/providers/cpu/nn/string_split.cc index c6add27eb7484..aee6c9bcd0a8b 100644 --- a/onnxruntime/core/providers/cpu/nn/string_split.cc +++ b/onnxruntime/core/providers/cpu/nn/string_split.cc @@ -94,8 +94,7 @@ Status StringSplit::Compute(OpKernelContext* context) const { auto splits_data = context->Output(0, splits_shape)->template MutableDataAsSpan(); auto slices_iter = input_slices.begin(); for (auto output_splits_iter = splits_data.begin(); output_splits_iter != splits_data.end(); output_splits_iter += last_dim, slices_iter++) { - const auto output_slices = *slices_iter; - std::copy(output_slices.begin(), output_slices.end(), output_splits_iter); + std::copy(slices_iter->begin(), slices_iter->end(), output_splits_iter); } return Status::OK(); From cf91ede550a87578e19518db62d0a53ebf7841bf Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Tue, 9 Jan 2024 23:34:56 +0000 Subject: [PATCH 10/15] Update substring calculation function name --- onnxruntime/core/providers/cpu/nn/string_split.cc | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/string_split.cc b/onnxruntime/core/providers/cpu/nn/string_split.cc index aee6c9bcd0a8b..55fc8eb6202f4 100644 --- a/onnxruntime/core/providers/cpu/nn/string_split.cc +++ b/onnxruntime/core/providers/cpu/nn/string_split.cc @@ -18,7 +18,7 @@ ONNX_CPU_OPERATOR_KERNEL(StringSplit, 20, /// Calculate substrings in ``str`` delimited by ``delimiter``. A maximum of ``max_splits`` splits are permitted. /// Returns a vector of string slices into ``str`` representing the substrings as string views. The user must ensure /// the returned views' lifetime does not exceed ``str``'s. -InlinedVector FillSubstrings(std::string_view str, std::string_view delimiter, int64_t max_splits) { +InlinedVector ComputeSubstrings(std::string_view str, std::string_view delimiter, int64_t max_splits) { InlinedVector output; if (str.empty()) { return output; @@ -61,7 +61,7 @@ InlinedVector FillSubstrings(std::string_view str, std::string StringSplit::StringSplit(const OpKernelInfo& info) : OpKernel(info) { info.GetAttrOrDefault("maxsplit", &maxsplit_, std::numeric_limits::max() - 1); - info.GetAttrOrDefault("delimiter", &delimiter_, std::string("")); + info.GetAttrOrDefault("delimiter", &delimiter_, std::string()); } Status StringSplit::Compute(OpKernelContext* context) const { @@ -76,11 +76,10 @@ Status StringSplit::Compute(OpKernelContext* context) const { input_slices.reserve(input_data.size()); int64_t last_dim = 1; - auto input_slice_iterator = input_slices.begin(); - for (auto input_iter = input_data.begin(); input_iter != input_data.end(); input_iter++, input_slice_iterator++, num_tokens_iter++) { - auto substrs = FillSubstrings(*input_iter, delimiter_, maxsplit_); + for (auto input_iter = input_data.begin(); input_iter != input_data.end(); input_iter++, num_tokens_iter++) { + auto substrs = ComputeSubstrings(*input_iter, delimiter_, maxsplit_); auto substr_count = static_cast(substrs.size()); - input_slices.push_back(substrs); + input_slices.push_back(std::move(substrs)); last_dim = std::max(last_dim, substr_count); *num_tokens_iter = substr_count; } From ff69ae8f4143f56574a4b8a5af930d1286b0fbe7 Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Wed, 10 Jan 2024 00:53:29 +0000 Subject: [PATCH 11/15] Drop redundant last_dim min --- onnxruntime/core/providers/cpu/nn/string_split.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/string_split.cc b/onnxruntime/core/providers/cpu/nn/string_split.cc index 55fc8eb6202f4..a1f8803117ccf 100644 --- a/onnxruntime/core/providers/cpu/nn/string_split.cc +++ b/onnxruntime/core/providers/cpu/nn/string_split.cc @@ -76,7 +76,7 @@ Status StringSplit::Compute(OpKernelContext* context) const { input_slices.reserve(input_data.size()); int64_t last_dim = 1; - for (auto input_iter = input_data.begin(); input_iter != input_data.end(); input_iter++, num_tokens_iter++) { + for (auto input_iter = input_data.begin(); input_iter != input_data.end(); ++input_iter, ++num_tokens_iter) { auto substrs = ComputeSubstrings(*input_iter, delimiter_, maxsplit_); auto substr_count = static_cast(substrs.size()); input_slices.push_back(std::move(substrs)); @@ -84,15 +84,13 @@ Status StringSplit::Compute(OpKernelContext* context) const { *num_tokens_iter = substr_count; } - last_dim = std::min(last_dim, maxsplit_ + 1); - // Set up splits output auto splits_shape = input->Shape().AsShapeVector(); splits_shape.push_back(last_dim); auto splits_data = context->Output(0, splits_shape)->template MutableDataAsSpan(); auto slices_iter = input_slices.begin(); - for (auto output_splits_iter = splits_data.begin(); output_splits_iter != splits_data.end(); output_splits_iter += last_dim, slices_iter++) { + for (auto output_splits_iter = splits_data.begin(); output_splits_iter != splits_data.end(); output_splits_iter += last_dim, ++slices_iter) { std::copy(slices_iter->begin(), slices_iter->end(), output_splits_iter); } From 7fdd6dc60655e3c2a05c9e8eeadc4d2faf2eabe8 Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Wed, 10 Jan 2024 01:02:00 +0000 Subject: [PATCH 12/15] Use range-based loop --- onnxruntime/core/providers/cpu/nn/string_split.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cpu/nn/string_split.cc b/onnxruntime/core/providers/cpu/nn/string_split.cc index a1f8803117ccf..377650949b39f 100644 --- a/onnxruntime/core/providers/cpu/nn/string_split.cc +++ b/onnxruntime/core/providers/cpu/nn/string_split.cc @@ -76,12 +76,13 @@ Status StringSplit::Compute(OpKernelContext* context) const { input_slices.reserve(input_data.size()); int64_t last_dim = 1; - for (auto input_iter = input_data.begin(); input_iter != input_data.end(); ++input_iter, ++num_tokens_iter) { - auto substrs = ComputeSubstrings(*input_iter, delimiter_, maxsplit_); + for (const auto& s : input_data) { + auto substrs = ComputeSubstrings(s, delimiter_, maxsplit_); auto substr_count = static_cast(substrs.size()); input_slices.push_back(std::move(substrs)); last_dim = std::max(last_dim, substr_count); *num_tokens_iter = substr_count; + ++num_tokens_iter; } // Set up splits output From 434546a3078b372c739181e9941025795a5be179 Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Wed, 10 Jan 2024 11:30:52 +0000 Subject: [PATCH 13/15] Move to text directory --- .../core/providers/cpu/{nn => text}/string_normalizer.cc | 0 onnxruntime/core/providers/cpu/{nn => text}/string_normalizer.h | 0 onnxruntime/core/providers/cpu/{nn => text}/string_split.cc | 2 +- onnxruntime/core/providers/cpu/{nn => text}/string_split.h | 0 .../test/providers/cpu/{nn => text}/string_normalizer_test.cc | 0 .../test/providers/cpu/{nn => text}/string_split_test.cc | 0 6 files changed, 1 insertion(+), 1 deletion(-) rename onnxruntime/core/providers/cpu/{nn => text}/string_normalizer.cc (100%) rename onnxruntime/core/providers/cpu/{nn => text}/string_normalizer.h (100%) rename onnxruntime/core/providers/cpu/{nn => text}/string_split.cc (98%) rename onnxruntime/core/providers/cpu/{nn => text}/string_split.h (100%) rename onnxruntime/test/providers/cpu/{nn => text}/string_normalizer_test.cc (100%) rename onnxruntime/test/providers/cpu/{nn => text}/string_split_test.cc (100%) diff --git a/onnxruntime/core/providers/cpu/nn/string_normalizer.cc b/onnxruntime/core/providers/cpu/text/string_normalizer.cc similarity index 100% rename from onnxruntime/core/providers/cpu/nn/string_normalizer.cc rename to onnxruntime/core/providers/cpu/text/string_normalizer.cc diff --git a/onnxruntime/core/providers/cpu/nn/string_normalizer.h b/onnxruntime/core/providers/cpu/text/string_normalizer.h similarity index 100% rename from onnxruntime/core/providers/cpu/nn/string_normalizer.h rename to onnxruntime/core/providers/cpu/text/string_normalizer.h diff --git a/onnxruntime/core/providers/cpu/nn/string_split.cc b/onnxruntime/core/providers/cpu/text/string_split.cc similarity index 98% rename from onnxruntime/core/providers/cpu/nn/string_split.cc rename to onnxruntime/core/providers/cpu/text/string_split.cc index 377650949b39f..a6f2b3cd4a8da 100644 --- a/onnxruntime/core/providers/cpu/nn/string_split.cc +++ b/onnxruntime/core/providers/cpu/text/string_split.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/providers/cpu/nn/string_split.h" +#include "string_split.h" #include #include #include diff --git a/onnxruntime/core/providers/cpu/nn/string_split.h b/onnxruntime/core/providers/cpu/text/string_split.h similarity index 100% rename from onnxruntime/core/providers/cpu/nn/string_split.h rename to onnxruntime/core/providers/cpu/text/string_split.h diff --git a/onnxruntime/test/providers/cpu/nn/string_normalizer_test.cc b/onnxruntime/test/providers/cpu/text/string_normalizer_test.cc similarity index 100% rename from onnxruntime/test/providers/cpu/nn/string_normalizer_test.cc rename to onnxruntime/test/providers/cpu/text/string_normalizer_test.cc diff --git a/onnxruntime/test/providers/cpu/nn/string_split_test.cc b/onnxruntime/test/providers/cpu/text/string_split_test.cc similarity index 100% rename from onnxruntime/test/providers/cpu/nn/string_split_test.cc rename to onnxruntime/test/providers/cpu/text/string_split_test.cc From 0acf4728b831cf70b283319c8309feb45b1400ec Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Wed, 10 Jan 2024 23:37:33 +0000 Subject: [PATCH 14/15] Doc build + pass vector as out argument --- docs/OperatorKernels.md | 1 + .../core/providers/cpu/text/string_split.cc | 21 +++++++--------- .../providers/cpu/text/string_split_test.cc | 24 ++++++++++++++++--- 3 files changed, 31 insertions(+), 15 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index f985cf10ded60..3db2cff6f5447 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -383,6 +383,7 @@ Do not modify directly.* |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[1, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |StringNormalizer|*in* X:**tensor(string)**
*out* Y:**tensor(string)**|10+|**X** = tensor(string)| +|StringSplit|*in* X:**T1**
*out* Y:**T2**
*out* Z:**T3**|20+|**T1** = tensor(string)
**T2** = tensor(string)
**T3** = tensor(int64)| |Sub|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||13|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[7, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| diff --git a/onnxruntime/core/providers/cpu/text/string_split.cc b/onnxruntime/core/providers/cpu/text/string_split.cc index a6f2b3cd4a8da..8aa63036a6f85 100644 --- a/onnxruntime/core/providers/cpu/text/string_split.cc +++ b/onnxruntime/core/providers/cpu/text/string_split.cc @@ -18,10 +18,9 @@ ONNX_CPU_OPERATOR_KERNEL(StringSplit, 20, /// Calculate substrings in ``str`` delimited by ``delimiter``. A maximum of ``max_splits`` splits are permitted. /// Returns a vector of string slices into ``str`` representing the substrings as string views. The user must ensure /// the returned views' lifetime does not exceed ``str``'s. -InlinedVector ComputeSubstrings(std::string_view str, std::string_view delimiter, int64_t max_splits) { - InlinedVector output; +void ComputeSubstrings(std::string_view str, std::string_view delimiter, int64_t max_splits, InlinedVector& out) { if (str.empty()) { - return output; + return; } if (delimiter.empty()) { // Count consecutive whitespace as one delimiter. Preceding and trailing whitespace is meant to be ignored. @@ -34,28 +33,26 @@ InlinedVector ComputeSubstrings(std::string_view str, std::str while (str[next_pos] == ' ') { next_pos--; } - output.push_back(str.substr(pos, next_pos - pos + 1)); + out.push_back(str.substr(pos, next_pos - pos + 1)); break; } else { auto next_pos = str.find_first_of(" ", pos); - output.push_back(str.substr(pos, next_pos - pos)); + out.push_back(str.substr(pos, next_pos - pos)); pos = str.find_first_not_of(" ", next_pos); } } - return output; } else { size_t pos = 0; int64_t token_count = 0; while (pos != std::string::npos) { auto next_pos = str.find(delimiter, pos); if (token_count++ == max_splits || next_pos == std::string::npos) { - output.push_back(str.substr(pos)); + out.push_back(str.substr(pos)); break; } - output.push_back(str.substr(pos, next_pos - pos)); + out.push_back(str.substr(pos, next_pos - pos)); pos = next_pos + delimiter.size(); } - return output; } } @@ -74,12 +71,12 @@ Status StringSplit::Compute(OpKernelContext* context) const { InlinedVector> input_slices; input_slices.reserve(input_data.size()); - int64_t last_dim = 1; + int64_t last_dim = 0; for (const auto& s : input_data) { - auto substrs = ComputeSubstrings(s, delimiter_, maxsplit_); + auto& substrs = input_slices.emplace_back(); + ComputeSubstrings(s, delimiter_, maxsplit_, substrs); auto substr_count = static_cast(substrs.size()); - input_slices.push_back(std::move(substrs)); last_dim = std::max(last_dim, substr_count); *num_tokens_iter = substr_count; ++num_tokens_iter; diff --git a/onnxruntime/test/providers/cpu/text/string_split_test.cc b/onnxruntime/test/providers/cpu/text/string_split_test.cc index d2f37168babe3..d5e1c296d0b25 100644 --- a/onnxruntime/test/providers/cpu/text/string_split_test.cc +++ b/onnxruntime/test/providers/cpu/text/string_split_test.cc @@ -80,8 +80,8 @@ TEST(StringSplit, EdgeWhitespaceTest) { TEST(StringSplit, EmptyInputTest) { OpTester test("StringSplit", 20); - test.AddAttribute("delimiter", "*"); test.AddInput("X", {1, 3, 1}, {"", "+", "*"}); + test.AddAttribute("delimiter", "*"); test.AddOutput("Y", {1, 3, 1, 2}, {"", "", "+", "", "", ""}); test.AddOutput("Z", {1, 3, 1}, {0, 1, 2}); test.Run(); @@ -91,7 +91,7 @@ TEST(StringSplit, OnlyEmptyInputTest) { OpTester test("StringSplit", 20); test.AddAttribute("delimiter", "*"); test.AddInput("X", {1, 2, 1}, {"", ""}); - test.AddOutput("Y", {1, 2, 1, 1}, {"", ""}); + test.AddOutput("Y", {1, 2, 1, 0}, {}); test.AddOutput("Z", {1, 2, 1}, {0, 0}); test.Run(); } @@ -99,10 +99,28 @@ TEST(StringSplit, OnlyEmptyInputTest) { TEST(StringSplit, OnlyEmptyNoDelimiterInputTest) { OpTester test("StringSplit", 20); test.AddInput("X", {1, 2, 1}, {"", ""}); - test.AddOutput("Y", {1, 2, 1, 1}, {"", ""}); + test.AddOutput("Y", {1, 2, 1, 0}, {}); test.AddOutput("Z", {1, 2, 1}, {0, 0}); test.Run(); } +TEST(StringSplit, NoInputTest) { + OpTester test("StringSplit", 20); + test.AddInput("X", { + 0, + }, + {}); + test.AddOutput("Y", { + 0, + 0, + }, + {}); + test.AddOutput("Z", { + 0, + }, + {}); + test.Run(); +} + } // namespace test } // namespace onnxruntime From f269f77148ac26eae9cbdc64fcc255a14c8585f8 Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Thu, 11 Jan 2024 02:38:12 +0000 Subject: [PATCH 15/15] Make last_dim size_t --- onnxruntime/core/providers/cpu/text/string_split.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/cpu/text/string_split.cc b/onnxruntime/core/providers/cpu/text/string_split.cc index 8aa63036a6f85..2b82309838464 100644 --- a/onnxruntime/core/providers/cpu/text/string_split.cc +++ b/onnxruntime/core/providers/cpu/text/string_split.cc @@ -71,14 +71,14 @@ Status StringSplit::Compute(OpKernelContext* context) const { InlinedVector> input_slices; input_slices.reserve(input_data.size()); - int64_t last_dim = 0; + size_t last_dim = 0; for (const auto& s : input_data) { auto& substrs = input_slices.emplace_back(); ComputeSubstrings(s, delimiter_, maxsplit_, substrs); - auto substr_count = static_cast(substrs.size()); + auto substr_count = substrs.size(); last_dim = std::max(last_dim, substr_count); - *num_tokens_iter = substr_count; + *num_tokens_iter = static_cast(substr_count); ++num_tokens_iter; }