Skip to content

Commit

Permalink
Merge branch 'master' into cleanup_device
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Aug 14, 2024
2 parents ef307ea + f994fb2 commit afdc942
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 11 deletions.
22 changes: 11 additions & 11 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def __init__(self,

if has_optimizer:
self._configure_optimizer(optimizer, model_parameters)
self._configure_lr_scheduler(lr_scheduler)
self._configure_lr_scheduler()
self._report_progress(0)
elif self.zero_optimization():
# no optim selected but zero is enabled
Expand Down Expand Up @@ -943,19 +943,19 @@ def _optimizer_has_ckpt_event_prologue(self):
def _optimizer_has_ckpt_event_epilogue(self):
return self.optimizer is not None and hasattr(self.optimizer, 'checkpoint_event_epilogue')

def _configure_lr_scheduler(self, client_lr_scheduler):
# First check for scheduler in json configuration
lr_scheduler = self._scheduler_from_config(self.optimizer)
if lr_scheduler:
log_dist(f"DeepSpeed using configured LR scheduler = {self.scheduler_name()}", ranks=[0])
self.lr_scheduler = lr_scheduler
else:
if isinstance(client_lr_scheduler, Callable):
def _configure_lr_scheduler(self):
if self.client_lr_scheduler:
if isinstance(self.client_lr_scheduler, Callable):
log_dist('DeepSpeed using client callable to create LR scheduler', ranks=[0])
self.lr_scheduler = client_lr_scheduler(self.basic_optimizer)
self.lr_scheduler = self.client_lr_scheduler(self.basic_optimizer)
else:
log_dist('DeepSpeed using client LR scheduler', ranks=[0])
self.lr_scheduler = client_lr_scheduler
self.lr_scheduler = self.client_lr_scheduler
else:
# load lr scheduler from json configuration if lr scheduler is not defined and passed in
lr_scheduler = self._scheduler_from_config(self.optimizer)
log_dist(f"DeepSpeed using configured LR scheduler = {self.scheduler_name()}", ranks=[0])
self.lr_scheduler = lr_scheduler

log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0])

Expand Down
129 changes: 129 additions & 0 deletions tests/unit/runtime/test_ds_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,132 @@ def _lr_scheduler_callable(optimizer) -> _LRScheduler:
assert ds_lr_scheduler == client_scheduler
else:
assert isinstance(ds_lr_scheduler, LambdaLR)


@pytest.mark.parametrize("scheduler_type", [None, _LRScheduler, Callable])
class TestClientLrSchedulerInit(DistributedTest):
world_size = 1

def test_same_lrscheler_and_callable(self, scheduler_type):
"""
Expect behavior
if lr scheduler is defined in code and passed into initialize as arg,
it will be used even this is a lr scheduler has been defined in config.
Initialize lr scheduler from config when no lr scheduler is defined in code.
"""

def _my_lambda(epoch):
return epoch // 10

def _lr_scheduler_callable(optimizer) -> _LRScheduler:
return LambdaLR(optimizer, _my_lambda)

config_dict = {'train_batch_size': 1}

hidden_dim = 10
model = SimpleModel(hidden_dim)

client_optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

if scheduler_type is None:
config_dict['scheduler'] = {'type': WARMUP_LR, 'params': {}}
client_scheduler = None
elif scheduler_type == _LRScheduler:
client_scheduler = LambdaLR(client_optimizer, _my_lambda)
else:
client_scheduler = _lr_scheduler_callable

_, _, _, ds_lr_scheduler = deepspeed.initialize(config=config_dict,
model=model,
model_parameters=list(model.parameters()),
optimizer=client_optimizer,
lr_scheduler=client_scheduler)
if scheduler_type is None:
# in this case, we initialize from config
assert not isinstance(ds_lr_scheduler, LambdaLR)
assert isinstance(ds_lr_scheduler, WarmupLR)
else:
# in this case, we initialize from passed-in scheduler
assert isinstance(ds_lr_scheduler, LambdaLR)
assert not isinstance(ds_lr_scheduler, WarmupLR)

def test_diff_lrscheler_and_callable(self, scheduler_type):
"""
In this test,
the LambdaLR will be used for lrscheduler type
and the StepLR will be used for callable type
"""

from torch.optim.lr_scheduler import StepLR

def _my_lambda(epoch):
return epoch // 10

def _lr_scheduler_callable(optimizer) -> _LRScheduler:
return StepLR(optimizer, step_size=30)

config_dict = {'train_batch_size': 1}

hidden_dim = 10
model = SimpleModel(hidden_dim)

client_optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

if scheduler_type is None:
config_dict['scheduler'] = {'type': WARMUP_LR, 'params': {}}
client_scheduler = None
elif scheduler_type == _LRScheduler:
client_scheduler = LambdaLR(client_optimizer, _my_lambda)
else:
client_scheduler = _lr_scheduler_callable

_, _, _, ds_lr_scheduler = deepspeed.initialize(config=config_dict,
model=model,
model_parameters=list(model.parameters()),
optimizer=client_optimizer,
lr_scheduler=client_scheduler)
if scheduler_type is None:
assert isinstance(ds_lr_scheduler, WarmupLR)
elif scheduler_type == _LRScheduler:
assert isinstance(ds_lr_scheduler, LambdaLR)
else:
# callable
assert isinstance(ds_lr_scheduler, StepLR)

def test_diff_lrscheler_and_callable_onecyclelr_steplr(self, scheduler_type):

from deepspeed.runtime.lr_schedules import OneCycle, ONE_CYCLE, CYCLE_MIN_LR, CYCLE_MAX_LR
from torch.optim.lr_scheduler import OneCycleLR, StepLR

def _lr_scheduler_callable(optimizer) -> _LRScheduler:
return OneCycleLR(optimizer, max_lr=0.01, total_steps=200)

config_dict = {'train_batch_size': 1}

hidden_dim = 10
model = SimpleModel(hidden_dim)

client_optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

if scheduler_type is None:
config_dict['scheduler'] = {'type': ONE_CYCLE, 'params': {CYCLE_MIN_LR: 0, CYCLE_MAX_LR: 0.1}}
client_scheduler = None
elif scheduler_type == _LRScheduler:
client_scheduler = StepLR(client_optimizer, step_size=30)
else:
client_scheduler = _lr_scheduler_callable

_, _, _, ds_lr_scheduler = deepspeed.initialize(config=config_dict,
model=model,
model_parameters=list(model.parameters()),
optimizer=client_optimizer,
lr_scheduler=client_scheduler)
if scheduler_type is None:
assert isinstance(ds_lr_scheduler, OneCycle)
elif scheduler_type == _LRScheduler:
assert isinstance(ds_lr_scheduler, StepLR)
else:
# callable
assert isinstance(ds_lr_scheduler, OneCycleLR)

0 comments on commit afdc942

Please sign in to comment.