diff --git a/egs/wenetspeech4tts/TTS/valle/valle.py b/egs/wenetspeech4tts/TTS/valle/valle.py index 0cf73926b5..4bfa2b577b 100644 --- a/egs/wenetspeech4tts/TTS/valle/valle.py +++ b/egs/wenetspeech4tts/TTS/valle/valle.py @@ -1676,7 +1676,8 @@ def visualize( text_tokens, text_tokens_lens = tokenizer(tokens) assert text_tokens.ndim == 2 - utt_ids, texts = batch["utt_id"], batch["text"] + texts = batch["text"] + utt_ids = [cut.id for cut in batch["cut"]] encoder_outputs = predicts[0].to("cpu").type(torch.float32).detach().numpy() decoder_outputs = predicts[1]