From 1559a09df0776642bbde7baffc1dc97baa2a70f1 Mon Sep 17 00:00:00 2001
From: Ping Zheng
Date: Wed, 15 Nov 2023 09:16:50 -0500
Subject: [PATCH] reorder
---
i6_models/primitives/feature_extraction.py | 16 ++++++++--------
1 file changed, 8 insertions(+), 8 deletions(-)
diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py
index 6df10574..dadf6e41 100644
--- a/i6_models/primitives/feature_extraction.py
+++ b/i6_models/primitives/feature_extraction.py
@@ -92,7 +92,14 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
:param length in samples: [B]
:return features as [B,T,F] and length in frames [B]
"""
- if not self.rasr_compatible:
+ if self.rasr_compatible:
+ windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length)
+ smoothed = windowed * self.window.unsqueeze(0)
+
+ # Compute power spectrum using torch.fft.rfftn
+ power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, F, T]
+ power_spectrum = power_spectrum.transpose(1, 2) # [B, T, F]
+ else:
power_spectrum = (
torch.abs(
torch.stft(
@@ -108,13 +115,6 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
)
** 2
)
- else:
- windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length)
- smoothed = windowed * self.window.unsqueeze(0)
-
- # Compute power spectrum using torch.fft.rfftn
- power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, F, T]
- power_spectrum = power_spectrum.transpose(1, 2) # [B, T, F]
if len(power_spectrum.size()) == 2:
# For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again