From 0b3f866d64806cb4c8bee079ada6320db9383603 Mon Sep 17 00:00:00 2001 From: hablb Date: Sat, 28 Oct 2023 00:50:19 +0300 Subject: [PATCH] Don't check overflow for bf16 data type (#4512) Always check for fp16 bf16 dynamic range is similar to fp32. don't check overflow by default. Co-authored-by: Olatunji Ruwase --- deepspeed/runtime/zero/stage3.py | 3 ++- deepspeed/runtime/zero/stage_1_and_2.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 84065a0c4a6e..47e453cea192 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1872,7 +1872,8 @@ def _overflow_clean_up(self, prev_scale): def _overflow_check_and_loss_scale_update(self): # First compute norm for all group so we know if there is overflow - self.check_overflow() + if self.dtype == torch.float16: + self.check_overflow() #loss scaling related computation prev_scale = self.loss_scale diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index d78597e85353..dbb5b128ae14 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -1712,7 +1712,8 @@ def step(self, closure=None): see_memory_usage(f"In step before checking overflow") # First compute norm for all group so we know if there is overflow - self.check_overflow() + if self.dtype == torch.float16: + self.check_overflow() prev_scale = self.loss_scale self._update_scale(self.overflow)