Skip to content

Commit

Permalink
Implement STFT Decomposition transformer (#19725)
Browse files Browse the repository at this point in the history
Implement STFT Decomposition transformer.

Certain hardware does not support DXIL, and therefore existing operator
should be mapped to hardware supported functions.
Optimized convolution can be used to implement STFT.

---------

Co-authored-by: Sheil Kumar <[email protected]>
  • Loading branch information
smk2007 and Sheil Kumar authored Mar 8, 2024
1 parent 069d2d6 commit 7deee94
Show file tree
Hide file tree
Showing 8 changed files with 864 additions and 434 deletions.
381 changes: 381 additions & 0 deletions onnxruntime/core/optimizer/stft_decomposition.cc

Large diffs are not rendered by default.

30 changes: 30 additions & 0 deletions onnxruntime/core/optimizer/stft_decomposition.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/optimizer/graph_transformer.h"
#include "core/framework/ort_value.h"
#include <memory>
#include "core/framework/execution_provider.h"

namespace onnxruntime {

/**
@class STFTDecomposition
Transformer that traverses the graph top-down and decomposes
STFT into convolution.
*/
class STFTDecomposition : public GraphTransformer {
public:
/*! STFT decomposition .
\param execution_provider Execution provider instance to execute constant folding.
*/
STFTDecomposition(const InlinedHashSet<std::string_view>& compatible_execution_providers = {}) noexcept;

private:
Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override;
};

} // namespace onnxruntime
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/signal/dft.cc
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ static Status short_time_fourier_transform(OpKernelContext* ctx, bool is_oneside

// Calculate the window size with preference to the window input.
const auto window_size = window ? window->Shape()[0] : frame_length;
ORT_ENFORCE(window_size < signal_size, "Ensure that the dft size is smaller than the signal.");
ORT_ENFORCE(window_size <= signal_size, "Ensure that the dft size is smaller than the signal.");

// Calculate the number of dfts to run
const auto n_dfts =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -771,8 +771,14 @@ namespace Dml
!native16BitShaderOpsSupported &&
IsCustomOpShader(node))
{
nodeContainsSupportedDataTypes = false;
return;
// STFT is a special case since it has a dml ep registered
// graph transformation that will decompose fp16 STFT into convolution
// and so it is OK to register for fp16.
if (strcmp("STFT", node.OpType().c_str()) != 0)
{
nodeContainsSupportedDataTypes = false;
return;
}
}

// Allow nodeArgs that are SequenceTensor when they are actually implemented by CPU Kernels.
Expand Down
Loading

0 comments on commit 7deee94

Please sign in to comment.