Skip to content

Commit

Permalink
test: add for issue #655
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Aug 2, 2024
1 parent e483d55 commit 36c8723
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 27 deletions.
28 changes: 4 additions & 24 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,28 +476,10 @@ def _infer_code(
else:
temperature = params.temperature

for i, t in enumerate(text):
text[i] = (
t.replace("[Stts]", "")
.replace("[spk_emb]", "")
.replace("[empty_spk]", "")
.strip()
)
"""
see https://github.com/2noise/ChatTTS/issues/459
"""

if params.prompt:
text = [params.prompt + i for i in text]

txt_smp = "" if params.txt_smp is None else params.txt_smp
if params.spk_emb is not None:
text = [f"[Stts][spk_emb]{txt_smp}{i}[Ptts]" for i in text]
else:
text = [f"[Stts][empty_spk]{txt_smp}{i}[Ptts]" for i in text]

input_ids, attention_mask, text_mask = self.tokenizer.encode(
text,
self.tokenizer.decorate_code_prompts(
text, params.prompt, params.txt_smp, params.spk_emb,
),
self.config.gpt.num_vq,
prompt_str=params.spk_smp,
device=self.device_gpt,
Expand Down Expand Up @@ -597,10 +579,8 @@ def _refine_text(
if not isinstance(text, list):
text = [text]

text = [f"[Sbreak]{i}[Pbreak]{params.prompt}" for i in text]

input_ids, attention_mask, text_mask = self.tokenizer.encode(
text,
self.tokenizer.decorate_text_prompts(text, params.prompt),
self.config.gpt.num_vq,
device=self.device_gpt,
)
Expand Down
2 changes: 2 additions & 0 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def to(self, device: torch.device, dtype: torch.dtype):
if self.cache_position is not None:
self.cache_position = self.cache_position.to(device, dtype=dtype)

@torch.no_grad()
def _prepare_generation_inputs(
self,
input_ids: torch.Tensor,
Expand Down Expand Up @@ -371,6 +372,7 @@ def destroy(self):
del_all(self.attentions)
del_all(self.hiddens)

@torch.no_grad()
def _prepare_generation_outputs(
self,
inputs_ids: torch.Tensor,
Expand Down
45 changes: 42 additions & 3 deletions ChatTTS/model/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning
"""

from typing import List, Tuple, Optional
from typing import List, Tuple, Optional, Union
import lzma

import numpy as np
Expand All @@ -31,8 +31,6 @@ def __init__(
self.break_0_ids = tokenizer.convert_tokens_to_ids("[break_0]")
self.eos_token = tokenizer.convert_tokens_to_ids("[Ebreak]")

self.decode = self._tokenizer.batch_decode

@torch.inference_mode()
def encode(
self,
Expand Down Expand Up @@ -128,6 +126,15 @@ def encode(

return new_input_ids, attention_mask, text_mask

@torch.inference_mode
def decode(
self, sequences: Union[List[int], List[List[int]]],
skip_special_tokens: bool = False,
clean_up_tokenization_spaces: bool = None,
**kwargs,
):
return self._tokenizer.batch_decode(sequences, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)

@staticmethod
def _decode_spk_emb(spk_emb: str) -> np.ndarray:
return np.frombuffer(
Expand Down Expand Up @@ -212,3 +219,35 @@ def _encode_spk_emb(spk_emb: torch.Tensor) -> str:
)
del arr
return s

@staticmethod
@torch.no_grad()
def decorate_code_prompts(
text: List[str], prompt: str, txt_smp: Optional[str], spk_emb: Optional[str],
) -> List[str]:
for i, t in enumerate(text):
text[i] = (
t.replace("[Stts]", "")
.replace("[spk_emb]", "")
.replace("[empty_spk]", "")
.strip()
)
"""
see https://github.com/2noise/ChatTTS/issues/459
"""

if prompt:
text = [prompt + i for i in text]

txt_smp = "" if txt_smp is None else txt_smp
if spk_emb is not None:
text = [f"[Stts][spk_emb]{txt_smp}{i}[Ptts]" for i in text]
else:
text = [f"[Stts][empty_spk]{txt_smp}{i}[Ptts]" for i in text]

return text

@staticmethod
@torch.no_grad()
def decorate_text_prompts(text: List[str], prompt: str) -> List[str]:
return [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text]
60 changes: 60 additions & 0 deletions tests/#655.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os, sys

if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

now_dir = os.getcwd()
sys.path.append(now_dir)

import logging

import torch

import ChatTTS

from tools.logger import get_logger

logger = get_logger("Test #655", lv=logging.WARN)

chat = ChatTTS.Chat(logger)
chat.load(compile=False, source="huggingface") # Set to True for better performance

rand_spk = chat.sample_random_speaker()

params = ChatTTS.Chat.InferCodeParams(
spk_emb = rand_spk, # add sampled speaker
temperature = .3, # using custom temperature
top_P = 0.7, # top P decode
top_K = 20, # top K decode
)

text = ['What is [uv_break]your favorite english food?[laugh][lbreak]']

fail = False

input_ids, attention_mask, text_mask = chat.tokenizer.encode(
chat.tokenizer.decorate_code_prompts(
text, params.prompt, params.txt_smp, params.spk_emb,
),
chat.config.gpt.num_vq,
prompt_str=params.spk_smp,
device=chat.device_gpt,
)
with torch.inference_mode():
start_idx, end_idx = 0, torch.zeros(
input_ids.shape[0], device=input_ids.device, dtype=torch.long
).fill_(input_ids.shape[1])

recoded_text = chat.tokenizer.decode(chat.gpt._prepare_generation_outputs(
input_ids, start_idx, end_idx, [], [], True,
).ids)

fail = recoded_text[0] != '[Stts] [spk_emb] [speed_5] what is [uv_break] your favorite english food? [laugh] [lbreak] [Ptts]'

if fail:

logging.warning("got recoded_text '%s'", recoded_text[0])

import sys

sys.exit(1)

0 comments on commit 36c8723

Please sign in to comment.