Skip to content

Commit

Permalink
Implement broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Oct 17, 2023
1 parent b786943 commit e729b71
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 27 deletions.
45 changes: 31 additions & 14 deletions onnxruntime/core/providers/cpu/nn/string_concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,36 @@ ONNX_CPU_OPERATOR_KERNEL(
StringConcat);

Status StringConcat::Compute(OpKernelContext* context) const {
const auto* X = context->Input<Tensor>(0);
const auto* X_data = X->template Data<std::string>();
const auto* Y = context->Input<Tensor>(1);
const auto* Y_data = Y->template Data<std::string>();
auto* Z = context->Output(0, X->Shape());
auto* Z_data = Z->template MutableData<std::string>();
const auto N = X->Shape().Size();

for (int64_t i = 0; i < N; ++i) {
Z_data[i] = X_data[i] + Y_data[i];
}

return Status::OK();
ProcessBroadcastSpanFuncs broadcast_funcs{
[](BroadcastHelper& broadcast_helper) {
auto x = broadcast_helper.ScalarInput0<std::string>();
auto Y = broadcast_helper.SpanInput1<std::string>();
auto output = broadcast_helper.OutputSpan<std::string>();
std::transform(Y.begin(), Y.end(), output.begin(),
[&x](const std::string& y) {
return x + y;
});
},
[](BroadcastHelper& broadcast_helper) {
auto X = broadcast_helper.SpanInput0<std::string>();
auto y = broadcast_helper.ScalarInput1<std::string>();
auto output = broadcast_helper.OutputSpan<std::string>();
std::transform(X.begin(), X.end(), output.begin(),
[&y](const std::string& x) {
return x + y;
});
},
[](BroadcastHelper& broadcast_helper) {
auto X = broadcast_helper.SpanInput0<std::string>();
auto Y = broadcast_helper.SpanInput1<std::string>();
auto output = broadcast_helper.OutputSpan<std::string>();
std::transform(X.begin(), X.end(), Y.begin(), output.begin(),
[](const std::string& x, const std::string& y) {
return x + y;
});
}};
UntypedBroadcastTwo(*context, broadcast_funcs);
return Status::OK();
}

} // namespace onnxruntime
} // namespace onnxruntime
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/cpu/nn/string_concat.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ namespace onnxruntime {

class StringConcat final : public OpKernel {
public:
StringConcat(const OpKernelInfo& info): OpKernel(info) {}
StringConcat(const OpKernelInfo& info) : OpKernel(info) {}

Status Compute(OpKernelContext* context) const override;
};

} // namespace onnxruntime
} // namespace onnxruntime
32 changes: 21 additions & 11 deletions onnxruntime/test/providers/cpu/nn/stringconcat_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,33 @@ namespace onnxruntime {
namespace test {

static void RunTest(const std::vector<int64_t>& dims, const std::vector<std::string>& input1, const std::vector<std::string>& input2, const std::vector<std::string>& output) {
std::cout << "Running test\n";
OpTester test("StringConcat", 20, onnxruntime::kOnnxDomain);
test.AddInput<std::string>("X", dims, input1);
std::cout << "add input1" << "\n";
test.AddInput<std::string>("Y", dims, input2);
test.AddOutput<std::string>("Z", dims, output);
test.Run();
std::cout << "Running test\n";
OpTester test("StringConcat", 20, onnxruntime::kOnnxDomain);
test.AddInput<std::string>("X", dims, input1);
std::cout << "add input1"
<< "\n";
test.AddInput<std::string>("Y", dims, input2);
test.AddOutput<std::string>("Z", dims, output);
test.Run();
}

TEST(StringConcat, BasicConcatenation) {
RunTest({1, 2}, {"Hello", "World"}, {"Hello", "World"}, {"HelloHello", "WorldWorld"});
RunTest({1, 2}, {"Hello", "World"}, {"Hello", "World"}, {"HelloHello", "WorldWorld"});
}

TEST(StringConcat, TwoDimensionalConcatenation) {
RunTest({2, 2}, {"Hello", "World", "ONNX", "onnxruntime"}, {"Hello", "World", "ONNX", "onnxruntime"}, {"HelloHello", "WorldWorld", "ONNXONNX", "onnxruntimeonnxruntime"});
RunTest({2, 2}, {"Hello", "World", "ONNX", "onnxruntime"}, {"Hello", "World", "ONNX", "onnxruntime"}, {"HelloHello", "WorldWorld", "ONNXONNX", "onnxruntimeonnxruntime"});
}

TEST(StringConcat, BroadcastingConcatenation) {
OpTester test("StringConcat", 20, onnxruntime::kOnnxDomain);
test.AddInput<std::string>("X", {2, 2}, {"Hello", "World", "ONNX", "onnxruntime"});
std::cout << "add broadcasting input"
<< "\n";
test.AddInput<std::string>("Y", {1}, {"!"});
test.AddOutput<std::string>("Z", {2, 2}, {"Hello!", "World!", "ONNX!", "onnxruntime!"});
test.Run();
}

} // namespace test
} // namespace onnxruntime
} // namespace test
} // namespace onnxruntime

0 comments on commit e729b71

Please sign in to comment.