diff --git a/.github/workflows/pypi-release.yml b/.github/workflows/pypi-release.yml index 49a5b3004e..2bbcf3cd70 100644 --- a/.github/workflows/pypi-release.yml +++ b/.github/workflows/pypi-release.yml @@ -10,7 +10,7 @@ jobs: build-sdist: runs-on: ubuntu-20.04 steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - name: Verify tag matches version run: | set -ex @@ -38,7 +38,7 @@ jobs: matrix: python-version: ["3.9", "3.10", "3.11"] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v3 - uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} diff --git a/TTS/.models.json b/TTS/.models.json index 13da715b89..5f4008fb01 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -3,14 +3,14 @@ "multilingual": { "multi-dataset": { "xtts_v2": { - "description": "XTTS-v2 by Coqui with 16 languages.", + "description": "XTTS-v2.0.2 by Coqui with 16 languages.", "hf_url": [ "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/model.pth", "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/config.json", "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json", "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/hash.md5" ], - "model_hash": "6a09d1ad43896f06041ed8195956c9698f13b6189dc80f1c74bdc2b8e8d15324", + "model_hash": "5ce0502bfe3bc88dc8d9312b12a7558c", "default_vocoder": null, "commit": "480a6cdf7", "license": "CPML", diff --git a/TTS/VERSION b/TTS/VERSION index 1b619f3482..752e630381 100644 --- a/TTS/VERSION +++ b/TTS/VERSION @@ -1 +1 @@ -0.20.5 +0.20.6 diff --git a/TTS/bin/extract_tts_spectrograms.py b/TTS/bin/extract_tts_spectrograms.py index 9eadee070e..c6048626b3 100755 --- a/TTS/bin/extract_tts_spectrograms.py +++ b/TTS/bin/extract_tts_spectrograms.py @@ -15,6 +15,7 @@ from TTS.tts.utils.speakers import SpeakerManager from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import quantize from TTS.utils.generic_utils import count_parameters use_cuda = torch.cuda.is_available() @@ -159,7 +160,7 @@ def inference( def extract_spectrograms( - data_loader, model, ap, output_path, quantized_wav=False, save_audio=False, debug=False, metada_name="metada.txt" + data_loader, model, ap, output_path, quantize_bits=0, save_audio=False, debug=False, metada_name="metada.txt" ): model.eval() export_metadata = [] @@ -196,8 +197,8 @@ def extract_spectrograms( _, wavq_path, mel_path, wav_gl_path, wav_path = set_filename(wav_file_path, output_path) # quantize and save wav - if quantized_wav: - wavq = ap.quantize(wav) + if quantize_bits > 0: + wavq = quantize(wav, quantize_bits) np.save(wavq_path, wavq) # save TTS mel @@ -263,7 +264,7 @@ def main(args): # pylint: disable=redefined-outer-name model, ap, args.output_path, - quantized_wav=args.quantized, + quantize_bits=args.quantize_bits, save_audio=args.save_audio, debug=args.debug, metada_name="metada.txt", @@ -277,7 +278,7 @@ def main(args): # pylint: disable=redefined-outer-name parser.add_argument("--output_path", type=str, help="Path to save mel specs", required=True) parser.add_argument("--debug", default=False, action="store_true", help="Save audio files for debug") parser.add_argument("--save_audio", default=False, action="store_true", help="Save audio files") - parser.add_argument("--quantized", action="store_true", help="Save quantized audio files") + parser.add_argument("--quantize_bits", type=int, default=0, help="Save quantized audio files if non-zero") parser.add_argument("--eval", type=bool, help="compute eval.", default=True) args = parser.parse_args() diff --git a/TTS/tts/layers/tortoise/diffusion.py b/TTS/tts/layers/tortoise/diffusion.py index cb350af779..7bea02ca08 100644 --- a/TTS/tts/layers/tortoise/diffusion.py +++ b/TTS/tts/layers/tortoise/diffusion.py @@ -13,12 +13,18 @@ import numpy as np import torch import torch as th -from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral from tqdm import tqdm from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper -K_DIFFUSION_SAMPLERS = {"k_euler_a": sample_euler_ancestral, "dpm++2m": sample_dpmpp_2m} +try: + from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral + + K_DIFFUSION_SAMPLERS = {"k_euler_a": sample_euler_ancestral, "dpm++2m": sample_dpmpp_2m} +except ImportError: + K_DIFFUSION_SAMPLERS = None + + SAMPLERS = ["dpm++2m", "p", "ddim"] @@ -531,6 +537,8 @@ def sample_loop(self, *args, **kwargs): if self.conditioning_free is not True: raise RuntimeError("cond_free must be true") with tqdm(total=self.num_timesteps) as pbar: + if K_DIFFUSION_SAMPLERS is None: + raise ModuleNotFoundError("Install k_diffusion for using k_diffusion samplers") return self.k_diffusion_sample_loop(K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs) else: raise RuntimeError("sampler not impl") diff --git a/TTS/tts/layers/xtts/gpt.py b/TTS/tts/layers/xtts/gpt.py index d914ebf90f..e7b186b858 100644 --- a/TTS/tts/layers/xtts/gpt.py +++ b/TTS/tts/layers/xtts/gpt.py @@ -441,7 +441,9 @@ def forward( audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_audio_token) # Pad mel codes with stop_audio_token - audio_codes = self.set_mel_padding(audio_codes, code_lengths - 3) # -3 to get the real code lengths without consider start and stop tokens that was not added yet + audio_codes = self.set_mel_padding( + audio_codes, code_lengths - 3 + ) # -3 to get the real code lengths without consider start and stop tokens that was not added yet # Build input and target tensors # Prepend start token to inputs and append stop token to targets diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py index 7726d829ac..5284874397 100644 --- a/TTS/tts/layers/xtts/tokenizer.py +++ b/TTS/tts/layers/xtts/tokenizer.py @@ -1,6 +1,6 @@ -import json import os import re +import textwrap from functools import cached_property import pypinyin @@ -8,10 +8,66 @@ from hangul_romanize import Transliter from hangul_romanize.rule import academic from num2words import num2words +from spacy.lang.ar import Arabic +from spacy.lang.en import English +from spacy.lang.es import Spanish +from spacy.lang.ja import Japanese +from spacy.lang.zh import Chinese from tokenizers import Tokenizer from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words + +def get_spacy_lang(lang): + if lang == "zh": + return Chinese() + elif lang == "ja": + return Japanese() + elif lang == "ar": + return Arabic() + elif lang == "es": + return Spanish() + else: + # For most languages, Enlish does the job + return English() + + +def split_sentence(text, lang, text_split_length=250): + """Preprocess the input text""" + text_splits = [] + if text_split_length is not None and len(text) >= text_split_length: + text_splits.append("") + nlp = get_spacy_lang(lang) + nlp.add_pipe("sentencizer") + doc = nlp(text) + for sentence in doc.sents: + if len(text_splits[-1]) + len(str(sentence)) <= text_split_length: + # if the last sentence + the current sentence is less than the text_split_length + # then add the current sentence to the last sentence + text_splits[-1] += " " + str(sentence) + text_splits[-1] = text_splits[-1].lstrip() + elif len(str(sentence)) > text_split_length: + # if the current sentence is greater than the text_split_length + for line in textwrap.wrap( + str(sentence), + width=text_split_length, + drop_whitespace=True, + break_on_hyphens=False, + tabsize=1, + ): + text_splits.append(str(line)) + else: + text_splits.append(str(sentence)) + + if len(text_splits) > 1: + if text_splits[0] == "": + del text_splits[0] + else: + text_splits = [text.lstrip()] + + return text_splits + + _whitespace_re = re.compile(r"\s+") # List of (regular expression, replacement) pairs for abbreviations: @@ -115,7 +171,7 @@ # There are not many common abbreviations in Arabic as in English. ] ], - "zh-cn": [ + "zh": [ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [ # Chinese doesn't typically use abbreviations in the same way as Latin-based scripts. @@ -280,7 +336,7 @@ def expand_abbreviations_multilingual(text, lang="en"): ("°", " درجة "), ] ], - "zh-cn": [ + "zh": [ # Chinese (re.compile(r"%s" % re.escape(x[0]), re.IGNORECASE), x[1]) for x in [ @@ -464,7 +520,7 @@ def _expand_number(m, lang="en"): def expand_numbers_multilingual(text, lang="en"): - if lang == "zh" or lang == "zh-cn": + if lang == "zh": text = zh_num2words()(text) else: if lang in ["en", "ru"]: @@ -525,7 +581,7 @@ def japanese_cleaners(text, katsu): return text -def korean_cleaners(text): +def korean_transliterate(text): r = Transliter(academic) return r.translit(text) @@ -546,7 +602,7 @@ def __init__(self, vocab_file=None): "it": 213, "pt": 203, "pl": 224, - "zh-cn": 82, + "zh": 82, "ar": 166, "cs": 186, "ru": 182, @@ -564,6 +620,7 @@ def katsu(self): return cutlet.Cutlet() def check_input_length(self, txt, lang): + lang = lang.split("-")[0] # remove the region limit = self.char_limits.get(lang, 250) if len(txt) > limit: print( @@ -571,21 +628,23 @@ def check_input_length(self, txt, lang): ) def preprocess_text(self, txt, lang): - if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh-cn", "zh-cn"}: + if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "ko"}: txt = multilingual_cleaners(txt, lang) - if lang in {"zh", "zh-cn"}: + if lang == "zh": txt = chinese_transliterate(txt) + if lang == "ko": + txt = korean_transliterate(txt) elif lang == "ja": txt = japanese_cleaners(txt, self.katsu) - elif lang == "ko": - txt = korean_cleaners(txt) else: raise NotImplementedError(f"Language '{lang}' is not supported.") return txt def encode(self, txt, lang): + lang = lang.split("-")[0] # remove the region self.check_input_length(txt, lang) txt = self.preprocess_text(txt, lang) + lang = "zh-cn" if lang == "zh" else lang txt = f"[{lang}]{txt}" txt = txt.replace(" ", "[SPACE]") return self.tokenizer.encode(txt).ids @@ -682,8 +741,8 @@ def test_expand_numbers_multilingual(): ("Dat wordt dan $20 meneer.", "Dat wordt dan twintig dollar meneer.", "nl"), ("Dat wordt dan 20€ meneer.", "Dat wordt dan twintig euro meneer.", "nl"), # Chinese (Simplified) - ("在12.5秒内", "在十二点五秒内", "zh-cn"), - ("有50名士兵", "有五十名士兵", "zh-cn"), + ("在12.5秒内", "在十二点五秒内", "zh"), + ("有50名士兵", "有五十名士兵", "zh"), # ("那将是$20先生", '那将是二十美元先生', 'zh'), currency doesn't work # ("那将是20€先生", '那将是二十欧元先生', 'zh'), # Turkish @@ -764,7 +823,7 @@ def test_symbols_multilingual(): ("Ik heb 14% batterij", "Ik heb 14 procent batterij", "nl"), ("Ik zie je @ het feest", "Ik zie je bij het feest", "nl"), ("لدي 14% في البطارية", "لدي 14 في المئة في البطارية", "ar"), - ("我的电量为 14%", "我的电量为 14 百分之", "zh-cn"), + ("我的电量为 14%", "我的电量为 14 百分之", "zh"), ("Pilim %14 dolu.", "Pilim yüzde 14 dolu.", "tr"), ("Az akkumulátorom töltöttsége 14%", "Az akkumulátorom töltöttsége 14 százalék", "hu"), ("배터리 잔량이 14%입니다.", "배터리 잔량이 14 퍼센트입니다.", "ko"), diff --git a/TTS/tts/layers/xtts/trainer/gpt_trainer.py b/TTS/tts/layers/xtts/trainer/gpt_trainer.py index 005b30bede..4789e1f43f 100644 --- a/TTS/tts/layers/xtts/trainer/gpt_trainer.py +++ b/TTS/tts/layers/xtts/trainer/gpt_trainer.py @@ -318,9 +318,10 @@ def eval_step(self, batch, criterion): batch["cond_idxs"] = None return self.train_step(batch, criterion) - def on_epoch_start(self, trainer): # pylint: disable=W0613 - # guarante that dvae will be in eval mode after .train() on evaluation end - self.dvae = self.dvae.eval() + def on_train_epoch_start(self, trainer): + trainer.model.eval() # the whole model to eval + # put gpt model in training mode + trainer.model.xtts.gpt.train() def on_init_end(self, trainer): # pylint: disable=W0613 # ignore similarities.pth on clearml save/upload diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index f37f08449d..208ec4d561 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -10,7 +10,7 @@ from TTS.tts.layers.xtts.gpt import GPT from TTS.tts.layers.xtts.hifigan_decoder import HifiDecoder from TTS.tts.layers.xtts.stream_generator import init_stream_support -from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer +from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, split_sentence from TTS.tts.models.base_tts import BaseTTS from TTS.utils.io import load_fsspec @@ -396,7 +396,7 @@ def inference_with_config(self, text, config, ref_audio_path, language, **kwargs inference with config """ assert ( - language in self.config.languages + "zh-cn" if language == "zh" else language in self.config.languages ), f" ❗ Language {language} is not supported. Supported languages are {self.config.languages}" # Use generally found best tuning knobs for generation. settings = { @@ -420,9 +420,9 @@ def full_inference( ref_audio_path, language, # GPT inference - temperature=0.65, - length_penalty=1, - repetition_penalty=2.0, + temperature=0.75, + length_penalty=1.0, + repetition_penalty=10.0, top_k=50, top_p=0.85, do_sample=True, @@ -502,71 +502,76 @@ def inference( gpt_cond_latent, speaker_embedding, # GPT inference - temperature=0.65, - length_penalty=1, - repetition_penalty=2.0, + temperature=0.75, + length_penalty=1.0, + repetition_penalty=10.0, top_k=50, top_p=0.85, do_sample=True, num_beams=1, speed=1.0, + enable_text_splitting=False, **hf_generate_kwargs, ): + language = language.split("-")[0] # remove the country code length_scale = 1.0 / max(speed, 0.05) - text = text.strip().lower() - text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) - - # print(" > Input text: ", text) - # print(" > Input text preprocessed: ",self.tokenizer.preprocess_text(text, language)) - # print(" > Input tokens: ", text_tokens) - # print(" > Decoded text: ", self.tokenizer.decode(text_tokens[0].cpu().numpy())) - assert ( - text_tokens.shape[-1] < self.args.gpt_max_text_tokens - ), " ❗ XTTS can only generate text with a maximum of 400 tokens." - - with torch.no_grad(): - gpt_codes = self.gpt.generate( - cond_latents=gpt_cond_latent, - text_inputs=text_tokens, - input_tokens=None, - do_sample=do_sample, - top_p=top_p, - top_k=top_k, - temperature=temperature, - num_return_sequences=self.gpt_batch_size, - num_beams=num_beams, - length_penalty=length_penalty, - repetition_penalty=repetition_penalty, - output_attentions=False, - **hf_generate_kwargs, - ) - expected_output_len = torch.tensor( - [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device - ) + if enable_text_splitting: + text = split_sentence(text, language, self.tokenizer.char_limits[language]) + else: + text = [text] + + wavs = [] + gpt_latents_list = [] + for sent in text: + sent = sent.strip().lower() + text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device) + + assert ( + text_tokens.shape[-1] < self.args.gpt_max_text_tokens + ), " ❗ XTTS can only generate text with a maximum of 400 tokens." + + with torch.no_grad(): + gpt_codes = self.gpt.generate( + cond_latents=gpt_cond_latent, + text_inputs=text_tokens, + input_tokens=None, + do_sample=do_sample, + top_p=top_p, + top_k=top_k, + temperature=temperature, + num_return_sequences=self.gpt_batch_size, + num_beams=num_beams, + length_penalty=length_penalty, + repetition_penalty=repetition_penalty, + output_attentions=False, + **hf_generate_kwargs, + ) + expected_output_len = torch.tensor( + [gpt_codes.shape[-1] * self.gpt.code_stride_len], device=text_tokens.device + ) - text_len = torch.tensor([text_tokens.shape[-1]], device=self.device) - gpt_latents = self.gpt( - text_tokens, - text_len, - gpt_codes, - expected_output_len, - cond_latents=gpt_cond_latent, - return_attentions=False, - return_latent=True, - ) + text_len = torch.tensor([text_tokens.shape[-1]], device=self.device) + gpt_latents = self.gpt( + text_tokens, + text_len, + gpt_codes, + expected_output_len, + cond_latents=gpt_cond_latent, + return_attentions=False, + return_latent=True, + ) - if length_scale != 1.0: - gpt_latents = F.interpolate( - gpt_latents.transpose(1, 2), - scale_factor=length_scale, - mode="linear" - ).transpose(1, 2) + if length_scale != 1.0: + gpt_latents = F.interpolate( + gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear" + ).transpose(1, 2) - wav = self.hifigan_decoder(gpt_latents, g=speaker_embedding) + gpt_latents_list.append(gpt_latents.cpu()) + wavs.append(self.hifigan_decoder(gpt_latents, g=speaker_embedding).cpu().squeeze()) return { - "wav": wav.cpu().numpy().squeeze(), - "gpt_latents": gpt_latents, + "wav": torch.cat(wavs, dim=0).numpy(), + "gpt_latents": torch.cat(gpt_latents_list, dim=1).numpy(), "speaker_embedding": speaker_embedding, } @@ -606,66 +611,76 @@ def inference_stream( stream_chunk_size=20, overlap_wav_len=1024, # GPT inference - temperature=0.65, - length_penalty=1, - repetition_penalty=2.0, + temperature=0.75, + length_penalty=1.0, + repetition_penalty=10.0, top_k=50, top_p=0.85, do_sample=True, speed=1.0, + enable_text_splitting=False, **hf_generate_kwargs, ): + language = language.split("-")[0] # remove the country code length_scale = 1.0 / max(speed, 0.05) - text = text.strip().lower() - text_tokens = torch.IntTensor(self.tokenizer.encode(text, lang=language)).unsqueeze(0).to(self.device) + if enable_text_splitting: + text = split_sentence(text, language, self.tokenizer.char_limits[language]) + else: + text = [text] - fake_inputs = self.gpt.compute_embeddings( - gpt_cond_latent.to(self.device), - text_tokens, - ) - gpt_generator = self.gpt.get_generator( - fake_inputs=fake_inputs, - top_k=top_k, - top_p=top_p, - temperature=temperature, - do_sample=do_sample, - num_beams=1, - num_return_sequences=1, - length_penalty=float(length_penalty), - repetition_penalty=float(repetition_penalty), - output_attentions=False, - output_hidden_states=True, - **hf_generate_kwargs, - ) + for sent in text: + sent = sent.strip().lower() + text_tokens = torch.IntTensor(self.tokenizer.encode(sent, lang=language)).unsqueeze(0).to(self.device) - last_tokens = [] - all_latents = [] - wav_gen_prev = None - wav_overlap = None - is_end = False - - while not is_end: - try: - x, latent = next(gpt_generator) - last_tokens += [x] - all_latents += [latent] - except StopIteration: - is_end = True - - if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size): - gpt_latents = torch.cat(all_latents, dim=0)[None, :] - if length_scale != 1.0: - gpt_latents = F.interpolate( - gpt_latents.transpose(1, 2), - scale_factor=length_scale, - mode="linear" - ).transpose(1, 2) - wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) - wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( - wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len - ) - last_tokens = [] - yield wav_chunk + assert ( + text_tokens.shape[-1] < self.args.gpt_max_text_tokens + ), " ❗ XTTS can only generate text with a maximum of 400 tokens." + + fake_inputs = self.gpt.compute_embeddings( + gpt_cond_latent.to(self.device), + text_tokens, + ) + gpt_generator = self.gpt.get_generator( + fake_inputs=fake_inputs, + top_k=top_k, + top_p=top_p, + temperature=temperature, + do_sample=do_sample, + num_beams=1, + num_return_sequences=1, + length_penalty=float(length_penalty), + repetition_penalty=float(repetition_penalty), + output_attentions=False, + output_hidden_states=True, + **hf_generate_kwargs, + ) + + last_tokens = [] + all_latents = [] + wav_gen_prev = None + wav_overlap = None + is_end = False + + while not is_end: + try: + x, latent = next(gpt_generator) + last_tokens += [x] + all_latents += [latent] + except StopIteration: + is_end = True + + if is_end or (stream_chunk_size > 0 and len(last_tokens) >= stream_chunk_size): + gpt_latents = torch.cat(all_latents, dim=0)[None, :] + if length_scale != 1.0: + gpt_latents = F.interpolate( + gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear" + ).transpose(1, 2) + wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device)) + wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks( + wav_gen.squeeze(), wav_gen_prev, wav_overlap, overlap_wav_len + ) + last_tokens = [] + yield wav_chunk def forward(self): raise NotImplementedError( diff --git a/TTS/utils/audio/numpy_transforms.py b/TTS/utils/audio/numpy_transforms.py index b701e76712..af88569fc3 100644 --- a/TTS/utils/audio/numpy_transforms.py +++ b/TTS/utils/audio/numpy_transforms.py @@ -201,7 +201,6 @@ def stft( def istft( *, y: np.ndarray = None, - fft_size: int = None, hop_length: int = None, win_length: int = None, window: str = "hann", diff --git a/TTS/utils/audio/processor.py b/TTS/utils/audio/processor.py index 4ceb7da4b3..c53bad562e 100644 --- a/TTS/utils/audio/processor.py +++ b/TTS/utils/audio/processor.py @@ -5,10 +5,26 @@ import numpy as np import scipy.io.wavfile import scipy.signal -import soundfile as sf from TTS.tts.utils.helpers import StandardScaler -from TTS.utils.audio.numpy_transforms import compute_f0 +from TTS.utils.audio.numpy_transforms import ( + amp_to_db, + build_mel_basis, + compute_f0, + db_to_amp, + deemphasis, + find_endpoint, + griffin_lim, + load_wav, + mel_to_spec, + millisec_to_length, + preemphasis, + rms_volume_norm, + spec_to_mel, + stft, + trim_silence, + volume_norm, +) # pylint: disable=too-many-public-methods @@ -200,7 +216,9 @@ def __init__( # setup stft parameters if hop_length is None: # compute stft parameters from given time values - self.hop_length, self.win_length = self._stft_parameters() + self.win_length, self.hop_length = millisec_to_length( + frame_length_ms=self.frame_length_ms, frame_shift_ms=self.frame_shift_ms, sample_rate=self.sample_rate + ) else: # use stft parameters from config file self.hop_length = hop_length @@ -215,8 +233,13 @@ def __init__( for key, value in members.items(): print(" | > {}:{}".format(key, value)) # create spectrogram utils - self.mel_basis = self._build_mel_basis() - self.inv_mel_basis = np.linalg.pinv(self._build_mel_basis()) + self.mel_basis = build_mel_basis( + sample_rate=self.sample_rate, + fft_size=self.fft_size, + num_mels=self.num_mels, + mel_fmax=self.mel_fmax, + mel_fmin=self.mel_fmin, + ) # setup scaler if stats_path and signal_norm: mel_mean, mel_std, linear_mean, linear_std, _ = self.load_stats(stats_path) @@ -232,35 +255,6 @@ def init_from_config(config: "Coqpit", verbose=True): return AudioProcessor(verbose=verbose, **config.audio) return AudioProcessor(verbose=verbose, **config) - ### setting up the parameters ### - def _build_mel_basis( - self, - ) -> np.ndarray: - """Build melspectrogram basis. - - Returns: - np.ndarray: melspectrogram basis. - """ - if self.mel_fmax is not None: - assert self.mel_fmax <= self.sample_rate // 2 - return librosa.filters.mel( - sr=self.sample_rate, n_fft=self.fft_size, n_mels=self.num_mels, fmin=self.mel_fmin, fmax=self.mel_fmax - ) - - def _stft_parameters( - self, - ) -> Tuple[int, int]: - """Compute the real STFT parameters from the time values. - - Returns: - Tuple[int, int]: hop length and window length for STFT. - """ - factor = self.frame_length_ms / self.frame_shift_ms - assert (factor).is_integer(), " [!] frame_shift_ms should divide frame_length_ms" - hop_length = int(self.frame_shift_ms / 1000.0 * self.sample_rate) - win_length = int(hop_length * factor) - return hop_length, win_length - ### normalization ### def normalize(self, S: np.ndarray) -> np.ndarray: """Normalize values into `[0, self.max_norm]` or `[-self.max_norm, self.max_norm]` @@ -386,31 +380,6 @@ def setup_scaler( self.linear_scaler = StandardScaler() self.linear_scaler.set_stats(linear_mean, linear_std) - ### DB and AMP conversion ### - # pylint: disable=no-self-use - def _amp_to_db(self, x: np.ndarray) -> np.ndarray: - """Convert amplitude values to decibels. - - Args: - x (np.ndarray): Amplitude spectrogram. - - Returns: - np.ndarray: Decibels spectrogram. - """ - return self.spec_gain * _log(np.maximum(1e-5, x), self.base) - - # pylint: disable=no-self-use - def _db_to_amp(self, x: np.ndarray) -> np.ndarray: - """Convert decibels spectrogram to amplitude spectrogram. - - Args: - x (np.ndarray): Decibels spectrogram. - - Returns: - np.ndarray: Amplitude spectrogram. - """ - return _exp(x / self.spec_gain, self.base) - ### Preemphasis ### def apply_preemphasis(self, x: np.ndarray) -> np.ndarray: """Apply pre-emphasis to the audio signal. Useful to reduce the correlation between neighbouring signal values. @@ -424,32 +393,13 @@ def apply_preemphasis(self, x: np.ndarray) -> np.ndarray: Returns: np.ndarray: Decorrelated audio signal. """ - if self.preemphasis == 0: - raise RuntimeError(" [!] Preemphasis is set 0.0.") - return scipy.signal.lfilter([1, -self.preemphasis], [1], x) + return preemphasis(x=x, coef=self.preemphasis) def apply_inv_preemphasis(self, x: np.ndarray) -> np.ndarray: """Reverse pre-emphasis.""" - if self.preemphasis == 0: - raise RuntimeError(" [!] Preemphasis is set 0.0.") - return scipy.signal.lfilter([1], [1, -self.preemphasis], x) + return deemphasis(x=x, coef=self.preemphasis) ### SPECTROGRAMs ### - def _linear_to_mel(self, spectrogram: np.ndarray) -> np.ndarray: - """Project a full scale spectrogram to a melspectrogram. - - Args: - spectrogram (np.ndarray): Full scale spectrogram. - - Returns: - np.ndarray: Melspectrogram - """ - return np.dot(self.mel_basis, spectrogram) - - def _mel_to_linear(self, mel_spec: np.ndarray) -> np.ndarray: - """Convert a melspectrogram to full scale spectrogram.""" - return np.maximum(1e-10, np.dot(self.inv_mel_basis, mel_spec)) - def spectrogram(self, y: np.ndarray) -> np.ndarray: """Compute a spectrogram from a waveform. @@ -460,11 +410,16 @@ def spectrogram(self, y: np.ndarray) -> np.ndarray: np.ndarray: Spectrogram. """ if self.preemphasis != 0: - D = self._stft(self.apply_preemphasis(y)) - else: - D = self._stft(y) + y = self.apply_preemphasis(y) + D = stft( + y=y, + fft_size=self.fft_size, + hop_length=self.hop_length, + win_length=self.win_length, + pad_mode=self.stft_pad_mode, + ) if self.do_amp_to_db_linear: - S = self._amp_to_db(np.abs(D)) + S = amp_to_db(x=np.abs(D), gain=self.spec_gain, base=self.base) else: S = np.abs(D) return self.normalize(S).astype(np.float32) @@ -472,32 +427,35 @@ def spectrogram(self, y: np.ndarray) -> np.ndarray: def melspectrogram(self, y: np.ndarray) -> np.ndarray: """Compute a melspectrogram from a waveform.""" if self.preemphasis != 0: - D = self._stft(self.apply_preemphasis(y)) - else: - D = self._stft(y) + y = self.apply_preemphasis(y) + D = stft( + y=y, + fft_size=self.fft_size, + hop_length=self.hop_length, + win_length=self.win_length, + pad_mode=self.stft_pad_mode, + ) + S = spec_to_mel(spec=np.abs(D), mel_basis=self.mel_basis) if self.do_amp_to_db_mel: - S = self._amp_to_db(self._linear_to_mel(np.abs(D))) - else: - S = self._linear_to_mel(np.abs(D)) + S = amp_to_db(x=S, gain=self.spec_gain, base=self.base) + return self.normalize(S).astype(np.float32) def inv_spectrogram(self, spectrogram: np.ndarray) -> np.ndarray: """Convert a spectrogram to a waveform using Griffi-Lim vocoder.""" S = self.denormalize(spectrogram) - S = self._db_to_amp(S) + S = db_to_amp(x=S, gain=self.spec_gain, base=self.base) # Reconstruct phase - if self.preemphasis != 0: - return self.apply_inv_preemphasis(self._griffin_lim(S**self.power)) - return self._griffin_lim(S**self.power) + W = self._griffin_lim(S**self.power) + return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W def inv_melspectrogram(self, mel_spectrogram: np.ndarray) -> np.ndarray: """Convert a melspectrogram to a waveform using Griffi-Lim vocoder.""" D = self.denormalize(mel_spectrogram) - S = self._db_to_amp(D) - S = self._mel_to_linear(S) # Convert back to linear - if self.preemphasis != 0: - return self.apply_inv_preemphasis(self._griffin_lim(S**self.power)) - return self._griffin_lim(S**self.power) + S = db_to_amp(x=D, gain=self.spec_gain, base=self.base) + S = mel_to_spec(mel=S, mel_basis=self.mel_basis) # Convert back to linear + W = self._griffin_lim(S**self.power) + return self.apply_inv_preemphasis(W) if self.preemphasis != 0 else W def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray: """Convert a full scale linear spectrogram output of a network to a melspectrogram. @@ -509,60 +467,22 @@ def out_linear_to_mel(self, linear_spec: np.ndarray) -> np.ndarray: np.ndarray: Normalized melspectrogram. """ S = self.denormalize(linear_spec) - S = self._db_to_amp(S) - S = self._linear_to_mel(np.abs(S)) - S = self._amp_to_db(S) + S = db_to_amp(x=S, gain=self.spec_gain, base=self.base) + S = spec_to_mel(spec=np.abs(S), mel_basis=self.mel_basis) + S = amp_to_db(x=S, gain=self.spec_gain, base=self.base) mel = self.normalize(S) return mel - ### STFT and ISTFT ### - def _stft(self, y: np.ndarray) -> np.ndarray: - """Librosa STFT wrapper. - - Args: - y (np.ndarray): Audio signal. - - Returns: - np.ndarray: Complex number array. - """ - return librosa.stft( - y=y, - n_fft=self.fft_size, + def _griffin_lim(self, S): + return griffin_lim( + spec=S, + num_iter=self.griffin_lim_iters, hop_length=self.hop_length, win_length=self.win_length, + fft_size=self.fft_size, pad_mode=self.stft_pad_mode, - window="hann", - center=True, ) - def _istft(self, y: np.ndarray) -> np.ndarray: - """Librosa iSTFT wrapper.""" - return librosa.istft(y, hop_length=self.hop_length, win_length=self.win_length) - - def _griffin_lim(self, S): - angles = np.exp(2j * np.pi * np.random.rand(*S.shape)) - try: - S_complex = np.abs(S).astype(np.complex) - except AttributeError: # np.complex is deprecated since numpy 1.20.0 - S_complex = np.abs(S).astype(complex) - y = self._istft(S_complex * angles) - if not np.isfinite(y).all(): - print(" [!] Waveform is not finite everywhere. Skipping the GL.") - return np.array([0.0]) - for _ in range(self.griffin_lim_iters): - angles = np.exp(1j * np.angle(self._stft(y))) - y = self._istft(S_complex * angles) - return y - - def compute_stft_paddings(self, x, pad_sides=1): - """Compute paddings used by Librosa's STFT. Compute right padding (final frame) or both sides padding - (first and final frames)""" - assert pad_sides in (1, 2) - pad = (x.shape[0] // self.hop_length + 1) * self.hop_length - x.shape[0] - if pad_sides == 1: - return 0, pad - return pad // 2, pad // 2 + pad % 2 - def compute_f0(self, x: np.ndarray) -> np.ndarray: """Compute pitch (f0) of a waveform using the same parameters used for computing melspectrogram. @@ -581,8 +501,6 @@ def compute_f0(self, x: np.ndarray) -> np.ndarray: >>> wav = ap.load_wav(WAV_FILE, sr=ap.sample_rate)[:5 * ap.sample_rate] >>> pitch = ap.compute_f0(wav) """ - assert self.pitch_fmax is not None, " [!] Set `pitch_fmax` before caling `compute_f0`." - assert self.pitch_fmin is not None, " [!] Set `pitch_fmin` before caling `compute_f0`." # align F0 length to the spectrogram length if len(x) % self.hop_length == 0: x = np.pad(x, (0, self.hop_length // 2), mode=self.stft_pad_mode) @@ -612,21 +530,24 @@ def find_endpoint(self, wav: np.ndarray, min_silence_sec=0.8) -> int: Returns: int: Last point without silence. """ - window_length = int(self.sample_rate * min_silence_sec) - hop_length = int(window_length / 4) - threshold = self._db_to_amp(-self.trim_db) - for x in range(hop_length, len(wav) - window_length, hop_length): - if np.max(wav[x : x + window_length]) < threshold: - return x + hop_length - return len(wav) + return find_endpoint( + wav=wav, + trim_db=self.trim_db, + sample_rate=self.sample_rate, + min_silence_sec=min_silence_sec, + gain=self.spec_gain, + base=self.base, + ) def trim_silence(self, wav): """Trim silent parts with a threshold and 0.01 sec margin""" - margin = int(self.sample_rate * 0.01) - wav = wav[margin:-margin] - return librosa.effects.trim(wav, top_db=self.trim_db, frame_length=self.win_length, hop_length=self.hop_length)[ - 0 - ] + return trim_silence( + wav=wav, + sample_rate=self.sample_rate, + trim_db=self.trim_db, + win_length=self.win_length, + hop_length=self.hop_length, + ) @staticmethod def sound_norm(x: np.ndarray) -> np.ndarray: @@ -638,13 +559,7 @@ def sound_norm(x: np.ndarray) -> np.ndarray: Returns: np.ndarray: Volume normalized waveform. """ - return x / abs(x).max() * 0.95 - - @staticmethod - def _rms_norm(wav, db_level=-27): - r = 10 ** (db_level / 20) - a = np.sqrt((len(wav) * (r**2)) / np.sum(wav**2)) - return wav * a + return volume_norm(x=x) def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray: """Normalize the volume based on RMS of the signal. @@ -657,9 +572,7 @@ def rms_volume_norm(self, x: np.ndarray, db_level: float = None) -> np.ndarray: """ if db_level is None: db_level = self.db_level - assert -99 <= db_level <= 0, " [!] db_level should be between -99 and 0" - wav = self._rms_norm(x, db_level) - return wav + return rms_volume_norm(x=x, db_level=db_level) ### save and load ### def load_wav(self, filename: str, sr: int = None) -> np.ndarray: @@ -674,15 +587,10 @@ def load_wav(self, filename: str, sr: int = None) -> np.ndarray: Returns: np.ndarray: Loaded waveform. """ - if self.resample: - # loading with resampling. It is significantly slower. - x, sr = librosa.load(filename, sr=self.sample_rate) - elif sr is None: - # SF is faster than librosa for loading files - x, sr = sf.read(filename) - assert self.sample_rate == sr, "%s vs %s" % (self.sample_rate, sr) + if sr is not None: + x = load_wav(filename=filename, sample_rate=sr, resample=True) else: - x, sr = librosa.load(filename, sr=sr) + x = load_wav(filename=filename, sample_rate=self.sample_rate, resample=self.resample) if self.do_trim_silence: try: x = self.trim_silence(x) @@ -723,55 +631,3 @@ def get_duration(self, filename: str) -> float: filename (str): Path to the wav file. """ return librosa.get_duration(filename=filename) - - @staticmethod - def mulaw_encode(wav: np.ndarray, qc: int) -> np.ndarray: - mu = 2**qc - 1 - # wav_abs = np.minimum(np.abs(wav), 1.0) - signal = np.sign(wav) * np.log(1 + mu * np.abs(wav)) / np.log(1.0 + mu) - # Quantize signal to the specified number of levels. - signal = (signal + 1) / 2 * mu + 0.5 - return np.floor( - signal, - ) - - @staticmethod - def mulaw_decode(wav, qc): - """Recovers waveform from quantized values.""" - mu = 2**qc - 1 - x = np.sign(wav) / mu * ((1 + mu) ** np.abs(wav) - 1) - return x - - @staticmethod - def encode_16bits(x): - return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16) - - @staticmethod - def quantize(x: np.ndarray, bits: int) -> np.ndarray: - """Quantize a waveform to a given number of bits. - - Args: - x (np.ndarray): Waveform to quantize. Must be normalized into the range `[-1, 1]`. - bits (int): Number of quantization bits. - - Returns: - np.ndarray: Quantized waveform. - """ - return (x + 1.0) * (2**bits - 1) / 2 - - @staticmethod - def dequantize(x, bits): - """Dequantize a waveform from the given number of bits.""" - return 2 * x / (2**bits - 1) - 1 - - -def _log(x, base): - if base == 10: - return np.log10(x) - return np.log(x) - - -def _exp(x, base): - if base == 10: - return np.power(10, x) - return np.exp(x) diff --git a/TTS/vocoder/configs/parallel_wavegan_config.py b/TTS/vocoder/configs/parallel_wavegan_config.py index 7845dd6bf8..6059d7f04f 100644 --- a/TTS/vocoder/configs/parallel_wavegan_config.py +++ b/TTS/vocoder/configs/parallel_wavegan_config.py @@ -94,6 +94,7 @@ class ParallelWaveganConfig(BaseGANVocoderConfig): use_noise_augment: bool = False use_cache: bool = True steps_to_start_discriminator: int = 200000 + target_loss: str = "loss_1" # LOSS PARAMETERS - overrides use_stft_loss: bool = True diff --git a/TTS/vocoder/datasets/preprocess.py b/TTS/vocoder/datasets/preprocess.py index 0f69b812fa..503bb04b2f 100644 --- a/TTS/vocoder/datasets/preprocess.py +++ b/TTS/vocoder/datasets/preprocess.py @@ -7,6 +7,7 @@ from tqdm import tqdm from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor): @@ -29,7 +30,11 @@ def preprocess_wav_files(out_path: str, config: Coqpit, ap: AudioProcessor): mel = ap.melspectrogram(y) np.save(mel_path, mel) if isinstance(config.mode, int): - quant = ap.mulaw_encode(y, qc=config.mode) if config.model_args.mulaw else ap.quantize(y, bits=config.mode) + quant = ( + mulaw_encode(wav=y, mulaw_qc=config.mode) + if config.model_args.mulaw + else quantize(x=y, quantize_bits=config.mode) + ) np.save(quant_path, quant) diff --git a/TTS/vocoder/datasets/wavernn_dataset.py b/TTS/vocoder/datasets/wavernn_dataset.py index c390796428..a67c5b31a0 100644 --- a/TTS/vocoder/datasets/wavernn_dataset.py +++ b/TTS/vocoder/datasets/wavernn_dataset.py @@ -2,6 +2,8 @@ import torch from torch.utils.data import Dataset +from TTS.utils.audio.numpy_transforms import mulaw_encode, quantize + class WaveRNNDataset(Dataset): """ @@ -66,7 +68,9 @@ def load_item(self, index): x_input = audio elif isinstance(self.mode, int): x_input = ( - self.ap.mulaw_encode(audio, qc=self.mode) if self.mulaw else self.ap.quantize(audio, bits=self.mode) + mulaw_encode(wav=audio, mulaw_qc=self.mode) + if self.mulaw + else quantize(x=audio, quantize_bits=self.mode) ) else: raise RuntimeError("Unknown dataset mode - ", self.mode) diff --git a/TTS/vocoder/models/wavernn.py b/TTS/vocoder/models/wavernn.py index 903f4b7e63..7f74ba3ebf 100644 --- a/TTS/vocoder/models/wavernn.py +++ b/TTS/vocoder/models/wavernn.py @@ -13,6 +13,7 @@ from TTS.tts.utils.visual import plot_spectrogram from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import mulaw_decode from TTS.utils.io import load_fsspec from TTS.vocoder.datasets.wavernn_dataset import WaveRNNDataset from TTS.vocoder.layers.losses import WaveRNNLoss @@ -399,7 +400,7 @@ def inference(self, mels, batched=None, target=None, overlap=None): output = output[0] if self.args.mulaw and isinstance(self.args.mode, int): - output = AudioProcessor.mulaw_decode(output, self.args.mode) + output = mulaw_decode(wav=output, mulaw_qc=self.args.mode) # Fade-out at the end to avoid signal cutting out suddenly fade_out = np.linspace(1, 0, 20 * self.config.audio.hop_length) diff --git a/notebooks/ExtractTTSpectrogram.ipynb b/notebooks/ExtractTTSpectrogram.ipynb index 9acc9929fc..0ec5f167b4 100644 --- a/notebooks/ExtractTTSpectrogram.ipynb +++ b/notebooks/ExtractTTSpectrogram.ipynb @@ -13,23 +13,28 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", - "import sys\n", - "import torch\n", "import importlib\n", + "import os\n", + "import pickle\n", + "\n", "import numpy as np\n", - "from tqdm import tqdm\n", - "from torch.utils.data import DataLoader\n", "import soundfile as sf\n", - "import pickle\n", + "import torch\n", + "from matplotlib import pylab as plt\n", + "from torch.utils.data import DataLoader\n", + "from tqdm import tqdm\n", + "\n", + "from TTS.config import load_config\n", + "from TTS.tts.configs.shared_configs import BaseDatasetConfig\n", + "from TTS.tts.datasets import load_tts_samples\n", "from TTS.tts.datasets.dataset import TTSDataset\n", "from TTS.tts.layers.losses import L1LossMasked\n", - "from TTS.utils.audio import AudioProcessor\n", - "from TTS.config import load_config\n", - "from TTS.tts.utils.visual import plot_spectrogram\n", - "from TTS.tts.utils.helpers import sequence_mask\n", "from TTS.tts.models import setup_model\n", - "from TTS.tts.utils.text.symbols import make_symbols, symbols, phonemes\n", + "from TTS.tts.utils.helpers import sequence_mask\n", + "from TTS.tts.utils.text.tokenizer import TTSTokenizer\n", + "from TTS.tts.utils.visual import plot_spectrogram\n", + "from TTS.utils.audio import AudioProcessor\n", + "from TTS.utils.audio.numpy_transforms import quantize\n", "\n", "%matplotlib inline\n", "\n", @@ -49,11 +54,9 @@ " file_name = wav_file.split('.')[0]\n", " os.makedirs(os.path.join(out_path, \"quant\"), exist_ok=True)\n", " os.makedirs(os.path.join(out_path, \"mel\"), exist_ok=True)\n", - " os.makedirs(os.path.join(out_path, \"wav_gl\"), exist_ok=True)\n", " wavq_path = os.path.join(out_path, \"quant\", file_name)\n", " mel_path = os.path.join(out_path, \"mel\", file_name)\n", - " wav_path = os.path.join(out_path, \"wav_gl\", file_name)\n", - " return file_name, wavq_path, mel_path, wav_path" + " return file_name, wavq_path, mel_path" ] }, { @@ -65,14 +68,14 @@ "# Paths and configurations\n", "OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n", "DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n", + "PHONEME_CACHE_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/phoneme_cache\"\n", "DATASET = \"ljspeech\"\n", "METADATA_FILE = \"metadata.csv\"\n", "CONFIG_PATH = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json\"\n", "MODEL_FILE = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth\"\n", "BATCH_SIZE = 32\n", "\n", - "QUANTIZED_WAV = False\n", - "QUANTIZE_BIT = None\n", + "QUANTIZE_BITS = 0 # if non-zero, quantize wav files with the given number of bits\n", "DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n", "\n", "# Check CUDA availability\n", @@ -80,10 +83,10 @@ "print(\" > CUDA enabled: \", use_cuda)\n", "\n", "# Load the configuration\n", + "dataset_config = BaseDatasetConfig(formatter=DATASET, meta_file_train=METADATA_FILE, path=DATA_PATH)\n", "C = load_config(CONFIG_PATH)\n", "C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n", - "ap = AudioProcessor(bits=QUANTIZE_BIT, **C.audio)\n", - "print(C['r'])" + "ap = AudioProcessor(**C.audio)" ] }, { @@ -92,12 +95,10 @@ "metadata": {}, "outputs": [], "source": [ - "# If the vocabulary was passed, replace the default\n", - "if 'characters' in C and C['characters']:\n", - " symbols, phonemes = make_symbols(**C.characters)\n", + "# Initialize the tokenizer\n", + "tokenizer, C = TTSTokenizer.init_from_config(C)\n", "\n", "# Load the model\n", - "num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n", "# TODO: multiple speakers\n", "model = setup_model(C)\n", "model.load_checkpoint(C, MODEL_FILE, eval=True)" @@ -109,42 +110,21 @@ "metadata": {}, "outputs": [], "source": [ - "# Load the preprocessor based on the dataset\n", - "preprocessor = importlib.import_module(\"TTS.tts.datasets.formatters\")\n", - "preprocessor = getattr(preprocessor, DATASET.lower())\n", - "meta_data = preprocessor(DATA_PATH, METADATA_FILE)\n", + "# Load data instances\n", + "meta_data_train, meta_data_eval = load_tts_samples(dataset_config)\n", + "meta_data = meta_data_train + meta_data_eval\n", + "\n", "dataset = TTSDataset(\n", - " C,\n", - " C.text_cleaner,\n", - " False,\n", - " ap,\n", - " meta_data,\n", - " characters=C.get('characters', None),\n", - " use_phonemes=C.use_phonemes,\n", - " phoneme_cache_path=C.phoneme_cache_path,\n", - " enable_eos_bos=C.enable_eos_bos_chars,\n", + " outputs_per_step=C[\"r\"],\n", + " compute_linear_spec=False,\n", + " ap=ap,\n", + " samples=meta_data,\n", + " tokenizer=tokenizer,\n", + " phoneme_cache_path=PHONEME_CACHE_PATH,\n", ")\n", "loader = DataLoader(\n", " dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n", - ")\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Initialize lists for storing results\n", - "file_idxs = []\n", - "metadata = []\n", - "losses = []\n", - "postnet_losses = []\n", - "criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n", - "\n", - "# Create log file\n", - "log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n", - "log_file = open(log_file_path, \"w\")" + ")" ] }, { @@ -160,26 +140,33 @@ "metadata": {}, "outputs": [], "source": [ + "# Initialize lists for storing results\n", + "file_idxs = []\n", + "metadata = []\n", + "losses = []\n", + "postnet_losses = []\n", + "criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n", + "\n", "# Start processing with a progress bar\n", - "with torch.no_grad():\n", + "log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n", + "with torch.no_grad() and open(log_file_path, \"w\") as log_file:\n", " for data in tqdm(loader, desc=\"Processing\"):\n", " try:\n", - " # setup input data\n", - " text_input, text_lengths, _, linear_input, mel_input, mel_lengths, stop_targets, item_idx = data\n", - "\n", " # dispatch data to GPU\n", " if use_cuda:\n", - " text_input = text_input.cuda()\n", - " text_lengths = text_lengths.cuda()\n", - " mel_input = mel_input.cuda()\n", - " mel_lengths = mel_lengths.cuda()\n", + " data[\"token_id\"] = data[\"token_id\"].cuda()\n", + " data[\"token_id_lengths\"] = data[\"token_id_lengths\"].cuda()\n", + " data[\"mel\"] = data[\"mel\"].cuda()\n", + " data[\"mel_lengths\"] = data[\"mel_lengths\"].cuda()\n", "\n", - " mask = sequence_mask(text_lengths)\n", - " mel_outputs, postnet_outputs, alignments, stop_tokens = model.forward(text_input, text_lengths, mel_input)\n", + " mask = sequence_mask(data[\"token_id_lengths\"])\n", + " outputs = model.forward(data[\"token_id\"], data[\"token_id_lengths\"], data[\"mel\"])\n", + " mel_outputs = outputs[\"decoder_outputs\"]\n", + " postnet_outputs = outputs[\"model_outputs\"]\n", "\n", " # compute loss\n", - " loss = criterion(mel_outputs, mel_input, mel_lengths)\n", - " loss_postnet = criterion(postnet_outputs, mel_input, mel_lengths)\n", + " loss = criterion(mel_outputs, data[\"mel\"], data[\"mel_lengths\"])\n", + " loss_postnet = criterion(postnet_outputs, data[\"mel\"], data[\"mel_lengths\"])\n", " losses.append(loss.item())\n", " postnet_losses.append(loss_postnet.item())\n", "\n", @@ -193,28 +180,27 @@ " postnet_outputs = torch.stack(mel_specs)\n", " elif C.model == \"Tacotron2\":\n", " postnet_outputs = postnet_outputs.detach().cpu().numpy()\n", - " alignments = alignments.detach().cpu().numpy()\n", + " alignments = outputs[\"alignments\"].detach().cpu().numpy()\n", "\n", " if not DRY_RUN:\n", - " for idx in range(text_input.shape[0]):\n", - " wav_file_path = item_idx[idx]\n", + " for idx in range(data[\"token_id\"].shape[0]):\n", + " wav_file_path = data[\"item_idxs\"][idx]\n", " wav = ap.load_wav(wav_file_path)\n", - " file_name, wavq_path, mel_path, wav_path = set_filename(wav_file_path, OUT_PATH)\n", + " file_name, wavq_path, mel_path = set_filename(wav_file_path, OUT_PATH)\n", " file_idxs.append(file_name)\n", "\n", " # quantize and save wav\n", - " if QUANTIZED_WAV:\n", - " wavq = ap.quantize(wav)\n", + " if QUANTIZE_BITS > 0:\n", + " wavq = quantize(wav, QUANTIZE_BITS)\n", " np.save(wavq_path, wavq)\n", "\n", " # save TTS mel\n", " mel = postnet_outputs[idx]\n", - " mel_length = mel_lengths[idx]\n", + " mel_length = data[\"mel_lengths\"][idx]\n", " mel = mel[:mel_length, :].T\n", " np.save(mel_path, mel)\n", "\n", " metadata.append([wav_file_path, mel_path])\n", - "\n", " except Exception as e:\n", " log_file.write(f\"Error processing data: {str(e)}\\n\")\n", "\n", @@ -224,35 +210,20 @@ " log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n", " log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n", "\n", - "# Close the log file\n", - "log_file.close()\n", - "\n", "# For wavernn\n", "if not DRY_RUN:\n", " pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n", "\n", "# For pwgan\n", "with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n", - " for data in metadata:\n", - " f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")\n", + " for wav_file_path, mel_path in metadata:\n", + " f.write(f\"{wav_file_path[0]}|{mel_path[1]+'.npy'}\\n\")\n", "\n", "# Print mean losses\n", "print(f\"Mean Loss: {mean_loss}\")\n", "print(f\"Mean Postnet Loss: {mean_postnet_loss}\")" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# for pwgan\n", - "with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n", - " for data in metadata:\n", - " f.write(f\"{data[0]}|{data[1]+'.npy'}\\n\")" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -267,7 +238,7 @@ "outputs": [], "source": [ "idx = 1\n", - "ap.melspectrogram(ap.load_wav(item_idx[idx])).shape" + "ap.melspectrogram(ap.load_wav(data[\"item_idxs\"][idx])).shape" ] }, { @@ -276,10 +247,9 @@ "metadata": {}, "outputs": [], "source": [ - "import soundfile as sf\n", - "wav, sr = sf.read(item_idx[idx])\n", - "mel_postnet = postnet_outputs[idx][:mel_lengths[idx], :]\n", - "mel_decoder = mel_outputs[idx][:mel_lengths[idx], :].detach().cpu().numpy()\n", + "wav, sr = sf.read(data[\"item_idxs\"][idx])\n", + "mel_postnet = postnet_outputs[idx][:data[\"mel_lengths\"][idx], :]\n", + "mel_decoder = mel_outputs[idx][:data[\"mel_lengths\"][idx], :].detach().cpu().numpy()\n", "mel_truth = ap.melspectrogram(wav)\n", "print(mel_truth.shape)" ] @@ -291,7 +261,7 @@ "outputs": [], "source": [ "# plot posnet output\n", - "print(mel_postnet[:mel_lengths[idx], :].shape)\n", + "print(mel_postnet[:data[\"mel_lengths\"][idx], :].shape)\n", "plot_spectrogram(mel_postnet, ap)" ] }, @@ -324,10 +294,9 @@ "outputs": [], "source": [ "# postnet, decoder diff\n", - "from matplotlib import pylab as plt\n", "mel_diff = mel_decoder - mel_postnet\n", "plt.figure(figsize=(16, 10))\n", - "plt.imshow(abs(mel_diff[:mel_lengths[idx],:]).T,aspect=\"auto\", origin=\"lower\");\n", + "plt.imshow(abs(mel_diff[:data[\"mel_lengths\"][idx],:]).T,aspect=\"auto\", origin=\"lower\")\n", "plt.colorbar()\n", "plt.tight_layout()" ] @@ -339,10 +308,9 @@ "outputs": [], "source": [ "# PLOT GT SPECTROGRAM diff\n", - "from matplotlib import pylab as plt\n", "mel_diff2 = mel_truth.T - mel_decoder\n", "plt.figure(figsize=(16, 10))\n", - "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n", + "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n", "plt.colorbar()\n", "plt.tight_layout()" ] @@ -354,21 +322,13 @@ "outputs": [], "source": [ "# PLOT GT SPECTROGRAM diff\n", - "from matplotlib import pylab as plt\n", "mel = postnet_outputs[idx]\n", "mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]]\n", "plt.figure(figsize=(16, 10))\n", - "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\");\n", + "plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n", "plt.colorbar()\n", "plt.tight_layout()" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { diff --git a/requirements.txt b/requirements.txt index 53e8af590c..1f7a44f6d8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,33 +1,33 @@ # core deps numpy==1.22.0;python_version<="3.10" -numpy==1.24.3;python_version>"3.10" -cython==0.29.30 +numpy>=1.24.3;python_version>"3.10" +cython>=0.29.30 scipy>=1.11.2 torch>=2.1 torchaudio -soundfile==0.12.* -librosa==0.10.* -scikit-learn==1.3.0 +soundfile>=0.12.0 +librosa>=0.10.0 +scikit-learn>=1.3.0 numba==0.55.1;python_version<"3.9" -numba==0.57.0;python_version>="3.9" -inflect==5.6.* -tqdm==4.64.* -anyascii==0.3.* -pyyaml==6.* -fsspec==2023.6.0 # <= 2023.9.1 makes aux tests fail -aiohttp==3.8.* -packaging==23.1 +numba>=0.57.0;python_version>="3.9" +inflect>=5.6.0 +tqdm>=4.64.1 +anyascii>=0.3.0 +pyyaml>=6.0 +fsspec>=2023.6.0 # <= 2023.9.1 makes aux tests fail +aiohttp>=3.8.1 +packaging>=23.1 # deps for examples -flask==2.* +flask>=2.0.1 # deps for inference -pysbd==0.3.4 +pysbd>=0.3.4 # deps for notebooks -umap-learn==0.5.* +umap-learn>=0.5.1 pandas>=1.4,<2.0 # deps for training -matplotlib==3.7.* +matplotlib>=3.7.0 # coqui stack -trainer +trainer>=0.0.32 # config management coqpit>=0.0.16 # chinese g2p deps @@ -46,11 +46,11 @@ bangla bnnumerizer bnunicodenormalizer #deps for tortoise -k_diffusion -einops==0.6.* -transformers==4.33.* +einops>=0.6.0 +transformers>=4.33.0 #deps for bark -encodec==0.1.* +encodec>=0.1.1 # deps for XTTS -unidecode==1.3.* +unidecode>=1.3.2 num2words +spacy[ja]>=3 \ No newline at end of file diff --git a/tests/vocoder_tests/test_vocoder_losses.py b/tests/vocoder_tests/test_vocoder_losses.py index 2a35aa2e37..95501c2d39 100644 --- a/tests/vocoder_tests/test_vocoder_losses.py +++ b/tests/vocoder_tests/test_vocoder_losses.py @@ -5,6 +5,7 @@ from tests import get_tests_input_path, get_tests_output_path, get_tests_path from TTS.config import BaseAudioConfig from TTS.utils.audio import AudioProcessor +from TTS.utils.audio.numpy_transforms import stft from TTS.vocoder.layers.losses import MelganFeatureLoss, MultiScaleSTFTLoss, STFTLoss, TorchSTFT TESTS_PATH = get_tests_path() @@ -21,7 +22,7 @@ def test_torch_stft(): torch_stft = TorchSTFT(ap.fft_size, ap.hop_length, ap.win_length) # librosa stft wav = ap.load_wav(WAV_FILE) - M_librosa = abs(ap._stft(wav)) # pylint: disable=protected-access + M_librosa = abs(stft(y=wav, fft_size=ap.fft_size, hop_length=ap.hop_length, win_length=ap.win_length)) # torch stft wav = torch.from_numpy(wav[None, :]).float() M_torch = torch_stft(wav) diff --git a/tests/zoo_tests/test_models.py b/tests/zoo_tests/test_models.py index a5aad5c1ea..8fa56e287a 100644 --- a/tests/zoo_tests/test_models.py +++ b/tests/zoo_tests/test_models.py @@ -186,7 +186,7 @@ def test_xtts_v2_streaming(): "en", gpt_cond_latent, speaker_embedding, - speed=1.5 + speed=1.5, ) wav_chuncks = [] for i, chunk in enumerate(chunks): @@ -198,7 +198,7 @@ def test_xtts_v2_streaming(): "en", gpt_cond_latent, speaker_embedding, - speed=0.66 + speed=0.66, ) wav_chuncks = [] for i, chunk in enumerate(chunks):