Skip to content

Commit

Permalink
fix test instance
Browse files Browse the repository at this point in the history
  • Loading branch information
CoinCheung committed Oct 27, 2023
1 parent 38bce82 commit a689fa0
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 21 deletions.
57 changes: 41 additions & 16 deletions deepspeed/runtime/lr_schedules.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,7 +771,7 @@ def _get_gamma(self):
float(max(1.0, self.total_num_steps - self.warmup_num_steps)))


class WarmupCosineLR(WarmupLR):
class WarmupCosineLR(object):
"""Increase the learning rate of each parameter group from min lr ratio to max lr ratio
over warmup_num_steps steps, and then decay at cosine rate over the remaining training steps to min cosine ratio.
Expand Down Expand Up @@ -803,14 +803,17 @@ def __init__(self,
warmup_type: str = WARMUP_LOG_RATE,
last_batch_iteration: int = -1):

self.optimizer = get_torch_optimizer(optimizer)

self.total_num_steps = total_num_steps
self.warmup_min_ratio = warmup_min_ratio
self.last_batch_iteration = last_batch_iteration
self.cos_min_ratio = cos_min_ratio
super(WarmupCosineLR, self).__init__(
optimizer,
last_batch_iteration=last_batch_iteration,
warmup_num_steps=warmup_num_steps,
)

self.warmup_type = warmup_type
self.warmup_min_ratio = warmup_min_ratio
self.warmup_num_steps = max(2, warmup_num_steps)
self.inverse_log_warm_up = 1.0 / math.log(warmup_num_steps)

if self.total_num_steps < self.warmup_num_steps:
logger.warning('total_num_steps {} is less than warmup_num_steps {}'.format(
total_num_steps, warmup_num_steps))
Expand All @@ -828,20 +831,42 @@ def get_lr_ratio(self):
ratio = self.last_batch_iteration / self.warmup_num_steps
ratio_delta = 1. - self.warmup_min_ratio
ratio = self.warmup_min_ratio + ratio * ratio_delta
else:
real_last_step = self.last_batch_iteration - self.warmup_num_steps
real_total_steps = self.total_num_steps - self.warmup_num_steps
ratio_delta = 1. - self.cos_min_ratio
ratio = (1 + math.cos(math.pi * real_last_step / real_total_steps)) / 2
ratio = max(0.0, self.cos_min_ratio + ratio_delta * ratio)
return ratio

real_last_step = self.last_batch_iteration - self.warmup_num_steps + 1
real_total_steps = self.total_num_steps - self.warmup_num_steps
ratio_delta = 1. - self.cos_min_ratio
ratio = (1 + math.cos(math.pi * real_last_step / real_total_steps)) / 2
ratio = max(0.0, self.cos_min_ratio + ratio_delta * ratio)
return ratio

def step(self, last_batch_iteration=None):
if last_batch_iteration is None:
last_batch_iteration = self.last_batch_iteration + 1
self.last_batch_iteration = last_batch_iteration

lr_ratio = self.get_lr_ratio()
for param_group, org_lr in zip(self.optimizer.param_groups, self.org_lrs):
param_group['lr'] = lr_ratio * org_lr
lrs = self.get_lr()
for param_group, lr in zip(self.optimizer.param_groups, lrs):
param_group['lr'] = lr
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]

def get_lr(self):
if self.last_batch_iteration < 0:
logger.warning("Attempting to get learning rate from scheduler before it has started")
return [0.0]
lr_ratio = self.get_lr_ratio()
return [org_lr * lr_ratio for org_lr in self.org_lrs]

def state_dict(self):
return {'last_batch_iteration': self.last_batch_iteration}

def load_state_dict(self, sd):
self.last_batch_iteration = sd['last_batch_iteration']

def _format_param(self, optimizer, param_value, param_name):
if isinstance(param_value, list) or isinstance(param_value, tuple):
if len(param_value) != len(optimizer.param_groups):
raise ValueError("expected {} value for {}, got {}".format(len(optimizer.param_groups), param_name,
FileNotFoundError(param_value)))
return list(param_value)
return [param_value] * len(optimizer.param_groups)
11 changes: 6 additions & 5 deletions tests/unit/runtime/test_lr_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ class TestWarmupCosineLR(DistributedTest):

@pytest.mark.parametrize("total_num_steps, warmup_num_steps, cos_min_ratio, warmup_min_ratio",
[
(100, 10, 0.1, 0.2),
(200, 20, 0.1, 0.2),
(500, 30, 0.0, 0.2),
(600, 300, 0.1, 0.0),
Expand Down Expand Up @@ -490,22 +491,22 @@ def test_lr(self, total_num_steps, warmup_num_steps, cos_min_ratio, warmup_min_r

step_lrs = []
for _, batch in enumerate(data_loader):
step_lrs.extend(lr_scheduler.get_lr())
loss = model(batch[0], batch[1])
model.backward(loss)
model.step()
step_lrs.extend(lr_scheduler.get_lr())

# Verify starting lr
assert step_lrs[0] == opt_lr * warmup_min_ratio
assert abs(step_lrs[0] - opt_lr * warmup_min_ratio) < 1e-7

# Verify peak lr
assert step_lrs[warmup_num_steps] == opt_lr
assert abs(step_lrs[warmup_num_steps - 1] - opt_lr) < 1e-7

# Verify end lr
assert step_lrs[total_num_steps - 1] == opt_lr * cos_min_ratio
assert abs(step_lrs[total_num_steps - 1] - opt_lr * cos_min_ratio) < 1e-7

# Verify increasing phase
_verify_continuous_increase(step_lrs[:warmup_num_steps])

# Verify decreasing phase
_verify_continuous_decrease(step_lrs[warmup_num_steps:])
_verify_continuous_decrease(step_lrs[warmup_num_steps:total_num_steps])

0 comments on commit a689fa0

Please sign in to comment.