Skip to content

Commit

Permalink
Extend feature extraction module to allow for RASR compatible logmel …
Browse files Browse the repository at this point in the history
…features (#40)
  • Loading branch information
kuacakuaca authored Dec 7, 2023
1 parent 8f4c364 commit c363a01
Showing 1 changed file with 59 additions and 23 deletions.
82 changes: 59 additions & 23 deletions i6_models/primitives/feature_extraction.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
__all__ = ["LogMelFeatureExtractionV1", "LogMelFeatureExtractionV1Config"]

from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional, Tuple, Union, Literal
from enum import Enum

from librosa import filters
import torch
from torch import nn
import numpy as np
from numpy.typing import DTypeLike

from i6_models.config import ModelConfiguration


class SpectrumType(Enum):
STFT = 1
RFFTN = 2


@dataclass
class LogMelFeatureExtractionV1Config(ModelConfiguration):
"""
Expand All @@ -22,6 +30,10 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
min_amp: minimum amplitude for safe log
num_filters: number of mel windows
center: centered STFT with automatic padding
periodic: whether the window is assumed to be periodic
htk: whether use HTK formula instead of Slaney
norm: how to normalize the filters, cf. https://librosa.org/doc/main/generated/librosa.filters.mel.html
spectrum_type: apply torch.stft on raw audio input (default) or torch.fft.rfftn on windowed audio to make features compatible to RASR's
"""

sample_rate: int
Expand All @@ -33,6 +45,11 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
num_filters: int
center: bool
n_fft: Optional[int] = None
periodic: bool = True
htk: bool = False
norm: Optional[Union[Literal["slaney"], float]] = "slaney"
dtype: DTypeLike = np.float32
spectrum_type: SpectrumType = SpectrumType.STFT

def __post_init__(self) -> None:
super().__post_init__()
Expand Down Expand Up @@ -62,6 +79,7 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config):
self.min_amp = cfg.min_amp
self.n_fft = cfg.n_fft
self.win_length = int(cfg.win_size * cfg.sample_rate)
self.spectrum_type = cfg.spectrum_type

self.register_buffer(
"mel_basis",
Expand All @@ -72,42 +90,60 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config):
n_mels=cfg.num_filters,
fmin=cfg.f_min,
fmax=cfg.f_max,
)
htk=cfg.htk,
norm=cfg.norm,
dtype=cfg.dtype,
),
),
)
self.register_buffer("window", torch.hann_window(self.win_length))
self.register_buffer("window", torch.hann_window(self.win_length, periodic=cfg.periodic))

def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:param raw_audio: [B, T]
:param length in samples: [B]
:return features as [B,T,F] and length in frames [B]
"""
power_spectrum = (
torch.abs(
torch.stft(
raw_audio,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode="constant",
return_complex=True,
if self.spectrum_type == SpectrumType.STFT:
power_spectrum = (
torch.abs(
torch.stft(
raw_audio,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode="constant",
return_complex=True,
)
)
** 2
)
** 2
)
elif self.spectrum_type == SpectrumType.RFFTN:
windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length]
smoothed = windowed * self.window.unsqueeze(0) # [B, T', W]

# Compute power spectrum using torch.fft.rfftn
power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, T', F=n_fft//2+1]
power_spectrum = power_spectrum.transpose(1, 2) # [B, F, T']
else:
raise ValueError(f"Invalid spectrum type {self.spectrum_type!r}.")

if len(power_spectrum.size()) == 2:
# For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again
power_spectrum = torch.unsqueeze(power_spectrum, 0)
melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis)
power_spectrum = torch.unsqueeze(power_spectrum, 0) # [B, F, T']
melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis) # [B, F'=num_filters, T']
log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp))
feature_data = torch.transpose(log_melspec, 1, 2)
feature_data = torch.transpose(log_melspec, 1, 2) # [B, T', F']

if self.center:
length = (length // self.hop_length) + 1
if self.spectrum_type == SpectrumType.STFT:
if self.center:
length = (length // self.hop_length) + 1
else:
length = ((length - self.n_fft) // self.hop_length) + 1
elif self.spectrum_type == SpectrumType.RFFTN:
length = ((length - self.win_length) // self.hop_length) + 1
else:
length = ((length - self.n_fft) // self.hop_length) + 1

raise ValueError(f"Invalid spectrum type {self.spectrum_type!r}.")
return feature_data, length.int()

0 comments on commit c363a01

Please sign in to comment.