Skip to content

Commit

Permalink
Fix baichuan export (#11640)
Browse files Browse the repository at this point in the history
* Fix baichuan export

Signed-off-by: Chen Cui <[email protected]>

* update import

Signed-off-by: Chen Cui <[email protected]>

---------

Signed-off-by: Chen Cui <[email protected]>
  • Loading branch information
cuichenx authored Dec 20, 2024
1 parent ce88a09 commit ec6df08
Showing 1 changed file with 12 additions and 22 deletions.
34 changes: 12 additions & 22 deletions nemo/collections/llm/gpt/model/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
import torch.nn.functional as F
from torch import nn

from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel
from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel, torch_dtype_from_mcore_config
from nemo.collections.llm.utils import Config
from nemo.lightning import OptimizerModule, io, teardown
from nemo.lightning.pytorch.utils import dtype_from_hf

if TYPE_CHECKING:
from transformers import AutoConfig, AutoModelForCausalLM
from transformers import AutoModelForCausalLM

from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec
Expand Down Expand Up @@ -142,16 +142,23 @@ def make_vocab_size_divisible_by(vocab_size):

@io.model_exporter(Baichuan2Model, "hf")
class HFBaichuan2Exporter(io.ModelConnector[Baichuan2Model, "AutoModelForCausalLM"]):
def init(self) -> "AutoModelForCausalLM":
def init(self, dtype=torch.bfloat16, model_name="baichuan-inc/Baichuan2-7B-Base") -> "AutoModelForCausalLM":
from transformers import AutoModelForCausalLM
from transformers.modeling_utils import no_init_weights

with no_init_weights(True):
return AutoModelForCausalLM.from_config(self.config, trust_remote_code=True)
# Since Baichuan is not importable from transformers, we can only initialize the HF model
# from a known checkpoint. If more than 1 Baichuan model is supported in NeMo in the future,
# the model_name will need to be passed in.
return AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=dtype,
)

def apply(self, output_path: Path) -> Path:
target = self.init()
source, _ = self.nemo_load(str(self))
target = self.init(torch_dtype_from_mcore_config(source.config))
target = self.convert_state(source, target)

target = target.cpu()
Expand All @@ -177,23 +184,6 @@ def convert_state(self, source, target):
def tokenizer(self):
return io.load_context(str(self)).model.tokenizer.tokenizer

@property
def config(self) -> "AutoConfig":
source: Baichuan2Config = io.load_context(str(self)).model.config

return AutoConfig(
num_hidden_layers=source.num_layers,
hidden_size=source.hidden_size,
intermediate_size=source.ffn_hidden_size,
num_attention_heads=source.num_attention_heads,
max_position_embeddings=source.seq_length,
initializer_range=source.init_method_std,
rms_norm_eps=source.layernorm_epsilon,
num_key_value_heads=source.num_query_groups,
rope_theta=source.rotary_base,
vocab_size=self.tokenizer.vocab_size,
)


@io.state_transform(
source_key="model.layers.*.self_attn.W_pack.weight",
Expand Down

0 comments on commit ec6df08

Please sign in to comment.