From 32fcf73740aba943d4dc3a436838428c6b20163d Mon Sep 17 00:00:00 2001 From: liqun Fu Date: Tue, 19 Dec 2023 10:42:54 -0800 Subject: [PATCH] Implement dft(20) (#17821) ### Description dft is updated in opset20. implement it in ort ### Motivation and Context this is for ort 1.17.0 release Fixes #17723 --------- Signed-off-by: Liqun Fu --- docs/OperatorKernels.md | 3 +- .../providers/cpu/cpu_execution_provider.cc | 6 +- onnxruntime/core/providers/cpu/signal/dft.cc | 18 +++- onnxruntime/core/providers/cpu/signal/dft.h | 7 +- .../providers/cpu/signal/signal_ops_test.cc | 101 +++++++++++++----- .../onnx_backend_test_series_filters.jsonc | 3 - 6 files changed, 101 insertions(+), 37 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index edf249a816923..1ce9b3254d91f 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -80,7 +80,8 @@ Do not modify directly.* |Crop|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| |CumSum|*in* x:**T**
*in* axis:**T2**
*out* y:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)| |||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)| -|DFT|*in* input:**T1**
*in* dft_length:**T2**
*in* axis:**tensor(int64)**
*out* output:**T1**

or

*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| +|DFT|*in* input:**T1**
*in* dft_length:**T2**
*in* axis:**tensor(int64)**
*out* output:**T1**

or

