Skip to content

Commit

Permalink
Don't check overflow for bf16 data type (microsoft#4512)
Browse files Browse the repository at this point in the history
Always check for fp16
bf16 dynamic range is similar to fp32. don't check overflow by default.

Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
2 people authored and amaurya committed Feb 17, 2024
1 parent 9dbac25 commit 0b3f866
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 2 additions & 1 deletion deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1872,7 +1872,8 @@ def _overflow_clean_up(self, prev_scale):
def _overflow_check_and_loss_scale_update(self):

# First compute norm for all group so we know if there is overflow
self.check_overflow()
if self.dtype == torch.float16:
self.check_overflow()

#loss scaling related computation
prev_scale = self.loss_scale
Expand Down
3 changes: 2 additions & 1 deletion deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1712,7 +1712,8 @@ def step(self, closure=None):
see_memory_usage(f"In step before checking overflow")

# First compute norm for all group so we know if there is overflow
self.check_overflow()
if self.dtype == torch.float16:
self.check_overflow()

prev_scale = self.loss_scale
self._update_scale(self.overflow)
Expand Down

0 comments on commit 0b3f866

Please sign in to comment.