From 9fd92aafb98d75fb8790c63d288a67424c4ba2e0 Mon Sep 17 00:00:00 2001 From: Aditya Goel Date: Wed, 18 Oct 2023 02:48:52 +0100 Subject: [PATCH] Normal case --- cgmanifests/generated/cgmanifest.json | 2 +- cmake/deps.txt | 6 +- cmake/external/onnx | 2 +- .../providers/cpu/cpu_execution_provider.cc | 2 + .../core/providers/cpu/nn/string_split.cc | 88 +++++++++++++++++++ .../core/providers/cpu/nn/string_split.h | 18 ++++ .../providers/cpu/nn/string_split_test.cc | 17 ++++ 7 files changed, 130 insertions(+), 5 deletions(-) create mode 100644 onnxruntime/core/providers/cpu/nn/string_split.cc create mode 100644 onnxruntime/core/providers/cpu/nn/string_split.h create mode 100644 onnxruntime/test/providers/cpu/nn/string_split_test.cc diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 08ca90d7c3b7f..38c104c365b12 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -166,7 +166,7 @@ "component": { "type": "git", "git": { - "commitHash": "fdefbe85ed9c362b95b9b401cd19db068a76141f", + "commitHash": "0c296085f9f65f0f8ef7aec7b9eed55faf37dc40", "repositoryUrl": "https://github.com/onnx/onnx.git" }, "comments": "onnx" diff --git a/cmake/deps.txt b/cmake/deps.txt index 7cf49f02333a4..a5daaa01521d7 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -16,15 +16,15 @@ flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v1.12.0.zip; fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip;b985f6985a05a1c03ff1bb71190f66d8f98a1494 fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1 google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.7.0.zip;e97c368b176e8614e3f1bf13dd9abcf6a7ad9908 -google_nsync;https://github.com/google/nsync/archive/refs/tags/1.26.0.zip;5e7c00ef6bf5b787386fc040067903ec774e2752 -googletest;https://github.com/google/googletest/archive/refs/tags/v1.14.0.zip;0ac421f2ec11af38b0fff0f1992184032731a8bc +google_nsync;https://github.com/google/nsync/archive/refs/tags/1.23.0.zip;f3233450cf7156fc0bedd1b0e884eddec264897c +googletest;https://github.com/google/googletest/archive/519beb0e52c842729b4b53731d27c0e0c32ab4a2.zip;4b3c37972e4c1bef1185d46f702082f8772ee73f googlexnnpack;https://github.com/google/XNNPACK/archive/003c580e696a774afdc984996ee909b7c8d8128c.zip;9f192e3f15e1e37ae9c78d53eeea47e45c5eb31c json;https://github.com/nlohmann/json/archive/refs/tags/v3.10.5.zip;f257f8dc27c5b8c085dc887b40cddd18ae1f725c microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 -onnx;https://github.com/onnx/onnx/archive/14303de049144035dfd94ace5f7a3b44773b1aad.zip;250eab9690392b248d75b56e605fb49eca373442 +onnx;https://github.com/onnx/onnx/archive/0c296085f9f65f0f8ef7aec7b9eed55faf37dc40.zip;01ca9e955a03a9183e3d278e96f975f1a762cef1 #use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459) onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/0462dc31ae78f48744b6141ae376df1f96d3f459.zip;5ff086361956cceb81ed17453a1fd8db2aa4328d protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa diff --git a/cmake/external/onnx b/cmake/external/onnx index e2525550194ce..0c296085f9f65 160000 --- a/cmake/external/onnx +++ b/cmake/external/onnx @@ -1 +1 @@ -Subproject commit e2525550194ce3d8a2c4a3af451c9d9b3ae6650e +Subproject commit 0c296085f9f65f0f8ef7aec7b9eed55faf37dc40 diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 3d03abf5b7ebc..da7e2951cabb3 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -960,6 +960,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Sh // Opset 20 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, StringSplit); // !!PLEASE READ BELOW!! Following that, add new entries above this comment @@ -2389,6 +2390,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { // Opset 20 BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/nn/string_split.cc b/onnxruntime/core/providers/cpu/nn/string_split.cc new file mode 100644 index 0000000000000..3ea8d31ae0536 --- /dev/null +++ b/onnxruntime/core/providers/cpu/nn/string_split.cc @@ -0,0 +1,88 @@ +#include "string_split.h" +#include "core/common/common.h" +#include +namespace onnxruntime { + +ONNX_CPU_OPERATOR_KERNEL( + StringSplit, + 20, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) + .TypeConstraint("T3", DataTypeImpl::GetTensorType()), + StringSplit); + +int64_t countSubstrings(const std::string& str, const std::string& substr) { + if (str.empty()) { + return 0; + } + int64_t count = 1; + size_t pos = 0; + while ((pos = str.find(substr, pos)) != std::string::npos) { + ++count; + pos += substr.length(); + } + + return count; +} + +size_t fill_substrings(const std::string& str, const std::string& delimiter, gsl::span output, int64_t output_index, int64_t max_tokens) { + // Fill output with substrings of str, delimited by delimiter and place into output starting at output_index and incrementing up. + // Up to max_tokens substrings should be reached. If we are done before max_tokens, fill the rest with "". If we would not be done after max_tokens, make sure the max_tokensth substring is the remainder of the string. + auto pos = 0; + size_t token_index = 0; + while (token_index < max_tokens) { + auto next_pos = str.find(delimiter, pos); + output[output_index + token_index] = str.substr(pos, next_pos - pos); + pos = next_pos + delimiter.size(); + ++token_index; + if (next_pos == std::string::npos) { + break; + } + } + return token_index; +} + +StringSplit::StringSplit(const OpKernelInfo& info): OpKernel(info) { + info.GetAttrOrDefault("maxsplit", &maxsplit_, static_cast(-1)); // TODO is this the right thing to do here? + info.GetAttrOrDefault("delimiter", &delimiter_, std::string("")); +} + +Status StringSplit::Compute(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + if (nullptr == input) { + return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); + } + auto* num_substrings = context->Output(1, input->Shape()); + if (nullptr == num_substrings) { + return Status(common::ONNXRUNTIME, common::FAIL, "output count mismatch"); + } + if ("" == delimiter_) { + // TODO: takes consecutive whitespace as delimiter + } else { + + auto input_data = input->template DataAsSpan(); + auto last_dim = maxsplit_; + for (auto i = 0; i < input->Shape().Size(); ++i) { + auto substring_count = countSubstrings(input_data[i], delimiter_); + last_dim = std::max(last_dim, substring_count); + } + // 1. instantiate output shape to be shape input + (last_dim,) + // 2. maintain output tensor index/pointer. Iterate over input tensor; split tensor into last_dim substrings (with "" at end for extra); copy into output tensor and output pointer/index. advance output pointer/index by last_dim. + + // Set up num_substrings output + auto num_substrings_data = num_substrings->template MutableDataAsSpan(); + // Set up splits output + auto splits_shape = input->Shape().AsShapeVector(); + splits_shape.push_back(last_dim); + auto splits_data = context->Output(0, splits_shape)->template MutableDataAsSpan(); + auto splits_index = 0; + for (auto i = 0; i < input->Shape().Size(); ++i) { + num_substrings_data[i] = static_cast(fill_substrings(input_data[i], delimiter_, splits_data, splits_index, last_dim)); + splits_index += last_dim; + } + } + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/string_split.h b/onnxruntime/core/providers/cpu/nn/string_split.h new file mode 100644 index 0000000000000..96c61b74c5a36 --- /dev/null +++ b/onnxruntime/core/providers/cpu/nn/string_split.h @@ -0,0 +1,18 @@ +#pragma once + +#include "core/framework/op_kernel.h" + +namespace onnxruntime { + +class StringSplit final : public OpKernel { + public: + explicit StringSplit(const OpKernelInfo& info); + Status Compute(OpKernelContext* context) const override; + + private: + std::string delimiter_; + int64_t maxsplit_; +}; + + +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/nn/string_split_test.cc b/onnxruntime/test/providers/cpu/nn/string_split_test.cc new file mode 100644 index 0000000000000..e71c284e290c1 --- /dev/null +++ b/onnxruntime/test/providers/cpu/nn/string_split_test.cc @@ -0,0 +1,17 @@ +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +TEST(StringSplit, BasicSplitTest) { + OpTester test("StringSplit", 20); + test.AddInput("X", {3}, {"hello world", "hello", "world"}); + test.AddAttribute("delimiter", " "); + test.AddOutput("Y", {3, 2}, {"hello", "world", "hello", "", "world", ""}); + test.AddOutput("Z", {3}, {2, 1, 1}); + test.Run(); +} + +} // namespace test +} // namespace onnxruntime