diff --git a/.github/scripts/ljspeech/TTS/run-matcha.sh b/.github/scripts/ljspeech/TTS/run-matcha.sh index 0876cb47f2..93c4e13c05 100755 --- a/.github/scripts/ljspeech/TTS/run-matcha.sh +++ b/.github/scripts/ljspeech/TTS/run-matcha.sh @@ -56,7 +56,7 @@ function infer() { curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 - ./matcha/infer.py \ + ./matcha/synth.py \ --epoch 1 \ --exp-dir ./matcha/exp \ --tokens data/tokens.txt \ diff --git a/egs/ljspeech/TTS/matcha/infer.py b/egs/ljspeech/TTS/matcha/infer.py index 64abd8e50b..c0842ee15e 100755 --- a/egs/ljspeech/TTS/matcha/infer.py +++ b/egs/ljspeech/TTS/matcha/infer.py @@ -9,14 +9,16 @@ import soundfile as sf import torch -from matcha.hifigan.config import v1, v2, v3 -from matcha.hifigan.denoiser import Denoiser -from matcha.hifigan.models import Generator as HiFiGAN +import torch.nn as nn +from hifigan.config import v1, v2, v3 +from hifigan.denoiser import Denoiser +from hifigan.models import Generator as HiFiGAN from tokenizer import Tokenizer from train import get_model, get_params +from tts_datamodule import LJSpeechTtsDataModule from icefall.checkpoint import load_checkpoint -from icefall.utils import AttributeDict +from icefall.utils import AttributeDict, setup_logger def get_parser(): @@ -63,24 +65,10 @@ def get_parser(): help="""Path to vocabulary.""", ) - parser.add_argument( - "--input-text", - type=str, - required=True, - help="The text to generate speech for", - ) - - parser.add_argument( - "--output-wav", - type=str, - required=True, - help="The filename of the wave to save the generated speech", - ) - return parser -def load_vocoder(checkpoint_path): +def load_vocoder(checkpoint_path: Path) -> nn.Module: checkpoint_path = str(checkpoint_path) if checkpoint_path.endswith("v1"): h = AttributeDict(v1) @@ -100,13 +88,15 @@ def load_vocoder(checkpoint_path): return hifigan -def to_waveform(mel, vocoder, denoiser): +def to_waveform( + mel: torch.Tensor, vocoder: nn.Module, denoiser: nn.Module +) -> torch.Tensor: audio = vocoder(mel).clamp(-1, 1) audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze() return audio.cpu().squeeze() -def process_text(text: str, tokenizer): +def process_text(text: str, tokenizer: Tokenizer) -> dict: x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) x = torch.tensor(x, dtype=torch.long) x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu") @@ -114,8 +104,14 @@ def process_text(text: str, tokenizer): def synthesise( - model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None -): + model: nn.Module, + tokenizer: Tokenizer, + n_timesteps: int, + text: str, + length_scale: float, + temperature: float, + spks=None, +) -> dict: text_processed = process_text(text, tokenizer) start_t = dt.datetime.now() output = model.synthesise( @@ -131,14 +127,102 @@ def synthesise( return output +def infer_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + vocoder: nn.Module, + denoiser: nn.Module, + tokenizer: Tokenizer, +) -> None: + """Decode dataset. + The ground-truth and generated audio pairs will be saved to `params.save_wav_dir`. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + tokenizer: + Used to convert text to phonemes. + """ + + device = next(model.parameters()).device + num_cuts = 0 + log_interval = 5 + + try: + num_batches = len(dl) + except TypeError: + num_batches = "?" + + for batch_idx, batch in enumerate(dl): + batch_size = len(batch["tokens"]) + + texts = batch["supervisions"]["text"] + + audio = batch["audio"] + audio_lens = batch["audio_lens"].tolist() + cut_ids = [cut.id for cut in batch["cut"]] + + for i in range(batch_size): + output = synthesise( + model=model, + tokenizer=tokenizer, + n_timesteps=params.n_timesteps, + text=texts[i], + length_scale=params.length_scale, + temperature=params.temperature, + ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + + sf.write( + file=params.save_wave_dir / f"{cut_ids[i]}_pred.wav", + data=output["waveform"], + samplerate=params.sampling_rate, + subtype="PCM_16" + ) + sf.write( + file=params.save_wave_dir / f"{cut_ids[i]}_gt.wav", + data=audio[i].numpy(), + samplerate=params.sampling_rate, + subtype="PCM_16" + ) + + num_cuts += batch_size + + if batch_idx % log_interval == 0: + batch_str = f"{batch_idx}/{num_batches}" + + logging.info(f"batch {batch_str}, cuts processed until now is {num_cuts}") + + @torch.inference_mode() def main(): parser = get_parser() + LJSpeechTtsDataModule.add_arguments(parser) args = parser.parse_args() - params = get_params() + args.exp_dir = Path(args.exp_dir) + params = get_params() params.update(vars(args)) + params.suffix = f"epoch-{params.epoch}" + + params.res_dir = params.exp_dir / "infer" / params.suffix + params.save_wav_dir = params.res_dir / "wav" + params.save_wav_dir.mkdir(parents=True, exist_ok=True) + + setup_logger(f"{params.res_dir}/log-infer-{params.suffix}") + logging.info("Infer started") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + logging.info(f"Device: {device}") + tokenizer = Tokenizer(params.tokens) params.blank_id = tokenizer.pad_id params.vocab_size = tokenizer.vocab_size @@ -151,49 +235,49 @@ def main(): params.model_args.data_statistics.mel_mean = stats["fbank_mean"] params.model_args.data_statistics.mel_std = stats["fbank_std"] + + # Number of ODE Solver steps + params.n_timesteps = 2 + + # Changes to the speaking rate + params.length_scale = 1.0 + + # Sampling temperature + params.temperature = 0.667 logging.info(params) logging.info("About to create model") model = get_model(params) - if not Path(f"{params.exp_dir}/epoch-{params.epoch}.pt").is_file(): - raise ValueError("{params.exp_dir}/epoch-{params.epoch}.pt does not exist") - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + model.to(device) model.eval() + # we need cut ids to organize tts results. + args.return_cuts = True + ljspeech = LJSpeechTtsDataModule(args) + + test_cuts = ljspeech.test_cuts() + test_dl = ljspeech.test_dataloaders(test_cuts) + if not Path(params.vocoder).is_file(): raise ValueError(f"{params.vocoder} does not exist") vocoder = load_vocoder(params.vocoder) - denoiser = Denoiser(vocoder, mode="zeros") - - # Number of ODE Solver steps - n_timesteps = 2 + vocoder = vocoder.to(device) - # Changes to the speaking rate - length_scale = 1.0 - - # Sampling temperature - temperature = 0.667 + denoiser = Denoiser(vocoder, mode="zeros") + denoiser = denoiser.to(device) - output = synthesise( + infer_dataset( + dl=test_dl, + params=params, model=model, + vocoder=vocoder, + denoiser=denoiser, tokenizer=tokenizer, - n_timesteps=n_timesteps, - text=params.input_text, - length_scale=length_scale, - temperature=temperature, ) - output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) - - sf.write(params.output_wav, output["waveform"], 22050, "PCM_16") if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - torch.set_num_threads(1) - torch.set_num_interop_threads(1) main() diff --git a/egs/ljspeech/TTS/matcha/synth.py b/egs/ljspeech/TTS/matcha/synth.py new file mode 100755 index 0000000000..a4880fd3a2 --- /dev/null +++ b/egs/ljspeech/TTS/matcha/synth.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) + +import argparse +import datetime as dt +import json +import logging +from pathlib import Path + +import soundfile as sf +import torch +from matcha.hifigan.config import v1, v2, v3 +from matcha.hifigan.denoiser import Denoiser +from matcha.hifigan.models import Generator as HiFiGAN +from tokenizer import Tokenizer +from train import get_model, get_params + +from icefall.checkpoint import load_checkpoint +from icefall.utils import AttributeDict, setup_logger + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--epoch", + type=int, + default=4000, + help="""It specifies the checkpoint to use for decoding. + Note: Epoch counts from 1. + """, + ) + + parser.add_argument( + "--exp-dir", + type=Path, + default="matcha/exp-new-3", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--vocoder", + type=Path, + default="./generator_v1", + help="Path to the vocoder", + ) + + parser.add_argument( + "--tokens", + type=Path, + default="data/tokens.txt", + ) + + parser.add_argument( + "--cmvn", + type=str, + default="data/fbank/cmvn.json", + help="""Path to vocabulary.""", + ) + + parser.add_argument( + "--input-text", + type=str, + required=True, + help="The text to generate speech for", + ) + + parser.add_argument( + "--output-wav", + type=str, + required=True, + help="The filename of the wave to save the generated speech", + ) + + return parser + + +def load_vocoder(checkpoint_path): + checkpoint_path = str(checkpoint_path) + if checkpoint_path.endswith("v1"): + h = AttributeDict(v1) + elif checkpoint_path.endswith("v2"): + h = AttributeDict(v2) + elif checkpoint_path.endswith("v3"): + h = AttributeDict(v3) + else: + raise ValueError(f"supports only v1, v2, and v3, given {checkpoint_path}") + + hifigan = HiFiGAN(h).to("cpu") + hifigan.load_state_dict( + torch.load(checkpoint_path, map_location="cpu")["generator"] + ) + _ = hifigan.eval() + hifigan.remove_weight_norm() + return hifigan + + +def to_waveform(mel, vocoder, denoiser): + audio = vocoder(mel).clamp(-1, 1) + audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze() + return audio.cpu().squeeze() + + +def process_text(text: str, tokenizer): + x = tokenizer.texts_to_token_ids([text], add_sos=True, add_eos=True) + x = torch.tensor(x, dtype=torch.long) + x_lengths = torch.tensor([x.shape[-1]], dtype=torch.long, device="cpu") + return {"x_orig": text, "x": x, "x_lengths": x_lengths} + + +def synthesise( + model, tokenizer, n_timesteps, text, length_scale, temperature, spks=None +): + text_processed = process_text(text, tokenizer) + start_t = dt.datetime.now() + output = model.synthesise( + text_processed["x"], + text_processed["x_lengths"], + n_timesteps=n_timesteps, + temperature=temperature, + spks=spks, + length_scale=length_scale, + ) + # merge everything to one dict + output.update({"start_t": start_t, **text_processed}) + return output + + +@torch.inference_mode() +def main(): + parser = get_parser() + args = parser.parse_args() + params = get_params() + + params.update(vars(args)) + + tokenizer = Tokenizer(params.tokens) + params.blank_id = tokenizer.pad_id + params.vocab_size = tokenizer.vocab_size + params.model_args.n_vocab = params.vocab_size + + with open(params.cmvn) as f: + stats = json.load(f) + params.data_args.data_statistics.mel_mean = stats["fbank_mean"] + params.data_args.data_statistics.mel_std = stats["fbank_std"] + + params.model_args.data_statistics.mel_mean = stats["fbank_mean"] + params.model_args.data_statistics.mel_std = stats["fbank_std"] + logging.info(params) + + logging.info("About to create model") + model = get_model(params) + + load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) + model.eval() + + if not Path(params.vocoder).is_file(): + raise ValueError(f"{params.vocoder} does not exist") + + vocoder = load_vocoder(params.vocoder) + denoiser = Denoiser(vocoder, mode="zeros") + + # Number of ODE Solver steps + n_timesteps = 2 + + # Changes to the speaking rate + length_scale = 1.0 + + # Sampling temperature + temperature = 0.667 + + output = synthesise( + model=model, + tokenizer=tokenizer, + n_timesteps=n_timesteps, + text=params.input_text, + length_scale=length_scale, + temperature=temperature, + ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + + sf.write(params.output_wav, output["waveform"], 22050, "PCM_16") + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + main() diff --git a/egs/ljspeech/TTS/matcha/train.py b/egs/ljspeech/TTS/matcha/train.py index 78f4f33734..31135f623b 100755 --- a/egs/ljspeech/TTS/matcha/train.py +++ b/egs/ljspeech/TTS/matcha/train.py @@ -150,7 +150,7 @@ def _get_data_params() -> AttributeDict: "n_spks": 1, "n_fft": 1024, "n_feats": 80, - "sample_rate": 22050, + "sampling_rate": 22050, "hop_length": 256, "win_length": 1024, "f_min": 0, @@ -445,11 +445,6 @@ def train_one_epoch( saved_bad_model = False - # used to track the stats over iterations in one epoch - tot_loss = MetricsTracker() - - saved_bad_model = False - def save_bad_model(suffix: str = ""): save_checkpoint( filename=params.exp_dir / f"bad-model{suffix}-{rank}.pt", diff --git a/egs/ljspeech/TTS/vits/infer.py b/egs/ljspeech/TTS/vits/infer.py index 7be76e3151..cf1067dfcc 100755 --- a/egs/ljspeech/TTS/vits/infer.py +++ b/egs/ljspeech/TTS/vits/infer.py @@ -234,7 +234,7 @@ def main(): logging.info(f"Number of parameters in discriminator: {num_param_d}") logging.info(f"Total number of parameters: {num_param_g + num_param_d}") - # we need cut ids to display recognition results. + # we need cut ids to organize tts results. args.return_cuts = True ljspeech = LJSpeechTtsDataModule(args) diff --git a/egs/ljspeech/TTS/vits/test_model.py b/egs/ljspeech/TTS/vits/test_model.py index 1de10f012b..4faaa96a54 100755 --- a/egs/ljspeech/TTS/vits/test_model.py +++ b/egs/ljspeech/TTS/vits/test_model.py @@ -18,7 +18,6 @@ from tokenizer import Tokenizer from train import get_model, get_params -from vits import VITS def test_model_type(model_type):