From dd3690c5348ddfa475cef112eb01aebe6a409847 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka <81312776+tohtana@users.noreply.github.com> Date: Wed, 21 Feb 2024 11:07:33 -0800 Subject: [PATCH] Fix allreduce for BF16 and ZeRO0 (#5170) This PR fixes an issue with allreducing for ZeRO0 + BF16. (This replaces #5154) DeepSpeed uses `BF16_Optimizer` when ZeRO0 and BF16 are enabled. The optimizer accumulates gradients on FP32 buffer soon after a backward pass completes. However, DeepSpeed engine performs allreduce on BF16 gradients. This PR fixes the issue by performing allreduce on the FP32 buffer. It also eliminates an assertion that prohibits BF16+PP+Z1, which is actually runnable. This shows loss curves of the following conditions: - BF16/Z0,Z1,Z2,Z3/NoPP - BF16/Z0,Z1/PP(2 stages) (all used 8GPUs, gradient accumulation step: 4) ![image](https://github.com/microsoft/DeepSpeed/assets/81312776/0dc1e9ef-43bc-4b47-8b9e-d6aca137a217) --------- Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com> --- deepspeed/runtime/engine.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index e15cc49339ff..5c1202ba06ae 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1911,9 +1911,6 @@ def print_forward_breakdown(self, fwd_time): @instrument_w_nvtx def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): - assert not (self.bfloat16_enabled() and self.pipeline_parallelism), \ - f'allreduce_gradients() is not valid when bfloat+pipeline_parallelism is enabled' - # Pass (PP) gas boundary flag to optimizer (required for zero) self.optimizer.is_gradient_accumulation_boundary = self.is_gradient_accumulation_boundary() # ZeRO stage >= 2 communicates during non gradient accumulation boundaries as well @@ -1926,7 +1923,11 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE): self.optimizer, 'reduce_gradients'): self.optimizer.reduce_gradients(pipeline_parallel=self.pipeline_parallelism) else: - self.buffered_allreduce_fallback(elements_per_buffer=bucket_size) + grads = None + if hasattr(self.optimizer, "get_grads_for_reduction"): + # This is currently for BF16 optimizer + grads = self.optimizer.get_grads_for_reduction() + self.buffered_allreduce_fallback(grads=grads, elements_per_buffer=bucket_size) @instrument_w_nvtx def backward(self, loss, allreduce_gradients=True, release_loss=False, retain_graph=False, scale_wrt_gas=True):