From cbe012d54c7007bfbbb82e71a9f1184500fb0824 Mon Sep 17 00:00:00 2001 From: Yuekai Zhang Date: Fri, 22 Nov 2024 11:18:01 +0800 Subject: [PATCH] Valle Recipe for WenetSpeech4TTS, LibriTTS, LibriTTS-R (#1805) * add valle * update readme --- egs/libritts/TTS/README.md | 65 +- ...te_neural_codec_and_prepare_text_tokens.py | 1 + egs/libritts/TTS/prepare.sh | 43 +- egs/libritts/TTS/valle | 1 + egs/wenetspeech4tts/TTS/README.md | 72 + ...te_neural_codec_and_prepare_text_tokens.py | 609 ++++++ .../TTS/local/display_manifest_statistics.py | 53 + egs/wenetspeech4tts/TTS/prepare.sh | 100 + egs/wenetspeech4tts/TTS/shared | 1 + ...te_neural_codec_and_prepare_text_tokens.py | 1 + egs/wenetspeech4tts/TTS/valle/infer.py | 300 +++ egs/wenetspeech4tts/TTS/valle/optim.py | 1 + egs/wenetspeech4tts/TTS/valle/tokenizer.py | 111 ++ egs/wenetspeech4tts/TTS/valle/train.py | 1244 ++++++++++++ .../TTS/valle/tts_datamodule.py | 343 ++++ egs/wenetspeech4tts/TTS/valle/valle.py | 1745 +++++++++++++++++ 16 files changed, 4675 insertions(+), 15 deletions(-) create mode 120000 egs/libritts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py create mode 120000 egs/libritts/TTS/valle create mode 100644 egs/wenetspeech4tts/TTS/README.md create mode 100755 egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py create mode 100755 egs/wenetspeech4tts/TTS/local/display_manifest_statistics.py create mode 100755 egs/wenetspeech4tts/TTS/prepare.sh create mode 120000 egs/wenetspeech4tts/TTS/shared create mode 120000 egs/wenetspeech4tts/TTS/valle/compute_neural_codec_and_prepare_text_tokens.py create mode 100644 egs/wenetspeech4tts/TTS/valle/infer.py create mode 120000 egs/wenetspeech4tts/TTS/valle/optim.py create mode 100644 egs/wenetspeech4tts/TTS/valle/tokenizer.py create mode 100755 egs/wenetspeech4tts/TTS/valle/train.py create mode 100644 egs/wenetspeech4tts/TTS/valle/tts_datamodule.py create mode 100644 egs/wenetspeech4tts/TTS/valle/valle.py diff --git a/egs/libritts/TTS/README.md b/egs/libritts/TTS/README.md index 4d4fb85803..67424a1ca7 100644 --- a/egs/libritts/TTS/README.md +++ b/egs/libritts/TTS/README.md @@ -1,7 +1,7 @@ # Introduction -LibriTTS is a multi-speaker English corpus of approximately 585 hours of read English speech at 24kHz sampling rate, prepared by Heiga Zen with the assistance of Google Speech and Google Brain team members. -The LibriTTS corpus is designed for TTS research. It is derived from the original materials (mp3 audio files from LibriVox and text files from Project Gutenberg) of the LibriSpeech corpus. +LibriTTS is a multi-speaker English corpus of approximately 585 hours of read English speech at 24kHz sampling rate, prepared by Heiga Zen with the assistance of Google Speech and Google Brain team members. +The LibriTTS corpus is designed for TTS research. It is derived from the original materials (mp3 audio files from LibriVox and text files from Project Gutenberg) of the LibriSpeech corpus. The main differences from the LibriSpeech corpus are listed below: 1. The audio files are at 24kHz sampling rate. 2. The speech is split at sentence breaks. @@ -11,16 +11,16 @@ The main differences from the LibriSpeech corpus are listed below: For more information, refer to the paper "LibriTTS: A Corpus Derived from LibriSpeech for Text-to-Speech", Heiga Zen, Viet Dang, Rob Clark, Yu Zhang, Ron J. Weiss, Ye Jia, Zhifeng Chen, and Yonghui Wu, arXiv, 2019. If you use the LibriTTS corpus in your work, please cite this paper where it was introduced. > [!CAUTION] -> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS). +> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS). > While these recipes has the potential to advance various fields such as accessibility, language education, and AI-driven solutions, it also carries certain ethical and legal responsibilities. -> +> > By using this framework, you agree to the following: > 1. Legal and Ethical Use: You shall not use this framework, or any models derived from it, for any unlawful or unethical purposes. This includes, but is not limited to: Creating voice clones without the explicit, informed consent of the individual whose voice is being cloned. Engaging in any form of identity theft, impersonation, or fraud using cloned voices. Violating any local, national, or international laws regarding privacy, intellectual property, or personal data. -> +> > 2. Responsibility of Use: The users of this framework are solely responsible for ensuring that their use of voice cloning technologies complies with all applicable laws and ethical guidelines. We explicitly disclaim any liability for misuse of the technology. -> +> > 3. Attribution and Use of Open-Source Components: This project is provided under the Apache 2.0 license. Users must adhere to the terms of this license and provide appropriate attribution when required. -> +> > 4. No Warranty: This framework is provided “as-is,” without warranty of any kind, either express or implied. We do not guarantee that the use of this software will comply with legal requirements or that it will not infringe the rights of third parties. @@ -49,3 +49,54 @@ To inference, use: --epoch 400 \ --tokens data/tokens.txt ``` + +# [VALL-E](https://arxiv.org/abs/2301.02111) + +./valle contains the code for training VALL-E TTS model. + +Checkpoints and training logs can be found [here](https://huggingface.co/yuekai/vall-e_libritts). The demo of the model trained with libritts and [libritts-r](https://www.openslr.org/141/) is available [here](https://huggingface.co/spaces/yuekai/valle-libritts-demo). + +Preparation: + +``` +bash prepare.sh --start-stage 4 +``` + +The training command is given below: + +``` +world_size=8 +exp_dir=exp/valle + +## Train AR model +python3 valle/train.py --max-duration 320 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 1 \ + --num-buckets 6 --dtype "bfloat16" --save-every-n 1000 --valid-interval 2000 \ + --share-embedding true --norm-first true --add-prenet false \ + --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ + --base-lr 0.03 --warmup-steps 200 --average-period 0 \ + --num-epochs 20 --start-epoch 1 --start-batch 0 --accumulate-grad-steps 1 \ + --exp-dir ${exp_dir} --world-size ${world_size} + +## Train NAR model +# cd ${exp_dir} +# ln -s ${exp_dir}/best-valid-loss.pt epoch-99.pt # --start-epoch 100=99+1 +# cd - +python3 valle/train.py --max-duration 160 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 2 \ + --num-buckets 6 --dtype "float32" --save-every-n 1000 --valid-interval 2000 \ + --share-embedding true --norm-first true --add-prenet false \ + --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ + --base-lr 0.03 --warmup-steps 200 --average-period 0 \ + --num-epochs 40 --start-epoch 100 --start-batch 0 --accumulate-grad-steps 2 \ + --exp-dir ${exp_dir} --world-size ${world_size} +``` + +To inference, use: +``` +huggingface-cli login +huggingface-cli download --local-dir ${exp_dir} yuekai/vall-e_libritts +top_p=1.0 +python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg}_top_p_${top_p} \ + --top-k -1 --temperature 1.0 \ + --text ./libritts.txt \ + --checkpoint ${exp_dir}/epoch-${epoch}-avg-${avg}.pt --top-p ${top_p} +``` diff --git a/egs/libritts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py b/egs/libritts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py new file mode 120000 index 0000000000..68579ffd44 --- /dev/null +++ b/egs/libritts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py @@ -0,0 +1 @@ +../../../wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py \ No newline at end of file diff --git a/egs/libritts/TTS/prepare.sh b/egs/libritts/TTS/prepare.sh index 44016e6d21..1700e0737d 100755 --- a/egs/libritts/TTS/prepare.sh +++ b/egs/libritts/TTS/prepare.sh @@ -32,7 +32,7 @@ if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then cd vits/monotonic_align python setup.py build_ext --inplace cd ../../ - else + else log "monotonic_align lib already built" fi fi @@ -75,11 +75,11 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then log "Stage 2: Compute Spectrogram for LibriTTS" mkdir -p data/spectrogram if [ ! -e data/spectrogram/.libritts.done ]; then - ./local/compute_spectrogram_libritts.py --sampling-rate $sampling_rate + ./local/compute_spectrogram_libritts.py --sampling-rate $sampling_rate touch data/spectrogram/.libritts.done fi - # Here we shuffle and combine the train-clean-100, train-clean-360 and + # Here we shuffle and combine the train-clean-100, train-clean-360 and # train-other-500 together to form the training set. if [ ! -f data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz ]; then cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \ @@ -88,7 +88,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then shuf | gzip -c > data/spectrogram/libritts_cuts_train-all-shuf.jsonl.gz fi - # Here we shuffle and combine the train-clean-100, train-clean-360 + # Here we shuffle and combine the train-clean-100, train-clean-360 # together to form the training set. if [ ! -f data/spectrogram/libritts_cuts_train-clean-460.jsonl.gz ]; then cat <(gunzip -c data/spectrogram/libritts_cuts_train-clean-100.jsonl.gz) \ @@ -108,10 +108,10 @@ if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then log "Stage 3: Prepare phoneme tokens for LibriTTS" # We assume you have installed piper_phonemize and espnet_tts_frontend. # If not, please install them with: - # - piper_phonemize: + # - piper_phonemize: # refer to https://github.com/rhasspy/piper-phonemize, # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 - # - espnet_tts_frontend: + # - espnet_tts_frontend: # `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ if [ ! -e data/spectrogram/.libritts_with_token.done ]; then ./local/prepare_tokens_libritts.py @@ -123,12 +123,39 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then log "Stage 4: Generate token file" # We assume you have installed piper_phonemize and espnet_tts_frontend. # If not, please install them with: - # - piper_phonemize: + # - piper_phonemize: # refer to https://github.com/rhasspy/piper-phonemize, # could install the pre-built wheels from https://github.com/csukuangfj/piper-phonemize/releases/tag/2023.12.5 - # - espnet_tts_frontend: + # - espnet_tts_frontend: # `pip install espnet_tts_frontend`, refer to https://github.com/espnet/espnet_tts_frontend/ if [ ! -e data/tokens.txt ]; then ./local/prepare_token_file.py --tokens data/tokens.txt fi fi + +audio_feats_dir=data/tokenized +dataset_parts="--dataset-parts all" # debug "-p dev-clean -p test-clean" +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Tokenize/Fbank LibriTTS for valle" + mkdir -p ${audio_feats_dir} + if [ ! -e ${audio_feats_dir}/.libritts.tokenize.done ]; then + python3 ./local/compute_neural_codec_and_prepare_text_tokens.py --dataset-parts "${dataset_parts}" \ + --audio-extractor "Encodec" \ + --batch-duration 400 \ + --src-dir "data/manifests" \ + --output-dir "${audio_feats_dir}" + fi + touch ${audio_feats_dir}/.libritts.tokenize.done + + lhotse combine \ + ${audio_feats_dir}/libritts_cuts_train-clean-100.jsonl.gz \ + ${audio_feats_dir}/libritts_cuts_train-clean-360.jsonl.gz \ + ${audio_feats_dir}/libritts_cuts_train-other-500.jsonl.gz \ + ${audio_feats_dir}/cuts_train.jsonl.gz + lhotse copy \ + ${audio_feats_dir}/libritts_cuts_dev-clean.jsonl.gz \ + ${audio_feats_dir}/cuts_dev.jsonl.gz + lhotse copy \ + ${audio_feats_dir}/libritts_cuts_test-clean.jsonl.gz \ + ${audio_feats_dir}/cuts_test.jsonl.gz +fi diff --git a/egs/libritts/TTS/valle b/egs/libritts/TTS/valle new file mode 120000 index 0000000000..c8fe8fdb07 --- /dev/null +++ b/egs/libritts/TTS/valle @@ -0,0 +1 @@ +../../wenetspeech4tts/TTS/valle/ \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/README.md b/egs/wenetspeech4tts/TTS/README.md new file mode 100644 index 0000000000..f35bb51c7f --- /dev/null +++ b/egs/wenetspeech4tts/TTS/README.md @@ -0,0 +1,72 @@ +# Introduction + +[**WenetSpeech4TTS**](https://huggingface.co/datasets/Wenetspeech4TTS/WenetSpeech4TTS) is a multi-domain **Mandarin** corpus derived from the open-sourced [WenetSpeech](https://arxiv.org/abs/2110.03370) dataset. + +> [!CAUTION] +> The next-gen Kaldi framework provides tools and models for generating high-quality, synthetic speech (Text-to-Speech, TTS). +> While these recipes has the potential to advance various fields such as accessibility, language education, and AI-driven solutions, it also carries certain ethical and legal responsibilities. +> +> By using this framework, you agree to the following: +> 1. Legal and Ethical Use: You shall not use this framework, or any models derived from it, for any unlawful or unethical purposes. This includes, but is not limited to: Creating voice clones without the explicit, informed consent of the individual whose voice is being cloned. Engaging in any form of identity theft, impersonation, or fraud using cloned voices. Violating any local, national, or international laws regarding privacy, intellectual property, or personal data. +> +> 2. Responsibility of Use: The users of this framework are solely responsible for ensuring that their use of voice cloning technologies complies with all applicable laws and ethical guidelines. We explicitly disclaim any liability for misuse of the technology. +> +> 3. Attribution and Use of Open-Source Components: This project is provided under the Apache 2.0 license. Users must adhere to the terms of this license and provide appropriate attribution when required. +> +> 4. No Warranty: This framework is provided “as-is,” without warranty of any kind, either express or implied. We do not guarantee that the use of this software will comply with legal requirements or that it will not infringe the rights of third parties. + + +# [VALL-E](https://arxiv.org/abs/2301.02111) + +./valle contains the code for training VALL-E TTS model. + +Checkpoints and training logs can be found [here](https://huggingface.co/yuekai/vall-e_wenetspeech4tts). The demo of the model trained with Wenetspeech4TTS Premium (945 hours) is available [here](https://huggingface.co/spaces/yuekai/valle_wenetspeech4tts_demo). + +Preparation: + +``` +bash prepare.sh +``` + +The training command is given below: + +``` +world_size=8 +exp_dir=exp/valle + +## Train AR model +python3 valle/train.py --max-duration 320 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 1 \ + --num-buckets 6 --dtype "bfloat16" --save-every-n 1000 --valid-interval 2000 \ + --share-embedding true --norm-first true --add-prenet false \ + --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ + --base-lr 0.03 --warmup-steps 200 --average-period 0 \ + --num-epochs 20 --start-epoch 1 --start-batch 0 --accumulate-grad-steps 1 \ + --exp-dir ${exp_dir} --world-size ${world_size} + +## Train NAR model +# cd ${exp_dir} +# ln -s ${exp_dir}/best-valid-loss.pt epoch-99.pt # --start-epoch 100=99+1 +# cd - +python3 valle/train.py --max-duration 160 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 2 \ + --num-buckets 6 --dtype "float32" --save-every-n 1000 --valid-interval 2000 \ + --share-embedding true --norm-first true --add-prenet false \ + --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ + --base-lr 0.03 --warmup-steps 200 --average-period 0 \ + --num-epochs 40 --start-epoch 100 --start-batch 0 --accumulate-grad-steps 2 \ + --exp-dir ${exp_dir} --world-size ${world_size} +``` + +To inference, use: +``` +huggingface-cli login +huggingface-cli download --local-dir ${exp_dir} yuekai/vall-e_wenetspeech4tts +top_p=1.0 +python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg}_top_p_${top_p} \ + --top-k -1 --temperature 1.0 \ + --text ./aishell3.txt \ + --checkpoint ${exp_dir}/epoch-${epoch}-avg-${avg}.pt \ + --text-extractor pypinyin_initials_finals --top-p ${top_p} +``` + +# Credits +- [vall-e](https://github.com/lifeiteng/vall-e) diff --git a/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py b/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py new file mode 100755 index 0000000000..5494bf3400 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py @@ -0,0 +1,609 @@ +#!/usr/bin/env python3 +# Copyright 2023 (authors: Feiteng Li) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Phonemize Text and EnCodec Audio. + +Usage example: + python3 ./local/compute_neural_codec_and_prepare_text_tokens.py --dataset-parts "${dataset_parts}" \ + --text-extractor ${text_extractor} \ + --audio-extractor ${audio_extractor} \ + --batch-duration 2500 --prefix "wenetspeech4tts" \ + --src-dir "data/manifests" --split 100 \ + --output-dir "${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100" + +""" +import argparse +import logging +import os +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import torch +import torch.multiprocessing +from encodec import EncodecModel +from encodec.utils import convert_audio +from lhotse import CutSet, NumpyHdf5Writer +from lhotse.features import FeatureExtractor +from lhotse.recipes.utils import read_manifests_if_cached +from lhotse.utils import Seconds, compute_num_frames +from phonemizer.backend import EspeakBackend +from phonemizer.backend.espeak.language_switch import LanguageSwitch +from phonemizer.backend.espeak.words_mismatch import WordMismatch +from phonemizer.punctuation import Punctuation +from phonemizer.separator import Separator +from tqdm.auto import tqdm + +from icefall.utils import get_executor + +try: + from pypinyin import Style, pinyin + from pypinyin.style._utils import get_finals, get_initials +except Exception: + pass + + +import re +from typing import Pattern + +import numpy as np +from k2 import SymbolTable + +# from valle.data import ( +# AudioTokenConfig, +# AudioTokenExtractor, +# TextTokenizer, +# tokenize_text, +# ) +# from valle.data.fbank import get_fbank_extractor +# from valle.utils import SymbolTable + +os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) +torch.multiprocessing.set_sharing_strategy("file_system") + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--src-dir", + type=Path, + default=Path("data/manifests"), + help="Path to the manifest files", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=Path("data/tokenized"), + help="Path to the tokenized files", + ) + parser.add_argument( + "--text-extractor", + type=str, + default="espeak", + help="espeak or pypinyin or pypinyin_initials_finals", + ) + parser.add_argument( + "--audio-extractor", + type=str, + default="Encodec", + help="Encodec or Fbank", + ) + parser.add_argument( + "--dataset-parts", + type=str, + default="dev-clean test-clean", + help="Space separated dataset parts", + ) + parser.add_argument( + "--prefix", + type=str, + default="libritts", + help="prefix of the manifest file", + ) + parser.add_argument( + "--suffix", + type=str, + default="jsonl.gz", + help="suffix of the manifest file", + ) + parser.add_argument( + "--batch-duration", + type=float, + default=400.0, + help="The maximum number of audio seconds in a batch." + "Determines batch size dynamically.", + ) + parser.add_argument( + "--split", + type=int, + default=1, + help="Split the cut_set into multiple parts", + ) + + return parser.parse_args() + + +class PypinyinBackend: + """PypinyinBackend for Chinese. Most codes is referenced from espnet. + There are two types pinyin or initials_finals, one is + just like "ni1 hao3", the other is like "n i1 h ao3". + """ + + def __init__( + self, + backend="initials_finals", + punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), + ) -> None: + self.backend = backend + self.punctuation_marks = punctuation_marks + + def phonemize( + self, text: List[str], separator: Separator, strip=True, njobs=1 + ) -> List[str]: + assert isinstance(text, List) + phonemized = [] + for _text in text: + _text = re.sub(" +", " ", _text.strip()) + _text = _text.replace(" ", separator.word) + phones = [] + if self.backend == "pypinyin": + for n, py in enumerate( + pinyin(_text, style=Style.TONE3, neutral_tone_with_five=True) + ): + if all([c in self.punctuation_marks for c in py[0]]): + if len(phones): + assert phones[-1] == separator.syllable + phones.pop(-1) + + phones.extend(list(py[0])) + else: + phones.extend([py[0], separator.syllable]) + elif self.backend == "pypinyin_initials_finals": + for n, py in enumerate( + pinyin(_text, style=Style.TONE3, neutral_tone_with_five=True) + ): + if all([c in self.punctuation_marks for c in py[0]]): + if len(phones): + assert phones[-1] == separator.syllable + phones.pop(-1) + phones.extend(list(py[0])) + else: + if py[0][-1].isalnum(): + initial = get_initials(py[0], strict=False) + if py[0][-1].isdigit(): + final = get_finals(py[0][:-1], strict=False) + py[0][-1] + else: + final = get_finals(py[0], strict=False) + phones.extend( + [ + initial, + separator.phone, + final, + separator.syllable, + ] + ) + else: + assert ValueError + else: + raise NotImplementedError + phonemized.append( + "".join(phones).rstrip(f"{separator.word}{separator.syllable}") + ) + return phonemized + + +class TextTokenizer: + """Phonemize Text.""" + + def __init__( + self, + language="en-us", + backend="espeak", + separator=Separator(word="_", syllable="-", phone="|"), + preserve_punctuation=True, + punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), + with_stress: bool = False, + tie: Union[bool, str] = False, + language_switch: LanguageSwitch = "keep-flags", + words_mismatch: WordMismatch = "ignore", + ) -> None: + if backend == "espeak": + phonemizer = EspeakBackend( + language, + punctuation_marks=punctuation_marks, + preserve_punctuation=preserve_punctuation, + with_stress=with_stress, + tie=tie, + language_switch=language_switch, + words_mismatch=words_mismatch, + ) + elif backend in ["pypinyin", "pypinyin_initials_finals"]: + phonemizer = PypinyinBackend( + backend=backend, + punctuation_marks=punctuation_marks + separator.word, + ) + else: + raise NotImplementedError(f"{backend}") + + self.backend = phonemizer + self.separator = separator + + def to_list(self, phonemized: str) -> List[str]: + fields = [] + for word in phonemized.split(self.separator.word): + # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z. + pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE) + fields.extend( + [p for p in pp if p != self.separator.phone] + [self.separator.word] + ) + assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count( + self.separator.phone + ) + return fields[:-1] + + def __call__(self, text, strip=True) -> List[List[str]]: + if isinstance(text, str): + text = [text] + + phonemized = self.backend.phonemize( + text, separator=self.separator, strip=strip, njobs=1 + ) + return [self.to_list(p) for p in phonemized] + + +def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]: + phonemes = tokenizer([text.strip()]) + return phonemes[0] # k2symbols + + +def remove_encodec_weight_norm(model): + from encodec.modules import SConv1d + from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock + from torch.nn.utils import remove_weight_norm + + encoder = model.encoder.model + for key in encoder._modules: + if isinstance(encoder._modules[key], SEANetResnetBlock): + remove_weight_norm(encoder._modules[key].shortcut.conv.conv) + block_modules = encoder._modules[key].block._modules + for skey in block_modules: + if isinstance(block_modules[skey], SConv1d): + remove_weight_norm(block_modules[skey].conv.conv) + elif isinstance(encoder._modules[key], SConv1d): + remove_weight_norm(encoder._modules[key].conv.conv) + + decoder = model.decoder.model + for key in decoder._modules: + if isinstance(decoder._modules[key], SEANetResnetBlock): + remove_weight_norm(decoder._modules[key].shortcut.conv.conv) + block_modules = decoder._modules[key].block._modules + for skey in block_modules: + if isinstance(block_modules[skey], SConv1d): + remove_weight_norm(block_modules[skey].conv.conv) + elif isinstance(decoder._modules[key], SConvTranspose1d): + remove_weight_norm(decoder._modules[key].convtr.convtr) + elif isinstance(decoder._modules[key], SConv1d): + remove_weight_norm(decoder._modules[key].conv.conv) + + +class AudioTokenizer: + """EnCodec audio.""" + + def __init__( + self, + device: Any = None, + ) -> None: + # Instantiate a pretrained EnCodec model + model = EncodecModel.encodec_model_24khz() + model.set_target_bandwidth(6.0) + remove_encodec_weight_norm(model) + + if not device: + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda:0") + + self._device = device + + self.codec = model.to(device) + self.sample_rate = model.sample_rate + self.channels = model.channels + + @property + def device(self): + return self._device + + def encode(self, wav: torch.Tensor) -> torch.Tensor: + return self.codec.encode(wav.to(self.device)) + + def decode(self, frames: torch.Tensor) -> torch.Tensor: + return self.codec.decode(frames) + + +@dataclass +class AudioTokenConfig: + frame_shift: Seconds = 320.0 / 24000 + num_quantizers: int = 8 + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig": + return AudioTokenConfig(**data) + + +class AudioTokenExtractor(FeatureExtractor): + name = "encodec" + config_type = AudioTokenConfig + + def __init__(self, config: Optional[Any] = None): + super(AudioTokenExtractor, self).__init__(config) + self.tokenizer = AudioTokenizer() + + def extract( + self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int + ) -> np.ndarray: + if not isinstance(samples, torch.Tensor): + samples = torch.from_numpy(samples) + if sampling_rate != self.tokenizer.sample_rate: + samples = convert_audio( + samples, + sampling_rate, + self.tokenizer.sample_rate, + self.tokenizer.channels, + ) + if len(samples.shape) == 2: + samples = samples.unsqueeze(0) + else: + raise ValueError() + + device = self.tokenizer.device + encoded_frames = self.tokenizer.encode(samples.detach().to(device)) + codes = encoded_frames[0][0] # [B, n_q, T] + if True: + duration = round(samples.shape[-1] / sampling_rate, ndigits=12) + expected_num_frames = compute_num_frames( + duration=duration, + frame_shift=self.frame_shift, + sampling_rate=sampling_rate, + ) + assert abs(codes.shape[-1] - expected_num_frames) <= 1 + codes = codes[..., :expected_num_frames] + return codes.cpu().squeeze(0).permute(1, 0).numpy() + + @property + def frame_shift(self) -> Seconds: + return self.config.frame_shift + + def feature_dim(self, sampling_rate: int) -> int: + return self.config.num_quantizers + + def pad_tensor_list(self, tensor_list, device, padding_value=0): + lengths = [tensor.shape[0] for tensor in tensor_list] + tensor_list = [torch.Tensor(t).to(device) for t in tensor_list] + padded_tensor = torch.nn.utils.rnn.pad_sequence( + tensor_list, batch_first=True, padding_value=padding_value + ) + return padded_tensor, lengths + + def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray: + samples = [wav.squeeze() for wav in samples] + device = self.tokenizer.device + samples, lengths = self.pad_tensor_list(samples, device) + samples = samples.unsqueeze(1) + + if not isinstance(samples, torch.Tensor): + samples = torch.from_numpy(samples) + if len(samples.shape) != 3: + raise ValueError() + if sampling_rate != self.tokenizer.sample_rate: + samples = [ + convert_audio( + wav, + sampling_rate, + self.tokenizer.sample_rate, + self.tokenizer.channels, + ) + for wav in samples + ] + samples = torch.stack(samples, 0) # convert samples from list to tensor + # Extract discrete codes from EnCodec + with torch.no_grad(): + encoded_frames = self.tokenizer.encode(samples.detach().to(device)) + encoded_frames = encoded_frames[0][0] # [B, n_q, T] + batch_codes = [] + for b, length in enumerate(lengths): + codes = encoded_frames[b] + duration = round(length / sampling_rate, ndigits=12) + expected_num_frames = compute_num_frames( + duration=duration, + frame_shift=self.frame_shift, + sampling_rate=sampling_rate, + ) + batch_codes.append(codes[..., :expected_num_frames]) + return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes] + + +def main(): + args = get_args() + + dataset_parts = args.dataset_parts.replace("--dataset-parts", "").strip() + if dataset_parts == "all": # LibriTTS + dataset_parts = [ + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + "train-clean-360", + "train-other-500", + ] + else: + dataset_parts = dataset_parts.replace("-p", "").strip().split(" ") + + assert len(dataset_parts) >= 1 + + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=args.src_dir, + prefix=args.prefix, + suffix=args.suffix, + types=["recordings", "supervisions", "cuts"], + ) + + text_tokenizer = None + if args.text_extractor: + text_tokenizer = TextTokenizer(backend=args.text_extractor) + + audio_extractor = None + if args.audio_extractor: + if args.audio_extractor == "Encodec": + audio_extractor = AudioTokenExtractor(AudioTokenConfig()) + else: + raise NotImplementedError(f"{args.audio_extractor}") + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + unique_symbols = set() + num_jobs = min(32, os.cpu_count()) + logging.info(f"dataset_parts: {dataset_parts} manifests {len(manifests)}") + + prefix = args.prefix + if prefix and not prefix.endswith("_"): + prefix = f"{prefix}_" + with get_executor() as ex: + for partition, m in manifests.items(): + logging.info( + f"Processing partition: {partition} CUDA: {torch.cuda.is_available()}" + ) + try: + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + except Exception: + cut_set = m["cuts"] + + # Split cut_set if split > 1 + split = 1 + if args.split > 1: + cut_sets = cut_set.split(args.split) + split = args.split + else: + cut_sets = [cut_set] + + for idx, part in enumerate(cut_sets): + if args.audio_extractor: + if args.audio_extractor == "Encodec": + storage_path = f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx if split > 1 else ''}" + else: + storage_path = f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx if split > 1 else ''}" + + if args.prefix.lower() in [ + "ljspeech", + "aishell", + "baker", + "wenetspeech4tts", + ]: + part = part.resample(24000) + assert args.prefix.lower() in [ + "ljspeech", + "aishell", + "baker", + "wenetspeech4tts", + "libritts", + "libritts-r", + ] + with torch.no_grad(): + if ( + torch.cuda.is_available() + and args.audio_extractor == "Encodec" + ): + part = part.compute_and_store_features_batch( + extractor=audio_extractor, + storage_path=storage_path, + num_workers=num_jobs, + batch_duration=args.batch_duration, + collate=False, + overwrite=True, + storage_type=NumpyHdf5Writer, + ) + else: + part = part.compute_and_store_features( + extractor=audio_extractor, + storage_path=storage_path, + num_jobs=num_jobs if ex is None else 64, + executor=ex, + storage_type=NumpyHdf5Writer, + ) + + # TextTokenizer + if args.text_extractor: + for c in tqdm(part): + if args.prefix == "ljspeech": + text = c.supervisions[0].custom["normalized_text"] + text = text.replace(""", '"').replace(""", '"') + phonemes = tokenize_text(text_tokenizer, text=text) + elif args.prefix in [ + "aishell", + "aishell2", + "wenetspeech4tts", + "libritts", + "libritts-r", + ]: + phonemes = tokenize_text( + text_tokenizer, text=c.supervisions[0].text + ) + if c.supervisions[0].custom is None: + c.supervisions[0].custom = {} + c.supervisions[0].normalized_text = c.supervisions[0].text + else: + raise NotImplementedError(f"{args.prefix}") + unique_symbols.update(phonemes) + c.tokens = phonemes + assert c.supervisions[ + 0 + ].normalized_text, "normalized_text is None" + + # Save each part with an index if split > 1 + cuts_filename = ( + f"{prefix}cuts_{partition}.{idx if split > 1 else ''}.{args.suffix}" + ) + part.to_file(f"{args.output_dir}/{cuts_filename}") + logging.info(f"Saved {cuts_filename}") + + if args.text_extractor: + unique_phonemes = SymbolTable() + for s in sorted(list(unique_symbols)): + unique_phonemes.add(s) + logging.info(f"{len(unique_symbols)} unique phonemes: {unique_symbols}") + + unique_phonemes_file = f"{args.output_dir}/unique_text_tokens.k2symbols" + unique_phonemes.to_file(unique_phonemes_file) + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/wenetspeech4tts/TTS/local/display_manifest_statistics.py b/egs/wenetspeech4tts/TTS/local/display_manifest_statistics.py new file mode 100755 index 0000000000..f967dfd2b7 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/local/display_manifest_statistics.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang) +# Copyright 2023 (authors: Feiteng Li) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file displays duration statistics of utterances in the manifests. +You can use the displayed value to choose minimum/maximum duration +to remove short and long utterances during the training. +""" + +import argparse +from pathlib import Path + +from lhotse import load_manifest_lazy + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/tokenized"), + help="Path to the tokenized manifests.", + ) + return parser.parse_args() + + +def main(): + args = get_args() + manifest_dir = args.manifest_dir or Path("data/tokenized") + for part in ["train", "dev", "test"]: + print(f"## {part}") + cuts = load_manifest_lazy(manifest_dir / f"cuts_{part}.jsonl.gz") + cuts.describe() + print("\n") + + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech4tts/TTS/prepare.sh b/egs/wenetspeech4tts/TTS/prepare.sh new file mode 100755 index 0000000000..54e140dbb1 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/prepare.sh @@ -0,0 +1,100 @@ +#!/usr/bin/env bash + +set -eou pipefail + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +stage=1 +stop_stage=4 + +dl_dir=$PWD/download + +dataset_parts="Premium" # Basic for all 10k hours data, Premium for about 10% of the data + +text_extractor="pypinyin_initials_finals" # default is espeak for English +audio_extractor="Encodec" # or Fbank +audio_feats_dir=data/tokenized + +. shared/parse_options.sh || exit 1 + + +# All files generated by this script are saved in "data". +# You can safely remove "data" and rerun this script to regenerate it. +mkdir -p data +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then + log "dl_dir: $dl_dir" + log "Stage 0: Download data" + huggingface-cli login + huggingface-cli download --repo-type dataset --local-dir $dl_dir Wenetspeech4TTS/WenetSpeech4TTS + + # Extract the downloaded data: + for folder in Standard Premium Basic; do + for file in "$dl_dir/$folder"/*.tar.gz; do + tar -xzvf "$file" -C "$dl_dir/$folder" + done + done +fi + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + log "Stage 1: Prepare wenetspeech4tts manifest" + # We assume that you have downloaded the wenetspeech4tts corpus + # to $dl_dir/wenetspeech4tts + mkdir -p data/manifests + if [ ! -e data/manifests/.wenetspeech4tts.done ]; then + lhotse prepare wenetspeech4tts $dl_dir data/manifests --dataset-parts "${dataset_parts}" + touch data/manifests/.wenetspeech4tts.done + fi +fi + + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + log "Stage 2: Tokenize/Fbank wenetspeech4tts" + mkdir -p ${audio_feats_dir} + if [ ! -e ${audio_feats_dir}/.wenetspeech4tts.tokenize.done ]; then + python3 ./local/compute_neural_codec_and_prepare_text_tokens.py --dataset-parts "${dataset_parts}" \ + --text-extractor ${text_extractor} \ + --audio-extractor ${audio_extractor} \ + --batch-duration 2500 --prefix "wenetspeech4tts" \ + --src-dir "data/manifests" \ + --split 100 \ + --output-dir "${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100" + cp ${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100/unique_text_tokens.k2symbols ${audio_feats_dir} + fi + touch ${audio_feats_dir}/.wenetspeech4tts.tokenize.done +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + log "Stage 3: Combine features" + if [ ! -f ${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz ]; then + pieces=$(find ${audio_feats_dir}/wenetspeech4tts_${dataset_parts}_split_100 -name "*.jsonl.gz") + lhotse combine $pieces ${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz + fi +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + log "Stage 4: Prepare wenetspeech4tts train/dev/test" + if [ ! -e ${audio_feats_dir}/.wenetspeech4tts.train.done ]; then + + lhotse subset --first 400 \ + ${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz \ + ${audio_feats_dir}/cuts_dev.jsonl.gz + + lhotse subset --last 400 \ + ${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz \ + ${audio_feats_dir}/cuts_test.jsonl.gz + + lhotse copy \ + ${audio_feats_dir}/wenetspeech4tts_cuts_${dataset_parts}.jsonl.gz \ + ${audio_feats_dir}/cuts_train.jsonl.gz + + touch ${audio_feats_dir}/.wenetspeech4tts.train.done + fi + python3 ./local/display_manifest_statistics.py --manifest-dir ${audio_feats_dir} +fi diff --git a/egs/wenetspeech4tts/TTS/shared b/egs/wenetspeech4tts/TTS/shared new file mode 120000 index 0000000000..4c5e91438c --- /dev/null +++ b/egs/wenetspeech4tts/TTS/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/valle/compute_neural_codec_and_prepare_text_tokens.py b/egs/wenetspeech4tts/TTS/valle/compute_neural_codec_and_prepare_text_tokens.py new file mode 120000 index 0000000000..e70ee319a1 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/valle/compute_neural_codec_and_prepare_text_tokens.py @@ -0,0 +1 @@ +../local/compute_neural_codec_and_prepare_text_tokens.py \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/valle/infer.py b/egs/wenetspeech4tts/TTS/valle/infer.py new file mode 100644 index 0000000000..fd7ba9f216 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/valle/infer.py @@ -0,0 +1,300 @@ +#!/usr/bin/env python3 +# Copyright 2023 (authors: Feiteng Li) +# Copyright 2024 (authors: Yuekai Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script is used to synthesize speech from text prompts and audio prompts. +Usage example: + python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg} \ + --checkpoint=${exp_dir}/epoch-${epoch}-avg-${avg}.pt \ + --text-prompts "KNOT one point one five miles per hour." \ + --audio-prompts ./prompts/8463_294825_000043_000000.wav \ + --text "To get up and running quickly just follow the steps below." + + top_p=1.0 + python3 valle/infer.py --output-dir demos_epoch_${epoch}_avg_${avg}_top_p_${top_p} \ + --top-k -1 --temperature 1.0 \ + --text ./aishell3.txt \ + --checkpoint ${exp_dir}/epoch-${epoch}-avg-${avg}.pt \ + --text-extractor pypinyin_initials_finals --top-p ${top_p} + +""" +import argparse +import logging +import os +from pathlib import Path + +os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + +import torch +import torchaudio +from compute_neural_codec_and_prepare_text_tokens import ( + AudioTokenizer, + TextTokenizer, + tokenize_text, +) +from encodec.utils import convert_audio +from k2 import symbol_table +from tokenizer import get_text_token_collater +from valle import VALLE + +from icefall.utils import AttributeDict, str2bool + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--text-prompts", + type=str, + default="", + help="Text prompts which are separated by |.", + ) + + parser.add_argument( + "--audio-prompts", + type=str, + default="", + help="Audio prompts which are separated by | and should be aligned with --text-prompts.", + ) + + parser.add_argument( + "--text", + type=str, + default="", + help="prompt text\t prompt audio\ttarget text\ttarget audio", + ) + + parser.add_argument( + "--text-extractor", + type=str, + default="espeak", + help="espeak or pypinyin or pypinyin_initials_finals", + ) + + parser.add_argument( + "--checkpoint", + type=str, + default="exp/vallf_nano_full/checkpoint-100000.pt", + help="Path to the saved checkpoint.", + ) + + parser.add_argument( + "--output-dir", + type=Path, + default=Path("infer/demo"), + help="Path to the tokenized files.", + ) + + parser.add_argument( + "--top-k", + type=int, + default=-100, + help="Whether AR Decoder do top_k(if > 0) sampling.", + ) + + parser.add_argument( + "--top-p", + type=float, + default=1.0, + help="Whether AR Decoder do top_p(if > 0) sampling.", + ) + + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="The temperature of AR Decoder top_k sampling.", + ) + + parser.add_argument( + "--continual", + type=str2bool, + default=False, + help="Do continual task.", + ) + + parser.add_argument( + "--repetition-aware-sampling", + type=str2bool, + default=False, + help="Whether AR Decoder do valle-2 repetition-aware sampling. https://arxiv.org/pdf/2406.05370", + ) + + return parser.parse_args() + + +def load_model(checkpoint, device): + if not checkpoint: + return None + + checkpoint = torch.load(checkpoint, map_location=device) + + params = AttributeDict(checkpoint) + model = VALLE( + params.decoder_dim, + params.nhead, + params.num_decoder_layers, + norm_first=params.norm_first, + add_prenet=params.add_prenet, + prefix_mode=params.prefix_mode, + share_embedding=params.share_embedding, + nar_scale_factor=params.scale_factor, + prepend_bos=params.prepend_bos, + num_quantizers=params.num_quantizers, + ) + + missing_keys, unexpected_keys = model.load_state_dict( + checkpoint["model"], strict=True + ) + assert not missing_keys + model.to(device) + model.eval() + + return model, params.text_tokens + + +def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str): + # Load and pre-process the audio waveform + wav, sr = torchaudio.load(audio_path) + wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels) + wav = wav.unsqueeze(0) + + # Extract discrete codes from EnCodec + with torch.no_grad(): + encoded_frames = tokenizer.encode(wav) + return encoded_frames + + +@torch.no_grad() +def main(): + args = get_args() + text_tokenizer = TextTokenizer(backend=args.text_extractor) + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + + model, text_tokens = load_model(args.checkpoint, device) + + text_collater = get_text_token_collater(text_tokens) + + audio_tokenizer = AudioTokenizer() + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + text_prompts = " ".join(args.text_prompts.split("|")) + + audio_prompts = [] + if args.audio_prompts: + for n, audio_file in enumerate(args.audio_prompts.split("|")): + encoded_frames = tokenize_audio(audio_tokenizer, audio_file) + if False: + samples = audio_tokenizer.decode(encoded_frames) + torchaudio.save(f"{args.output_dir}/p{n}.wav", samples[0], 24000) + + audio_prompts.append(encoded_frames[0][0]) + + assert len(args.text_prompts.split("|")) == len(audio_prompts) + audio_prompts = torch.concat(audio_prompts, dim=-1).transpose(2, 1) + audio_prompts = audio_prompts.to(device) + + if os.path.isfile(args.text): # for demos + # https://github.com/lifeiteng/lifeiteng.github.com/blob/main/valle/prepare.py + with open(args.text) as f: + for line in f: + fields = line.strip().split(" ") + fields = [item for item in fields if item] + assert len(fields) == 4 + prompt_text, prompt_audio, text, audio_path = fields + logging.info(f"synthesize text: {text}") + text_tokens, text_tokens_lens = text_collater( + [ + tokenize_text( + text_tokenizer, text=f"{prompt_text} {text}".strip() + ) + ] + ) + _, enroll_x_lens = text_collater( + [tokenize_text(text_tokenizer, text=f"{prompt_text}".strip())] + ) + + audio_prompts = tokenize_audio(audio_tokenizer, prompt_audio) + audio_prompts = audio_prompts[0][0].transpose(2, 1).to(device) + + # synthesis + encoded_frames = model.inference( + text_tokens.to(device), + text_tokens_lens.to(device), + audio_prompts, + enroll_x_lens=enroll_x_lens, + top_k=args.top_k, + temperature=args.temperature, + top_p=args.top_p, + ras=args.repetition_aware_sampling, + ) + + samples = audio_tokenizer.decode( + [(encoded_frames.transpose(2, 1), None)] + ) + # store + # save audio path into args.output_dir + audio_path + audio_path = f"{args.output_dir}/{audio_path}" + # mkdir -p + os.makedirs(os.path.dirname(audio_path), exist_ok=True) + torchaudio.save(audio_path, samples[0].cpu(), 24000) + return + + for n, text in enumerate(args.text.split("|")): + logging.info(f"synthesize text: {text}") + text_tokens, text_tokens_lens = text_collater( + [tokenize_text(text_tokenizer, text=f"{text_prompts} {text}".strip())] + ) + + # synthesis + if args.continual: + assert text == "" + encoded_frames = model.continual( + text_tokens.to(device), + text_tokens_lens.to(device), + audio_prompts, + ) + else: + enroll_x_lens = None + if text_prompts: + _, enroll_x_lens = text_collater( + [tokenize_text(text_tokenizer, text=f"{text_prompts}".strip())] + ) + encoded_frames = model.inference( + text_tokens.to(device), + text_tokens_lens.to(device), + audio_prompts, + enroll_x_lens=enroll_x_lens, + top_k=args.top_k, + temperature=args.temperature, + top_p=args.top_p, + ras=args.repetition_aware_sampling, + ) + + if audio_prompts != []: + samples = audio_tokenizer.decode([(encoded_frames.transpose(2, 1), None)]) + # store + torchaudio.save(f"{args.output_dir}/{n}.wav", samples[0].cpu(), 24000) + else: # Transformer + pass + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/wenetspeech4tts/TTS/valle/optim.py b/egs/wenetspeech4tts/TTS/valle/optim.py new file mode 120000 index 0000000000..5eaa3cffd4 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/valle/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/wenetspeech4tts/TTS/valle/tokenizer.py b/egs/wenetspeech4tts/TTS/valle/tokenizer.py new file mode 100644 index 0000000000..db4f003964 --- /dev/null +++ b/egs/wenetspeech4tts/TTS/valle/tokenizer.py @@ -0,0 +1,111 @@ +from pathlib import Path +from typing import List, Tuple + +import numpy as np +import torch +from k2 import SymbolTable + + +class TextTokenCollater: + """Collate list of text tokens + + Map sentences to integers. Sentences are padded to equal length. + Beginning and end-of-sequence symbols can be added. + + Example: + >>> token_collater = TextTokenCollater(text_tokens) + >>> tokens_batch, tokens_lens = token_collater(text) + + Returns: + tokens_batch: IntTensor of shape (B, L) + B: batch dimension, number of input sentences + L: length of the longest sentence + tokens_lens: IntTensor of shape (B,) + Length of each sentence after adding and + but before padding. + """ + + def __init__( + self, + text_tokens: List[str], + add_eos: bool = True, + add_bos: bool = True, + pad_symbol: str = "", + bos_symbol: str = "", + eos_symbol: str = "", + ): + self.pad_symbol = pad_symbol + + self.add_eos = add_eos + self.add_bos = add_bos + + self.bos_symbol = bos_symbol + self.eos_symbol = eos_symbol + + unique_tokens = ( + [pad_symbol] + + ([bos_symbol] if add_bos else []) + + ([eos_symbol] if add_eos else []) + + sorted(text_tokens) + ) + + self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)} + self.idx2token = [token for token in unique_tokens] + + def index(self, tokens_list: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: + seqs, seq_lens = [], [] + for tokens in tokens_list: + assert all([True if s in self.token2idx else False for s in tokens]) is True + seq = ( + ([self.bos_symbol] if self.add_bos else []) + + list(tokens) + + ([self.eos_symbol] if self.add_eos else []) + ) + seqs.append(seq) + seq_lens.append(len(seq)) + + max_len = max(seq_lens) + for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)): + seq.extend([self.pad_symbol] * (max_len - seq_len)) + + tokens = torch.from_numpy( + np.array( + [[self.token2idx[token] for token in seq] for seq in seqs], + dtype=np.int64, + ) + ) + tokens_lens = torch.IntTensor(seq_lens) + + return tokens, tokens_lens + + def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: + tokens_seqs = [[p for p in text] for text in texts] + max_len = len(max(tokens_seqs, key=len)) + + seqs = [ + ([self.bos_symbol] if self.add_bos else []) + + list(seq) + + ([self.eos_symbol] if self.add_eos else []) + + [self.pad_symbol] * (max_len - len(seq)) + for seq in tokens_seqs + ] + + tokens_batch = torch.from_numpy( + np.array( + [[self.token2idx[token] for token in seq] for seq in seqs], + dtype=np.int64, + ) + ) + + tokens_lens = torch.IntTensor( + [len(seq) + int(self.add_eos) + int(self.add_bos) for seq in tokens_seqs] + ) + + return tokens_batch, tokens_lens + + +def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater: + text_tokens_path = Path(text_tokens_file) + unique_tokens = SymbolTable.from_file(text_tokens_path) + collater = TextTokenCollater(unique_tokens.symbols, add_bos=True, add_eos=True) + return collater diff --git a/egs/wenetspeech4tts/TTS/valle/train.py b/egs/wenetspeech4tts/TTS/valle/train.py new file mode 100755 index 0000000000..fde209511c --- /dev/null +++ b/egs/wenetspeech4tts/TTS/valle/train.py @@ -0,0 +1,1244 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo) +# Copyright 2023 (authors: Feiteng Li) +# Copyright 2024 (authors: Yuekai Zhang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +world_size=8 +exp_dir=exp/valle + +## Train AR model +python3 valle/train.py --max-duration 320 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 1 \ + --num-buckets 6 --dtype "bfloat16" --save-every-n 1000 --valid-interval 2000 \ + --share-embedding true --norm-first true --add-prenet false \ + --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ + --base-lr 0.03 --warmup-steps 200 --average-period 0 \ + --num-epochs 20 --start-epoch 1 --start-batch 0 --accumulate-grad-steps 1 \ + --exp-dir ${exp_dir} --world-size ${world_size} + +## Train NAR model +# cd ${exp_dir} +# ln -s ${exp_dir}/best-valid-loss.pt epoch-99.pt # --start-epoch 100=99+1 +# cd - +python3 valle/train.py --max-duration 160 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 2 \ + --num-buckets 6 --dtype "float32" --save-every-n 1000 --valid-interval 2000 \ + --share-embedding true --norm-first true --add-prenet false \ + --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ + --base-lr 0.03 --warmup-steps 200 --average-period 0 \ + --num-epochs 40 --start-epoch 100 --start-batch 0 --accumulate-grad-steps 2 \ + --exp-dir ${exp_dir} --world-size ${world_size} +""" + +import argparse +import copy +import logging +import os +import random +import warnings +from contextlib import nullcontext +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from lhotse import CutSet +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from optim import Eden, ScaledAdam +from tokenizer import TextTokenCollater, get_text_token_collater +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter +from tts_datamodule import TtsDataModule +from valle import VALLE + +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool + +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--decoder-dim", + type=int, + default=1024, + help="Embedding dimension in the decoder model.", + ) + parser.add_argument( + "--nhead", + type=int, + default=16, + help="Number of attention heads in the Decoder layers.", + ) + parser.add_argument( + "--num-decoder-layers", + type=int, + default=12, + help="Number of Decoder layers.", + ) + parser.add_argument( + "--scale-factor", + type=float, + default=1.0, + help="Model scale factor which will be assigned different meanings in different models.", + ) + parser.add_argument( + "--norm-first", + type=str2bool, + default=True, + help="Pre or Post Normalization.", + ) + parser.add_argument( + "--add-prenet", + type=str2bool, + default=False, + help="Whether add PreNet after Inputs.", + ) + + parser.add_argument( + "--prefix-mode", + type=int, + default=0, + help="The mode for how to prefix VALL-E NAR Decoder, " + "0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.", + ) + parser.add_argument( + "--share-embedding", + type=str2bool, + default=True, + help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.", + ) + parser.add_argument( + "--prepend-bos", + type=str2bool, + default=False, + help="Whether prepend to the acoustic tokens -> AR Decoder inputs.", + ) + parser.add_argument( + "--num-quantizers", + type=int, + default=8, + help="Number of Audio/Semantic quantization layers.", + ) + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=20, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="exp/valle_dev", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--text-tokens", + type=str, + default="data/tokenized/unique_text_tokens.k2symbols", + help="Path to the unique text tokens file", + ) + + parser.add_argument( + "--optimizer-name", + type=str, + default="ScaledAdam", + help="The optimizer.", + ) + parser.add_argument( + "--scheduler-name", + type=str, + default="Eden", + help="The scheduler.", + ) + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + parser.add_argument( + "--warmup-steps", + type=int, + default=200, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=10000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train %% save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + parser.add_argument( + "--valid-interval", + type=int, + default=10000, + help="""Run validation if batch_idx %% valid_interval is 0.""", + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=0, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--accumulate-grad-steps", + type=int, + default=1, + help="""update gradient when batch_idx_train %% accumulate_grad_steps == 0. + """, + ) + + parser.add_argument( + "--dtype", + type=str, + default="float32", + help="Training dtype: float32 bfloat16 float16.", + ) + + parser.add_argument( + "--filter-min-duration", + type=float, + default=0.0, + help="Keep only utterances with duration > this.", + ) + parser.add_argument( + "--filter-max-duration", + type=float, + default=20.0, + help="Keep only utterances with duration < this.", + ) + + parser.add_argument( + "--train-stage", + type=int, + default=0, + help="""0: train all modules, For VALL-E, support 1: AR Decoder 2: NAR Decoder(s) + """, + ) + + parser.add_argument( + "--visualize", + type=str2bool, + default=False, + help="visualize model results in eval step.", + ) + + parser.add_argument( + "--oom-check", + type=str2bool, + default=False, + help="perform OOM check on dataloader batches before starting training.", + ) + + add_model_arguments(parser) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 100, + "reset_interval": 200, + "valid_interval": 10000, + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + if isinstance(model, DDP): + raise ValueError("load_checkpoint before DDP") + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + saved_stage = saved_params.get("train_stage", 0) + if params.train_stage != saved_stage: + # switch training stage + if params.train_stage and saved_stage: # switch between 1 and 2 + params.start_epoch = 1 + params.start_batch = 0 + else: + # switch between 0 and 1/2 + assert params.num_epochs >= params.start_epoch + params.batch_idx_train = saved_params["batch_idx_train"] + + for key in ["optimizer", "grad_scaler", "sampler"]: + if key in saved_params: + saved_params.pop(key) + + # when base on stage 0, we keep scheduler + if saved_stage != 0: + for key in ["scheduler"]: + if key in saved_params: + saved_params.pop(key) + + best_train_filename = params.exp_dir / "best-train-loss.pt" + if best_train_filename.is_file(): + copyfile( + src=best_train_filename, + dst=params.exp_dir / f"best-train-loss-stage{saved_stage}.pt", + ) + + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + if best_valid_filename.is_file(): + copyfile( + src=best_valid_filename, + dst=params.exp_dir / f"best-valid-loss-stage{saved_stage}.pt", + ) + else: + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + + +def prepare_input(batch: dict, tokenizer: TextTokenCollater, device: torch.device): + """Parse batch data""" + + features = batch["features"].to(device) + features_lens = batch["features_lens"].to(device) + if "tokens" not in batch: + raise NotImplementedError("Need to tokenize text") + # tokens = [] + # for c in batch["cuts"]: + # phonemes = tokenize_text( + # tokenizer, text=c.supervisions[0].text + # ) + # tokens.append(phonemes) + else: + tokens = batch["tokens"] + + text_tokens, text_tokens_lens = tokenizer(tokens) + text_tokens = text_tokens.to(device) + text_tokens_lens = text_tokens_lens.to(device) + + return features, features_lens, text_tokens, text_tokens_lens + + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: TextTokenCollater, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = model.device if isinstance(model, DDP) else next(model.parameters()).device + ( + audio_features, + audio_features_lens, + text_tokens, + text_tokens_lens, + ) = prepare_input(batch, tokenizer, device) + # at entry, TextTokens is (N, P) + assert text_tokens.ndim == 2 + assert audio_features.ndim == 3 + + with torch.set_grad_enabled(is_training): + predicts, loss, metrics = model( + x=text_tokens, + x_lens=text_tokens_lens, + y=audio_features, + y_lens=audio_features_lens, + train_stage=params.train_stage, + ) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (audio_features_lens).sum().item() + info["utterances"] = text_tokens.size(0) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + for metric in metrics: + info[metric] = metrics[metric].detach().cpu().item() + del metrics + + return predicts, loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: TextTokenCollater, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + predicts, loss, loss_info = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + if world_size > 1: + tot_loss.reduce(loss.device) + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + if params.visualize: + output_dir = Path(f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}") + output_dir.mkdir(parents=True, exist_ok=True) + if isinstance(model, DDP): + model.module.visualize(predicts, batch, output_dir=output_dir) + else: + model.visualize(predicts, batch, output_dir=output_dir) + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: TextTokenCollater, + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + rng: random.Random, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + rng: + Random for selecting. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + tot_loss = MetricsTracker() + iter_dl = iter(train_dl) + + dtype, enabled = torch.float32, False + if params.dtype in ["bfloat16", "bf16"]: + dtype, enabled = torch.bfloat16, True + elif params.dtype in ["float16", "fp16"]: + dtype, enabled = torch.float16, True + + batch_idx = 0 + while True: + try: + batch = next(iter_dl) + except StopIteration: + logging.info("Reaches end of dataloader.") + break + + batch_idx += 1 + + params.batch_idx_train += 1 + batch_size = len(batch["text"]) + + try: + with torch.cuda.amp.autocast(dtype=dtype, enabled=enabled): + _, loss, loss_info = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = (tot_loss * (1 - 1 / params.reset_interval)) + loss_info * ( + 1 / params.reset_interval + ) + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + scaler.scale(loss).backward() + if params.batch_idx_train >= params.accumulate_grad_steps: + if params.batch_idx_train % params.accumulate_grad_steps == 0: + if params.optimizer_name not in ["ScaledAdam", "Eve"]: + # Unscales the gradients of optimizer's assigned params in-place + scaler.unscale_(optimizer) + # Since the gradients of optimizer's assigned params are unscaled, clips as usual: + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + for k in range(params.accumulate_grad_steps): + if isinstance(scheduler, Eden): + scheduler.step_batch(params.batch_idx_train) + else: + scheduler.step() + + set_batch_count(model, params.batch_idx_train) + except: # noqa + display_and_save_batch(batch, params=params) + raise + + if params.average_period > 0: + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + # Perform Operation in rank 0 + if rank == 0: + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + # Perform Operation in rank 0 + if rank == 0: + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.dtype in ["float16", "fp16"]: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or (cur_grad_scale < 8.0 and batch_idx % 400 == 0): + scaler.update(cur_grad_scale * 2.0) + + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = ( + scaler._scale.item() if params.dtype in ["float16", "fp16"] else 1.0 + ) + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, train_loss[{loss_info}], " + f"tot_loss[{tot_loss}], " + f"batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + + ( + f", grad_scale: {cur_grad_scale}" + if params.dtype in ["float16", "fp16"] + else "" + ) + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, + "train/current_", + params.batch_idx_train, + ) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train) + if params.dtype in ["float16", "fp16"]: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if params.batch_idx_train % params.valid_interval == 0: + # Calculate validation loss in Rank 0 + model.eval() + logging.info("Computing validation loss") + with torch.cuda.amp.autocast(dtype=dtype): + valid_info = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + valid_dl=valid_dl, + world_size=world_size, + ) + logging.info(f"Epoch {params.cur_epoch}, validation: {valid_info}") + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + model.train() + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def filter_short_and_long_utterances( + cuts: CutSet, min_duration: float, max_duration: float +) -> CutSet: + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 0.6 second and 20 seconds + if c.duration < min_duration or c.duration > max_duration: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + cuts = cuts.filter(remove_short_and_long_utt) + + return cuts + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + rng = random.Random(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + if params.train_stage: + tb_writer = SummaryWriter( + log_dir=f"{params.exp_dir}/tensorboard_stage{params.train_stage}" + ) + else: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + # https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + + logging.info(f"Device: {device}") + + tokenizer = get_text_token_collater(params.text_tokens) + logging.info(params) + + logging.info("About to create model") + + model = VALLE( + params.decoder_dim, + params.nhead, + params.num_decoder_layers, + norm_first=params.norm_first, + add_prenet=params.add_prenet, + prefix_mode=params.prefix_mode, + share_embedding=params.share_embedding, + nar_scale_factor=params.scale_factor, + prepend_bos=params.prepend_bos, + num_quantizers=params.num_quantizers, + ) + + with open(f"{params.exp_dir}/model.txt", "w") as f: + print(model) + print(model, file=f) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0 and params.average_period > 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + if params.train_stage: + _model = model.module if isinstance(model, DDP) else model + model_parameters = _model.stage_parameters(params.train_stage) + else: + model_parameters = model.parameters() + + if params.optimizer_name == "ScaledAdam": + optimizer = ScaledAdam( + model_parameters, + lr=params.base_lr, + clipping_scale=2.0, + ) + elif params.optimizer_name == "AdamW": + optimizer = torch.optim.AdamW( + model_parameters, + lr=params.base_lr, + betas=(0.9, 0.95), + weight_decay=1e-2, + eps=1e-8, + ) + elif params.optimizer_name == "Adam": + optimizer = torch.optim.Adam( + model_parameters, + lr=params.base_lr, + betas=(0.9, 0.95), + eps=1e-8, + ) + else: + raise NotImplementedError() + + scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps) + optimizer.zero_grad() + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.inf_check: + register_inf_check_hooks(model) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + dataset = TtsDataModule(args) + train_cuts = dataset.train_cuts() + valid_cuts = dataset.dev_cuts() + + train_cuts = filter_short_and_long_utterances( + train_cuts, params.filter_min_duration, params.filter_max_duration + ) + valid_cuts = filter_short_and_long_utterances( + valid_cuts, params.filter_min_duration, params.filter_max_duration + ) + + train_dl = dataset.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + valid_dl = dataset.dev_dataloaders(valid_cuts) + + if params.oom_check: + scan_pessimistic_batches_for_oom( + model=model, + tokenizer=tokenizer, + train_dl=train_dl, + optimizer=optimizer, + params=params, + ) + + scaler = GradScaler(enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + if isinstance(scheduler, Eden): + scheduler.step_epoch(epoch - 1) + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + tokenizer=tokenizer, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + train_dl=train_dl, + valid_dl=valid_dl, + rng=rng, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + tokenizer: TextTokenCollater, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + + dtype = torch.float32 + if params.dtype in ["bfloat16", "bf16"]: + dtype = torch.bfloat16 + elif params.dtype in ["float16", "fp16"]: + dtype = torch.float16 + + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(dtype=dtype): + _, loss, _ = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + TtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/wenetspeech4tts/TTS/valle/tts_datamodule.py b/egs/wenetspeech4tts/TTS/valle/tts_datamodule.py new file mode 100644 index 0000000000..8e34d06dcb --- /dev/null +++ b/egs/wenetspeech4tts/TTS/valle/tts_datamodule.py @@ -0,0 +1,343 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao, +# Zengrui Jin,) +# Copyright 2023 (authors: Feiteng Li) +# Copyright 2024 (Author: Yuekai Zhang) +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.features.io import KaldiReader +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + + +class TtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in TTS + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/tokenized"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--speaker-embeds", + type=Path, + default=Path("exp/xvector_nnet_1a/"), + help="Path to directory with speaker embeddings.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=4, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=False, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + group.add_argument( + "--dataset", + type=str, + default="libritts", + help="--input-strategy PromptedPrecomputedFeatures needs dataset name to prepare prompts.", + ) + + parser.add_argument( + "--sampling-rate", + type=int, + default=24000, + help="""Audio sampling rate.""", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + raise NotImplementedError + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + raise NotImplementedError + else: + validate = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + dev_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create valid dataloader") + dev_dl = DataLoader( + validate, + sampler=dev_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + ) + + return dev_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + raise NotImplementedError + else: + test = SpeechSynthesisDataset( + return_text=True, + return_tokens=True, + return_spk_ids=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_train.jsonl.gz") + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz") + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz") + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz" + ) diff --git a/egs/wenetspeech4tts/TTS/valle/valle.py b/egs/wenetspeech4tts/TTS/valle/valle.py new file mode 100644 index 0000000000..b2eb8ae69d --- /dev/null +++ b/egs/wenetspeech4tts/TTS/valle/valle.py @@ -0,0 +1,1745 @@ +# Copyright 2023 (authors: Feiteng Li) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy +import math +import numbers +import random +from functools import partial +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +from torch import Tensor +from torch.nn import Linear, Module +from torch.nn import functional as F +from torch.nn.init import constant_, xavier_normal_, xavier_uniform_ +from torch.nn.modules.linear import NonDynamicallyQuantizableLinear +from torch.nn.parameter import Parameter +from torchmetrics.classification import MulticlassAccuracy + +from icefall.utils import make_pad_mask + +NUM_TEXT_TOKENS = 5000 +NUM_AUDIO_TOKENS = 1024 # EnCodec RVQ bins + + +class PromptedFeatures: + def __init__(self, prompts, features): + self.prompts = prompts + self.features = features + + def to(self, device): + return PromptedFeatures(self.prompts.to(device), self.features.to(device)) + + def sum(self): + return self.features.sum() + + @property + def ndim(self): + return self.features.ndim + + @property + def data(self): + return (self.prompts, self.features) + + +class TokenEmbedding(nn.Module): + def __init__( + self, + dim_model: int, + vocab_size: int, + dropout: float = 0.0, + ): + super().__init__() + + self.vocab_size = vocab_size + self.dim_model = dim_model + + self.dropout = torch.nn.Dropout(p=dropout) + self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model) + + @property + def weight(self) -> torch.Tensor: + return self.word_embeddings.weight + + def embedding(self, index: int) -> torch.Tensor: + return self.word_embeddings.weight[index : index + 1] + + def forward(self, x: torch.Tensor): + X = self.word_embeddings(x) + X = self.dropout(X) + + return X + + +class SinePositionalEmbedding(nn.Module): + def __init__( + self, + dim_model: int, + dropout: float = 0.0, + scale: bool = False, + alpha: bool = False, + ): + super().__init__() + self.dim_model = dim_model + self.x_scale = math.sqrt(dim_model) if scale else 1.0 + self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) + self.dropout = torch.nn.Dropout(p=dropout) + + self.reverse = False + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, 4000)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.dim_model) + if self.reverse: + position = torch.arange( + x.size(1) - 1, -1, -1.0, dtype=torch.float32 + ).unsqueeze(1) + else: + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.dim_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.dim_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype).detach() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + self.extend_pe(x) + output = x.unsqueeze(-1) if x.ndim == 2 else x + output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)] + return self.dropout(output) + + +class Transpose(nn.Identity): + """(N, T, D) -> (N, D, T)""" + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return input.transpose(1, 2) + + +_shape_t = Union[int, List[int], torch.Size] + + +class MultiheadAttention(Module): + r"""Allows the model to jointly attend to information + from different representation subspaces as described in the paper: + `Attention Is All You Need `_. + + Multi-Head Attention is defined as: + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + + where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. + + ``forward()`` will use a special optimized implementation if all of the following + conditions are met: + + - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This + restriction will be loosened in the future.) + - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad`` + - training is disabled (using ``.eval()``) + - dropout is 0 + - ``add_bias_kv`` is ``False`` + - ``add_zero_attn`` is ``False`` + - ``batch_first`` is ``True`` and the input is batched + - ``kdim`` and ``vdim`` are equal to ``embed_dim`` + - at most one of ``key_padding_mask`` or ``attn_mask`` is passed + - if a `NestedTensor `_ is passed, neither ``key_padding_mask`` + nor ``attn_mask`` is passed + + If the optimized implementation is in use, a + `NestedTensor `_ can be passed for + ``query``/``key``/``value`` to represent padding more efficiently than using a + padding mask. In this case, a `NestedTensor `_ + will be returned, and an additional speedup proportional to the fraction of the input + that is padding can be expected. + + Args: + embed_dim: Total dimension of the model. + num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split + across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). + dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). + bias: If specified, adds bias to input / output projection layers. Default: ``True``. + add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. + add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. + Default: ``False``. + kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). + vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + + Examples:: + + >>> # xdoctest: +SKIP + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + + """ + __constants__ = ["batch_first"] + bias_k: Optional[torch.Tensor] + bias_v: Optional[torch.Tensor] + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + linear1_cls=Linear, + linear2_cls=Linear, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.batch_first = batch_first + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + if add_bias_kv: + self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + self.bias_v = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs)) + else: + self.bias_k = self.bias_v = None + + if linear1_cls == Linear: + if not self._qkv_same_embed_dim: + self.q_proj_weight = Parameter( + torch.empty((embed_dim, embed_dim), **factory_kwargs) + ) + self.k_proj_weight = Parameter( + torch.empty((embed_dim, self.kdim), **factory_kwargs) + ) + self.v_proj_weight = Parameter( + torch.empty((embed_dim, self.vdim), **factory_kwargs) + ) + self.register_parameter("in_proj_weight", None) + else: + self.in_proj_weight = Parameter( + torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) + ) + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias: + self.in_proj_bias = Parameter( + torch.empty(3 * embed_dim, **factory_kwargs) + ) + else: + self.register_parameter("in_proj_bias", None) + self.out_proj = NonDynamicallyQuantizableLinear( + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) + + self._reset_parameters() + else: + if not self._qkv_same_embed_dim: + raise NotImplementedError + else: + self.in_proj_linear = linear1_cls( + embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs + ) + self.in_proj_weight = self.in_proj_linear.weight + + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias: + self.in_proj_bias = self.in_proj_linear.bias + else: + self.register_parameter("in_proj_bias", None) + + self.out_proj = linear2_cls( + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) + + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + self.add_zero_attn = add_zero_attn + + def _reset_parameters(self): + if self._qkv_same_embed_dim: + xavier_uniform_(self.in_proj_weight) + else: + xavier_uniform_(self.q_proj_weight) + xavier_uniform_(self.k_proj_weight) + xavier_uniform_(self.v_proj_weight) + + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.0) + constant_(self.out_proj.bias, 0.0) + + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if "_qkv_same_embed_dim" not in state: + state["_qkv_same_embed_dim"] = True + + super(MultiheadAttention, self).__setstate__(state) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` + or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, + :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. + Queries are compared against key-value pairs to produce the output. + See "Attention Is All You Need" for more details. + key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` + or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, + :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. + See "Attention Is All You Need" for more details. + value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when + ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source + sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. + See "Attention Is All You Need" for more details. + key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` + to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. + Binary and byte masks are supported. + For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for + the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. + need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. + Default: ``True``. + attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape + :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, + :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be + broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. + Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the + corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the + corresponding position is not allowed to attend. For a float mask, the mask values will be added to + the attention weight. + average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across + heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an + effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) + + Outputs: + - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, + :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, + where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the + embedding dimension ``embed_dim``. + - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, + returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. + + .. note:: + `batch_first` argument is ignored for unbatched inputs. + """ + is_batched = query.dim() == 3 + if key_padding_mask is not None: + _kpm_dtype = key_padding_mask.dtype + if _kpm_dtype != torch.bool and not torch.is_floating_point( + key_padding_mask + ): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported" + ) + why_not_fast_path = "" + if not is_batched: + why_not_fast_path = ( + f"input not batched; expected query.dim() of 3 but got {query.dim()}" + ) + elif query is not key or key is not value: + # When lifting this restriction, don't forget to either + # enforce that the dtypes all match or test cases where + # they don't! + why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" + elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype: + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" + elif ( + self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype + ): + # this case will fail anyway, but at least they'll get a useful error message. + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" + elif self.training: + why_not_fast_path = "training is enabled" + elif not self.batch_first: + why_not_fast_path = "batch_first was not True" + elif self.bias_k is not None: + why_not_fast_path = "self.bias_k was not None" + elif self.bias_v is not None: + why_not_fast_path = "self.bias_v was not None" + elif self.dropout: + why_not_fast_path = f"dropout was {self.dropout}, required zero" + elif self.add_zero_attn: + why_not_fast_path = "add_zero_attn was enabled" + elif not self._qkv_same_embed_dim: + why_not_fast_path = "_qkv_same_embed_dim was not True" + elif attn_mask is not None: + why_not_fast_path = "attn_mask was not None" + elif query.is_nested and key_padding_mask is not None: + why_not_fast_path = ( + "key_padding_mask is not supported with NestedTensor input" + ) + elif self.num_heads % 2 == 1: + why_not_fast_path = "num_heads is odd" + elif torch.is_autocast_enabled(): + why_not_fast_path = "autocast is enabled" + + if not why_not_fast_path: + tensor_args = ( + query, + key, + value, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj.weight, + self.out_proj.bias, + ) + # We have to use list comprehensions below because TorchScript does not support + # generator expressions. + if torch.overrides.has_torch_function(tensor_args): + why_not_fast_path = "some Tensor argument has_torch_function" + elif not all( + [ + (x is None or x.is_cuda or "cpu" in str(x.device)) + for x in tensor_args + ] + ): + why_not_fast_path = "some Tensor argument is neither CUDA nor CPU" + elif torch.is_grad_enabled() and any( + [x is not None and x.requires_grad for x in tensor_args] + ): + why_not_fast_path = ( + "grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad" + ) + if not why_not_fast_path: + return torch._native_multi_head_attention( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj.weight, + self.out_proj.bias, + key_padding_mask if key_padding_mask is not None else attn_mask, + need_weights, + average_attn_weights, + 1 + if key_padding_mask is not None + else 0 + if attn_mask is not None + else None, + ) + + any_nested = query.is_nested or key.is_nested or value.is_nested + assert not any_nested, ( + "MultiheadAttention does not support NestedTensor outside of its fast path. " + + f"The fast path was not hit because {why_not_fast_path}" + ) + + if self.batch_first and is_batched: + # make sure that the transpose op does not affect the "is" property + if key is value: + if query is key: + query = key = value = query.transpose(1, 0) + else: + query, key = [x.transpose(1, 0) for x in (query, key)] + value = key + else: + query, key, value = [x.transpose(1, 0) for x in (query, key, value)] + + if not self._qkv_same_embed_dim: + attn_output, attn_output_weights = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + average_attn_weights=average_attn_weights, + ) + else: + attn_output, attn_output_weights = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + average_attn_weights=average_attn_weights, + ) + if self.batch_first and is_batched: + return attn_output.transpose(1, 0), attn_output_weights + else: + return attn_output, attn_output_weights + + +class LayerNorm(nn.Module): + __constants__ = ["normalized_shape", "eps", "elementwise_affine"] + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__( + self, + normalized_shape: _shape_t, + eps: float = 1e-5, + elementwise_affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + self.bias = nn.Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + return ( + F.layer_norm( + input, + self.normalized_shape, + self.weight, + self.bias, + self.eps, + ), + embedding, + ) + + assert embedding is None + return F.layer_norm( + input, self.normalized_shape, self.weight, self.bias, self.eps + ) + + def extra_repr(self) -> str: + return ( + "{normalized_shape}, eps={eps}, " + "elementwise_affine={elementwise_affine}".format(**self.__dict__) + ) + + +class AdaptiveLayerNorm(nn.Module): + r"""Adaptive Layer Normalization""" + + def __init__(self, d_model, norm) -> None: + super(AdaptiveLayerNorm, self).__init__() + self.project_layer = nn.Linear(d_model, 2 * d_model) + self.norm = norm + self.d_model = d_model + self.eps = self.norm.eps + + def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return (weight * self.norm(input) + bias, embedding) + + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return weight * self.norm(input) + bias + + +class TransformerEncoderLayer(nn.Module): + __constants__ = ["batch_first", "norm_first"] + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + batch_first: bool = False, + norm_first: bool = False, + device=None, + dtype=None, + linear1_self_attention_cls: nn.Module = nn.Linear, + linear2_self_attention_cls: nn.Module = nn.Linear, + linear1_feedforward_cls: nn.Module = nn.Linear, + linear2_feedforward_cls: nn.Module = nn.Linear, + layer_norm_cls: nn.Module = LayerNorm, + layer_norm_eps: float = 1e-5, + adaptive_layer_norm=False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(TransformerEncoderLayer, self).__init__() + self.self_attn = MultiheadAttention( + d_model, + nhead, + dropout=dropout, + batch_first=batch_first, + linear1_cls=linear1_self_attention_cls, + linear2_cls=linear2_self_attention_cls, + **factory_kwargs, + ) + + # Implementation of Feedforward model + self.linear1 = linear1_feedforward_cls( + d_model, dim_feedforward, **factory_kwargs + ) + self.dropout = nn.Dropout(dropout) + self.linear2 = linear2_feedforward_cls( + dim_feedforward, d_model, **factory_kwargs + ) + + self.norm_first = norm_first + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + activation = _get_activation_fn(activation) + elif isinstance(activation, partial): + activation = activation(d_model) + # elif activation == BalancedDoubleSwish: + # activation = BalancedDoubleSwish(d_model) + + # # We can't test self.activation in forward() in TorchScript, + # # so stash some information about it instead. + # if activation is F.relu or isinstance(activation, torch.nn.ReLU): + # self.activation_relu_or_gelu = 1 + # elif activation is F.gelu or isinstance(activation, torch.nn.GELU): + # self.activation_relu_or_gelu = 2 + # else: + # self.activation_relu_or_gelu = 0 + self.activation = activation + + norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) + # if layer_norm_cls == IdentityNorm: + # norm2 = BalancedBasicNorm(d_model, eps=layer_norm_eps, **factory_kwargs) + # else: + if True: + norm2 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) + + if adaptive_layer_norm: + self.norm1 = AdaptiveLayerNorm(d_model, norm1) + self.norm2 = AdaptiveLayerNorm(d_model, norm2) + else: + self.norm1 = norm1 + self.norm2 = norm2 + + def __setstate__(self, state): + super(TransformerEncoderLayer, self).__setstate__(state) + if not hasattr(self, "activation"): + self.activation = F.relu + + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + x, stage_embedding = src, None + is_src_tuple = False + if isinstance(src, tuple): + x, stage_embedding = src + is_src_tuple = True + + if src_key_padding_mask is not None: + _skpm_dtype = src_key_padding_mask.dtype + if _skpm_dtype != torch.bool and not torch.is_floating_point( + src_key_padding_mask + ): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported" + ) + + if self.norm_first: + x = x + self._sa_block( + self.norm1(x, stage_embedding), + src_mask, + src_key_padding_mask, + ) + x = x + self._ff_block(self.norm2(x, stage_embedding)) + else: + x = self.norm1( + x + self._sa_block(x, src_mask, src_key_padding_mask), + stage_embedding, + ) + x = self.norm2(x + self._ff_block(x), stage_embedding) + + if is_src_tuple: + return (x, stage_embedding) + return x + + # self-attention block + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + ) -> Tensor: + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + )[0] + return self.dropout1(x) + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + +class TransformerEncoder(nn.Module): + r"""TransformerEncoder is a stack of N encoder layers. Users can build the + BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. + + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + enable_nested_tensor: if True, input will automatically convert to nested tensor + (and convert back on output). This will improve the overall performance of + TransformerEncoder when padding rate is high. Default: ``True`` (enabled). + + Examples:: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = transformer_encoder(src) + """ + __constants__ = ["norm"] + + def __init__(self, encoder_layer, num_layers, norm=None): + super(TransformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + return_layer_states: bool = False, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + return_layer_states: return layers' state (optional). + + Shape: + see the docs in Transformer class. + """ + if return_layer_states: + layer_states = [] # layers' output + output = src + for mod in self.layers: + output = mod( + output, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + ) + layer_states.append(output[0]) + + if self.norm is not None: + output = self.norm(output) + + return layer_states, output + + output = src + for mod in self.layers: + output = mod( + output, src_mask=mask, src_key_padding_mask=src_key_padding_mask + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) + + +class VALLE(nn.Module): + """It implements https://arxiv.org/abs/2301.02111 + "Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers" + """ + + def __init__( + self, + d_model: int, + nhead: int, + num_layers: int, + norm_first: bool = True, + add_prenet: bool = False, + decoder_cls=TransformerEncoder, + decoder_layer_cls=TransformerEncoderLayer, + prefix_mode: int = 0, + share_embedding: bool = True, + nar_scale_factor: float = 1.0, + prepend_bos: bool = False, + num_quantizers: int = 8, + **kwargs, + ): + """ + Args: + d_model: + The number of expected features in the input (required). + nhead: + The number of heads in the multiheadattention models (required). + num_layers: + The number of sub-decoder-layers in the decoder (required). + """ + super().__init__() + nar_d_model = int(d_model * nar_scale_factor) + + self.ar_text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x + self.nar_text_embedding = TokenEmbedding(nar_d_model, NUM_TEXT_TOKENS) + + # ID NUM_AUDIO_TOKENS -> PAD + # ID NUM_AUDIO_TOKENS + 1 -> BOS + self.ar_audio_prepend_bos = prepend_bos + self.ar_audio_embedding = TokenEmbedding( + d_model, NUM_AUDIO_TOKENS + 1 + int(prepend_bos) + ) + + # PreNet + if add_prenet: + self.ar_text_prenet = nn.Sequential( + Transpose(), + nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), + nn.BatchNorm1d(d_model), + nn.ReLU(), + nn.Dropout(0.5), + nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), + nn.BatchNorm1d(d_model), + nn.ReLU(), + nn.Dropout(0.5), + nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), + nn.BatchNorm1d(d_model), + nn.ReLU(), + nn.Dropout(0.5), + Transpose(), + nn.Linear(d_model, d_model), + ) + + self.ar_audio_prenet = nn.Sequential( + nn.Linear(d_model, 256), + nn.ReLU(), + nn.Dropout(0.25), + nn.Linear(256, 256), + nn.ReLU(), + nn.Dropout(0.25), + nn.Linear(256, d_model), + ) + else: + self.ar_text_prenet = nn.Identity() + self.ar_audio_prenet = nn.Identity() + + self.ar_text_position = SinePositionalEmbedding( + d_model, + dropout=0.1, + scale=False, + alpha=True, + ) + self.ar_audio_position = SinePositionalEmbedding( + d_model, + dropout=0.1, + scale=False, + alpha=True, + ) + + self.ar_decoder = decoder_cls( + decoder_layer_cls( + d_model, + nhead, + dim_feedforward=d_model * 4, + dropout=0.1, + batch_first=True, + norm_first=norm_first, + ), + num_layers=num_layers, + norm=LayerNorm(d_model) if norm_first else None, + ) + self.ar_predict_layer = nn.Linear(d_model, NUM_AUDIO_TOKENS + 1, bias=False) + + self.ar_accuracy_metric = MulticlassAccuracy( + NUM_AUDIO_TOKENS + 1, + top_k=10, + average="micro", + multidim_average="global", + ignore_index=NUM_AUDIO_TOKENS, + ) + + self.rng = random.Random(0) + self.num_heads = nhead + self.prefix_mode = prefix_mode + self.num_quantizers = num_quantizers + + assert num_quantizers >= 1 + if num_quantizers > 1: + self.nar_audio_embeddings = nn.ModuleList( + [TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS + 1)] + + [ + TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS) + for i in range(num_quantizers - 1) + ] + ) # W_a + + # PreNet + if add_prenet: + self.nar_text_prenet = nn.Sequential( + Transpose(), + nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"), + nn.BatchNorm1d(nar_d_model), + nn.ReLU(), + nn.Dropout(0.5), + nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"), + nn.BatchNorm1d(nar_d_model), + nn.ReLU(), + nn.Dropout(0.5), + nn.Conv1d(nar_d_model, nar_d_model, kernel_size=5, padding="same"), + nn.BatchNorm1d(nar_d_model), + nn.ReLU(), + nn.Dropout(0.5), + Transpose(), + nn.Linear(nar_d_model, nar_d_model), + ) + self.nar_audio_prenet = nn.Sequential( + nn.Linear(nar_d_model, 256), + nn.ReLU(), + nn.Dropout(0.25), + nn.Linear(256, 256), + nn.ReLU(), + nn.Dropout(0.25), + nn.Linear(256, nar_d_model), + ) + else: + self.nar_text_prenet = nn.Identity() + self.nar_audio_prenet = nn.Identity() + + self.nar_text_position = SinePositionalEmbedding( + nar_d_model, + dropout=0.0, + scale=False, + alpha=False, + ) + self.nar_audio_position = SinePositionalEmbedding( + nar_d_model, + dropout=0.1, + scale=False, + alpha=False, + ) + + self.nar_decoder = decoder_cls( + decoder_layer_cls( + nar_d_model, + int(nhead * nar_scale_factor), + dim_feedforward=nar_d_model * 4, + dropout=0.1, + batch_first=True, + norm_first=norm_first, + adaptive_layer_norm=True, + ), + num_layers=int(num_layers * nar_scale_factor), + norm=AdaptiveLayerNorm(nar_d_model, norm=nn.LayerNorm(nar_d_model)) + if norm_first + else None, + ) + self.nar_predict_layers = nn.ModuleList( + [ + nn.Linear(nar_d_model, NUM_AUDIO_TOKENS, bias=False) + for i in range(num_quantizers - 1) + ] + ) + self.nar_stage_embeddings = nn.ModuleList( + [TokenEmbedding(nar_d_model, 1) for i in range(num_quantizers - 1)] + ) + + if share_embedding: + # We share the parameters of the output projection layer with the parameters of the acoustic embedding Wa + # NOTE(Feiteng): In the experiment, this undermines accuracy + # self.ar_predict_layer.weight = self.ar_audio_embedding.weight + + # We also share the parameters of the acoustic embedding layer and the output prediction layer, + # which means the weights of the j-th prediction layer are the same as the (j + 1)-th acoustic embedding layer. + for j in range(0, num_quantizers - 2): + self.nar_predict_layers[j].weight = self.nar_audio_embeddings[ + j + 2 + ].weight + + self.nar_accuracy_metric = MulticlassAccuracy( + NUM_AUDIO_TOKENS + 1, + top_k=10, + average="micro", + multidim_average="global", + ignore_index=NUM_AUDIO_TOKENS, + ) + + def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]: + assert stage > 0 + if stage == 1: + for name, param in self.named_parameters(): + if name.startswith("ar_"): + print(f" AR parameter: {name}") + yield param + + if stage == 2: + for name, param in self.named_parameters(): + if name.startswith("nar_"): + print(f"NAR parameter: {name}") + yield param + + def stage_named_parameters( + self, stage: int = 1 + ) -> Iterator[Tuple[str, nn.Parameter]]: + assert stage > 0 + if stage == 1: + for pair in self.named_parameters(): + if pair[0].startswith("ar_"): + yield pair + + if stage == 2: + for pair in self.named_parameters(): + if pair[0].startswith("nar_"): + yield pair + + def pad_y_eos(self, y, y_mask_int, eos_id): + targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad( + y_mask_int, (0, 1), value=1 + ) + # inputs, targets + if self.ar_audio_prepend_bos: + return ( + F.pad(targets[:, :-1], (1, 0), value=NUM_AUDIO_TOKENS + 1), + targets, + ) + + return targets[:, :-1], targets[:, 1:] + + def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes): + # 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds + # from the same utterance. + # We implement this differently. + if self.prefix_mode == 0: + # no prefix + prefix_len = 0 + y_emb = self.nar_audio_embeddings[0](y) + for j in range(1, nar_stage): + # Formula (4) (5) + y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j]) + elif self.prefix_mode == 1: + # prefix at begining + int_low = (0.25 * y_lens.min()).type(torch.int64).item() + prefix_len = torch.randint(int_low, int_low * 2, size=()).item() + prefix_len = min(prefix_len, 225) # 24000/320 * 3s = 225 frames + + y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len]) + y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:]) + for j in range(1, self.num_quantizers): + y_prompts += self.nar_audio_embeddings[j](codes[:, :prefix_len, j]) + if j < nar_stage: + y_emb += self.nar_audio_embeddings[j](codes[:, prefix_len:, j]) + y_emb = torch.concat([y_prompts, y_emb], axis=1) + elif self.prefix_mode in [2, 4]: + if self.prefix_mode == 2: + # random prefix + prefix_len = min(225, int(0.25 * y_lens.min().item())) + + y_prompts_codes = [] + for b in range(codes.shape[0]): + start = self.rng.randint(0, y_lens[b].item() - prefix_len) + y_prompts_codes.append( + torch.clone(codes[b, start : start + prefix_len]) + ) + codes[b, start : start + prefix_len, nar_stage] = NUM_AUDIO_TOKENS + y_prompts_codes = torch.stack(y_prompts_codes, dim=0) + else: + prefix_len = y_prompts_codes.shape[1] + + y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0]) + y_emb = self.nar_audio_embeddings[0](y) + for j in range(1, self.num_quantizers): + y_prompts += self.nar_audio_embeddings[j](y_prompts_codes[..., j]) + if j < nar_stage: + y_emb += self.nar_audio_embeddings[j](codes[..., j]) + y_emb = torch.concat([y_prompts, y_emb], axis=1) + else: + raise ValueError + + return y_emb, prefix_len + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: Union[torch.Tensor, PromptedFeatures], + y_lens: Union[torch.Tensor, PromptedFeatures], + reduction: str = "sum", + train_stage: int = 0, + **kwargs, + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: + """ + Args: + x: + A 2-D tensor of shape (N, S). + x_lens: + A 1-D tensor of shape (N,). It contains the number of tokens in `x` + before padding. + y: + A 3-D tensor of shape (N, T, 8). + y_lens: + A 1-D tensor of shape (N,). It contains the number of tokens in `x` + before padding. + train_stage: + 0: AR & NAR modules, 1: AR modules, 2: NAR modules + Returns: + Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy. + """ + assert x.ndim == 2, x.shape + assert x_lens.ndim == 1, x_lens.shape + + y_prompts_codes = None + if isinstance(y, PromptedFeatures): + y_prompts_codes, y = y.data + prompts_len, y_lens = y_lens.data + assert prompts_len.min() == prompts_len.max() + assert self.prefix_mode == 4 + y_prompts_codes = y_prompts_codes.type(torch.int64) + + assert y.ndim == 3, y.shape + assert y_lens.ndim == 1, y_lens.shape + + # NOTE: x has been padded in TextTokenCollater + x_mask = make_pad_mask(x_lens).to(x.device) + y_mask = make_pad_mask(y_lens).to(y.device) + y_mask_int = y_mask.type(torch.int64) + + text = x + codes = y.type(torch.int64) * (1 - y_mask_int.unsqueeze(dim=-1)) + + y, targets = self.pad_y_eos(codes[..., 0], y_mask_int, eos_id=NUM_AUDIO_TOKENS) + + x_len = x_lens.max() + + metrics = {} + total_loss = 0.0 + + xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) + if self.ar_audio_prepend_bos: + ar_xy_padding_mask = torch.concat( + [x_mask, F.pad(y_mask, (1, 0), value=False)], dim=1 + ) + else: + ar_xy_padding_mask = xy_padding_mask + # AR Decoder + if train_stage in [0, 1]: + x = self.ar_text_embedding(text) + x = self.ar_text_prenet(x) + x = self.ar_text_position(x) + + y_len = y_lens.max() + int(self.ar_audio_prepend_bos) + + x_attn_mask = F.pad( + torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device), + (0, y_len), + value=True, + ) + y_attn_mask = F.pad( + torch.triu( + torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), + diagonal=1, + ), + (x_len, 0), + value=False, + ) + xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) + + # merge key padding and attention masks + bsz, src_len = x.shape[0], x_len + y_len + _xy_padding_mask = ( + ar_xy_padding_mask.view(bsz, 1, 1, src_len) + .expand(-1, self.num_heads, -1, -1) + .reshape(bsz * self.num_heads, 1, src_len) + ) + xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) + + new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) + new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) + xy_attn_mask = new_attn_mask + + y_emb = self.ar_audio_embedding(y) + y_emb = self.ar_audio_prenet(y_emb) + y_pos = self.ar_audio_position(y_emb) + + xy_pos = torch.concat([x, y_pos], dim=1) + + xy_dec, _ = self.ar_decoder( + (xy_pos, None), + mask=xy_attn_mask, + # src_key_padding_mask=xy_padding_mask, + # is_causal=True, + ) + logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1) + # loss + total_loss = F.cross_entropy(logits, targets, reduction=reduction) + + metrics["ArTop10Accuracy"] = self.ar_accuracy_metric( + logits.detach(), targets + ).item() * y_lens.sum().type(torch.float32) + + if self.num_quantizers == 1: + return ((x, codes), total_loss, metrics) + + # Non-AR Decoders + if self.ar_audio_prepend_bos: + y = y[:, 1:] + if train_stage in [0, 2]: + num_nar_layers = self.num_quantizers - 1 + nar_stage = self.rng.choices( + [_k for _k in range(1, self.num_quantizers)], + weights=[1.0 / num_nar_layers] * num_nar_layers, + k=1, + )[0] + + x = self.nar_text_embedding(text) + x = self.nar_text_prenet(x) + x = self.nar_text_position(x) + + y_emb, prefix_len = self._prepare_prompts( + y, y_lens, codes, nar_stage, y_prompts_codes + ) + + y_len = y_lens.max() + targets = codes[..., nar_stage] + NUM_AUDIO_TOKENS * y_mask_int + if self.prefix_mode in [2, 4]: + xy_padding_mask = torch.concat( + [ + x_mask, + F.pad(y_mask, (y_emb.shape[1] - y_len, 0), value=False), + ], + dim=1, + ) + elif self.prefix_mode == 1: + targets = targets[:, prefix_len:] + + y_pos = self.nar_audio_prenet(y_emb) + y_pos = self.nar_audio_position(y_pos) + xy_pos = torch.concat([x, y_pos], dim=1) + xy_dec, _ = self.nar_decoder( + (xy_pos, self.nar_stage_embeddings[nar_stage - 1].weight), + src_key_padding_mask=xy_padding_mask, + # is_causal=False, + ) + xy_dec = xy_dec[:, x_lens.max() + prefix_len :] + if self.prefix_mode == 4: + prefix_len = 0 # reset for Top10Accuracy metric + logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute(0, 2, 1) + + # loss + total_length = (y_lens).sum().type(torch.float32) + total_loss += F.cross_entropy( + logits, + targets, + ignore_index=NUM_AUDIO_TOKENS, + reduction=reduction, + ) * (total_length / (total_length - prefix_len * x.shape[0])) + metrics["NarTop10Accuracy"] = ( + self.nar_accuracy_metric( + F.pad( + logits.detach(), + (0, 0, 0, 1, 0, 0), + value=logits.min().cpu().item(), + ), + targets, + ).item() + * total_length + ) + + if train_stage == 0: + total_loss = total_loss / 2.0 + + return ((x, codes), total_loss, metrics) + + def inference( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: torch.Tensor, + enroll_x_lens: torch.Tensor, + top_k: int = -100, + temperature: float = 1.0, + top_p: float = 1.0, + ras: bool = False, + ) -> torch.Tensor: + """ + Args: + x: + A 2-D tensor of shape (1, S). + x_lens: + A 1-D tensor of shape (1,). It contains the number of tokens in `x` + before padding. + y: + A 3-D tensor of shape (1, T, 8). + top_k: (`optional`) int + The number of highest probability tokens to keep for top-k-filtering. Default to -100. + temperature: (`optional`) float + The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. + ras: (`optional`) bool + Whether to use repetition-aware sampling. Default to False. + Returns: + Return the predicted audio code matrix. + """ + assert x.ndim == 2, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.ndim == 3, y.shape + assert y.shape[0] == 1, y.shape + + assert torch.all(x_lens > 0) + + # NOTE: x has been padded in TextTokenCollater + text = x + x = self.ar_text_embedding(text) + x = self.ar_text_prenet(x) + x = self.ar_text_position(x) + + text_len = x_lens.max() + prompts = y + prefix_len = y.shape[1] + + # AR Decoder + # TODO: Managing decoder steps avoid repetitive computation + y = prompts[..., 0] + if self.ar_audio_prepend_bos: + y = F.pad(y, (1, 0), value=NUM_AUDIO_TOKENS + 1) + + x_len = x_lens.max() + x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) + + while True: + y_emb = self.ar_audio_embedding(y) + y_emb = self.ar_audio_prenet(y_emb) + y_pos = self.ar_audio_position(y_emb) + xy_pos = torch.concat([x, y_pos], dim=1) + + y_len = y.shape[1] + x_attn_mask_pad = F.pad( + x_attn_mask, + (0, y_len), + value=True, + ) + y_attn_mask = F.pad( + torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1), + (x_len, 0), + value=False, + ) + xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to( + y.device + ) + + xy_dec, _ = self.ar_decoder( + (xy_pos, None), + mask=xy_attn_mask, + ) + logits = self.ar_predict_layer(xy_dec[:, -1]) + samples = topk_sampling( + logits, + top_k=top_k, + top_p=top_p, + temperature=temperature, + repetition_aware_sampling=ras, + preceding_tokens=y, + ) + + if ( + torch.argmax(logits, dim=-1)[0] == NUM_AUDIO_TOKENS + or samples[0, 0] == NUM_AUDIO_TOKENS + or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16 + ): + if prompts.shape[1] == y.shape[1]: + raise SyntaxError("well trained model shouldn't reach here.") + break + + y = torch.concat([y, samples], dim=1) + + codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]] + if self.num_quantizers == 1: + return torch.stack(codes, dim=-1) + + # Non-AR Decoders + y_emb = self.nar_audio_embeddings[0](y[:, int(self.ar_audio_prepend_bos) :]) + + if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes + enrolled_len = enroll_x_lens.max().item() + # SOS + Synthesis Text + EOS + text = torch.concat( + [ + text[:, :1], + text[:, enrolled_len - 1 :], + ], + dim=1, + ) + text_len = text_len - (enrolled_len - 2) + assert text.shape[0] == 1 + + x = self.nar_text_embedding(text) + x = self.nar_text_prenet(x) + x = self.nar_text_position(x) + + if self.prefix_mode == 0: + for i, (predict_layer, embedding_layer) in enumerate( + zip( + self.nar_predict_layers, + self.nar_audio_embeddings[1:], + ) + ): + y_pos = self.nar_audio_prenet(y_emb) + y_pos = self.nar_audio_position(y_pos) + xy_pos = torch.concat([x, y_pos], dim=1) + + xy_dec, _ = self.nar_decoder( + (xy_pos, self.nar_stage_embeddings[i].weight) + ) + logits = predict_layer(xy_dec[:, text_len + prefix_len :]) + + samples = torch.argmax(logits, dim=-1) + codes.append(samples) + + if i < self.num_quantizers - 2: + y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1]) + y_emb[:, prefix_len:] += embedding_layer(samples) + else: + for j in range(1, self.num_quantizers): + y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j]) + + for i, (predict_layer, embedding_layer) in enumerate( + zip( + self.nar_predict_layers, + self.nar_audio_embeddings[1:], + ) + ): + y_pos = self.nar_audio_prenet(y_emb) + y_pos = self.nar_audio_position(y_pos) + xy_pos = torch.concat([x, y_pos], dim=1) + + xy_dec, _ = self.nar_decoder( + (xy_pos, self.nar_stage_embeddings[i].weight) + ) + logits = predict_layer(xy_dec[:, text_len + prefix_len :]) + + samples = torch.argmax(logits, dim=-1) + codes.append(samples) + + if i < self.num_quantizers - 2: + y_emb[:, prefix_len:] += embedding_layer(samples) + + assert len(codes) == self.num_quantizers + return torch.stack(codes, dim=-1) + + def continual( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + x: + A 2-D tensor of shape (1, S). + x_lens: + A 1-D tensor of shape (1,). It contains the number of tokens in `x` + before padding. + y: + A 3-D tensor of shape (1, T, 8). + Returns: + Return the predicted audio code matrix. + """ + assert x.ndim == 2, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.ndim == 3, y.shape + assert y.shape[0] == 1, y.shape + + assert torch.all(x_lens > 0) + assert self.num_quantizers == 8 + + # NOTE: x has been padded in TextTokenCollater + text = x + x = self.ar_text_embedding(text) + x = self.ar_text_prenet(x) + x = self.ar_text_position(x) + + text_len = x_lens.max() + + prefix_len = min(int(y.shape[1] * 0.5), 3 * 75) + + # AR Decoder + prompts = y[:, :prefix_len] + + codes = [y[:, prefix_len:, 0]] + # Non-AR Decoders + x = self.nar_text_embedding(text) + x = self.nar_text_prenet(x) + x = self.nar_text_position(x) + + y_emb = self.nar_audio_embeddings[0](y[..., 0]) + + if self.prefix_mode == 0: + for i, (predict_layer, embedding_layer) in enumerate( + zip( + self.nar_predict_layers, + self.nar_audio_embeddings[1:], + ) + ): + y_pos = self.nar_audio_position(y_emb) + y_pos = self.nar_audio_prenet(y_pos) + xy_pos = torch.concat([x, y_pos], dim=1) + + xy_dec, _ = self.nar_decoder( + (xy_pos, self.nar_stage_embeddings[i].weight) + ) + logits = predict_layer(xy_dec[:, text_len + prefix_len :]) + + samples = torch.argmax(logits, dim=-1) + codes.append(samples) + + if i < 6: + y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1]) + y_emb[:, prefix_len:] += embedding_layer(samples) + else: + for j in range(1, 8): + y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j]) + + for i, (predict_layer, embedding_layer) in enumerate( + zip( + self.nar_predict_layers, + self.nar_audio_embeddings[1:], + ) + ): + y_pos = self.nar_audio_prenet(y_emb) + y_pos = self.nar_audio_position(y_pos) + xy_pos = torch.concat([x, y_pos], dim=1) + + xy_dec, _ = self.nar_decoder( + (xy_pos, self.nar_stage_embeddings[i].weight) + ) + logits = predict_layer(xy_dec[:, text_len + prefix_len :]) + + samples = torch.argmax(logits, dim=-1) + codes.append(samples) + + if i < 6: + y_emb[:, prefix_len:] += embedding_layer(samples) + + assert len(codes) == 8 + return torch.stack(codes, dim=-1) + + +# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py +def top_k_top_p_filtering( + logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 +): + """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (batch size, vocabulary size) + if top_k > 0: keep only top k tokens with highest probability (top-k filtering). + if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + Make sure we keep at least min_tokens_to_keep per batch example in the output + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + if top_k > 0: + top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + logits[indices_to_remove] = filter_value + return logits + + +def topk_sampling( + logits, + top_k=10, + top_p=1.0, + temperature=1.0, + repetition_aware_sampling=False, + preceding_tokens=None, +): + if temperature != 1.0: + logits = logits / temperature + # Top-p/top-k filtering + logits_filtered = top_k_top_p_filtering( + logits.clone(), top_k=top_k, top_p=top_p, min_tokens_to_keep=2 + ) + # Sample + probs = F.softmax(logits_filtered, dim=-1) + tokens = torch.multinomial(probs, num_samples=1) + + if repetition_aware_sampling: + window_size = 10 + threshold = 0.1 + # we first generate the target code ct′ + # by nucleus sampling with a pre-defined top-p value v. Then, we + # calculate the repetition ratio r of token ct′ + # in the preceding code sequence with a window size K. + # If the ratio r exceeds a pre-defined repetition threshold ratio tn, we replace the target code ct′ + # by + # random sampling from p(ct′ + # |x, c window_size: + preceding_tokens = preceding_tokens[:, -window_size:] + if preceding_tokens.shape[1] > 0: + for i, item in enumerate(preceding_tokens): + # check if the repeat ratio exceeds the threshold + if (item == tokens[i]).sum() / window_size > threshold: + # replace the target code ct′ by random sampling + probs = F.softmax(logits[i], dim=-1) + token_new = torch.multinomial(probs, num_samples=1) + tokens[i] = token_new + return tokens