diff --git a/ChatTTS/model/tokenizer.py b/ChatTTS/model/tokenizer.py index 0ee4ca706..72cdacd3c 100644 --- a/ChatTTS/model/tokenizer.py +++ b/ChatTTS/model/tokenizer.py @@ -162,7 +162,7 @@ def apply_spk_emb( .unsqueeze_(1) .expand(emb.shape) ) - cond = input_ids.narrow(-1, 0, 1).eq(self.spk_emb_ids).expand(emb.shape) + cond = input_ids.narrow(-1, 0, 1).eq(self.spk_emb_ids).expand(emb.shape).to(device) torch.where(cond, n, emb, out=emb) del cond, n @@ -180,12 +180,13 @@ def _decode_prompt(prompt: str) -> torch.Tensor: dtype=" str: - arr: np.ndarray = prompt.to(dtype=torch.uint16, device="cpu").numpy() + # arr: np.ndarray = prompt.to(dtype=torch.uint16, device="cpu").numpy() + arr: np.ndarray = prompt.to(device="cpu").numpy().astype(np.uint16) shp = arr.shape assert len(shp) == 2, "prompt must be a 2D tensor" s = b14.encode_to_string(