From 5361ecdc56cbb9f0ed0a1ad6045f3e19e57d57c9 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 19 Nov 2024 04:21:32 +0000 Subject: [PATCH] add valle --- ...te_neural_codec_and_prepare_text_tokens.py | 575 ++++++ egs/libritts/TTS/valle/infer.py | 304 +++ egs/libritts/TTS/valle/optim.py | 1 + egs/libritts/TTS/valle/tokenizer.py | 121 ++ egs/libritts/TTS/valle/train.py | 1287 ++++++++++++ egs/libritts/TTS/valle/tts_datamodule.py | 344 ++++ egs/libritts/TTS/valle/valle.py | 1822 +++++++++++++++++ 7 files changed, 4454 insertions(+) create mode 100755 egs/libritts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py create mode 100644 egs/libritts/TTS/valle/infer.py create mode 120000 egs/libritts/TTS/valle/optim.py create mode 100644 egs/libritts/TTS/valle/tokenizer.py create mode 100755 egs/libritts/TTS/valle/train.py create mode 100644 egs/libritts/TTS/valle/tts_datamodule.py create mode 100644 egs/libritts/TTS/valle/valle.py diff --git a/egs/libritts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py b/egs/libritts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py new file mode 100755 index 0000000000..588c7ddd33 --- /dev/null +++ b/egs/libritts/TTS/local/compute_neural_codec_and_prepare_text_tokens.py @@ -0,0 +1,575 @@ +#!/usr/bin/env python3 +# Copyright 2023 (authors: Feiteng Li) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Phonemize Text and EnCodec Audio. + +Usage example: + python3 bin/tokenizer.py \ + --src_dir ./data/manifests --output_dir ./data/tokenized + +""" +import argparse +import logging +import os +from pathlib import Path + +import torch +import torch.multiprocessing +from icefall.utils import get_executor +from lhotse import CutSet, NumpyHdf5Writer +from lhotse.recipes.utils import read_manifests_if_cached +from tqdm.auto import tqdm + +from valle.data import ( + AudioTokenConfig, + AudioTokenExtractor, + TextTokenizer, + tokenize_text, +) +# from valle.data.fbank import get_fbank_extractor +from valle.utils import SymbolTable + +os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + + +# Torch's multithreaded behavior needs to be disabled or +# it wastes a lot of CPU and slow things down. +# Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) +torch.multiprocessing.set_sharing_strategy("file_system") + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--src-dir", + type=Path, + default=Path("data/manifests"), + help="Path to the manifest files", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=Path("data/tokenized"), + help="Path to the tokenized files", + ) + parser.add_argument( + "--text-extractor", + type=str, + default="espeak", + help="espeak or pypinyin or pypinyin_initials_finals", + ) + parser.add_argument( + "--audio-extractor", + type=str, + default="Encodec", + help="Encodec or Fbank", + ) + parser.add_argument( + "--dataset-parts", + type=str, + default="dev-clean test-clean", + help="Space separated dataset parts", + ) + parser.add_argument( + "--prefix", + type=str, + default="libritts", + help="prefix of the manifest file", + ) + parser.add_argument( + "--suffix", + type=str, + default="jsonl.gz", + help="suffix of the manifest file", + ) + parser.add_argument( + "--batch-duration", + type=float, + default=400.0, + help="The maximum number of audio seconds in a batch." + "Determines batch size dynamically.", + ) + parser.add_argument( + "--split", + type=int, + default=1, + help="Split the cut_set into multiple parts", + ) + + return parser.parse_args() + +class PypinyinBackend: + """PypinyinBackend for Chinese. Most codes is referenced from espnet. + There are two types pinyin or initials_finals, one is + just like "ni1 hao3", the other is like "n i1 h ao3". + """ + + def __init__( + self, + backend="initials_finals", + punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), + ) -> None: + self.backend = backend + self.punctuation_marks = punctuation_marks + + def phonemize( + self, text: List[str], separator: Separator, strip=True, njobs=1 + ) -> List[str]: + assert isinstance(text, List) + phonemized = [] + for _text in text: + _text = re.sub(" +", " ", _text.strip()) + _text = _text.replace(" ", separator.word) + phones = [] + if self.backend == "pypinyin": + for n, py in enumerate( + pinyin( + _text, style=Style.TONE3, neutral_tone_with_five=True + ) + ): + if all([c in self.punctuation_marks for c in py[0]]): + if len(phones): + assert phones[-1] == separator.syllable + phones.pop(-1) + + phones.extend(list(py[0])) + else: + phones.extend([py[0], separator.syllable]) + elif self.backend == "pypinyin_initials_finals": + for n, py in enumerate( + pinyin( + _text, style=Style.TONE3, neutral_tone_with_five=True + ) + ): + if all([c in self.punctuation_marks for c in py[0]]): + if len(phones): + assert phones[-1] == separator.syllable + phones.pop(-1) + phones.extend(list(py[0])) + else: + if py[0][-1].isalnum(): + initial = get_initials(py[0], strict=False) + if py[0][-1].isdigit(): + final = ( + get_finals(py[0][:-1], strict=False) + + py[0][-1] + ) + else: + final = get_finals(py[0], strict=False) + phones.extend( + [ + initial, + separator.phone, + final, + separator.syllable, + ] + ) + else: + assert ValueError + else: + raise NotImplementedError + phonemized.append( + "".join(phones).rstrip(f"{separator.word}{separator.syllable}") + ) + return phonemized + + +class TextTokenizer: + """Phonemize Text.""" + + def __init__( + self, + language="en-us", + backend="espeak", + separator=Separator(word="_", syllable="-", phone="|"), + preserve_punctuation=True, + punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), + with_stress: bool = False, + tie: Union[bool, str] = False, + language_switch: LanguageSwitch = "keep-flags", + words_mismatch: WordMismatch = "ignore", + ) -> None: + if backend == "espeak": + phonemizer = EspeakBackend( + language, + punctuation_marks=punctuation_marks, + preserve_punctuation=preserve_punctuation, + with_stress=with_stress, + tie=tie, + language_switch=language_switch, + words_mismatch=words_mismatch, + ) + elif backend in ["pypinyin", "pypinyin_initials_finals"]: + phonemizer = PypinyinBackend( + backend=backend, + punctuation_marks=punctuation_marks + separator.word, + ) + else: + raise NotImplementedError(f"{backend}") + + self.backend = phonemizer + self.separator = separator + + def to_list(self, phonemized: str) -> List[str]: + fields = [] + for word in phonemized.split(self.separator.word): + # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z. + pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE) + fields.extend( + [p for p in pp if p != self.separator.phone] + + [self.separator.word] + ) + assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count( + self.separator.phone + ) + return fields[:-1] + + def __call__(self, text, strip=True) -> List[List[str]]: + if isinstance(text, str): + text = [text] + + phonemized = self.backend.phonemize( + text, separator=self.separator, strip=strip, njobs=1 + ) + return [self.to_list(p) for p in phonemized] + + +def tokenize_text(tokenizer: TextTokenizer, text: str) -> List[str]: + phonemes = tokenizer([text.strip()]) + return phonemes[0] # k2symbols + + +def remove_encodec_weight_norm(model): + from encodec.modules import SConv1d + from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock + from torch.nn.utils import remove_weight_norm + + encoder = model.encoder.model + for key in encoder._modules: + if isinstance(encoder._modules[key], SEANetResnetBlock): + remove_weight_norm(encoder._modules[key].shortcut.conv.conv) + block_modules = encoder._modules[key].block._modules + for skey in block_modules: + if isinstance(block_modules[skey], SConv1d): + remove_weight_norm(block_modules[skey].conv.conv) + elif isinstance(encoder._modules[key], SConv1d): + remove_weight_norm(encoder._modules[key].conv.conv) + + decoder = model.decoder.model + for key in decoder._modules: + if isinstance(decoder._modules[key], SEANetResnetBlock): + remove_weight_norm(decoder._modules[key].shortcut.conv.conv) + block_modules = decoder._modules[key].block._modules + for skey in block_modules: + if isinstance(block_modules[skey], SConv1d): + remove_weight_norm(block_modules[skey].conv.conv) + elif isinstance(decoder._modules[key], SConvTranspose1d): + remove_weight_norm(decoder._modules[key].convtr.convtr) + elif isinstance(decoder._modules[key], SConv1d): + remove_weight_norm(decoder._modules[key].conv.conv) + + +class AudioTokenizer: + """EnCodec audio.""" + + def __init__( + self, + device: Any = None, + ) -> None: + # Instantiate a pretrained EnCodec model + model = EncodecModel.encodec_model_24khz() + model.set_target_bandwidth(6.0) + remove_encodec_weight_norm(model) + + if not device: + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda:0") + + self._device = device + + self.codec = model.to(device) + self.sample_rate = model.sample_rate + self.channels = model.channels + + @property + def device(self): + return self._device + + def encode(self, wav: torch.Tensor) -> torch.Tensor: + return self.codec.encode(wav.to(self.device)) + + def decode(self, frames: torch.Tensor) -> torch.Tensor: + return self.codec.decode(frames) + + +@dataclass +class AudioTokenConfig: + frame_shift: Seconds = 320.0 / 24000 + num_quantizers: int = 8 + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + @staticmethod + def from_dict(data: Dict[str, Any]) -> "AudioTokenConfig": + return AudioTokenConfig(**data) + +class AudioTokenExtractor(FeatureExtractor): + name = "encodec" + config_type = AudioTokenConfig + + def __init__(self, config: Optional[Any] = None): + super(AudioTokenExtractor, self).__init__(config) + self.tokenizer = AudioTokenizer() + + def extract( + self, samples: Union[np.ndarray, torch.Tensor], sampling_rate: int + ) -> np.ndarray: + if not isinstance(samples, torch.Tensor): + samples = torch.from_numpy(samples) + if sampling_rate != self.tokenizer.sample_rate: + samples = convert_audio( + samples, + sampling_rate, + self.tokenizer.sample_rate, + self.tokenizer.channels, + ) + if len(samples.shape) == 2: + samples = samples.unsqueeze(0) + else: + raise ValueError() + + device = self.tokenizer.device + encoded_frames = self.tokenizer.encode(samples.detach().to(device)) + codes = encoded_frames[0][0] # [B, n_q, T] + if True: + duration = round(samples.shape[-1] / sampling_rate, ndigits=12) + expected_num_frames = compute_num_frames( + duration=duration, + frame_shift=self.frame_shift, + sampling_rate=sampling_rate, + ) + assert abs(codes.shape[-1] - expected_num_frames) <= 1 + codes = codes[..., :expected_num_frames] + return codes.cpu().squeeze(0).permute(1, 0).numpy() + + @property + def frame_shift(self) -> Seconds: + return self.config.frame_shift + + def feature_dim(self, sampling_rate: int) -> int: + return self.config.num_quantizers + + def pad_tensor_list(self, tensor_list, device, padding_value=0): + # 计算每个张量的长度 + lengths = [tensor.shape[0] for tensor in tensor_list] + # 使用pad_sequence函数进行填充 + tensor_list = [torch.Tensor(t).to(device) for t in tensor_list] + padded_tensor = torch.nn.utils.rnn.pad_sequence( + tensor_list, batch_first=True, padding_value=padding_value + ) + return padded_tensor, lengths + + def extract_batch(self, samples, sampling_rate, lengths) -> np.ndarray: + samples = [wav.squeeze() for wav in samples] + device = self.tokenizer.device + samples, lengths = self.pad_tensor_list(samples, device) + samples = samples.unsqueeze(1) + + if not isinstance(samples, torch.Tensor): + samples = torch.from_numpy(samples) + if len(samples.shape) != 3: + raise ValueError() + if sampling_rate != self.tokenizer.sample_rate: + samples = [ + convert_audio( + wav, + sampling_rate, + self.tokenizer.sample_rate, + self.tokenizer.channels, + ) + for wav in samples + ] + samples = torch.stack(samples, 0) # convert samples from list to tensor + # Extract discrete codes from EnCodec + with torch.no_grad(): + encoded_frames = self.tokenizer.encode(samples.detach().to(device)) + encoded_frames = encoded_frames[0][0] # [B, n_q, T] + batch_codes = [] + for b, length in enumerate(lengths): + codes = encoded_frames[b] + duration = round(length / sampling_rate, ndigits=12) + expected_num_frames = compute_num_frames( + duration=duration, + frame_shift=self.frame_shift, + sampling_rate=sampling_rate, + ) + batch_codes.append(codes[..., :expected_num_frames]) + return [codes.cpu().permute(1, 0).numpy() for codes in batch_codes] + +def main(): + args = get_args() + + dataset_parts = args.dataset_parts.replace("--dataset-parts", "").strip() + if dataset_parts == "all": # LibriTTS + dataset_parts = [ + "dev-clean", + "dev-other", + "test-clean", + "test-other", + "train-clean-100", + "train-clean-360", + "train-other-500", + ] + else: + dataset_parts = dataset_parts.replace("-p", "").strip().split(" ") + + assert len(dataset_parts) >= 1 + + manifests = read_manifests_if_cached( + dataset_parts=dataset_parts, + output_dir=args.src_dir, + prefix=args.prefix, + suffix=args.suffix, + types=["recordings", "supervisions", "cuts"], + ) + + text_tokenizer = None + if args.text_extractor: + text_tokenizer = TextTokenizer(backend=args.text_extractor) + + audio_extractor = None + if args.audio_extractor: + if args.audio_extractor == "Encodec": + audio_extractor = AudioTokenExtractor(AudioTokenConfig()) + else: + assert args.audio_extractor == "Fbank" + audio_extractor = get_fbank_extractor() + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + unique_symbols = set() + num_jobs = min(32, os.cpu_count()) + logging.info(f"dataset_parts: {dataset_parts} manifests {len(manifests)}") + + prefix = args.prefix + if prefix and not prefix.endswith("_"): + prefix = f"{prefix}_" + with get_executor() as ex: + for partition, m in manifests.items(): + logging.info( + f"Processing partition: {partition} CUDA: {torch.cuda.is_available()}" + ) + try: + cut_set = CutSet.from_manifests( + recordings=m["recordings"], + supervisions=m["supervisions"], + ) + except Exception: + cut_set = m["cuts"] + + # Split cut_set if split > 1 + split = 1 + if args.split > 1: + cut_sets = cut_set.split(args.split) + split = args.split + else: + cut_sets = [cut_set] + + for idx, part in enumerate(cut_sets): + # AudioTokenizer + if args.audio_extractor: + if args.audio_extractor == "Encodec": + storage_path = ( + f"{args.output_dir}/{args.prefix}_encodec_{partition}_{idx if split > 1 else ''}" + ) + else: + storage_path = ( + f"{args.output_dir}/{args.prefix}_fbank_{partition}_{idx if split > 1 else ''}" + ) + + if args.prefix.lower() in ["ljspeech", "aishell", "baker", "wenetspeech4tts"]: + part = part.resample(24000) + + with torch.no_grad(): + if ( + torch.cuda.is_available() + and args.audio_extractor == "Encodec" + ): + part = part.compute_and_store_features_batch( + extractor=audio_extractor, + storage_path=storage_path, + num_workers=num_jobs, + batch_duration=args.batch_duration, + collate=False, + overwrite=True, + storage_type=NumpyHdf5Writer, + ) + else: + part = part.compute_and_store_features( + extractor=audio_extractor, + storage_path=storage_path, + num_jobs=num_jobs if ex is None else 64, + executor=ex, + storage_type=NumpyHdf5Writer, + ) + + # TextTokenizer + if args.text_extractor: + for c in tqdm(part): + if args.prefix == "baker" and args.text_extractor == "labeled_pinyin": + phonemes = c.supervisions[0].custom["tokens"]["text"] + unique_symbols.update(phonemes) + else: + if args.prefix == "ljspeech": + text = c.supervisions[0].custom["normalized_text"] + text = text.replace(""", '"').replace(""", '"') + phonemes = tokenize_text(text_tokenizer, text=text) + elif args.prefix in ["aishell", "aishell2", "wenetspeech4tts", "libritts"]: + phonemes = tokenize_text( + text_tokenizer, text=c.supervisions[0].text + ) + if c.supervisions[0].custom is None: + c.supervisions[0].custom = {} + else: + raise NotImplementedError(f"{args.prefix}") + c.supervisions[0].custom["tokens"] = {"text": phonemes} + unique_symbols.update(phonemes) + + # Save each part with an index if split > 1 + cuts_filename = f"{prefix}cuts_{partition}.{idx if split > 1 else ''}.{args.suffix}" + part.to_file(f"{args.output_dir}/{cuts_filename}") + logging.info(f"Saved {cuts_filename}") + + if args.text_extractor: + unique_phonemes = SymbolTable() + for s in sorted(list(unique_symbols)): + unique_phonemes.add(s) + logging.info(f"{len(unique_symbols)} unique phonemes: {unique_symbols}") + + unique_phonemes_file = f"{args.output_dir}/unique_text_tokens.k2symbols" + unique_phonemes.to_file(unique_phonemes_file) + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/libritts/TTS/valle/infer.py b/egs/libritts/TTS/valle/infer.py new file mode 100644 index 0000000000..7c4a3bea47 --- /dev/null +++ b/egs/libritts/TTS/valle/infer.py @@ -0,0 +1,304 @@ +#!/usr/bin/env python3 +# Copyright 2023 (authors: Feiteng Li) +# Copyright 2024 (authors: Yuekai Zhang) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Phonemize Text and EnCodec Audio. + +Usage example: + python3 bin/infer.py --output-dir demos_epoch_${epoch}_avg_${avg} \ + --checkpoint=${exp_dir}/epoch-${epoch}-avg-${avg}.pt \ + --text-prompts "KNOT one point one five miles per hour." \ + --audio-prompts ./prompts/8463_294825_000043_000000.wav \ + --text "To get up and running quickly just follow the steps below." + + python3 bin/infer.py --output-dir demos_epoch_${epoch}_avg_${avg} \ + --top-k -1 --temperature 1.0 \ + --text-prompts "" \ + --audio-prompts "" \ + --text ./libritts.txt \ + --checkpoint ${exp_dir}/epoch-${epoch}-avg-${avg}.pt + +""" +import argparse +import logging +import os +from pathlib import Path + +os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + +import torch +import torchaudio +from icefall.utils import AttributeDict, str2bool + +from valle.data import ( + AudioTokenizer, + TextTokenizer, + tokenize_audio, + tokenize_text, +) +from valle.data.collation import get_text_token_collater +from valle.models import get_model + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--text-prompts", + type=str, + default="", + help="Text prompts which are separated by |.", + ) + + parser.add_argument( + "--audio-prompts", + type=str, + default="", + help="Audio prompts which are separated by | and should be aligned with --text-prompts.", + ) + + parser.add_argument( + "--text", + type=str, + default="To get up and running quickly just follow the steps below.", + help="Text to be synthesized.", + ) + + # model + # add_model_arguments(parser) + # parser.add_argument( + # "--text-tokens", + # type=str, + # default="data/tokenized/unique_text_tokens.k2symbols", + # help="Path to the unique text tokens file.", + # ) + + parser.add_argument( + "--text-extractor", + type=str, + default="espeak", + help="espeak or pypinyin or pypinyin_initials_finals", + ) + + parser.add_argument( + "--checkpoint", + type=str, + default="exp/vallf_nano_full/checkpoint-100000.pt", + help="Path to the saved checkpoint.", + ) + + parser.add_argument( + "--output-dir", + type=Path, + default=Path("infer/demo"), + help="Path to the tokenized files.", + ) + + parser.add_argument( + "--top-k", + type=int, + default=-100, + help="Whether AR Decoder do top_k(if > 0) sampling.", + ) + + parser.add_argument( + "--top-p", + type=float, + default=1.0, + help="Whether AR Decoder do top_p(if > 0) sampling.", + ) + + parser.add_argument( + "--temperature", + type=float, + default=1.0, + help="The temperature of AR Decoder top_k sampling.", + ) + + parser.add_argument( + "--continual", + type=str2bool, + default=False, + help="Do continual task.", + ) + + return parser.parse_args() + + +def load_model(checkpoint, device): + if not checkpoint: + return None + + checkpoint = torch.load(checkpoint, map_location=device) + + args = AttributeDict(checkpoint) + model = get_model(args) + + missing_keys, unexpected_keys = model.load_state_dict( + checkpoint["model"], strict=True + ) + assert not missing_keys + model.to(device) + model.eval() + + text_tokens = args.text_tokens + + return model, text_tokens + + +@torch.no_grad() +def main(): + args = get_args() + text_tokenizer = TextTokenizer(backend=args.text_extractor) + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", 0) + model, text_tokens = load_model(args.checkpoint, device) + text_collater = get_text_token_collater(text_tokens) + + audio_tokenizer = AudioTokenizer() + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + + text_prompts = " ".join(args.text_prompts.split("|")) + + audio_prompts = [] + if args.audio_prompts: + for n, audio_file in enumerate(args.audio_prompts.split("|")): + encoded_frames = tokenize_audio(audio_tokenizer, audio_file) + if False: + samples = audio_tokenizer.decode(encoded_frames) + torchaudio.save( + f"{args.output_dir}/p{n}.wav", samples[0], 24000 + ) + + audio_prompts.append(encoded_frames[0][0]) + + assert len(args.text_prompts.split("|")) == len(audio_prompts) + audio_prompts = torch.concat(audio_prompts, dim=-1).transpose(2, 1) + audio_prompts = audio_prompts.to(device) + + if os.path.isfile(args.text): # for demos + # https://github.com/lifeiteng/lifeiteng.github.com/blob/main/valle/prepare.py + with open(args.text) as f: + for line in f: + # fields = line.strip().split("\t") + fields = line.strip().split(" ") + fields = [item for item in fields if item] + assert len(fields) == 4 + prompt_text, prompt_audio, text, audio_path = fields + logging.info(f"synthesize text: {text}") + text_tokens, text_tokens_lens = text_collater( + [ + tokenize_text( + text_tokenizer, text=f"{prompt_text} {text}".strip() + ) + ] + ) + _, enroll_x_lens = text_collater( + [ + tokenize_text( + text_tokenizer, text=f"{prompt_text}".strip() + ) + ] + ) + + audio_prompts = tokenize_audio(audio_tokenizer, prompt_audio) + audio_prompts = audio_prompts[0][0].transpose(2, 1).to(device) + + # synthesis + encoded_frames = model.inference( + text_tokens.to(device), + text_tokens_lens.to(device), + audio_prompts, + enroll_x_lens=enroll_x_lens, + top_k=args.top_k, + temperature=args.temperature, + top_p=args.top_p, + ) + + samples = audio_tokenizer.decode( + [(encoded_frames.transpose(2, 1), None)] + ) + # store + # save audio path into args.output_dir + audio_path + audio_path = f"{args.output_dir}/{audio_path}" + # mkdir -p + os.makedirs(os.path.dirname(audio_path), exist_ok=True) + torchaudio.save(audio_path, samples[0].cpu(), 24000) + return + + for n, text in enumerate(args.text.split("|")): + logging.info(f"synthesize text: {text}") + text_tokens, text_tokens_lens = text_collater( + [ + tokenize_text( + text_tokenizer, text=f"{text_prompts} {text}".strip() + ) + ] + ) + + # synthesis + if args.continual: + assert text == "" + encoded_frames = model.continual( + text_tokens.to(device), + text_tokens_lens.to(device), + audio_prompts, + ) + else: + enroll_x_lens = None + if text_prompts: + _, enroll_x_lens = text_collater( + [ + tokenize_text( + text_tokenizer, text=f"{text_prompts}".strip() + ) + ] + ) + encoded_frames = model.inference( + text_tokens.to(device), + text_tokens_lens.to(device), + audio_prompts, + enroll_x_lens=enroll_x_lens, + top_k=args.top_k, + temperature=args.temperature, + top_p=args.top_p, + ) + + if audio_prompts != []: + samples = audio_tokenizer.decode( + [(encoded_frames.transpose(2, 1), None)] + ) + # store + torchaudio.save( + f"{args.output_dir}/{n}.wav", samples[0].cpu(), 24000 + ) + else: # Transformer + pass + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) +torch._C._jit_set_profiling_executor(False) +torch._C._jit_set_profiling_mode(False) +torch._C._set_graph_executor_optimize(False) +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/egs/libritts/TTS/valle/optim.py b/egs/libritts/TTS/valle/optim.py new file mode 120000 index 0000000000..5eaa3cffd4 --- /dev/null +++ b/egs/libritts/TTS/valle/optim.py @@ -0,0 +1 @@ +../../../librispeech/ASR/zipformer/optim.py \ No newline at end of file diff --git a/egs/libritts/TTS/valle/tokenizer.py b/egs/libritts/TTS/valle/tokenizer.py new file mode 100644 index 0000000000..55db84c68a --- /dev/null +++ b/egs/libritts/TTS/valle/tokenizer.py @@ -0,0 +1,121 @@ +from pathlib import Path +from typing import List, Tuple + +import numpy as np +import torch + +from k2 import SymbolTable + +class TextTokenCollater: + """Collate list of text tokens + + Map sentences to integers. Sentences are padded to equal length. + Beginning and end-of-sequence symbols can be added. + + Example: + >>> token_collater = TextTokenCollater(text_tokens) + >>> tokens_batch, tokens_lens = token_collater(text) + + Returns: + tokens_batch: IntTensor of shape (B, L) + B: batch dimension, number of input sentences + L: length of the longest sentence + tokens_lens: IntTensor of shape (B,) + Length of each sentence after adding and + but before padding. + """ + + def __init__( + self, + text_tokens: List[str], + add_eos: bool = True, + add_bos: bool = True, + pad_symbol: str = "", + bos_symbol: str = "", + eos_symbol: str = "", + ): + self.pad_symbol = pad_symbol + + self.add_eos = add_eos + self.add_bos = add_bos + + self.bos_symbol = bos_symbol + self.eos_symbol = eos_symbol + + unique_tokens = ( + [pad_symbol] + + ([bos_symbol] if add_bos else []) + + ([eos_symbol] if add_eos else []) + + sorted(text_tokens) + ) + + self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)} + self.idx2token = [token for token in unique_tokens] + + def index( + self, tokens_list: List[str] + ) -> Tuple[torch.Tensor, torch.Tensor]: + seqs, seq_lens = [], [] + for tokens in tokens_list: + assert ( + all([True if s in self.token2idx else False for s in tokens]) + is True + ) + seq = ( + ([self.bos_symbol] if self.add_bos else []) + + list(tokens) + + ([self.eos_symbol] if self.add_eos else []) + ) + seqs.append(seq) + seq_lens.append(len(seq)) + + max_len = max(seq_lens) + for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)): + seq.extend([self.pad_symbol] * (max_len - seq_len)) + + tokens = torch.from_numpy( + np.array( + [[self.token2idx[token] for token in seq] for seq in seqs], + dtype=np.int64, + ) + ) + tokens_lens = torch.IntTensor(seq_lens) + + return tokens, tokens_lens + + def __call__(self, texts: List[str]) -> Tuple[torch.Tensor, torch.Tensor]: + tokens_seqs = [[p for p in text] for text in texts] + max_len = len(max(tokens_seqs, key=len)) + + seqs = [ + ([self.bos_symbol] if self.add_bos else []) + + list(seq) + + ([self.eos_symbol] if self.add_eos else []) + + [self.pad_symbol] * (max_len - len(seq)) + for seq in tokens_seqs + ] + + tokens_batch = torch.from_numpy( + np.array( + [[self.token2idx[token] for token in seq] for seq in seqs], + dtype=np.int64, + ) + ) + + tokens_lens = torch.IntTensor( + [ + len(seq) + int(self.add_eos) + int(self.add_bos) + for seq in tokens_seqs + ] + ) + + return tokens_batch, tokens_lens + + +def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater: + text_tokens_path = Path(text_tokens_file) + unique_tokens = SymbolTable.from_file(text_tokens_path) + collater = TextTokenCollater( + unique_tokens.symbols, add_bos=True, add_eos=True + ) + return collater diff --git a/egs/libritts/TTS/valle/train.py b/egs/libritts/TTS/valle/train.py new file mode 100755 index 0000000000..302bdbee2b --- /dev/null +++ b/egs/libritts/TTS/valle/train.py @@ -0,0 +1,1287 @@ +#!/usr/bin/env python3 +# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Wei Kang, +# Mingshuang Luo) +# Copyright 2023 (authors: Feiteng Li) +# Copyright 2024 (authors: Yuekai Zhang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Usage: +world_size=8 +exp_dir=exp/valle + +## Train AR model +python3 train.py --max-duration 320 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 1 \ + --num-buckets 6 --dtype "bfloat16" --save-every-n 1000 --valid-interval 2000 \ + --model-name valle --share-embedding true --norm-first true --add-prenet false \ + --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ + --base-lr 0.03 --warmup-steps 200 --average-period 0 \ + --num-epochs 20 --start-epoch 1 --start-batch 0 --accumulate-grad-steps 1 \ + --exp-dir ${exp_dir} --world-size ${world_size} + +## Train NAR model +# cd ${exp_dir} +# ln -s ${exp_dir}/best-valid-loss.pt epoch-99.pt # --start-epoch 100=99+1 +# cd - +python3 train.py --max-duration 160 --filter-min-duration 0.5 --filter-max-duration 14 --train-stage 2 \ + --num-buckets 6 --dtype "float32" --save-every-n 1000 --valid-interval 2000 \ + --model-name valle --share-embedding true --norm-first true --add-prenet false \ + --decoder-dim 1024 --nhead 16 --num-decoder-layers 12 --prefix-mode 1 \ + --base-lr 0.03 --warmup-steps 200 --average-period 0 \ + --num-epochs 40 --start-epoch 100 --start-batch 0 --accumulate-grad-steps 2 \ + --exp-dir ${exp_dir} --world-size ${world_size} +""" + +import argparse +import copy +import logging +import os +from contextlib import nullcontext + +import random +import warnings +from pathlib import Path +from shutil import copyfile +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.multiprocessing as mp +import torch.nn as nn +from icefall.checkpoint import load_checkpoint, remove_checkpoints +from icefall.checkpoint import save_checkpoint as save_checkpoint_impl +from icefall.checkpoint import ( + save_checkpoint_with_global_batch_idx, + update_averaged_model, +) +from icefall.dist import cleanup_dist, setup_dist +from icefall.env import get_env_info +from icefall.hooks import register_inf_check_hooks +from icefall.utils import AttributeDict, MetricsTracker, setup_logger, str2bool +from lhotse import CutSet +from lhotse.cut import Cut +from lhotse.dataset.sampling.base import CutSampler +from lhotse.utils import fix_random_seed +from torch import Tensor +from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.utils.tensorboard import SummaryWriter + +from tts_datamodule import TtsDataModule +from optim import Eden, ScaledAdam +from valle import VALLE +from tokenizer import TextTokenCollater, get_text_token_collater + +LRSchedulerType = torch.optim.lr_scheduler._LRScheduler + +def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: + if isinstance(model, DDP): + # get underlying nn.Module + model = model.module + + for module in model.modules(): + if hasattr(module, "batch_count"): + module.batch_count = batch_count + +def add_model_arguments(parser: argparse.ArgumentParser): + parser.add_argument( + "--decoder-dim", + type=int, + default=1024, + help="Embedding dimension in the decoder model.", + ) + parser.add_argument( + "--nhead", + type=int, + default=16, + help="Number of attention heads in the Decoder layers.", + ) + parser.add_argument( + "--num-decoder-layers", + type=int, + default=12, + help="Number of Decoder layers.", + ) + parser.add_argument( + "--scale-factor", + type=float, + default=1.0, + help="Model scale factor which will be assigned different meanings in different models.", + ) + parser.add_argument( + "--norm-first", + type=str2bool, + default=True, + help="Pre or Post Normalization.", + ) + parser.add_argument( + "--add-prenet", + type=str2bool, + default=False, + help="Whether add PreNet after Inputs.", + ) + + parser.add_argument( + "--prefix-mode", + type=int, + default=0, + help="The mode for how to prefix VALL-E NAR Decoder, " + "0: no prefix, 1: 0 to random, 2: random to random, 4: chunk of pre or post utterance.", + ) + parser.add_argument( + "--share-embedding", + type=str2bool, + default=True, + help="Share the parameters of the output projection layer with the parameters of the acoustic embedding.", + ) + parser.add_argument( + "--prepend-bos", + type=str2bool, + default=False, + help="Whether prepend to the acoustic tokens -> AR Decoder inputs.", + ) + parser.add_argument( + "--num-quantizers", + type=int, + default=8, + help="Number of Audio/Semantic quantization layers.", + ) + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--tensorboard", + type=str2bool, + default=True, + help="Should various information be logged in tensorboard.", + ) + + parser.add_argument( + "--num-epochs", + type=int, + default=20, + help="Number of epochs to train.", + ) + + parser.add_argument( + "--start-epoch", + type=int, + default=1, + help="""Resume training from this epoch. It should be positive. + If larger than 1, it will load checkpoint from + exp-dir/epoch-{start_epoch-1}.pt + """, + ) + + parser.add_argument( + "--start-batch", + type=int, + default=0, + help="""If positive, --start-epoch is ignored and + it loads the checkpoint from exp-dir/checkpoint-{start_batch}.pt + """, + ) + + parser.add_argument( + "--exp-dir", + type=str, + default="exp/valle_dev", + help="""The experiment dir. + It specifies the directory where all training related + files, e.g., checkpoints, log, etc, are saved + """, + ) + + parser.add_argument( + "--text-tokens", + type=str, + default="data/tokenized/unique_text_tokens.k2symbols", + help="Path to the unique text tokens file", + ) + + parser.add_argument( + "--optimizer-name", + type=str, + default="ScaledAdam", + help="The optimizer.", + ) + parser.add_argument( + "--scheduler-name", + type=str, + default="Eden", + help="The scheduler.", + ) + parser.add_argument( + "--base-lr", type=float, default=0.05, help="The base learning rate." + ) + parser.add_argument( + "--warmup-steps", + type=int, + default=200, + help="""Number of steps that affects how rapidly the learning rate + decreases. We suggest not to change this.""", + ) + + parser.add_argument( + "--seed", + type=int, + default=42, + help="The seed for random generators intended for reproducibility", + ) + + parser.add_argument( + "--inf-check", + type=str2bool, + default=False, + help="Add hooks to check for infinite module outputs and gradients.", + ) + + parser.add_argument( + "--save-every-n", + type=int, + default=10000, + help="""Save checkpoint after processing this number of batches" + periodically. We save checkpoint to exp-dir/ whenever + params.batch_idx_train %% save_every_n == 0. The checkpoint filename + has the form: f'exp-dir/checkpoint-{params.batch_idx_train}.pt' + Note: It also saves checkpoint to `exp-dir/epoch-xxx.pt` at the + end of each epoch where `xxx` is the epoch number counting from 0. + """, + ) + parser.add_argument( + "--valid-interval", + type=int, + default=10000, + help="""Run validation if batch_idx %% valid_interval is 0.""", + ) + + parser.add_argument( + "--keep-last-k", + type=int, + default=20, + help="""Only keep this number of checkpoints on disk. + For instance, if it is 3, there are only 3 checkpoints + in the exp-dir with filenames `checkpoint-xxx.pt`. + It does not affect checkpoints with name `epoch-xxx.pt`. + """, + ) + + parser.add_argument( + "--average-period", + type=int, + default=0, + help="""Update the averaged model, namely `model_avg`, after processing + this number of batches. `model_avg` is a separate version of model, + in which each floating-point parameter is the average of all the + parameters from the start of training. Each time we take the average, + we do: `model_avg = model * (average_period / batch_idx_train) + + model_avg * ((batch_idx_train - average_period) / batch_idx_train)`. + """, + ) + + parser.add_argument( + "--accumulate-grad-steps", + type=int, + default=1, + help="""update gradient when batch_idx_train %% accumulate_grad_steps == 0. + """, + ) + + parser.add_argument( + "--dtype", + type=str, + default="float32", + help="Training dtype: float32 bfloat16 float16.", + ) + + parser.add_argument( + "--filter-min-duration", + type=float, + default=0.0, + help="Keep only utterances with duration > this.", + ) + parser.add_argument( + "--filter-max-duration", + type=float, + default=20.0, + help="Keep only utterances with duration < this.", + ) + + parser.add_argument( + "--train-stage", + type=int, + default=0, + help="""0: train all modules, For VALL-E, support 1: AR Decoder 2: NAR Decoder(s) + """, + ) + + parser.add_argument( + "--visualize", + type=str2bool, + default=False, + help="visualize model results in eval step.", + ) + + parser.add_argument( + "--oom-check", + type=str2bool, + default=False, + help="perform OOM check on dataloader batches before starting training.", + ) + + add_model_arguments(parser) + + return parser + + + +def get_params() -> AttributeDict: + """Return a dict containing training parameters. + + All training related parameters that are not passed from the commandline + are saved in the variable `params`. + + Commandline options are merged into `params` after they are parsed, so + you can also access them via `params`. + + Explanation of options saved in `params`: + + - best_train_loss: Best training loss so far. It is used to select + the model that has the lowest training loss. It is + updated during the training. + + - best_valid_loss: Best validation loss so far. It is used to select + the model that has the lowest validation loss. It is + updated during the training. + + - best_train_epoch: It is the epoch that has the best training loss. + + - best_valid_epoch: It is the epoch that has the best validation loss. + + - batch_idx_train: Used to writing statistics to tensorboard. It + contains number of batches trained so far across + epochs. + + - log_interval: Print training loss if batch_idx % log_interval` is 0 + + - reset_interval: Reset statistics if batch_idx % reset_interval is 0 + + - valid_interval: Run validation if batch_idx % valid_interval is 0 + """ + params = AttributeDict( + { + "best_train_loss": float("inf"), + "best_valid_loss": float("inf"), + "best_train_epoch": -1, + "best_valid_epoch": -1, + "batch_idx_train": 0, + "log_interval": 100, + "reset_interval": 200, + "valid_interval": 10000, + "env_info": get_env_info(), + } + ) + + return params + + +def load_checkpoint_if_available( + params: AttributeDict, + model: nn.Module, + model_avg: nn.Module = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, +) -> Optional[Dict[str, Any]]: + """Load checkpoint from file. + + If params.start_batch is positive, it will load the checkpoint from + `params.exp_dir/checkpoint-{params.start_batch}.pt`. Otherwise, if + params.start_epoch is larger than 1, it will load the checkpoint from + `params.start_epoch - 1`. + + Apart from loading state dict for `model` and `optimizer` it also updates + `best_train_epoch`, `best_train_loss`, `best_valid_epoch`, + and `best_valid_loss` in `params`. + + Args: + params: + The return value of :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer that we are using. + scheduler: + The scheduler that we are using. + Returns: + Return a dict containing previously saved training info. + """ + if params.start_batch > 0: + filename = params.exp_dir / f"checkpoint-{params.start_batch}.pt" + elif params.start_epoch > 1: + filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" + else: + return None + + assert filename.is_file(), f"{filename} does not exist!" + + if isinstance(model, DDP): + raise ValueError("load_checkpoint before DDP") + + saved_params = load_checkpoint( + filename, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + ) + + saved_stage = saved_params.get("train_stage", 0) + if params.train_stage != saved_stage: + # switch training stage + if params.train_stage and saved_stage: # switch between 1 and 2 + params.start_epoch = 1 + params.start_batch = 0 + else: + # switch between 0 and 1/2 + assert params.num_epochs >= params.start_epoch + params.batch_idx_train = saved_params["batch_idx_train"] + + for key in ["optimizer", "grad_scaler", "sampler"]: + if key in saved_params: + saved_params.pop(key) + + # when base on stage 0, we keep scheduler + if saved_stage != 0: + for key in ["scheduler"]: + if key in saved_params: + saved_params.pop(key) + + best_train_filename = params.exp_dir / "best-train-loss.pt" + if best_train_filename.is_file(): + copyfile( + src=best_train_filename, + dst=params.exp_dir / f"best-train-loss-stage{saved_stage}.pt", + ) + + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + if best_valid_filename.is_file(): + copyfile( + src=best_valid_filename, + dst=params.exp_dir / f"best-valid-loss-stage{saved_stage}.pt", + ) + else: + + keys = [ + "best_train_epoch", + "best_valid_epoch", + "batch_idx_train", + "best_train_loss", + "best_valid_loss", + ] + for k in keys: + params[k] = saved_params[k] + + if params.start_batch > 0: + if "cur_epoch" in saved_params: + params["start_epoch"] = saved_params["cur_epoch"] + + return saved_params + + +def save_checkpoint( + params: AttributeDict, + model: Union[nn.Module, DDP], + model_avg: Optional[nn.Module] = None, + optimizer: Optional[torch.optim.Optimizer] = None, + scheduler: Optional[LRSchedulerType] = None, + sampler: Optional[CutSampler] = None, + scaler: Optional[GradScaler] = None, + rank: int = 0, +) -> None: + """Save model, optimizer, scheduler and training stats to file. + + Args: + params: + It is returned by :func:`get_params`. + model: + The training model. + model_avg: + The stored model averaged from the start of training. + optimizer: + The optimizer used in the training. + sampler: + The sampler for the training dataset. + scaler: + The scaler used for mix precision training. + """ + if rank != 0: + return + filename = params.exp_dir / f"epoch-{params.cur_epoch}.pt" + save_checkpoint_impl( + filename=filename, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=sampler, + scaler=scaler, + rank=rank, + ) + + if params.best_train_epoch == params.cur_epoch: + best_train_filename = params.exp_dir / "best-train-loss.pt" + copyfile(src=filename, dst=best_train_filename) + + if params.best_valid_epoch == params.cur_epoch: + best_valid_filename = params.exp_dir / "best-valid-loss.pt" + copyfile(src=filename, dst=best_valid_filename) + +def prepare_input(batch: dict, tokenizer: Tokenizer, device: torch.device): + """Parse batch data""" + + features = batch["features"].to(device) + features_lens = batch["features_lens"].to(device) + if "tokens" not in batch: + raise NotImplementedError("Need to tokenize text") + # tokens = [] + # for c in batch["cuts"]: + # phonemes = tokenize_text( + # tokenizer, text=c.supervisions[0].text + # ) + # tokens.append(phonemes) + else: + tokens = batch["tokens"] + + text_tokens, text_tokens_lens = tokenizer(tokens) + text_tokens = text_tokens.to(device) + text_tokens_lens = text_tokens_lens.to(device) + + return features, features_lens, text_tokens, text_tokens_lens + +def compute_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: TextTokenCollater, + batch: dict, + is_training: bool, +) -> Tuple[Tensor, MetricsTracker]: + """ + Compute transducer loss given the model and its inputs. + + Args: + params: + Parameters for training. See :func:`get_params`. + model: + The model for training. It is an instance of Zipformer in our case. + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + is_training: + True for training. False for validation. When it is True, this + function enables autograd during computation; when it is False, it + disables autograd. + warmup: a floating point value which increases throughout training; + values >= 1.0 are fully warmed up and have all modules present. + """ + device = ( + model.device + if isinstance(model, DDP) + else next(model.parameters()).device + ) + ( + audio_features, + audio_features_lens, + text_tokens, + text_tokens_lens, + ) = prepare_input(batch, tokenizer, device) + # at entry, TextTokens is (N, P) + assert text_tokens.ndim == 2 + assert audio_features.ndim == 3 + + with torch.set_grad_enabled(is_training): + predicts, loss, metrics = model( + x=text_tokens, + x_lens=text_tokens_lens, + y=audio_features, + y_lens=audio_features_lens, + train_stage=params.train_stage, + ) + + assert loss.requires_grad == is_training + + info = MetricsTracker() + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + info["frames"] = (audio_features_lens).sum().item() + info["utterances"] = text_tokens.size(0) + + # Note: We use reduction=sum while computing the loss. + info["loss"] = loss.detach().cpu().item() + for metric in metrics: + info[metric] = metrics[metric].detach().cpu().item() + del metrics + + return predicts, loss, info + + +def compute_validation_loss( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: TextTokenCollater, + valid_dl: torch.utils.data.DataLoader, + world_size: int = 1, +) -> MetricsTracker: + """Run the validation process.""" + tot_loss = MetricsTracker() + + for batch_idx, batch in enumerate(valid_dl): + predicts, loss, loss_info = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + batch=batch, + is_training=False, + ) + assert loss.requires_grad is False + tot_loss = tot_loss + loss_info + if world_size > 1: + tot_loss.reduce(loss.device) + loss_value = tot_loss["loss"] / tot_loss["frames"] + if loss_value < params.best_valid_loss: + params.best_valid_epoch = params.cur_epoch + params.best_valid_loss = loss_value + + if params.visualize: + output_dir = Path( + f"{params.exp_dir}/eval/step-{params.batch_idx_train:06d}" + ) + output_dir.mkdir(parents=True, exist_ok=True) + if isinstance(model, DDP): + model.module.visualize(predicts, batch, output_dir=output_dir) + else: + model.visualize(predicts, batch, output_dir=output_dir) + + return tot_loss + + +def train_one_epoch( + params: AttributeDict, + model: Union[nn.Module, DDP], + tokenizer: TextTokenCollater, + optimizer: torch.optim.Optimizer, + scheduler: LRSchedulerType, + train_dl: torch.utils.data.DataLoader, + valid_dl: torch.utils.data.DataLoader, + rng: random.Random, + scaler: GradScaler, + model_avg: Optional[nn.Module] = None, + tb_writer: Optional[SummaryWriter] = None, + world_size: int = 1, + rank: int = 0, +) -> None: + """Train the model for one epoch. + + The training loss from the mean of all frames is saved in + `params.train_loss`. It runs the validation process every + `params.valid_interval` batches. + + Args: + params: + It is returned by :func:`get_params`. + model: + The model for training. + optimizer: + The optimizer we are using. + scheduler: + The learning rate scheduler, we call step() every step. + train_dl: + Dataloader for the training dataset. + valid_dl: + Dataloader for the validation dataset. + rng: + Random for selecting. + scaler: + The scaler used for mix precision training. + model_avg: + The stored model averaged from the start of training. + tb_writer: + Writer to write log messages to tensorboard. + world_size: + Number of nodes in DDP training. If it is 1, DDP is disabled. + rank: + The rank of the node in DDP training. If no DDP is used, it should + be set to 0. + """ + model.train() + tot_loss = MetricsTracker() + iter_dl = iter(train_dl) + + dtype, enabled = torch.float32, False + if params.dtype in ["bfloat16", "bf16"]: + dtype, enabled = torch.bfloat16, True + elif params.dtype in ["float16", "fp16"]: + dtype, enabled = torch.float16, True + + batch_idx = 0 + while True: + try: + batch = next(iter_dl) + except StopIteration: + logging.info("Reaches end of dataloader.") + break + + batch_idx += 1 + + params.batch_idx_train += 1 + batch_size = len(batch["text"]) + + try: + with torch.cuda.amp.autocast(dtype=dtype, enabled=enabled): + _, loss, loss_info = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + batch=batch, + is_training=True, + ) + # summary stats + tot_loss = ( + tot_loss * (1 - 1 / params.reset_interval) + ) + loss_info * (1 / params.reset_interval) + + # NOTE: We use reduction==sum and loss is computed over utterances + # in the batch and there is no normalization to it so far. + + scaler.scale(loss).backward() + if params.batch_idx_train >= params.accumulate_grad_steps: + if ( + params.batch_idx_train % params.accumulate_grad_steps + == 0 + ): + if params.optimizer_name not in ["ScaledAdam", "Eve"]: + # Unscales the gradients of optimizer's assigned params in-place + scaler.unscale_(optimizer) + # Since the gradients of optimizer's assigned params are unscaled, clips as usual: + torch.nn.utils.clip_grad_norm_( + model.parameters(), 1.0 + ) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + for k in range(params.accumulate_grad_steps): + if isinstance(scheduler, Eden): + scheduler.step_batch(params.batch_idx_train) + else: + scheduler.step() + + set_batch_count(model, params.batch_idx_train) + except: # noqa + display_and_save_batch(batch, params=params) + raise + + if params.average_period > 0: + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.average_period == 0 + ): + # Perform Operation in rank 0 + if rank == 0: + update_averaged_model( + params=params, + model_cur=model, + model_avg=model_avg, + ) + + if ( + params.batch_idx_train > 0 + and params.batch_idx_train % params.save_every_n == 0 + ): + # Perform Operation in rank 0 + if rank == 0: + save_checkpoint_with_global_batch_idx( + out_dir=params.exp_dir, + global_batch_idx=params.batch_idx_train, + model=model, + model_avg=model_avg, + params=params, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + remove_checkpoints( + out_dir=params.exp_dir, + topk=params.keep_last_k, + rank=rank, + ) + + if batch_idx % 100 == 0 and params.dtype in ["float16", "fp16"]: + # If the grad scale was less than 1, try increasing it. The _growth_interval + # of the grad scaler is configurable, but we can't configure it to have different + # behavior depending on the current grad scale. + cur_grad_scale = scaler._scale.item() + if cur_grad_scale < 1.0 or ( + cur_grad_scale < 8.0 and batch_idx % 400 == 0 + ): + scaler.update(cur_grad_scale * 2.0) + + if cur_grad_scale < 0.01: + logging.warning(f"Grad scale is small: {cur_grad_scale}") + if cur_grad_scale < 1.0e-05: + raise RuntimeError( + f"grad_scale is too small, exiting: {cur_grad_scale}" + ) + + if batch_idx % params.log_interval == 0: + cur_lr = scheduler.get_last_lr()[0] + cur_grad_scale = ( + scaler._scale.item() + if params.dtype in ["float16", "fp16"] + else 1.0 + ) + + logging.info( + f"Epoch {params.cur_epoch}, " + f"batch {batch_idx}, train_loss[{loss_info}], " + f"tot_loss[{tot_loss}], " + f"batch size: {batch_size}, " + f"lr: {cur_lr:.2e}" + + ( + f", grad_scale: {cur_grad_scale}" + if params.dtype in ["float16", "fp16"] + else "" + ) + ) + + if tb_writer is not None: + tb_writer.add_scalar( + "train/learning_rate", cur_lr, params.batch_idx_train + ) + loss_info.write_summary( + tb_writer, + "train/current_", + params.batch_idx_train, + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + tot_loss.write_summary( + tb_writer, "train/tot_", params.batch_idx_train + ) + if params.dtype in ["float16", "fp16"]: + tb_writer.add_scalar( + "train/grad_scale", + cur_grad_scale, + params.batch_idx_train, + ) + + if params.batch_idx_train % params.valid_interval == 0: + # Calculate validation loss in Rank 0 + model.eval() + logging.info("Computing validation loss") + with torch.cuda.amp.autocast(dtype=dtype): + valid_info = compute_validation_loss( + params=params, + model=model, + tokenizer=tokenizer, + valid_dl=valid_dl, + world_size=world_size, + ) + logging.info( + f"Epoch {params.cur_epoch}, validation: {valid_info}" + ) + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + if tb_writer is not None: + valid_info.write_summary( + tb_writer, "train/valid_", params.batch_idx_train + ) + + model.train() + + loss_value = tot_loss["loss"] / tot_loss["frames"] + params.train_loss = loss_value + if params.train_loss < params.best_train_loss: + params.best_train_epoch = params.cur_epoch + params.best_train_loss = params.train_loss + + +def filter_short_and_long_utterances( + cuts: CutSet, min_duration: float, max_duration: float +) -> CutSet: + def remove_short_and_long_utt(c: Cut): + # Keep only utterances with duration between 0.6 second and 20 seconds + if c.duration < min_duration or c.duration > max_duration: + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) + return False + return True + + cuts = cuts.filter(remove_short_and_long_utt) + + return cuts + + +def run(rank, world_size, args): + """ + Args: + rank: + It is a value between 0 and `world_size-1`, which is + passed automatically by `mp.spawn()` in :func:`main`. + The node with rank 0 is responsible for saving checkpoint. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + fix_random_seed(params.seed) + rng = random.Random(params.seed) + if world_size > 1: + setup_dist(rank, world_size, params.master_port) + + setup_logger(f"{params.exp_dir}/log/log-train") + logging.info("Training started") + + if args.tensorboard and rank == 0: + if params.train_stage: + tb_writer = SummaryWriter( + log_dir=f"{params.exp_dir}/tensorboard_stage{params.train_stage}" + ) + else: + tb_writer = SummaryWriter(log_dir=f"{params.exp_dir}/tensorboard") + else: + tb_writer = None + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + # https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + + logging.info(f"Device: {device}") + + tokenizer = get_text_token_collater(params.text_tokens) + logging.info(params) + + logging.info("About to create model") + + model = VALLE( + params.decoder_dim, + params.nhead, + params.num_decoder_layers, + norm_first=params.norm_first, + add_prenet=params.add_prenet, + prefix_mode=params.prefix_mode, + share_embedding=params.share_embedding, + nar_scale_factor=params.scale_factor, + prepend_bos=params.prepend_bos, + num_quantizers=params.num_quantizers, + ) + + with open(f"{params.exp_dir}/model.txt", "w") as f: + print(model) + print(model, file=f) + + num_param = sum([p.numel() for p in model.parameters()]) + logging.info(f"Number of model parameters: {num_param}") + + assert params.save_every_n >= params.average_period + model_avg: Optional[nn.Module] = None + if rank == 0 and params.average_period > 0: + # model_avg is only used with rank 0 + model_avg = copy.deepcopy(model).to(torch.float64) + + assert params.start_epoch > 0, params.start_epoch + checkpoints = load_checkpoint_if_available( + params=params, model=model, model_avg=model_avg + ) + + model.to(device) + if world_size > 1: + logging.info("Using DDP") + model = DDP(model, device_ids=[rank], find_unused_parameters=True) + + if params.train_stage: + _model = model.module if isinstance(model, DDP) else model + model_parameters = _model.stage_parameters(params.train_stage) + else: + model_parameters = model.parameters() + + if params.optimizer_name == "ScaledAdam": + parameters_names = [] + if params.train_stage: # != 0 + _model = model.module if isinstance(model, DDP) else model + parameters_names.append( + [ + name_param_pair[0] + for name_param_pair in _model.stage_named_parameters( + params.train_stage + ) + ] + ) + else: + parameters_names.append( + [ + name_param_pair[0] + for name_param_pair in model.named_parameters() + ] + ) + + optimizer = ScaledAdam( + model_parameters, + lr=params.base_lr, + betas=(0.9, 0.95), + clipping_scale=2.0, + parameters_names=parameters_names, + show_dominant_parameters=False, + clipping_update_period=1000, + ) + elif params.optimizer_name == "AdamW": + optimizer = torch.optim.AdamW( + model_parameters, + lr=params.base_lr, + betas=(0.9, 0.95), + weight_decay=1e-2, + eps=1e-8, + ) + elif params.optimizer_name == "Adam": + optimizer = torch.optim.Adam( + model_parameters, + lr=params.base_lr, + betas=(0.9, 0.95), + eps=1e-8, + ) + else: + raise NotImplementedError() + + scheduler = Eden(optimizer, 5000, 4, warmup_batches=params.warmup_steps) + optimizer.zero_grad() + + if checkpoints and "optimizer" in checkpoints: + logging.info("Loading optimizer state dict") + optimizer.load_state_dict(checkpoints["optimizer"]) + + if ( + checkpoints + and "scheduler" in checkpoints + and checkpoints["scheduler"] is not None + ): + logging.info("Loading scheduler state dict") + scheduler.load_state_dict(checkpoints["scheduler"]) + + if params.inf_check: + register_inf_check_hooks(model) + + if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: + sampler_state_dict = checkpoints["sampler"] + else: + sampler_state_dict = None + + dataset = TtsDataModule(args) + train_cuts = dataset.train_cuts() + valid_cuts = dataset.dev_cuts() + + train_cuts = filter_short_and_long_utterances( + train_cuts, params.filter_min_duration, params.filter_max_duration + ) + valid_cuts = filter_short_and_long_utterances( + valid_cuts, params.filter_min_duration, params.filter_max_duration + ) + + train_dl = dataset.train_dataloaders( + train_cuts, sampler_state_dict=sampler_state_dict + ) + valid_dl = dataset.valid_dataloaders(valid_cuts) + + if params.oom_check: + scan_pessimistic_batches_for_oom( + model=model, + tokenizer=tokenizer, + train_dl=train_dl, + optimizer=optimizer, + params=params, + ) + + scaler = GradScaler( + enabled=(params.dtype in ["fp16", "float16"]), init_scale=1.0 + ) + if checkpoints and "grad_scaler" in checkpoints: + logging.info("Loading grad scaler state dict") + scaler.load_state_dict(checkpoints["grad_scaler"]) + + for epoch in range(params.start_epoch, params.num_epochs + 1): + if isinstance(scheduler, Eden): + scheduler.step_epoch(epoch - 1) + + fix_random_seed(params.seed + epoch - 1) + train_dl.sampler.set_epoch(epoch - 1) + + if tb_writer is not None: + tb_writer.add_scalar("train/epoch", epoch, params.batch_idx_train) + + params.cur_epoch = epoch + + train_one_epoch( + params=params, + model=model, + tokenizer=tokenizer, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + train_dl=train_dl, + valid_dl=valid_dl, + rng=rng, + scaler=scaler, + tb_writer=tb_writer, + world_size=world_size, + rank=rank, + ) + + save_checkpoint( + params=params, + model=model, + model_avg=model_avg, + optimizer=optimizer, + scheduler=scheduler, + sampler=train_dl.sampler, + scaler=scaler, + rank=rank, + ) + + logging.info("Done!") + + if world_size > 1: + torch.distributed.barrier() + cleanup_dist() + + +def display_and_save_batch( + batch: dict, + params: AttributeDict, +) -> None: + """Display the batch statistics and save the batch into disk. + + Args: + batch: + A batch of data. See `lhotse.dataset.K2SpeechRecognitionDataset()` + for the content in it. + params: + Parameters for training. See :func:`get_params`. + """ + from lhotse.utils import uuid4 + + filename = f"{params.exp_dir}/batch-{uuid4()}.pt" + logging.info(f"Saving batch to {filename}") + torch.save(batch, filename) + + +def scan_pessimistic_batches_for_oom( + model: Union[nn.Module, DDP], + tokenizer: TextTokenCollater, + train_dl: torch.utils.data.DataLoader, + optimizer: torch.optim.Optimizer, + params: AttributeDict, +): + from lhotse.dataset import find_pessimistic_batches + + logging.info( + "Sanity check -- see if any of the batches in epoch 1 would cause OOM." + ) + batches, crit_values = find_pessimistic_batches(train_dl.sampler) + + dtype = torch.float32 + if params.dtype in ["bfloat16", "bf16"]: + dtype = torch.bfloat16 + elif params.dtype in ["float16", "fp16"]: + dtype = torch.float16 + + for criterion, cuts in batches.items(): + batch = train_dl.dataset[cuts] + try: + with torch.cuda.amp.autocast(dtype=dtype): + _, loss, _ = compute_loss( + params=params, + model=model, + tokenizer=tokenizer, + batch=batch, + is_training=True, + ) + loss.backward() + optimizer.zero_grad() + except Exception as e: + if "CUDA out of memory" in str(e): + logging.error( + "Your GPU ran out of memory with the current " + "max_duration setting. We recommend decreasing " + "max_duration and trying again.\n" + f"Failing criterion: {criterion} " + f"(={crit_values[criterion]}) ..." + ) + display_and_save_batch(batch, params=params) + raise + logging.info( + f"Maximum memory allocated so far is {torch.cuda.max_memory_allocated()//1000000}MB" + ) + + +def main(): + parser = get_parser() + TtsDataModule.add_arguments(parser) + args = parser.parse_args() + args.exp_dir = Path(args.exp_dir) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + mp.spawn(run, args=(world_size, args), nprocs=world_size, join=True) + else: + run(rank=0, world_size=1, args=args) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/libritts/TTS/valle/tts_datamodule.py b/egs/libritts/TTS/valle/tts_datamodule.py new file mode 100644 index 0000000000..958e07c670 --- /dev/null +++ b/egs/libritts/TTS/valle/tts_datamodule.py @@ -0,0 +1,344 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022-2024 Xiaomi Corporation (Authors: Mingshuang Luo, +# Zengwei Yao, +# Zengrui Jin,) +# Copyright 2023 (authors: Feiteng Li) +# Copyright 2024 (Author: Yuekai Zhang) +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Any, Dict, Optional + +import torch +from lhotse import CutSet, Spectrogram, SpectrogramConfig, load_manifest_lazy +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + DynamicBucketingSampler, + PrecomputedFeatures, + SimpleCutSampler, + SpeechSynthesisDataset, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + OnTheFlyFeatures, +) +from lhotse.features.io import KaldiReader +from lhotse.utils import fix_random_seed +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class _SeedWorkers: + def __init__(self, seed: int): + self.seed = seed + + def __call__(self, worker_id: int): + fix_random_seed(self.seed + worker_id) + +class TtsDataModule: + """ + DataModule for tts experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in TTS + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="TTS data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/tokenized"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--speaker-embeds", + type=Path, + default=Path("exp/xvector_nnet_1a/"), + help="Path to directory with speaker embeddings.", + ) + group.add_argument( + "--max-duration", + type=int, + default=200.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--num-buckets", + type=int, + default=30, + help="The number of buckets for the DynamicBucketingSampler" + "(you might want to increase it for larger datasets).", + ) + + group.add_argument( + "--on-the-fly-feats", + type=str2bool, + default=False, + help="When enabled, use on-the-fly cut mixing and feature " + "extraction. Will drop existing precomputed feature manifests " + "if available.", + ) + group.add_argument( + "--shuffle", + type=str2bool, + default=True, + help="When enabled (=default), the examples will be " + "shuffled for each epoch.", + ) + group.add_argument( + "--drop-last", + type=str2bool, + default=True, + help="Whether to drop last batch. Used by sampler.", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=4, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--enable-spec-aug", + type=str2bool, + default=False, + help="When enabled, use SpecAugment for training dataset.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + group.add_argument( + "--dataset", + type=str, + default="libritts", + help="--input-strategy PromptedPrecomputedFeatures needs dataset name to prepare prompts.", + ) + + parser.add_argument( + "--sampling-rate", + type=int, + default=24000, + help="""Audio sampling rate.""", + ) + + def train_dataloaders( + self, + cuts_train: CutSet, + sampler_state_dict: Optional[Dict[str, Any]] = None, + ) -> DataLoader: + """ + Args: + cuts_train: + CutSet for training. + sampler_state_dict: + The state dict for the training sampler. + """ + logging.info("About to create train dataset") + train = SpeechSynthesisDataset( + return_text=False, + return_tokens=False, + return_spk_ids=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + + if self.args.on_the_fly_feats: + raise NotImplementedError + + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + train_sampler = DynamicBucketingSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + num_buckets=self.args.num_buckets, + buffer_size=self.args.num_buckets * 2000, + shuffle_buffer_size=self.args.num_buckets * 5000, + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + ) + logging.info("About to create train dataloader") + + if sampler_state_dict is not None: + logging.info("Loading sampler state dict") + train_sampler.load_state_dict(sampler_state_dict) + + # 'seed' is derived from the current random state, which will have + # previously been set in the main process. + seed = torch.randint(0, 100000, ()).item() + worker_init_fn = _SeedWorkers(seed) + + train_dl = DataLoader( + train, + sampler=train_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + worker_init_fn=worker_init_fn, + ) + + return train_dl + + def dev_dataloaders(self, cuts_valid: CutSet) -> DataLoader: + logging.info("About to create dev dataset") + if self.args.on_the_fly_feats: + raise NotImplementedError + else: + validate = SpeechSynthesisDataset( + return_text=False, + return_tokens=False, + return_spk_ids=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + dev_sampler = DynamicBucketingSampler( + cuts_valid, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create valid dataloader") + dev_dl = DataLoader( + validate, + sampler=dev_sampler, + batch_size=None, + num_workers=self.args.num_workers, + persistent_workers=False, + ) + + return dev_dl + + def test_dataloaders(self, cuts: CutSet) -> DataLoader: + logging.info("About to create test dataset") + if self.args.on_the_fly_feats: + raise NotImplementedError + else: + test = SpeechSynthesisDataset( + return_text=False, + return_tokens=False, + return_spk_ids=False, + feature_input_strategy=eval(self.args.input_strategy)(), + return_cuts=self.args.return_cuts, + ) + test_sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + ) + logging.info("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=test_sampler, + num_workers=self.args.num_workers, + ) + return test_dl + + @lru_cache() + def train_cuts(self) -> CutSet: + logging.info("About to get train cuts") + return load_manifest_lazy( + self.args.manifest_dir / "cuts_train.jsonl.gz" + ) + + @lru_cache() + def dev_cuts(self) -> CutSet: + logging.info("About to get dev cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_dev.jsonl.gz") + + @lru_cache() + def test_cuts(self) -> CutSet: + logging.info("About to get test cuts") + return load_manifest_lazy(self.args.manifest_dir / "cuts_test.jsonl.gz") + + @lru_cache() + def dev_clean_cuts(self) -> CutSet: + logging.info("About to get dev-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_dev-clean.jsonl.gz" + ) + + @lru_cache() + def dev_other_cuts(self) -> CutSet: + logging.info("About to get dev-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_dev-other.jsonl.gz" + ) + + @lru_cache() + def test_clean_cuts(self) -> CutSet: + logging.info("About to get test-clean cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_test-clean.jsonl.gz" + ) + + @lru_cache() + def test_other_cuts(self) -> CutSet: + logging.info("About to get test-other cuts") + return load_manifest_lazy( + self.args.manifest_dir / "libritts_cuts_test-other.jsonl.gz" + ) \ No newline at end of file diff --git a/egs/libritts/TTS/valle/valle.py b/egs/libritts/TTS/valle/valle.py new file mode 100644 index 0000000000..3ed3534b50 --- /dev/null +++ b/egs/libritts/TTS/valle/valle.py @@ -0,0 +1,1822 @@ +# Copyright 2023 (authors: Feiteng Li) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +from typing import Dict, Iterator, List, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from icefall.utils import make_pad_mask +from torchmetrics.classification import MulticlassAccuracy + +# from valle.data.input_strategies import PromptedFeatures +# from valle.modules.embedding import SinePositionalEmbedding, TokenEmbedding +# from valle.modules.transformer import ( +# AdaptiveLayerNorm, +# LayerNorm, +# TransformerEncoder, +# TransformerEncoderLayer, +# ) + +from .macros import NUM_AUDIO_TOKENS, NUM_TEXT_TOKENS +from .visualizer import visualize + +class PromptedFeatures: + def __init__(self, prompts, features): + self.prompts = prompts + self.features = features + + def to(self, device): + return PromptedFeatures( + self.prompts.to(device), self.features.to(device) + ) + + def sum(self): + return self.features.sum() + + @property + def ndim(self): + return self.features.ndim + + @property + def data(self): + return (self.prompts, self.features) + +class TokenEmbedding(nn.Module): + def __init__( + self, + dim_model: int, + vocab_size: int, + dropout: float = 0.0, + ): + super().__init__() + + self.vocab_size = vocab_size + self.dim_model = dim_model + + self.dropout = torch.nn.Dropout(p=dropout) + self.word_embeddings = nn.Embedding(self.vocab_size, self.dim_model) + + @property + def weight(self) -> torch.Tensor: + return self.word_embeddings.weight + + def embedding(self, index: int) -> torch.Tensor: + return self.word_embeddings.weight[index : index + 1] + + def forward(self, x: torch.Tensor): + X = self.word_embeddings(x) + X = self.dropout(X) + + return X + + +class SinePositionalEmbedding(nn.Module): + def __init__( + self, + dim_model: int, + dropout: float = 0.0, + scale: bool = False, + alpha: bool = False, + ): + super().__init__() + self.dim_model = dim_model + self.x_scale = math.sqrt(dim_model) if scale else 1.0 + self.alpha = nn.Parameter(torch.ones(1), requires_grad=alpha) + self.dropout = torch.nn.Dropout(p=dropout) + + self.reverse = False + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, 4000)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.dim_model) + if self.reverse: + position = torch.arange( + x.size(1) - 1, -1, -1.0, dtype=torch.float32 + ).unsqueeze(1) + else: + position = torch.arange( + 0, x.size(1), dtype=torch.float32 + ).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.dim_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.dim_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype).detach() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + self.extend_pe(x) + output = x.unsqueeze(-1) if x.ndim == 2 else x + output = output * self.x_scale + self.alpha * self.pe[:, : x.size(1)] + return self.dropout(output) + +class Transpose(nn.Identity): + """(N, T, D) -> (N, D, T)""" + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return input.transpose(1, 2) + +_shape_t = Union[int, List[int], torch.Size] + +class MultiheadAttention(Module): + r"""Allows the model to jointly attend to information + from different representation subspaces as described in the paper: + `Attention Is All You Need `_. + + Multi-Head Attention is defined as: + + .. math:: + \text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O + + where :math:`head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)`. + + ``forward()`` will use a special optimized implementation if all of the following + conditions are met: + + - self attention is being computed (i.e., ``query``, ``key``, and ``value`` are the same tensor. This + restriction will be loosened in the future.) + - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor argument ``requires_grad`` + - training is disabled (using ``.eval()``) + - dropout is 0 + - ``add_bias_kv`` is ``False`` + - ``add_zero_attn`` is ``False`` + - ``batch_first`` is ``True`` and the input is batched + - ``kdim`` and ``vdim`` are equal to ``embed_dim`` + - at most one of ``key_padding_mask`` or ``attn_mask`` is passed + - if a `NestedTensor `_ is passed, neither ``key_padding_mask`` + nor ``attn_mask`` is passed + + If the optimized implementation is in use, a + `NestedTensor `_ can be passed for + ``query``/``key``/``value`` to represent padding more efficiently than using a + padding mask. In this case, a `NestedTensor `_ + will be returned, and an additional speedup proportional to the fraction of the input + that is padding can be expected. + + Args: + embed_dim: Total dimension of the model. + num_heads: Number of parallel attention heads. Note that ``embed_dim`` will be split + across ``num_heads`` (i.e. each head will have dimension ``embed_dim // num_heads``). + dropout: Dropout probability on ``attn_output_weights``. Default: ``0.0`` (no dropout). + bias: If specified, adds bias to input / output projection layers. Default: ``True``. + add_bias_kv: If specified, adds bias to the key and value sequences at dim=0. Default: ``False``. + add_zero_attn: If specified, adds a new batch of zeros to the key and value sequences at dim=1. + Default: ``False``. + kdim: Total number of features for keys. Default: ``None`` (uses ``kdim=embed_dim``). + vdim: Total number of features for values. Default: ``None`` (uses ``vdim=embed_dim``). + batch_first: If ``True``, then the input and output tensors are provided + as (batch, seq, feature). Default: ``False`` (seq, batch, feature). + + Examples:: + + >>> # xdoctest: +SKIP + >>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + >>> attn_output, attn_output_weights = multihead_attn(query, key, value) + + """ + __constants__ = ["batch_first"] + bias_k: Optional[torch.Tensor] + bias_v: Optional[torch.Tensor] + + def __init__( + self, + embed_dim, + num_heads, + dropout=0.0, + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + linear1_cls=Linear, + linear2_cls=Linear, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(MultiheadAttention, self).__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = ( + self.kdim == embed_dim and self.vdim == embed_dim + ) + + self.num_heads = num_heads + self.dropout = dropout + self.batch_first = batch_first + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + + if add_bias_kv: + self.bias_k = Parameter( + torch.empty((1, 1, embed_dim), **factory_kwargs) + ) + self.bias_v = Parameter( + torch.empty((1, 1, embed_dim), **factory_kwargs) + ) + else: + self.bias_k = self.bias_v = None + + if linear1_cls == Linear: + if not self._qkv_same_embed_dim: + self.q_proj_weight = Parameter( + torch.empty((embed_dim, embed_dim), **factory_kwargs) + ) + self.k_proj_weight = Parameter( + torch.empty((embed_dim, self.kdim), **factory_kwargs) + ) + self.v_proj_weight = Parameter( + torch.empty((embed_dim, self.vdim), **factory_kwargs) + ) + self.register_parameter("in_proj_weight", None) + else: + self.in_proj_weight = Parameter( + torch.empty((3 * embed_dim, embed_dim), **factory_kwargs) + ) + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias: + self.in_proj_bias = Parameter( + torch.empty(3 * embed_dim, **factory_kwargs) + ) + else: + self.register_parameter("in_proj_bias", None) + self.out_proj = NonDynamicallyQuantizableLinear( + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) + + self._reset_parameters() + else: + if not self._qkv_same_embed_dim: + raise NotImplementedError + else: + self.in_proj_linear = linear1_cls( + embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs + ) + self.in_proj_weight = self.in_proj_linear.weight + + self.register_parameter("q_proj_weight", None) + self.register_parameter("k_proj_weight", None) + self.register_parameter("v_proj_weight", None) + + if bias: + self.in_proj_bias = self.in_proj_linear.bias + else: + self.register_parameter("in_proj_bias", None) + + self.out_proj = linear2_cls( + embed_dim, embed_dim, bias=bias, **factory_kwargs + ) + + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + self.add_zero_attn = add_zero_attn + + def _reset_parameters(self): + if self._qkv_same_embed_dim: + xavier_uniform_(self.in_proj_weight) + else: + xavier_uniform_(self.q_proj_weight) + xavier_uniform_(self.k_proj_weight) + xavier_uniform_(self.v_proj_weight) + + if self.in_proj_bias is not None: + constant_(self.in_proj_bias, 0.0) + constant_(self.out_proj.bias, 0.0) + + if self.bias_k is not None: + xavier_normal_(self.bias_k) + if self.bias_v is not None: + xavier_normal_(self.bias_v) + + def __setstate__(self, state): + # Support loading old MultiheadAttention checkpoints generated by v1.1.0 + if "_qkv_same_embed_dim" not in state: + state["_qkv_same_embed_dim"] = True + + super(MultiheadAttention, self).__setstate__(state) + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + key_padding_mask: Optional[Tensor] = None, + need_weights: bool = True, + attn_mask: Optional[Tensor] = None, + average_attn_weights: bool = True, + ) -> Tuple[Tensor, Optional[Tensor]]: + r""" + Args: + query: Query embeddings of shape :math:`(L, E_q)` for unbatched input, :math:`(L, N, E_q)` when ``batch_first=False`` + or :math:`(N, L, E_q)` when ``batch_first=True``, where :math:`L` is the target sequence length, + :math:`N` is the batch size, and :math:`E_q` is the query embedding dimension ``embed_dim``. + Queries are compared against key-value pairs to produce the output. + See "Attention Is All You Need" for more details. + key: Key embeddings of shape :math:`(S, E_k)` for unbatched input, :math:`(S, N, E_k)` when ``batch_first=False`` + or :math:`(N, S, E_k)` when ``batch_first=True``, where :math:`S` is the source sequence length, + :math:`N` is the batch size, and :math:`E_k` is the key embedding dimension ``kdim``. + See "Attention Is All You Need" for more details. + value: Value embeddings of shape :math:`(S, E_v)` for unbatched input, :math:`(S, N, E_v)` when + ``batch_first=False`` or :math:`(N, S, E_v)` when ``batch_first=True``, where :math:`S` is the source + sequence length, :math:`N` is the batch size, and :math:`E_v` is the value embedding dimension ``vdim``. + See "Attention Is All You Need" for more details. + key_padding_mask: If specified, a mask of shape :math:`(N, S)` indicating which elements within ``key`` + to ignore for the purpose of attention (i.e. treat as "padding"). For unbatched `query`, shape should be :math:`(S)`. + Binary and byte masks are supported. + For a binary mask, a ``True`` value indicates that the corresponding ``key`` value will be ignored for + the purpose of attention. For a float mask, it will be directly added to the corresponding ``key`` value. + need_weights: If specified, returns ``attn_output_weights`` in addition to ``attn_outputs``. + Default: ``True``. + attn_mask: If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape + :math:`(L, S)` or :math:`(N\cdot\text{num\_heads}, L, S)`, where :math:`N` is the batch size, + :math:`L` is the target sequence length, and :math:`S` is the source sequence length. A 2D mask will be + broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. + Binary, byte, and float masks are supported. For a binary mask, a ``True`` value indicates that the + corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the + corresponding position is not allowed to attend. For a float mask, the mask values will be added to + the attention weight. + average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across + heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an + effect when ``need_weights=True``. Default: ``True`` (i.e. average weights across heads) + + Outputs: + - **attn_output** - Attention outputs of shape :math:`(L, E)` when input is unbatched, + :math:`(L, N, E)` when ``batch_first=False`` or :math:`(N, L, E)` when ``batch_first=True``, + where :math:`L` is the target sequence length, :math:`N` is the batch size, and :math:`E` is the + embedding dimension ``embed_dim``. + - **attn_output_weights** - Only returned when ``need_weights=True``. If ``average_attn_weights=True``, + returns attention weights averaged across heads of shape :math:`(L, S)` when input is unbatched or + :math:`(N, L, S)`, where :math:`N` is the batch size, :math:`L` is the target sequence length, and + :math:`S` is the source sequence length. If ``average_attn_weights=False``, returns attention weights per + head of shape :math:`(\text{num\_heads}, L, S)` when input is unbatched or :math:`(N, \text{num\_heads}, L, S)`. + + .. note:: + `batch_first` argument is ignored for unbatched inputs. + """ + is_batched = query.dim() == 3 + if key_padding_mask is not None: + _kpm_dtype = key_padding_mask.dtype + if _kpm_dtype != torch.bool and not torch.is_floating_point( + key_padding_mask + ): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported" + ) + why_not_fast_path = "" + if not is_batched: + why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}" + elif query is not key or key is not value: + # When lifting this restriction, don't forget to either + # enforce that the dtypes all match or test cases where + # they don't! + why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)" + elif ( + self.in_proj_bias is not None + and query.dtype != self.in_proj_bias.dtype + ): + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match" + elif ( + self.in_proj_weight is not None + and query.dtype != self.in_proj_weight.dtype + ): + # this case will fail anyway, but at least they'll get a useful error message. + why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match" + elif self.training: + why_not_fast_path = "training is enabled" + elif not self.batch_first: + why_not_fast_path = "batch_first was not True" + elif self.bias_k is not None: + why_not_fast_path = "self.bias_k was not None" + elif self.bias_v is not None: + why_not_fast_path = "self.bias_v was not None" + elif self.dropout: + why_not_fast_path = f"dropout was {self.dropout}, required zero" + elif self.add_zero_attn: + why_not_fast_path = "add_zero_attn was enabled" + elif not self._qkv_same_embed_dim: + why_not_fast_path = "_qkv_same_embed_dim was not True" + elif attn_mask is not None: + why_not_fast_path = "attn_mask was not None" + elif query.is_nested and key_padding_mask is not None: + why_not_fast_path = ( + "key_padding_mask is not supported with NestedTensor input" + ) + elif self.num_heads % 2 == 1: + why_not_fast_path = "num_heads is odd" + elif torch.is_autocast_enabled(): + why_not_fast_path = "autocast is enabled" + + if not why_not_fast_path: + tensor_args = ( + query, + key, + value, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj.weight, + self.out_proj.bias, + ) + # We have to use list comprehensions below because TorchScript does not support + # generator expressions. + if torch.overrides.has_torch_function(tensor_args): + why_not_fast_path = "some Tensor argument has_torch_function" + elif not all( + [ + (x is None or x.is_cuda or "cpu" in str(x.device)) + for x in tensor_args + ] + ): + why_not_fast_path = ( + "some Tensor argument is neither CUDA nor CPU" + ) + elif torch.is_grad_enabled() and any( + [x is not None and x.requires_grad for x in tensor_args] + ): + why_not_fast_path = ( + "grad is enabled and at least one of query or the " + "input/output projection weights or biases requires_grad" + ) + if not why_not_fast_path: + return torch._native_multi_head_attention( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.out_proj.weight, + self.out_proj.bias, + key_padding_mask + if key_padding_mask is not None + else attn_mask, + need_weights, + average_attn_weights, + 1 + if key_padding_mask is not None + else 0 + if attn_mask is not None + else None, + ) + + any_nested = query.is_nested or key.is_nested or value.is_nested + assert not any_nested, ( + "MultiheadAttention does not support NestedTensor outside of its fast path. " + + f"The fast path was not hit because {why_not_fast_path}" + ) + + if self.batch_first and is_batched: + # make sure that the transpose op does not affect the "is" property + if key is value: + if query is key: + query = key = value = query.transpose(1, 0) + else: + query, key = [x.transpose(1, 0) for x in (query, key)] + value = key + else: + query, key, value = [ + x.transpose(1, 0) for x in (query, key, value) + ] + + if not self._qkv_same_embed_dim: + attn_output, attn_output_weights = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight, + average_attn_weights=average_attn_weights, + ) + else: + attn_output, attn_output_weights = F.multi_head_attention_forward( + query, + key, + value, + self.embed_dim, + self.num_heads, + self.in_proj_weight, + self.in_proj_bias, + self.bias_k, + self.bias_v, + self.add_zero_attn, + self.dropout, + self.out_proj.weight, + self.out_proj.bias, + training=self.training, + key_padding_mask=key_padding_mask, + need_weights=need_weights, + attn_mask=attn_mask, + average_attn_weights=average_attn_weights, + ) + if self.batch_first and is_batched: + return attn_output.transpose(1, 0), attn_output_weights + else: + return attn_output, attn_output_weights + + +class LayerNorm(nn.Module): + __constants__ = ["normalized_shape", "eps", "elementwise_affine"] + normalized_shape: Tuple[int, ...] + eps: float + elementwise_affine: bool + + def __init__( + self, + normalized_shape: _shape_t, + eps: float = 1e-5, + elementwise_affine: bool = True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + # mypy error: incompatible types in assignment + normalized_shape = (normalized_shape,) # type: ignore[assignment] + self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type] + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = nn.Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + self.bias = nn.Parameter( + torch.empty(self.normalized_shape, **factory_kwargs) + ) + else: + self.register_parameter("weight", None) + self.register_parameter("bias", None) + + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, input: Tensor, embedding: Any = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + return ( + F.layer_norm( + input, + self.normalized_shape, + self.weight, + self.bias, + self.eps, + ), + embedding, + ) + + assert embedding is None + return F.layer_norm( + input, self.normalized_shape, self.weight, self.bias, self.eps + ) + + def extra_repr(self) -> str: + return ( + "{normalized_shape}, eps={eps}, " + "elementwise_affine={elementwise_affine}".format(**self.__dict__) + ) + + +class AdaptiveLayerNorm(nn.Module): + r"""Adaptive Layer Normalization""" + + def __init__(self, d_model, norm) -> None: + super(AdaptiveLayerNorm, self).__init__() + self.project_layer = nn.Linear(d_model, 2 * d_model) + self.norm = norm + self.d_model = d_model + self.eps = self.norm.eps + + def forward(self, input: Tensor, embedding: Tensor = None) -> Tensor: + if isinstance(input, tuple): + input, embedding = input + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return (weight * self.norm(input) + bias, embedding) + + weight, bias = torch.split( + self.project_layer(embedding), + split_size_or_sections=self.d_model, + dim=-1, + ) + return weight * self.norm(input) + bias + + +class TransformerEncoderLayer(nn.Module): + __constants__ = ["batch_first", "norm_first"] + + def __init__( + self, + d_model: int, + nhead: int, + dim_feedforward: int = 2048, + dropout: float = 0.1, + activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, + batch_first: bool = False, + norm_first: bool = False, + device=None, + dtype=None, + linear1_self_attention_cls: nn.Module = nn.Linear, + linear2_self_attention_cls: nn.Module = nn.Linear, + linear1_feedforward_cls: nn.Module = nn.Linear, + linear2_feedforward_cls: nn.Module = nn.Linear, + layer_norm_cls: nn.Module = LayerNorm, + layer_norm_eps: float = 1e-5, + adaptive_layer_norm=False, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super(TransformerEncoderLayer, self).__init__() + self.self_attn = MultiheadAttention( + d_model, + nhead, + dropout=dropout, + batch_first=batch_first, + linear1_cls=linear1_self_attention_cls, + linear2_cls=linear2_self_attention_cls, + **factory_kwargs, + ) + + # Implementation of Feedforward model + self.linear1 = linear1_feedforward_cls( + d_model, dim_feedforward, **factory_kwargs + ) + self.dropout = nn.Dropout(dropout) + self.linear2 = linear2_feedforward_cls( + dim_feedforward, d_model, **factory_kwargs + ) + + self.norm_first = norm_first + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + # Legacy string support for activation function. + if isinstance(activation, str): + activation = _get_activation_fn(activation) + elif isinstance(activation, partial): + activation = activation(d_model) + # elif activation == BalancedDoubleSwish: + # activation = BalancedDoubleSwish(d_model) + + # # We can't test self.activation in forward() in TorchScript, + # # so stash some information about it instead. + # if activation is F.relu or isinstance(activation, torch.nn.ReLU): + # self.activation_relu_or_gelu = 1 + # elif activation is F.gelu or isinstance(activation, torch.nn.GELU): + # self.activation_relu_or_gelu = 2 + # else: + # self.activation_relu_or_gelu = 0 + self.activation = activation + + norm1 = layer_norm_cls(d_model, eps=layer_norm_eps, **factory_kwargs) + if layer_norm_cls == IdentityNorm: + norm2 = BalancedBasicNorm( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + else: + norm2 = layer_norm_cls( + d_model, eps=layer_norm_eps, **factory_kwargs + ) + + if adaptive_layer_norm: + self.norm1 = AdaptiveLayerNorm(d_model, norm1) + self.norm2 = AdaptiveLayerNorm(d_model, norm2) + else: + self.norm1 = norm1 + self.norm2 = norm2 + + def __setstate__(self, state): + super(TransformerEncoderLayer, self).__setstate__(state) + if not hasattr(self, "activation"): + self.activation = F.relu + + def forward( + self, + src: Tensor, + src_mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + ) -> Tensor: + r"""Pass the input through the encoder layer. + + Args: + src: the sequence to the encoder layer (required). + src_mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + + Shape: + see the docs in Transformer class. + """ + x, stage_embedding = src, None + is_src_tuple = False + if isinstance(src, tuple): + x, stage_embedding = src + is_src_tuple = True + + if src_key_padding_mask is not None: + _skpm_dtype = src_key_padding_mask.dtype + if _skpm_dtype != torch.bool and not torch.is_floating_point( + src_key_padding_mask + ): + raise AssertionError( + "only bool and floating types of key_padding_mask are supported" + ) + + if self.norm_first: + x = x + self._sa_block( + self.norm1(x, stage_embedding), + src_mask, + src_key_padding_mask, + ) + x = x + self._ff_block(self.norm2(x, stage_embedding)) + else: + x = self.norm1( + x + self._sa_block(x, src_mask, src_key_padding_mask), + stage_embedding, + ) + x = self.norm2(x + self._ff_block(x), stage_embedding) + + if is_src_tuple: + return (x, stage_embedding) + return x + + # self-attention block + def _sa_block( + self, + x: Tensor, + attn_mask: Optional[Tensor], + key_padding_mask: Optional[Tensor], + ) -> Tensor: + x = self.self_attn( + x, + x, + x, + attn_mask=attn_mask, + key_padding_mask=key_padding_mask, + need_weights=False, + )[0] + return self.dropout1(x) + + # feed forward block + def _ff_block(self, x: Tensor) -> Tensor: + x = self.linear2(self.dropout(self.activation(self.linear1(x)))) + return self.dropout2(x) + + +class TransformerEncoder(nn.Module): + r"""TransformerEncoder is a stack of N encoder layers. Users can build the + BERT(https://arxiv.org/abs/1810.04805) model with corresponding parameters. + + Args: + encoder_layer: an instance of the TransformerEncoderLayer() class (required). + num_layers: the number of sub-encoder-layers in the encoder (required). + norm: the layer normalization component (optional). + enable_nested_tensor: if True, input will automatically convert to nested tensor + (and convert back on output). This will improve the overall performance of + TransformerEncoder when padding rate is high. Default: ``True`` (enabled). + + Examples:: + >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) + >>> transformer_encoder = TransformerEncoder(encoder_layer, num_layers=6) + >>> src = torch.rand(10, 32, 512) + >>> out = transformer_encoder(src) + """ + __constants__ = ["norm"] + + def __init__(self, encoder_layer, num_layers, norm=None): + super(TransformerEncoder, self).__init__() + self.layers = _get_clones(encoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward( + self, + src: Tensor, + mask: Optional[Tensor] = None, + src_key_padding_mask: Optional[Tensor] = None, + return_layer_states: bool = False, + ) -> Tensor: + r"""Pass the input through the encoder layers in turn. + + Args: + src: the sequence to the encoder (required). + mask: the mask for the src sequence (optional). + src_key_padding_mask: the mask for the src keys per batch (optional). + return_layer_states: return layers' state (optional). + + Shape: + see the docs in Transformer class. + """ + if return_layer_states: + layer_states = [] # layers' output + output = src + for mod in self.layers: + output = mod( + output, + src_mask=mask, + src_key_padding_mask=src_key_padding_mask, + ) + layer_states.append(output[0]) + + if self.norm is not None: + output = self.norm(output) + + return layer_states, output + + output = src + for mod in self.layers: + output = mod( + output, src_mask=mask, src_key_padding_mask=src_key_padding_mask + ) + + if self.norm is not None: + output = self.norm(output) + + return output + + + +def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation: str) -> Callable[[Tensor], Tensor]: + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + + raise RuntimeError( + "activation should be relu/gelu, not {}".format(activation) + ) + +class VALLE(nn.Module): + """It implements https://arxiv.org/abs/2301.02111 + "Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers" + """ + + def __init__( + self, + d_model: int, + nhead: int, + num_layers: int, + norm_first: bool = True, + add_prenet: bool = False, + decoder_cls=TransformerEncoder, + decoder_layer_cls=TransformerEncoderLayer, + prefix_mode: int = 0, + share_embedding: bool = True, + nar_scale_factor: float = 1.0, + prepend_bos: bool = False, + num_quantizers: int = 8, + **kwargs, + ): + """ + Args: + d_model: + The number of expected features in the input (required). + nhead: + The number of heads in the multiheadattention models (required). + num_layers: + The number of sub-decoder-layers in the decoder (required). + """ + super().__init__() + nar_d_model = int(d_model * nar_scale_factor) + + self.ar_text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x + self.nar_text_embedding = TokenEmbedding(nar_d_model, NUM_TEXT_TOKENS) + + # ID NUM_AUDIO_TOKENS -> PAD + # ID NUM_AUDIO_TOKENS + 1 -> BOS + self.ar_audio_prepend_bos = prepend_bos + self.ar_audio_embedding = TokenEmbedding( + d_model, NUM_AUDIO_TOKENS + 1 + int(prepend_bos) + ) + + # PreNet + if add_prenet: + self.ar_text_prenet = nn.Sequential( + Transpose(), + nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), + nn.BatchNorm1d(d_model), + nn.ReLU(), + nn.Dropout(0.5), + nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), + nn.BatchNorm1d(d_model), + nn.ReLU(), + nn.Dropout(0.5), + nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"), + nn.BatchNorm1d(d_model), + nn.ReLU(), + nn.Dropout(0.5), + Transpose(), + nn.Linear(d_model, d_model), + ) + + self.ar_audio_prenet = nn.Sequential( + nn.Linear(d_model, 256), + nn.ReLU(), + nn.Dropout(0.25), + nn.Linear(256, 256), + nn.ReLU(), + nn.Dropout(0.25), + nn.Linear(256, d_model), + ) + else: + self.ar_text_prenet = nn.Identity() + self.ar_audio_prenet = nn.Identity() + + self.ar_text_position = SinePositionalEmbedding( + d_model, + dropout=0.1, + scale=False, + alpha=True, + ) + self.ar_audio_position = SinePositionalEmbedding( + d_model, + dropout=0.1, + scale=False, + alpha=True, + ) + + self.ar_decoder = decoder_cls( + decoder_layer_cls( + d_model, + nhead, + dim_feedforward=d_model * 4, + dropout=0.1, + batch_first=True, + norm_first=norm_first, + ), + num_layers=num_layers, + norm=LayerNorm(d_model) if norm_first else None, + ) + self.ar_predict_layer = nn.Linear( + d_model, NUM_AUDIO_TOKENS + 1, bias=False + ) + + self.ar_accuracy_metric = MulticlassAccuracy( + NUM_AUDIO_TOKENS + 1, + top_k=10, + average="micro", + multidim_average="global", + ignore_index=NUM_AUDIO_TOKENS, + ) + + self.rng = random.Random(0) + self.num_heads = nhead + self.prefix_mode = prefix_mode + self.num_quantizers = num_quantizers + + assert num_quantizers >= 1 + if num_quantizers > 1: + self.nar_audio_embeddings = nn.ModuleList( + [TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS + 1)] + + [ + TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS) + for i in range(num_quantizers - 1) + ] + ) # W_a + + # PreNet + if add_prenet: + self.nar_text_prenet = nn.Sequential( + Transpose(), + nn.Conv1d( + nar_d_model, nar_d_model, kernel_size=5, padding="same" + ), + nn.BatchNorm1d(nar_d_model), + nn.ReLU(), + nn.Dropout(0.5), + nn.Conv1d( + nar_d_model, nar_d_model, kernel_size=5, padding="same" + ), + nn.BatchNorm1d(nar_d_model), + nn.ReLU(), + nn.Dropout(0.5), + nn.Conv1d( + nar_d_model, nar_d_model, kernel_size=5, padding="same" + ), + nn.BatchNorm1d(nar_d_model), + nn.ReLU(), + nn.Dropout(0.5), + Transpose(), + nn.Linear(nar_d_model, nar_d_model), + ) + self.nar_audio_prenet = nn.Sequential( + nn.Linear(nar_d_model, 256), + nn.ReLU(), + nn.Dropout(0.25), + nn.Linear(256, 256), + nn.ReLU(), + nn.Dropout(0.25), + nn.Linear(256, nar_d_model), + ) + else: + self.nar_text_prenet = nn.Identity() + self.nar_audio_prenet = nn.Identity() + + self.nar_text_position = SinePositionalEmbedding( + nar_d_model, + dropout=0.0, + scale=False, + alpha=False, + ) + self.nar_audio_position = SinePositionalEmbedding( + nar_d_model, + dropout=0.1, + scale=False, + alpha=False, + ) + + self.nar_decoder = decoder_cls( + decoder_layer_cls( + nar_d_model, + int(nhead * nar_scale_factor), + dim_feedforward=nar_d_model * 4, + dropout=0.1, + batch_first=True, + norm_first=norm_first, + adaptive_layer_norm=True, + ), + num_layers=int(num_layers * nar_scale_factor), + norm=AdaptiveLayerNorm( + nar_d_model, norm=nn.LayerNorm(nar_d_model) + ) + if norm_first + else None, + ) + self.nar_predict_layers = nn.ModuleList( + [ + nn.Linear(nar_d_model, NUM_AUDIO_TOKENS, bias=False) + for i in range(num_quantizers - 1) + ] + ) + self.nar_stage_embeddings = nn.ModuleList( + [ + TokenEmbedding(nar_d_model, 1) + for i in range(num_quantizers - 1) + ] + ) + + if share_embedding: + # We share the parameters of the output projection layer with the parameters of the acoustic embedding Wa + # NOTE(Feiteng): In the experiment, this undermines accuracy + # self.ar_predict_layer.weight = self.ar_audio_embedding.weight + + # We also share the parameters of the acoustic embedding layer and the output prediction layer, + # which means the weights of the j-th prediction layer are the same as the (j + 1)-th acoustic embedding layer. + for j in range(0, num_quantizers - 2): + self.nar_predict_layers[ + j + ].weight = self.nar_audio_embeddings[j + 2].weight + + self.nar_accuracy_metric = MulticlassAccuracy( + NUM_AUDIO_TOKENS + 1, + top_k=10, + average="micro", + multidim_average="global", + ignore_index=NUM_AUDIO_TOKENS, + ) + + def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]: + assert stage > 0 + if stage == 1: + for name, param in self.named_parameters(): + if name.startswith("ar_"): + print(f" AR parameter: {name}") + yield param + + if stage == 2: + for name, param in self.named_parameters(): + if name.startswith("nar_"): + print(f"NAR parameter: {name}") + yield param + + def stage_named_parameters( + self, stage: int = 1 + ) -> Iterator[Tuple[str, nn.Parameter]]: + assert stage > 0 + if stage == 1: + for pair in self.named_parameters(): + if pair[0].startswith("ar_"): + yield pair + + if stage == 2: + for pair in self.named_parameters(): + if pair[0].startswith("nar_"): + yield pair + + def pad_y_eos(self, y, y_mask_int, eos_id): + targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad( + y_mask_int, (0, 1), value=1 + ) + # inputs, targets + if self.ar_audio_prepend_bos: + return ( + F.pad(targets[:, :-1], (1, 0), value=NUM_AUDIO_TOKENS + 1), + targets, + ) + + return targets[:, :-1], targets[:, 1:] + + def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes): + # 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds + # from the same utterance. + # We implement this differently. + if self.prefix_mode == 0: + # no prefix + prefix_len = 0 + y_emb = self.nar_audio_embeddings[0](y) + for j in range(1, nar_stage): + # Formula (4) (5) + y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j]) + elif self.prefix_mode == 1: + # prefix at begining + int_low = (0.25 * y_lens.min()).type(torch.int64).item() + prefix_len = torch.randint(int_low, int_low * 2, size=()).item() + prefix_len = min(prefix_len, 225) # 24000/320 * 3s = 225 frames + + y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len]) + y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:]) + for j in range(1, self.num_quantizers): + y_prompts += self.nar_audio_embeddings[j]( + codes[:, :prefix_len, j] + ) + if j < nar_stage: + y_emb += self.nar_audio_embeddings[j]( + codes[:, prefix_len:, j] + ) + y_emb = torch.concat([y_prompts, y_emb], axis=1) + elif self.prefix_mode in [2, 4]: + if self.prefix_mode == 2: + # random prefix + prefix_len = min(225, int(0.25 * y_lens.min().item())) + + y_prompts_codes = [] + for b in range(codes.shape[0]): + start = self.rng.randint(0, y_lens[b].item() - prefix_len) + y_prompts_codes.append( + torch.clone(codes[b, start : start + prefix_len]) + ) + codes[ + b, start : start + prefix_len, nar_stage + ] = NUM_AUDIO_TOKENS + y_prompts_codes = torch.stack(y_prompts_codes, dim=0) + else: + prefix_len = y_prompts_codes.shape[1] + + y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0]) + y_emb = self.nar_audio_embeddings[0](y) + for j in range(1, self.num_quantizers): + y_prompts += self.nar_audio_embeddings[j]( + y_prompts_codes[..., j] + ) + if j < nar_stage: + y_emb += self.nar_audio_embeddings[j](codes[..., j]) + y_emb = torch.concat([y_prompts, y_emb], axis=1) + else: + raise ValueError + + return y_emb, prefix_len + + def visualize( + self, + predicts: Tuple[torch.Tensor], + batch: Dict[str, Union[List, torch.Tensor]], + output_dir: str, + limit: int = 4, + ) -> None: + visualize(predicts, batch, output_dir, limit=limit) + + def forward( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: Union[torch.Tensor, PromptedFeatures], + y_lens: Union[torch.Tensor, PromptedFeatures], + reduction: str = "sum", + train_stage: int = 0, + **kwargs, + ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]: + """ + Args: + x: + A 2-D tensor of shape (N, S). + x_lens: + A 1-D tensor of shape (N,). It contains the number of tokens in `x` + before padding. + y: + A 3-D tensor of shape (N, T, 8). + y_lens: + A 1-D tensor of shape (N,). It contains the number of tokens in `x` + before padding. + train_stage: + 0: AR & NAR modules, 1: AR modules, 2: NAR modules + Returns: + Return the predicted audio code matrix, cross-entropy loss and Top-10 accuracy. + """ + assert x.ndim == 2, x.shape + assert x_lens.ndim == 1, x_lens.shape + + y_prompts_codes = None + if isinstance(y, PromptedFeatures): + y_prompts_codes, y = y.data + prompts_len, y_lens = y_lens.data + assert prompts_len.min() == prompts_len.max() + assert self.prefix_mode == 4 + y_prompts_codes = y_prompts_codes.type(torch.int64) + + assert y.ndim == 3, y.shape + assert y_lens.ndim == 1, y_lens.shape + + # NOTE: x has been padded in TextTokenCollater + x_mask = make_pad_mask(x_lens).to(x.device) + y_mask = make_pad_mask(y_lens).to(y.device) + y_mask_int = y_mask.type(torch.int64) + + text = x + codes = y.type(torch.int64) * (1 - y_mask_int.unsqueeze(dim=-1)) + + y, targets = self.pad_y_eos( + codes[..., 0], y_mask_int, eos_id=NUM_AUDIO_TOKENS + ) + + x_len = x_lens.max() + + metrics = {} + total_loss = 0.0 + + xy_padding_mask = torch.concat([x_mask, y_mask], dim=1) + if self.ar_audio_prepend_bos: + ar_xy_padding_mask = torch.concat( + [x_mask, F.pad(y_mask, (1, 0), value=False)], dim=1 + ) + else: + ar_xy_padding_mask = xy_padding_mask + # AR Decoder + if train_stage in [0, 1]: + x = self.ar_text_embedding(text) + x = self.ar_text_prenet(x) + x = self.ar_text_position(x) + + y_len = y_lens.max() + int(self.ar_audio_prepend_bos) + + x_attn_mask = F.pad( + torch.zeros((x_len, x_len), dtype=torch.bool, device=x.device), + (0, y_len), + value=True, + ) + y_attn_mask = F.pad( + torch.triu( + torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), + diagonal=1, + ), + (x_len, 0), + value=False, + ) + xy_attn_mask = torch.concat([x_attn_mask, y_attn_mask], dim=0) + + # merge key padding and attention masks + bsz, src_len = x.shape[0], x_len + y_len + _xy_padding_mask = ( + ar_xy_padding_mask.view(bsz, 1, 1, src_len) + .expand(-1, self.num_heads, -1, -1) + .reshape(bsz * self.num_heads, 1, src_len) + ) + xy_attn_mask = xy_attn_mask.logical_or(_xy_padding_mask) + + new_attn_mask = torch.zeros_like(xy_attn_mask, dtype=x.dtype) + new_attn_mask.masked_fill_(xy_attn_mask, float("-inf")) + xy_attn_mask = new_attn_mask + + y_emb = self.ar_audio_embedding(y) + y_emb = self.ar_audio_prenet(y_emb) + y_pos = self.ar_audio_position(y_emb) + + xy_pos = torch.concat([x, y_pos], dim=1) + + xy_dec, _ = self.ar_decoder( + (xy_pos, None), + mask=xy_attn_mask, + # src_key_padding_mask=xy_padding_mask, + # is_causal=True, + ) + logits = self.ar_predict_layer(xy_dec[:, x_len:]).permute(0, 2, 1) + # loss + total_loss = F.cross_entropy(logits, targets, reduction=reduction) + + metrics["ArTop10Accuracy"] = self.ar_accuracy_metric( + logits.detach(), targets + ).item() * y_lens.sum().type(torch.float32) + + if self.num_quantizers == 1: + return ((x, codes), total_loss, metrics) + + # Non-AR Decoders + if self.ar_audio_prepend_bos: + y = y[:, 1:] + if train_stage in [0, 2]: + num_nar_layers = self.num_quantizers - 1 + nar_stage = self.rng.choices( + [_k for _k in range(1, self.num_quantizers)], + weights=[1.0 / num_nar_layers] * num_nar_layers, + k=1, + )[0] + + x = self.nar_text_embedding(text) + x = self.nar_text_prenet(x) + x = self.nar_text_position(x) + + y_emb, prefix_len = self._prepare_prompts( + y, y_lens, codes, nar_stage, y_prompts_codes + ) + + y_len = y_lens.max() + targets = codes[..., nar_stage] + NUM_AUDIO_TOKENS * y_mask_int + if self.prefix_mode in [2, 4]: + xy_padding_mask = torch.concat( + [ + x_mask, + F.pad(y_mask, (y_emb.shape[1] - y_len, 0), value=False), + ], + dim=1, + ) + elif self.prefix_mode == 1: + targets = targets[:, prefix_len:] + + y_pos = self.nar_audio_prenet(y_emb) + y_pos = self.nar_audio_position(y_pos) + xy_pos = torch.concat([x, y_pos], dim=1) + xy_dec, _ = self.nar_decoder( + (xy_pos, self.nar_stage_embeddings[nar_stage - 1].weight), + src_key_padding_mask=xy_padding_mask, + # is_causal=False, + ) + xy_dec = xy_dec[:, x_lens.max() + prefix_len :] + if self.prefix_mode == 4: + prefix_len = 0 # reset for Top10Accuracy metric + logits = self.nar_predict_layers[nar_stage - 1](xy_dec).permute( + 0, 2, 1 + ) + + # loss + total_length = (y_lens).sum().type(torch.float32) + total_loss += ( + F.cross_entropy( + logits, + targets, + ignore_index=NUM_AUDIO_TOKENS, + reduction=reduction, + ) + * (total_length / (total_length - prefix_len * x.shape[0])) + ) + metrics["NarTop10Accuracy"] = ( + self.nar_accuracy_metric( + F.pad( + logits.detach(), + (0, 0, 0, 1, 0, 0), + value=logits.min().cpu().item(), + ), + targets, + ).item() + * total_length + ) + + if train_stage == 0: + total_loss = total_loss / 2.0 + + return ((x, codes), total_loss, metrics) + + def inference( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: torch.Tensor, + enroll_x_lens: torch.Tensor, + top_k: int = -100, + temperature: float = 1.0, + top_p: float = 1.0, + ) -> torch.Tensor: + """ + Args: + x: + A 2-D tensor of shape (1, S). + x_lens: + A 1-D tensor of shape (1,). It contains the number of tokens in `x` + before padding. + y: + A 3-D tensor of shape (1, T, 8). + top_k: (`optional`) int + The number of highest probability tokens to keep for top-k-filtering. Default to -100. + temperature: (`optional`) float + The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. + Returns: + Return the predicted audio code matrix. + """ + assert x.ndim == 2, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.ndim == 3, y.shape + assert y.shape[0] == 1, y.shape + + assert torch.all(x_lens > 0) + + # NOTE: x has been padded in TextTokenCollater + text = x + x = self.ar_text_embedding(text) + x = self.ar_text_prenet(x) + x = self.ar_text_position(x) + + text_len = x_lens.max() + prompts = y + prefix_len = y.shape[1] + + # AR Decoder + # TODO: Managing decoder steps avoid repetitive computation + y = prompts[..., 0] + if self.ar_audio_prepend_bos: + y = F.pad(y, (1, 0), value=NUM_AUDIO_TOKENS + 1) + + x_len = x_lens.max() + x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool) + + while True: + y_emb = self.ar_audio_embedding(y) + y_emb = self.ar_audio_prenet(y_emb) + y_pos = self.ar_audio_position(y_emb) + xy_pos = torch.concat([x, y_pos], dim=1) + + y_len = y.shape[1] + x_attn_mask_pad = F.pad( + x_attn_mask, + (0, y_len), + value=True, + ) + y_attn_mask = F.pad( + torch.triu( + torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1 + ), + (x_len, 0), + value=False, + ) + xy_attn_mask = torch.concat( + [x_attn_mask_pad, y_attn_mask], dim=0 + ).to(y.device) + + xy_dec, _ = self.ar_decoder( + (xy_pos, None), + mask=xy_attn_mask, + ) + logits = self.ar_predict_layer(xy_dec[:, -1]) + ras=True + samples = topk_sampling( + logits, top_k=top_k, top_p=top_p, temperature=temperature, repetition_aware_sampling=ras, preceding_tokens=y + ) + + if ( + torch.argmax(logits, dim=-1)[0] == NUM_AUDIO_TOKENS + or samples[0, 0] == NUM_AUDIO_TOKENS + or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16 + ): + if prompts.shape[1] == y.shape[1]: + raise SyntaxError( + "well trained model shouldn't reach here." + ) + + print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]") + break + + y = torch.concat([y, samples], dim=1) + + codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]] + if self.num_quantizers == 1: + return torch.stack(codes, dim=-1) + + # Non-AR Decoders + y_emb = self.nar_audio_embeddings[0]( + y[:, int(self.ar_audio_prepend_bos) :] + ) + + if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes + enrolled_len = enroll_x_lens.max().item() + # SOS + Synthesis Text + EOS + text = torch.concat( + [ + text[:, :1], + text[:, enrolled_len - 1 :], + ], + dim=1, + ) + text_len = text_len - (enrolled_len - 2) + assert text.shape[0] == 1 + + x = self.nar_text_embedding(text) + x = self.nar_text_prenet(x) + x = self.nar_text_position(x) + + if self.prefix_mode == 0: + for i, (predict_layer, embedding_layer) in enumerate( + zip( + self.nar_predict_layers, + self.nar_audio_embeddings[1:], + ) + ): + y_pos = self.nar_audio_prenet(y_emb) + y_pos = self.nar_audio_position(y_pos) + xy_pos = torch.concat([x, y_pos], dim=1) + + xy_dec, _ = self.nar_decoder( + (xy_pos, self.nar_stage_embeddings[i].weight) + ) + logits = predict_layer(xy_dec[:, text_len + prefix_len :]) + + samples = torch.argmax(logits, dim=-1) + codes.append(samples) + + if i < self.num_quantizers - 2: + y_emb[:, :prefix_len] += embedding_layer( + prompts[..., i + 1] + ) + y_emb[:, prefix_len:] += embedding_layer(samples) + else: + for j in range(1, self.num_quantizers): + y_emb[:, :prefix_len] += self.nar_audio_embeddings[j]( + prompts[..., j] + ) + + for i, (predict_layer, embedding_layer) in enumerate( + zip( + self.nar_predict_layers, + self.nar_audio_embeddings[1:], + ) + ): + y_pos = self.nar_audio_prenet(y_emb) + y_pos = self.nar_audio_position(y_pos) + xy_pos = torch.concat([x, y_pos], dim=1) + + xy_dec, _ = self.nar_decoder( + (xy_pos, self.nar_stage_embeddings[i].weight) + ) + logits = predict_layer(xy_dec[:, text_len + prefix_len :]) + + samples = torch.argmax(logits, dim=-1) + codes.append(samples) + + if i < self.num_quantizers - 2: + y_emb[:, prefix_len:] += embedding_layer(samples) + + assert len(codes) == self.num_quantizers + return torch.stack(codes, dim=-1) + + def continual( + self, + x: torch.Tensor, + x_lens: torch.Tensor, + y: torch.Tensor, + ) -> torch.Tensor: + """ + Args: + x: + A 2-D tensor of shape (1, S). + x_lens: + A 1-D tensor of shape (1,). It contains the number of tokens in `x` + before padding. + y: + A 3-D tensor of shape (1, T, 8). + Returns: + Return the predicted audio code matrix. + """ + assert x.ndim == 2, x.shape + assert x_lens.ndim == 1, x_lens.shape + assert y.ndim == 3, y.shape + assert y.shape[0] == 1, y.shape + + assert torch.all(x_lens > 0) + assert self.num_quantizers == 8 + + # NOTE: x has been padded in TextTokenCollater + text = x + x = self.ar_text_embedding(text) + x = self.ar_text_prenet(x) + x = self.ar_text_position(x) + + text_len = x_lens.max() + + prefix_len = min(int(y.shape[1] * 0.5), 3 * 75) + + # AR Decoder + prompts = y[:, :prefix_len] + + codes = [y[:, prefix_len:, 0]] + # Non-AR Decoders + x = self.nar_text_embedding(text) + x = self.nar_text_prenet(x) + x = self.nar_text_position(x) + + y_emb = self.nar_audio_embeddings[0](y[..., 0]) + + if self.prefix_mode == 0: + for i, (predict_layer, embedding_layer) in enumerate( + zip( + self.nar_predict_layers, + self.nar_audio_embeddings[1:], + ) + ): + y_pos = self.nar_audio_position(y_emb) + y_pos = self.nar_audio_prenet(y_pos) + xy_pos = torch.concat([x, y_pos], dim=1) + + xy_dec, _ = self.nar_decoder( + (xy_pos, self.nar_stage_embeddings[i].weight) + ) + logits = predict_layer(xy_dec[:, text_len + prefix_len :]) + + samples = torch.argmax(logits, dim=-1) + codes.append(samples) + + if i < 6: + y_emb[:, :prefix_len] += embedding_layer( + prompts[..., i + 1] + ) + y_emb[:, prefix_len:] += embedding_layer(samples) + else: + for j in range(1, 8): + y_emb[:, :prefix_len] += self.nar_audio_embeddings[j]( + prompts[..., j] + ) + + for i, (predict_layer, embedding_layer) in enumerate( + zip( + self.nar_predict_layers, + self.nar_audio_embeddings[1:], + ) + ): + y_pos = self.nar_audio_prenet(y_emb) + y_pos = self.nar_audio_position(y_pos) + xy_pos = torch.concat([x, y_pos], dim=1) + + xy_dec, _ = self.nar_decoder( + (xy_pos, self.nar_stage_embeddings[i].weight) + ) + logits = predict_layer(xy_dec[:, text_len + prefix_len :]) + + samples = torch.argmax(logits, dim=-1) + codes.append(samples) + + if i < 6: + y_emb[:, prefix_len:] += embedding_layer(samples) + + assert len(codes) == 8 + return torch.stack(codes, dim=-1) + + +# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py +def top_k_top_p_filtering( + logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 +): + """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (batch size, vocabulary size) + if top_k > 0: keep only top k tokens with highest probability (top-k filtering). + if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + Make sure we keep at least min_tokens_to_keep per batch example in the output + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + if top_k > 0: + top_k = min( + max(top_k, min_tokens_to_keep), logits.size(-1) + ) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum( + F.softmax(sorted_logits, dim=-1), dim=-1 + ) + + # Remove tokens with cumulative probability above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens_to_keep > 1: + # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ + ..., :-1 + ].clone() + sorted_indices_to_remove[..., 0] = 0 + + # scatter sorted tensors to original indexing + indices_to_remove = sorted_indices_to_remove.scatter( + 1, sorted_indices, sorted_indices_to_remove + ) + logits[indices_to_remove] = filter_value + return logits + + +def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0, repetition_aware_sampling=False, preceding_tokens=None): + # temperature: (`optional`) float + # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0. + # top_k: (`optional`) int + # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50. + # top_p: (`optional`) float + # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1. + + # Temperature (higher temperature => more likely to sample low probability tokens) + if temperature != 1.0: + logits = logits / temperature + # Top-p/top-k filtering + logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p, min_tokens_to_keep=2) + # Sample + probs = F.softmax(logits, dim=-1) + # print top 10 value and index + print("top 10 value and index", torch.topk(probs, 10), top_p) + tokens = torch.multinomial(probs, num_samples=1) + + if repetition_aware_sampling: + window_size = 10 + threshold = 0.1 + # we first generate the target code ct′ + # by nucleus sampling with a pre-defined top-p value v. Then, we + # calculate the repetition ratio r of token ct′ + # in the preceding code sequence with a window size K. + # If the ratio r exceeds a pre-defined repetition threshold ratio tn, we replace the target code ct′ + # by + # random sampling from p(ct′ + # |x, c window_size: + preceding_tokens = preceding_tokens[:, -window_size:] + if preceding_tokens.shape[1] > 0: + for i, item in enumerate(preceding_tokens): + # check if the repeat ratio exceeds the threshold + if (item == tokens[i]).sum() / window_size > threshold: + # replace the target code ct′ by random sampling + # make sure we don't sample the same token, by setting the probability of the token to 0 + # logits[i][tokens[i]] = -float("Inf") + probs = F.softmax(logits[i], dim=-1) + token_new = torch.multinomial(probs, num_samples=1) + + print(f"Repetition Aware Sampling: {item}, {tokens[i]} -> {token_new}") + print("probs", probs, logits.shape) + tokens[i] = token_new + else: + print(f"Not trigger: {i}, {item}, {tokens[i]}") + return tokens