diff --git a/i6_models/primitives/feature_extraction.py b/i6_models/primitives/feature_extraction.py index ead52dd5..ccb2476e 100644 --- a/i6_models/primitives/feature_extraction.py +++ b/i6_models/primitives/feature_extraction.py @@ -1,15 +1,23 @@ __all__ = ["LogMelFeatureExtractionV1", "LogMelFeatureExtractionV1Config"] from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Tuple, Union, Literal +from enum import Enum from librosa import filters import torch from torch import nn +import numpy as np +from numpy.typing import DTypeLike from i6_models.config import ModelConfiguration +class SpectrumType(Enum): + STFT = 1 + RFFTN = 2 + + @dataclass class LogMelFeatureExtractionV1Config(ModelConfiguration): """ @@ -22,6 +30,10 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration): min_amp: minimum amplitude for safe log num_filters: number of mel windows center: centered STFT with automatic padding + periodic: whether the window is assumed to be periodic + htk: whether use HTK formula instead of Slaney + norm: how to normalize the filters, cf. https://librosa.org/doc/main/generated/librosa.filters.mel.html + spectrum_type: apply torch.stft on raw audio input (default) or torch.fft.rfftn on windowed audio to make features compatible to RASR's """ sample_rate: int @@ -33,6 +45,11 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration): num_filters: int center: bool n_fft: Optional[int] = None + periodic: bool = True + htk: bool = False + norm: Optional[Union[Literal["slaney"], float]] = "slaney" + dtype: DTypeLike = np.float32 + spectrum_type: SpectrumType = SpectrumType.STFT def __post_init__(self) -> None: super().__post_init__() @@ -62,6 +79,7 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config): self.min_amp = cfg.min_amp self.n_fft = cfg.n_fft self.win_length = int(cfg.win_size * cfg.sample_rate) + self.spectrum_type = cfg.spectrum_type self.register_buffer( "mel_basis", @@ -72,10 +90,13 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config): n_mels=cfg.num_filters, fmin=cfg.f_min, fmax=cfg.f_max, - ) + htk=cfg.htk, + norm=cfg.norm, + dtype=cfg.dtype, + ), ), ) - self.register_buffer("window", torch.hann_window(self.win_length)) + self.register_buffer("window", torch.hann_window(self.win_length, periodic=cfg.periodic)) def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -83,31 +104,46 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]: :param length in samples: [B] :return features as [B,T,F] and length in frames [B] """ - power_spectrum = ( - torch.abs( - torch.stft( - raw_audio, - n_fft=self.n_fft, - hop_length=self.hop_length, - win_length=self.win_length, - window=self.window, - center=self.center, - pad_mode="constant", - return_complex=True, + if self.spectrum_type == SpectrumType.STFT: + power_spectrum = ( + torch.abs( + torch.stft( + raw_audio, + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window=self.window, + center=self.center, + pad_mode="constant", + return_complex=True, + ) ) + ** 2 ) - ** 2 - ) + elif self.spectrum_type == SpectrumType.RFFTN: + windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length] + smoothed = windowed * self.window.unsqueeze(0) # [B, T', W] + + # Compute power spectrum using torch.fft.rfftn + power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, T', F=n_fft//2+1] + power_spectrum = power_spectrum.transpose(1, 2) # [B, F, T'] + else: + raise ValueError(f"Invalid spectrum type {self.spectrum_type!r}.") + if len(power_spectrum.size()) == 2: # For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again - power_spectrum = torch.unsqueeze(power_spectrum, 0) - melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis) + power_spectrum = torch.unsqueeze(power_spectrum, 0) # [B, F, T'] + melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis) # [B, F'=num_filters, T'] log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp)) - feature_data = torch.transpose(log_melspec, 1, 2) + feature_data = torch.transpose(log_melspec, 1, 2) # [B, T', F'] - if self.center: - length = (length // self.hop_length) + 1 + if self.spectrum_type == SpectrumType.STFT: + if self.center: + length = (length // self.hop_length) + 1 + else: + length = ((length - self.n_fft) // self.hop_length) + 1 + elif self.spectrum_type == SpectrumType.RFFTN: + length = ((length - self.win_length) // self.hop_length) + 1 else: - length = ((length - self.n_fft) // self.hop_length) + 1 - + raise ValueError(f"Invalid spectrum type {self.spectrum_type!r}.") return feature_data, length.int()