From 7b6a615f56f4d648bc92089cae34eb57aafe3897 Mon Sep 17 00:00:00 2001 From: Aditya Goel <48102515+adityagoel4512@users.noreply.github.com> Date: Fri, 12 Jan 2024 17:46:23 +0000 Subject: [PATCH] StringSplit operator (#18016) ### Description ### Motivation and Context Closes https://github.com/microsoft/onnxruntime/issues/17596 --- docs/OperatorKernels.md | 1 + .../cpu/quantization/qlinear_concat.cc | 2 +- .../providers/cpu/cpu_execution_provider.cc | 2 + .../core/providers/cpu/text/string_split.cc | 98 ++++++++++++++ .../core/providers/cpu/text/string_split.h | 20 +++ .../providers/cpu/text/string_split_test.cc | 126 ++++++++++++++++++ .../onnx_backend_test_series_filters.jsonc | 6 - 7 files changed, 248 insertions(+), 7 deletions(-) create mode 100644 onnxruntime/core/providers/cpu/text/string_split.cc create mode 100644 onnxruntime/core/providers/cpu/text/string_split.h create mode 100644 onnxruntime/test/providers/cpu/text/string_split_test.cc diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index f0b79eb9e429f..a2bb39da76235 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -385,6 +385,7 @@ Do not modify directly.* |||[1, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |StringConcat|*in* X:**T**
*in* Y:**T**
*out* Z:**T**|20+|**T** = tensor(string)| |StringNormalizer|*in* X:**tensor(string)**
*out* Y:**tensor(string)**|10+|**X** = tensor(string)| +|StringSplit|*in* X:**T1**
*out* Y:**T2**
*out* Z:**T3**|20+|**T1** = tensor(string)
**T2** = tensor(string)
**T3** = tensor(int64)| |Sub|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||13|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[7, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| diff --git a/onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc b/onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc index ee9ae7167945c..af163b6be702b 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc @@ -1,4 +1,4 @@ -// Copyright (c Microsoft Corporation. All rights reserved. +// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. #include "qlinear_util.h" diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 9cd0b3d0620af..6aef03a32db09 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -991,6 +991,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, StringConcat); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, RegexFullMatch); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, StringSplit); // !!PLEASE READ BELOW!! Following that, add new entries above this comment @@ -2451,6 +2452,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/text/string_split.cc b/onnxruntime/core/providers/cpu/text/string_split.cc new file mode 100644 index 0000000000000..2b82309838464 --- /dev/null +++ b/onnxruntime/core/providers/cpu/text/string_split.cc @@ -0,0 +1,98 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "string_split.h" +#include +#include +#include +#include "core/common/common.h" +namespace onnxruntime { + +ONNX_CPU_OPERATOR_KERNEL(StringSplit, 20, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + StringSplit); + +/// Calculate substrings in ``str`` delimited by ``delimiter``. A maximum of ``max_splits`` splits are permitted. +/// Returns a vector of string slices into ``str`` representing the substrings as string views. The user must ensure +/// the returned views' lifetime does not exceed ``str``'s. +void ComputeSubstrings(std::string_view str, std::string_view delimiter, int64_t max_splits, InlinedVector& out) { + if (str.empty()) { + return; + } + if (delimiter.empty()) { + // Count consecutive whitespace as one delimiter. Preceding and trailing whitespace is meant to be ignored. + size_t pos = str.find_first_not_of(" "); + int64_t token_count = 0; + while (pos != std::string::npos) { + if (token_count++ == max_splits) { + // Trim down last substring as required in specification + size_t next_pos = str.length() - 1; + while (str[next_pos] == ' ') { + next_pos--; + } + out.push_back(str.substr(pos, next_pos - pos + 1)); + break; + } else { + auto next_pos = str.find_first_of(" ", pos); + out.push_back(str.substr(pos, next_pos - pos)); + pos = str.find_first_not_of(" ", next_pos); + } + } + } else { + size_t pos = 0; + int64_t token_count = 0; + while (pos != std::string::npos) { + auto next_pos = str.find(delimiter, pos); + if (token_count++ == max_splits || next_pos == std::string::npos) { + out.push_back(str.substr(pos)); + break; + } + out.push_back(str.substr(pos, next_pos - pos)); + pos = next_pos + delimiter.size(); + } + } +} + +StringSplit::StringSplit(const OpKernelInfo& info) : OpKernel(info) { + info.GetAttrOrDefault("maxsplit", &maxsplit_, std::numeric_limits::max() - 1); + info.GetAttrOrDefault("delimiter", &delimiter_, std::string()); +} + +Status StringSplit::Compute(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + auto input_data = input->template DataAsSpan(); + + // Set up number of tokens output + auto num_tokens_data = context->Output(1, input->Shape())->template MutableDataAsSpan(); + auto num_tokens_iter = num_tokens_data.begin(); + + InlinedVector> input_slices; + input_slices.reserve(input_data.size()); + size_t last_dim = 0; + + for (const auto& s : input_data) { + auto& substrs = input_slices.emplace_back(); + ComputeSubstrings(s, delimiter_, maxsplit_, substrs); + auto substr_count = substrs.size(); + last_dim = std::max(last_dim, substr_count); + *num_tokens_iter = static_cast(substr_count); + ++num_tokens_iter; + } + + // Set up splits output + auto splits_shape = input->Shape().AsShapeVector(); + splits_shape.push_back(last_dim); + + auto splits_data = context->Output(0, splits_shape)->template MutableDataAsSpan(); + auto slices_iter = input_slices.begin(); + for (auto output_splits_iter = splits_data.begin(); output_splits_iter != splits_data.end(); output_splits_iter += last_dim, ++slices_iter) { + std::copy(slices_iter->begin(), slices_iter->end(), output_splits_iter); + } + + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/text/string_split.h b/onnxruntime/core/providers/cpu/text/string_split.h new file mode 100644 index 0000000000000..6be249261d4e3 --- /dev/null +++ b/onnxruntime/core/providers/cpu/text/string_split.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" + +namespace onnxruntime { + +class StringSplit final : public OpKernel { + public: + explicit StringSplit(const OpKernelInfo& info); + Status Compute(OpKernelContext* context) const override; + + private: + std::string delimiter_; + int64_t maxsplit_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/text/string_split_test.cc b/onnxruntime/test/providers/cpu/text/string_split_test.cc new file mode 100644 index 0000000000000..d5e1c296d0b25 --- /dev/null +++ b/onnxruntime/test/providers/cpu/text/string_split_test.cc @@ -0,0 +1,126 @@ +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +TEST(StringSplit, BasicSplitTest) { + OpTester test("StringSplit", 20); + test.AddInput("X", {3}, {"hello world", "hello", "world"}); + test.AddAttribute("delimiter", " "); + test.AddOutput("Y", {3, 2}, {"hello", "world", "hello", "", "world", ""}); + test.AddOutput("Z", {3}, {2, 1, 1}); + test.Run(); +} + +TEST(StringSplit, MaxSplitTest) { + OpTester test("StringSplit", 20); + test.AddInput("X", {2, 2}, {"eggs;milk;chesse", "pepper;salt", "chicken;fish;pork", "spinach"}); + test.AddAttribute("delimiter", ";"); + test.AddAttribute("maxsplit", 1); + test.AddOutput("Y", {2, 2, 2}, + {"eggs", "milk;chesse", "pepper", "salt", "chicken", "fish;pork", "spinach", ""}); + test.AddOutput("Z", {2, 2}, {2, 2, 2, 1}); + test.Run(); +} + +TEST(StringSplit, EmptyStringDelimiterTest) { + OpTester test("StringSplit", 20); + test.AddInput("X", {1, 4}, {"hello world", "hello world", " hello world", "hello world "}); + test.AddAttribute("delimiter", ""); + test.AddOutput("Y", {1, 4, 2}, {"hello", "world", "hello", "world", "hello", "world", "hello", "world"}); + test.AddOutput("Z", {1, 4}, {2, 2, 2, 2}); + test.Run(); +} + +TEST(StringSplit, SubsequentWhitespaceDefaultTest) { + OpTester test("StringSplit", 20); + test.AddInput("X", {1, 4}, {"hello world", "hello world", " hello world", "hello world "}); + test.AddOutput("Y", {1, 4, 2}, {"hello", "world", "hello", "world", "hello", "world", "hello", "world"}); + test.AddOutput("Z", {1, 4}, {2, 2, 2, 2}); + test.Run(); +} + +TEST(StringSplit, SubsequentWhitespaceWithLimitTest) { + OpTester test("StringSplit", 20); + test.AddInput("X", {1, 4}, + {"lorem ipsum doler", " Open Neural Network Exchange (ONNX)", "onnx", "ONNX runtime "}); + test.AddAttribute("maxsplit", 1); + test.AddOutput( + "Y", {1, 4, 2}, + {"lorem", "ipsum doler", "Open", "Neural Network Exchange (ONNX)", "onnx", "", "ONNX", "runtime"}); + test.AddOutput("Z", {1, 4}, {2, 2, 1, 2}); + test.Run(); +} + +TEST(StringSplit, SingleTokenTest) { + OpTester test("StringSplit", 20); + test.AddAttribute("delimiter", "*"); + test.AddInput("X", {1, 1, 1}, {"lorem"}); + test.AddOutput("Y", {1, 1, 1, 1}, {"lorem"}); + test.AddOutput("Z", {1, 1, 1}, {1}); + test.Run(); +} + +TEST(StringSplit, SingleTokenWhitespaceTest) { + OpTester test("StringSplit", 20); + test.AddInput("X", {1, 1, 1}, {"lorem"}); + test.AddOutput("Y", {1, 1, 1, 1}, {"lorem"}); + test.AddOutput("Z", {1, 1, 1}, {1}); + test.Run(); +} + +TEST(StringSplit, EdgeWhitespaceTest) { + OpTester test("StringSplit", 20); + test.AddInput("X", {1, 1, 1}, {" lorem "}); + test.AddOutput("Y", {1, 1, 1, 1}, {"lorem"}); + test.AddOutput("Z", {1, 1, 1}, {1}); + test.Run(); +} + +TEST(StringSplit, EmptyInputTest) { + OpTester test("StringSplit", 20); + test.AddInput("X", {1, 3, 1}, {"", "+", "*"}); + test.AddAttribute("delimiter", "*"); + test.AddOutput("Y", {1, 3, 1, 2}, {"", "", "+", "", "", ""}); + test.AddOutput("Z", {1, 3, 1}, {0, 1, 2}); + test.Run(); +} + +TEST(StringSplit, OnlyEmptyInputTest) { + OpTester test("StringSplit", 20); + test.AddAttribute("delimiter", "*"); + test.AddInput("X", {1, 2, 1}, {"", ""}); + test.AddOutput("Y", {1, 2, 1, 0}, {}); + test.AddOutput("Z", {1, 2, 1}, {0, 0}); + test.Run(); +} + +TEST(StringSplit, OnlyEmptyNoDelimiterInputTest) { + OpTester test("StringSplit", 20); + test.AddInput("X", {1, 2, 1}, {"", ""}); + test.AddOutput("Y", {1, 2, 1, 0}, {}); + test.AddOutput("Z", {1, 2, 1}, {0, 0}); + test.Run(); +} + +TEST(StringSplit, NoInputTest) { + OpTester test("StringSplit", 20); + test.AddInput("X", { + 0, + }, + {}); + test.AddOutput("Y", { + 0, + 0, + }, + {}); + test.AddOutput("Z", { + 0, + }, + {}); + test.Run(); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index c2ca5f860a107..ed263515d6dd6 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -248,12 +248,6 @@ "^test_image_decoder_decode_pnm_rgb", "^test_image_decoder_decode_tiff_rgb", "^test_image_decoder_decode_webp_rgb", - "^test_string_split_basic", - "^test_string_split_consecutive_delimiters", - "^test_string_split_empty_string_delimiter", - "^test_string_split_empty_tensor", - "^test_string_split_maxsplit", - "^test_string_split_no_delimiter", "^test_reduce_l1_empty_set_cuda", "^test_reduce_l1_empty_set_expanded_cuda", "^test_reduce_l2_empty_set_cuda",