Skip to content

Commit

Permalink
Swap checkpoint_dir, checkpoint_path parameters in xtts
Browse files Browse the repository at this point in the history
  • Loading branch information
bivashy committed Dec 29, 2023
1 parent c176266 commit 1432572
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,8 +731,8 @@ def get_compatible_checkpoint_state_dict(self, model_path):
def load_checkpoint(
self,
config,
checkpoint_dir=None,
checkpoint_path=None,
checkpoint_dir=None,
vocab_path=None,
eval=True,
strict=True,
Expand All @@ -744,16 +744,17 @@ def load_checkpoint(
Args:
config (dict): The configuration dictionary for the model.
checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None.
checkpoint_path (str, optional): The path to the checkpoint file. Defaults to None.
checkpoint_dir (str, optional): The directory where the checkpoint is stored. Defaults to None.
vocab_path (str, optional): The path to the vocabulary file. Defaults to None.
eval (bool, optional): Whether to set the model to evaluation mode. Defaults to True.
strict (bool, optional): Whether to strictly enforce that the keys in the checkpoint match the keys in the model. Defaults to True.
Returns:
None
"""

if checkpoint_dir is None and checkpoint_path:
checkpoint_dir = os.path.dirname(checkpoint_path)
model_path = checkpoint_path or os.path.join(checkpoint_dir, "model.pth")
vocab_path = vocab_path or os.path.join(checkpoint_dir, "vocab.json")

Expand Down

0 comments on commit 1432572

Please sign in to comment.