diff --git a/nemo/collections/llm/gpt/model/hyena.py b/nemo/collections/llm/gpt/model/hyena.py index 17b1dc66b574..23d93a8a1328 100644 --- a/nemo/collections/llm/gpt/model/hyena.py +++ b/nemo/collections/llm/gpt/model/hyena.py @@ -322,9 +322,6 @@ class HyenaTestConfig(HyenaConfig): add_qkv_bias: bool = False add_bias_linear: bool = False layernorm_epsilon: float = 1e-6 - # fp8: str = 'hybrid' - # fp8_amax_history_len: int = 16 - # fp8_amax_compute_algo: str = "max" recompute_granularity: str = 'full' recompute_method: str = 'uniform' recompute_num_layers: int = 2 @@ -356,9 +353,6 @@ class Hyena7bConfig(HyenaConfig): add_qkv_bias: bool = False add_bias_linear: bool = False layernorm_epsilon: float = 1e-6 - fp8: str = 'hybrid' - fp8_amax_history_len: int = 16 - fp8_amax_compute_algo: str = "max" recompute_granularity: str = 'full' recompute_method: str = 'uniform' recompute_num_layers: int = 4 @@ -389,9 +383,6 @@ class Hyena40bConfig(HyenaConfig): add_qkv_bias: bool = False add_bias_linear: bool = False layernorm_epsilon: float = 1e-6 - fp8: str = 'hybrid' - fp8_amax_history_len: int = 16 - fp8_amax_compute_algo: str = "max" recompute_granularity: str = 'full' recompute_method: str = 'uniform' recompute_num_layers: int = 2 diff --git a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py index 1d50fc961576..abc7d4740ff4 100644 --- a/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py +++ b/nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py @@ -21,6 +21,8 @@ from megatron.core.utils import make_viewless_tensor from megatron.core import InferenceParams, parallel_state, tensor_parallel from contextlib import nullcontext +from megatron.core.packed_seq_params import PackedSeqParams + try: from megatron.core.extensions.transformer_engine import ( TEDelayedScaling, @@ -184,7 +186,11 @@ def _checkpointed_forward( def custom(start: int, end: int): def custom_forward( - hidden_states, attention_mask, rotary_pos_emb + hidden_states, + attention_mask, + context, + context_mask, + rotary_pos_emb ): for index in range(start, end): layer = self._get_layer(index) @@ -210,6 +216,8 @@ def checkpoint_handler(forward_func): parallel_state.get_tensor_model_parallel_group(), hidden_states, attention_mask, + None, + None, rotary_pos_emb, ) else: @@ -218,6 +226,8 @@ def checkpoint_handler(forward_func): self.config.distribute_saved_activations, hidden_states, attention_mask, + None, + None, rotary_pos_emb, ) diff --git a/tests/collections/llm/gpt/model/test_hyena.py b/tests/collections/llm/gpt/model/test_hyena.py index abbb8c085a5f..e5342333e8c6 100644 --- a/tests/collections/llm/gpt/model/test_hyena.py +++ b/tests/collections/llm/gpt/model/test_hyena.py @@ -49,8 +49,9 @@ --pipeline-model-parallel-size=1 \ --context-parallel-size=1 \ --global-batch-size=16 \ - --micro-batch-size=2 \ - --model-size=7b + --micro-batch-size=1 \ + --model-size=7b \ + --fp8 """ def get_args(): @@ -65,6 +66,9 @@ def get_args(): parser.add_argument( "--sequence-parallel", action="store_true", help="Set to enable sequence parallel" ) + parser.add_argument( + "--fp8", action="store_true", help="Set to enable FP8" + ) parser.add_argument('--micro-batch-size', type=int, default=1, help="Pipeline Parallel Size") parser.add_argument('--global-batch-size', type=int, default=8, help="Pipeline Parallel Size") parser.add_argument('--max-steps', type=int, help="Number of steps to train for") @@ -150,7 +154,7 @@ def get_args(): context_parallel_size=args.context_parallel_size, pipeline_dtype=torch.bfloat16, sequence_parallel=args.sequence_parallel, - ckpt_load_optimizer=True, + ckpt_load_optimizer=False, ckpt_save_optimizer=True, ckpt_async_save=False, save_ckpt_format='zarr', @@ -163,6 +167,9 @@ def get_args(): plugins=nl.MegatronMixedPrecision( precision="bf16-mixed", params_dtype=torch.bfloat16, + fp8='hybrid' if args.fp8 else None, + fp8_amax_history_len=16 if args.fp8 else 1, + fp8_amax_compute_algo="max" if args.fp8 else "most_recent", ), val_check_interval=args.val_check_interval, ) @@ -185,7 +192,7 @@ def get_args(): RestoreConfig( path=args.ckpt_dir, load_model_state = True, - load_optim_state = True, + load_optim_state = False, ) if args.ckpt_dir else None ), )