Skip to content

Commit

Permalink
fix fp8
Browse files Browse the repository at this point in the history
  • Loading branch information
Ali Taghibakhshi committed Dec 13, 2024
1 parent 0f47825 commit 5c894d0
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 14 deletions.
9 changes: 0 additions & 9 deletions nemo/collections/llm/gpt/model/hyena.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,9 +322,6 @@ class HyenaTestConfig(HyenaConfig):
add_qkv_bias: bool = False
add_bias_linear: bool = False
layernorm_epsilon: float = 1e-6
# fp8: str = 'hybrid'
# fp8_amax_history_len: int = 16
# fp8_amax_compute_algo: str = "max"
recompute_granularity: str = 'full'
recompute_method: str = 'uniform'
recompute_num_layers: int = 2
Expand Down Expand Up @@ -356,9 +353,6 @@ class Hyena7bConfig(HyenaConfig):
add_qkv_bias: bool = False
add_bias_linear: bool = False
layernorm_epsilon: float = 1e-6
fp8: str = 'hybrid'
fp8_amax_history_len: int = 16
fp8_amax_compute_algo: str = "max"
recompute_granularity: str = 'full'
recompute_method: str = 'uniform'
recompute_num_layers: int = 4
Expand Down Expand Up @@ -389,9 +383,6 @@ class Hyena40bConfig(HyenaConfig):
add_qkv_bias: bool = False
add_bias_linear: bool = False
layernorm_epsilon: float = 1e-6
fp8: str = 'hybrid'
fp8_amax_history_len: int = 16
fp8_amax_compute_algo: str = "max"
recompute_granularity: str = 'full'
recompute_method: str = 'uniform'
recompute_num_layers: int = 2
Expand Down
12 changes: 11 additions & 1 deletion nemo/collections/llm/gpt/model/megatron/hyena/hyena_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from megatron.core.utils import make_viewless_tensor
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from contextlib import nullcontext
from megatron.core.packed_seq_params import PackedSeqParams

try:
from megatron.core.extensions.transformer_engine import (
TEDelayedScaling,
Expand Down Expand Up @@ -184,7 +186,11 @@ def _checkpointed_forward(

def custom(start: int, end: int):
def custom_forward(
hidden_states, attention_mask, rotary_pos_emb
hidden_states,
attention_mask,
context,
context_mask,
rotary_pos_emb
):
for index in range(start, end):
layer = self._get_layer(index)
Expand All @@ -210,6 +216,8 @@ def checkpoint_handler(forward_func):
parallel_state.get_tensor_model_parallel_group(),
hidden_states,
attention_mask,
None,
None,
rotary_pos_emb,
)
else:
Expand All @@ -218,6 +226,8 @@ def checkpoint_handler(forward_func):
self.config.distribute_saved_activations,
hidden_states,
attention_mask,
None,
None,
rotary_pos_emb,
)

Expand Down
15 changes: 11 additions & 4 deletions tests/collections/llm/gpt/model/test_hyena.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@
--pipeline-model-parallel-size=1 \
--context-parallel-size=1 \
--global-batch-size=16 \
--micro-batch-size=2 \
--model-size=7b
--micro-batch-size=1 \
--model-size=7b \
--fp8
"""

def get_args():
Expand All @@ -65,6 +66,9 @@ def get_args():
parser.add_argument(
"--sequence-parallel", action="store_true", help="Set to enable sequence parallel"
)
parser.add_argument(
"--fp8", action="store_true", help="Set to enable FP8"
)
parser.add_argument('--micro-batch-size', type=int, default=1, help="Pipeline Parallel Size")
parser.add_argument('--global-batch-size', type=int, default=8, help="Pipeline Parallel Size")
parser.add_argument('--max-steps', type=int, help="Number of steps to train for")
Expand Down Expand Up @@ -150,7 +154,7 @@ def get_args():
context_parallel_size=args.context_parallel_size,
pipeline_dtype=torch.bfloat16,
sequence_parallel=args.sequence_parallel,
ckpt_load_optimizer=True,
ckpt_load_optimizer=False,
ckpt_save_optimizer=True,
ckpt_async_save=False,
save_ckpt_format='zarr',
Expand All @@ -163,6 +167,9 @@ def get_args():
plugins=nl.MegatronMixedPrecision(
precision="bf16-mixed",
params_dtype=torch.bfloat16,
fp8='hybrid' if args.fp8 else None,
fp8_amax_history_len=16 if args.fp8 else 1,
fp8_amax_compute_algo="max" if args.fp8 else "most_recent",
),
val_check_interval=args.val_check_interval,
)
Expand All @@ -185,7 +192,7 @@ def get_args():
RestoreConfig(
path=args.ckpt_dir,
load_model_state = True,
load_optim_state = True,
load_optim_state = False,
) if args.ckpt_dir else None
),
)
Expand Down

0 comments on commit 5c894d0

Please sign in to comment.