Skip to content

Commit

Permalink
minor adjustment
Browse files Browse the repository at this point in the history
  • Loading branch information
JinZr committed Dec 7, 2024
1 parent fb34991 commit 64d8a43
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 178 deletions.
2 changes: 1 addition & 1 deletion .github/scripts/ljspeech/TTS/run-matcha.sh
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ function infer() {

curl -SL -O https://github.com/csukuangfj/models/raw/refs/heads/master/hifigan/generator_v1

./matcha/synth.py \
./matcha/infer.py \
--epoch 1 \
--exp-dir ./matcha/exp \
--tokens data/tokens.txt \
Expand Down
64 changes: 53 additions & 11 deletions egs/ljspeech/TTS/matcha/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,28 @@ def get_parser():
help="""Path to vocabulary.""",
)

# The following arguments are used for inference on single text
parser.add_argument(
"--input-text",
type=str,
required=False,
help="The text to generate speech for",
)

parser.add_argument(
"--output-wav",
type=str,
required=False,
help="The filename of the wave to save the generated speech",
)

parser.add_argument(
"--sampling-rate",
type=int,
default=22050,
help="The sampling rate of the generated speech (default: 22050 for LJSpeech)",
)

return parser


Expand Down Expand Up @@ -103,7 +125,7 @@ def process_text(text: str, tokenizer: Tokenizer, device: str = "cpu") -> dict:
return {"x_orig": text, "x": x, "x_lengths": x_lengths}


def synthesise(
def synthesize(
model: nn.Module,
tokenizer: Tokenizer,
n_timesteps: int,
Expand Down Expand Up @@ -169,7 +191,7 @@ def infer_dataset(
cut_ids = [cut.id for cut in batch["cut"]]

for i in range(batch_size):
output = synthesise(
output = synthesize(
model=model,
tokenizer=tokenizer,
n_timesteps=params.n_timesteps,
Expand Down Expand Up @@ -271,15 +293,35 @@ def main():
denoiser = Denoiser(vocoder, mode="zeros")
denoiser.to(device)

infer_dataset(
dl=test_dl,
params=params,
model=model,
vocoder=vocoder,
denoiser=denoiser,
tokenizer=tokenizer,
)

if params.input_text is not None and params.output_wav is not None:
logging.info("Synthesizing a single text")
output = synthesize(
model=model,
tokenizer=tokenizer,
n_timesteps=params.n_timesteps,
text=params.input_text,
length_scale=params.length_scale,
temperature=params.temperature,
device=device,
)
output["waveform"] = to_waveform(output["mel"], vocoder, denoiser)

sf.write(
file=params.output_wav,
data=output["waveform"],
samplerate=params.sampling_rate,
subtype="PCM_16",
)
else:
logging.info("Decoding the test set")
infer_dataset(
dl=test_dl,
params=params,
model=model,
vocoder=vocoder,
denoiser=denoiser,
tokenizer=tokenizer,
)

if __name__ == "__main__":
main()
166 changes: 0 additions & 166 deletions egs/ljspeech/TTS/matcha/synth.py

This file was deleted.

0 comments on commit 64d8a43

Please sign in to comment.