Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ZeRO0 does not handle BF16 gradients properly #5154

Closed

Conversation

tohtana
Copy link
Contributor

@tohtana tohtana commented Feb 19, 2024

The combination of BF16 and ZeRO0 (no ZeRO optimization) has some issues with handling gradients. BF16_Optimizer seems to be designed to accumulate gradients in FP32, but this doesn't match other parts.

  • DeepSpeed engine converts BF16 gradients and accumulate them in FP32 soon after the backward pass using BF16_Optimizer. However, the engine performs allreduce on BF16 gradients after that. As the result, gradients are not properly reduced.
  • The engine calls BF16_Optimizer's backward(). This clears BF16 gradients. So it doesn't work when the gradient accumulation step > 1.

There are two possible approaches to resolve this issue:

  1. Accumulate gradients in BF16 until the boundary of gradient accumulation steps (We should clear BF16 gradient only at the gradient accumulation boundary). Then perform allreduce, conversion to FP32, and parameter update.
  2. Accumulate gradients in FP32 and run allreduce.

This PR takes the first approach to resolve the issue. ZeRO 1/2/3 follows the first approach though the second one has the advantage in terms the precision for gradient accumulation.

@tohtana tohtana changed the base branch from master to tohtana/fix_fp32_clipping February 19, 2024 08:45
@tohtana tohtana marked this pull request as ready for review February 19, 2024 09:33
@tohtana tohtana changed the title Update bf16 optimizer's master weights after allreduce ZeRO0 does not handle BF16 gradients properly Feb 20, 2024
@tohtana
Copy link
Contributor Author

tohtana commented Feb 21, 2024

This PR breaks PP. Opened #5170 as another solution.

@tohtana tohtana closed this Feb 21, 2024
github-merge-queue bot pushed a commit that referenced this pull request Feb 21, 2024
This PR fixes an issue with allreducing for ZeRO0 + BF16. (This replaces
#5154)

DeepSpeed uses `BF16_Optimizer` when ZeRO0 and BF16 are enabled. The
optimizer accumulates gradients on FP32 buffer soon after a backward
pass completes. However, DeepSpeed engine performs allreduce on BF16
gradients.

This PR fixes the issue by performing allreduce on the FP32 buffer. It
also eliminates an assertion that prohibits BF16+PP+Z1, which is
actually runnable.

This shows loss curves of the following conditions:
- BF16/Z0,Z1,Z2,Z3/NoPP
- BF16/Z0,Z1/PP(2 stages)
(all used 8GPUs, gradient accumulation step: 4)

![image](https://github.com/microsoft/DeepSpeed/assets/81312776/0dc1e9ef-43bc-4b47-8b9e-d6aca137a217)

---------

Co-authored-by: Logan Adams <[email protected]>
ShellyNR pushed a commit to ShellyNR/DeepSpeed that referenced this pull request Mar 11, 2024
This PR fixes an issue with allreducing for ZeRO0 + BF16. (This replaces
microsoft#5154)

DeepSpeed uses `BF16_Optimizer` when ZeRO0 and BF16 are enabled. The
optimizer accumulates gradients on FP32 buffer soon after a backward
pass completes. However, DeepSpeed engine performs allreduce on BF16
gradients.

This PR fixes the issue by performing allreduce on the FP32 buffer. It
also eliminates an assertion that prohibits BF16+PP+Z1, which is
actually runnable.

This shows loss curves of the following conditions:
- BF16/Z0,Z1,Z2,Z3/NoPP
- BF16/Z0,Z1/PP(2 stages)
(all used 8GPUs, gradient accumulation step: 4)

![image](https://github.com/microsoft/DeepSpeed/assets/81312776/0dc1e9ef-43bc-4b47-8b9e-d6aca137a217)

---------

Co-authored-by: Logan Adams <[email protected]>
rraminen pushed a commit to ROCm/DeepSpeed that referenced this pull request May 9, 2024
This PR fixes an issue with allreducing for ZeRO0 + BF16. (This replaces
microsoft#5154)

DeepSpeed uses `BF16_Optimizer` when ZeRO0 and BF16 are enabled. The
optimizer accumulates gradients on FP32 buffer soon after a backward
pass completes. However, DeepSpeed engine performs allreduce on BF16
gradients.

This PR fixes the issue by performing allreduce on the FP32 buffer. It
also eliminates an assertion that prohibits BF16+PP+Z1, which is
actually runnable.

This shows loss curves of the following conditions:
- BF16/Z0,Z1,Z2,Z3/NoPP
- BF16/Z0,Z1/PP(2 stages)
(all used 8GPUs, gradient accumulation step: 4)

![image](https://github.com/microsoft/DeepSpeed/assets/81312776/0dc1e9ef-43bc-4b47-8b9e-d6aca137a217)

---------

Co-authored-by: Logan Adams <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant