-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add CPU kernels for DynamicTimeWarping and UnfoldTensor. (#22033)
### Description Add CPU kernels for DynamicTimeWarping and UnfoldTensor.
- Loading branch information
Showing
10 changed files
with
312 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
101 changes: 101 additions & 0 deletions
101
onnxruntime/contrib_ops/cpu/tensor/dynamic_time_warping.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "contrib_ops/cpu/tensor/dynamic_time_warping.h" | ||
#include "core/providers/cpu/tensor/utils.h" | ||
|
||
#include <vector> | ||
#include <numeric> | ||
|
||
using namespace onnxruntime::common; | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
|
||
ONNX_OPERATOR_KERNEL_EX( | ||
DynamicTimeWarping, | ||
kMSDomain, | ||
1, | ||
kCpuExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.TypeConstraint("F", DataTypeImpl::GetTensorType<float>()) | ||
.TypeConstraint("I", DataTypeImpl::GetTensorType<int32_t>()), | ||
DynamicTimeWarping); | ||
|
||
Status DynamicTimeWarping::Compute(OpKernelContext* ctx) const { | ||
const Tensor& input_tensor = *ctx->Input<Tensor>(0); | ||
const auto& input_dims = input_tensor.Shape().GetDims(); | ||
int rank = SafeInt<int>(input_dims.size()); | ||
ORT_ENFORCE(rank == 2 || (rank == 3 && input_dims[0] == 1), | ||
"Currently input rank must be 2, or (3 with first dim equal to 1), but got:", rank); | ||
|
||
const size_t rows = SafeInt<size_t>(input_dims[rank == 3 ? 1 : 0]); | ||
const size_t cols = SafeInt<size_t>(input_dims[rank == 3 ? 2 : 1]); | ||
|
||
std::vector<std::vector<float>> cost(rows + 1, std::vector<float>(cols + 1, std::numeric_limits<float>::infinity())); | ||
std::vector<std::vector<int8_t>> trace(rows + 1, std::vector<int8_t>(cols + 1, -1)); | ||
std::vector<std::vector<int32_t>> path_helper; | ||
|
||
// Compute the cost and trace matrices | ||
cost[0][0] = 0; | ||
for (size_t j = 1; j <= cols; ++j) { | ||
for (size_t i = 1; i <= rows; ++i) { | ||
const float c0 = cost[i - 1][j - 1]; | ||
const float c1 = cost[i - 1][j]; | ||
const float c2 = cost[i][j - 1]; | ||
|
||
float cur_cost; | ||
int8_t cur_trace; | ||
if (c0 < c1 && c0 < c2) { | ||
cur_cost = c0; | ||
cur_trace = 0; | ||
} else if (c1 < c0 && c1 < c2) { | ||
cur_cost = c1; | ||
cur_trace = 1; | ||
} else { | ||
cur_cost = c2; | ||
cur_trace = 2; | ||
} | ||
|
||
cost[i][j] = cur_cost + input_tensor.Data<float>()[(i - 1) * cols + j - 1]; | ||
trace[i][j] = cur_trace; | ||
} | ||
} | ||
|
||
// Back-tracing to find the optimal path | ||
int i = static_cast<int>(rows); | ||
int j = static_cast<int>(cols); | ||
int result_len = 0; | ||
while (i > 0 && j > 0) { | ||
path_helper.push_back({i - 1, j - 1}); | ||
++result_len; | ||
int8_t cur_trace = trace[i][j]; | ||
switch (cur_trace) { | ||
case 0: | ||
--i; | ||
--j; | ||
break; | ||
case 1: | ||
--i; | ||
break; | ||
case 2: | ||
--j; | ||
break; | ||
default: | ||
ORT_THROW("Invalid trace value: ", cur_trace); | ||
} | ||
} | ||
|
||
// Update the output tensor | ||
Tensor* output_tensor = ctx->Output(0, TensorShape{2LL, SafeInt<int64_t>(result_len)}); | ||
auto* output_data = output_tensor->MutableData<int32_t>(); | ||
for (int k = 0; k < result_len; ++k) { | ||
output_data[k] = path_helper[static_cast<size_t>(result_len) - k - 1][0]; | ||
output_data[k + result_len] = path_helper[static_cast<size_t>(result_len) - k - 1][1]; | ||
} | ||
|
||
return Status::OK(); | ||
} | ||
|
||
} // namespace contrib | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
#include "core/framework/op_kernel.h" | ||
#include <core/common/safeint.h> | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
|
||
using onnxruntime::OpKernelContext; | ||
using onnxruntime::OpKernelInfo; | ||
|
||
class DynamicTimeWarping : public OpKernel { | ||
public: | ||
DynamicTimeWarping(const OpKernelInfo& info) : OpKernel(info) {} | ||
|
||
~DynamicTimeWarping() = default; | ||
|
||
Status Compute(OpKernelContext* context) const override; | ||
}; | ||
|
||
} // namespace contrib | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#include "contrib_ops/cpu/tensor/unfold.h" | ||
#include "core/providers/cpu/tensor/utils.h" | ||
#include "core/providers/common.h" | ||
#include "core/platform/threadpool.h" | ||
|
||
#include <vector> | ||
#include <numeric> | ||
|
||
using namespace onnxruntime::common; | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
|
||
ONNX_OPERATOR_KERNEL_EX( | ||
UnfoldTensor, | ||
kMSDomain, | ||
1, | ||
kCpuExecutionProvider, | ||
(*KernelDefBuilder::Create()) | ||
.TypeConstraint("T", DataTypeImpl::AllTensorTypes()), | ||
UnfoldTensor); | ||
|
||
template <typename T> | ||
Status LaunchUnfoldTensor(const T* input, | ||
T* output, | ||
int64_t leading_dims_size, | ||
int64_t unfold_dim_size, | ||
int64_t tailing_dims_size, | ||
int64_t unfold_size, | ||
int64_t step_size, | ||
concurrency::ThreadPool* tp) { | ||
int64_t unfold_dim_size_dst = (unfold_dim_size - unfold_size) / step_size + 1; | ||
int64_t N = leading_dims_size * unfold_dim_size_dst * tailing_dims_size * unfold_size; | ||
|
||
int64_t stride_leading_dst = unfold_size * tailing_dims_size * unfold_dim_size_dst; | ||
int64_t stride_fold_dim_src = tailing_dims_size * step_size; | ||
int64_t stride_leading_src = tailing_dims_size * unfold_dim_size; | ||
|
||
static constexpr double cost = 1.0; | ||
concurrency::ThreadPool::TryParallelFor(tp, static_cast<ptrdiff_t>(N), cost, | ||
[&](std::ptrdiff_t begin, std::ptrdiff_t end) { | ||
for (std::ptrdiff_t i = begin; i != end; ++i) { | ||
const int64_t idx = static_cast<int64_t>(i); | ||
const int64_t idx_leading = idx / stride_leading_dst; | ||
int64_t n = idx % stride_leading_dst; | ||
const int64_t stride_fold_dim_dst = tailing_dims_size * unfold_size; | ||
const int64_t idx_fold = n / stride_fold_dim_dst; | ||
n %= stride_fold_dim_dst; | ||
const int64_t idx_tailing = n / unfold_size; | ||
const int64_t idx_append = n % unfold_size; | ||
|
||
int64_t idx_src = idx_leading * stride_leading_src + | ||
idx_fold * stride_fold_dim_src + idx_tailing + | ||
idx_append * tailing_dims_size; | ||
output[idx] = input[idx_src]; | ||
} | ||
}); | ||
|
||
return Status::OK(); | ||
} | ||
|
||
Status UnfoldTensor::Compute(OpKernelContext* ctx) const { | ||
const Tensor& input_tensor = *ctx->Input<Tensor>(0); | ||
const auto& input_dims = input_tensor.Shape().GetDims(); | ||
int rank = SafeInt<int>(input_dims.size()); | ||
|
||
int dim = SafeInt<int>(HandleNegativeAxis(dim_, rank)); | ||
ORT_ENFORCE(dim < rank, "input rank:", rank, " is not bigger than attribut specified dim: ", dim); | ||
ORT_ENFORCE(input_dims[dim] >= size_, "dimsize:", input_dims[dim], " is less than unfold size:", size_); | ||
|
||
int64_t leading_dims = std::accumulate(input_dims.begin(), input_dims.begin() + static_cast<size_t>(dim), | ||
1LL, std::multiplies<int64_t>()); | ||
int64_t tailing_dims = std::accumulate(input_dims.begin() + (static_cast<size_t>(dim) + 1), | ||
input_dims.end(), 1LL, std::multiplies<int64_t>()); | ||
|
||
std::vector<int64_t> output_dims(static_cast<size_t>(rank) + 1, 0); | ||
std::copy(input_dims.begin(), input_dims.end(), output_dims.begin()); | ||
output_dims[dim] = (input_dims[dim] - size_) / step_ + 1; | ||
output_dims.back() = size_; | ||
TensorShape output_shape(output_dims); | ||
Tensor* output_tensor = ctx->Output(0, output_shape); | ||
|
||
auto* tp = ctx->GetOperatorThreadPool(); | ||
|
||
Status status; | ||
if (input_tensor.IsDataType<float>()) { | ||
status = LaunchUnfoldTensor<float>(input_tensor.Data<float>(), output_tensor->MutableData<float>(), | ||
leading_dims, input_dims[dim], tailing_dims, size_, step_, tp); | ||
} else if (input_tensor.IsDataType<double>()) { | ||
status = LaunchUnfoldTensor<double>(input_tensor.Data<double>(), output_tensor->MutableData<double>(), | ||
leading_dims, input_dims[dim], tailing_dims, size_, step_, tp); | ||
} else if (input_tensor.IsDataType<int32_t>()) { | ||
status = LaunchUnfoldTensor<int32_t>(input_tensor.Data<int32_t>(), output_tensor->MutableData<int32_t>(), | ||
leading_dims, input_dims[dim], tailing_dims, size_, step_, tp); | ||
} else if (input_tensor.IsDataType<int64_t>()) { | ||
status = LaunchUnfoldTensor<int64_t>(input_tensor.Data<int64_t>(), output_tensor->MutableData<int64_t>(), | ||
leading_dims, input_dims[dim], tailing_dims, size_, step_, tp); | ||
} else { | ||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported data type: ", input_tensor.DataType()); | ||
} | ||
|
||
return status; | ||
} | ||
|
||
} // namespace contrib | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
// Copyright (c) Microsoft Corporation. All rights reserved. | ||
// Licensed under the MIT License. | ||
|
||
#pragma once | ||
#include "core/framework/op_kernel.h" | ||
#include <core/common/safeint.h> | ||
|
||
namespace onnxruntime { | ||
namespace contrib { | ||
|
||
using onnxruntime::OpKernelContext; | ||
using onnxruntime::OpKernelInfo; | ||
|
||
template <typename T> | ||
Status LaunchUnfoldTensor( | ||
const T* input, | ||
T* output, | ||
int64_t leading_dims_size, | ||
int64_t unfold_dim_size, | ||
int64_t tailing_dims_size, | ||
int64_t unfold_size, | ||
int64_t step_size); | ||
|
||
class UnfoldTensor final : public OpKernel { | ||
public: | ||
UnfoldTensor(const OpKernelInfo& info) : OpKernel(info) { | ||
dim_ = SafeInt<int>(info.GetAttrOrDefault<int64_t>("dim", -1LL)); | ||
step_ = SafeInt<int>(info.GetAttrOrDefault<int64_t>("step", 1LL)); | ||
ORT_ENFORCE(step_ > 0, "step must greater than zero!"); | ||
|
||
int64_t temp_size; | ||
ORT_ENFORCE(info.GetAttr("size", &temp_size).IsOK()); | ||
size_ = SafeInt<int>(temp_size); | ||
} | ||
|
||
~UnfoldTensor() = default; | ||
|
||
Status Compute(OpKernelContext* context) const override; | ||
|
||
private: | ||
int dim_; | ||
int size_; | ||
int step_; | ||
}; | ||
|
||
} // namespace contrib | ||
} // namespace onnxruntime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.