Skip to content

Commit

Permalink
[Bug fix] WarmupCosineLR issues (#4688)
Browse files Browse the repository at this point in the history
Original code missing a `self.` before `warmup_num_steps` so that
`warmup_num_steps` might be 0 and cause math domain error when doing
`math.log(0)`

```py
        self.warmup_num_steps = max(2, warmup_num_steps)
        self.inverse_log_warm_up = 1.0 / math.log(warmup_num_steps)
```
  • Loading branch information
sbwww authored Nov 16, 2023
1 parent bcdabf4 commit ce0ebda
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion deepspeed/runtime/lr_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ def __init__(self,
self.warmup_type = warmup_type
self.warmup_min_ratio = warmup_min_ratio
self.warmup_num_steps = max(2, warmup_num_steps)
self.inverse_log_warm_up = 1.0 / math.log(warmup_num_steps)
self.inverse_log_warm_up = 1.0 / math.log(self.warmup_num_steps)

if self.total_num_steps < self.warmup_num_steps:
logger.warning('total_num_steps {} is less than warmup_num_steps {}'.format(
Expand Down Expand Up @@ -857,6 +857,12 @@ def get_lr(self):
lr_ratio = self.get_lr_ratio()
return [org_lr * lr_ratio for org_lr in self.org_lrs]

def get_last_lr(self):
""" Return last computed learning rate by current scheduler.
"""
assert getattr(self, '_last_lr', None) is not None, "need to call step() first"
return self._last_lr

def state_dict(self):
return {'last_batch_iteration': self.last_batch_iteration}

Expand Down

0 comments on commit ce0ebda

Please sign in to comment.