From 7f07c356a46740f50809ef761b9032fe45361807 Mon Sep 17 00:00:00 2001 From: Yoach Lacombe <52246514+ylacombe@users.noreply.github.com> Date: Fri, 8 Dec 2023 13:48:29 +0000 Subject: [PATCH] Fix CLAP converting script (#27153) * update converting script * make style --- .../convert_clap_original_pytorch_to_hf.py | 42 ++++++++++++------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/clap/convert_clap_original_pytorch_to_hf.py b/src/transformers/models/clap/convert_clap_original_pytorch_to_hf.py index 908fef5927af02..d422bc45ab3de0 100644 --- a/src/transformers/models/clap/convert_clap_original_pytorch_to_hf.py +++ b/src/transformers/models/clap/convert_clap_original_pytorch_to_hf.py @@ -16,8 +16,7 @@ import argparse import re -import torch -from CLAP import create_model +from laion_clap import CLAP_Module from transformers import AutoFeatureExtractor, ClapConfig, ClapModel @@ -38,17 +37,25 @@ processor = AutoFeatureExtractor.from_pretrained("laion/clap-htsat-unfused", truncation="rand_trunc") -def init_clap(checkpoint_path, enable_fusion=False): - model, model_cfg = create_model( - "HTSAT-tiny", - "roberta", - checkpoint_path, - precision="fp32", - device="cuda:0" if torch.cuda.is_available() else "cpu", +def init_clap(checkpoint_path, model_type, enable_fusion=False): + model = CLAP_Module( + amodel=model_type, enable_fusion=enable_fusion, - fusion_type="aff_2d" if enable_fusion else None, ) - return model, model_cfg + model.load_ckpt(checkpoint_path) + return model + + +def get_config_from_original(clap_model): + audio_config = { + "patch_embeds_hidden_size": clap_model.model.audio_branch.embed_dim, + "depths": clap_model.model.audio_branch.depths, + "hidden_size": clap_model.model.audio_projection[0].in_features, + } + + text_config = {"hidden_size": clap_model.model.text_branch.pooler.dense.in_features} + + return ClapConfig(audio_config=audio_config, text_config=text_config) def rename_state_dict(state_dict): @@ -94,14 +101,14 @@ def rename_state_dict(state_dict): return model_state_dict -def convert_clap_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path, enable_fusion=False): - clap_model, clap_model_cfg = init_clap(checkpoint_path, enable_fusion=enable_fusion) +def convert_clap_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_path, model_type, enable_fusion=False): + clap_model = init_clap(checkpoint_path, model_type, enable_fusion=enable_fusion) clap_model.eval() - state_dict = clap_model.state_dict() + state_dict = clap_model.model.state_dict() state_dict = rename_state_dict(state_dict) - transformers_config = ClapConfig() + transformers_config = get_config_from_original(clap_model) transformers_config.audio_config.enable_fusion = enable_fusion model = ClapModel(transformers_config) @@ -118,6 +125,9 @@ def convert_clap_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_pa parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint") parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert") parser.add_argument("--enable_fusion", action="store_true", help="Whether to enable fusion or not") + parser.add_argument("--model_type", default="HTSAT-tiny", type=str, help="Whether to enable fusion or not") args = parser.parse_args() - convert_clap_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.enable_fusion) + convert_clap_checkpoint( + args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.model_type, args.enable_fusion + )