From c6bae901581d9574e774c9cba7876671808d1dab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Fri, 2 Aug 2024 15:38:22 +0800 Subject: [PATCH] fix(norm): tags are read literally (fix #655) --- ChatTTS/core.py | 2 ++ ChatTTS/norm.py | 13 +++++++++---- tests/#655.py | 11 +++++++---- 3 files changed, 18 insertions(+), 8 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 47fb43124..ad28bf1c7 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -368,6 +368,8 @@ def _infer( for t in text ] + self.logger.debug("normed texts %s", str(text)) + if not skip_refine_text: refined = self._refine_text( text, diff --git a/ChatTTS/norm.py b/ChatTTS/norm.py index 770a440e0..e5a9f6ae3 100644 --- a/ChatTTS/norm.py +++ b/ChatTTS/norm.py @@ -89,7 +89,7 @@ def __init__(self, map_file_path: str, logger=logging.getLogger(__name__)): """ self.coding = "utf-16-le" if sys.byteorder == "little" else "utf-16-be" self.reject_pattern = re.compile(r"[^\u4e00-\u9fffA-Za-z,。、,\. ]") - self.sub_pattern = re.compile(r"\[uv_break\]|\[laugh\]|\[lbreak\]") + self.sub_pattern = re.compile(r"\[[\w_]+\]") self.chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]") self.english_word_pattern = re.compile(r"\b[A-Za-z]+\b") self.character_simplifier = str.maketrans( @@ -113,8 +113,8 @@ def __init__(self, map_file_path: str, logger=logging.getLogger(__name__)): "!": ".", "(": ",", ")": ",", - "[": ",", - "]": ",", + # "[": ",", + # "]": ",", ">": ",", "<": ",", "-": ",", @@ -189,7 +189,12 @@ def __call__( repl_res = ", ".join([f"{_[0]}->{_[1]}" for _ in replaced_words]) self.logger.info(f"replace homophones: {repl_res}") if len(invalid_characters): - text = self.reject_pattern.sub("", text) + texts, tags = _split_tags(text) + self.logger.debug("split texts %s, tags %s", str(texts), str(tags)) + texts = [self.reject_pattern.sub("", t) for t in texts] + self.logger.debug("normed texts %s", str(texts)) + text = _combine_tags(texts, tags) + self.logger.debug("combined text %s", text) return text def register(self, name: str, normalizer: Callable[[str], str]) -> bool: diff --git a/tests/#655.py b/tests/#655.py index 502e40b99..f673b5ecc 100644 --- a/tests/#655.py +++ b/tests/#655.py @@ -20,7 +20,10 @@ chat = ChatTTS.Chat(logger) chat.load(compile=False, source="huggingface") # Set to True for better performance -chat.normalizer.register("en", normalizer_en_nemo_text()) +try: + chat.normalizer.register("en", normalizer_en_nemo_text()) +except: + logger.warning("Package nemo_text_processing not found!") rand_spk = chat.sample_random_speaker() @@ -29,14 +32,14 @@ fail = False -with TorchSeedContext(1231231): +with TorchSeedContext(12345): refined_text = chat.infer( text, refine_text_only=True, params_refine_text=ChatTTS.Chat.RefineTextParams( prompt='[oral_2][laugh_0][break_6]', ), ) -if refined_text[0] != "What is [uv_break]your favorite english food?[laugh][lbreak]": +if refined_text[0] != "like [uv_break] what is [uv_break] your favorite english food [laugh] [lbreak]": fail = True logger.warning("refined text is '%s'", refined_text[0]) @@ -60,7 +63,7 @@ ).fill_(input_ids.shape[1]) recoded_text = chat.tokenizer.decode(chat.gpt._prepare_generation_outputs( - input_ids, start_idx, end_idx, [], [], False, + input_ids, start_idx, end_idx, [], [], True, ).ids) if recoded_text[0] != '[Stts] [spk_emb] [speed_5] what is [uv_break] your favorite english food? [laugh] [lbreak] [Ptts]':