From 18dba884bf458ab136ce0eee3fe56ca3972f685a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Fri, 2 Aug 2024 15:00:16 +0800 Subject: [PATCH] fix: split tags and text first before norm (#655) --- ChatTTS/norm.py | 35 ++++++++++++++++++++++++++++++++++- tests/#511.py | 2 +- tests/#588.py | 2 +- tests/#655.py | 4 +++- 4 files changed, 39 insertions(+), 4 deletions(-) diff --git a/ChatTTS/norm.py b/ChatTTS/norm.py index 74e12fd2a..d3b42d015 100644 --- a/ChatTTS/norm.py +++ b/ChatTTS/norm.py @@ -33,6 +33,37 @@ def _fast_replace( replaced_words.append((chr(ch), chr(repl_char))) return result, replaced_words +@jit +def _split_tags(text: str) -> Tuple[List[str], List[str]]: + texts: List[str] = [] + tags: List[str] = [] + current_text = "" + current_tag = "" + for c in text: + if c == '[': + texts.append(current_text) + current_text = "" + current_tag = c + elif current_tag != "": + current_tag += c + else: + current_text += c + if c == ']': + tags.append(current_tag) + current_tag = "" + if current_text != "": + texts.append(current_text) + return texts, tags + +@jit +def _combine_tags(texts: List[str], tags: List[str]) -> str: + text = "" + for t in texts: + tg = "" + if len(tags) > 0: + tg = tags.pop(0) + text += t + tg + return text class Normalizer: def __init__(self, map_file_path: str, logger=logging.getLogger(__name__)): @@ -136,7 +167,9 @@ def __call__( if do_text_normalization: _lang = self._detect_language(text) if lang is None else lang if _lang in self.normalizers: - text = self.normalizers[_lang](text) + texts, tags = _split_tags(text) + texts = [self.normalizers[_lang](t) for t in text] + text = _combine_tags(texts, tags) if _lang == "zh": text = self._apply_half2full_map(text) invalid_characters = self._count_invalid_characters(text) diff --git a/tests/#511.py b/tests/#511.py index e225b8e1e..20198eaae 100644 --- a/tests/#511.py +++ b/tests/#511.py @@ -12,7 +12,7 @@ from tools.logger import get_logger -logger = get_logger("Test #511", lv=logging.WARN) +logger = get_logger("Test", lv=logging.WARN) chat = ChatTTS.Chat(logger) chat.load(compile=False, source="huggingface") # Set to True for better performance diff --git a/tests/#588.py b/tests/#588.py index d34188929..1c2d860c6 100644 --- a/tests/#588.py +++ b/tests/#588.py @@ -12,7 +12,7 @@ from tools.logger import get_logger -logger = get_logger("Test #588", lv=logging.WARN) +logger = get_logger("Test", lv=logging.WARN) chat = ChatTTS.Chat(logger) chat.load(compile=False, source="huggingface") # Set to True for better performance diff --git a/tests/#655.py b/tests/#655.py index 54a515ef2..bc872e02c 100644 --- a/tests/#655.py +++ b/tests/#655.py @@ -13,11 +13,13 @@ import ChatTTS from tools.logger import get_logger +from tools.normalizer import normalizer_en_nemo_text -logger = get_logger("Test #655", lv=logging.WARN) +logger = get_logger("Test", lv=logging.WARN) chat = ChatTTS.Chat(logger) chat.load(compile=False, source="huggingface") # Set to True for better performance +chat.normalizer.register("en", normalizer_en_nemo_text()) rand_spk = chat.sample_random_speaker()