Skip to content

Commit

Permalink
Implement dft(20) (#17821)
Browse files Browse the repository at this point in the history
### Description
dft is updated in opset20. implement it in ort



### Motivation and Context
this is for ort 1.17.0 release

Fixes #17723

---------

Signed-off-by: Liqun Fu <[email protected]>
  • Loading branch information
liqunfu authored Dec 19, 2023
1 parent 5f00bc9 commit 32fcf73
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 37 deletions.
3 changes: 2 additions & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ Do not modify directly.*
|Crop|*in* input:**T**<br> *out* output:**T**|1+|**T** = tensor(float)|
|CumSum|*in* x:**T**<br> *in* axis:**T2**<br> *out* y:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T2** = tensor(int32), tensor(int64)|
|||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)<br/> **T2** = tensor(int32), tensor(int64)|
|DFT|*in* input:**T1**<br> *in* dft_length:**T2**<br> *in* axis:**tensor(int64)**<br> *out* output:**T1**<br><br>or<br><br>*in* input:**T1**<br> *in* dft_length:**T2**<br> *out* output:**T1**|17+|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(int32), tensor(int64)|
|DFT|*in* input:**T1**<br> *in* dft_length:**T2**<br> *in* axis:**tensor(int64)**<br> *out* output:**T1**<br><br>or<br><br>*in* input:**T1**<br> *in* dft_length:**T2**<br> *out* output:**T1**|20+|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(int32), tensor(int64)|
|||[17, 19]|**T1** = tensor(double), tensor(float)<br/> **T2** = tensor(int32), tensor(int64)|
|DepthToSpace|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float)|
|||[11, 12]|**T** = tensor(double), tensor(float)|
|||[1, 10]|**T** = tensor(double), tensor(float)|
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 16, int64_t, LessOrEqual);

// Opset 17
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, DFT);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, 19, DFT);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, BlackmanWindow);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, HammingWindow);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, HannWindow);
Expand Down Expand Up @@ -960,6 +960,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Sh

// Opset 20
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, DFT);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, GridSample);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, GridSample);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, AffineGrid);
Expand Down Expand Up @@ -2217,7 +2218,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {

// Opset 17
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, BlackmanWindow)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, DFT)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, 19, DFT)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, HammingWindow)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, HannWindow)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 17, MelWeightMatrix)>,
Expand Down Expand Up @@ -2403,6 +2404,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) {

// Opset 20
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, DFT)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, GridSample)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, GridSample)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, AffineGrid)>,
Expand Down
18 changes: 16 additions & 2 deletions onnxruntime/core/providers/cpu/signal/dft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,15 @@