*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|20+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| +|||[17, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)| |||[11, 12]|**T** = tensor(double), tensor(float)| |||[1, 10]|**T** = tensor(double), tensor(float)| diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 4553e7ee18913..1390f60243174 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -823,7 +823,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, int64_t, LessOrEqual); // Opset 17 -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, DFT); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, 19, DFT); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, BlackmanWindow); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, HammingWindow); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, HannWindow); @@ -960,6 +960,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Sh // Opset 20 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, DFT); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, GridSample); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, GridSample); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, AffineGrid); @@ -2217,7 +2218,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { // Opset 17 BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -2403,6 +2404,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { // Opset 20 BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cpu/signal/dft.cc b/onnxruntime/core/providers/cpu/signal/dft.cc index 8634e393b43d0..15bf633579e5f 100644 --- a/onnxruntime/core/providers/cpu/signal/dft.cc +++ b/onnxruntime/core/providers/cpu/signal/dft.cc @@ -19,7 +19,15 @@ namespace onnxruntime { -ONNX_CPU_OPERATOR_KERNEL(DFT, 17, +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( + DFT, + 17, 19, + KernelDefBuilder() + .TypeConstraint("T1", BuildKernelDefConstraints()) + .TypeConstraint("T2", BuildKernelDefConstraints()), + DFT); + +ONNX_CPU_OPERATOR_KERNEL(DFT, 20, KernelDefBuilder() .TypeConstraint("T1", BuildKernelDefConstraints()) .TypeConstraint("T2", BuildKernelDefConstraints()), @@ -442,7 +450,13 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, int64_t axis, boo } Status DFT::Compute(OpKernelContext* ctx) const { - ORT_RETURN_IF_ERROR(discrete_fourier_transform(ctx, axis_, is_onesided_, is_inverse_)); + int64_t axis = axis_; + if (opset_ >= 20 && ctx->InputCount() >= 3) { + const Tensor* axes_tensor = ctx->Input(2); + axis = axes_tensor->Data()[0]; + } + + ORT_RETURN_IF_ERROR(discrete_fourier_transform(ctx, axis, is_onesided_, is_inverse_)); return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/signal/dft.h b/onnxruntime/core/providers/cpu/signal/dft.h index 71cac52e37e8f..967d4ec15524b 100644 --- a/onnxruntime/core/providers/cpu/signal/dft.h +++ b/onnxruntime/core/providers/cpu/signal/dft.h @@ -7,6 +7,7 @@ namespace onnxruntime { class DFT final : public OpKernel { + int opset_; bool is_onesided_ = true; int64_t axis_ = 0; bool is_inverse_ = false; @@ -14,7 +15,11 @@ class DFT final : public OpKernel { public: explicit DFT(const OpKernelInfo& info) : OpKernel(info) { is_onesided_ = static_cast(info.GetAttrOrDefault("onesided", 0)); - axis_ = info.GetAttrOrDefault("axis", 1); + opset_ = info.node().SinceVersion(); + if (opset_ < 20) + axis_ = info.GetAttrOrDefault("axis", 1); + else + axis_ = -2; // default axis of DFT(20) is_inverse_ = info.GetAttrOrDefault("inverse", 0); } Status Compute(OpKernelContext* ctx) const override; diff --git a/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc b/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc index 3d4324189d463..54d725defe5ee 100644 --- a/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc +++ b/onnxruntime/test/providers/cpu/signal/signal_ops_test.cc @@ -16,9 +16,10 @@ namespace onnxruntime { namespace test { static constexpr int kMinOpsetVersion = 17; +static constexpr int kOpsetVersion20 = 20; -static void TestNaiveDFTFloat(bool onesided) { - OpTester test("DFT", kMinOpsetVersion); +static void TestNaiveDFTFloat(bool onesided, int since_version) { + OpTester test("DFT", since_version); vector shape = {1, 5, 1}; vector output_shape = {1, 5, 2}; @@ -37,8 +38,8 @@ static void TestNaiveDFTFloat(bool onesided) { test.Run(); } -static void TestRadix2DFTFloat(bool onesided) { - OpTester test("DFT", kMinOpsetVersion); +static void TestRadix2DFTFloat(bool onesided, int since_version) { + OpTester test("DFT", since_version); vector shape = {1, 8, 1}; vector output_shape = {1, 8, 2}; @@ -57,20 +58,8 @@ static void TestRadix2DFTFloat(bool onesided) { test.Run(); } -TEST(SignalOpsTest, DFTFloat_naive) { - TestNaiveDFTFloat(false); -} - -TEST(SignalOpsTest, DFTFloat_naive_onesided) { - TestNaiveDFTFloat(true); -} - -TEST(SignalOpsTest, DFTFloat_radix2) { TestRadix2DFTFloat(false); } - -TEST(SignalOpsTest, DFTFloat_radix2_onesided) { TestRadix2DFTFloat(true); } - -TEST(SignalOpsTest, DFTFloat_inverse) { - OpTester test("DFT", kMinOpsetVersion); +static void TestInverseFloat(int since_version) { + OpTester test("DFT", since_version); vector shape = {1, 5, 2}; vector input = {15.000000f, 0.0000000f, -2.499999f, 3.4409550f, -2.500000f, @@ -83,12 +72,44 @@ TEST(SignalOpsTest, DFTFloat_inverse) { test.Run(); } +TEST(SignalOpsTest, DFT17_Float_naive) { + TestNaiveDFTFloat(false, kMinOpsetVersion); +} + +TEST(SignalOpsTest, DFT20_Float_naive) { + TestNaiveDFTFloat(false, kOpsetVersion20); +} + +TEST(SignalOpsTest, DFT17_Float_naive_onesided) { + TestNaiveDFTFloat(true, kMinOpsetVersion); +} + +TEST(SignalOpsTest, DFT20_Float_naive_onesided) { + TestNaiveDFTFloat(true, kOpsetVersion20); +} + +TEST(SignalOpsTest, DFT17_Float_radix2) { TestRadix2DFTFloat(false, kMinOpsetVersion); } + +TEST(SignalOpsTest, DFT20_Float_radix2) { TestRadix2DFTFloat(false, kOpsetVersion20); } + +TEST(SignalOpsTest, DFT17_Float_radix2_onesided) { TestRadix2DFTFloat(true, kMinOpsetVersion); } + +TEST(SignalOpsTest, DFT20_Float_radix2_onesided) { TestRadix2DFTFloat(true, kOpsetVersion20); } + +TEST(SignalOpsTest, DFT17_Float_inverse) { + TestInverseFloat(kMinOpsetVersion); +} + +TEST(SignalOpsTest, DFT20_Float_inverse) { + TestInverseFloat(kOpsetVersion20); +} + // Tests that FFT(FFT(x), inverse=true) == x -static void TestDFTInvertible(bool complex) { +static void TestDFTInvertible(bool complex, int since_version) { // TODO: test dft_length class DFTInvertibleTester : public OpTester { public: - DFTInvertibleTester(int64_t axis) : OpTester("DFT", kMinOpsetVersion), axis_(axis) {} + DFTInvertibleTester(int64_t axis, int since_version) : OpTester("DFT", since_version), axis_(axis) {} protected: void AddNodes(Graph& graph, vector& graph_inputs, vector& graph_outputs, @@ -98,11 +119,20 @@ static void TestDFTInvertible(bool complex) { // call base implementation to add the DFT node. OpTester::AddNodes(graph, graph_inputs, intermediate_outputs, add_attribute_funcs); - OpTester::AddAttribute("axis", axis_); + if (this->Opset() < kOpsetVersion20) { + OpTester::AddAttribute("axis", axis_); + } else { + assert(intermediate_outputs.size() == 1); + assert(graph_inputs.size() == 3); + intermediate_outputs.push_back(graph_inputs[1]); + intermediate_outputs.push_back(graph_inputs[2]); + } Node& inverse = graph.AddNode("inverse", "DFT", "inverse", intermediate_outputs, graph_outputs); inverse.AddAttribute("inverse", static_cast(true)); - inverse.AddAttribute("axis", axis_); + if (this->Opset() < kOpsetVersion20) { + inverse.AddAttribute("axis", axis_); + } } private: @@ -112,14 +142,21 @@ static void TestDFTInvertible(bool complex) { RandomValueGenerator random(GetTestRandomSeed()); // TODO(smk2007): Add tests for different dft_length values. constexpr int64_t num_batches = 2; - for (int64_t axis = 1; axis < 2; axis += 1) { + for (int64_t axis = 0; axis < 2; axis += 1) { for (int64_t signal_dim1 = 2; signal_dim1 <= 5; signal_dim1 += 1) { for (int64_t signal_dim2 = 2; signal_dim2 <= 5; signal_dim2 += 1) { - DFTInvertibleTester test(axis); + if (axis == 0 && since_version < kOpsetVersion20) + continue; + DFTInvertibleTester test(axis, since_version); vector input_shape{num_batches, signal_dim1, signal_dim2, 1 + (complex ? 1 : 0)}; vector input_data = random.Uniform(input_shape, -100.f, 100.f); test.AddInput("input", input_shape, input_data); + if (since_version >= kOpsetVersion20) { + test.AddInput("", {0}, {}); + test.AddInput("axis", {1}, {axis}); + } + vector output_shape(input_shape); vector* output_data_p; vector output_data; @@ -141,12 +178,20 @@ static void TestDFTInvertible(bool complex) { } } -TEST(SignalOpsTest, DFT_invertible_real) { - TestDFTInvertible(false); +TEST(SignalOpsTest, DFT17_invertible_real) { + TestDFTInvertible(false, kMinOpsetVersion); +} + +TEST(SignalOpsTest, DFT20_invertible_real) { + TestDFTInvertible(false, kOpsetVersion20); +} + +TEST(SignalOpsTest, DFT17_invertible_complex) { + TestDFTInvertible(true, kMinOpsetVersion); } -TEST(SignalOpsTest, DFT_invertible_complex) { - TestDFTInvertible(true); +TEST(SignalOpsTest, DFT20_invertible_complex) { + TestDFTInvertible(true, kOpsetVersion20); } TEST(SignalOpsTest, STFTFloat) { diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index bfdc0b1d26953..49d8d7150a117 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -262,9 +262,6 @@ "^test_string_split_empty_tensor", "^test_string_split_maxsplit", "^test_string_split_no_delimiter", - "^test_dft_axis", - "^test_dft", - "^test_dft_inverse", "^test_reduce_max_bool_inputs", "^test_reduce_min_bool_inputs", "^test_reduce_min_empty_set",