diff --git a/vocos/pretrained.py b/vocos/pretrained.py index a8a5935..9385085 100644 --- a/vocos/pretrained.py +++ b/vocos/pretrained.py @@ -67,7 +67,9 @@ def from_pretrained(cls, repo_id: str, revision: Optional[str] = None) -> Vocos: config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml", revision=revision) model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin", revision=revision) model = cls.from_hparams(config_path) - state_dict = torch.load(model_path, map_location="cpu") + # Check if GPU is available and load model on it. + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + state_dict = torch.load(model_path, map_location=device) if isinstance(model.feature_extractor, EncodecFeatures): encodec_parameters = { "feature_extractor.encodec." + key: value