Skip to content

Commit

Permalink
Revert "Extend feature extraction module to allow for RASR compatible…
Browse files Browse the repository at this point in the history
… logmel features (#40)" (#41)

This reverts commit c363a01.

Previous commit breakes hashes/setups.
  • Loading branch information
JackTemaki authored Dec 7, 2023
1 parent c363a01 commit 933c6c1
Showing 1 changed file with 23 additions and 59 deletions.
82 changes: 23 additions & 59 deletions i6_models/primitives/feature_extraction.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,15 @@
__all__ = ["LogMelFeatureExtractionV1", "LogMelFeatureExtractionV1Config"]

from dataclasses import dataclass
from typing import Optional, Tuple, Union, Literal
from enum import Enum
from typing import Optional, Tuple

from librosa import filters
import torch
from torch import nn
import numpy as np
from numpy.typing import DTypeLike

from i6_models.config import ModelConfiguration


class SpectrumType(Enum):
STFT = 1
RFFTN = 2


@dataclass
class LogMelFeatureExtractionV1Config(ModelConfiguration):
"""
Expand All @@ -30,10 +22,6 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
min_amp: minimum amplitude for safe log
num_filters: number of mel windows
center: centered STFT with automatic padding
periodic: whether the window is assumed to be periodic
htk: whether use HTK formula instead of Slaney
norm: how to normalize the filters, cf. https://librosa.org/doc/main/generated/librosa.filters.mel.html
spectrum_type: apply torch.stft on raw audio input (default) or torch.fft.rfftn on windowed audio to make features compatible to RASR's
"""

sample_rate: int
Expand All @@ -45,11 +33,6 @@ class LogMelFeatureExtractionV1Config(ModelConfiguration):
num_filters: int
center: bool
n_fft: Optional[int] = None
periodic: bool = True
htk: bool = False
norm: Optional[Union[Literal["slaney"], float]] = "slaney"
dtype: DTypeLike = np.float32
spectrum_type: SpectrumType = SpectrumType.STFT

def __post_init__(self) -> None:
super().__post_init__()
Expand Down Expand Up @@ -79,7 +62,6 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config):
self.min_amp = cfg.min_amp
self.n_fft = cfg.n_fft
self.win_length = int(cfg.win_size * cfg.sample_rate)
self.spectrum_type = cfg.spectrum_type

self.register_buffer(
"mel_basis",
Expand All @@ -90,60 +72,42 @@ def __init__(self, cfg: LogMelFeatureExtractionV1Config):
n_mels=cfg.num_filters,
fmin=cfg.f_min,
fmax=cfg.f_max,
htk=cfg.htk,
norm=cfg.norm,
dtype=cfg.dtype,
),
)
),
)
self.register_buffer("window", torch.hann_window(self.win_length, periodic=cfg.periodic))
self.register_buffer("window", torch.hann_window(self.win_length))

def forward(self, raw_audio, length) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:param raw_audio: [B, T]
:param length in samples: [B]
:return features as [B,T,F] and length in frames [B]
"""
if self.spectrum_type == SpectrumType.STFT:
power_spectrum = (
torch.abs(
torch.stft(
raw_audio,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode="constant",
return_complex=True,
)
power_spectrum = (
torch.abs(
torch.stft(
raw_audio,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=self.center,
pad_mode="constant",
return_complex=True,
)
** 2
)
elif self.spectrum_type == SpectrumType.RFFTN:
windowed = raw_audio.unfold(1, size=self.win_length, step=self.hop_length) # [B, T', W=win_length]
smoothed = windowed * self.window.unsqueeze(0) # [B, T', W]

# Compute power spectrum using torch.fft.rfftn
power_spectrum = torch.abs(torch.fft.rfftn(smoothed, s=self.n_fft)) ** 2 # [B, T', F=n_fft//2+1]
power_spectrum = power_spectrum.transpose(1, 2) # [B, F, T']
else:
raise ValueError(f"Invalid spectrum type {self.spectrum_type!r}.")

** 2
)
if len(power_spectrum.size()) == 2:
# For some reason torch.stft removes the batch axis for batch sizes of 1, so we need to add it again
power_spectrum = torch.unsqueeze(power_spectrum, 0) # [B, F, T']
melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis) # [B, F'=num_filters, T']
power_spectrum = torch.unsqueeze(power_spectrum, 0)
melspec = torch.einsum("...ft,mf->...mt", power_spectrum, self.mel_basis)
log_melspec = torch.log10(torch.clamp(melspec, min=self.min_amp))
feature_data = torch.transpose(log_melspec, 1, 2) # [B, T', F']
feature_data = torch.transpose(log_melspec, 1, 2)

if self.spectrum_type == SpectrumType.STFT:
if self.center:
length = (length // self.hop_length) + 1
else:
length = ((length - self.n_fft) // self.hop_length) + 1
elif self.spectrum_type == SpectrumType.RFFTN:
length = ((length - self.win_length) // self.hop_length) + 1
if self.center:
length = (length // self.hop_length) + 1
else:
raise ValueError(f"Invalid spectrum type {self.spectrum_type!r}.")
length = ((length - self.n_fft) // self.hop_length) + 1

return feature_data, length.int()

0 comments on commit 933c6c1

Please sign in to comment.