namespace onnxruntime {

ONNX_CPU_OPERATOR_KERNEL(DFT, 17,
ONNX_CPU_OPERATOR_VERSIONED_KERNEL(
DFT,
17, 19,
KernelDefBuilder()
.TypeConstraint("T1", BuildKernelDefConstraints<float, double>())
.TypeConstraint("T2", BuildKernelDefConstraints<int32_t, int64_t>()),
DFT);

ONNX_CPU_OPERATOR_KERNEL(DFT, 20,
KernelDefBuilder()
.TypeConstraint("T1", BuildKernelDefConstraints<float, double>())
.TypeConstraint("T2", BuildKernelDefConstraints<int32_t, int64_t>()),
Expand Down Expand Up @@ -442,7 +450,13 @@ static Status discrete_fourier_transform(OpKernelContext* ctx, int64_t axis, boo
}

Status DFT::Compute(OpKernelContext* ctx) const {
ORT_RETURN_IF_ERROR(discrete_fourier_transform(ctx, axis_, is_onesided_, is_inverse_));
int64_t axis = axis_;
if (opset_ >= 20 && ctx->InputCount() >= 3) {
const Tensor* axes_tensor = ctx->Input<Tensor>(2);
axis = axes_tensor->Data<int64_t>()[0];
}

ORT_RETURN_IF_ERROR(discrete_fourier_transform(ctx, axis, is_onesided_, is_inverse_));
return Status::OK();
}

Expand Down
7 changes: 6 additions & 1 deletion onnxruntime/core/providers/cpu/signal/dft.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,19 @@
namespace onnxruntime {

class DFT final : public OpKernel {
int opset_;
bool is_onesided_ = true;
int64_t axis_ = 0;
bool is_inverse_ = false;

public:
explicit DFT(const OpKernelInfo& info) : OpKernel(info) {
is_onesided_ = static_cast<bool>(info.GetAttrOrDefault<int64_t>("onesided", 0));
axis_ = info.GetAttrOrDefault<int64_t>("axis", 1);
opset_ = info.node().SinceVersion();
if (opset_ < 20)
axis_ = info.GetAttrOrDefault<int64_t>("axis", 1);
else
axis_ = -2; // default axis of DFT(20)
is_inverse_ = info.GetAttrOrDefault<int64_t>("inverse", 0);
}
Status Compute(OpKernelContext* ctx) const override;
Expand Down
101 changes: 73 additions & 28 deletions onnxruntime/test/providers/cpu/signal/signal_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@ namespace onnxruntime {
namespace test {

static constexpr int kMinOpsetVersion = 17;
static constexpr int kOpsetVersion20 = 20;

static void TestNaiveDFTFloat(bool onesided) {
OpTester test("DFT", kMinOpsetVersion);
static void TestNaiveDFTFloat(bool onesided, int since_version) {
OpTester test("DFT", since_version);

vector<int64_t> shape = {1, 5, 1};
vector<int64_t> output_shape = {1, 5, 2};
Expand All @@ -37,8 +38,8 @@ static void TestNaiveDFTFloat(bool onesided) {
test.Run();
}

static void TestRadix2DFTFloat(bool onesided) {
OpTester test("DFT", kMinOpsetVersion);
static void TestRadix2DFTFloat(bool onesided, int since_version) {
OpTester test("DFT", since_version);

vector<int64_t> shape = {1, 8, 1};
vector<int64_t> output_shape = {1, 8, 2};
Expand All @@ -57,20 +58,8 @@ static void TestRadix2DFTFloat(bool onesided) {
test.Run();
}

TEST(SignalOpsTest, DFTFloat_naive) {
TestNaiveDFTFloat(false);
}

TEST(SignalOpsTest, DFTFloat_naive_onesided) {
TestNaiveDFTFloat(true);
}

TEST(SignalOpsTest, DFTFloat_radix2) { TestRadix2DFTFloat(false); }

TEST(SignalOpsTest, DFTFloat_radix2_onesided) { TestRadix2DFTFloat(true); }

TEST(SignalOpsTest, DFTFloat_inverse) {
OpTester test("DFT", kMinOpsetVersion);
static void TestInverseFloat(int since_version) {
OpTester test("DFT", since_version);

vector<int64_t> shape = {1, 5, 2};
vector<float> input = {15.000000f, 0.0000000f, -2.499999f, 3.4409550f, -2.500000f,
Expand All @@ -83,12 +72,44 @@ TEST(SignalOpsTest, DFTFloat_inverse) {
test.Run();
}

TEST(SignalOpsTest, DFT17_Float_naive) {
TestNaiveDFTFloat(false, kMinOpsetVersion);
}

TEST(SignalOpsTest, DFT20_Float_naive) {
TestNaiveDFTFloat(false, kOpsetVersion20);
}

TEST(SignalOpsTest, DFT17_Float_naive_onesided) {
TestNaiveDFTFloat(true, kMinOpsetVersion);
}

TEST(SignalOpsTest, DFT20_Float_naive_onesided) {
TestNaiveDFTFloat(true, kOpsetVersion20);
}

TEST(SignalOpsTest, DFT17_Float_radix2) { TestRadix2DFTFloat(false, kMinOpsetVersion); }

TEST(SignalOpsTest, DFT20_Float_radix2) { TestRadix2DFTFloat(false, kOpsetVersion20); }

TEST(SignalOpsTest, DFT17_Float_radix2_onesided) { TestRadix2DFTFloat(true, kMinOpsetVersion); }

TEST(SignalOpsTest, DFT20_Float_radix2_onesided) { TestRadix2DFTFloat(true, kOpsetVersion20); }

TEST(SignalOpsTest, DFT17_Float_inverse) {
TestInverseFloat(kMinOpsetVersion);
}

TEST(SignalOpsTest, DFT20_Float_inverse) {
TestInverseFloat(kOpsetVersion20);
}

// Tests that FFT(FFT(x), inverse=true) == x
static void TestDFTInvertible(bool complex) {
static void TestDFTInvertible(bool complex, int since_version) {
// TODO: test dft_length
class DFTInvertibleTester : public OpTester {
public:
DFTInvertibleTester(int64_t axis) : OpTester("DFT", kMinOpsetVersion), axis_(axis) {}
DFTInvertibleTester(int64_t axis, int since_version) : OpTester("DFT", since_version), axis_(axis) {}

protected:
void AddNodes(Graph& graph, vector<NodeArg*>& graph_inputs, vector<NodeArg*>& graph_outputs,
Expand All @@ -98,11 +119,20 @@ static void TestDFTInvertible(bool complex) {

// call base implementation to add the DFT node.
OpTester::AddNodes(graph, graph_inputs, intermediate_outputs, add_attribute_funcs);
OpTester::AddAttribute("axis", axis_);
if (this->Opset() < kOpsetVersion20) {
OpTester::AddAttribute("axis", axis_);
} else {
assert(intermediate_outputs.size() == 1);
assert(graph_inputs.size() == 3);
intermediate_outputs.push_back(graph_inputs[1]);
intermediate_outputs.push_back(graph_inputs[2]);
}

Node& inverse = graph.AddNode("inverse", "DFT", "inverse", intermediate_outputs, graph_outputs);
inverse.AddAttribute("inverse", static_cast<int64_t>(true));
inverse.AddAttribute("axis", axis_);
if (this->Opset() < kOpsetVersion20) {
inverse.AddAttribute("axis", axis_);
}
}

private:
Expand All @@ -112,14 +142,21 @@ static void TestDFTInvertible(bool complex) {
RandomValueGenerator random(GetTestRandomSeed());
// TODO(smk2007): Add tests for different dft_length values.
constexpr int64_t num_batches = 2;
for (int64_t axis = 1; axis < 2; axis += 1) {
for (int64_t axis = 0; axis < 2; axis += 1) {
for (int64_t signal_dim1 = 2; signal_dim1 <= 5; signal_dim1 += 1) {
for (int64_t signal_dim2 = 2; signal_dim2 <= 5; signal_dim2 += 1) {
DFTInvertibleTester test(axis);
if (axis == 0 && since_version < kOpsetVersion20)
continue;
DFTInvertibleTester test(axis, since_version);
vector<int64_t> input_shape{num_batches, signal_dim1, signal_dim2, 1 + (complex ? 1 : 0)};
vector<float> input_data = random.Uniform<float>(input_shape, -100.f, 100.f);
test.AddInput("input", input_shape, input_data);

if (since_version >= kOpsetVersion20) {
test.AddInput<int64_t>("", {0}, {});
test.AddInput<int64_t>("axis", {1}, {axis});
}

vector<int64_t> output_shape(input_shape);
vector<float>* output_data_p;
vector<float> output_data;
Expand All @@ -141,12 +178,20 @@ static void TestDFTInvertible(bool complex) {
}
}

TEST(SignalOpsTest, DFT_invertible_real) {
TestDFTInvertible(false);
TEST(SignalOpsTest, DFT17_invertible_real) {
TestDFTInvertible(false, kMinOpsetVersion);
}

TEST(SignalOpsTest, DFT20_invertible_real) {
TestDFTInvertible(false, kOpsetVersion20);
}

TEST(SignalOpsTest, DFT17_invertible_complex) {
TestDFTInvertible(true, kMinOpsetVersion);
}

TEST(SignalOpsTest, DFT_invertible_complex) {
TestDFTInvertible(true);
TEST(SignalOpsTest, DFT20_invertible_complex) {
TestDFTInvertible(true, kOpsetVersion20);
}

TEST(SignalOpsTest, STFTFloat) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,6 @@
"^test_string_split_empty_tensor",
"^test_string_split_maxsplit",
"^test_string_split_no_delimiter",
"^test_dft_axis",
"^test_dft",
"^test_dft_inverse",
"^test_reduce_max_bool_inputs",
"^test_reduce_min_bool_inputs",
"^test_reduce_min_empty_set",
Expand Down

0 comments on commit 32fcf73

Please sign in to comment.