Skip to content

Commit

Permalink
[STFT][PT FE] Add support for normalize attr in STFT conversion (#28005)
Browse files Browse the repository at this point in the history
### Details:
- Add support for `normalize` attribute of torch `aten::stft` by post op
subgraph in PT Frontend

### Tickets:
 - 159159
  • Loading branch information
mitruska authored Dec 11, 2024
1 parent c4daa25 commit 328feb6
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
13 changes: 10 additions & 3 deletions src/frontends/pytorch/src/op/stft.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "openvino/op/convert_like.hpp"
#include "openvino/op/divide.hpp"
#include "openvino/op/shape_of.hpp"
#include "openvino/op/sqrt.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "utils.hpp"

Expand Down Expand Up @@ -66,8 +67,6 @@ OutputVector translate_stft(const NodeContext& context) {
if (!context.input_is_none(5)) {
normalized = context.const_input<bool>(5);
}
PYTORCH_OP_CONVERSION_CHECK(!normalized,
"aten::stft conversion is currently supported with normalized=False only.");

bool onesided = true;
if (!context.input_is_none(6)) {
Expand All @@ -85,7 +84,15 @@ OutputVector translate_stft(const NodeContext& context) {
// Perform STFT
constexpr bool transpose_frames = true;
auto stft = context.mark_node(std::make_shared<v15::STFT>(input, window, n_fft, hop_length, transpose_frames));
return {stft};

if (normalized) {
const auto nfft_convert = context.mark_node(std::make_shared<v1::ConvertLike>(n_fft, stft));
const auto divisor = context.mark_node(std::make_shared<v0::Sqrt>(nfft_convert));
const auto norm_stft = context.mark_node(std::make_shared<v1::Divide>(stft, divisor));
return {norm_stft};
} else {
return {stft};
}
};
} // namespace op
} // namespace pytorch
Expand Down
20 changes: 9 additions & 11 deletions tests/layer_tests/pytorch_tests/test_stft.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,17 @@ def _prepare_input(self, win_length, signal_shape, rand_data=False, out_dtype="f

return (signal, window.astype(out_dtype))

def create_model(self, n_fft, hop_length, win_length):
def create_model(self, n_fft, hop_length, win_length, normalized):
import torch

class aten_stft(torch.nn.Module):

def __init__(self, n_fft, hop_length, win_length):
def __init__(self, n_fft, hop_length, win_length, normalized):
super(aten_stft, self).__init__()
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.normalized = normalized

def forward(self, x, window):
return torch.stft(
Expand All @@ -44,14 +45,14 @@ def forward(self, x, window):
window=window,
center=False,
pad_mode="reflect",
normalized=False,
normalized=self.normalized,
onesided=True,
return_complex=False,
)

ref_net = None

return aten_stft(n_fft, hop_length, win_length), ref_net, "aten::stft"
return aten_stft(n_fft, hop_length, win_length, normalized), ref_net, "aten::stft"

@pytest.mark.nightly
@pytest.mark.precommit
Expand All @@ -64,10 +65,11 @@ def forward(self, x, window):
[24, 32, 20],
[128, 128, 128],
])
def test_stft(self, n_fft, hop_length, window_size, signal_shape, ie_device, precision, ir_version, trace_model):
@pytest.mark.parametrize(("normalized"), [True, False])
def test_stft(self, n_fft, hop_length, window_size, signal_shape, normalized, ie_device, precision, ir_version, trace_model):
if ie_device == "GPU":
pytest.xfail(reason="STFT op is not supported on GPU yet")
self._test(*self.create_model(n_fft, hop_length, window_size), ie_device, precision,
self._test(*self.create_model(n_fft, hop_length, window_size, normalized), ie_device, precision,
ir_version, kwargs_to_prepare_input={"win_length": window_size, "signal_shape": signal_shape}, trace_model=trace_model)


Expand Down Expand Up @@ -125,8 +127,8 @@ def forward(self, x):
[16, None, 16, False, "reflect", False, True, False], # hop_length None
[16, None, None, False, "reflect", False, True, False], # hop & win length None
[16, 4, None, False, "reflect", False, True, False], # win_length None
# Unsupported cases:
[16, 4, 16, False, "reflect", True, True, False], # normalized True
# Unsupported cases:
[16, 4, 16, False, "reflect", False, False, False], # onesided False
[16, 4, 16, False, "reflect", False, True, True], # reutrn_complex True
])
Expand All @@ -138,10 +140,6 @@ def test_stft_not_supported_attrs(self, n_fft, hop_length, win_length, center, p
pytest.xfail(
reason="torch stft uses list() for `center` subgrpah before aten::stft, that leads to error: No conversion rule found for operations: aten::list")

if normalized is True:
pytest.xfail(
reason="aten::stft conversion is currently supported with normalized=False only")

if onesided is False:
pytest.xfail(
reason="aten::stft conversion is currently supported with onesided=True only")
Expand Down

0 comments on commit 328feb6

Please sign in to comment.