Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix allreduce for BF16 and ZeRO0 #5170

Merged
merged 5 commits into from
Feb 21, 2024
Merged

Fix allreduce for BF16 and ZeRO0 #5170

merged 5 commits into from
Feb 21, 2024

Conversation

tohtana
Copy link
Contributor

@tohtana tohtana commented Feb 21, 2024

This PR fixes an issue with allreducing for ZeRO0 + BF16. (This replaces #5154)

DeepSpeed uses BF16_Optimizer when ZeRO0 and BF16 are enabled. The optimizer accumulates gradients on FP32 buffer soon after a backward pass completes. However, DeepSpeed engine performs allreduce on BF16 gradients.

This PR fixes the issue by performing allreduce on the FP32 buffer. It also eliminates an assertion that prohibits BF16+PP+Z1, which is actually runnable.

This shows loss curves of the following conditions:

  • BF16/Z0,Z1,Z2,Z3/NoPP
  • BF16/Z0,Z1/PP(2 stages)
    (all used 8GPUs, gradient accumulation step: 4)
    image

@tohtana tohtana changed the title Tohtana/fix bf16 z0 reduce Fix allreduce for BF16 and ZeRO0 Feb 21, 2024
@tohtana tohtana marked this pull request as ready for review February 21, 2024 17:30
@tohtana tohtana requested a review from mrwyattii as a code owner February 21, 2024 17:30
@tohtana tohtana enabled auto-merge February 21, 2024 18:10
@tohtana tohtana added this pull request to the merge queue Feb 21, 2024
Merged via the queue into master with commit dd3690c Feb 21, 2024
12 checks passed
@tohtana tohtana deleted the tohtana/fix_bf16_z0_reduce branch February 21, 2024 20:08
ShellyNR pushed a commit to ShellyNR/DeepSpeed that referenced this pull request Mar 11, 2024
This PR fixes an issue with allreducing for ZeRO0 + BF16. (This replaces
microsoft#5154)

DeepSpeed uses `BF16_Optimizer` when ZeRO0 and BF16 are enabled. The
optimizer accumulates gradients on FP32 buffer soon after a backward
pass completes. However, DeepSpeed engine performs allreduce on BF16
gradients.

This PR fixes the issue by performing allreduce on the FP32 buffer. It
also eliminates an assertion that prohibits BF16+PP+Z1, which is
actually runnable.

This shows loss curves of the following conditions:
- BF16/Z0,Z1,Z2,Z3/NoPP
- BF16/Z0,Z1/PP(2 stages)
(all used 8GPUs, gradient accumulation step: 4)

![image](https://github.com/microsoft/DeepSpeed/assets/81312776/0dc1e9ef-43bc-4b47-8b9e-d6aca137a217)

---------

Co-authored-by: Logan Adams <[email protected]>
rraminen pushed a commit to ROCm/DeepSpeed that referenced this pull request May 9, 2024
This PR fixes an issue with allreducing for ZeRO0 + BF16. (This replaces
microsoft#5154)

DeepSpeed uses `BF16_Optimizer` when ZeRO0 and BF16 are enabled. The
optimizer accumulates gradients on FP32 buffer soon after a backward
pass completes. However, DeepSpeed engine performs allreduce on BF16
gradients.

This PR fixes the issue by performing allreduce on the FP32 buffer. It
also eliminates an assertion that prohibits BF16+PP+Z1, which is
actually runnable.

This shows loss curves of the following conditions:
- BF16/Z0,Z1,Z2,Z3/NoPP
- BF16/Z0,Z1/PP(2 stages)
(all used 8GPUs, gradient accumulation step: 4)

![image](https://github.com/microsoft/DeepSpeed/assets/81312776/0dc1e9ef-43bc-4b47-8b9e-d6aca137a217)

---------

Co-authored-by: Logan Adams <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants