Skip to content

Commit

Permalink
Normal case
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Oct 18, 2023
1 parent 2efab54 commit 9fd92aa
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 5 deletions.
2 changes: 1 addition & 1 deletion cgmanifests/generated/cgmanifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@
"component": {
"type": "git",
"git": {
"commitHash": "fdefbe85ed9c362b95b9b401cd19db068a76141f",
"commitHash": "0c296085f9f65f0f8ef7aec7b9eed55faf37dc40",
"repositoryUrl": "https://github.com/onnx/onnx.git"
},
"comments": "onnx"
Expand Down
6 changes: 3 additions & 3 deletions cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ flatbuffers;https://github.com/google/flatbuffers/archive/refs/tags/v1.12.0.zip;
fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b64b145d91.zip;b985f6985a05a1c03ff1bb71190f66d8f98a1494
fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1
google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.7.0.zip;e97c368b176e8614e3f1bf13dd9abcf6a7ad9908
google_nsync;https://github.com/google/nsync/archive/refs/tags/1.26.0.zip;5e7c00ef6bf5b787386fc040067903ec774e2752
googletest;https://github.com/google/googletest/archive/refs/tags/v1.14.0.zip;0ac421f2ec11af38b0fff0f1992184032731a8bc
google_nsync;https://github.com/google/nsync/archive/refs/tags/1.23.0.zip;f3233450cf7156fc0bedd1b0e884eddec264897c
googletest;https://github.com/google/googletest/archive/519beb0e52c842729b4b53731d27c0e0c32ab4a2.zip;4b3c37972e4c1bef1185d46f702082f8772ee73f
googlexnnpack;https://github.com/google/XNNPACK/archive/003c580e696a774afdc984996ee909b7c8d8128c.zip;9f192e3f15e1e37ae9c78d53eeea47e45c5eb31c
json;https://github.com/nlohmann/json/archive/refs/tags/v3.10.5.zip;f257f8dc27c5b8c085dc887b40cddd18ae1f725c
microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41
mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063
onnx;https://github.com/onnx/onnx/archive/14303de049144035dfd94ace5f7a3b44773b1aad.zip;250eab9690392b248d75b56e605fb49eca373442
onnx;https://github.com/onnx/onnx/archive/0c296085f9f65f0f8ef7aec7b9eed55faf37dc40.zip;01ca9e955a03a9183e3d278e96f975f1a762cef1
#use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459)
onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/0462dc31ae78f48744b6141ae376df1f96d3f459.zip;5ff086361956cceb81ed17453a1fd8db2aa4328d
protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa
Expand Down
2 changes: 1 addition & 1 deletion cmake/external/onnx
Submodule onnx updated 956 files
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Sh

// Opset 20
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, StringSplit);

// !!PLEASE READ BELOW!! Following that, add new entries above this comment

Expand Down Expand Up @@ -2389,6 +2390,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {

// Opset 20
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, StringSplit)>,
};

for (auto& function_table_entry : function_table) {
Expand Down
88 changes: 88 additions & 0 deletions onnxruntime/core/providers/cpu/nn/string_split.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#include "string_split.h"
#include "core/common/common.h"
#include <algorithm>
namespace onnxruntime {

ONNX_CPU_OPERATOR_KERNEL(
StringSplit,
20,
KernelDefBuilder()
.TypeConstraint("T1", DataTypeImpl::GetTensorType<std::string>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<std::string>())
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int64_t>()),
StringSplit);

int64_t countSubstrings(const std::string& str, const std::string& substr) {
if (str.empty()) {
return 0;
}
int64_t count = 1;
size_t pos = 0;
while ((pos = str.find(substr, pos)) != std::string::npos) {
++count;
pos += substr.length();
}

return count;
}

size_t fill_substrings(const std::string& str, const std::string& delimiter, gsl::span<std::string> output, int64_t output_index, int64_t max_tokens) {
// Fill output with substrings of str, delimited by delimiter and place into output starting at output_index and incrementing up.
// Up to max_tokens substrings should be reached. If we are done before max_tokens, fill the rest with "". If we would not be done after max_tokens, make sure the max_tokensth substring is the remainder of the string.
auto pos = 0;
size_t token_index = 0;
while (token_index < max_tokens) {
auto next_pos = str.find(delimiter, pos);
output[output_index + token_index] = str.substr(pos, next_pos - pos);
pos = next_pos + delimiter.size();
++token_index;
if (next_pos == std::string::npos) {
break;
}
}
return token_index;
}

StringSplit::StringSplit(const OpKernelInfo& info): OpKernel(info) {
info.GetAttrOrDefault("maxsplit", &maxsplit_, static_cast<int64_t>(-1)); // TODO is this the right thing to do here?
info.GetAttrOrDefault("delimiter", &delimiter_, std::string(""));
}

Status StringSplit::Compute(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
if (nullptr == input) {
return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
}
auto* num_substrings = context->Output(1, input->Shape());
if (nullptr == num_substrings) {
return Status(common::ONNXRUNTIME, common::FAIL, "output count mismatch");
}
if ("" == delimiter_) {
// TODO: takes consecutive whitespace as delimiter
} else {

auto input_data = input->template DataAsSpan<std::string>();
auto last_dim = maxsplit_;
for (auto i = 0; i < input->Shape().Size(); ++i) {
auto substring_count = countSubstrings(input_data[i], delimiter_);
last_dim = std::max(last_dim, substring_count);
}
// 1. instantiate output shape to be shape input + (last_dim,)
// 2. maintain output tensor index/pointer. Iterate over input tensor; split tensor into last_dim substrings (with "" at end for extra); copy into output tensor and output pointer/index. advance output pointer/index by last_dim.

// Set up num_substrings output
auto num_substrings_data = num_substrings->template MutableDataAsSpan<int64_t>();
// Set up splits output
auto splits_shape = input->Shape().AsShapeVector();
splits_shape.push_back(last_dim);
auto splits_data = context->Output(0, splits_shape)->template MutableDataAsSpan<std::string>();
auto splits_index = 0;
for (auto i = 0; i < input->Shape().Size(); ++i) {
num_substrings_data[i] = static_cast<int64_t>(fill_substrings(input_data[i], delimiter_, splits_data, splits_index, last_dim));
splits_index += last_dim;
}
}
return Status::OK();
}

} // namespace onnxruntime
18 changes: 18 additions & 0 deletions onnxruntime/core/providers/cpu/nn/string_split.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#pragma once

#include "core/framework/op_kernel.h"

namespace onnxruntime {

class StringSplit final : public OpKernel {
public:
explicit StringSplit(const OpKernelInfo& info);
Status Compute(OpKernelContext* context) const override;

private:
std::string delimiter_;
int64_t maxsplit_;
};


} // namespace onnxruntime
17 changes: 17 additions & 0 deletions onnxruntime/test/providers/cpu/nn/string_split_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include "gtest/gtest.h"
#include "test/providers/provider_test_utils.h"

namespace onnxruntime {
namespace test {

TEST(StringSplit, BasicSplitTest) {
OpTester test("StringSplit", 20);
test.AddInput<std::string>("X", {3}, {"hello world", "hello", "world"});
test.AddAttribute<std::string>("delimiter", " ");
test.AddOutput<std::string>("Y", {3, 2}, {"hello", "world", "hello", "", "world", ""});
test.AddOutput<int64_t>("Z", {3}, {2, 1, 1});
test.Run();
}

} // namespace test
} // namespace onnxruntime

0 comments on commit 9fd92aa

Please sign in to comment.