Skip to content

Commit

Permalink
reorder
Browse files Browse the repository at this point in the history
  • Loading branch information
kuacakuaca committed Nov 15, 2023
1 parent 91e8e0b commit 1559a09
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions i6_models/primitives/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,14 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
:param length in samples: [B]
:return features as [B,T,F] and length in frames [B]
"""
if not self.rasr_compatible:
if self.rasr_compatible:
windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length)
smoothed = windowed * self.window.unsqueeze(0)

# Compute power spectrum using torch.fft.rfftn
power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, F, T]
power_spectrum = power_spectrum.transpose(1, 2) # [B, T, F]
else:
power_spectrum = (
torch.abs(
torch.stft(
Expand All @@ -108,13 +115,6 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
)
** 2
)
else:
windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length)
smoothed = windowed * self.window.unsqueeze(0)

# Compute power spectrum using torch.fft.rfftn
power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, F, T]
power_spectrum = power_spectrum.transpose(1, 2) # [B, T, F]

if len(power_spectrum.size()) == 2:
# For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again
Expand Down

0 comments on commit 1559a09

Please sign in to comment.