Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend feature extraction module to allow for RASR compatible logmel features #40

Merged
merged 10 commits into from
Dec 7, 2023
58 changes: 38 additions & 20 deletions i6_models/primitives/feature_extraction.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
__all__ = ["LogMelFeatureExtractionV1", "LogMelFeatureExtractionV1Config"]

from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional, Tuple, Any, Dict

from librosa import filters
import torch
Expand All @@ -22,6 +22,9 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
min_amp: minimum amplitude for safe log
num_filters: number of mel windows
center: centered STFT with automatic padding
periodic: whether the window is assumed to be periodic
mel_options: extra options for mel filters
rasr_compatible: apply FFT to make features compatible to RASR's, otherwise (default) apply STFT
curufinwe marked this conversation as resolved.
Show resolved Hide resolved
"""

sample_rate: int
Expand All @@ -33,6 +36,9 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
num_filters: int
center: bool
n_fft: Optional[int] = None
periodic: bool = True
mel_options: Optional[Dict[str, Any]] = None
rasr_compatible: bool = False

def __post_init__(self) -> None:
super().__post_init__()
Expand Down Expand Up @@ -62,6 +68,8 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config):
self.min_amp = cfg.min_amp
self.n_fft = cfg.n_fft
self.win_length = int(cfg.win_size * cfg.sample_rate)
self.mel_options = cfg.mel_options or {}
self.rasr_compatible = cfg.rasr_compatible

self.register_buffer(
"mel_basis",
Expand All @@ -72,42 +80,52 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config):
n_mels=cfg.num_filters,
fmin=cfg.f_min,
fmax=cfg.f_max,
**self.mel_options,
curufinwe marked this conversation as resolved.
Show resolved Hide resolved
)
),
)
self.register_buffer("window", torch.hann_window(self.win_length))
self.register_buffer("window", torch.hann_window(self.win_length, periodic=cfg.periodic))

def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:param raw_audio: [B, T]
:param length in samples: [B]
:return features as [B,T,F] and length in frames [B]
"""
power_spectrum = (
torch.abs(
torch.stft(
raw_audio,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode="constant",
return_complex=True,
if self.rasr_compatible:
windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length]
smoothed = windowed * self.window.unsqueeze(0) # [B, T', W]

# Compute power spectrum using torch.fft.rfftn
power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, T', F=n_fft//2+1]
power_spectrum = power_spectrum.transpose(1, 2) # [B, F, T']
else:
power_spectrum = (
torch.abs(
torch.stft(
raw_audio,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode="constant",
return_complex=True,
)
)
** 2
)
** 2
)

if len(power_spectrum.size()) == 2:
# For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again
power_spectrum = torch.unsqueeze(power_spectrum, 0)
melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis)
power_spectrum = torch.unsqueeze(power_spectrum, 0) # [B, F, T']
melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis) # [B, F'=num_filters, T']
log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp))
feature_data = torch.transpose(log_melspec, 1, 2)
feature_data = torch.transpose(log_melspec, 1, 2) # [B, T', F']

if self.center:
if self.center and not self.rasr_compatible:
length = (length // self.hop_length) + 1
else:
length = ((length - self.n_fft) // self.hop_length) + 1
length = ((length - self.win_length) // self.hop_length) + 1
albertz marked this conversation as resolved.
Show resolved Hide resolved
albertz marked this conversation as resolved.
Show resolved Hide resolved

return feature_data, length.int()
Loading