From e21375747e1a1ad053da3541e76db6bb3f5a5f08 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Wed, 8 Nov 2023 13:41:59 +0200 Subject: [PATCH] Run `make style` --- TTS/cs_api.py | 9 +++++--- TTS/tts/layers/tortoise/dpm_solver.py | 23 +++++++++++++++----- TTS/tts/layers/xtts/tokenizer.py | 7 ++++-- TTS/tts/models/xtts.py | 4 ++-- tests/xtts_tests/test_xtts_gpt_train.py | 4 +++- tests/xtts_tests/test_xtts_v2-0_gpt_train.py | 4 +++- 6 files changed, 36 insertions(+), 15 deletions(-) diff --git a/TTS/cs_api.py b/TTS/cs_api.py index ac9c8698bf..476ce70596 100644 --- a/TTS/cs_api.py +++ b/TTS/cs_api.py @@ -82,7 +82,6 @@ class CS_API: }, } - SUPPORTED_LANGUAGES = ["en", "es", "de", "fr", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn", "ja"] def __init__(self, api_token=None, model="XTTS"): @@ -308,7 +307,11 @@ def tts_to_file( print(api.list_speakers_as_tts_models()) ts = time.time() - wav, sr = api.tts("It took me quite a long time to develop a voice.", language="en", speaker_name=api.speakers[0].name) + wav, sr = api.tts( + "It took me quite a long time to develop a voice.", language="en", speaker_name=api.speakers[0].name + ) print(f" [i] XTTS took {time.time() - ts:.2f}s") - filepath = api.tts_to_file(text="Hello world!", speaker_name=api.speakers[0].name, language="en", file_path="output.wav") + filepath = api.tts_to_file( + text="Hello world!", speaker_name=api.speakers[0].name, language="en", file_path="output.wav" + ) diff --git a/TTS/tts/layers/tortoise/dpm_solver.py b/TTS/tts/layers/tortoise/dpm_solver.py index 2166eebb3c..c70888df42 100644 --- a/TTS/tts/layers/tortoise/dpm_solver.py +++ b/TTS/tts/layers/tortoise/dpm_solver.py @@ -562,15 +562,21 @@ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type if order == 3: K = steps // 3 + 1 if steps % 3 == 0: - orders = [3,] * ( + orders = [ + 3, + ] * ( K - 2 ) + [2, 1] elif steps % 3 == 1: - orders = [3,] * ( + orders = [ + 3, + ] * ( K - 1 ) + [1] else: - orders = [3,] * ( + orders = [ + 3, + ] * ( K - 1 ) + [2] elif order == 2: @@ -581,7 +587,9 @@ def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type ] * K else: K = steps // 2 + 1 - orders = [2,] * ( + orders = [ + 2, + ] * ( K - 1 ) + [1] elif order == 1: @@ -1440,7 +1448,10 @@ def sample( model_prev_list[-1] = self.model_fn(x, t) elif method in ["singlestep", "singlestep_fixed"]: if method == "singlestep": - (timesteps_outer, orders,) = self.get_orders_and_timesteps_for_singlestep_solver( + ( + timesteps_outer, + orders, + ) = self.get_orders_and_timesteps_for_singlestep_solver( steps=steps, order=order, skip_type=skip_type, @@ -1548,4 +1559,4 @@ def expand_dims(v, dims): Returns: a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`. """ - return v[(...,) + (None,) * (dims - 1)] \ No newline at end of file + return v[(...,) + (None,) * (dims - 1)] diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py index ed677773eb..28b522c45b 100644 --- a/TTS/tts/layers/xtts/tokenizer.py +++ b/TTS/tts/layers/xtts/tokenizer.py @@ -559,12 +559,15 @@ def __init__(self, vocab_file=None): @cached_property def katsu(self): import cutlet + return cutlet.Cutlet() - + def check_input_length(self, txt, lang): limit = self.char_limits.get(lang, 250) if len(txt) > limit: - print(f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio.") + print( + f"[!] Warning: The text length exceeds the character limit of {limit} for language '{lang}', this might cause truncated audio." + ) def preprocess_text(self, txt, lang): if lang in {"ar", "cs", "de", "en", "es", "fr", "hu", "it", "nl", "pl", "pt", "ru", "tr", "zh", "zh-cn"}: diff --git a/TTS/tts/models/xtts.py b/TTS/tts/models/xtts.py index da0952ff92..e01d008f76 100644 --- a/TTS/tts/models/xtts.py +++ b/TTS/tts/models/xtts.py @@ -606,7 +606,7 @@ def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len): if overlap_len > len(wav_chunk): # wav_chunk is smaller than overlap_len, pass on last wav_gen if wav_gen_prev is not None: - wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len):] + wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) :] else: # not expecting will hit here as problem happens on last chunk wav_chunk = wav_gen[-overlap_len:] @@ -616,7 +616,7 @@ def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len): crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device) wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device) wav_chunk[:overlap_len] += crossfade_wav - + wav_overlap = wav_gen[-overlap_len:] wav_gen_prev = wav_gen return wav_chunk, wav_gen_prev, wav_overlap diff --git a/tests/xtts_tests/test_xtts_gpt_train.py b/tests/xtts_tests/test_xtts_gpt_train.py index 12c547d684..b8b9a4e388 100644 --- a/tests/xtts_tests/test_xtts_gpt_train.py +++ b/tests/xtts_tests/test_xtts_gpt_train.py @@ -60,7 +60,9 @@ # Training sentences generations -SPEAKER_REFERENCE = ["tests/data/ljspeech/wavs/LJ001-0002.wav"] # speaker reference to be used in training test sentences +SPEAKER_REFERENCE = [ + "tests/data/ljspeech/wavs/LJ001-0002.wav" +] # speaker reference to be used in training test sentences LANGUAGE = config_dataset.language diff --git a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py index b19b7210d8..6663433c12 100644 --- a/tests/xtts_tests/test_xtts_v2-0_gpt_train.py +++ b/tests/xtts_tests/test_xtts_v2-0_gpt_train.py @@ -58,7 +58,9 @@ # Training sentences generations -SPEAKER_REFERENCE = ["tests/data/ljspeech/wavs/LJ001-0002.wav"] # speaker reference to be used in training test sentences +SPEAKER_REFERENCE = [ + "tests/data/ljspeech/wavs/LJ001-0002.wav" +] # speaker reference to be used in training test sentences LANGUAGE = config_dataset.language