Skip to content

Commit

Permalink
minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Ali Taghibakhshi committed Dec 7, 2024
1 parent 0935784 commit 5faf46d
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 31 deletions.
30 changes: 15 additions & 15 deletions nemo/collections/llm/gpt/model/hyena.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
try:
from megatron.core import parallel_state
from megatron.core.models.hyena import HyenaModel as MCoreHyenaModel
from megatron.core.models.hyena.hyena_layer_specs import hyena_stack_spec
from megatron.core.models.hyena.hyena_layer_specs import hyena_stack_spec, hyena_stack_spec_no_te
from megatron.core.ssm.hyena_utils import hyena_no_weight_decay_cond

HAVE_MEGATRON_CORE_OR_TE = True
Expand All @@ -42,14 +42,13 @@


def hyena_forward_step(model, batch) -> torch.Tensor:

forward_args = {
"input_ids": batch["tokens"],
"position_ids": batch["position_ids"],
"labels": batch["labels"],
"loss_mask": batch["loss_mask"],
"attention_mask": None
}
forward_args["attention_mask"] = None
return model(**forward_args)


Expand Down Expand Up @@ -90,6 +89,9 @@ class HyenaConfig(TransformerConfig, io.IOMixin):
recompute_granularity: str = 'full'
recompute_method: str = 'uniform'
recompute_num_layers: int = 4
fp8: str = 'hybrid'
fp8_amax_history_len: int = 16
fp8_amax_compute_algo: str = "max"
forward_step_fn: Callable = hyena_forward_step
data_step_fn: Callable = gpt_data_step
tokenizer_model_path: str = None
Expand Down Expand Up @@ -135,7 +137,7 @@ def init(self) -> GPTModel:

return GPTModel(self.config, tokenizer=self.tokenizer)

def apply(self, output_path: Path) -> Path:
def apply(self, output_path: Path, te_enabled=True) -> Path:

source = torch.load(str(self), map_location='cpu')
if 'model' in source:
Expand Down Expand Up @@ -187,7 +189,7 @@ def transform_source_dict(self, source):
trainer = self.nemo_setup(target, ckpt_async_save=False, save_ckpt_format='zarr')
source.to(self.config.params_dtype)
target.to(self.config.params_dtype)
self.convert_state(source, target)
self.convert_state(source, target, te_enabled)
self.nemo_save(output_path, trainer)

logging.info(f"Converted Hyena model to Nemo, model saved to {output_path}")
Expand All @@ -197,33 +199,31 @@ def transform_source_dict(self, source):

return output_path

def convert_state(self, source, target):
def convert_state(self, source, target, te_enabled=True):

