Skip to content

Commit

Permalink
Bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Oct 18, 2023
1 parent 5309a68 commit c29ba03
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 46 deletions.
68 changes: 27 additions & 41 deletions onnxruntime/core/providers/cpu/nn/string_split.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,76 +18,60 @@ int64_t countSubstrings(const std::string& str, const std::string& substr) {
}
if (substr.empty()) {
// Count consecutive whitespace as one delimiter
bool in_whitespace = false;
int64_t count = 1;
for (const auto& c : str) {
if (isspace(c)) {
if (!in_whitespace) {
in_whitespace = true;
count += 1;
}
} else {
in_whitespace = false;
}
size_t pos = str.find_first_not_of(" ");
while (pos != std::string::npos) {
++count;
pos = str.find_first_not_of(" ", str.find_first_of(" ", pos));
}
return count;
} else {
int64_t count = 1;
size_t pos = 0;
while ((pos = str.find(substr, pos)) != std::string::npos) {
size_t pos = str.find(substr);
while (pos != std::string::npos) {
++count;
pos += substr.length();
pos = str.find(substr, pos + substr.length());
}
return count;
}
}

size_t fill_substrings(const std::string& str, const std::string& delimiter, gsl::span<std::string> output, int64_t output_index, size_t max_tokens) {
int64_t fillSubstrings(const std::string& str, const std::string& delimiter, gsl::span<std::string> output, int64_t output_index, size_t max_tokens) {
if (str.empty()) {
return 0;
}
if (delimiter.empty()) {
// Count consecutive whitespace as one delimiter
bool in_whitespace = false;
// Count consecutive whitespace as one delimiter. Preceding and trailing whitespace is meant to be ignored.
size_t pos = str.find_first_not_of(" ");
size_t token_index = 0;
size_t substr_start_index = 0;

for (size_t i = 0; i < str.size(); ++i) {
if (token_index == max_tokens - 1) {
// if we are at the max_tokens-1 substring, the next and final substring should be the remainder of the string
output[output_index + token_index] = str.substr(i);
++token_index;
while (token_index < max_tokens && pos != std::string::npos) {
auto next_pos = token_index == max_tokens - 1 ? std::string::npos : str.find_first_of(" ", pos);
output[output_index + token_index] = str.substr(pos, next_pos - pos);
++token_index;
if (next_pos == std::string::npos) {
break;
}
if (!isspace(str[i])) {
// if currently in a whitespace, this index marks the start of the current substring
if (in_whitespace) {
substr_start_index = i;
in_whitespace = false;
}
} else if (!in_whitespace) {
// if currently not in whitespace, this index is the end of a substring
output[output_index + token_index] = str.substr(substr_start_index, i - substr_start_index);
in_whitespace = true;
++token_index;
}
pos = str.find_first_not_of(" ", next_pos);
}
return token_index;
} else {
size_t pos = 0;
size_t token_index = 0;
while (token_index < max_tokens) {
auto next_pos = str.find(delimiter, pos);
while (token_index < max_tokens && pos != std::string::npos) {
auto next_pos = token_index == max_tokens - 1 ? std::string::npos : str.find(delimiter, pos);
output[output_index + token_index] = str.substr(pos, next_pos - pos);
pos = next_pos + delimiter.size();
++token_index;
if (next_pos == std::string::npos) {
break;
}
pos = next_pos + delimiter.size();
}
return token_index;
}
}

StringSplit::StringSplit(const OpKernelInfo& info) : OpKernel(info) {
info.GetAttrOrDefault("maxsplit", &maxsplit_, static_cast<int64_t>(-1)); // TODO is this the right thing to do here?
info.GetAttrOrDefault("maxsplit", &maxsplit_, std::numeric_limits<int64_t>::max() - 1); // TODO is this the right thing to do here?
info.GetAttrOrDefault("delimiter", &delimiter_, std::string(""));
}

Expand All @@ -101,12 +85,14 @@ Status StringSplit::Compute(OpKernelContext* context) const {
return Status(common::ONNXRUNTIME, common::FAIL, "output count mismatch");
}
auto input_data = input->template DataAsSpan<std::string>();
auto last_dim = maxsplit_;
int64_t last_dim = 0;
for (auto i = 0; i < input->Shape().Size(); ++i) {
auto substring_count = countSubstrings(input_data[i], delimiter_);
last_dim = std::max(last_dim, substring_count);
}

last_dim = std::min(last_dim, maxsplit_ + 1);

// Set up num_substrings output
auto num_substrings_data = num_substrings->template MutableDataAsSpan<int64_t>();
// Set up splits output
Expand All @@ -117,7 +103,7 @@ Status StringSplit::Compute(OpKernelContext* context) const {
auto splits_data = context->Output(0, splits_shape)->template MutableDataAsSpan<std::string>();
auto splits_index = 0;
for (auto i = 0; i < input->Shape().Size(); ++i) {
num_substrings_data[i] = static_cast<int64_t>(fill_substrings(input_data[i], delimiter_, splits_data, splits_index, last_dim));
num_substrings_data[i] = fillSubstrings(input_data[i], delimiter_, splits_data, splits_index, last_dim);
splits_index += last_dim;
}
return Status::OK();
Expand Down
14 changes: 9 additions & 5 deletions onnxruntime/test/providers/cpu/nn/string_split_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,34 @@ TEST(StringSplit, MaxSplitTest) {
test.AddAttribute<std::string>("delimiter", ";");
test.AddAttribute<int64_t>("maxsplit", 1);
test.AddOutput<std::string>("Y", {2, 2, 2}, {"eggs", "milk;chesse", "pepper", "salt", "chicken", "fish;pork", "spinach", ""});
test.AddOutput<int64_t>("Z", {2, 2}, {2, 1, 2, 1});
test.AddOutput<int64_t>("Z", {2, 2}, {2, 2, 2, 1});
test.Run();
}

TEST(StringSplit, EmptyStringDelimiterTest) {
OpTester test("StringSplit", 20);
test.AddInput<std::string>("X", {1, 4}, {"hello world", "hello world", " hello world", "hello world "});
test.AddAttribute<std::string>("delimiter", "");
test.AddOutput<std::string>("Y", {1, 4, 3}, {"hello", "world", "", "hello", "world", "", "", "hello", "world", "hello", "world", ""});
test.AddOutput<std::string>("Y", {1, 4, 3}, {"hello", "world", "", "hello", "world", "", "hello", "world", "", "hello", "world", ""});
test.AddOutput<int64_t>("Z", {1, 4}, {2, 2, 2, 2});
test.Run();
}

TEST(StringSplit, SubsequentWhitespaceDefaultTest) {
OpTester test("StringSplit", 20);
test.AddInput<std::string>("X", {1, 4}, {"hello world", "hello world", " hello world", "hello world "});
test.AddOutput<std::string>("Y", {1, 4, 3}, {"hello", "world", "", "hello", "world", "", "", "hello", "world", "hello", "world", ""});
test.AddInput<std::string>("X", {1, 4}, {"hello world", "hello world", " hello world", "hello world "});
test.AddOutput<std::string>("Y", {1, 4, 3}, {"hello", "world", "", "hello", "world", "", "hello", "world", "", "hello", "world", ""});
test.AddOutput<int64_t>("Z", {1, 4}, {2, 2, 2, 2});
test.Run();
}

TEST(StringSplit, SubsequentWhitespaceWithLimitTest) {
OpTester test("StringSplit", 20);
test.AddInput<std::string>("X", {1, 4}, {"lorem ipsum doler", "Open Neural Network Exchange (ONNX)", "onnx", "ONNX runtime "});
test.AddInput<std::string>("X", {1, 4}, {"lorem ipsum doler", " Open Neural Network Exchange (ONNX)", "onnx", "ONNX runtime "});
test.AddAttribute<int64_t>("maxsplit", 1);
test.AddOutput<std::string>("Y", {1, 4, 2}, {"lorem", "ipsum doler", "Open", "Neural Network Exchange (ONNX)", "onnx", "", "ONNX", "runtime "});
test.AddOutput<int64_t>("Z", {1, 4}, {2, 2, 1, 2});
test.Run();
}

} // namespace test
Expand Down

0 comments on commit c29ba03

Please sign in to comment.