From 35fb010dbd13ecf020c930271685fc19d9035455 Mon Sep 17 00:00:00 2001 From: yaoyu-33 <54727607+yaoyu-33@users.noreply.github.com> Date: Tue, 25 Jun 2024 09:50:16 -0700 Subject: [PATCH] Update neva conversion script from and to HF (#9296) * Update NeMo script Signed-off-by: yaoyu-33 * Fix example scripts Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: yaoyu-33 * Update convert_llava_nemo_to_hf.py Signed-off-by: yaoyu-33 <54727607+yaoyu-33@users.noreply.github.com> * address comments Signed-off-by: yaoyu-33 * Apply isort and black reformatting Signed-off-by: yaoyu-33 --------- Signed-off-by: yaoyu-33 Signed-off-by: yaoyu-33 Signed-off-by: yaoyu-33 <54727607+yaoyu-33@users.noreply.github.com> Co-authored-by: yaoyu-33 --- .../neva/conf/llava_config.yaml | 4 +- .../convert_gemma_hf_to_nemo.py | 2 +- .../convert_gemma_pyt_to_nemo.py | 2 +- .../convert_llava_hf_to_nemo.py | 331 +++++++++++++++++ .../convert_llava_nemo_to_hf.py | 337 ++++++++++++++++++ 5 files changed, 672 insertions(+), 4 deletions(-) create mode 100644 scripts/checkpoint_converters/convert_llava_hf_to_nemo.py create mode 100644 scripts/checkpoint_converters/convert_llava_nemo_to_hf.py diff --git a/examples/multimodal/multimodal_llm/neva/conf/llava_config.yaml b/examples/multimodal/multimodal_llm/neva/conf/llava_config.yaml index b47c719fef1d..3ec90b2d1b53 100644 --- a/examples/multimodal/multimodal_llm/neva/conf/llava_config.yaml +++ b/examples/multimodal/multimodal_llm/neva/conf/llava_config.yaml @@ -86,7 +86,7 @@ model: # LLM configs # use GPTModel from megatron.core - mcore_gpt: False + mcore_gpt: True # model architecture encoder_seq_length: 4096 @@ -149,7 +149,7 @@ model: bias_activation_fusion: False megatron_legacy: False - transformer_engine: False + transformer_engine: True fp8: False # enables fp8 in TransformerLayer forward fp8_e4m3: False # sets fp8_format = recipe.Format.E4M3 fp8_hybrid: False # sets fp8_format = recipe.Format.HYBRID diff --git a/scripts/checkpoint_converters/convert_gemma_hf_to_nemo.py b/scripts/checkpoint_converters/convert_gemma_hf_to_nemo.py index de12aefd1844..9ce51e544661 100644 --- a/scripts/checkpoint_converters/convert_gemma_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_gemma_hf_to_nemo.py @@ -127,8 +127,8 @@ def adjust_tensor_shapes(model, nemo_state_dict): model_config = model.cfg num_query_groups = model_config["num_query_groups"] head_num = model_config["num_attention_heads"] - head_size = model_config["kv_channels"] hidden_size = model_config["hidden_size"] + head_size = model_config["kv_channels"] heads_per_group = head_num // num_query_groups # Note: For 'key' and 'value' weight and biases, NeMo uses a consolidated tensor 'query_key_value'. diff --git a/scripts/checkpoint_converters/convert_gemma_pyt_to_nemo.py b/scripts/checkpoint_converters/convert_gemma_pyt_to_nemo.py index d14e5f7de551..3cf3ed021527 100644 --- a/scripts/checkpoint_converters/convert_gemma_pyt_to_nemo.py +++ b/scripts/checkpoint_converters/convert_gemma_pyt_to_nemo.py @@ -133,8 +133,8 @@ def adjust_tensor_shapes(model, nemo_state_dict): model_config = model.cfg num_query_groups = model_config["num_query_groups"] head_num = model_config["num_attention_heads"] - head_size = model_config["kv_channels"] hidden_size = model_config["hidden_size"] + head_size = model_config["kv_channels"] heads_per_group = head_num // num_query_groups # Note: For 'key' and 'value' weight and biases, NeMo uses a consolidated tensor 'query_key_value'. diff --git a/scripts/checkpoint_converters/convert_llava_hf_to_nemo.py b/scripts/checkpoint_converters/convert_llava_hf_to_nemo.py new file mode 100644 index 000000000000..d91899348e8c --- /dev/null +++ b/scripts/checkpoint_converters/convert_llava_hf_to_nemo.py @@ -0,0 +1,331 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" + python3 /opt/NeMo/scripts/checkpoint_converters/convert_llava_hf_to_nemo.py \ + --input_name_or_path llava-hf/llava-1.5-7b-hf \ + --output_path /path/to/llava-7b.nemo \ + --tokenizer_path /path/to/tokenizer.model +""" + +import os +from argparse import ArgumentParser + +import torch +from omegaconf import OmegaConf +from transformers import LlamaTokenizer, LlavaForConditionalGeneration + +from nemo.collections.multimodal.models.multimodal_llm.neva.neva_model import MegatronNevaModel +from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision +from nemo.utils import logging + + +def create_rename_keys(num_hidden_layers): + rename_keys = [] + for i in range(num_hidden_layers): + # Attention layers + rename_keys.extend( + [ + ( + f"language_model.model.layers.{i}.self_attn.o_proj.weight", + f"model.decoder.layers.{i}.self_attention.linear_proj.weight", + ), + ( + f"language_model.model.layers.{i}.self_attn.q_proj.weight", + f"model.decoder.layers.{i}.self_attention.linear_q.weight", + ), + ( + f"language_model.model.layers.{i}.self_attn.k_proj.weight", + f"model.decoder.layers.{i}.self_attention.linear_k.weight", + ), + ( + f"language_model.model.layers.{i}.self_attn.v_proj.weight", + f"model.decoder.layers.{i}.self_attention.linear_v.weight", + ), + # MLP and LayerNorm + ( + f"language_model.model.layers.{i}.mlp.gate_proj.weight", + f"model.decoder.layers.{i}.mlp.linear_fc1_gate.weight", + ), + ( + f"language_model.model.layers.{i}.mlp.up_proj.weight", + f"model.decoder.layers.{i}.mlp.linear_fc1_proj.weight", + ), + ( + f"language_model.model.layers.{i}.mlp.down_proj.weight", + f"model.decoder.layers.{i}.mlp.linear_fc2.weight", + ), + ( + f"language_model.model.layers.{i}.input_layernorm.weight", + f"model.decoder.layers.{i}.self_attention.linear_qkv.layer_norm_weight", + ), + ( + f"language_model.model.layers.{i}.post_attention_layernorm.weight", + f"model.decoder.layers.{i}.mlp.linear_fc1.layer_norm_weight", + ), + ] + ) + + rename_keys.extend( + [ + ( + "multi_modal_projector.linear_1.weight", + "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector.0.weight", + ), + ( + "multi_modal_projector.linear_1.bias", + "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector.0.bias", + ), + ( + "multi_modal_projector.linear_2.weight", + "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector.2.weight", + ), + ( + "multi_modal_projector.linear_2.bias", + "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector.2.bias", + ), + ("language_model.model.embed_tokens.weight", "model.embedding.word_embeddings.weight"), + ("language_model.model.norm.weight", "model.decoder.final_layernorm.weight"), + ("language_model.lm_head.weight", "model.output_layer.weight"), + ] + ) + + return rename_keys + + +def rename_model_keys(model_state_dict, rename_keys): + """ + Rename keys in the model's state dictionary based on the provided mappings. + + Parameters: + model_state_dict (dict): The state dictionary of the model. + rename_keys (list): A list of tuples with the mapping (old_key, new_key). + + Returns: + dict: A new state dictionary with updated key names. + """ + + # Create a new state dictionary with updated key names + new_state_dict = {} + + # Track keys from the original state dict to ensure all are processed + remaining_keys = set(model_state_dict.keys()) + + # Iterate over the rename mappings + for old_key, new_key in rename_keys: + if old_key in model_state_dict: + # Rename the key and remove it from the tracking set + new_state_dict[new_key] = model_state_dict[old_key] + remaining_keys.remove(old_key) + + # Check if any keys were not converted from old to new + for old_key in remaining_keys: + print(f"Warning: Key '{old_key}' was not converted.") + + return new_state_dict + + +def adjust_tensor_shapes(model, nemo_state_dict): + """ + Adapt tensor shapes in the state dictionary to ensure compatibility with a different model structure. + + Parameters: + nemo_state_dict (dict): The state dictionary of the model. + + Returns: + dict: The updated state dictionary with modified tensor shapes for compatibility. + """ + model_config = model.cfg + num_query_groups = model_config["num_query_groups"] + head_num = model_config["num_attention_heads"] + hidden_size = model_config["hidden_size"] + head_size = model_config["kv_channels"] + heads_per_group = head_num // num_query_groups + + # Note: For 'key' and 'value' weight and biases, NeMo uses a consolidated tensor 'query_key_value'. + for key_ in list(nemo_state_dict.keys()): + if 'vision_towel' in key_: + del nemo_state_dict[key_] + + if 'word_embeddings.weight' in key_ or 'output_layer.weight' in key_: + # padding + loaded_weight = nemo_state_dict[key_] + new_weight = model.state_dict()[key_] + new_weight[: loaded_weight.shape[0], : loaded_weight.shape[1]] = loaded_weight + nemo_state_dict[key_] = new_weight + + if 'mlp.linear_fc1_gate.weight' in key_: + key_gate = key_ + key_proj = key_.replace('mlp.linear_fc1_gate.weight', 'mlp.linear_fc1_proj.weight') + new_key = key_.replace('mlp.linear_fc1_gate.weight', 'mlp.linear_fc1.weight') + gate_weight = nemo_state_dict[key_gate] + proj_weight = nemo_state_dict[key_proj] + nemo_state_dict[new_key] = torch.cat((gate_weight, proj_weight)) + del nemo_state_dict[key_gate], nemo_state_dict[key_proj] + + if 'self_attention.linear_q.weight' in key_: + key_q = key_ + key_k = key_.replace('linear_q', 'linear_k') + key_v = key_.replace('linear_q', 'linear_v') + key_qkv = key_.replace('linear_q', 'linear_qkv') + + # [(head_num + 2 * num_query_groups) * head_size, hidden_size] + # -> [head_num, head_size, hidden_size], 2 * [num_query_groups, head_size, hidden_size] + q_weight, k_weight, v_weight = nemo_state_dict[key_q], nemo_state_dict[key_k], nemo_state_dict[key_v] + q_weight = q_weight.reshape(head_num, head_size, hidden_size) + k_weight = k_weight.reshape(num_query_groups, head_size, hidden_size) + v_weight = v_weight.reshape(num_query_groups, head_size, hidden_size) + + qkv_weight = torch.empty((0, head_size, hidden_size), device=q_weight.device) + for i in range(num_query_groups): + qkv_weight = torch.cat((qkv_weight, q_weight[i * heads_per_group : (i + 1) * heads_per_group, :, :])) + qkv_weight = torch.cat((qkv_weight, k_weight[i : i + 1, :, :])) + qkv_weight = torch.cat((qkv_weight, v_weight[i : i + 1, :, :])) + qkv_weight = qkv_weight.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + nemo_state_dict[key_qkv] = qkv_weight + del nemo_state_dict[key_q], nemo_state_dict[key_k], nemo_state_dict[key_v] + + return nemo_state_dict + + +def adjust_nemo_config(model_config, ref_config): + model_config.mm_cfg.mm_mlp_adapter_type = "mlp2x_gelu" + if ref_config["vision_config"].image_size == 336: + model_config.mm_cfg.vision_encoder.from_pretrained = "openai/clip-vit-large-patch14-336" + model_config.data.image_token_len = 576 + else: + model_config.mm_cfg.vision_encoder.from_pretrained = "openai/clip-vit-large-patch14" + model_config.data.image_token_len = 256 + + ref_config = ref_config['text_config'].__dict__ + model_config["encoder_seq_length"] = ref_config["max_position_embeddings"] + model_config["num_layers"] = ref_config["num_hidden_layers"] + model_config["ffn_hidden_size"] = ref_config["intermediate_size"] + model_config["hidden_size"] = ref_config["hidden_size"] + model_config["num_attention_heads"] = ref_config["num_attention_heads"] + model_config["num_query_groups"] = ref_config["num_key_value_heads"] + model_config["layernorm_epsilon"] = ref_config["rms_norm_eps"] + model_config["init_method_std"] = ref_config["initializer_range"] + model_config["kv_channels"] = ref_config.get( + "head_dim", model_config["hidden_size"] // model_config["num_attention_heads"] + ) + if ref_config.get("rope_scaling") is not None: + if ref_config["rope_scaling"]["type"] == "linear": + model_config["seq_len_interpolation_factor"] = ref_config["rope_scaling"]["factor"] + else: + raise ValueError("Only linear rope scaling type is supported now") + model_config["use_cpu_initialization"] = True + + return model_config + + +def get_args(): + parser = ArgumentParser() + parser.add_argument("--input_name_or_path", type=str) + parser.add_argument("--tokenizer_path", type=str) + parser.add_argument("--conv_template", default="v1", type=str) + parser.add_argument( + "--hparams_file", + type=str, + default=os.path.join( + os.path.dirname(__file__), '../../examples/multimodal/multimodal_llm/neva/conf/llava_config.yaml' + ), + required=False, + help="Path config for restoring. It's created during training and may need to be modified during restore if restore environment is different than training. Ex: /raid/nemo_experiments/megatron_gpt/hparams.yaml", + ) + parser.add_argument("--output_path", type=str, default=None, help="Path to output .nemo file.") + parser.add_argument( + "--precision", type=str, default="bf16", choices=["bf16", "32"], help="Precision for checkpoint weight saved" + ) + parser.add_argument("--skip_verification", action="store_true") + + args = parser.parse_args() + return args + + +def convert(args): + logging.info(f"Loading checkpoint from HF Llava: `{args.input_name_or_path}`") + hf_tokenizer = LlamaTokenizer.from_pretrained(args.input_name_or_path) + hf_model = LlavaForConditionalGeneration.from_pretrained(args.input_name_or_path) + logging.info("HF Model loading done.") + + nemo_config = OmegaConf.load(args.hparams_file) + nemo_config.model = adjust_nemo_config(nemo_config.model, hf_model.config.__dict__) + nemo_config.model.data["conv_template"] = args.conv_template + nemo_config.model.mm_cfg.llm["model_type"] = args.conv_template + nemo_config.model.tokenizer["model"] = args.tokenizer_path + + nemo_config.trainer["precision"] = args.precision + trainer = MegatronTrainerBuilder(nemo_config).create_trainer() + model = MegatronNevaModel(nemo_config.model, trainer) + + rename_keys = create_rename_keys(nemo_config.model.num_layers) + old_state_dict = hf_model.state_dict() + new_state_dict = rename_model_keys(model_state_dict=old_state_dict, rename_keys=rename_keys) + + nemo_state_dict = adjust_tensor_shapes(model, new_state_dict) + model.load_state_dict(nemo_state_dict, strict=False) + + logging.info(f'=' * 100) + if not args.skip_verification: + # Verifications + input_texts = [ + 'query: how much protein should a female eat', + ] + logging.info(f"Running verifications {input_texts} ...") + + # Tokenize the input texts + hf_tokenizer.pad_token = hf_tokenizer.eos_token + batch_dict = hf_tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors='pt') + batch_dict_cuda = {k: v.cuda() for k, v in batch_dict.items()} + hf_model = hf_model.cuda().eval() + model = model.eval() + + hf_outputs = hf_model(**batch_dict_cuda, output_hidden_states=True) + ids = batch_dict_cuda['input_ids'] + + id_tensors = [torch.unsqueeze(torch.LongTensor(id_list), dim=0) for id_list in ids.cpu()] + + masks_and_position_ids = [ + get_ltor_masks_and_position_ids(id_tensor, hf_tokenizer.eos_token, False, False, False) + for id_tensor in id_tensors + ] + for tokens, attn_mask_and_pos_ids in zip(id_tensors, masks_and_position_ids): + attn_mask, _, pos_ids = attn_mask_and_pos_ids + + outputs = model( + tokens=tokens, text_position_ids=pos_ids.cuda(), attention_mask=attn_mask.cuda(), labels=None + ) + + hf_next_token = hf_outputs.logits[0, -1].argmax() + next_token = outputs.squeeze()[-1].argmax() + + logging.info(f"HF predicted next token is: '{hf_tokenizer._convert_id_to_token(int(hf_next_token))}'.") + logging.info(f"NeMo predicted next token is: '{hf_tokenizer._convert_id_to_token(int(next_token))}'.") + assert ( + hf_next_token == next_token + ), f'prediction mismatch: {hf_tokenizer.decode(hf_next_token)} != {hf_tokenizer.decode(next_token)}' + logging.info(f'=' * 100) + + dtype = torch_dtype_from_precision(args.precision) + model = model.to(dtype=dtype) + model.save_to(args.output_path) + logging.info(f'NeMo model saved to: {args.output_path}') + + +if __name__ == '__main__': + args = get_args() + convert(args) diff --git a/scripts/checkpoint_converters/convert_llava_nemo_to_hf.py b/scripts/checkpoint_converters/convert_llava_nemo_to_hf.py new file mode 100644 index 000000000000..430a74567ec2 --- /dev/null +++ b/scripts/checkpoint_converters/convert_llava_nemo_to_hf.py @@ -0,0 +1,337 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" + python3 /opt/NeMo/scripts/nlp_language_modeling/convert_gemma_hf_to_nemo.py \ + --input_name_or_path /path/to/llava-v1.5-7b.nemo \ + --hf_input_path llava-hf/llava-1.5-7b-hf \ + --hf_output_path=/path/to/hf_updated_checkpoint +""" + +import os +from argparse import ArgumentParser + +import torch +from omegaconf import OmegaConf +from transformers import LlamaTokenizer, LlavaForConditionalGeneration + +from nemo.collections.multimodal.models.multimodal_llm.neva.neva_model import MegatronNevaModel +from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector +from nemo.utils import logging + + +def create_rename_keys(num_hidden_layers): + rename_keys = [] + for i in range(num_hidden_layers): + # Attention layers + rename_keys.extend( + [ + ( + f"language_model.model.layers.{i}.self_attn.o_proj.weight", + f"model.decoder.layers.{i}.self_attention.linear_proj.weight", + ), + ( + f"language_model.model.layers.{i}.self_attn.q_proj.weight", + f"model.decoder.layers.{i}.self_attention.linear_q.weight", + ), + ( + f"language_model.model.layers.{i}.self_attn.k_proj.weight", + f"model.decoder.layers.{i}.self_attention.linear_k.weight", + ), + ( + f"language_model.model.layers.{i}.self_attn.v_proj.weight", + f"model.decoder.layers.{i}.self_attention.linear_v.weight", + ), + # MLP and LayerNorm + ( + f"language_model.model.layers.{i}.mlp.gate_proj.weight", + f"model.decoder.layers.{i}.mlp.linear_fc1_gate.weight", + ), + ( + f"language_model.model.layers.{i}.mlp.up_proj.weight", + f"model.decoder.layers.{i}.mlp.linear_fc1_proj.weight", + ), + ( + f"language_model.model.layers.{i}.mlp.down_proj.weight", + f"model.decoder.layers.{i}.mlp.linear_fc2.weight", + ), + ( + f"language_model.model.layers.{i}.input_layernorm.weight", + f"model.decoder.layers.{i}.self_attention.linear_qkv.layer_norm_weight", + ), + ( + f"language_model.model.layers.{i}.post_attention_layernorm.weight", + f"model.decoder.layers.{i}.mlp.linear_fc1.layer_norm_weight", + ), + ] + ) + + rename_keys.extend( + [ + ( + "multi_modal_projector.linear_1.weight", + "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector.0.weight", + ), + ( + "multi_modal_projector.linear_1.bias", + "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector.0.bias", + ), + ( + "multi_modal_projector.linear_2.weight", + "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector.2.weight", + ), + ( + "multi_modal_projector.linear_2.bias", + "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector.2.bias", + ), + ("language_model.model.embed_tokens.weight", "model.embedding.word_embeddings.weight"), + ("language_model.model.norm.weight", "model.decoder.final_layernorm.weight"), + ("language_model.lm_head.weight", "model.output_layer.weight"), + ] + ) + + return rename_keys + + +def rename_model_keys(model_state_dict, rename_keys): + """ + Rename keys in the model's state dictionary based on the provided mappings. + + Parameters: + model_state_dict (dict): The state dictionary of the model. + rename_keys (list): A list of tuples with the mapping (old_key, new_key). + + Returns: + dict: A new state dictionary with updated key names. + """ + + # Create a new state dictionary with updated key names + new_state_dict = {} + + # Track keys from the original state dict to ensure all are processed + remaining_keys = set(model_state_dict.keys()) + + # Iterate over the rename mappings + for new_key, old_key in rename_keys: + if old_key in model_state_dict: + # Rename the key and remove it from the tracking set + new_state_dict[new_key] = model_state_dict[old_key] + remaining_keys.remove(old_key) + + # Check if any keys were not converted from old to new + for old_key in remaining_keys: + print(f"Warning: Key '{old_key}' was not converted.") + + return new_state_dict + + +def reverse_adjust_tensor_shapes(model, hf_model, nemo_state_dict): + """ + Reverse the tensor adjustments made in the state dictionary to retrieve the original model structure. + + Parameters: + model (torch.nn.Module): The model instance to reference the state dictionary. + nemo_state_dict (dict): The state dictionary containing the adjusted tensors. + + Returns: + dict: The updated state dictionary with original tensor shapes and structures. + """ + model_config = model.cfg + num_query_groups = model_config["num_query_groups"] + head_num = model_config["num_attention_heads"] + hidden_size = model_config["hidden_size"] + head_size = model_config["kv_channels"] + if head_size is None: + head_size = hidden_size // head_num + heads_per_group = head_num // num_query_groups + vocab_size = hf_model.config.vocab_size + + for key_ in list(nemo_state_dict.keys()): + if 'word_embeddings.weight' in key_ or 'output_layer.weight' in key_: + # Reverse padding + loaded_weight = model.state_dict()[key_] + nemo_state_dict[key_] = loaded_weight[:vocab_size] + + if 'mlp.linear_fc1.weight' in key_: + new_key_gate = key_.replace('mlp.linear_fc1.weight', 'mlp.linear_fc1_gate.weight') + new_key_proj = key_.replace('mlp.linear_fc1.weight', 'mlp.linear_fc1_proj.weight') + + # Split concatenated gate and projection weights + combined_weight = nemo_state_dict[key_] + gate_weight, proj_weight = torch.chunk(combined_weight, 2, dim=0) + nemo_state_dict[new_key_gate] = gate_weight + nemo_state_dict[new_key_proj] = proj_weight + del nemo_state_dict[key_] + + if 'self_attention.linear_qkv.weight' in key_: + key_qkv = key_ + key_q = key_qkv.replace('linear_qkv', 'linear_q') + key_k = key_qkv.replace('linear_qkv', 'linear_k') + key_v = key_qkv.replace('linear_qkv', 'linear_v') + qkv_weight = nemo_state_dict[key_qkv].reshape(-1, head_size, hidden_size) + q_weight = torch.empty((head_num, head_size, hidden_size), device=qkv_weight.device) + k_weight = torch.empty((num_query_groups, head_size, hidden_size), device=qkv_weight.device) + v_weight = torch.empty((num_query_groups, head_size, hidden_size), device=qkv_weight.device) + + qkv_index = 0 + for i in range(num_query_groups): + q_weight[i * heads_per_group : (i + 1) * heads_per_group, :, :] = qkv_weight[ + qkv_index : qkv_index + heads_per_group, :, : + ] + qkv_index += heads_per_group + k_weight[i, :, :] = qkv_weight[qkv_index, :, :] + qkv_index += 1 + v_weight[i, :, :] = qkv_weight[qkv_index, :, :] + qkv_index += 1 + + nemo_state_dict[key_q] = q_weight.reshape(head_num * head_size, hidden_size) + nemo_state_dict[key_k] = k_weight.reshape(num_query_groups * head_size, hidden_size) + nemo_state_dict[key_v] = v_weight.reshape(num_query_groups * head_size, hidden_size) + + del nemo_state_dict[key_qkv] + + return nemo_state_dict + + +def adjust_nemo_config(model_config, ref_config): + model_config.mm_cfg.mm_mlp_adapter_type = "mlp2x_gelu" + if ref_config["vision_config"].image_size == 336: + model_config.mm_cfg.vision_encoder.from_pretrained = "openai/clip-vit-large-patch14-336" + model_config.data.image_token_len = 576 + else: + model_config.mm_cfg.vision_encoder.from_pretrained = "openai/clip-vit-large-patch14" + model_config.data.image_token_len = 256 + + ref_config = ref_config['text_config'].__dict__ + model_config["encoder_seq_length"] = ref_config["max_position_embeddings"] + model_config["num_layers"] = ref_config["num_hidden_layers"] + model_config["ffn_hidden_size"] = ref_config["intermediate_size"] + model_config["hidden_size"] = ref_config["hidden_size"] + model_config["num_attention_heads"] = ref_config["num_attention_heads"] + model_config["num_query_groups"] = ref_config["num_key_value_heads"] + model_config["layernorm_epsilon"] = ref_config["rms_norm_eps"] + model_config["init_method_std"] = ref_config["initializer_range"] + model_config["kv_channels"] = ref_config.get( + "head_dim", model_config["hidden_size"] // model_config["num_attention_heads"] + ) + if ref_config.get("rope_scaling") is not None: + if ref_config["rope_scaling"]["type"] == "linear": + model_config["seq_len_interpolation_factor"] = ref_config["rope_scaling"]["factor"] + else: + raise ValueError("Only linear rope scaling type is supported now") + model_config["use_cpu_initialization"] = True + + return model_config + + +def get_args(): + parser = ArgumentParser() + parser.add_argument( + "--input_name_or_path", + type=str, + default=None, + required=True, + help="Path to .nemo file or extracted folder", + ) + parser.add_argument( + "--hf_input_path", + type=str, + default=None, + help="A HF model path, " "e.g. a folder containing https://huggingface.co/meta-llama/Llama-2-7b-hf/tree/main", + ) + parser.add_argument( + "--hf_output_path", + type=str, + default=None, + help="Output HF model path, " "with the same format as above but user's own weights", + ) + parser.add_argument("--skip_verification", action="store_true") + + args = parser.parse_args() + return args + + +def convert(args): + logging.info(f"Loading checkpoint from HF Llava: `{args.hf_input_path}`") + hf_tokenizer = LlamaTokenizer.from_pretrained(args.hf_input_path) + hf_model = LlavaForConditionalGeneration.from_pretrained(args.hf_input_path) + logging.info("HF Model loading done.") + + nemo_config = OmegaConf.load( + os.path.join(os.path.dirname(__file__), '../../examples/multimodal/multimodal_llm/neva/conf/llava_config.yaml') + ) + trainer = MegatronTrainerBuilder(nemo_config).create_trainer() + model = MegatronNevaModel.restore_from( + restore_path=args.input_name_or_path, + trainer=trainer, + save_restore_connector=NLPSaveRestoreConnector(), + ) + + rename_keys = create_rename_keys(model.cfg.num_layers) + old_state_dict = model.state_dict() + nemo_state_dict = reverse_adjust_tensor_shapes(model, hf_model, old_state_dict) + hf_state_dict = rename_model_keys(model_state_dict=nemo_state_dict, rename_keys=rename_keys) + + hf_model.load_state_dict(hf_state_dict, strict=False) + + logging.info(f'=' * 100) + if not args.skip_verification: + # Verifications + input_texts = [ + 'query: how much protein should a female eat', + ] + logging.info(f"Running verifications {input_texts} ...") + + # Tokenize the input texts + hf_tokenizer.pad_token = hf_tokenizer.eos_token + batch_dict = hf_tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors='pt') + batch_dict_cuda = {k: v.cuda() for k, v in batch_dict.items()} + hf_model = hf_model.cuda().eval() + model = model.eval() + + hf_outputs = hf_model(**batch_dict_cuda, output_hidden_states=True) + ids = batch_dict_cuda['input_ids'] + + id_tensors = [torch.unsqueeze(torch.LongTensor(id_list), dim=0) for id_list in ids.cpu()] + + masks_and_position_ids = [ + get_ltor_masks_and_position_ids(id_tensor, hf_tokenizer.eos_token, False, False, False) + for id_tensor in id_tensors + ] + for tokens, attn_mask_and_pos_ids in zip(id_tensors, masks_and_position_ids): + attn_mask, _, pos_ids = attn_mask_and_pos_ids + + outputs = model( + tokens=tokens, text_position_ids=pos_ids.cuda(), attention_mask=attn_mask.cuda(), labels=None + ) + + hf_next_token = hf_outputs.logits[0, -1].argmax() + next_token = outputs.squeeze()[-1].argmax() + + logging.info(f"HF predicted next token is: '{hf_tokenizer._convert_id_to_token(int(hf_next_token))}'.") + logging.info(f"NeMo predicted next token is: '{hf_tokenizer._convert_id_to_token(int(next_token))}'.") + assert ( + hf_next_token == next_token + ), f'prediction mismatch: {hf_tokenizer.decode(hf_next_token)} != {hf_tokenizer.decode(next_token)}' + logging.info(f'=' * 100) + + hf_model.save_pretrained(args.hf_output_path) + logging.info(f"Full HF model saved to {args.hf_output_path}") + + +if __name__ == '__main__': + args = get_args() + convert(args)