Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Offline mode #40

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,28 @@ python train.py -c configs/vocos.yaml
Refer to [Pytorch Lightning documentation](https://lightning.ai/docs/pytorch/stable/) for details about customizing the
training pipeline.

## Offline mode

If internet is not available in your GPU cluster, you can upload vocos and encodec checkpoint
to your GPU cluster and enable offline mode.

```python
import sys
import os

os.environ['HF_HOME'] = '/home/your_name/your_hf_home'
os.environ['HF_DATASETS_OFFLINE'] = '1'
os.environ['TRANSFORMERS_OFFLINE'] = '1'
os.environ['VOCOS_OFFLINE'] = '1' # just like `transformers` and `datasets`

import torch
from vocos import Vocos

vocos = Vocos.from_pretrained(
"/your_folder/charactr/vocos-encodec-24khz",
feature_extractor_repo='/your_folder/torch-hub/hub/checkpoints')
```

## Citation

If this code contributes to your research, please cite our work:
Expand Down
8 changes: 6 additions & 2 deletions vocos/feature_extractors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List
from typing import Optional, List
from pathlib import Path

import torch
import torchaudio
Expand Down Expand Up @@ -55,6 +56,7 @@ def __init__(
encodec_model: str = "encodec_24khz",
bandwidths: List[float] = [1.5, 3.0, 6.0, 12.0],
train_codebooks: bool = False,
repo: Optional[str] = None,
):
super().__init__()
if encodec_model == "encodec_24khz":
Expand All @@ -65,7 +67,9 @@ def __init__(
raise ValueError(
f"Unsupported encodec_model: {encodec_model}. Supported options are 'encodec_24khz' and 'encodec_48khz'."
)
self.encodec = encodec(pretrained=True)
if repo is not None:
repo = Path(repo)
self.encodec = encodec(pretrained=True, repository=repo)
for param in self.encodec.parameters():
param.requires_grad = False
self.num_q = self.encodec.quantizer.get_num_quantizers_for_bandwidth(
Expand Down
29 changes: 24 additions & 5 deletions vocos/pretrained.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import os
from typing import Any, Dict, Tuple, Union, Optional

import torch
Expand Down Expand Up @@ -47,26 +48,44 @@ def __init__(
self.head = head

@classmethod
def from_hparams(cls, config_path: str) -> Vocos:
def from_hparams(cls,
config_path: str,
feature_extractor_repo: Optional[str] = None) -> Vocos:
"""
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
with open(config_path, "r") as f:
config = yaml.safe_load(f)
if feature_extractor_repo is not None:
kwargs = config['feature_extractor'].setdefault("init_args", {})
kwargs['repo'] = feature_extractor_repo
feature_extractor = instantiate_class(args=(), init=config["feature_extractor"])
backbone = instantiate_class(args=(), init=config["backbone"])
head = instantiate_class(args=(), init=config["head"])
model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head)
return model

@classmethod
def from_pretrained(cls, repo_id: str, revision: Optional[str] = None) -> Vocos:
def from_pretrained(cls,
repo_id: str,
revision: Optional[str] = None,
feature_extractor_repo: Optional[str] = None) -> Vocos:
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
config_path = hf_hub_download(repo_id=repo_id, filename="config.yaml", revision=revision)
model_path = hf_hub_download(repo_id=repo_id, filename="pytorch_model.bin", revision=revision)
model = cls.from_hparams(config_path)
offline = os.environ.get('VOCOS_OFFLINE', '0') == '1'
if offline:
config_path = os.path.join(repo_id, 'config.yaml')
model_path = os.path.join(repo_id, 'pytorch_model.bin')
else:
config_path = hf_hub_download(repo_id=repo_id,
filename="config.yaml",
revision=revision)
model_path = hf_hub_download(repo_id=repo_id,
filename="pytorch_model.bin",
revision=revision)
model = cls.from_hparams(config_path,
feature_extractor_repo=feature_extractor_repo)
state_dict = torch.load(model_path, map_location="cpu")
if isinstance(model.feature_extractor, EncodecFeatures):
encodec_parameters = {
Expand Down