From 6ae48845ba4a89fb64d2ceeb063c5bb0854d7028 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Sat, 30 Nov 2024 20:12:07 +0800 Subject: [PATCH 1/7] minor fixes --- ...ute_neural_codec_and_prepare_text_tokens.py | 18 +++++++++++++----- egs/wenetspeech4tts/TTS/valle/requirements.txt | 2 ++ 2 files changed, 15 insertions(+), 5 deletions(-) create mode 100644 egs/wenetspeech4tts/TTS/valle/requirements.txt diff --git a/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py b/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py index 5494bf3400..4e0a47c68d 100755 --- a/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py +++ b/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py @@ -516,9 +516,15 @@ def main(): for idx, part in enumerate(cut_sets): if args.audio_extractor: if args.audio_extractor == "Encodec": - storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx if split > 1 else ''}" + if split > 1: + storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx}" + else: + storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}" else: - storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx if split > 1 else ''}" + if split > 1: + storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx}" + else: + storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}" if args.prefix.lower() in [ "ljspeech", @@ -587,9 +593,11 @@ def main(): ].normalized_text, "normalized_text is None" # Save each part with an index if split > 1 - cuts_filename = ( - f"{prefix}cuts_{partition}.{idx if split > 1 else ''}.{args.suffix}" - ) + if split > 1: + cuts_filename = f"{prefix}cuts_{partition}.{idx}.{args.suffix}" + else: + cuts_filename = f"{prefix}cuts_{partition}.{args.suffix}" + part.to_file(f"{args.output_dir}/{cuts_filename}") logging.info(f"Saved {cuts_filename}") diff --git a/egs/wenetspeech4tts/TTS/valle/requirements.txt b/egs/wenetspeech4tts/TTS/valle/requirements.txt new file mode 100644 index 0000000000..06958dbeaf --- /dev/null +++ b/egs/wenetspeech4tts/TTS/valle/requirements.txt @@ -0,0 +1,2 @@ +phonemizer==3.2.1 +git+https://github.com/facebookresearch/encodec.git \ No newline at end of file From 58f7875c7e150523e5d3c73f774035f12ddc7e9f Mon Sep 17 00:00:00 2001 From: zr_jin Date: Thu, 5 Dec 2024 20:08:42 +0800 Subject: [PATCH 2/7] fixed some default params. --- egs/wenetspeech4tts/TTS/valle/infer.py | 2 +- egs/wenetspeech4tts/TTS/valle/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/wenetspeech4tts/TTS/valle/infer.py b/egs/wenetspeech4tts/TTS/valle/infer.py index fd7ba9f216..44a251c561 100644 --- a/egs/wenetspeech4tts/TTS/valle/infer.py +++ b/egs/wenetspeech4tts/TTS/valle/infer.py @@ -86,7 +86,7 @@ def get_args(): parser.add_argument( "--checkpoint", type=str, - default="exp/vallf_nano_full/checkpoint-100000.pt", + default="./valle/exp/checkpoint-100000.pt", help="Path to the saved checkpoint.", ) diff --git a/egs/wenetspeech4tts/TTS/valle/train.py b/egs/wenetspeech4tts/TTS/valle/train.py index fde209511c..27b947b777 100755 --- a/egs/wenetspeech4tts/TTS/valle/train.py +++ b/egs/wenetspeech4tts/TTS/valle/train.py @@ -216,7 +216,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="exp/valle_dev", + default="./valle/exp", help="""The experiment dir. It specifies the directory where all training related files, e.g., checkpoints, log, etc, are saved From 60c5a1d53904700bafda5b42a79ebef70067ea2c Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 6 Dec 2024 10:12:58 +0800 Subject: [PATCH 3/7] added the missing ``visualize`` function --- ...te_neural_codec_and_prepare_text_tokens.py | 8 +- egs/wenetspeech4tts/TTS/valle/valle.py | 85 +++++++++++++++++++ 2 files changed, 91 insertions(+), 2 deletions(-) diff --git a/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py b/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py index 4e0a47c68d..7de2c6202e 100755 --- a/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py +++ b/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py @@ -519,12 +519,16 @@ def main(): if split > 1: storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx}" else: - storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}" + storage_path = ( + f"{args.output_dir}/{args.prefix}_encodec_{partition}" + ) else: if split > 1: storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx}" else: - storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}" + storage_path = ( + f"{args.output_dir}/{args.prefix}_fbank_{partition}" + ) if args.prefix.lower() in [ "ljspeech", diff --git a/egs/wenetspeech4tts/TTS/valle/valle.py b/egs/wenetspeech4tts/TTS/valle/valle.py index b2eb8ae69d..bfe0476176 100644 --- a/egs/wenetspeech4tts/TTS/valle/valle.py +++ b/egs/wenetspeech4tts/TTS/valle/valle.py @@ -19,6 +19,8 @@ from functools import partial from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union +import matplotlib.pyplot as plt +import numpy as np import torch import torch.nn as nn from torch import Tensor @@ -1658,6 +1660,89 @@ def continual( assert len(codes) == 8 return torch.stack(codes, dim=-1) + def visualize( + self, + predicts: Tuple[torch.Tensor], + batch: Dict[str, Union[List, torch.Tensor]], + output_dir: str, + limit: int = 4, + ) -> None: + text_tokens = batch["text_tokens"].to("cpu").detach().numpy() + text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy() + audio_features = batch["audio_features"].to("cpu").detach().numpy() + audio_features_lens = ( + batch["audio_features_lens"].to("cpu").detach().numpy() + ) + assert text_tokens.ndim == 2 + + utt_ids, texts = batch["utt_id"], batch["text"] + + encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy() + decoder_outputs = predicts[1] + if isinstance(decoder_outputs, list): + decoder_outputs = decoder_outputs[-1] + decoder_outputs = ( + decoder_outputs.to("cpu").type(torch.float32).detach().numpy() + ) + + vmin, vmax = 0, 1024 # Encodec + if decoder_outputs.dtype == np.float32: + vmin, vmax = -6, 0 # Fbank + + num_figures = 3 + for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])): + _ = plt.figure(figsize=(14, 8 * num_figures)) + + S = text_tokens_lens[b] + T = audio_features_lens[b] + + # encoder + plt.subplot(num_figures, 1, 1) + plt.title(f"Text: {text}") + plt.imshow( + X=np.transpose(encoder_outputs[b]), + cmap=plt.get_cmap("jet"), + aspect="auto", + interpolation="nearest", + ) + plt.gca().invert_yaxis() + plt.axvline(x=S - 0.4, linewidth=2, color="r") + plt.xlabel("Encoder Output") + plt.colorbar() + + # decoder + plt.subplot(num_figures, 1, 2) + plt.imshow( + X=np.transpose(decoder_outputs[b]), + cmap=plt.get_cmap("jet"), + aspect="auto", + interpolation="nearest", + vmin=vmin, + vmax=vmax, + ) + plt.gca().invert_yaxis() + plt.axvline(x=T - 0.4, linewidth=2, color="r") + plt.xlabel("Decoder Output") + plt.colorbar() + + # target + plt.subplot(num_figures, 1, 3) + plt.imshow( + X=np.transpose(audio_features[b]), + cmap=plt.get_cmap("jet"), + aspect="auto", + interpolation="nearest", + vmin=vmin, + vmax=vmax, + ) + plt.gca().invert_yaxis() + plt.axvline(x=T - 0.4, linewidth=2, color="r") + plt.xlabel("Decoder Target") + plt.colorbar() + + plt.savefig(f"{output_dir}/{utt_id}.png") + plt.close() + # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py def top_k_top_p_filtering( From ce73643af641cf9b0edb22db8fc5a80fb31dd958 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 6 Dec 2024 10:44:14 +0800 Subject: [PATCH 4/7] black formatted --- egs/wenetspeech4tts/TTS/valle/valle.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/egs/wenetspeech4tts/TTS/valle/valle.py b/egs/wenetspeech4tts/TTS/valle/valle.py index bfe0476176..40501736ba 100644 --- a/egs/wenetspeech4tts/TTS/valle/valle.py +++ b/egs/wenetspeech4tts/TTS/valle/valle.py @@ -1670,9 +1670,7 @@ def visualize( text_tokens = batch["text_tokens"].to("cpu").detach().numpy() text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy() audio_features = batch["audio_features"].to("cpu").detach().numpy() - audio_features_lens = ( - batch["audio_features_lens"].to("cpu").detach().numpy() - ) + audio_features_lens = batch["audio_features_lens"].to("cpu").detach().numpy() assert text_tokens.ndim == 2 utt_ids, texts = batch["utt_id"], batch["text"] @@ -1681,9 +1679,7 @@ def visualize( decoder_outputs = predicts[1] if isinstance(decoder_outputs, list): decoder_outputs = decoder_outputs[-1] - decoder_outputs = ( - decoder_outputs.to("cpu").type(torch.float32).detach().numpy() - ) + decoder_outputs = decoder_outputs.to("cpu").type(torch.float32).detach().numpy() vmin, vmax = 0, 1024 # Encodec if decoder_outputs.dtype == np.float32: From 2504036f5bd619b63b762ca2745e6f87b8f52312 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 6 Dec 2024 11:51:40 +0800 Subject: [PATCH 5/7] minor fixes --- egs/wenetspeech4tts/TTS/valle/train.py | 7 +++---- egs/wenetspeech4tts/TTS/valle/valle.py | 7 +++++-- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/egs/wenetspeech4tts/TTS/valle/train.py b/egs/wenetspeech4tts/TTS/valle/train.py index 27b947b777..e9ec548f33 100755 --- a/egs/wenetspeech4tts/TTS/valle/train.py +++ b/egs/wenetspeech4tts/TTS/valle/train.py @@ -4,6 +4,7 @@ # Mingshuang Luo) # Copyright 2023 (authors: Feiteng Li) # Copyright 2024 (authors: Yuekai Zhang) +# Copyright 2024 Tsinghua University (authors: Zengrui Jin,) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -48,10 +49,8 @@ import argparse import copy import logging -import os import random import warnings -from contextlib import nullcontext from pathlib import Path from shutil import copyfile from typing import Any, Dict, Optional, Tuple, Union @@ -686,9 +685,9 @@ def compute_validation_loss( output_dir = Path(f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}") output_dir.mkdir(parents=True, exist_ok=True) if isinstance(model, DDP): - model.module.visualize(predicts, batch, output_dir=output_dir) + model.module.visualize(predicts, batch, tokenizer, output_dir=output_dir) else: - model.visualize(predicts, batch, output_dir=output_dir) + model.visualize(predicts, batch, tokenizer, output_dir=output_dir) return tot_loss diff --git a/egs/wenetspeech4tts/TTS/valle/valle.py b/egs/wenetspeech4tts/TTS/valle/valle.py index 40501736ba..206b843ba1 100644 --- a/egs/wenetspeech4tts/TTS/valle/valle.py +++ b/egs/wenetspeech4tts/TTS/valle/valle.py @@ -23,6 +23,7 @@ import numpy as np import torch import torch.nn as nn +from tokenizer import TextTokenCollater from torch import Tensor from torch.nn import Linear, Module from torch.nn import functional as F @@ -1664,13 +1665,15 @@ def visualize( self, predicts: Tuple[torch.Tensor], batch: Dict[str, Union[List, torch.Tensor]], + tokenizer: TextTokenCollater, output_dir: str, limit: int = 4, ) -> None: - text_tokens = batch["text_tokens"].to("cpu").detach().numpy() - text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy() audio_features = batch["audio_features"].to("cpu").detach().numpy() audio_features_lens = batch["audio_features_lens"].to("cpu").detach().numpy() + + tokens = batch["tokens"] + text_tokens, text_tokens_lens = tokenizer(tokens) assert text_tokens.ndim == 2 utt_ids, texts = batch["utt_id"], batch["text"] From 94126e7f38c971ee02d8bd60dde3b5cc3ca743f6 Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 6 Dec 2024 13:57:59 +0800 Subject: [PATCH 6/7] Update valle.py --- egs/wenetspeech4tts/TTS/valle/valle.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/egs/wenetspeech4tts/TTS/valle/valle.py b/egs/wenetspeech4tts/TTS/valle/valle.py index 206b843ba1..0cf73926b5 100644 --- a/egs/wenetspeech4tts/TTS/valle/valle.py +++ b/egs/wenetspeech4tts/TTS/valle/valle.py @@ -1669,8 +1669,8 @@ def visualize( output_dir: str, limit: int = 4, ) -> None: - audio_features = batch["audio_features"].to("cpu").detach().numpy() - audio_features_lens = batch["audio_features_lens"].to("cpu").detach().numpy() + audio_features = batch["features"].to("cpu").detach().numpy() + audio_features_lens = batch["features_lens"].to("cpu").detach().numpy() tokens = batch["tokens"] text_tokens, text_tokens_lens = tokenizer(tokens) From a8efe19aa42246103edc49f75e1af6da63837f1f Mon Sep 17 00:00:00 2001 From: JinZr Date: Fri, 6 Dec 2024 14:34:53 +0800 Subject: [PATCH 7/7] minor fixes --- egs/wenetspeech4tts/TTS/valle/valle.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/egs/wenetspeech4tts/TTS/valle/valle.py b/egs/wenetspeech4tts/TTS/valle/valle.py index 0cf73926b5..4bfa2b577b 100644 --- a/egs/wenetspeech4tts/TTS/valle/valle.py +++ b/egs/wenetspeech4tts/TTS/valle/valle.py @@ -1676,7 +1676,8 @@ def visualize( text_tokens, text_tokens_lens = tokenizer(tokens) assert text_tokens.ndim == 2 - utt_ids, texts = batch["utt_id"], batch["text"] + texts = batch["text"] + utt_ids = [cut.id for cut in batch["cut"]] encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy() decoder_outputs = predicts[1]