Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend feature extraction module to allow for RASR compatible logmel features #40

Merged
merged 10 commits into from
Dec 7, 2023
79 changes: 56 additions & 23 deletions i6_models/primitives/feature_extraction.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
__all__ = ["LogMelFeatureExtractionV1", "LogMelFeatureExtractionV1Config"]

from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional, Tuple, Union, Literal
from enum import Enum

from librosa import filters
import torch
from torch import nn
import numpy as np
from numpy.typing import DTypeLike

from i6_models.config import ModelConfiguration


class SpectrumType(Enum):
STFT = 1
RFFTN = 2


@dataclass
class LogMelFeatureExtractionV1Config(ModelConfiguration):
"""
Expand All @@ -22,6 +30,10 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
min_amp: minimum amplitude for safe log
num_filters: number of mel windows
center: centered STFT with automatic padding
periodic: whether the window is assumed to be periodic
htk: whether use HTK formula instead of Slaney
norm: how to normalize the filters, cf. https://librosa.org/doc/main/generated/librosa.filters.mel.html
spectrum_type: apply torch.stft on raw audio input (default) or torch.fft.rfftn on windowed audio to make features compatible to RASR's
"""

sample_rate: int
Expand All @@ -33,6 +45,11 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
num_filters: int
center: bool
n_fft: Optional[int] = None
periodic: bool = True
htk: bool = False
norm: Optional[Union[Literal["slaney"], float]] = "slaney"
dtype: DTypeLike = np.float32
spectrum_type: SpectrumType = SpectrumType.STFT
Comment on lines +48 to +52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we agreed on not setting any defaults unless it is immediately obvious that that should be a default or for "structural parts"

But I now wonder that the original values had no defaults and what you add does. (Wich is required, I understand, to keep backwards compatibility - But then did we not agree to deprecate the old class and create a new one V2 in cases like this?)
Do we actually need to configure all this? Or are there ever only 2 sets of parameters that are going to be used? ("default", "rasr-compatible"). Is then not a better choice to have a single parameter for "default" vs. "rasr-compatible" that then sets all of these never otherwise touched parameters to the appropriate values?
And if the answer is yes, would it not rather make sense in this case to do create a new class - but not LogMelFeatureExtractionV2 but instead RasrCompatibleLogMelFeatureExtractionV1..? (and then we have no config options, just the two classes do use the correct thing)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curufinwe approved these changes 8 minutes ago
curufinwe merged commit c363a01 into main 8 minutes ago

oh well 🤷

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the defaults are problematic here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not the defaults that are problematic, but adding new parameter automatically breaks setups. I will revert the PR. And then I would be in favor of Willis RasrCompatibleLogMelFeatureExtraction instead of the V2. But I would be fine with both.

Copy link
Member

@albertz albertz Dec 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does adding parameters with default values break anything?

Edit See discussion in #41 about that.


def __post_init__(self) -> None:
super().__post_init__()
Expand Down Expand Up @@ -62,6 +79,7 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config):
self.min_amp = cfg.min_amp
self.n_fft = cfg.n_fft
self.win_length = int(cfg.win_size * cfg.sample_rate)
self.spectrum_type = cfg.spectrum_type

self.register_buffer(
"mel_basis",
Expand All @@ -72,42 +90,57 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config):
n_mels=cfg.num_filters,
fmin=cfg.f_min,
fmax=cfg.f_max,
)
htk=cfg.htk,
norm=cfg.norm,
dtype=cfg.dtype,
),
),
)
self.register_buffer("window", torch.hann_window(self.win_length))
self.register_buffer("window", torch.hann_window(self.win_length, periodic=cfg.periodic))

def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:param raw_audio: [B, T]
:param length in samples: [B]
:return features as [B,T,F] and length in frames [B]
"""
power_spectrum = (
torch.abs(
torch.stft(
raw_audio,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode="constant",
return_complex=True,
if self.spectrum_type == SpectrumType.STFT:
power_spectrum = (
torch.abs(
torch.stft(
raw_audio,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode="constant",
return_complex=True,
)
)
** 2
)
** 2
)
elif self.spectrum_type == SpectrumType.RFFTN:
windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length]
smoothed = windowed * self.window.unsqueeze(0) # [B, T', W]

# Compute power spectrum using torch.fft.rfftn
power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, T', F=n_fft//2+1]
power_spectrum = power_spectrum.transpose(1, 2) # [B, F, T']
albertz marked this conversation as resolved.
Show resolved Hide resolved

if len(power_spectrum.size()) == 2:
# For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again
power_spectrum = torch.unsqueeze(power_spectrum, 0)
melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis)
power_spectrum = torch.unsqueeze(power_spectrum, 0) # [B, F, T']
melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis) # [B, F'=num_filters, T']
log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp))
feature_data = torch.transpose(log_melspec, 1, 2)
feature_data = torch.transpose(log_melspec, 1, 2) # [B, T', F']

if self.center:
length = (length // self.hop_length) + 1
else:
length = ((length - self.n_fft) // self.hop_length) + 1
if self.spectrum_type == SpectrumType.STFT:
if self.center:
length = (length // self.hop_length) + 1
else:
length = ((length - self.n_fft) // self.hop_length) + 1
elif self.spectrum_type == SpectrumType.RFFTN:
length = ((length - self.win_length) // self.hop_length) + 1
albertz marked this conversation as resolved.
Show resolved Hide resolved

return feature_data, length.int()
Loading