Skip to content

Commit

Permalink
String concat operator (#17994)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->



### Motivation and Context
Closes #17595.

---------

Signed-off-by: Aditya Goel <[email protected]>
  • Loading branch information
adityagoel4512 authored Jan 11, 2024
1 parent f68dfcd commit 4694edc
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 5 deletions.
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ Do not modify directly.*
|Squeeze|*in* data:**T**<br> *in* axes:**tensor(int64)**<br> *out* squeezed:**T**<br><br>or<br><br>*in* data:**T**<br> *out* squeezed:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[1, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|StringConcat|*in* X:**T**<br> *in* Y:**T**<br> *out* Z:**T**|20+|**T** = tensor(string)|
|StringNormalizer|*in* X:**tensor(string)**<br> *out* Y:**tensor(string)**|10+|**X** = tensor(string)|
|Sub|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|||13|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2FNUZ, IsNaN);
#endif
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, StringConcat);

// !!PLEASE READ BELOW!! Following that, add new entries above this comment

Expand Down Expand Up @@ -2447,6 +2448,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2FNUZ, IsNaN)>,
#endif
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, StringConcat)>,
};

for (auto& function_table_entry : function_table) {
Expand Down
60 changes: 60 additions & 0 deletions onnxruntime/core/providers/cpu/text/string_concat.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "string_concat.h"
#include "core/providers/cpu/math/element_wise_ops.h"
#include "core/common/common.h"

namespace onnxruntime {
ONNX_CPU_OPERATOR_KERNEL(StringConcat, 20,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<std::string>()),
StringConcat);

Status StringConcat::Compute(OpKernelContext* context) const {
ProcessBroadcastSpanFuncs broadcast_funcs{[](BroadcastHelper& broadcast_helper) {
auto x = broadcast_helper.ScalarInput0<std::string>();
auto y = broadcast_helper.SpanInput1<std::string>();
auto y_iter = y.begin();
auto output_iter = broadcast_helper.OutputSpan<std::string>().begin();
const auto x_size = x.length();
while (y_iter != y.end()) {
output_iter->reserve(x_size + y_iter->length());
output_iter->append(x);
output_iter->append(*y_iter);
y_iter++;
output_iter++;
}
},
[](BroadcastHelper& broadcast_helper) {
auto x = broadcast_helper.SpanInput0<std::string>();
auto x_iter = x.begin();
auto y = broadcast_helper.ScalarInput1<std::string>();
auto output_iter = broadcast_helper.OutputSpan<std::string>().begin();
const auto y_size = y.length();
while (x_iter != x.end()) {
output_iter->reserve(y_size + x_iter->length());
output_iter->append(*x_iter);
output_iter->append(y);
x_iter++;
output_iter++;
}
},
[](BroadcastHelper& broadcast_helper) {
auto x_iter = broadcast_helper.SpanInput0<std::string>().begin();
auto y_iter = broadcast_helper.SpanInput1<std::string>().begin();
auto output = broadcast_helper.OutputSpan<std::string>();
auto output_iter = output.begin();
while (output_iter != output.end()) {
output_iter->reserve(x_iter->length() + y_iter->length());
output_iter->append(*x_iter);
output_iter->append(*y_iter);
x_iter++;
y_iter++;
output_iter++;
}
}};
UntypedBroadcastTwo(*context, broadcast_funcs);
return Status::OK();
}

} // namespace onnxruntime
17 changes: 17 additions & 0 deletions onnxruntime/core/providers/cpu/text/string_concat.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/framework/op_kernel.h"

namespace onnxruntime {

class StringConcat final : public OpKernel {
public:
StringConcat(const OpKernelInfo& info) : OpKernel(info) {}

Status Compute(OpKernelContext* context) const override;
};

} // namespace onnxruntime
76 changes: 76 additions & 0 deletions onnxruntime/test/providers/cpu/text/string_concat_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"

namespace onnxruntime {
namespace test {

static void RunTest(const std::vector<int64_t>& dims, const std::vector<std::string>& input1,
const std::vector<std::string>& input2, const std::vector<std::string>& output) {
OpTester test("StringConcat", 20, onnxruntime::kOnnxDomain);
test.AddInput<std::string>("X", dims, input1);
test.AddInput<std::string>("Y", dims, input2);
test.AddOutput<std::string>("Z", dims, output);
test.Run();
}

TEST(StringConcat, BasicConcatenation) {
RunTest({1, 2}, {"Hello", "World"}, {"Hello", "World"}, {"HelloHello", "WorldWorld"});
}

TEST(StringConcat, TwoDimensionalConcatenation) {
RunTest({2, 2}, {"Hello", "World", "ONNX", "onnxruntime"}, {"Hello", "World", "ONNX", "onnxruntime"},
{"HelloHello", "WorldWorld", "ONNXONNX", "onnxruntimeonnxruntime"});
}

TEST(StringConcat, LeftToRightBroadcastingConcatenation) {
OpTester test("StringConcat", 20, onnxruntime::kOnnxDomain);
test.AddInput<std::string>("X", {2, 2}, {"Hello", "World", "ONNX", "onnxruntime"});
test.AddInput<std::string>("Y", {1}, {"!"});
test.AddOutput<std::string>("Z", {2, 2}, {"Hello!", "World!", "ONNX!", "onnxruntime!"});
test.Run();
}

TEST(StringConcat, RightToLeftBroadcastingConcatenation) {
OpTester test("StringConcat", 20, onnxruntime::kOnnxDomain);
test.AddInput<std::string>("X", {1}, {"!"});
test.AddInput<std::string>("Y", {2, 2}, {"Hello", "World", "ONNX", "onnxruntime"});
test.AddOutput<std::string>("Z", {2, 2}, {"!Hello", "!World", "!ONNX", "!onnxruntime"});
test.Run();
}

TEST(StringConcat, BidirectionalBroadcastingConcatenation) {
OpTester test("StringConcat", 20, onnxruntime::kOnnxDomain);
test.AddInput<std::string>("X", {2, 1, 3}, {"a", "b", "c", "d", "e", "f"});
test.AddInput<std::string>("Y", {1, 4, 3}, {"a", "b", "c", "d", "e", "f", "g", "h", "i", "k", "l", "m"});
test.AddOutput<std::string>("Z", {2, 4, 3},
{
"aa",
"bb",
"cc",
"ad",
"be",
"cf",
"ag",
"bh",
"ci",
"ak",
"bl",
"cm",
"da",
"eb",
"fc",
"dd",
"ee",
"ff",
"dg",
"eh",
"fi",
"dk",
"el",
"fm",
});
test.Run();
}

} // namespace test
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,6 @@
"^test_regex_full_match_basic",
"^test_regex_full_match_email_domain",
"^test_regex_full_match_empty",
"^test_string_concat_broadcasting",
"^test_string_concat",
"^test_string_concat_empty_string",
"^test_string_concat_utf8",
"^test_string_concat_zero_dimensional",
"^test_string_split_basic",
"^test_string_split_consecutive_delimiters",
"^test_string_split_empty_string_delimiter",
Expand Down

0 comments on commit 4694edc

Please sign in to comment.