diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 2d3250427..7b7674396 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -377,7 +377,7 @@ def _infer( return if stream: - length = np.zeros(len(text), dtype=np.uint32) + length = 0 pass_batch_count = 0 for result in self._infer_code( text, @@ -397,22 +397,15 @@ def _infer( continue a = length b = a + params_infer_code.stream_speed - new_wavs = np.zeros((wavs.shape[0], params_infer_code.stream_speed)) - for i in range(wavs.shape[0]): - if b[i] > len(wavs[i]): - b[i] = len(wavs[i]) - new_wavs[i, : b[i] - a[i]] = wavs[i, a[i] : b[i]] + if b > wavs.shape[1]: + b = wavs.shape[1] + new_wavs = wavs[:, a:b] length = b yield new_wavs else: yield wavs if stream: - for i in range(wavs.shape[0]): - a = length[i] - b = len(wavs[i]) - wavs[i, : b - a] = wavs[i, a:] - wavs[i, b - a :] = 0 - yield wavs + yield wavs[:, length:] @torch.inference_mode() def _vocos_decode(self, spec: torch.Tensor) -> np.ndarray: