Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

job hangs or IndexError when train reward model with PP> 1 #251

Open
zirui opened this issue Jul 24, 2024 · 6 comments
Open

job hangs or IndexError when train reward model with PP> 1 #251

zirui opened this issue Jul 24, 2024 · 6 comments
Assignees
Labels
bug Something isn't working

Comments

@zirui
Copy link

zirui commented Jul 24, 2024

Describe the bug

I Attempt to train reward models of different size(3B/6B/30B), and found out that when PP > 1, two type of issues arise
3B/6B:

  • TP=4, PP=1: ok
  • TP=4, PP=2: the job hangs

30B:

  • TP=8, PP=1: ok
  • TP=8, PP=2/4/8: IndexError: index xx is out of range

parameters configuration:

python $CODE_DIR/NeMo-Aligner/examples/nlp/gpt/train_reward_model.py \
      trainer.num_nodes=$NNODES \
      trainer.devices=8 \
      trainer.rm.max_epochs=3 \
      trainer.rm.max_steps=-1 \
      ++model.encoder_seq_length=4096 \
      ++model.micro_batch_size=1 \
      ++model.global_batch_size=64 \
      ++model.data.data_impl=jsonl \
      +model.use_flash_attention=True \
      model.data.seq_length=2048 \
      pretrained_checkpoint.restore_from_path=$MODEL_PATH \
      "model.data.data_prefix={train: [${TRAIN_DATA_PATH}], validation: [${VALID_DATA_PATH}], test: [${VALID_DATA_PATH}]}" \
      trainer.rm.save_interval=10 \
      trainer.rm.val_check_interval=10 \
      exp_manager.create_wandb_logger=True \
      exp_manager.wandb_logger_kwargs.project=$WANDB_PROJECT \
      exp_manager.wandb_logger_kwargs.name=$EXP_NAME \
      exp_manager.explicit_log_dir=$EXP_OUTPUT_DIR \
      ++model.tensor_model_parallel_size=8 \
      ++model.pipeline_model_parallel_size=4 \
      ++model.optim.name=distributed_fused_adam \
      ++model.optim.lr=${LR}  \
      ++model.optim.sched.constant_steps=0 \
      model.optim.weight_decay=0.1

