diff --git a/.github/scripts/ljspeech/TTS/run-matcha.sh b/.github/scripts/ljspeech/TTS/run-matcha.sh index 93c4e13c05..0876cb47f2 100755 --- a/.github/scripts/ljspeech/TTS/run-matcha.sh +++ b/.github/scripts/ljspeech/TTS/run-matcha.sh @@ -56,7 +56,7 @@ function infer() { curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1 - ./matcha/synth.py \ + ./matcha/infer.py \ --epoch 1 \ --exp-dir ./matcha/exp \ --tokens data/tokens.txt \ diff --git a/egs/ljspeech/TTS/matcha/infer.py b/egs/ljspeech/TTS/matcha/infer.py index bb68a42f15..e756ffb89c 100755 --- a/egs/ljspeech/TTS/matcha/infer.py +++ b/egs/ljspeech/TTS/matcha/infer.py @@ -65,6 +65,28 @@ def get_parser(): help="""Path to vocabulary.""", ) + # The following arguments are used for inference on single text + parser.add_argument( + "--input-text", + type=str, + required=False, + help="The text to generate speech for", + ) + + parser.add_argument( + "--output-wav", + type=str, + required=False, + help="The filename of the wave to save the generated speech", + ) + + parser.add_argument( + "--sampling-rate", + type=int, + default=22050, + help="The sampling rate of the generated speech (default: 22050 for LJSpeech)", + ) + return parser @@ -103,7 +125,7 @@ def process_text(text: str, tokenizer: Tokenizer, device: str = "cpu") -> dict: return {"x_orig": text, "x": x, "x_lengths": x_lengths} -def synthesise( +def synthesize( model: nn.Module, tokenizer: Tokenizer, n_timesteps: int, @@ -169,7 +191,7 @@ def infer_dataset( cut_ids = [cut.id for cut in batch["cut"]] for i in range(batch_size): - output = synthesise( + output = synthesize( model=model, tokenizer=tokenizer, n_timesteps=params.n_timesteps, @@ -271,15 +293,35 @@ def main(): denoiser = Denoiser(vocoder, mode="zeros") denoiser.to(device) - infer_dataset( - dl=test_dl, - params=params, - model=model, - vocoder=vocoder, - denoiser=denoiser, - tokenizer=tokenizer, - ) - + if params.input_text is not None and params.output_wav is not None: + logging.info("Synthesizing a single text") + output = synthesize( + model=model, + tokenizer=tokenizer, + n_timesteps=params.n_timesteps, + text=params.input_text, + length_scale=params.length_scale, + temperature=params.temperature, + device=device, + ) + output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) + + sf.write( + file=params.output_wav, + data=output["waveform"], + samplerate=params.sampling_rate, + subtype="PCM_16", + ) + else: + logging.info("Decoding the test set") + infer_dataset( + dl=test_dl, + params=params, + model=model, + vocoder=vocoder, + denoiser=denoiser, + tokenizer=tokenizer, + ) if __name__ == "__main__": main() diff --git a/egs/ljspeech/TTS/matcha/synth.py b/egs/ljspeech/TTS/matcha/synth.py deleted file mode 100755 index e9805f0670..0000000000 --- a/egs/ljspeech/TTS/matcha/synth.py +++ /dev/null @@ -1,166 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2024 Xiaomi Corp. (authors: Fangjun Kuang) - -import argparse -import json -import logging -from pathlib import Path - -import soundfile as sf -import torch -from hifigan.denoiser import Denoiser -from infer import load_vocoder, synthesise, to_waveform -from tokenizer import Tokenizer -from train import get_model, get_params - -from icefall.checkpoint import load_checkpoint - - -def get_parser(): - parser = argparse.ArgumentParser( - formatter_class=argparse.ArgumentDefaultsHelpFormatter - ) - - parser.add_argument( - "--epoch", - type=int, - default=4000, - help="""It specifies the checkpoint to use for decoding. - Note: Epoch counts from 1. - """, - ) - - parser.add_argument( - "--exp-dir", - type=Path, - default="matcha/exp", - help="""The experiment dir. - It specifies the directory where all training related - files, e.g., checkpoints, log, etc, are saved - """, - ) - - parser.add_argument( - "--vocoder", - type=Path, - default="./generator_v1", - help="Path to the vocoder", - ) - - parser.add_argument( - "--tokens", - type=Path, - default="data/tokens.txt", - ) - - parser.add_argument( - "--cmvn", - type=str, - default="data/fbank/cmvn.json", - help="""Path to vocabulary.""", - ) - - parser.add_argument( - "--input-text", - type=str, - required=True, - help="The text to generate speech for", - ) - - parser.add_argument( - "--output-wav", - type=str, - required=True, - help="The filename of the wave to save the generated speech", - ) - - parser.add_argument( - "--sampling-rate", - type=int, - default=22050, - help="The sampling rate of the generated speech (default: 22050 for LJSpeech)", - ) - - return parser - - -@torch.inference_mode() -def main(): - parser = get_parser() - args = parser.parse_args() - params = get_params() - - params.update(vars(args)) - - logging.info("Infer started") - - device = torch.device("cpu") - if torch.cuda.is_available(): - device = torch.device("cuda", 0) - - tokenizer = Tokenizer(params.tokens) - params.blank_id = tokenizer.pad_id - params.vocab_size = tokenizer.vocab_size - params.model_args.n_vocab = params.vocab_size - - with open(params.cmvn) as f: - stats = json.load(f) - params.data_args.data_statistics.mel_mean = stats["fbank_mean"] - params.data_args.data_statistics.mel_std = stats["fbank_std"] - - params.model_args.data_statistics.mel_mean = stats["fbank_mean"] - params.model_args.data_statistics.mel_std = stats["fbank_std"] - - # Number of ODE Solver steps - params.n_timesteps = 2 - - # Changes to the speaking rate - params.length_scale = 1.0 - - # Sampling temperature - params.temperature = 0.667 - logging.info(params) - - logging.info("About to create model") - model = get_model(params) - - load_checkpoint(f"{params.exp_dir}/epoch-{params.epoch}.pt", model) - model.to(device) - model.eval() - - if not Path(params.vocoder).is_file(): - raise ValueError(f"{params.vocoder} does not exist") - - vocoder = load_vocoder(params.vocoder) - vocoder.to(device) - - denoiser = Denoiser(vocoder, mode="zeros") - denoiser.to(device) - - output = synthesise( - model=model, - tokenizer=tokenizer, - n_timesteps=params.n_timesteps, - text=params.input_text, - length_scale=params.length_scale, - temperature=params.temperature, - device=device, - ) - output["waveform"] = to_waveform(output["mel"], vocoder, denoiser) - - sf.write( - file=params.output_wav, - data=output["waveform"], - samplerate=params.sampling_rate, - subtype="PCM_16", - ) - - -if __name__ == "__main__": - formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" - - logging.basicConfig(format=formatter, level=logging.INFO) - - torch.set_num_threads(1) - torch.set_num_interop_threads(1) - main()