Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Push to Hub functionnality to Model and Pipeline #1699

Open
wants to merge 20 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 221 additions & 2 deletions pyannote/audio/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,17 @@
from functools import cached_property
from importlib import import_module
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Any, Dict, List, Optional, Text, Tuple, Union
from urllib.parse import urlparse

import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import RepositoryNotFoundError
import yaml
from huggingface_hub import HfApi, ModelCard, ModelCardData, hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
from lightning_fabric.utilities.cloud_io import _load as pl_load
from pyannote.core import SlidingWindow
from pytorch_lightning.utilities.model_summary import ModelSummary
Expand Down Expand Up @@ -534,6 +536,7 @@ def from_pretrained(
strict: bool = True,
use_auth_token: Union[Text, None] = None,
cache_dir: Union[Path, Text] = CACHE_DIR,
subfolder: Optional[str] = None,
**kwargs,
) -> "Model":
"""Load pretrained model
Expand Down Expand Up @@ -566,6 +569,8 @@ def from_pretrained(
cache_dir: Path or str, optional
Path to model cache directory. Defaults to content of PYANNOTE_CACHE
environment variable, or "~/.cache/torch/pyannote" when unset.
subfolder: Path or str, optional
An optional value corresponding to a folder inside the model repo.
kwargs: optional
Any extra keyword args needed to init the model.
Can also be used to override saved hyperparameter values.
Expand All @@ -581,6 +586,7 @@ def from_pretrained(
"""

# pytorch-lightning expects str, not Path.

