Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG]: Parameter model.norm.weight failed at the gradient reduction. #4984

Closed
fancyerii opened this issue Oct 26, 2023 · 15 comments
Closed

[BUG]: Parameter model.norm.weight failed at the gradient reduction. #4984

fancyerii opened this issue Oct 26, 2023 · 15 comments
Labels
bug Something isn't working

Comments

@fancyerii
Copy link

🐛 Describe the bug

I am running the Colossal-LLaMA-2 example. I can run it with zero2_cpu. But when I switch to gemini or gemini_auto, it crashed with:

loss.backward()
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/torch/_tensor.py", line 479, in backward
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/colossalai/booster/booster.py", line 159, in backward
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/torch/_tensor.py", line 479, in backward
            loss.backward()return handle_torch_function(return handle_torch_function(


  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/torch/overrides.py", line 1534, in handle_torch_function
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/torch/overrides.py", line 1534, in handle_torch_function
      File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/torch/_tensor.py", line 479, in backward
optimizer.backward(loss)
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/colossalai/zero/gemini/gemini_optimizer.py", line 265, in back
ward
    return handle_torch_function(
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/torch/overrides.py", line 1534, in handle_torch_function
    return handle_torch_function(
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/torch/overrides.py", line 1534, in handle_torch_function
    self.module.backward(loss)
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/colossalai/zero/gemini/gemini_ddp.py", line 305, in backward
    return handle_torch_function(
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/torch/overrides.py", line 1534, in handle_torch_function
    loss.backward()
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/torch/_tensor.py", line 479, in backward
    result = torch_func_method(public_api, types, args, kwargs)
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/colossalai/tensor/colo_tensor.py", line 81, in __torch_functio
n__
    result = torch_func_method(public_api, types, args, kwargs)
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/colossalai/tensor/colo_tensor.py", line 81, in __torch_functio
n__
    return backward_tensor.backward(**tensor_kwargs)
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/torch/_tensor.py", line 488, in backward
    return backward_tensor.backward(**tensor_kwargs)
      File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/torch/_tensor.py", line 488, in backward
return handle_torch_function(
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/torch/overrides.py", line 1534, in handle_torch_function
    result = torch_func_method(public_api, types, args, kwargs)
      File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/colossalai/tensor/colo_tensor.py", line 81, in __torch_function__
result = torch_func_method(public_api, types, args, kwargs)
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/colossalai/tensor/colo_tensor.py", line 81, in __torch_function__
    result = torch_func_method(public_api, types, args, kwargs)
          File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/colossalai/tensor/colo_tensor.py", line 81, in __torch_function__
return backward_tensor.backward(**tensor_kwargs)result = torch_func_method(public_api, types, args, kwargs)

  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/torch/_tensor.py", line 488, in backward
return backward_tensor.backward(**tensor_kwargs)
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/colossalai/tensor/colo_tensor.py", line 81, in __torch_function__
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/torch/autograd/__init__.py", line 197, in backward
    torch.autograd.backward(
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/torch/autograd/__init__.py", line 197, in backward
    return backward_tensor.backward(**tensor_kwargs)
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/torch/_tensor.py", line 488, in backward
    return backward_tensor.backward(**tensor_kwargs)
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/torch/_tensor.py", line 488, in backward
    result = torch_func_method(public_api, types, args, kwargs)
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/colossalai/tensor/colo_tensor.py", line 81, in __torch_functio
n__
        Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward passVariable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/colossalai/zero/gemini/gemini_ddp.py", line 320, in grad_handle
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/colossalai/zero/gemini/gemini_ddp.py", line 320, in grad_handle
    return backward_tensor.backward(**tensor_kwargs)
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
      File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/torch/autograd/__init__.py", line 197, in backward
torch.autograd.backward(
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/torch/autograd/__init__.py", line 197, in backward
    torch.autograd.backward(
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/torch/autograd/__init__.py", line 197, in backward
    torch.autograd.backward(
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/torch/autograd/__init__.py", line 197, in backward
    raise RuntimeError(
raise RuntimeError(
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass    RuntimeError
Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward passRuntimeError:
:   File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/colossalai/zero/gemini/gemini_ddp.py", line 320, in grad_handle
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/colossalai/zero/gemini/gemini_ddp.py", line 320, in grad_handle
Parameter `model.norm.weight` failed at the gradient reduction. Some unsupported torch function is operated upon this parameter.Parameter `model.norm.weight` failed at the gradient reduction. Some unsupported torch function is operated upon this parameter.

    torch.autograd.backward(
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/ubuntu/miniconda3/envs/py39_torch113/lib/python3.9/site-packages/colossalai/zero/gemini/gemini_ddp.py", line 320, in grad_handle

My launch script:

colossalai run --nproc_per_node 8   train.py \
    --pretrained /nas/lili/models_hf/13B-chat \
    --dataset "/nas/lili/colossalaitest/Colossal-LLaMA-2/spliced_tokenized_output_arrow/part-00000" \
    --plugin "gemini" \
    --save_interval 400 \
    --save_dir "13b_rv" \
    --tensorboard_dir tbdir \
    --config_file config-13b.json \
    --num_epochs 1 \
    --micro_batch_size 1 \
    --lr 1e-4 \
    --mixed_precision "bf16" \
    --grad_clip 1.0 \
    --weight_decay 0.01 \
    --warmup_steps 100 \
    --use_grad_checkpoint \
    --use_flash_attn

Environment

cuda 11.7
cudnn 8.9.4.25_cuda11
python 3.9
pytorch 1.13.1+cu117
colossal 0.3.3

@fancyerii fancyerii added the bug Something isn't working label Oct 26, 2023
@JThh
Copy link
Contributor

JThh commented Oct 29, 2023

I cannot see the exact issue from your description. But the error should be raised at line.

@JThh
Copy link
Contributor

JThh commented Oct 29, 2023

Would you please let me know your config file, and where are the model weights from?

@fancyerii
Copy link
Author

@JThh I don't use any config file. I just start it with:

colossalai run --nproc_per_node 8   train.py \
    --pretrained /nas/lili/models_hf/13B-chat \
    --dataset "/nas/lili/colossalaitest/Colossal-LLaMA-2/spliced_tokenized_output_arrow/part-00000" \
    --plugin "gemini" \
    --save_interval 400 \
    --save_dir "13b_rv" \
    --tensorboard_dir tbdir \
    --config_file config-13b.json \
    --num_epochs 1 \
    --micro_batch_size 1 \
    --lr 1e-4 \
    --mixed_precision "bf16" \
    --grad_clip 1.0 \
    --weight_decay 0.01 \
    --warmup_steps 100 \
    --use_grad_checkpoint \
    --use_flash_attn

@fancyerii
Copy link
Author

Anyone could help? I search here and find many similar problems without answer.

I also tried llama2 example. Because this issue, I have to use transformers 4.33.3 and I can run script/gemini_auto.sh with 70b correctly.

So the difference between examples/language/llama2 and application/Colossal-LLaMA-2 may be related to the bug.
And I also use transformers 4.33.3 to run application/Colossal-LLaMA-2. It throws the same exception.

@Fridge003
Copy link
Contributor

Fridge003 commented Nov 1, 2023

Hi @fancyerii , I found that this bug was triggered by the rms_norm module of flash attention.
You can try commenting the last two lines of function replace_with_flash_attention in colossalai_llama2/utils/flash_attention_patch.py in the following way:

def replace_with_flash_attention(model: LlamaForCausalLM) -> None:
    for name, module in model.named_modules():
        if isinstance(module, LlamaAttention):
            module.forward = MethodType(attention_forward, module)
        if isinstance(module, LlamaModel):
            module._prepare_decoder_attention_mask = MethodType(_prepare_decoder_attention_mask, module)
       #  if isinstance(module, LlamaRMSNorm):
       #    module.forward = MethodType(rms_norm_forward, module)

and the bug might be fixed.

@fancyerii
Copy link
Author

Thanks. I also found the problem and tested replace_xformers. It runs correctly now. But I have to switch back to transformers==4.33.3.
It's a workaround to comment that two lines or permanent solution?

@Fridge003
Copy link
Contributor

Fridge003 commented Nov 1, 2023

Thanks. I also found the problem and tested replace_xformers. It runs correctly now. But I have to switch back to transformers==4.33.3.
It's a workaround to comment that two lines or permanent solution?

Currently it's just a workaround, the mechanism behind is complex and we are still finding a better solution. Thank you for your issue!

@fancyerii
Copy link
Author

By the way, what's the difference of replace_xformers and replace_with_flash_attention?

@Fridge003
Copy link
Contributor

By the way, what's the difference of replace_xformers and replace_with_flash_attention?

replace_xformers use xformers's version of flash attention, while replace_with_flash_attention use the original version developed by Tri Dao. Also replace_xformers doesn't include replacing the rms_norm, so it will not trigger this bug.

@fancyerii
Copy link
Author

Thanks.

@Issues-translate-bot
Copy link

Bot detected the issue body's language is not English, translate it automatically. 👯👭🏻🧑‍🤝‍🧑👫🧑🏿‍🤝‍🧑🏻👩🏾‍🤝‍👨🏿👬🏿


Thanks.

@fancyerii
Copy link
Author

@Fridge003 if LlamaRMSNorm is not used, then no RMSNorm is used? Will the training be correct?

@fancyerii fancyerii reopened this Nov 1, 2023
@Fridge003
Copy link
Contributor

@Fridge003 if LlamaRMSNorm is not used, then no RMSNorm is used? Will the training be correct?

LlamaRMSNorm is still used, but not replaced to the flash attention version. The training will be correct, but a little bit slower.

@fancyerii
Copy link
Author

I got it. So when I use other plugin such as zero_cpu, I can still use the original codes. If I want to use gemini or gemini_auto, then I need comment out them.

@Fridge003
Copy link
Contributor

I got it. So when I use other plugin such as zero_cpu, I can still use the original codes. If I want to use gemini or gemini_auto, then I need comment out them.

Currently yes, we will try to fix this later.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants