From 3c80aa9feed4ee1249fdcadfd50ad66b8e039ac1 Mon Sep 17 00:00:00 2001
From: mindest <30493312+mindest@users.noreply.github.com>
Date: Sat, 12 Oct 2024 01:44:18 +0900
Subject: [PATCH] Add CPU kernels for DynamicTimeWarping and UnfoldTensor.
(#22033)
### Description
Add CPU kernels for DynamicTimeWarping and UnfoldTensor.
---
docs/ContribOperators.md | 4 +-
docs/OperatorKernels.md | 2 +
.../contrib_ops/cpu/cpu_contrib_kernels.cc | 4 +
.../cpu/tensor/dynamic_time_warping.cc | 101 ++++++++++++++++
.../cpu/tensor/dynamic_time_warping.h | 24 ++++
onnxruntime/contrib_ops/cpu/tensor/unfold.cc | 109 ++++++++++++++++++
onnxruntime/contrib_ops/cpu/tensor/unfold.h | 47 ++++++++
.../core/graph/contrib_ops/contrib_defs.cc | 12 +-
.../dynamic_time_warping_op_test.cc | 9 +-
.../test/contrib_ops/tensor_op_test.cc | 20 ++--
10 files changed, 312 insertions(+), 20 deletions(-)
create mode 100644 onnxruntime/contrib_ops/cpu/tensor/dynamic_time_warping.cc
create mode 100644 onnxruntime/contrib_ops/cpu/tensor/dynamic_time_warping.h
create mode 100644 onnxruntime/contrib_ops/cpu/tensor/unfold.cc
create mode 100644 onnxruntime/contrib_ops/cpu/tensor/unfold.h
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index da4d0b7f66c37..c791c426176df 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -1545,7 +1545,7 @@ This version of the operator has been available since version 1 of the 'com.micr
### **com.microsoft.DynamicTimeWarping**
- Input is cost matrix where each value in input[r][c] is the cost for pass the point (r, c). From current point(r, c), points (r+1, c), (r+1, c+1) or (r, c+1) could be arrived in next move. Given such cost matrix, return dynamic time wrapping of shape [2, x], where the path made by all points (output[0][t], output[1][t])have the lowest cost among all paths from (0, 0) to (M-1, N-1).
+ Input is cost matrix where each value in input[r][c] is the cost for pass the point (r, c). From current point(r, c), points (r+1, c), (r+1, c+1) or (r, c+1) could be arrived in next move. Given such cost matrix, return dynamic time warping of shape [2, x], where the path made by all points (output[0][t], output[1][t])have the lowest cost among all paths from (0, 0) to (M-1, N-1).
#### Version
@@ -5974,7 +5974,7 @@ This version of the operator has been available since version 1 of the 'com.micr
### **com.microsoft.UnfoldTensor**
- Returns a tensor which contains all slices of size size from input tensor in the dimension dim. Step between two slices is given by step. If sizedim is the size of dimension dim for input tensor, the size of dimension dim in the returned tensor will be (sizedim - size) / step + 1. An additional dimension of size size is appended in the returned tensor.
+ Returns a tensor which contains all slices of size `size` from input tensor in the dimension `dim`. Step between two slices is given by `step`. If `sizedim` is the size of dimension `dim` for input tensor, the size of dimension `dim` in the returned tensor will be `(sizedim - size) / step + 1`. An additional dimension of size `size` is appended in the returned tensor.
#### Version
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index a9176605d9175..5963f1a03b7fa 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -471,6 +471,7 @@ Do not modify directly.*
|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int16), tensor(int32), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)
**T2** = tensor(float)|
|DynamicQuantizeLSTM|*in* X:**T**
*in* W:**T2**
*in* R:**T2**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* W_scale:**T**
*in* W_zero_point:**T2**
*in* R_scale:**T**
*in* R_zero_point:**T2**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(float)
**T1** = tensor(int32)
**T2** = tensor(int8), tensor(uint8)|
|DynamicQuantizeMatMul|*in* A:**T1**
*in* B:**T2**
*in* b_scale:**T1**
*in* b_zero_point:**T2**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)|
+|DynamicTimeWarping|*in* input:**F**
*out* output:**I**|1+|**F** = tensor(float)
**I** = tensor(int32)|
|EmbedLayerNormalization|*in* input_ids:**T1**
*in* segment_ids:**T1**
*in* word_embedding:**T**
*in* position_embedding:**T**
*in* segment_embedding:**T**
*in* gamma:**T**
*in* beta:**T**
*in* mask:**T1**
*in* position_ids:**T1**
*out* output:**T**
*out* mask_index:**T1**
*out* embedding_sum:**T**|1+|**T** = tensor(float)|
|ExpandDims|*in* X:**T**
*in* axis:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**axis** = tensor(int32)|
|FastGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
@@ -518,6 +519,7 @@ Do not modify directly.*
|Tokenizer|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(string)|
|TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int64)|
+|UnfoldTensor|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Unique|*in* x:**T**
*out* y:**T**
*out* idx:**tensor(int64)**
*out* counts:**tensor(int64)**|1+|**T** = tensor(float)|
|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*in* temperature:**T**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float)|
|WordConvEmbedding|*in* Sequence:**T**
*in* W:**T1**
*in* B:**T1**
*in* C:**T1**
*out* Y:**T1**|1+|**T** = tensor(int32)
**T1** = tensor(float)|
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
index 6ffe861d19931..5f5919566713f 100644
--- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@@ -148,6 +148,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, SkipSimplifiedLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UnfoldTensor);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, DynamicTimeWarping);
#ifdef ENABLE_ATEN
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kPytorchAtenDomain, 1, ATen);
@@ -358,6 +360,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#ifdef ENABLE_ATEN
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cpu/tensor/dynamic_time_warping.cc b/onnxruntime/contrib_ops/cpu/tensor/dynamic_time_warping.cc
new file mode 100644
index 0000000000000..9f1d4d6e20307
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/tensor/dynamic_time_warping.cc
@@ -0,0 +1,101 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "contrib_ops/cpu/tensor/dynamic_time_warping.h"
+#include "core/providers/cpu/tensor/utils.h"
+
+#include
+#include
+
+using namespace onnxruntime::common;
+
+namespace onnxruntime {
+namespace contrib {
+
+ONNX_OPERATOR_KERNEL_EX(
+ DynamicTimeWarping,
+ kMSDomain,
+ 1,
+ kCpuExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("F", DataTypeImpl::GetTensorType())
+ .TypeConstraint("I", DataTypeImpl::GetTensorType()),
+ DynamicTimeWarping);
+
+Status DynamicTimeWarping::Compute(OpKernelContext* ctx) const {
+ const Tensor& input_tensor = *ctx->Input(0);
+ const auto& input_dims = input_tensor.Shape().GetDims();
+ int rank = SafeInt(input_dims.size());
+ ORT_ENFORCE(rank == 2 || (rank == 3 && input_dims[0] == 1),
+ "Currently input rank must be 2, or (3 with first dim equal to 1), but got:", rank);
+
+ const size_t rows = SafeInt(input_dims[rank == 3 ? 1 : 0]);
+ const size_t cols = SafeInt(input_dims[rank == 3 ? 2 : 1]);
+
+ std::vector> cost(rows + 1, std::vector(cols + 1, std::numeric_limits::infinity()));
+ std::vector> trace(rows + 1, std::vector(cols + 1, -1));
+ std::vector> path_helper;
+
+ // Compute the cost and trace matrices
+ cost[0][0] = 0;
+ for (size_t j = 1; j <= cols; ++j) {
+ for (size_t i = 1; i <= rows; ++i) {
+ const float c0 = cost[i - 1][j - 1];
+ const float c1 = cost[i - 1][j];
+ const float c2 = cost[i][j - 1];
+
+ float cur_cost;
+ int8_t cur_trace;
+ if (c0 < c1 && c0 < c2) {
+ cur_cost = c0;
+ cur_trace = 0;
+ } else if (c1 < c0 && c1 < c2) {
+ cur_cost = c1;
+ cur_trace = 1;
+ } else {
+ cur_cost = c2;
+ cur_trace = 2;
+ }
+
+ cost[i][j] = cur_cost + input_tensor.Data()[(i - 1) * cols + j - 1];
+ trace[i][j] = cur_trace;
+ }
+ }
+
+ // Back-tracing to find the optimal path
+ int i = static_cast(rows);
+ int j = static_cast(cols);
+ int result_len = 0;
+ while (i > 0 && j > 0) {
+ path_helper.push_back({i - 1, j - 1});
+ ++result_len;
+ int8_t cur_trace = trace[i][j];
+ switch (cur_trace) {
+ case 0:
+ --i;
+ --j;
+ break;
+ case 1:
+ --i;
+ break;
+ case 2:
+ --j;
+ break;
+ default:
+ ORT_THROW("Invalid trace value: ", cur_trace);
+ }
+ }
+
+ // Update the output tensor
+ Tensor* output_tensor = ctx->Output(0, TensorShape{2LL, SafeInt(result_len)});
+ auto* output_data = output_tensor->MutableData();
+ for (int k = 0; k < result_len; ++k) {
+ output_data[k] = path_helper[static_cast(result_len) - k - 1][0];
+ output_data[k + result_len] = path_helper[static_cast(result_len) - k - 1][1];
+ }
+
+ return Status::OK();
+}
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/tensor/dynamic_time_warping.h b/onnxruntime/contrib_ops/cpu/tensor/dynamic_time_warping.h
new file mode 100644
index 0000000000000..76083d426a58a
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/tensor/dynamic_time_warping.h
@@ -0,0 +1,24 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/framework/op_kernel.h"
+#include
+
+namespace onnxruntime {
+namespace contrib {
+
+using onnxruntime::OpKernelContext;
+using onnxruntime::OpKernelInfo;
+
+class DynamicTimeWarping : public OpKernel {
+ public:
+ DynamicTimeWarping(const OpKernelInfo& info) : OpKernel(info) {}
+
+ ~DynamicTimeWarping() = default;
+
+ Status Compute(OpKernelContext* context) const override;
+};
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/tensor/unfold.cc b/onnxruntime/contrib_ops/cpu/tensor/unfold.cc
new file mode 100644
index 0000000000000..edafa538be219
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/tensor/unfold.cc
@@ -0,0 +1,109 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "contrib_ops/cpu/tensor/unfold.h"
+#include "core/providers/cpu/tensor/utils.h"
+#include "core/providers/common.h"
+#include "core/platform/threadpool.h"
+
+#include
+#include
+
+using namespace onnxruntime::common;
+
+namespace onnxruntime {
+namespace contrib {
+
+ONNX_OPERATOR_KERNEL_EX(
+ UnfoldTensor,
+ kMSDomain,
+ 1,
+ kCpuExecutionProvider,
+ (*KernelDefBuilder::Create())
+ .TypeConstraint("T", DataTypeImpl::AllTensorTypes()),
+ UnfoldTensor);
+
+template
+Status LaunchUnfoldTensor(const T* input,
+ T* output,
+ int64_t leading_dims_size,
+ int64_t unfold_dim_size,
+ int64_t tailing_dims_size,
+ int64_t unfold_size,
+ int64_t step_size,
+ concurrency::ThreadPool* tp) {
+ int64_t unfold_dim_size_dst = (unfold_dim_size - unfold_size) / step_size + 1;
+ int64_t N = leading_dims_size * unfold_dim_size_dst * tailing_dims_size * unfold_size;
+
+ int64_t stride_leading_dst = unfold_size * tailing_dims_size * unfold_dim_size_dst;
+ int64_t stride_fold_dim_src = tailing_dims_size * step_size;
+ int64_t stride_leading_src = tailing_dims_size * unfold_dim_size;
+
+ static constexpr double cost = 1.0;
+ concurrency::ThreadPool::TryParallelFor(tp, static_cast(N), cost,
+ [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
+ for (std::ptrdiff_t i = begin; i != end; ++i) {
+ const int64_t idx = static_cast(i);
+ const int64_t idx_leading = idx / stride_leading_dst;
+ int64_t n = idx % stride_leading_dst;
+ const int64_t stride_fold_dim_dst = tailing_dims_size * unfold_size;
+ const int64_t idx_fold = n / stride_fold_dim_dst;
+ n %= stride_fold_dim_dst;
+ const int64_t idx_tailing = n / unfold_size;
+ const int64_t idx_append = n % unfold_size;
+
+ int64_t idx_src = idx_leading * stride_leading_src +
+ idx_fold * stride_fold_dim_src + idx_tailing +
+ idx_append * tailing_dims_size;
+ output[idx] = input[idx_src];
+ }
+ });
+
+ return Status::OK();
+}
+
+Status UnfoldTensor::Compute(OpKernelContext* ctx) const {
+ const Tensor& input_tensor = *ctx->Input(0);
+ const auto& input_dims = input_tensor.Shape().GetDims();
+ int rank = SafeInt(input_dims.size());
+
+ int dim = SafeInt(HandleNegativeAxis(dim_, rank));
+ ORT_ENFORCE(dim < rank, "input rank:", rank, " is not bigger than attribut specified dim: ", dim);
+ ORT_ENFORCE(input_dims[dim] >= size_, "dimsize:", input_dims[dim], " is less than unfold size:", size_);
+
+ int64_t leading_dims = std::accumulate(input_dims.begin(), input_dims.begin() + static_cast(dim),
+ 1LL, std::multiplies());
+ int64_t tailing_dims = std::accumulate(input_dims.begin() + (static_cast(dim) + 1),
+ input_dims.end(), 1LL, std::multiplies());
+
+ std::vector output_dims(static_cast(rank) + 1, 0);
+ std::copy(input_dims.begin(), input_dims.end(), output_dims.begin());
+ output_dims[dim] = (input_dims[dim] - size_) / step_ + 1;
+ output_dims.back() = size_;
+ TensorShape output_shape(output_dims);
+ Tensor* output_tensor = ctx->Output(0, output_shape);
+
+ auto* tp = ctx->GetOperatorThreadPool();
+
+ Status status;
+ if (input_tensor.IsDataType()) {
+ status = LaunchUnfoldTensor(input_tensor.Data(), output_tensor->MutableData(),
+ leading_dims, input_dims[dim], tailing_dims, size_, step_, tp);
+ } else if (input_tensor.IsDataType()) {
+ status = LaunchUnfoldTensor(input_tensor.Data(), output_tensor->MutableData(),
+ leading_dims, input_dims[dim], tailing_dims, size_, step_, tp);
+ } else if (input_tensor.IsDataType()) {
+ status = LaunchUnfoldTensor(input_tensor.Data(), output_tensor->MutableData(),
+ leading_dims, input_dims[dim], tailing_dims, size_, step_, tp);
+ } else if (input_tensor.IsDataType()) {
+ status = LaunchUnfoldTensor(input_tensor.Data(), output_tensor->MutableData(),
+ leading_dims, input_dims[dim], tailing_dims, size_, step_, tp);
+ } else {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported data type: ", input_tensor.DataType());
+ }
+
+ return status;
+}
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/tensor/unfold.h b/onnxruntime/contrib_ops/cpu/tensor/unfold.h
new file mode 100644
index 0000000000000..6c48d0f67fcc2
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/tensor/unfold.h
@@ -0,0 +1,47 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/framework/op_kernel.h"
+#include
+
+namespace onnxruntime {
+namespace contrib {
+
+using onnxruntime::OpKernelContext;
+using onnxruntime::OpKernelInfo;
+
+template
+Status LaunchUnfoldTensor(
+ const T* input,
+ T* output,
+ int64_t leading_dims_size,
+ int64_t unfold_dim_size,
+ int64_t tailing_dims_size,
+ int64_t unfold_size,
+ int64_t step_size);
+
+class UnfoldTensor final : public OpKernel {
+ public:
+ UnfoldTensor(const OpKernelInfo& info) : OpKernel(info) {
+ dim_ = SafeInt(info.GetAttrOrDefault("dim", -1LL));
+ step_ = SafeInt(info.GetAttrOrDefault("step", 1LL));
+ ORT_ENFORCE(step_ > 0, "step must greater than zero!");
+
+ int64_t temp_size;
+ ORT_ENFORCE(info.GetAttr("size", &temp_size).IsOK());
+ size_ = SafeInt(temp_size);
+ }
+
+ ~UnfoldTensor() = default;
+
+ Status Compute(OpKernelContext* context) const override;
+
+ private:
+ int dim_;
+ int size_;
+ int step_;
+};
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index 115db369d2af0..09a4a77780916 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -1067,11 +1067,11 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GridSample, 1,
ONNX_MS_OPERATOR_SET_SCHEMA(
UnfoldTensor, 1,
OpSchema()
- .SetDoc("Returns a tensor which contains all slices of size size from input tensor in the dimension dim. "
- "Step between two slices is given by step. "
- "If sizedim is the size of dimension dim for input tensor, the size of dimension dim in "
- "the returned tensor will be (sizedim - size) / step + 1. "
- "An additional dimension of size size is appended in the returned tensor.")
+ .SetDoc("Returns a tensor which contains all slices of size `size` from input tensor in the dimension `dim`. "
+ "Step between two slices is given by `step`. "
+ "If `sizedim` is the size of dimension `dim` for input tensor, the size of dimension `dim` in "
+ "the returned tensor will be `(sizedim - size) / step + 1`. "
+ "An additional dimension of size `size` is appended in the returned tensor.")
.Attr("dim", "specify the dimension to unfold", AttributeProto::INT, static_cast(-1))
.Attr("size", "specify the size", AttributeProto::INT)
.Attr("step", "specify the step.", AttributeProto::INT, static_cast(1))
@@ -1122,7 +1122,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
OpSchema()
.SetDoc("Input is cost matrix where each value in input[r][c] is the cost for pass the point (r, c). From current point"
"(r, c), points (r+1, c), (r+1, c+1) or (r, c+1) could be arrived in next move. Given such cost matrix, return "
- "dynamic time wrapping of shape [2, x], where the path made by all points (output[0][t], output[1][t])"
+ "dynamic time warping of shape [2, x], where the path made by all points (output[0][t], output[1][t])"
"have the lowest cost among all paths from (0, 0) to (M-1, N-1).")
.Input(0, "input", "Input cost tensor, it must be 2D tensor of shape M x N, or 1 x M x N", "F")
.Output(0, "output", "Output tensor. shape is [2, x], where max(M, N) <= x < M + N", "I")
diff --git a/onnxruntime/test/contrib_ops/dynamic_time_warping_op_test.cc b/onnxruntime/test/contrib_ops/dynamic_time_warping_op_test.cc
index ea6f93a273055..4754f3a520694 100644
--- a/onnxruntime/test/contrib_ops/dynamic_time_warping_op_test.cc
+++ b/onnxruntime/test/contrib_ops/dynamic_time_warping_op_test.cc
@@ -11,12 +11,12 @@ using namespace ONNX_NAMESPACE;
namespace onnxruntime {
namespace test {
+TEST(DynamicTimeWarping, simple) {
#ifdef USE_CUDA
-
-TEST(DynamicTimeWarp, simple) {
if (NeedSkipIfCudaArchLowerThan(530)) {
return;
}
+#endif
std::vector X = {
3.0f,
@@ -113,11 +113,12 @@ TEST(DynamicTimeWarp, simple) {
tester.AddOutput("output", {2, 12}, Y);
std::vector> execution_providers;
+#ifdef USE_CUDA
execution_providers.push_back(DefaultCudaExecutionProvider());
+#endif
+ execution_providers.push_back(DefaultCpuExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
-#endif
-
} // namespace test
} // namespace onnxruntime
diff --git a/onnxruntime/test/contrib_ops/tensor_op_test.cc b/onnxruntime/test/contrib_ops/tensor_op_test.cc
index 81c8641f450f6..bc2ff5f4f724d 100644
--- a/onnxruntime/test/contrib_ops/tensor_op_test.cc
+++ b/onnxruntime/test/contrib_ops/tensor_op_test.cc
@@ -205,12 +205,12 @@ TEST(MVNContribOpTest, MeanVarianceNormalizationCPUTest_Version1_TO_8) {
MeanVarianceNormalizationPerChannel(false, true);
}
-#ifdef USE_CUDA
-
TEST(UnfoldTensorOpTest, LastDim) {
+#ifdef USE_CUDA
if (NeedSkipIfCudaArchLowerThan(530)) {
return;
}
+#endif
std::vector X = {
1.0f, 2.0f, 3.0f, 4.0f,
@@ -229,7 +229,10 @@ TEST(UnfoldTensorOpTest, LastDim) {
tester.AddOutput("output", {3, 2, 3}, output);
std::vector> execution_providers;
+#ifdef USE_CUDA
execution_providers.push_back(DefaultCudaExecutionProvider());
+#endif
+ execution_providers.push_back(DefaultCpuExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
@@ -238,13 +241,13 @@ TEST(UnfoldTensorOpTest, NormalDim) {
return;
}
- std::vector X = {
+ std::vector X = {
1, 2, 3, 4, 2, 2, 3, 4, 3, 2, 3, 4,
4, 6, 7, 8, 5, 6, 7, 8, 6, 6, 7, 8,
6, 7, 8, 9, 7, 7, 8, 9, 8, 7, 8, 9,
9, 7, 8, 9, 10, 7, 8, 9, 11, 7, 8, 9};
- std::vector output = {
+ std::vector output = {
1, 2, 3,
2, 2, 2,
3, 3, 3,
@@ -269,15 +272,16 @@ TEST(UnfoldTensorOpTest, NormalDim) {
tester.AddAttribute("dim", 1LL);
tester.AddAttribute("size", 3LL);
tester.AddAttribute("step", 2LL);
- tester.AddInput("input", {2, 6, 4}, X);
- tester.AddOutput("output", {2, 2, 4, 3}, output);
+ tester.AddInput("input", {2, 6, 4}, X);
+ tester.AddOutput("output", {2, 2, 4, 3}, output);
std::vector> execution_providers;
+#ifdef USE_CUDA
execution_providers.push_back(DefaultCudaExecutionProvider());
+#endif
+ execution_providers.push_back(DefaultCpuExecutionProvider());
tester.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
-#endif
-
} // namespace test
} // namespace onnxruntime