Skip to content

Commit

Permalink
fix(norm): tags are read literally (fix #655)
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Aug 2, 2024
1 parent e6b35a9 commit c6bae90
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
2 changes: 2 additions & 0 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,8 @@ def _infer(
for t in text
]

self.logger.debug("normed texts %s", str(text))

if not skip_refine_text:
refined = self._refine_text(
text,
Expand Down
13 changes: 9 additions & 4 deletions ChatTTS/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(self, map_file_path: str, logger=logging.getLogger(__name__)):
"""
self.coding = "utf-16-le" if sys.byteorder == "little" else "utf-16-be"
self.reject_pattern = re.compile(r"[^\u4e00-\u9fffA-Za-z,。、,\. ]")
self.sub_pattern = re.compile(r"\[uv_break\]|\[laugh\]|\[lbreak\]")
self.sub_pattern = re.compile(r"\[[\w_]+\]")
self.chinese_char_pattern = re.compile(r"[\u4e00-\u9fff]")
self.english_word_pattern = re.compile(r"\b[A-Za-z]+\b")
self.character_simplifier = str.maketrans(
Expand All @@ -113,8 +113,8 @@ def __init__(self, map_file_path: str, logger=logging.getLogger(__name__)):
"!": ".",
"(": ",",
")": ",",
"[": ",",
"]": ",",
# "[": ",",
# "]": ",",
">": ",",
"<": ",",
"-": ",",
Expand Down Expand Up @@ -189,7 +189,12 @@ def __call__(
repl_res = ", ".join([f"{_[0]}->{_[1]}" for _ in replaced_words])
self.logger.info(f"replace homophones: {repl_res}")
if len(invalid_characters):
text = self.reject_pattern.sub("", text)
texts, tags = _split_tags(text)
self.logger.debug("split texts %s, tags %s", str(texts), str(tags))
texts = [self.reject_pattern.sub("", t) for t in texts]
self.logger.debug("normed texts %s", str(texts))
text = _combine_tags(texts, tags)
self.logger.debug("combined text %s", text)
return text

def register(self, name: str, normalizer: Callable[[str], str]) -> bool:
Expand Down
11 changes: 7 additions & 4 deletions tests/#655.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@

chat = ChatTTS.Chat(logger)
chat.load(compile=False, source="huggingface") # Set to True for better performance
chat.normalizer.register("en", normalizer_en_nemo_text())
try:
chat.normalizer.register("en", normalizer_en_nemo_text())
except:
logger.warning("Package nemo_text_processing not found!")

rand_spk = chat.sample_random_speaker()

Expand All @@ -29,14 +32,14 @@

fail = False

with TorchSeedContext(1231231):
with TorchSeedContext(12345):
refined_text = chat.infer(
text, refine_text_only=True,
params_refine_text=ChatTTS.Chat.RefineTextParams(
prompt='[oral_2][laugh_0][break_6]',
),
)
if refined_text[0] != "What is [uv_break]your favorite english food?[laugh][lbreak]":
if refined_text[0] != "like [uv_break] what is [uv_break] your favorite english food [laugh] [lbreak]":
fail = True
logger.warning("refined text is '%s'", refined_text[0])

Expand All @@ -60,7 +63,7 @@
).fill_(input_ids.shape[1])

recoded_text = chat.tokenizer.decode(chat.gpt._prepare_generation_outputs(
input_ids, start_idx, end_idx, [], [], False,
input_ids, start_idx, end_idx, [], [], True,
).ids)

if recoded_text[0] != '[Stts] [spk_emb] [speed_5] what is [uv_break] your favorite english food? [laugh] [lbreak] [Ptts]':
Expand Down

0 comments on commit c6bae90

Please sign in to comment.