Skip to content

Commit

Permalink
Fix CLAP converting script (huggingface#27153)
Browse files Browse the repository at this point in the history
* update converting script

* make style
  • Loading branch information
ylacombe authored Dec 8, 2023
1 parent b31905d commit 7f07c35
Showing 1 changed file with 26 additions and 16 deletions.
42 changes: 26 additions & 16 deletions src/transformers/models/clap/convert_clap_original_pytorch_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
import argparse
import re

import torch
from CLAP import create_model
from laion_clap import CLAP_Module

from transformers import AutoFeatureExtractor, ClapConfig, ClapModel

Expand All @@ -38,17 +37,25 @@
processor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused", truncation="rand_trunc")


def init_clap(checkpoint_path, enable_fusion=False):
model, model_cfg = create_model(
"HTSAT-tiny",
"roberta",
checkpoint_path,
precision="fp32",
device="cuda:0" if torch.cuda.is_available() else "cpu",
def init_clap(checkpoint_path, model_type, enable_fusion=False):
model = CLAP_Module(
amodel=model_type,
enable_fusion=enable_fusion,
fusion_type="aff_2d" if enable_fusion else None,
)
return model, model_cfg
model.load_ckpt(checkpoint_path)
return model


def get_config_from_original(clap_model):
audio_config = {
"patch_embeds_hidden_size": clap_model.model.audio_branch.embed_dim,
"depths": clap_model.model.audio_branch.depths,
"hidden_size": clap_model.model.audio_projection[0].in_features,
}

text_config = {"hidden_size": clap_model.model.text_branch.pooler.dense.in_features}

return ClapConfig(audio_config=audio_config, text_config=text_config)


def rename_state_dict(state_dict):
Expand Down Expand Up @@ -94,14 +101,14 @@ def rename_state_dict(state_dict):
return model_state_dict


def convert_clap_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path, enable_fusion=False):
clap_model, clap_model_cfg = init_clap(checkpoint_path, enable_fusion=enable_fusion)
def convert_clap_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path, model_type, enable_fusion=False):
clap_model = init_clap(checkpoint_path, model_type, enable_fusion=enable_fusion)

clap_model.eval()
state_dict = clap_model.state_dict()
state_dict = clap_model.model.state_dict()
state_dict = rename_state_dict(state_dict)

transformers_config = ClapConfig()
transformers_config = get_config_from_original(clap_model)
transformers_config.audio_config.enable_fusion = enable_fusion
model = ClapModel(transformers_config)

Expand All @@ -118,6 +125,9 @@ def convert_clap_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_pa
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
parser.add_argument("--enable_fusion", action="store_true", help="Whether to enable fusion or not")
parser.add_argument("--model_type", default="HTSAT-tiny", type=str, help="Whether to enable fusion or not")
args = parser.parse_args()

convert_clap_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.enable_fusion)
convert_clap_checkpoint(
args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.model_type, args.enable_fusion
)

0 comments on commit 7f07c35

Please sign in to comment.