Skip to content

Commit

Permalink
Fix allreduce for BF16 and ZeRO0 (microsoft#5170)
Browse files Browse the repository at this point in the history
This PR fixes an issue with allreducing for ZeRO0 + BF16. (This replaces
microsoft#5154)

DeepSpeed uses `BF16_Optimizer` when ZeRO0 and BF16 are enabled. The
optimizer accumulates gradients on FP32 buffer soon after a backward
pass completes. However, DeepSpeed engine performs allreduce on BF16
gradients.

This PR fixes the issue by performing allreduce on the FP32 buffer. It
also eliminates an assertion that prohibits BF16+PP+Z1, which is
actually runnable.

This shows loss curves of the following conditions:
- BF16/Z0,Z1,Z2,Z3/NoPP
- BF16/Z0,Z1/PP(2 stages)
(all used 8GPUs, gradient accumulation step: 4)

![image](https://github.com/microsoft/DeepSpeed/assets/81312776/0dc1e9ef-43bc-4b47-8b9e-d6aca137a217)

---------

Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
2 people authored and rraminen committed May 9, 2024
1 parent 4d9ce5a commit 0d69cbc
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1911,9 +1911,6 @@ def print_forward_breakdown(self, fwd_time):

@instrument_w_nvtx
def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
assert not (self.bfloat16_enabled() and self.pipeline_parallelism), \
f'allreduce_gradients() is not valid when bfloat+pipeline_parallelism is enabled'

# Pass (PP) gas boundary flag to optimizer (required for zero)
self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary()
# ZeRO stage >= 2 communicates during non gradient accumulation boundaries as well
Expand All @@ -1926,7 +1923,11 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
self.optimizer, 'reduce_gradients'):
self.optimizer.reduce_gradients(pipeline_parallel=self.pipeline_parallelism)
else:
self.buffered_allreduce_fallback(elements_per_buffer=bucket_size)
grads = None
if hasattr(self.optimizer, "get_grads_for_reduction"):
# This is currently for BF16 optimizer
grads = self.optimizer.get_grads_for_reduction()
self.buffered_allreduce_fallback(grads=grads, elements_per_buffer=bucket_size)

@instrument_w_nvtx
def backward(self, loss, allreduce_gradients=True, release_loss=False, retain_graph=False, scale_wrt_gas=True):
Expand Down

0 comments on commit 0d69cbc

Please sign in to comment.