From 1559a09df0776642bbde7baffc1dc97baa2a70f1 Mon Sep 17 00:00:00 2001 From: Ping Zheng Date: Wed, 15 Nov 2023 09:16:50 -0500 Subject: [PATCH] reorder --- i6_models/primitives/feature_extraction.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index 6df10574..dadf6e41 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -92,7 +92,14 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: :param length in samples: [B] :return features as [B,T,F] and length in frames [B] """ - if not self.rasr_compatible: + if self.rasr_compatible: + windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) + smoothed = windowed * self.window.unsqueeze(0) + + # Compute power spectrum using torch.fft.rfftn + power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, F, T] + power_spectrum = power_spectrum.transpose(1, 2) # [B, T, F] + else: power_spectrum = ( torch.abs( torch.stft( @@ -108,13 +115,6 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: ) ** 2 ) - else: - windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) - smoothed = windowed * self.window.unsqueeze(0) - - # Compute power spectrum using torch.fft.rfftn - power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, F, T] - power_spectrum = power_spectrum.transpose(1, 2) # [B, T, F] if len(power_spectrum.size()) == 2: # For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again