mapping = {}
te_enabled = True
scale_or_weight = 'weight'
mapping['sequential.0.word_embeddings.weight'] = 'embedding.word_embeddings.weight'
mapping[f'sequential.{len(self.config.hybrid_override_pattern)}.norm.{scale_or_weight}'] = (
mapping[f'sequential.{len(self.config.hybrid_override_pattern)}.norm.weight'] = (
'decoder.final_norm.weight'
)
for i, symbol in enumerate(self.config.hybrid_override_pattern):
if te_enabled:
mapping[f'sequential.{i}.pre_mlp_layernorm.{scale_or_weight}'] = (
mapping[f'sequential.{i}.pre_mlp_layernorm.weight'] = (
f'decoder.layers.{i}.mlp.linear_fc1.layer_norm_weight'
)
else:
mapping[f'sequential.{i}.pre_mlp_layernorm.{scale_or_weight}'] = (
mapping[f'sequential.{i}.pre_mlp_layernorm.weight'] = (
f'decoder.layers.{i}.pre_mlp_layernorm.weight'
)
mapping[f'sequential.{i}.mlp.w3.weight'] = f'decoder.layers.{i}.mlp.linear_fc2.weight'

if symbol != '*':
if te_enabled:
mapping[f'sequential.{i}.input_layernorm.{scale_or_weight}'] = (
mapping[f'sequential.{i}.input_layernorm.weight'] = (
f'decoder.layers.{i}.mixer.dense_projection.layer_norm_weight'
)
else:
mapping[f'sequential.{i}.input_layernorm.{scale_or_weight}'] = f'decoder.layers.{i}.norm.weight'
mapping[f'sequential.{i}.input_layernorm.weight'] = f'decoder.layers.{i}.norm.weight'

mapping[f'sequential.{i}.mixer.dense_projection.weight'] = (
f'decoder.layers.{i}.mixer.dense_projection.weight'
Expand Down Expand Up @@ -256,11 +256,11 @@ def convert_state(self, source, target):

elif symbol == '*':
if te_enabled:
mapping[f'sequential.{i}.input_layernorm.{scale_or_weight}'] = (
mapping[f'sequential.{i}.input_layernorm.weight'] = (
f'decoder.layers.{i}.self_attention.linear_qkv.layer_norm_weight'
)
else:
mapping[f'sequential.{i}.input_layernorm.{scale_or_weight}'] = (
mapping[f'sequential.{i}.input_layernorm.weight'] = (
f'decoder.layers.{i}.input_layernorm.weight'
)

Expand Down
34 changes: 18 additions & 16 deletions tests/collections/llm/gpt/model/test_hyena.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@
from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule
from nemo.lightning.pytorch.optim import CosineAnnealingScheduler, WarmupHoldPolicyScheduler

# --ckpt-dir=/lustre/fsw/coreai_dlalgo_genai/ataghibakhsh/checkpoints/hyena_exp/small_ckpt \

"""
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc-per-node=8 /opt/NeMo/tests/collections/llm/gpt/model/test_hyena.py \
--num-nodes=1 \
--devices=8 \
--max-steps=50000 \
--val-check-interval=10 \
--experiment-dir=/lustre/fsw/coreai_dlalgo_genai/ataghibakhsh/checkpoints/hyena_exp \
--ckpt-dir=/lustre/fsw/coreai_dlalgo_genai/ataghibakhsh/checkpoints/hyena_exp \
--max-steps=500000 \
--val-check-interval=200 \
--experiment-dir=/lustre/fsw/coreai_dlalgo_genai/ataghibakhsh/checkpoints/hyena_exp2 \
--data-path=/lustre/fsw/coreai_dlalgo_genai/ataghibakhsh/datasets/hyena_data/hg38/pretraining_data_hg38/data_hg38_all_text_CharLevelTokenizer_document \
--seq-length=8192 \
--tensor-parallel-size=1 \
Expand Down Expand Up @@ -147,15 +147,15 @@ def get_args():
context_parallel_size=args.context_parallel_size,
pipeline_dtype=torch.bfloat16,
sequence_parallel=args.sequence_parallel,
ckpt_load_optimizer=True,
ckpt_save_optimizer=True,
ckpt_load_optimizer=False,
ckpt_save_optimizer=False,
ckpt_async_save=False,
save_ckpt_format='zarr',
),
logger=loggers,
callbacks = [checkpoint_callback],
log_every_n_steps=1,
limit_val_batches=10,
limit_val_batches=100,
num_sanity_val_steps=0,
plugins=nl.MegatronMixedPrecision(
precision="bf16-mixed",
Expand All @@ -174,24 +174,25 @@ def get_args():
from nemo.lightning.pytorch.strategies.utils import RestoreConfig

resume = nl.AutoResume(
resume_if_exists=True,
resume_if_exists=False,
resume_ignore_no_checkpoint=True,
resume_past_end=True,
resume_from_directory=args.ckpt_dir,
# restore_config=(
# RestoreConfig(
# path=args.ckpt_dir,
# load_model_state = True,
# load_optim_state = True,
# ) if args.ckpt_dir else None
# ),
restore_config=(
RestoreConfig(
path=args.ckpt_dir,
load_model_state = True,
load_optim_state = False,
) if args.ckpt_dir else None
),
)
resume.setup(trainer, model)

# Optimizer and scheduler setup
opt_config = OptimizerConfig(
optimizer='adam',
lr=0.0003,
weight_decay=0.1,
adam_beta1=0.9,
adam_beta2=0.95,
use_distributed_optimizer=True,
Expand All @@ -203,7 +204,8 @@ def get_args():
min_lr=0.00003,
)

opt = MegatronOptimizerModule(opt_config, sched)
opt = MegatronOptimizerModule(config=opt_config, no_weight_decay_cond=hyena_config.hyena_no_weight_decay_cond_fn, lr_scheduler=sched)
# opt = MegatronOptimizerModule(opt_config, sched)
opt.connect(model)

# Start training
Expand Down

0 comments on commit 5faf46d

Please sign in to comment.