diff --git a/README.md b/README.md index cc471c1..a2272d3 100644 --- a/README.md +++ b/README.md @@ -105,6 +105,28 @@ python train.py -c configs/vocos.yaml Refer to [Pytorch Lightning documentation](https://lightning.ai/docs/pytorch/stable/) for details about customizing the training pipeline. +## Offline mode + +If internet is not available in your GPU cluster, you can upload vocos and encodec checkpoint +to your GPU cluster and enable offline mode. + +```python +import sys +import os + +os.environ['HF_HOME'] = '/home/your_name/your_hf_home' +os.environ['HF_DATASETS_OFFLINE'] = '1' +os.environ['TRANSFORMERS_OFFLINE'] = '1' +os.environ['VOCOS_OFFLINE'] = '1' # just like `transformers` and `datasets` + +import torch +from vocos import Vocos + +vocos = Vocos.from_pretrained( + "/your_folder/charactr/vocos-encodec-24khz", + feature_extractor_repo='/your_folder/torch-hub/hub/checkpoints') +``` + ## Citation If this code contributes to your research, please cite our work: diff --git a/vocos/feature_extractors.py b/vocos/feature_extractors.py index 799f1b4..025b5fc 100644 --- a/vocos/feature_extractors.py +++ b/vocos/feature_extractors.py @@ -1,4 +1,5 @@ -from typing import List +from typing import Optional, List +from pathlib import Path import torch import torchaudio @@ -55,6 +56,7 @@ def __init__( encodec_model: str = "encodec_24khz", bandwidths: List[float] = [1.5, 3.0, 6.0, 12.0], train_codebooks: bool = False, + repo: Optional[str] = None, ): super().__init__() if encodec_model == "encodec_24khz": @@ -65,7 +67,9 @@ def __init__( raise ValueError( f"Unsupported encodec_model: {encodec_model}. Supported options are 'encodec_24khz' and 'encodec_48khz'." ) - self.encodec = encodec(pretrained=True) + if repo is not None: + repo = Path(repo) + self.encodec = encodec(pretrained=True, repository=repo) for param in self.encodec.parameters(): param.requires_grad = False self.num_q = self.encodec.quantizer.get_num_quantizers_for_bandwidth( diff --git a/vocos/pretrained.py b/vocos/pretrained.py index a8a5935..d511d95 100644 --- a/vocos/pretrained.py +++ b/vocos/pretrained.py @@ -1,5 +1,6 @@ from __future__ import annotations +import os from typing import Any, Dict, Tuple, Union, Optional import torch @@ -47,12 +48,17 @@ def __init__( self.head = head @classmethod - def from_hparams(cls, config_path: str) -> Vocos: + def from_hparams(cls, + config_path: str, + feature_extractor_repo: Optional[str] = None) -> Vocos: """ Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file. """ with open(config_path, "r") as f: config = yaml.safe_load(f) + if feature_extractor_repo is not None: + kwargs = config['feature_extractor'].setdefault("init_args", {}) + kwargs['repo'] = feature_extractor_repo feature_extractor = instantiate_class(args=(), init=config["feature_extractor"]) backbone = instantiate_class(args=(), init=config["backbone"]) head = instantiate_class(args=(), init=config["head"]) @@ -60,13 +66,26 @@ def from_hparams(cls, config_path: str) -> Vocos: return model @classmethod - def from_pretrained(cls, repo_id: str, revision: Optional[str] = None) -> Vocos: + def from_pretrained(cls, + repo_id: str, + revision: Optional[str] = None, + feature_extractor_repo: Optional[str] = None) -> Vocos: """ Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub. """ - config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml", revision=revision) - model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin", revision=revision) - model = cls.from_hparams(config_path) + offline = os.environ.get('VOCOS_OFFLINE', '0') == '1' + if offline: + config_path = os.path.join(repo_id, 'config.yaml') + model_path = os.path.join(repo_id, 'pytorch_model.bin') + else: + config_path = hf_hub_download(repo_id=repo_id, + filename="config.yaml", + revision=revision) + model_path = hf_hub_download(repo_id=repo_id, + filename="pytorch_model.bin", + revision=revision) + model = cls.from_hparams(config_path, + feature_extractor_repo=feature_extractor_repo) state_dict = torch.load(model_path, map_location="cpu") if isinstance(model.feature_extractor, EncodecFeatures): encodec_parameters = {