diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 1d2333d00..47fb43124 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -476,28 +476,10 @@ def _infer_code( else: temperature = params.temperature - for i, t in enumerate(text): - text[i] = ( - t.replace("[Stts]", "") - .replace("[spk_emb]", "") - .replace("[empty_spk]", "") - .strip() - ) - """ - see https://github.com/2noise/ChatTTS/issues/459 - """ - - if params.prompt: - text = [params.prompt + i for i in text] - - txt_smp = "" if params.txt_smp is None else params.txt_smp - if params.spk_emb is not None: - text = [f"[Stts][spk_emb]{txt_smp}{i}[Ptts]" for i in text] - else: - text = [f"[Stts][empty_spk]{txt_smp}{i}[Ptts]" for i in text] - input_ids, attention_mask, text_mask = self.tokenizer.encode( - text, + self.tokenizer.decorate_code_prompts( + text, params.prompt, params.txt_smp, params.spk_emb, + ), self.config.gpt.num_vq, prompt_str=params.spk_smp, device=self.device_gpt, @@ -597,10 +579,8 @@ def _refine_text( if not isinstance(text, list): text = [text] - text = [f"[Sbreak]{i}[Pbreak]{params.prompt}" for i in text] - input_ids, attention_mask, text_mask = self.tokenizer.encode( - text, + self.tokenizer.decorate_text_prompts(text, params.prompt), self.config.gpt.num_vq, device=self.device_gpt, ) diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index d967fcaab..97661b064 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -251,6 +251,7 @@ def to(self, device: torch.device, dtype: torch.dtype): if self.cache_position is not None: self.cache_position = self.cache_position.to(device, dtype=dtype) + @torch.no_grad() def _prepare_generation_inputs( self, input_ids: torch.Tensor, @@ -371,6 +372,7 @@ def destroy(self): del_all(self.attentions) del_all(self.hiddens) + @torch.no_grad() def _prepare_generation_outputs( self, inputs_ids: torch.Tensor, diff --git a/ChatTTS/model/tokenizer.py b/ChatTTS/model/tokenizer.py index 24d277cb3..e345a7a2f 100644 --- a/ChatTTS/model/tokenizer.py +++ b/ChatTTS/model/tokenizer.py @@ -5,7 +5,7 @@ https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning """ -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Union import lzma import numpy as np @@ -31,8 +31,6 @@ def __init__( self.break_0_ids = tokenizer.convert_tokens_to_ids("[break_0]") self.eos_token = tokenizer.convert_tokens_to_ids("[Ebreak]") - self.decode = self._tokenizer.batch_decode - @torch.inference_mode() def encode( self, @@ -128,6 +126,15 @@ def encode( return new_input_ids, attention_mask, text_mask + @torch.inference_mode + def decode( + self, sequences: Union[List[int], List[List[int]]], + skip_special_tokens: bool = False, + clean_up_tokenization_spaces: bool = None, + **kwargs, + ): + return self._tokenizer.batch_decode(sequences, skip_special_tokens, clean_up_tokenization_spaces, **kwargs) + @staticmethod def _decode_spk_emb(spk_emb: str) -> np.ndarray: return np.frombuffer( @@ -212,3 +219,35 @@ def _encode_spk_emb(spk_emb: torch.Tensor) -> str: ) del arr return s + + @staticmethod + @torch.no_grad() + def decorate_code_prompts( + text: List[str], prompt: str, txt_smp: Optional[str], spk_emb: Optional[str], + ) -> List[str]: + for i, t in enumerate(text): + text[i] = ( + t.replace("[Stts]", "") + .replace("[spk_emb]", "") + .replace("[empty_spk]", "") + .strip() + ) + """ + see https://github.com/2noise/ChatTTS/issues/459 + """ + + if prompt: + text = [prompt + i for i in text] + + txt_smp = "" if txt_smp is None else txt_smp + if spk_emb is not None: + text = [f"[Stts][spk_emb]{txt_smp}{i}[Ptts]" for i in text] + else: + text = [f"[Stts][empty_spk]{txt_smp}{i}[Ptts]" for i in text] + + return text + + @staticmethod + @torch.no_grad() + def decorate_text_prompts(text: List[str], prompt: str) -> List[str]: + return [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text] diff --git a/tests/#655.py b/tests/#655.py new file mode 100644 index 000000000..54a515ef2 --- /dev/null +++ b/tests/#655.py @@ -0,0 +1,60 @@ +import os, sys + +if sys.platform == "darwin": + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + +now_dir = os.getcwd() +sys.path.append(now_dir) + +import logging + +import torch + +import ChatTTS + +from tools.logger import get_logger + +logger = get_logger("Test #655", lv=logging.WARN) + +chat = ChatTTS.Chat(logger) +chat.load(compile=False, source="huggingface") # Set to True for better performance + +rand_spk = chat.sample_random_speaker() + +params = ChatTTS.Chat.InferCodeParams( + spk_emb = rand_spk, # add sampled speaker + temperature = .3, # using custom temperature + top_P = 0.7, # top P decode + top_K = 20, # top K decode +) + +text = ['What is [uv_break]your favorite english food?[laugh][lbreak]'] + +fail = False + +input_ids, attention_mask, text_mask = chat.tokenizer.encode( + chat.tokenizer.decorate_code_prompts( + text, params.prompt, params.txt_smp, params.spk_emb, + ), + chat.config.gpt.num_vq, + prompt_str=params.spk_smp, + device=chat.device_gpt, +) +with torch.inference_mode(): + start_idx, end_idx = 0, torch.zeros( + input_ids.shape[0], device=input_ids.device, dtype=torch.long + ).fill_(input_ids.shape[1]) + + recoded_text = chat.tokenizer.decode(chat.gpt._prepare_generation_outputs( + input_ids, start_idx, end_idx, [], [], True, + ).ids) + +fail = recoded_text[0] != '[Stts] [spk_emb] [speed_5] what is [uv_break] your favorite english food? [laugh] [lbreak] [Ptts]' + +if fail: + + logging.warning("got recoded_text '%s'", recoded_text[0]) + + import sys + + sys.exit(1)