Skip to content

Commit

Permalink
Merge branch 'master' into loadams/adam-params
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Aug 20, 2024
2 parents f037d4f + 96393f5 commit 9a4b142
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions deepspeed/runtime/zero/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#when implemented outside of torch.autograd.Function

import math
import functools

import torch
from torch import Tensor
Expand All @@ -33,8 +34,14 @@ def print_rank_0(message, debug=False, force=False):


try:
autocast_custom_fwd = get_accelerator().amp().custom_fwd
autocast_custom_bwd = get_accelerator().amp().custom_bwd
# Fix `torch.[device].amp.custom_fwd/bwd` FutureWarning in torch 2.4
if hasattr(torch, 'amp') and hasattr(torch.amp, 'custom_fwd') and hasattr(torch.amp, 'custom_bwd'):
autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=get_accelerator().device_name())
autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=get_accelerator().device_name())
else:
# original implementation
autocast_custom_fwd = get_accelerator().amp().custom_fwd
autocast_custom_bwd = get_accelerator().amp().custom_bwd
except (ImportError, AttributeError) as exp:
autocast_custom_fwd = noop_decorator
autocast_custom_bwd = noop_decorator
Expand Down

0 comments on commit 9a4b142

Please sign in to comment.