diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 16df788c284ee..3c6771456b12f 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -80,7 +80,8 @@ Do not modify directly.*
|Crop|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)|
|CumSum|*in* x:**T**
*in* axis:**T2**
*out* y:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)|
|||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)|
-|DFT|*in* input:**T1**
*in* dft_length:**T2**
*in* axis:**tensor(int64)**
*out* output:**T1**
or
*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)|
+|DFT|*in* input:**T1**
*in* dft_length:**T2**
*in* axis:**tensor(int64)**
*out* output:**T1**
or
*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|20+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)|
+|||[17, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)|
|DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)|
|||[11, 12]|**T** = tensor(double), tensor(float)|
|||[1, 10]|**T** = tensor(double), tensor(float)|
diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
index 4553e7ee18913..1390f60243174 100644
--- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
+++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc
@@ -823,7 +823,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, int64_t, LessOrEqual);
// Opset 17
-class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, DFT);
+class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, 19, DFT);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, BlackmanWindow);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, HammingWindow);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, HannWindow);
@@ -960,6 +960,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Sh
// Opset 20
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, DFT);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, GridSample);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, GridSample);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, AffineGrid);
@@ -2217,7 +2218,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
// Opset 17
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -2403,6 +2404,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {
// Opset 20
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cpu/signal/dft.cc b/onnxruntime/core/providers/cpu/signal/dft.cc
index 8634e393b43d0..15bf633579e5f 100644
--- a/onnxruntime/core/providers/cpu/signal/dft.cc
+++ b/onnxruntime/core/providers/cpu/signal/dft.cc
@@ -19,7 +19,15 @@
namespace onnxruntime {
-ONNX_CPU_OPERATOR_KERNEL(DFT, 17,
+ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
+ DFT,
+ 17, 19,
+ KernelDefBuilder()
+ .TypeConstraint("T1", BuildKernelDefConstraints())
+ .TypeConstraint("T2", BuildKernelDefConstraints()),
+ DFT);
+
+ONNX_CPU_OPERATOR_KERNEL(DFT, 20,
KernelDefBuilder()
.TypeConstraint("T1", BuildKernelDefConstraints())
.TypeConstraint("T2", BuildKernelDefConstraints()),
@@ -442,7 +450,13 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, int64_t axis, boo
}
Status DFT::Compute(OpKernelContext* ctx) const {
- ORT_RETURN_IF_ERROR(discrete_fourier_transform(ctx, axis_, is_onesided_, is_inverse_));
+ int64_t axis = axis_;
+ if (opset_ >= 20 && ctx->InputCount() >= 3) {
+ const Tensor* axes_tensor = ctx->Input(2);
+ axis = axes_tensor->Data()[0];
+ }
+
+ ORT_RETURN_IF_ERROR(discrete_fourier_transform(ctx, axis, is_onesided_, is_inverse_));
return Status::OK();
}
diff --git a/onnxruntime/core/providers/cpu/signal/dft.h b/onnxruntime/core/providers/cpu/signal/dft.h
index 71cac52e37e8f..967d4ec15524b 100644
--- a/onnxruntime/core/providers/cpu/signal/dft.h
+++ b/onnxruntime/core/providers/cpu/signal/dft.h
@@ -7,6 +7,7 @@
namespace onnxruntime {
class DFT final : public OpKernel {
+ int opset_;
bool is_onesided_ = true;
int64_t axis_ = 0;
bool is_inverse_ = false;
@@ -14,7 +15,11 @@ class DFT final : public OpKernel {
public:
explicit DFT(const OpKernelInfo& info) : OpKernel(info) {
is_onesided_ = static_cast(info.GetAttrOrDefault("onesided", 0));
- axis_ = info.GetAttrOrDefault("axis", 1);
+ opset_ = info.node().SinceVersion();
+ if (opset_ < 20)
+ axis_ = info.GetAttrOrDefault("axis", 1);
+ else
+ axis_ = -2; // default axis of DFT(20)
is_inverse_ = info.GetAttrOrDefault("inverse", 0);
}
Status Compute(OpKernelContext* ctx) const override;
diff --git a/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc b/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc
index 3d4324189d463..54d725defe5ee 100644
--- a/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc
+++ b/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc
@@ -16,9 +16,10 @@ namespace onnxruntime {
namespace test {
static constexpr int kMinOpsetVersion = 17;
+static constexpr int kOpsetVersion20 = 20;
-static void TestNaiveDFTFloat(bool onesided) {
- OpTester test("DFT", kMinOpsetVersion);
+static void TestNaiveDFTFloat(bool onesided, int since_version) {
+ OpTester test("DFT", since_version);
vector shape = {1, 5, 1};
vector output_shape = {1, 5, 2};
@@ -37,8 +38,8 @@ static void TestNaiveDFTFloat(bool onesided) {
test.Run();
}
-static void TestRadix2DFTFloat(bool onesided) {
- OpTester test("DFT", kMinOpsetVersion);
+static void TestRadix2DFTFloat(bool onesided, int since_version) {
+ OpTester test("DFT", since_version);
vector shape = {1, 8, 1};
vector output_shape = {1, 8, 2};
@@ -57,20 +58,8 @@ static void TestRadix2DFTFloat(bool onesided) {
test.Run();
}
-TEST(SignalOpsTest, DFTFloat_naive) {
- TestNaiveDFTFloat(false);
-}
-
-TEST(SignalOpsTest, DFTFloat_naive_onesided) {
- TestNaiveDFTFloat(true);
-}
-
-TEST(SignalOpsTest, DFTFloat_radix2) { TestRadix2DFTFloat(false); }
-
-TEST(SignalOpsTest, DFTFloat_radix2_onesided) { TestRadix2DFTFloat(true); }
-
-TEST(SignalOpsTest, DFTFloat_inverse) {
- OpTester test("DFT", kMinOpsetVersion);
+static void TestInverseFloat(int since_version) {
+ OpTester test("DFT", since_version);
vector shape = {1, 5, 2};
vector input = {15.000000f, 0.0000000f, -2.499999f, 3.4409550f, -2.500000f,
@@ -83,12 +72,44 @@ TEST(SignalOpsTest, DFTFloat_inverse) {
test.Run();
}
+TEST(SignalOpsTest, DFT17_Float_naive) {
+ TestNaiveDFTFloat(false, kMinOpsetVersion);
+}
+
+TEST(SignalOpsTest, DFT20_Float_naive) {
+ TestNaiveDFTFloat(false, kOpsetVersion20);
+}
+
+TEST(SignalOpsTest, DFT17_Float_naive_onesided) {
+ TestNaiveDFTFloat(true, kMinOpsetVersion);
+}
+
+TEST(SignalOpsTest, DFT20_Float_naive_onesided) {
+ TestNaiveDFTFloat(true, kOpsetVersion20);
+}
+
+TEST(SignalOpsTest, DFT17_Float_radix2) { TestRadix2DFTFloat(false, kMinOpsetVersion); }
+
+TEST(SignalOpsTest, DFT20_Float_radix2) { TestRadix2DFTFloat(false, kOpsetVersion20); }
+
+TEST(SignalOpsTest, DFT17_Float_radix2_onesided) { TestRadix2DFTFloat(true, kMinOpsetVersion); }
+
+TEST(SignalOpsTest, DFT20_Float_radix2_onesided) { TestRadix2DFTFloat(true, kOpsetVersion20); }
+
+TEST(SignalOpsTest, DFT17_Float_inverse) {
+ TestInverseFloat(kMinOpsetVersion);
+}
+
+TEST(SignalOpsTest, DFT20_Float_inverse) {
+ TestInverseFloat(kOpsetVersion20);
+}
+
// Tests that FFT(FFT(x), inverse=true) == x
-static void TestDFTInvertible(bool complex) {
+static void TestDFTInvertible(bool complex, int since_version) {
// TODO: test dft_length
class DFTInvertibleTester : public OpTester {
public:
- DFTInvertibleTester(int64_t axis) : OpTester("DFT", kMinOpsetVersion), axis_(axis) {}
+ DFTInvertibleTester(int64_t axis, int since_version) : OpTester("DFT", since_version), axis_(axis) {}
protected:
void AddNodes(Graph& graph, vector& graph_inputs, vector& graph_outputs,
@@ -98,11 +119,20 @@ static void TestDFTInvertible(bool complex) {
// call base implementation to add the DFT node.
OpTester::AddNodes(graph, graph_inputs, intermediate_outputs, add_attribute_funcs);
- OpTester::AddAttribute("axis", axis_);
+ if (this->Opset() < kOpsetVersion20) {
+ OpTester::AddAttribute("axis", axis_);
+ } else {
+ assert(intermediate_outputs.size() == 1);
+ assert(graph_inputs.size() == 3);
+ intermediate_outputs.push_back(graph_inputs[1]);
+ intermediate_outputs.push_back(graph_inputs[2]);
+ }
Node& inverse = graph.AddNode("inverse", "DFT", "inverse", intermediate_outputs, graph_outputs);
inverse.AddAttribute("inverse", static_cast(true));
- inverse.AddAttribute("axis", axis_);
+ if (this->Opset() < kOpsetVersion20) {
+ inverse.AddAttribute("axis", axis_);
+ }
}
private:
@@ -112,14 +142,21 @@ static void TestDFTInvertible(bool complex) {
RandomValueGenerator random(GetTestRandomSeed());
// TODO(smk2007): Add tests for different dft_length values.
constexpr int64_t num_batches = 2;
- for (int64_t axis = 1; axis < 2; axis += 1) {
+ for (int64_t axis = 0; axis < 2; axis += 1) {
for (int64_t signal_dim1 = 2; signal_dim1 <= 5; signal_dim1 += 1) {
for (int64_t signal_dim2 = 2; signal_dim2 <= 5; signal_dim2 += 1) {
- DFTInvertibleTester test(axis);
+ if (axis == 0 && since_version < kOpsetVersion20)
+ continue;
+ DFTInvertibleTester test(axis, since_version);
vector input_shape{num_batches, signal_dim1, signal_dim2, 1 + (complex ? 1 : 0)};
vector input_data = random.Uniform(input_shape, -100.f, 100.f);
test.AddInput("input", input_shape, input_data);
+ if (since_version >= kOpsetVersion20) {
+ test.AddInput("", {0}, {});
+ test.AddInput("axis", {1}, {axis});
+ }
+
vector output_shape(input_shape);
vector* output_data_p;
vector output_data;
@@ -141,12 +178,20 @@ static void TestDFTInvertible(bool complex) {
}
}
-TEST(SignalOpsTest, DFT_invertible_real) {
- TestDFTInvertible(false);
+TEST(SignalOpsTest, DFT17_invertible_real) {
+ TestDFTInvertible(false, kMinOpsetVersion);
+}
+
+TEST(SignalOpsTest, DFT20_invertible_real) {
+ TestDFTInvertible(false, kOpsetVersion20);
+}
+
+TEST(SignalOpsTest, DFT17_invertible_complex) {
+ TestDFTInvertible(true, kMinOpsetVersion);
}
-TEST(SignalOpsTest, DFT_invertible_complex) {
- TestDFTInvertible(true);
+TEST(SignalOpsTest, DFT20_invertible_complex) {
+ TestDFTInvertible(true, kOpsetVersion20);
}
TEST(SignalOpsTest, STFTFloat) {
diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
index bfdc0b1d26953..49d8d7150a117 100644
--- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
+++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc
@@ -262,9 +262,6 @@
"^test_string_split_empty_tensor",
"^test_string_split_maxsplit",
"^test_string_split_no_delimiter",
- "^test_dft_axis",
- "^test_dft",
- "^test_dft_inverse",
"^test_reduce_max_bool_inputs",
"^test_reduce_min_bool_inputs",
"^test_reduce_min_empty_set",