Skip to content

Commit

Permalink
fix: split tags and text first before norm (#655)
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Aug 2, 2024
1 parent 36c8723 commit 18dba88
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 4 deletions.
35 changes: 34 additions & 1 deletion ChatTTS/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,37 @@ def _fast_replace(
replaced_words.append((chr(ch), chr(repl_char)))
return result, replaced_words

@jit
def _split_tags(text: str) -> Tuple[List[str], List[str]]:
texts: List[str] = []
tags: List[str] = []
current_text = ""
current_tag = ""
for c in text:
if c == '[':
texts.append(current_text)
current_text = ""
current_tag = c
elif current_tag != "":
current_tag += c
else:
current_text += c
if c == ']':
tags.append(current_tag)
current_tag = ""
if current_text != "":
texts.append(current_text)
return texts, tags

@jit
def _combine_tags(texts: List[str], tags: List[str]) -> str:
text = ""
for t in texts:
tg = ""
if len(tags) > 0:
tg = tags.pop(0)
text += t + tg
return text

class Normalizer:
def __init__(self, map_file_path: str, logger=logging.getLogger(__name__)):
Expand Down Expand Up @@ -136,7 +167,9 @@ def __call__(
if do_text_normalization:
_lang = self._detect_language(text) if lang is None else lang
if _lang in self.normalizers:
text = self.normalizers[_lang](text)
texts, tags = _split_tags(text)
texts = [self.normalizers[_lang](t) for t in text]
text = _combine_tags(texts, tags)
if _lang == "zh":
text = self._apply_half2full_map(text)
invalid_characters = self._count_invalid_characters(text)
Expand Down
2 changes: 1 addition & 1 deletion tests/#511.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from tools.logger import get_logger

logger = get_logger("Test #511", lv=logging.WARN)
logger = get_logger("Test", lv=logging.WARN)

chat = ChatTTS.Chat(logger)
chat.load(compile=False, source="huggingface") # Set to True for better performance
Expand Down
2 changes: 1 addition & 1 deletion tests/#588.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from tools.logger import get_logger

logger = get_logger("Test #588", lv=logging.WARN)
logger = get_logger("Test", lv=logging.WARN)

chat = ChatTTS.Chat(logger)
chat.load(compile=False, source="huggingface") # Set to True for better performance
Expand Down
4 changes: 3 additions & 1 deletion tests/#655.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
import ChatTTS

from tools.logger import get_logger
from tools.normalizer import normalizer_en_nemo_text

logger = get_logger("Test #655", lv=logging.WARN)
logger = get_logger("Test", lv=logging.WARN)

chat = ChatTTS.Chat(logger)
chat.load(compile=False, source="huggingface") # Set to True for better performance
chat.normalizer.register("en", normalizer_en_nemo_text())

rand_spk = chat.sample_random_speaker()

Expand Down

0 comments on commit 18dba88

Please sign in to comment.