diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 75f983924..f63705d2e 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -161,7 +161,11 @@ def sample_random_speaker(self) -> str: return self.speaker.sample_random() def sample_audio_speaker(self, wav: Union[np.ndarray, torch.Tensor]) -> str: - return self.speaker.encode_prompt(self.dvae.sample_audio(wav)) + sample_audio = self.dvae.sample_audio(wav) + if "npu" in str(self.device): + # reset dvae to npu + self.dvae.to(self.device) + return self.speaker.encode_prompt(sample_audio) @dataclass(repr=False, eq=False) class RefineTextParams: @@ -268,13 +272,19 @@ def _load( self.vocos = vocos self.logger.log(logging.INFO, "vocos loaded.") - dvae = DVAE( - decoder_config=asdict(self.config.dvae.decoder), - encoder_config=asdict(self.config.dvae.encoder), - vq_config=asdict(self.config.dvae.vq), - dim=self.config.dvae.decoder.idim, - coef=coef, - device=device, + # Computation of MelSpectrogram on npu is not support now, use cpu fallback. + dvae_device = torch.device("cpu") if "npu" in str(self.device) else device + dvae = ( + DVAE( + decoder_config=asdict(self.config.dvae.decoder), + encoder_config=asdict(self.config.dvae.encoder), + vq_config=asdict(self.config.dvae.vq), + dim=self.config.dvae.decoder.idim, + coef=coef, + device=dvae_device, + ) + .to(dvae_device) + .eval() ) coef = str(dvae) assert dvae_ckpt_path, "dvae_ckpt_path should not be None"