From 4694edcd4199419ec9315cb94305f906c1df49ba Mon Sep 17 00:00:00 2001 From: Aditya Goel <48102515+adityagoel4512@users.noreply.github.com> Date: Thu, 11 Jan 2024 18:01:43 +0000 Subject: [PATCH] String concat operator (#17994) ### Description ### Motivation and Context Closes https://github.com/microsoft/onnxruntime/issues/17595. --------- Signed-off-by: Aditya Goel --- docs/OperatorKernels.md | 1 + .../providers/cpu/cpu_execution_provider.cc | 2 + .../core/providers/cpu/text/string_concat.cc | 60 +++++++++++++++ .../core/providers/cpu/text/string_concat.h | 17 +++++ .../cpu/{nn => text}/string_normalizer.cc | 0 .../cpu/{nn => text}/string_normalizer.h | 0 .../providers/cpu/text/string_concat_test.cc | 76 +++++++++++++++++++ .../{nn => text}/string_normalizer_test.cc | 0 .../onnx_backend_test_series_filters.jsonc | 5 -- 9 files changed, 156 insertions(+), 5 deletions(-) create mode 100644 onnxruntime/core/providers/cpu/text/string_concat.cc create mode 100644 onnxruntime/core/providers/cpu/text/string_concat.h rename onnxruntime/core/providers/cpu/{nn => text}/string_normalizer.cc (100%) rename onnxruntime/core/providers/cpu/{nn => text}/string_normalizer.h (100%) create mode 100644 onnxruntime/test/providers/cpu/text/string_concat_test.cc rename onnxruntime/test/providers/cpu/{nn => text}/string_normalizer_test.cc (100%) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index f985cf10ded60..c856d12141c9c 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -382,6 +382,7 @@ Do not modify directly.* |Squeeze|*in* data:**T**
*in* axes:**tensor(int64)**
*out* squeezed:**T**

or

*in* data:**T**
*out* squeezed:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[1, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|StringConcat|*in* X:**T**
*in* Y:**T**
*out* Z:**T**|20+|**T** = tensor(string)| |StringNormalizer|*in* X:**tensor(string)**
*out* Y:**tensor(string)**|10+|**X** = tensor(string)| |Sub|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||13|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index f60c7ddac5c05..ba7738b405795 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -989,6 +989,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2FNUZ, IsNaN); #endif class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, StringConcat); // !!PLEASE READ BELOW!! Following that, add new entries above this comment @@ -2447,6 +2448,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, #endif BuildKernelCreateInfo, + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/text/string_concat.cc b/onnxruntime/core/providers/cpu/text/string_concat.cc new file mode 100644 index 0000000000000..bc626f8e055aa --- /dev/null +++ b/onnxruntime/core/providers/cpu/text/string_concat.cc @@ -0,0 +1,60 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "string_concat.h" +#include "core/providers/cpu/math/element_wise_ops.h" +#include "core/common/common.h" + +namespace onnxruntime { +ONNX_CPU_OPERATOR_KERNEL(StringConcat, 20, + KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + StringConcat); + +Status StringConcat::Compute(OpKernelContext* context) const { + ProcessBroadcastSpanFuncs broadcast_funcs{[](BroadcastHelper& broadcast_helper) { + auto x = broadcast_helper.ScalarInput0(); + auto y = broadcast_helper.SpanInput1(); + auto y_iter = y.begin(); + auto output_iter = broadcast_helper.OutputSpan().begin(); + const auto x_size = x.length(); + while (y_iter != y.end()) { + output_iter->reserve(x_size + y_iter->length()); + output_iter->append(x); + output_iter->append(*y_iter); + y_iter++; + output_iter++; + } + }, + [](BroadcastHelper& broadcast_helper) { + auto x = broadcast_helper.SpanInput0(); + auto x_iter = x.begin(); + auto y = broadcast_helper.ScalarInput1(); + auto output_iter = broadcast_helper.OutputSpan().begin(); + const auto y_size = y.length(); + while (x_iter != x.end()) { + output_iter->reserve(y_size + x_iter->length()); + output_iter->append(*x_iter); + output_iter->append(y); + x_iter++; + output_iter++; + } + }, + [](BroadcastHelper& broadcast_helper) { + auto x_iter = broadcast_helper.SpanInput0().begin(); + auto y_iter = broadcast_helper.SpanInput1().begin(); + auto output = broadcast_helper.OutputSpan(); + auto output_iter = output.begin(); + while (output_iter != output.end()) { + output_iter->reserve(x_iter->length() + y_iter->length()); + output_iter->append(*x_iter); + output_iter->append(*y_iter); + x_iter++; + y_iter++; + output_iter++; + } + }}; + UntypedBroadcastTwo(*context, broadcast_funcs); + return Status::OK(); +} + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/text/string_concat.h b/onnxruntime/core/providers/cpu/text/string_concat.h new file mode 100644 index 0000000000000..63c1ea8a41146 --- /dev/null +++ b/onnxruntime/core/providers/cpu/text/string_concat.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/framework/op_kernel.h" + +namespace onnxruntime { + +class StringConcat final : public OpKernel { + public: + StringConcat(const OpKernelInfo& info) : OpKernel(info) {} + + Status Compute(OpKernelContext* context) const override; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/nn/string_normalizer.cc b/onnxruntime/core/providers/cpu/text/string_normalizer.cc similarity index 100% rename from onnxruntime/core/providers/cpu/nn/string_normalizer.cc rename to onnxruntime/core/providers/cpu/text/string_normalizer.cc diff --git a/onnxruntime/core/providers/cpu/nn/string_normalizer.h b/onnxruntime/core/providers/cpu/text/string_normalizer.h similarity index 100% rename from onnxruntime/core/providers/cpu/nn/string_normalizer.h rename to onnxruntime/core/providers/cpu/text/string_normalizer.h diff --git a/onnxruntime/test/providers/cpu/text/string_concat_test.cc b/onnxruntime/test/providers/cpu/text/string_concat_test.cc new file mode 100644 index 0000000000000..2bfa3dc5615e1 --- /dev/null +++ b/onnxruntime/test/providers/cpu/text/string_concat_test.cc @@ -0,0 +1,76 @@ +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +static void RunTest(const std::vector& dims, const std::vector& input1, + const std::vector& input2, const std::vector& output) { + OpTester test("StringConcat", 20, onnxruntime::kOnnxDomain); + test.AddInput("X", dims, input1); + test.AddInput("Y", dims, input2); + test.AddOutput("Z", dims, output); + test.Run(); +} + +TEST(StringConcat, BasicConcatenation) { + RunTest({1, 2}, {"Hello", "World"}, {"Hello", "World"}, {"HelloHello", "WorldWorld"}); +} + +TEST(StringConcat, TwoDimensionalConcatenation) { + RunTest({2, 2}, {"Hello", "World", "ONNX", "onnxruntime"}, {"Hello", "World", "ONNX", "onnxruntime"}, + {"HelloHello", "WorldWorld", "ONNXONNX", "onnxruntimeonnxruntime"}); +} + +TEST(StringConcat, LeftToRightBroadcastingConcatenation) { + OpTester test("StringConcat", 20, onnxruntime::kOnnxDomain); + test.AddInput("X", {2, 2}, {"Hello", "World", "ONNX", "onnxruntime"}); + test.AddInput("Y", {1}, {"!"}); + test.AddOutput("Z", {2, 2}, {"Hello!", "World!", "ONNX!", "onnxruntime!"}); + test.Run(); +} + +TEST(StringConcat, RightToLeftBroadcastingConcatenation) { + OpTester test("StringConcat", 20, onnxruntime::kOnnxDomain); + test.AddInput("X", {1}, {"!"}); + test.AddInput("Y", {2, 2}, {"Hello", "World", "ONNX", "onnxruntime"}); + test.AddOutput("Z", {2, 2}, {"!Hello", "!World", "!ONNX", "!onnxruntime"}); + test.Run(); +} + +TEST(StringConcat, BidirectionalBroadcastingConcatenation) { + OpTester test("StringConcat", 20, onnxruntime::kOnnxDomain); + test.AddInput("X", {2, 1, 3}, {"a", "b", "c", "d", "e", "f"}); + test.AddInput("Y", {1, 4, 3}, {"a", "b", "c", "d", "e", "f", "g", "h", "i", "k", "l", "m"}); + test.AddOutput("Z", {2, 4, 3}, + { + "aa", + "bb", + "cc", + "ad", + "be", + "cf", + "ag", + "bh", + "ci", + "ak", + "bl", + "cm", + "da", + "eb", + "fc", + "dd", + "ee", + "ff", + "dg", + "eh", + "fi", + "dk", + "el", + "fm", + }); + test.Run(); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/nn/string_normalizer_test.cc b/onnxruntime/test/providers/cpu/text/string_normalizer_test.cc similarity index 100% rename from onnxruntime/test/providers/cpu/nn/string_normalizer_test.cc rename to onnxruntime/test/providers/cpu/text/string_normalizer_test.cc diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 3a13e39702904..fb33ef0a1da3c 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -251,11 +251,6 @@ "^test_regex_full_match_basic", "^test_regex_full_match_email_domain", "^test_regex_full_match_empty", - "^test_string_concat_broadcasting", - "^test_string_concat", - "^test_string_concat_empty_string", - "^test_string_concat_utf8", - "^test_string_concat_zero_dimensional", "^test_string_split_basic", "^test_string_split_consecutive_delimiters", "^test_string_split_empty_string_delimiter",