Skip to content

Commit

Permalink
optimize: revert default device to cpu to satisfy non-cuda users
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Sep 4, 2024
1 parent 651093e commit a4c8a5a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
7 changes: 4 additions & 3 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def _load(
vq_config=asdict(self.config.dvae.vq),
dim=self.config.dvae.decoder.idim,
coef=coef,
device=self.device,
device=device,
)
.to(device)
.eval()
Expand All @@ -289,8 +289,8 @@ def _load(
self.config.embed.num_text_tokens,
self.config.embed.num_vq,
)
embed.from_pretrained(embed_path, device=self.device)
self.embed = embed.to(self.device)
embed.from_pretrained(embed_path, device=device)
self.embed = embed.to(device)
self.logger.log(logging.INFO, "embed loaded.")

gpt = GPT(
Expand Down Expand Up @@ -318,6 +318,7 @@ def _load(
decoder_config=asdict(self.config.decoder),
dim=self.config.decoder.idim,
coef=coef,
device=device,
)
.to(device)
.eval()
Expand Down
4 changes: 2 additions & 2 deletions ChatTTS/model/dvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def __init__(
hop_length=256,
n_mels=100,
padding: Literal["center", "same"] = "center",
device: torch.device = torch.device("cuda"),
device: torch.device = torch.device("cpu"),
):
super().__init__()
self.device = device
Expand Down Expand Up @@ -213,7 +213,7 @@ def __init__(
vq_config: Optional[dict] = None,
dim=512,
coef: Optional[str] = None,
device: torch.device = torch.device("cuda"),
device: torch.device = torch.device("cpu"),
):
super().__init__()
if coef is None:
Expand Down

0 comments on commit a4c8a5a

Please sign in to comment.