Skip to content

Commit

Permalink
optimize: drop seed context & add text temps
Browse files Browse the repository at this point in the history
- add param `manual_seed`
- add missing params of refine_text
  • Loading branch information
fumiama committed Jul 31, 2024
1 parent 63f4868 commit e675a59
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 30 deletions.
6 changes: 3 additions & 3 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ def has_loaded(self, use_decoder=False):
self.logger.warning(f"{module} not initialized.")
not_finish = True

if not not_finish:
self.logger.info("all models has been initialized.")

return not not_finish

def download_models(
Expand Down Expand Up @@ -186,6 +183,7 @@ class RefineTextParams:
min_new_token: int = 0
show_tqdm: bool = True
ensure_non_empty: bool = True
manual_seed: Optional[int] = None

@dataclass(repr=False, eq=False)
class InferCodeParams(RefineTextParams):
Expand Down Expand Up @@ -578,6 +576,7 @@ def _infer_code(
show_tqdm=params.show_tqdm,
ensure_non_empty=params.ensure_non_empty,
stream_batch=params.stream_batch,
manual_seed=params.manual_seed,
context=self.context,
)

Expand Down Expand Up @@ -667,6 +666,7 @@ def _refine_text(
stream=False,
show_tqdm=params.show_tqdm,
ensure_non_empty=params.ensure_non_empty,
manual_seed=params.manual_seed,
context=self.context,
)
)
Expand Down
11 changes: 10 additions & 1 deletion ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ def __init__(
self.device = device
self.device_gpt = device_gpt

self.generator = torch.Generator(device=device)

self.config = gpt_config
self.num_vq = int(gpt_config["num_vq"])
self.num_audio_tokens = int(gpt_config["num_audio_tokens"])
Expand Down Expand Up @@ -416,6 +418,7 @@ def generate(
show_tqdm=True,
ensure_non_empty=True,
stream_batch=24,
manual_seed: Optional[int] = None,
context=Context(),
):

Expand Down Expand Up @@ -581,7 +584,13 @@ def generate(

del logits

idx_next = torch.multinomial(scores, num_samples=1).to(finish.device)
if manual_seed is None:
idx_next = torch.multinomial(scores, num_samples=1).to(finish.device)
else:
idx_next = torch.multinomial(
scores, num_samples=1,
generator=self.generator.manual_seed(manual_seed),
).to(finish.device)

del scores

Expand Down
7 changes: 2 additions & 5 deletions examples/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ class ChatTTSParams(BaseModel):
use_decoder: bool = True
do_text_normalization: bool = True
do_homophone_replacement: bool = False
audio_seed: int
text_seed: int
params_refine_text: ChatTTS.Chat.RefineTextParams
params_infer_code: ChatTTS.Chat.InferCodeParams

Expand All @@ -63,13 +61,12 @@ async def generate_voice(params: ChatTTSParams):
logger.info("Text input: %s", str(params.text))

# audio seed
if params.audio_seed:
torch.manual_seed(params.audio_seed)
if params.params_infer_code.manual_seed is not None:
torch.manual_seed(params.params_infer_code.manual_seed)
params.params_infer_code.spk_emb = chat.sample_random_speaker()

# text seed for text refining
if params.params_refine_text:
torch.manual_seed(params.text_seed)
text = chat.infer(
text=params.text, skip_refine_text=False, refine_text_only=True
)
Expand Down
50 changes: 29 additions & 21 deletions examples/web/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,19 +133,27 @@ def refine_text(
text,
text_seed_input,
refine_text_flag,
temperature,
top_P,
top_K,
):
global chat

if not refine_text_flag:
sleep(1) # to skip fast answer of loading mark
return text

with TorchSeedContext(text_seed_input):
text = chat.infer(
text,
skip_refine_text=False,
refine_text_only=True,
)
text = chat.infer(
text,
skip_refine_text=False,
refine_text_only=True,
params_refine_text=ChatTTS.Chat.RefineTextParams(
temperature=temperature,
top_P=top_P,
top_K=top_K,
manual_seed=text_seed_input,
),
)

return text[0] if isinstance(text, list) else text

Expand All @@ -171,28 +179,28 @@ def generate_audio(
temperature=temperature,
top_P=top_P,
top_K=top_K,
manual_seed=audio_seed_input,
)

if sample_text_input and sample_audio_code_input:
params_infer_code.txt_smp = sample_text_input
params_infer_code.spk_smp = sample_audio_code_input
params_infer_code.spk_emb = None

with TorchSeedContext(audio_seed_input):
wav = chat.infer(
text,
skip_refine_text=True,
params_infer_code=params_infer_code,
stream=stream,
)
if stream:
for gen in wav:
audio = gen[0]
if audio is not None and len(audio) > 0:
yield 24000, float_to_int16(audio).T
del audio
else:
yield 24000, float_to_int16(wav[0]).T
wav = chat.infer(
text,
skip_refine_text=True,
params_infer_code=params_infer_code,
stream=stream,
)
if stream:
for gen in wav:
audio = gen[0]
if audio is not None and len(audio) > 0:
yield 24000, float_to_int16(audio).T
del audio
else:
yield 24000, float_to_int16(wav[0]).T


def interrupt_generate():
Expand Down
3 changes: 3 additions & 0 deletions examples/web/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ def make_audio(autoplay, stream):
text_input,
text_seed_input,
refine_text_checkbox,
temperature_slider,
top_p_slider,
top_k_slider,
],
outputs=text_output,
).then(
Expand Down

0 comments on commit e675a59

Please sign in to comment.