Environment details

  • docker: nvcr.io/nvidia/nemo:24.05.01
  File "/code/NeMo-Aligner/examples/nlp/gpt/train_reward_model.py", line 153, in main
    rm_trainer.fit()
  File "/code/NeMo-Aligner/nemo_aligner/algorithms/supervised.py", line 210, in fit
    loss, metrics = self.train_single_step(batch)
  File "/code/NeMo-Aligner/nemo_aligner/algorithms/supervised.py", line 145, in train_single_step
    loss_mean, metrics = self.model.get_loss_and_metrics(batch=batch, forward_only=False)
  File "/code/NeMo-Aligner/nemo_aligner/models/nlp/gpt/megatron_gpt_reward_model.py", line 252, in get_loss_and_metrics
    losses_reduced_per_micro_batch = fwd_bwd_function(
  File "/opt/megatron-lm/megatron/core/pipeline_parallel/schedules.py", line 1271, in forward_backward_pipelining_without_interleaving
    output_tensor, num_tokens = forward_step(
  File "/opt/megatron-lm/megatron/core/pipeline_parallel/schedules.py", line 206, in forward_step
    output_tensor, loss_func = forward_step_func(data_iterator, model)
  File "/code/NeMo-Aligner/nemo_aligner/models/nlp/gpt/megatron_gpt_reward_model.py", line 158, in fwd_output_and_loss_func
    output_tensor = model(**forward_args)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/megatron-lm/megatron/core/transformer/module.py", line 173, in forward
    outputs = self.module(*inputs, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/code/NeMo-Aligner/nemo_aligner/models/nlp/gpt/gpt_reward_model.py", line 214, in forward
    hidden_states = super().forward(
  File "/opt/megatron-lm/megatron/core/models/gpt/gpt_model.py", line 190, in forward
    hidden_states = self.decoder(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/megatron-lm/megatron/core/transformer/transformer_block.py", line 371, in forward
    hidden_states = self._checkpointed_forward(
  File "/opt/megatron-lm/megatron/core/transformer/transformer_block.py", line 257, in _checkpointed_forward
    hidden_states, context = checkpoint_handler(
  File "/opt/megatron-lm/megatron/core/transformer/transformer_block.py", line 240, in checkpoint_handler
    return tensor_parallel.checkpoint(
  File "/opt/megatron-lm/megatron/core/tensor_parallel/random.py", line 301, in checkpoint
    return CheckpointFunction.apply(function, distribute_saved_activations, *args)
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 553, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/opt/megatron-lm/megatron/core/tensor_parallel/random.py", line 240, in forward
    outputs = run_function(*args)
  File "/opt/megatron-lm/megatron/core/transformer/transformer_block.py", line 211, in custom_forward
    layer = self._get_layer(index)
  File "/opt/megatron-lm/megatron/core/transformer/transformer_block.py", line 188, in _get_layer
    return self.layers[layer_number]
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py", line 295, in __getitem__
    return self._modules[self._get_abs_string_index(idx)]
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/container.py", line 285, in _get_abs_string_index
    raise IndexError(f'index {idx} is out of range')
IndexError: index 15 is out of range
Training steps:   0%|          | 0/423 [00:03<?, ?it/s]
@zirui zirui added the bug Something isn't working label Jul 24, 2024
@zirui zirui changed the title hangs or IndexError when train reward model with PP> 1 job hangs or IndexError when train reward model with PP> 1 Jul 24, 2024
@odelalleau
Copy link
Collaborator

Would you be able to share your model_config.yaml?

@zirui
Copy link
Author

zirui commented Jul 24, 2024

Would you be able to share your model_config.yaml?

I use the default config: NeMo-Aligner/examples/nlp/gpt/conf/training_rm.yaml

@odelalleau
Copy link
Collaborator

I was wondering about the model_config.yaml found in your .nemo file (you can extract it with tar xvf on your .nemo). I'm curious about your model's architecture.

@zirui
Copy link
Author

zirui commented Jul 24, 2024

I was wondering about the model_config.yaml found in your .nemo file (you can extract it with tar xvf on your .nemo). I'm curious about your model's architecture.

Below is the content of model_config.yaml
(Actually, it is a llama-like model that has been converted into .nemo from huggingface model.)

mcore_gpt: true
micro_batch_size: 4
global_batch_size: 8
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
virtual_pipeline_model_parallel_size: null
encoder_seq_length: 4096
max_position_embeddings: 4096
num_layers: 32
hidden_size: 4096
ffn_hidden_size: 11008
num_attention_heads: 32
init_method_std: 0.02
use_scaled_init_method: true
hidden_dropout: 0.0
attention_dropout: 0.0
ffn_dropout: 0.0
kv_channels: null
apply_query_key_layer_scaling: true
normalization: rmsnorm
layernorm_epsilon: 1.0e-06
do_layer_norm_weight_decay: false
make_vocab_size_divisible_by: 128
pre_process: true
post_process: true
persist_layer_norm: true
bias: false
activation: fast-swiglu
headscale: false
transformer_block_type: pre_ln
openai_gelu: false
normalize_attention_scores: true
position_embedding_type: rope
rotary_percentage: 1.0
attention_type: multihead
share_embeddings_and_output_weights: false
overlap_p2p_comm: false
batch_p2p_comm: true
num_query_groups: 4
tokenizer:
  library: sentencepiece
  type: null
  model: nemo:4f7bc9bb269d4abd9680ac15dcec4b16_tokenizer.model
  vocab_file: null
  merge_file: null
  delimiter: null
  sentencepiece_legacy: false
native_amp_init_scale: 4294967296
native_amp_growth_interval: 1000
hysteresis: 2
fp32_residual_connection: false
fp16_lm_cross_entropy: false
megatron_amp_O2: false
grad_allreduce_chunk_size_mb: 125
grad_div_ar_fusion: true
gradient_accumulation_fusion: false
bias_activation_fusion: false
bias_dropout_add_fusion: false
masked_softmax_fusion: true
get_attention_mask_from_fusion: true
apply_rope_fusion: false
seed: 1234
resume_from_checkpoint: null
use_cpu_initialization: false
onnx_safe: false
apex_transformer_log_level: 30
gradient_as_bucket_view: true
sync_batch_comm: false
activations_checkpoint_granularity: null
activations_checkpoint_method: null
activations_checkpoint_num_layers: null
num_micro_batches_with_partial_activation_checkpoints: null
activations_checkpoint_layers_per_pipeline: null
sequence_parallel: false
transformer_engine: true
fp8: false
fp8_e4m3: false
fp8_hybrid: true
fp8_margin: 0
fp8_interval: 1
fp8_amax_history_len: 1024
fp8_amax_compute_algo: max
reduce_amax: true
use_emha: false
data:
  index_mapping_dir: null
  data_impl: mmap
  splits_string: 900,50,50
  seq_length: 4096
  skip_warmup: true
  num_workers: 2
  dataloader_type: single
  reset_position_ids: false
  reset_attention_mask: false
  eod_mask_loss: false
  validation_drop_last: true
  no_seqlen_plus_one_input_tokens: false
  pad_samples_to_global_batch_size: false
  shuffle_documents: true
nsys_profile:
  enabled: false
  start_step: 10
  end_step: 10
  ranks:
  - 0
  gen_shape: false
optim:
  name: fused_adam
  lr: 0.0002
  weight_decay: 0.01
  betas:
  - 0.9
  - 0.98
  sched:
    name: CosineAnnealing
    warmup_steps: 500
    constant_steps: 50000
    min_lr: 2.0e-05
rotary_base: 5000000.0
precision: 16
target: nemo.collections.nlp.models.language_modeling.megatron_gpt_model.MegatronGPTModel
nemo_version: 1.23.0

@zirui
Copy link
Author

zirui commented Jul 24, 2024

Some additional information:

  • I converted Llama-2-7b-hf to Llama-2-7b-hf.nemo, and set TP=4 PP=2, the reward model training tasks still get stuck, so the issue should not be related to my model's architecture
  • I also tried SFT training(follow the user-guide), the PP(TP=4,PP=2) training works well

@terrykong terrykong assigned terrykong and ashors1 and unassigned terrykong Oct 1, 2024
@ashors1
Copy link
Collaborator

ashors1 commented Oct 8, 2024

Hi, could you try setting model.encoder_seq_length and model.data.seq_length to the same value? I was able to reproduce the hang, and making those seq_length arguments consistent fixed the problem for me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants