Skip to content

Commit

Permalink
Merge branch 'k2-fsa:master' into fix/valle2
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr authored Dec 8, 2024
2 parents e1e78e6 + 5c04f7b commit 6d625b0
Show file tree
Hide file tree
Showing 27 changed files with 509 additions and 351 deletions.
2 changes: 1 addition & 1 deletion .github/scripts/ljspeech/TTS/run-matcha.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ function infer() {

curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1

./matcha/inference.py \
./matcha/infer.py \
--epoch 1 \
--exp-dir ./matcha/exp \
--tokens data/tokens.txt \
Expand Down
4 changes: 2 additions & 2 deletions egs/ljspeech/TTS/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ To inference, use:

wget https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1

./matcha/inference \
./matcha/synth.py \
--exp-dir ./matcha/exp-new-3 \
--epoch 4000 \
--tokens ./data/tokens.txt \
--vocoder ./generator_v1 \
--input-text "how are you doing?"
--input-text "how are you doing?" \
--output-wav ./generated.wav
```

Expand Down
1 change: 1 addition & 0 deletions egs/ljspeech/TTS/local/audio.py
91 changes: 3 additions & 88 deletions egs/ljspeech/TTS/local/compute_fbank_ljspeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,102 +27,17 @@
import argparse
import logging
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Union

import numpy as np
import torch
from fbank import MatchaFbank, MatchaFbankConfig
from lhotse import CutSet, LilcomChunkyWriter, load_manifest
from lhotse.audio import RecordingSet
from lhotse.features.base import FeatureExtractor, register_extractor
from lhotse.supervision import SupervisionSet
from lhotse.utils import Seconds, compute_num_frames
from matcha.audio import mel_spectrogram

from icefall.utils import get_executor


@dataclass
class MyFbankConfig:
n_fft: int
n_mels: int
sampling_rate: int
hop_length: int
win_length: int
f_min: float
f_max: float


@register_extractor
class MyFbank(FeatureExtractor):

name = "MyFbank"
config_type = MyFbankConfig

def __init__(self, config):
super().__init__(config=config)

@property
def device(self) -> Union[str, torch.device]:
return self.config.device

def feature_dim(self, sampling_rate: int) -> int:
return self.config.n_mels

def extract(
self,
samples: np.ndarray,
sampling_rate: int,
) -> torch.Tensor:
# Check for sampling rate compatibility.
expected_sr = self.config.sampling_rate
assert sampling_rate == expected_sr, (
f"Mismatched sampling rate: extractor expects {expected_sr}, "
f"got {sampling_rate}"
)
samples = torch.from_numpy(samples)
assert samples.ndim == 2, samples.shape
assert samples.shape[0] == 1, samples.shape

mel = (
mel_spectrogram(
samples,
self.config.n_fft,
self.config.n_mels,
self.config.sampling_rate,
self.config.hop_length,
self.config.win_length,
self.config.f_min,
self.config.f_max,
center=False,
)
.squeeze()
.t()
)

assert mel.ndim == 2, mel.shape
assert mel.shape[1] == self.config.n_mels, mel.shape

num_frames = compute_num_frames(
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
)

if mel.shape[0] > num_frames:
mel = mel[:num_frames]
elif mel.shape[0] < num_frames:
mel = mel.unsqueeze(0)
mel = torch.nn.functional.pad(
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
).squeeze(0)

return mel.numpy()

@property
def frame_shift(self) -> Seconds:
return self.config.hop_length / self.config.sampling_rate


def get_parser():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter
Expand All @@ -149,7 +64,7 @@ def compute_fbank_ljspeech(num_jobs: int):
logging.info(f"num_jobs: {num_jobs}")
logging.info(f"src_dir: {src_dir}")
logging.info(f"output_dir: {output_dir}")
config = MyFbankConfig(
config = MatchaFbankConfig(
n_fft=1024,
n_mels=80,
sampling_rate=22050,
Expand All @@ -170,7 +85,7 @@ def compute_fbank_ljspeech(num_jobs: int):
src_dir / f"{prefix}_supervisions_{partition}.{suffix}", SupervisionSet
)

extractor = MyFbank(config)
extractor = MatchaFbank(config)

with get_executor() as ex: # Initialize the executor only once.
cuts_filename = f"{prefix}_cuts_{partition}.{suffix}"
Expand Down
1 change: 1 addition & 0 deletions egs/ljspeech/TTS/local/fbank.py
1 change: 0 additions & 1 deletion egs/ljspeech/TTS/matcha/compute_fbank_ljspeech.py

This file was deleted.

2 changes: 1 addition & 1 deletion egs/ljspeech/TTS/matcha/export_onnx_hifigan.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import onnx
import torch
from inference import load_vocoder
from infer import load_vocoder


def add_meta_data(filename: str, meta_data: Dict[str, Any]):
Expand Down
88 changes: 88 additions & 0 deletions egs/ljspeech/TTS/matcha/fbank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from dataclasses import dataclass
from typing import Union

import numpy as np
import torch
from audio import mel_spectrogram
from lhotse.features.base import FeatureExtractor, register_extractor
from lhotse.utils import Seconds, compute_num_frames


@dataclass
class MatchaFbankConfig:
n_fft: int
n_mels: int
sampling_rate: int
hop_length: int
win_length: int
f_min: float
f_max: float


@register_extractor
class MatchaFbank(FeatureExtractor):

name = "MatchaFbank"
config_type = MatchaFbankConfig

def __init__(self, config):
super().__init__(config=config)

@property
def device(self) -> Union[str, torch.device]:
return self.config.device

def feature_dim(self, sampling_rate: int) -> int:
return self.config.n_mels

def extract(
self,
samples: np.ndarray,
sampling_rate: int,
) -> torch.Tensor:
# Check for sampling rate compatibility.
expected_sr = self.config.sampling_rate
assert sampling_rate == expected_sr, (
f"Mismatched sampling rate: extractor expects {expected_sr}, "
f"got {sampling_rate}"
)
samples = torch.from_numpy(samples)
assert samples.ndim == 2, samples.shape
assert samples.shape[0] == 1, samples.shape

mel = (
mel_spectrogram(
samples,
self.config.n_fft,
self.config.n_mels,
self.config.sampling_rate,
self.config.hop_length,
self.config.win_length,
self.config.f_min,
self.config.f_max,
center=False,
)
.squeeze()
.t()
)

assert mel.ndim == 2, mel.shape
assert mel.shape[1] == self.config.n_mels, mel.shape

num_frames = compute_num_frames(
samples.shape[1] / sampling_rate, self.frame_shift, sampling_rate
)

if mel.shape[0] > num_frames:
mel = mel[:num_frames]
elif mel.shape[0] < num_frames:
mel = mel.unsqueeze(0)
mel = torch.nn.functional.pad(
mel, (0, 0, 0, num_frames - mel.shape[1]), mode="replicate"
).squeeze(0)

return mel.numpy()

@property
def frame_shift(self) -> Seconds:
return self.config.hop_length / self.config.sampling_rate
Loading

0 comments on commit 6d625b0

Please sign in to comment.