checkpoint = str(checkpoint)
if hparams_file is not None:
hparams_file = str(hparams_file)
Expand Down Expand Up @@ -618,6 +624,7 @@ def from_pretrained(
use_auth_token=use_auth_token,
# local_files_only=False,
# legacy_cache_layout=False,
subfolder=subfolder,
)
except RepositoryNotFoundError:
print(
Expand Down Expand Up @@ -702,3 +709,215 @@ def default_map_location(storage, loc):
raise e

return model

def save_pretrained(self, checkpoint_dir):
"""save model config and checkpoint to a specific directory:

Args:
checkpoint_dir (str): Path directory to save the model and xzcheckpoint
model_type (str): Either PyanNet or WeSpeakerResNet34
"""

model_type = str(type(self)).split("'")[1].split(".")[-1]

assert model_type in ["PyanNet", "WeSpeakerResNet34"]
kamilakesbi marked this conversation as resolved.
Show resolved Hide resolved
checkpoint_dir = Path(checkpoint_dir)
# Save State Dicts:
checkpoint = {"state_dict": self.state_dict()}
self.on_save_checkpoint(checkpoint)
checkpoint["pytorch-lightning_version"] = pl.__version__

if model_type == "PyanNet":
checkpoint["hyper_parameters"] = dict(self.hparams)

pyannote_checkpoint = Path(checkpoint_dir) / HF_PYTORCH_WEIGHTS_NAME
torch.save(checkpoint, pyannote_checkpoint)

# Prepare Config Files and Tags for a PyanNet model
if model_type == "PyanNet":
file = {
"model": {},
"task": {},
}
file["model"] = checkpoint["hyper_parameters"]
file["model"]["_target_"] = str(type(self)).split("'")[1]
file["task"]["duration"] = self.specifications.duration
file["task"]["max_speakers_per_chunk"] = len(self.specifications.classes)
file["task"][
"max_speakers_per_frame"
] = self.specifications.powerset_max_classes

# Prepare Config Files and Tags for a WeSpeakerResNet34 model:
elif model_type == "WeSpeakerResNet34":
file = {
"model": {},
}

file["model"] = dict(self.hparams)
file["model"]["_target_"] = str(type(self)).split("'")[1]

with open(checkpoint_dir / "config.yaml", "w") as outfile:
yaml.dump(file, outfile, default_flow_style=False)

def push_to_hub(
self,
repo_id: str,
commit_message: Optional[str] = None,
private: Optional[bool] = None,
use_auth_token: Optional[Union[bool, str]] = None,
create_pr: bool = False,
revision: str = None,
commit_description: str = None,
tags: Optional[List[str]] = None,
) -> None:
"""
Upload the pyannote Model to the 🤗 Model Hub.

Parameters:
repo_id (`str`):
The name of the repository you want to push your Model to. It should contain your organization name
when pushing to a given organization.
commit_message (`str`, *optional*):
Message to commit while pushing. Will default to `"Upload Model"`.
private (`bool`, *optional*):
Whether or not the repository created should be private.
token (`bool` or `str`, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `huggingface-cli login` (stored in `~/.huggingface`). Will default to `True` if `repo_url`
is not specified.
create_pr (`bool`, *optional*, defaults to `False`):
Whether or not to create a PR with the uploaded files or directly commit.
revision (`str`, *optional*):
Branch to push the uploaded files to.
commit_description (`str`, *optional*):
The description of the commit that will be created
tags (`List[str]`, *optional*):
List of tags to push on the Hub.
"""

api = HfApi()

_ = api.create_repo(
repo_id,
private=private,
token=use_auth_token,
exist_ok=True,
repo_type="model",
)

with TemporaryDirectory() as tmpdir:
kamilakesbi marked this conversation as resolved.
Show resolved Hide resolved

# Save model checkpoint and config
self.save_pretrained(tmpdir)
# Update model card:
model_type = str(type(self)).split("'")[1].split(".")[-1]

model_card = create_and_tag_model_card(
repo_id,
model_type,
tags,
use_auth_token=use_auth_token,
)
model_card.save(os.path.join(tmpdir, "README.md"))

# Push to hub
return api.upload_folder(
repo_id=repo_id,
folder_path=tmpdir,
use_auth_token=use_auth_token,
repo_type="model",
commit_message=commit_message,
create_pr=create_pr,
revision=revision,
commit_description=commit_description,
)


def create_and_tag_model_card(
repo_id: str,
model_type: str,
tags: Optional[List[str]] = None,
use_auth_token: Optional[str] = None,
):
"""
Creates or loads an existing model card and tags it.

Args:
repo_id (`str`):
The repo_id where to look for the model card.
model_type (`str):
Specify the model type (PyanNet or WeSpeakerResNet34) to create the associated model card.
tags (`List[str]`, *optional*):
The list of optional tags to add in the model card
use_auth_token (`str`, *optional*):
Authentication token, obtained with `huggingface_hub.HfApi.login` method. Will default to the stored token.
ignore_metadata_errors (`str`):
If True, errors while parsing the metadata section will be ignored. Some information might be lost during
the process. Use it at your own risk.
"""

tags = [] if tags is None else tags

if model_type == "PyanNet":

base_tags = [
"pyannote",
"pyannote.audio",
"pyannote-audio-model",
"audio",
"voice",
"speech",
"speaker",
"speaker-diarization",
"speaker-change-detection",
"speaker-segmentation",
"voice-activity-detection",
"overlapped-speech-detection",
"resegmentation",
]
tags += base_tags
licence = "mit"

elif model_type == "WeSpeakerResNet34":

base_tags = [
"pyannote",
"pyannote.audio",
"pyannote-audio-model",
"audio",
"voice",
"speech",
"speaker",
"speaker-recognition",
"speaker-verification",
"speaker-identification",
"speaker-embedding",
"PyTorch",
"wespeaker",
]
tags += base_tags
licence = "cc-by-4.0"
try:
# Check if the model card is present on the remote repo
model_card = ModelCard.load(repo_id, token=use_auth_token)
except EntryNotFoundError:
# Otherwise create a simple model card from template
model_description = "This is the model card of a pyannote model that has been pushed on the Hub. This model card has been automatically generated."
card_data = ModelCardData(
tags=[] if tags is None else tags, library_name="pyannote"
)
model_card = ModelCard.from_template(
card_data, model_description=model_description
)

if tags is not None:
for model_tag in tags:
if model_tag not in model_card.data.tags:
model_card.data.tags.append(model_tag)

if licence is not None:
model_card.data.licence = licence

model_card.text = "This is the model card of a pyannote model that has been pushed on the Hub. This model card has been automatically generated."

return model_card
Loading