diff --git a/nemo/collections/llm/gpt/model/hyena.py b/nemo/collections/llm/gpt/model/hyena.py index 79db37b398cb..389dc381abd8 100644 --- a/nemo/collections/llm/gpt/model/hyena.py +++ b/nemo/collections/llm/gpt/model/hyena.py @@ -23,7 +23,7 @@ try: from megatron.core import parallel_state from megatron.core.models.hyena import HyenaModel as MCoreHyenaModel - from megatron.core.models.hyena.hyena_layer_specs import hyena_stack_spec + from megatron.core.models.hyena.hyena_layer_specs import hyena_stack_spec, hyena_stack_spec_no_te from megatron.core.ssm.hyena_utils import hyena_no_weight_decay_cond HAVE_MEGATRON_CORE_OR_TE = True @@ -42,14 +42,13 @@ def hyena_forward_step(model, batch) -> torch.Tensor: - forward_args = { "input_ids": batch["tokens"], "position_ids": batch["position_ids"], "labels": batch["labels"], "loss_mask": batch["loss_mask"], + "attention_mask": None } - forward_args["attention_mask"] = None return model(**forward_args) @@ -90,6 +89,9 @@ class HyenaConfig(TransformerConfig, io.IOMixin): recompute_granularity: str = 'full' recompute_method: str = 'uniform' recompute_num_layers: int = 4 + fp8: str = 'hybrid' + fp8_amax_history_len: int = 16 + fp8_amax_compute_algo: str = "max" forward_step_fn: Callable = hyena_forward_step data_step_fn: Callable = gpt_data_step tokenizer_model_path: str = None @@ -135,7 +137,7 @@ def init(self) -> GPTModel: return GPTModel(self.config, tokenizer=self.tokenizer) - def apply(self, output_path: Path) -> Path: + def apply(self, output_path: Path, te_enabled=True) -> Path: source = torch.load(str(self), map_location='cpu') if 'model' in source: @@ -187,7 +189,7 @@ def transform_source_dict(self, source): trainer = self.nemo_setup(target, ckpt_async_save=False, save_ckpt_format='zarr') source.to(self.config.params_dtype) target.to(self.config.params_dtype) - self.convert_state(source, target) + self.convert_state(source, target, te_enabled) self.nemo_save(output_path, trainer) logging.info(f"Converted Hyena model to Nemo, model saved to {output_path}") @@ -197,33 +199,31 @@ def transform_source_dict(self, source): return output_path - def convert_state(self, source, target): + def convert_state(self, source, target, te_enabled=True): mapping = {} - te_enabled = True - scale_or_weight = 'weight' mapping['sequential.0.word_embeddings.weight'] = 'embedding.word_embeddings.weight' - mapping[f'sequential.{len(self.config.hybrid_override_pattern)}.norm.{scale_or_weight}'] = ( + mapping[f'sequential.{len(self.config.hybrid_override_pattern)}.norm.weight'] = ( 'decoder.final_norm.weight' ) for i, symbol in enumerate(self.config.hybrid_override_pattern): if te_enabled: - mapping[f'sequential.{i}.pre_mlp_layernorm.{scale_or_weight}'] = ( + mapping[f'sequential.{i}.pre_mlp_layernorm.weight'] = ( f'decoder.layers.{i}.mlp.linear_fc1.layer_norm_weight' ) else: - mapping[f'sequential.{i}.pre_mlp_layernorm.{scale_or_weight}'] = ( + mapping[f'sequential.{i}.pre_mlp_layernorm.weight'] = ( f'decoder.layers.{i}.pre_mlp_layernorm.weight' ) mapping[f'sequential.{i}.mlp.w3.weight'] = f'decoder.layers.{i}.mlp.linear_fc2.weight' if symbol != '*': if te_enabled: - mapping[f'sequential.{i}.input_layernorm.{scale_or_weight}'] = ( + mapping[f'sequential.{i}.input_layernorm.weight'] = ( f'decoder.layers.{i}.mixer.dense_projection.layer_norm_weight' ) else: - mapping[f'sequential.{i}.input_layernorm.{scale_or_weight}'] = f'decoder.layers.{i}.norm.weight' + mapping[f'sequential.{i}.input_layernorm.weight'] = f'decoder.layers.{i}.norm.weight' mapping[f'sequential.{i}.mixer.dense_projection.weight'] = ( f'decoder.layers.{i}.mixer.dense_projection.weight' @@ -256,11 +256,11 @@ def convert_state(self, source, target): elif symbol == '*': if te_enabled: - mapping[f'sequential.{i}.input_layernorm.{scale_or_weight}'] = ( + mapping[f'sequential.{i}.input_layernorm.weight'] = ( f'decoder.layers.{i}.self_attention.linear_qkv.layer_norm_weight' ) else: - mapping[f'sequential.{i}.input_layernorm.{scale_or_weight}'] = ( + mapping[f'sequential.{i}.input_layernorm.weight'] = ( f'decoder.layers.{i}.input_layernorm.weight' ) diff --git a/tests/collections/llm/gpt/model/test_hyena.py b/tests/collections/llm/gpt/model/test_hyena.py index 9eb2c4c6f6f3..eef4a613aa76 100644 --- a/tests/collections/llm/gpt/model/test_hyena.py +++ b/tests/collections/llm/gpt/model/test_hyena.py @@ -32,15 +32,15 @@ from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule from nemo.lightning.pytorch.optim import CosineAnnealingScheduler, WarmupHoldPolicyScheduler +# --ckpt-dir=/lustre/fsw/coreai_dlalgo_genai/ataghibakhsh/checkpoints/hyena_exp/small_ckpt \ """ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nproc-per-node=8 /opt/NeMo/tests/collections/llm/gpt/model/test_hyena.py \ --num-nodes=1 \ --devices=8 \ - --max-steps=50000 \ - --val-check-interval=10 \ - --experiment-dir=/lustre/fsw/coreai_dlalgo_genai/ataghibakhsh/checkpoints/hyena_exp \ - --ckpt-dir=/lustre/fsw/coreai_dlalgo_genai/ataghibakhsh/checkpoints/hyena_exp \ + --max-steps=500000 \ + --val-check-interval=200 \ + --experiment-dir=/lustre/fsw/coreai_dlalgo_genai/ataghibakhsh/checkpoints/hyena_exp2 \ --data-path=/lustre/fsw/coreai_dlalgo_genai/ataghibakhsh/datasets/hyena_data/hg38/pretraining_data_hg38/data_hg38_all_text_CharLevelTokenizer_document \ --seq-length=8192 \ --tensor-parallel-size=1 \ @@ -147,15 +147,15 @@ def get_args(): context_parallel_size=args.context_parallel_size, pipeline_dtype=torch.bfloat16, sequence_parallel=args.sequence_parallel, - ckpt_load_optimizer=True, - ckpt_save_optimizer=True, + ckpt_load_optimizer=False, + ckpt_save_optimizer=False, ckpt_async_save=False, save_ckpt_format='zarr', ), logger=loggers, callbacks = [checkpoint_callback], log_every_n_steps=1, - limit_val_batches=10, + limit_val_batches=100, num_sanity_val_steps=0, plugins=nl.MegatronMixedPrecision( precision="bf16-mixed", @@ -174,17 +174,17 @@ def get_args(): from nemo.lightning.pytorch.strategies.utils import RestoreConfig resume = nl.AutoResume( - resume_if_exists=True, + resume_if_exists=False, resume_ignore_no_checkpoint=True, resume_past_end=True, resume_from_directory=args.ckpt_dir, - # restore_config=( - # RestoreConfig( - # path=args.ckpt_dir, - # load_model_state = True, - # load_optim_state = True, - # ) if args.ckpt_dir else None - # ), + restore_config=( + RestoreConfig( + path=args.ckpt_dir, + load_model_state = True, + load_optim_state = False, + ) if args.ckpt_dir else None + ), ) resume.setup(trainer, model) @@ -192,6 +192,7 @@ def get_args(): opt_config = OptimizerConfig( optimizer='adam', lr=0.0003, + weight_decay=0.1, adam_beta1=0.9, adam_beta2=0.95, use_distributed_optimizer=True, @@ -203,7 +204,8 @@ def get_args(): min_lr=0.00003, ) - opt = MegatronOptimizerModule(opt_config, sched) + opt = MegatronOptimizerModule(config=opt_config, no_weight_decay_cond=hyena_config.hyena_no_weight_decay_cond_fn, lr_scheduler=sched) + # opt = MegatronOptimizerModule(opt_config, sched) opt.connect(model) # Start training