From 9e96be7221c623beb7c4ed45df5e4e2f77e6a474 Mon Sep 17 00:00:00 2001 From: ZaymeShaw <402147150@qq.com> Date: Sat, 27 Jul 2024 13:08:32 +0800 Subject: [PATCH] fix some problem --- ChatTTS/model/tokenizer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ChatTTS/model/tokenizer.py b/ChatTTS/model/tokenizer.py index 0ee4ca706..72cdacd3c 100644 --- a/ChatTTS/model/tokenizer.py +++ b/ChatTTS/model/tokenizer.py @@ -162,7 +162,7 @@ def apply_spk_emb( .unsqueeze_(1) .expand(emb.shape) ) - cond = input_ids.narrow(-1, 0, 1).eq(self.spk_emb_ids).expand(emb.shape) + cond = input_ids.narrow(-1, 0, 1).eq(self.spk_emb_ids).expand(emb.shape).to(device) torch.where(cond, n, emb, out=emb) del cond, n @@ -180,12 +180,13 @@ def _decode_prompt(prompt: str) -> torch.Tensor: dtype=" str: - arr: np.ndarray = prompt.to(dtype=torch.uint16, device="cpu").numpy() + # arr: np.ndarray = prompt.to(dtype=torch.uint16, device="cpu").numpy() + arr: np.ndarray = prompt.to(device="cpu").numpy().astype(np.uint16) shp = arr.shape assert len(shp) == 2, "prompt must be a 2D tensor" s = b14.encode_to_string(