Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend feature extraction module to allow for RASR compatible logmel features #40

Merged
merged 10 commits into from
Dec 7, 2023
55 changes: 35 additions & 20 deletions i6_models/primitives/feature_extraction.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
__all__ = ["LogMelFeatureExtractionV1", "LogMelFeatureExtractionV1Config"]

from dataclasses import dataclass
from typing import Optional, Tuple, Any, Dict
from typing import Optional, Tuple, Union, Literal
from enum import Enum

from librosa import filters
import torch
from torch import nn
import numpy as np
from numpy.typing import DTypeLike

from i6_models.config import ModelConfiguration


class SpectrumType(Enum):
STFT = 1
RFFTN = 2


@dataclass
class LogMelFeatureExtractionV1Config(ModelConfiguration):
"""
Expand All @@ -23,8 +31,9 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
num_filters: number of mel windows
center: centered STFT with automatic padding
periodic: whether the window is assumed to be periodic
mel_options: extra options for mel filters
rasr_compatible: apply FFT to make features compatible to RASR's, otherwise (default) apply STFT
htk: whether use HTK formula instead of Slaney
norm: how to normalize the filters, cf. https://librosa.org/doc/main/generated/librosa.filters.mel.html
spectrum_type: apply torch.stft on raw audio input (default) or torch.fft.rfftn on windowed audio to make features compatible to RASR's
"""

sample_rate: int
Expand All @@ -37,8 +46,10 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
center: bool
n_fft: Optional[int] = None
periodic: bool = True
mel_options: Optional[Dict[str, Any]] = None
rasr_compatible: bool = False
htk: bool = False
norm: Optional[Union[Literal["slaney"], float]] = "slaney"
dtype: DTypeLike = np.float32
spectrum_type: SpectrumType = SpectrumType.STFT

def __post_init__(self) -> None:
super().__post_init__()
Expand Down Expand Up @@ -68,8 +79,7 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config):
self.min_amp = cfg.min_amp
self.n_fft = cfg.n_fft
self.win_length = int(cfg.win_size * cfg.sample_rate)
self.mel_options = cfg.mel_options or {}
self.rasr_compatible = cfg.rasr_compatible
self.spectrum_type = cfg.spectrum_type

self.register_buffer(
"mel_basis",
Expand All @@ -80,8 +90,10 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config):
n_mels=cfg.num_filters,
fmin=cfg.f_min,
fmax=cfg.f_max,
**self.mel_options,
)
htk=cfg.htk,
norm=cfg.norm,
dtype=cfg.dtype,
),
),
)
self.register_buffer("window", torch.hann_window(self.win_length, periodic=cfg.periodic))
Expand All @@ -92,14 +104,7 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
:param length in samples: [B]
:return features as [B,T,F] and length in frames [B]
"""
if self.rasr_compatible:
windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length]
smoothed = windowed * self.window.unsqueeze(0) # [B, T', W]

# Compute power spectrum using torch.fft.rfftn
power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, T', F=n_fft//2+1]
power_spectrum = power_spectrum.transpose(1, 2) # [B, F, T']
else:
if self.spectrum_type == SpectrumType.STFT:
power_spectrum = (
torch.abs(
torch.stft(
Expand All @@ -115,6 +120,13 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
)
** 2
)
elif self.spectrum_type == SpectrumType.RFFTN:
windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length]
smoothed = windowed * self.window.unsqueeze(0) # [B, T', W]

# Compute power spectrum using torch.fft.rfftn
power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, T', F=n_fft//2+1]
power_spectrum = power_spectrum.transpose(1, 2) # [B, F, T']
albertz marked this conversation as resolved.
Show resolved Hide resolved

if len(power_spectrum.size()) == 2:
# For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again
Expand All @@ -123,9 +135,12 @@ def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp))
feature_data = torch.transpose(log_melspec, 1, 2) # [B, T', F']

if self.center and not self.rasr_compatible:
length = (length // self.hop_length) + 1
else:
if self.spectrum_type == SpectrumType.STFT:
if self.center:
length = (length // self.hop_length) + 1
else:
length = ((length - self.n_fft) // self.hop_length) + 1
elif self.spectrum_type == SpectrumType.RFFTN:
length = ((length - self.win_length) // self.hop_length) + 1
albertz marked this conversation as resolved.
Show resolved Hide resolved

return feature_data, length.int()
Loading