Skip to content

Commit

Permalink
Update Gemma conversion script (NVIDIA#9365)
Browse files Browse the repository at this point in the history
* Update Gemma conversion script

Signed-off-by: yaoyu-33 <[email protected]>

* Apply isort and black reformatting

Signed-off-by: yaoyu-33 <[email protected]>

---------

Signed-off-by: yaoyu-33 <[email protected]>
Signed-off-by: yaoyu-33 <[email protected]>
Co-authored-by: yaoyu-33 <[email protected]>
  • Loading branch information
yaoyu-33 and yaoyu-33 authored Jun 3, 2024
1 parent bd014d9 commit a0488f6
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
3 changes: 2 additions & 1 deletion scripts/checkpoint_converters/convert_gemma_jax_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""
Requires to install: `pip install orbax jax flax jaxlib`
Requires to clone: https://github.com/google-deepmind/gemma.git
Required to set: `export PYTHONPATH=/path/to/google/gemma_jax:$PYTHONPATH`
python3 /opt/NeMo/scripts/nlp_language_modeling/convert_gemma_jax_to_nemo.py \
--input_name_or_path /path/to/gemma/checkpoints/jax/7b \
Expand All @@ -27,8 +28,8 @@

import jax
import torch
from gemma.params import load_params, nest_params, param_remapper
from omegaconf import OmegaConf
from params import load_params, nest_params, param_remapper
from transformer import TransformerConfig

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
Expand Down
10 changes: 6 additions & 4 deletions scripts/checkpoint_converters/convert_gemma_pyt_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""
Requires to install: `pip install fairscale==0.4.13 immutabledict==4.1.0 tensorstore==0.1.45`
Requires to clone: https://github.com/google/gemma_pytorch.git
Required to set: `export PYTHONPATH=/path/to/google/gemma_pytorchh:$PYTHONPATH`
python3 /opt/NeMo/scripts/nlp_language_modeling/convert_gemma_pyt_to_nemo.py \
--input_name_or_path /path/to/gemma/checkpoints/pyt/7b.ckpt \
Expand All @@ -26,9 +27,9 @@
from argparse import ArgumentParser

import torch
from model.config import get_config_for_2b, get_config_for_7b
from model.model import CausalLM
from model.tokenizer import Tokenizer
from gemma.config import get_config_for_2b, get_config_for_7b
from gemma.model import CausalLM
from gemma.tokenizer import Tokenizer
from omegaconf import OmegaConf

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
Expand Down Expand Up @@ -152,7 +153,8 @@ def adjust_tensor_shapes(model, nemo_state_dict):
# [(head_num + 2 * num_query_groups) * head_size, hidden_size]
# -> [head_num, head_size, hidden_size], 2 * [num_query_groups, head_size, hidden_size]
q_weight, k_weight, v_weight = qkv_weight.split(
[head_num * head_size, num_query_groups * head_size, num_query_groups * head_size], dim=0,
[head_num * head_size, num_query_groups * head_size, num_query_groups * head_size],
dim=0,
)
q_weight = q_weight.reshape(head_num, head_size, hidden_size)
k_weight = k_weight.reshape(num_query_groups, head_size, hidden_size)
Expand Down

0 comments on commit a0488f6

Please sign in to comment.