Skip to content

Commit

Permalink
Implement feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Jan 9, 2024
1 parent 45edebf commit 6ab138a
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 45 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cpu/quantization/qlinear_concat.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c Microsoft Corporation. All rights reserved.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "qlinear_util.h"
Expand Down
91 changes: 50 additions & 41 deletions onnxruntime/core/providers/cpu/nn/string_split.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "string_split.h"
#include "core/common/common.h"
#include <algorithm>
Expand All @@ -12,99 +15,105 @@ ONNX_CPU_OPERATOR_KERNEL(
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int64_t>()),
StringSplit);

int64_t countSubstrings(std::string_view str, std::string_view substr) {
if (str.empty()) {
return 0;
}
/// Count the number of instances of substring ``substr`` in ``str``. If ``substr`` is an empty string it counts the number of whitespace delimited words.
int64_t CountSubstrings(std::string_view str, std::string_view substr) {
if (substr.empty()) {
// Count consecutive whitespace as one delimiter
int64_t count = 1;
int64_t count = 0;
size_t pos = str.find_first_not_of(" ");
while (pos != std::string::npos) {
++count;
pos = str.find_first_not_of(" ", str.find_first_of(" ", pos));
}
return count;
} else {
int64_t count = 1;
size_t pos = str.find(substr);
int64_t count = 0;
size_t pos = 0;
while (pos != std::string::npos) {
++count;
pos = str.find(substr, pos + substr.length());
pos = str.find(substr, pos);
if (pos != std::string::npos) {
pos += substr.length();
}
}
return count;
}
}

int64_t fillSubstrings(std::string_view str, std::string_view delimiter, gsl::span<std::string> output, int64_t output_index, size_t max_tokens) {
/// Fill substrings of ``str`` based on split delimiter ``delimiter`` into ``output`` span. Restrict maximum number of generated substrings to ``max_tokens``.
/// The function returns the number of substrings generated (this is less or equal to ``max_tokens``).
int64_t FillSubstrings(std::string_view str, std::string_view delimiter, gsl::details::span_iterator<std::string> output, size_t max_tokens) {
if (str.empty()) {
return 0;
}
if (delimiter.empty()) {
// Count consecutive whitespace as one delimiter. Preceding and trailing whitespace is meant to be ignored.
size_t pos = str.find_first_not_of(" ");
size_t token_index = 0;
while (token_index < max_tokens && pos != std::string::npos) {
auto next_pos = token_index == max_tokens - 1 ? std::string::npos : str.find_first_of(" ", pos);
output[output_index + token_index] = str.substr(pos, next_pos - pos);
++token_index;
if (next_pos == std::string::npos) {
int64_t token_count = 0;
while (pos != std::string::npos) {
if (++token_count == max_tokens) {
// trim down last substring as required in specification
size_t next_pos = str.length() - 1;
while (str[next_pos] == ' ') {
next_pos--;
}
*output = str.substr(pos, next_pos - pos + 1);
break;
} else {
auto next_pos = str.find_first_of(" ", pos);
*output = str.substr(pos, next_pos - pos);
pos = str.find_first_not_of(" ", next_pos);
}
pos = str.find_first_not_of(" ", next_pos);

output++;
}
return token_index;
return token_count;
} else {
size_t pos = 0;
size_t token_index = 0;
while (token_index < max_tokens && pos != std::string::npos) {
auto next_pos = token_index == max_tokens - 1 ? std::string::npos : str.find(delimiter, pos);
output[output_index + token_index] = str.substr(pos, next_pos - pos);
++token_index;
int64_t token_count = 0;
while (pos != std::string::npos) {
auto next_pos = token_count == max_tokens - 1 ? std::string::npos : str.find(delimiter, pos);
*output++ = str.substr(pos, next_pos - pos);
token_count++;
if (next_pos == std::string::npos) {
break;
}
pos = next_pos + delimiter.size();
}
return token_index;
return token_count;
}
}

StringSplit::StringSplit(const OpKernelInfo& info) : OpKernel(info) {
info.GetAttrOrDefault("maxsplit", &maxsplit_, std::numeric_limits<int64_t>::max() - 1); // TODO is this the right thing to do here?
info.GetAttrOrDefault("maxsplit", &maxsplit_, std::numeric_limits<int64_t>::max() - 1);
info.GetAttrOrDefault("delimiter", &delimiter_, std::string(""));
}

Status StringSplit::Compute(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
if (nullptr == input) {
return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
}
auto* num_substrings = context->Output(1, input->Shape());
if (nullptr == num_substrings) {
return Status(common::ONNXRUNTIME, common::FAIL, "output count mismatch");
}
auto input_data = input->template DataAsSpan<std::string>();

int64_t last_dim = 0;
for (auto i = 0; i < input->Shape().Size(); ++i) {
auto substring_count = countSubstrings(input_data[i], delimiter_);
last_dim = std::max(last_dim, substring_count);
for (const auto& str : input_data) {
last_dim = std::max(last_dim, CountSubstrings(str, delimiter_));
}

last_dim = std::min(last_dim, maxsplit_ + 1);

// Set up num_substrings output
auto num_substrings_data = num_substrings->template MutableDataAsSpan<int64_t>();
// Set up splits output
auto splits_shape = input->Shape().AsShapeVector();
if (last_dim > 0) {
splits_shape.push_back(last_dim);
}
auto splits_data = context->Output(0, splits_shape)->template MutableDataAsSpan<std::string>();
auto splits_index = 0;
for (auto i = 0; i < input->Shape().Size(); ++i) {
num_substrings_data[i] = fillSubstrings(input_data[i], delimiter_, splits_data, splits_index, last_dim);
splits_index += last_dim;
auto output_splits_iter = splits_data.begin();

// Set up number of tokens output
auto* num_substrings = context->Output(1, input->Shape());
auto num_substrings_data = num_substrings->template MutableDataAsSpan<int64_t>();
auto output_num_tokens_iter = num_substrings_data.begin();

for (auto input_iter = input_data.begin(); input_iter != input_data.end(); input_iter++, output_splits_iter += last_dim, output_num_tokens_iter++) {
*output_num_tokens_iter = FillSubstrings(*input_iter, delimiter_, output_splits_iter, last_dim);
}
return Status::OK();
}
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/cpu/nn/string_split.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/framework/op_kernel.h"
Expand Down
40 changes: 37 additions & 3 deletions onnxruntime/test/providers/cpu/nn/string_split_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ TEST(StringSplit, EmptyStringDelimiterTest) {
OpTester test("StringSplit", 20);
test.AddInput<std::string>("X", {1, 4}, {"hello world", "hello world", " hello world", "hello world "});
test.AddAttribute<std::string>("delimiter", "");
test.AddOutput<std::string>("Y", {1, 4, 3}, {"hello", "world", "", "hello", "world", "", "hello", "world", "", "hello", "world", ""});
test.AddOutput<std::string>("Y", {1, 4, 2}, {"hello", "world", "hello", "world", "hello", "world", "hello", "world"});
test.AddOutput<int64_t>("Z", {1, 4}, {2, 2, 2, 2});
test.Run();
}

TEST(StringSplit, SubsequentWhitespaceDefaultTest) {
OpTester test("StringSplit", 20);
test.AddInput<std::string>("X", {1, 4}, {"hello world", "hello world", " hello world", "hello world "});
test.AddOutput<std::string>("Y", {1, 4, 3}, {"hello", "world", "", "hello", "world", "", "hello", "world", "", "hello", "world", ""});
test.AddOutput<std::string>("Y", {1, 4, 2}, {"hello", "world", "hello", "world", "hello", "world", "hello", "world"});
test.AddOutput<int64_t>("Z", {1, 4}, {2, 2, 2, 2});
test.Run();
}
Expand All @@ -44,10 +44,44 @@ TEST(StringSplit, SubsequentWhitespaceWithLimitTest) {
OpTester test("StringSplit", 20);
test.AddInput<std::string>("X", {1, 4}, {"lorem ipsum doler", " Open Neural Network Exchange (ONNX)", "onnx", "ONNX runtime "});
test.AddAttribute<int64_t>("maxsplit", 1);
test.AddOutput<std::string>("Y", {1, 4, 2}, {"lorem", "ipsum doler", "Open", "Neural Network Exchange (ONNX)", "onnx", "", "ONNX", "runtime "});
test.AddOutput<std::string>("Y", {1, 4, 2}, {"lorem", "ipsum doler", "Open", "Neural Network Exchange (ONNX)", "onnx", "", "ONNX", "runtime"});
test.AddOutput<int64_t>("Z", {1, 4}, {2, 2, 1, 2});
test.Run();
}

TEST(StringSplit, SingleTokenTest) {
OpTester test("StringSplit", 20);
test.AddAttribute<std::string>("delimiter", "*");
test.AddInput<std::string>("X", {1, 1, 1}, {"lorem"});
test.AddOutput<std::string>("Y", {1, 1, 1, 1}, {"lorem"});
test.AddOutput<int64_t>("Z", {1, 1, 1}, {1});
test.Run();
}

TEST(StringSplit, SingleTokenWhitespaceTest) {
OpTester test("StringSplit", 20);
test.AddInput<std::string>("X", {1, 1, 1}, {"lorem"});
test.AddOutput<std::string>("Y", {1, 1, 1, 1}, {"lorem"});
test.AddOutput<int64_t>("Z", {1, 1, 1}, {1});
test.Run();
}

TEST(StringSplit, EdgeWhitespaceTest) {
OpTester test("StringSplit", 20);
test.AddInput<std::string>("X", {1, 1, 1}, {" lorem "});
test.AddOutput<std::string>("Y", {1, 1, 1, 1}, {"lorem"});
test.AddOutput<int64_t>("Z", {1, 1, 1}, {1});
test.Run();
}

TEST(StringSplit, EmptyInputTest) {
OpTester test("StringSplit", 20);
test.AddAttribute<std::string>("delimiter", "*");
test.AddInput<std::string>("X", {1, 3, 1}, {"", "+", "*"});
test.AddOutput<std::string>("Y", {1, 3, 1, 2}, {"", "", "+", "", "", ""});
test.AddOutput<int64_t>("Z", {1, 3, 1}, {0, 1, 2});
test.Run();
}

} // namespace test
} // namespace onnxruntime

0 comments on commit 6ab138a

Please sign in to comment.