From 6ab138a297780fcaddbf9f515cf77f366017dcac Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Tue, 9 Jan 2024 03:55:02 +0000 Subject: [PATCH] Implement feedback --- .../cpu/quantization/qlinear_concat.cc | 2 +- .../core/providers/cpu/nn/string_split.cc | 91 ++++++++++--------- .../core/providers/cpu/nn/string_split.h | 3 + .../providers/cpu/nn/string_split_test.cc | 40 +++++++- 4 files changed, 91 insertions(+), 45 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc b/onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc index ee9ae7167945c..af163b6be702b 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc @@ -1,4 +1,4 @@ -// Copyright (c Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "qlinear_util.h" diff --git a/onnxruntime/core/providers/cpu/nn/string_split.cc b/onnxruntime/core/providers/cpu/nn/string_split.cc index 592575327cae8..eaf20075ceea4 100644 --- a/onnxruntime/core/providers/cpu/nn/string_split.cc +++ b/onnxruntime/core/providers/cpu/nn/string_split.cc @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include "string_split.h" #include "core/common/common.h" #include @@ -12,13 +15,11 @@ ONNX_CPU_OPERATOR_KERNEL( .TypeConstraint("T3", DataTypeImpl::GetTensorType()), StringSplit); -int64_t countSubstrings(std::string_view str, std::string_view substr) { - if (str.empty()) { - return 0; - } +/// Count the number of instances of substring ``substr`` in ``str``. If ``substr`` is an empty string it counts the number of whitespace delimited words. +int64_t CountSubstrings(std::string_view str, std::string_view substr) { if (substr.empty()) { // Count consecutive whitespace as one delimiter - int64_t count = 1; + int64_t count = 0; size_t pos = str.find_first_not_of(" "); while (pos != std::string::npos) { ++count; @@ -26,85 +27,93 @@ int64_t countSubstrings(std::string_view str, std::string_view substr) { } return count; } else { - int64_t count = 1; - size_t pos = str.find(substr); + int64_t count = 0; + size_t pos = 0; while (pos != std::string::npos) { ++count; - pos = str.find(substr, pos + substr.length()); + pos = str.find(substr, pos); + if (pos != std::string::npos) { + pos += substr.length(); + } } return count; } } -int64_t fillSubstrings(std::string_view str, std::string_view delimiter, gsl::span output, int64_t output_index, size_t max_tokens) { +/// Fill substrings of ``str`` based on split delimiter ``delimiter`` into ``output`` span. Restrict maximum number of generated substrings to ``max_tokens``. +/// The function returns the number of substrings generated (this is less or equal to ``max_tokens``). +int64_t FillSubstrings(std::string_view str, std::string_view delimiter, gsl::details::span_iterator output, size_t max_tokens) { if (str.empty()) { return 0; } if (delimiter.empty()) { // Count consecutive whitespace as one delimiter. Preceding and trailing whitespace is meant to be ignored. size_t pos = str.find_first_not_of(" "); - size_t token_index = 0; - while (token_index < max_tokens && pos != std::string::npos) { - auto next_pos = token_index == max_tokens - 1 ? std::string::npos : str.find_first_of(" ", pos); - output[output_index + token_index] = str.substr(pos, next_pos - pos); - ++token_index; - if (next_pos == std::string::npos) { + int64_t token_count = 0; + while (pos != std::string::npos) { + if (++token_count == max_tokens) { + // trim down last substring as required in specification + size_t next_pos = str.length() - 1; + while (str[next_pos] == ' ') { + next_pos--; + } + *output = str.substr(pos, next_pos - pos + 1); break; + } else { + auto next_pos = str.find_first_of(" ", pos); + *output = str.substr(pos, next_pos - pos); + pos = str.find_first_not_of(" ", next_pos); } - pos = str.find_first_not_of(" ", next_pos); + + output++; } - return token_index; + return token_count; } else { size_t pos = 0; - size_t token_index = 0; - while (token_index < max_tokens && pos != std::string::npos) { - auto next_pos = token_index == max_tokens - 1 ? std::string::npos : str.find(delimiter, pos); - output[output_index + token_index] = str.substr(pos, next_pos - pos); - ++token_index; + int64_t token_count = 0; + while (pos != std::string::npos) { + auto next_pos = token_count == max_tokens - 1 ? std::string::npos : str.find(delimiter, pos); + *output++ = str.substr(pos, next_pos - pos); + token_count++; if (next_pos == std::string::npos) { break; } pos = next_pos + delimiter.size(); } - return token_index; + return token_count; } } StringSplit::StringSplit(const OpKernelInfo& info) : OpKernel(info) { - info.GetAttrOrDefault("maxsplit", &maxsplit_, std::numeric_limits::max() - 1); // TODO is this the right thing to do here? + info.GetAttrOrDefault("maxsplit", &maxsplit_, std::numeric_limits::max() - 1); info.GetAttrOrDefault("delimiter", &delimiter_, std::string("")); } Status StringSplit::Compute(OpKernelContext* context) const { const Tensor* input = context->Input(0); - if (nullptr == input) { - return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); - } - auto* num_substrings = context->Output(1, input->Shape()); - if (nullptr == num_substrings) { - return Status(common::ONNXRUNTIME, common::FAIL, "output count mismatch"); - } auto input_data = input->template DataAsSpan(); + int64_t last_dim = 0; - for (auto i = 0; i < input->Shape().Size(); ++i) { - auto substring_count = countSubstrings(input_data[i], delimiter_); - last_dim = std::max(last_dim, substring_count); + for (const auto& str : input_data) { + last_dim = std::max(last_dim, CountSubstrings(str, delimiter_)); } - last_dim = std::min(last_dim, maxsplit_ + 1); - // Set up num_substrings output - auto num_substrings_data = num_substrings->template MutableDataAsSpan(); // Set up splits output auto splits_shape = input->Shape().AsShapeVector(); if (last_dim > 0) { splits_shape.push_back(last_dim); } auto splits_data = context->Output(0, splits_shape)->template MutableDataAsSpan(); - auto splits_index = 0; - for (auto i = 0; i < input->Shape().Size(); ++i) { - num_substrings_data[i] = fillSubstrings(input_data[i], delimiter_, splits_data, splits_index, last_dim); - splits_index += last_dim; + auto output_splits_iter = splits_data.begin(); + + // Set up number of tokens output + auto* num_substrings = context->Output(1, input->Shape()); + auto num_substrings_data = num_substrings->template MutableDataAsSpan(); + auto output_num_tokens_iter = num_substrings_data.begin(); + + for (auto input_iter = input_data.begin(); input_iter != input_data.end(); input_iter++, output_splits_iter += last_dim, output_num_tokens_iter++) { + *output_num_tokens_iter = FillSubstrings(*input_iter, delimiter_, output_splits_iter, last_dim); } return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/nn/string_split.h b/onnxruntime/core/providers/cpu/nn/string_split.h index f9ded9cafc760..6be249261d4e3 100644 --- a/onnxruntime/core/providers/cpu/nn/string_split.h +++ b/onnxruntime/core/providers/cpu/nn/string_split.h @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #pragma once #include "core/framework/op_kernel.h" diff --git a/onnxruntime/test/providers/cpu/nn/string_split_test.cc b/onnxruntime/test/providers/cpu/nn/string_split_test.cc index e4aa5a8161339..a45fe736c3df4 100644 --- a/onnxruntime/test/providers/cpu/nn/string_split_test.cc +++ b/onnxruntime/test/providers/cpu/nn/string_split_test.cc @@ -27,7 +27,7 @@ TEST(StringSplit, EmptyStringDelimiterTest) { OpTester test("StringSplit", 20); test.AddInput("X", {1, 4}, {"hello world", "hello world", " hello world", "hello world "}); test.AddAttribute("delimiter", ""); - test.AddOutput("Y", {1, 4, 3}, {"hello", "world", "", "hello", "world", "", "hello", "world", "", "hello", "world", ""}); + test.AddOutput("Y", {1, 4, 2}, {"hello", "world", "hello", "world", "hello", "world", "hello", "world"}); test.AddOutput("Z", {1, 4}, {2, 2, 2, 2}); test.Run(); } @@ -35,7 +35,7 @@ TEST(StringSplit, EmptyStringDelimiterTest) { TEST(StringSplit, SubsequentWhitespaceDefaultTest) { OpTester test("StringSplit", 20); test.AddInput("X", {1, 4}, {"hello world", "hello world", " hello world", "hello world "}); - test.AddOutput("Y", {1, 4, 3}, {"hello", "world", "", "hello", "world", "", "hello", "world", "", "hello", "world", ""}); + test.AddOutput("Y", {1, 4, 2}, {"hello", "world", "hello", "world", "hello", "world", "hello", "world"}); test.AddOutput("Z", {1, 4}, {2, 2, 2, 2}); test.Run(); } @@ -44,10 +44,44 @@ TEST(StringSplit, SubsequentWhitespaceWithLimitTest) { OpTester test("StringSplit", 20); test.AddInput("X", {1, 4}, {"lorem ipsum doler", " Open Neural Network Exchange (ONNX)", "onnx", "ONNX runtime "}); test.AddAttribute("maxsplit", 1); - test.AddOutput("Y", {1, 4, 2}, {"lorem", "ipsum doler", "Open", "Neural Network Exchange (ONNX)", "onnx", "", "ONNX", "runtime "}); + test.AddOutput("Y", {1, 4, 2}, {"lorem", "ipsum doler", "Open", "Neural Network Exchange (ONNX)", "onnx", "", "ONNX", "runtime"}); test.AddOutput("Z", {1, 4}, {2, 2, 1, 2}); test.Run(); } +TEST(StringSplit, SingleTokenTest) { + OpTester test("StringSplit", 20); + test.AddAttribute("delimiter", "*"); + test.AddInput("X", {1, 1, 1}, {"lorem"}); + test.AddOutput("Y", {1, 1, 1, 1}, {"lorem"}); + test.AddOutput("Z", {1, 1, 1}, {1}); + test.Run(); +} + +TEST(StringSplit, SingleTokenWhitespaceTest) { + OpTester test("StringSplit", 20); + test.AddInput("X", {1, 1, 1}, {"lorem"}); + test.AddOutput("Y", {1, 1, 1, 1}, {"lorem"}); + test.AddOutput("Z", {1, 1, 1}, {1}); + test.Run(); +} + +TEST(StringSplit, EdgeWhitespaceTest) { + OpTester test("StringSplit", 20); + test.AddInput("X", {1, 1, 1}, {" lorem "}); + test.AddOutput("Y", {1, 1, 1, 1}, {"lorem"}); + test.AddOutput("Z", {1, 1, 1}, {1}); + test.Run(); +} + +TEST(StringSplit, EmptyInputTest) { + OpTester test("StringSplit", 20); + test.AddAttribute("delimiter", "*"); + test.AddInput("X", {1, 3, 1}, {"", "+", "*"}); + test.AddOutput("Y", {1, 3, 1, 2}, {"", "", "+", "", "", ""}); + test.AddOutput("Z", {1, 3, 1}, {0, 1, 2}); + test.Run(); +} + } // namespace test } // namespace onnxruntime