diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index ccb2476e..ead52dd5 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -1,23 +1,15 @@ __all__ = ["LogMelFeatureExtractionV1", "LogMelFeatureExtractionV1Config"] from dataclasses import dataclass -from typing import Optional, Tuple, Union, Literal -from enum import Enum +from typing import Optional, Tuple from librosa import filters import torch from torch import nn -import numpy as np -from numpy.typing import DTypeLike from i6_models.config import ModelConfiguration -class SpectrumType(Enum): - STFT = 1 - RFFTN = 2 - - @dataclass class LogMelFeatureExtractionV1Config(ModelConfiguration): """ @@ -30,10 +22,6 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration): min_amp: minimum amplitude for safe log num_filters: number of mel windows center: centered STFT with automatic padding - periodic: whether the window is assumed to be periodic - htk: whether use HTK formula instead of Slaney - norm: how to normalize the filters, cf. https://librosa.org/doc/main/generated/librosa.filters.mel.html - spectrum_type: apply torch.stft on raw audio input (default) or torch.fft.rfftn on windowed audio to make features compatible to RASR's """ sample_rate: int @@ -45,11 +33,6 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration): num_filters: int center: bool n_fft: Optional[int] = None - periodic: bool = True - htk: bool = False - norm: Optional[Union[Literal["slaney"], float]] = "slaney" - dtype: DTypeLike = np.float32 - spectrum_type: SpectrumType = SpectrumType.STFT def __post_init__(self) -> None: super().__post_init__() @@ -79,7 +62,6 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config): self.min_amp = cfg.min_amp self.n_fft = cfg.n_fft self.win_length = int(cfg.win_size * cfg.sample_rate) - self.spectrum_type = cfg.spectrum_type self.register_buffer( "mel_basis", @@ -90,13 +72,10 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config): n_mels=cfg.num_filters, fmin=cfg.f_min, fmax=cfg.f_max, - htk=cfg.htk, - norm=cfg.norm, - dtype=cfg.dtype, - ), + ) ), ) - self.register_buffer("window", torch.hann_window(self.win_length, periodic=cfg.periodic)) + self.register_buffer("window", torch.hann_window(self.win_length)) def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -104,46 +83,31 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: :param length in samples: [B] :return features as [B,T,F] and length in frames [B] """ - if self.spectrum_type == SpectrumType.STFT: - power_spectrum = ( - torch.abs( - torch.stft( - raw_audio, - n_fft=self.n_fft, - hop_length=self.hop_length, - win_length=self.win_length, - window=self.window, - center=self.center, - pad_mode="constant", - return_complex=True, - ) + power_spectrum = ( + torch.abs( + torch.stft( + raw_audio, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + center=self.center, + pad_mode="constant", + return_complex=True, ) - ** 2 ) - elif self.spectrum_type == SpectrumType.RFFTN: - windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length] - smoothed = windowed * self.window.unsqueeze(0) # [B, T', W] - - # Compute power spectrum using torch.fft.rfftn - power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, T', F=n_fft//2+1] - power_spectrum = power_spectrum.transpose(1, 2) # [B, F, T'] - else: - raise ValueError(f"Invalid spectrum type {self.spectrum_type!r}.") - + ** 2 + ) if len(power_spectrum.size()) == 2: # For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again - power_spectrum = torch.unsqueeze(power_spectrum, 0) # [B, F, T'] - melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis) # [B, F'=num_filters, T'] + power_spectrum = torch.unsqueeze(power_spectrum, 0) + melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis) log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp)) - feature_data = torch.transpose(log_melspec, 1, 2) # [B, T', F'] + feature_data = torch.transpose(log_melspec, 1, 2) - if self.spectrum_type == SpectrumType.STFT: - if self.center: - length = (length // self.hop_length) + 1 - else: - length = ((length - self.n_fft) // self.hop_length) + 1 - elif self.spectrum_type == SpectrumType.RFFTN: - length = ((length - self.win_length) // self.hop_length) + 1 + if self.center: + length = (length // self.hop_length) + 1 else: - raise ValueError(f"Invalid spectrum type {self.spectrum_type!r}.") + length = ((length - self.n_fft) // self.hop_length) + 1 + return feature_data, length.int()