From 60c5a1d53904700bafda5b42a79ebef70067ea2c Mon Sep 17 00:00:00 2001 From: zr_jin Date: Fri, 6 Dec 2024 10:12:58 +0800 Subject: [PATCH] added the missing ``visualize`` function --- ...te_neural_codec_and_prepare_text_tokens.py | 8 +- egs/wenetspeech4tts/TTS/valle/valle.py | 85 +++++++++++++++++++ 2 files changed, 91 insertions(+), 2 deletions(-) diff --git a/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py b/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py index 4e0a47c68d..7de2c6202e 100755 --- a/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py +++ b/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py @@ -519,12 +519,16 @@ def main(): if split > 1: storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx}" else: - storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}" + storage_path = ( + f"{args.output_dir}/{args.prefix}_encodec_{partition}" + ) else: if split > 1: storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx}" else: - storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}" + storage_path = ( + f"{args.output_dir}/{args.prefix}_fbank_{partition}" + ) if args.prefix.lower() in [ "ljspeech", diff --git a/egs/wenetspeech4tts/TTS/valle/valle.py b/egs/wenetspeech4tts/TTS/valle/valle.py index b2eb8ae69d..bfe0476176 100644 --- a/egs/wenetspeech4tts/TTS/valle/valle.py +++ b/egs/wenetspeech4tts/TTS/valle/valle.py @@ -19,6 +19,8 @@ from functools import partial from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union +import matplotlib.pyplot as plt +import numpy as np import torch import torch.nn as nn from torch import Tensor @@ -1658,6 +1660,89 @@ def continual( assert len(codes) == 8 return torch.stack(codes, dim=-1) + def visualize( + self, + predicts: Tuple[torch.Tensor], + batch: Dict[str, Union[List, torch.Tensor]], + output_dir: str, + limit: int = 4, + ) -> None: + text_tokens = batch["text_tokens"].to("cpu").detach().numpy() + text_tokens_lens = batch["text_tokens_lens"].to("cpu").detach().numpy() + audio_features = batch["audio_features"].to("cpu").detach().numpy() + audio_features_lens = ( + batch["audio_features_lens"].to("cpu").detach().numpy() + ) + assert text_tokens.ndim == 2 + + utt_ids, texts = batch["utt_id"], batch["text"] + + encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy() + decoder_outputs = predicts[1] + if isinstance(decoder_outputs, list): + decoder_outputs = decoder_outputs[-1] + decoder_outputs = ( + decoder_outputs.to("cpu").type(torch.float32).detach().numpy() + ) + + vmin, vmax = 0, 1024 # Encodec + if decoder_outputs.dtype == np.float32: + vmin, vmax = -6, 0 # Fbank + + num_figures = 3 + for b, (utt_id, text) in enumerate(zip(utt_ids[:limit], texts[:limit])): + _ = plt.figure(figsize=(14, 8 * num_figures)) + + S = text_tokens_lens[b] + T = audio_features_lens[b] + + # encoder + plt.subplot(num_figures, 1, 1) + plt.title(f"Text: {text}") + plt.imshow( + X=np.transpose(encoder_outputs[b]), + cmap=plt.get_cmap("jet"), + aspect="auto", + interpolation="nearest", + ) + plt.gca().invert_yaxis() + plt.axvline(x=S - 0.4, linewidth=2, color="r") + plt.xlabel("Encoder Output") + plt.colorbar() + + # decoder + plt.subplot(num_figures, 1, 2) + plt.imshow( + X=np.transpose(decoder_outputs[b]), + cmap=plt.get_cmap("jet"), + aspect="auto", + interpolation="nearest", + vmin=vmin, + vmax=vmax, + ) + plt.gca().invert_yaxis() + plt.axvline(x=T - 0.4, linewidth=2, color="r") + plt.xlabel("Decoder Output") + plt.colorbar() + + # target + plt.subplot(num_figures, 1, 3) + plt.imshow( + X=np.transpose(audio_features[b]), + cmap=plt.get_cmap("jet"), + aspect="auto", + interpolation="nearest", + vmin=vmin, + vmax=vmax, + ) + plt.gca().invert_yaxis() + plt.axvline(x=T - 0.4, linewidth=2, color="r") + plt.xlabel("Decoder Target") + plt.colorbar() + + plt.savefig(f"{output_dir}/{utt_id}.png") + plt.close() + # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py def top_k_top_p_filtering(