Skip to content

Commit

Permalink
Update substring calculation function name
Browse files Browse the repository at this point in the history
  • Loading branch information
adityagoel4512 committed Jan 10, 2024
1 parent 7109712 commit cf91ede
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions onnxruntime/core/providers/cpu/nn/string_split.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ ONNX_CPU_OPERATOR_KERNEL(StringSplit, 20,
/// Calculate substrings in ``str`` delimited by ``delimiter``. A maximum of ``max_splits`` splits are permitted.
/// Returns a vector of string slices into ``str`` representing the substrings as string views. The user must ensure
/// the returned views' lifetime does not exceed ``str``'s.
InlinedVector<std::string_view> FillSubstrings(std::string_view str, std::string_view delimiter, int64_t max_splits) {
InlinedVector<std::string_view> ComputeSubstrings(std::string_view str, std::string_view delimiter, int64_t max_splits) {
InlinedVector<std::string_view> output;
if (str.empty()) {
return output;
Expand Down Expand Up @@ -61,7 +61,7 @@ InlinedVector<std::string_view> FillSubstrings(std::string_view str, std::string

StringSplit::StringSplit(const OpKernelInfo& info) : OpKernel(info) {
info.GetAttrOrDefault("maxsplit", &maxsplit_, std::numeric_limits<int64_t>::max() - 1);
info.GetAttrOrDefault("delimiter", &delimiter_, std::string(""));
info.GetAttrOrDefault("delimiter", &delimiter_, std::string());
}

Status StringSplit::Compute(OpKernelContext* context) const {
Expand All @@ -76,11 +76,10 @@ Status StringSplit::Compute(OpKernelContext* context) const {
input_slices.reserve(input_data.size());
int64_t last_dim = 1;

auto input_slice_iterator = input_slices.begin();
for (auto input_iter = input_data.begin(); input_iter != input_data.end(); input_iter++, input_slice_iterator++, num_tokens_iter++) {
auto substrs = FillSubstrings(*input_iter, delimiter_, maxsplit_);
for (auto input_iter = input_data.begin(); input_iter != input_data.end(); input_iter++, num_tokens_iter++) {
auto substrs = ComputeSubstrings(*input_iter, delimiter_, maxsplit_);
auto substr_count = static_cast<int64_t>(substrs.size());
input_slices.push_back(substrs);
input_slices.push_back(std::move(substrs));
last_dim = std::max(last_dim, substr_count);
*num_tokens_iter = substr_count;
}
Expand Down

0 comments on commit cf91ede

Please sign in to comment.