Skip to content

Commit

Permalink
Support Callable type for client optimizer and lr_scheduler (#1316)
Browse files Browse the repository at this point in the history
* Callable option for optimizer and scheduler

* Add unit test

* Formatting

* Disable debug prints

* Use base optimizer to construct lr scheduler

* Formatting

* Remove dead import
  • Loading branch information
tjruwase authored Aug 25, 2021
1 parent aa12129 commit 274c375
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 28 deletions.
29 changes: 17 additions & 12 deletions deepspeed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
'''
import sys
import types

from typing import Optional, Union
import torch
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from packaging import version as pkg_version

from . import ops
from . import module_inject

from .runtime.engine import DeepSpeedEngine
from .runtime.engine import DeepSpeedEngine, DeepSpeedOptimizerCallable, DeepSpeedSchedulerCallable
from .runtime.engine import ADAM_OPTIMIZER, LAMB_OPTIMIZER
from .runtime.pipe.engine import PipelineEngine
from .inference.engine import InferenceEngine
Expand Down Expand Up @@ -56,13 +59,15 @@ def _parse_version(version_str):


def initialize(args=None,
model=None,
optimizer=None,
model_parameters=None,
training_data=None,
lr_scheduler=None,
model: torch.nn.Module = None,
optimizer: Optional[Union[Optimizer,
DeepSpeedOptimizerCallable]] = None,
model_parameters: Optional[torch.nn.Module] = None,
training_data: Optional[torch.utils.data.Dataset] = None,
lr_scheduler: Optional[Union[_LRScheduler,
DeepSpeedSchedulerCallable]] = None,
mpu=None,
dist_init_required=None,
dist_init_required: Optional[bool] = None,
collate_fn=None,
config=None,
config_params=None):
Expand All @@ -74,16 +79,16 @@ def initialize(args=None,
model: Required: nn.module class before apply any wrappers
optimizer: Optional: a user defined optimizer, this is typically used instead of defining
an optimizer in the DeepSpeed json config.
optimizer: Optional: a user defined Optimizer or Callable that returns an Optimizer object.
This overrides any optimizer definition in the DeepSpeed json config.
model_parameters: Optional: An iterable of torch.Tensors or dicts.
Specifies what Tensors should be optimized.
training_data: Optional: Dataset of type torch.utils.data.Dataset
lr_scheduler: Optional: Learning Rate Scheduler Object. It should define a get_lr(),
step(), state_dict(), and load_state_dict() methods
lr_scheduler: Optional: Learning Rate Scheduler Object or a Callable that takes an Optimizer and returns a Scheduler object.
The scheduler object should define a get_lr(), step(), state_dict(), and load_state_dict() methods
mpu: Optional: A model parallelism unit object that implements
get_{model,data}_parallel_{rank,group,world_size}()
Expand Down
61 changes: 48 additions & 13 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,14 @@
from shutil import copyfile

from torch.nn.modules import Module
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from torch.distributed.distributed_c10d import _get_global_rank
from tensorboardX import SummaryWriter

from typing import Callable, Dict, Optional, Union, Iterable

from deepspeed.runtime.utils import see_memory_usage
from deepspeed.runtime.zero.stage2 import FP16_DeepSpeedZeroOptimizer
from deepspeed.runtime.zero.stage1 import FP16_DeepSpeedZeroOptimizer_Stage1
Expand Down Expand Up @@ -57,6 +62,10 @@

MEMORY_OPT_ALLREDUCE_SIZE = 500000000

DeepSpeedOptimizerCallable = \
Callable[[Union[Iterable[Parameter], Dict[str, Iterable]]], Optimizer]
DeepSpeedSchedulerCallable = Callable[[Optimizer], _LRScheduler]

try:
import apex
from apex import amp
Expand Down Expand Up @@ -198,6 +207,7 @@ def __init__(self,

# Configure optimizer and scheduler
self.optimizer = None
self.basic_optimizer = None
self.lr_scheduler = None
if model_parameters or optimizer:
self._configure_optimizer(optimizer, model_parameters)
Expand Down Expand Up @@ -536,9 +546,15 @@ def _configure_lr_scheduler(self, client_lr_scheduler):
f'DeepSpeed using configured LR scheduler = {self.scheduler_name()}')
self.lr_scheduler = lr_scheduler
else:
if self.global_rank == 0:
logger.info('DeepSpeed using client LR scheduler')
self.lr_scheduler = client_lr_scheduler
if isinstance(client_lr_scheduler, _LRScheduler):
if self.global_rank == 0:
logger.info('DeepSpeed using client LR scheduler')
self.lr_scheduler = client_lr_scheduler
elif isinstance(client_lr_scheduler, Callable):
if self.global_rank == 0:
logger.info('DeepSpeed using client callable to create LR scheduler')
self.lr_scheduler = client_lr_scheduler(self.basic_optimizer)

log_dist(f'DeepSpeed LR Scheduler = {self.lr_scheduler}', ranks=[0])

def _configure_checkpointing(self, dist_init_required):
Expand Down Expand Up @@ -644,6 +660,9 @@ def _is_supported_optimizer(self, optimizer_name):

# Validate configuration based on command line arguments
def _do_sanity_check(self):
assert isinstance(self.client_optimizer, (type(None), Optimizer, Callable)), \
f'Client Optimizer is of unexpected type {type(self.client_optimizer)}'

if not self.client_optimizer:
if self.optimizer_name() is not None:
assert self._is_supported_optimizer(self.optimizer_name()), \
Expand All @@ -654,6 +673,14 @@ def _do_sanity_check(self):
assert self.dynamic_loss_scale(), \
'DeepSpeed {} optimizer requires dynamic loss scaling'.format(self.optimizer_name())

assert isinstance(self.client_lr_scheduler, (type(None), _LRScheduler, Callable)), \
f'Client LR Scheduler is of unexpected type {type(self.client_lr_scheduler)}'

# Detect invalid combinations of client optimizer and client scheduler
if isinstance(self.client_lr_scheduler, _LRScheduler):
assert isinstance(self.client_optimizer, Optimizer), \
f'Client Optimizer (type = {type(self.client_optimizer)} is not instantiated but Client LR Scheduler is instantiated'

def _broadcast_model(self):
def is_replicated(p):
if hasattr(p, 'ds_status') and p.ds_status is not ZeroParamStatus.AVAILABLE:
Expand Down Expand Up @@ -771,18 +798,23 @@ def ids_list(group):

# Configure optimizer
def _configure_optimizer(self, client_optimizer, model_parameters):

if client_optimizer is not None:
client_optimizer.param_groups[:] = [
pg for pg in client_optimizer.param_groups if len(pg["params"]) != 0
]
if self.global_rank == 0:
logger.info(
"Removing param_group that has no 'params' in the client Optimizer")
if isinstance(client_optimizer, Optimizer):
client_optimizer.param_groups[:] = [
pg for pg in client_optimizer.param_groups if len(pg["params"]) != 0
]
if self.global_rank == 0:
logger.info(
"Removing param_group that has no 'params' in the client Optimizer"
)

basic_optimizer = client_optimizer
if self.global_rank == 0:
logger.info('Using client Optimizer as basic optimizer')
basic_optimizer = client_optimizer
if self.global_rank == 0:
logger.info('Using client Optimizer as basic optimizer')
else:
basic_optimizer = client_optimizer(model_parameters)
if self.global_rank == 0:
logger.info('Using client callable to create basic optimizer')
else:
basic_optimizer = self._configure_basic_optimizer(model_parameters)
if self.global_rank == 0:
Expand All @@ -792,6 +824,7 @@ def _configure_optimizer(self, client_optimizer, model_parameters):

self._check_for_duplicates(basic_optimizer)

self.basic_optimizer = basic_optimizer
if self.global_rank == 0:
logger.info('DeepSpeed Basic Optimizer = {}'.format(
basic_optimizer.__class__.__name__))
Expand Down Expand Up @@ -832,6 +865,8 @@ def _configure_optimizer(self, client_optimizer, model_parameters):

def _configure_basic_optimizer(self, model_parameters):
optimizer_parameters = self.optimizer_params()
if optimizer_parameters is None:
optimizer_parameters = {}
# print(optimizer_parameters.keys())
if 'max_grad_norm' in optimizer_parameters.keys():
raise ValueError(
Expand Down
6 changes: 3 additions & 3 deletions deepspeed/runtime/zero/stage2.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def __init__(self,
see_memory_usage(f"Before moving param group {i} to CPU")
# move all the parameters to cpu to free up GPU space for creating flat buffer
move_to_cpu(self.fp16_groups[i])
see_memory_usage(f"After moving param group {i} to CPU", force=True)
see_memory_usage(f"After moving param group {i} to CPU", force=False)

# Reorder group parameters for load balancing of gradient partitioning during backward among ranks.
# This ensures that gradients are reduced in a fashion such that ownership round robins among the ranks.
Expand All @@ -286,12 +286,12 @@ def __init__(self,
dist.get_world_size(group=self.real_dp_process_group[i])).cuda(
torch.cuda.current_device()))
see_memory_usage(f"After flattening and moving param group {i} to GPU",
force=True)
force=False)

if dist.get_rank(group=self.real_dp_process_group[i]) == 0:
see_memory_usage(
f"After Flattening and after emptying param group {i} cache",
force=True)
force=False)

# set model fp16 weight to slices of flattened buffer
self._update_model_fp16_weights(i)
Expand Down
135 changes: 135 additions & 0 deletions tests/unit/test_ds_initialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import pytest
from typing import Callable
import torch
from torch.optim import Optimizer, Adam, AdamW
from torch.optim.lr_scheduler import _LRScheduler, LambdaLR

from simple_model import args_from_dict, SimpleModel
from common import distributed_test

import deepspeed
from deepspeed.ops.adam import FusedAdam
from deepspeed.runtime.lr_schedules import WARMUP_LR, WarmupLR
from deepspeed.runtime.config import ADAM_OPTIMIZER


@pytest.mark.parametrize('optimizer_type', [None, Optimizer, Callable])
def test_client_optimizer(tmpdir, optimizer_type):
def _optimizer_callable(params) -> Optimizer:
return AdamW(params=params)

hidden_dim = 10
model = SimpleModel(hidden_dim)

config_dict = {'train_batch_size': 1}
if optimizer_type is None:
client_optimizer = None
config_dict['optimizer'] = {'type': ADAM_OPTIMIZER}
elif optimizer_type is Optimizer:
client_optimizer = Adam(model.parameters())
else:
client_optimizer = _optimizer_callable

args = args_from_dict(tmpdir, config_dict)

@distributed_test(world_size=[1])
def _test_client_optimizer(args, model, client_optimizer):
_, ds_optimizer, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=list(model.parameters()),
optimizer=client_optimizer)
if client_optimizer is None:
assert isinstance(ds_optimizer, FusedAdam)
elif isinstance(client_optimizer, Optimizer):
assert ds_optimizer == client_optimizer
else:
assert isinstance(ds_optimizer, AdamW)

_test_client_optimizer(args=args, model=model, client_optimizer=client_optimizer)


@pytest.mark.parametrize('scheduler_type, optimizer_type',
[(None,
None),
(None,
Optimizer),
(None,
Callable),
(_LRScheduler,
None),
(_LRScheduler,
Optimizer),
(_LRScheduler,
Callable),
(Callable,
None),
(Callable,
Optimizer),
(Callable,
Callable)])
def test_client_lr_scheduler(tmpdir, scheduler_type, optimizer_type):
def _my_lambda(epoch):
return epoch // 10

def _optimizer_callable(params) -> Optimizer:
return torch.optim.AdamW(params=params)

def _lr_scheduler_callable(optimizer) -> _LRScheduler:
return LambdaLR(optimizer, _my_lambda)

hidden_dim = 10
model = SimpleModel(hidden_dim)

config_dict = {'train_batch_size': 1}

client_optimizer = None
client_scheduler = None

if optimizer_type is None:
config_dict['optimizer'] = {'type': ADAM_OPTIMIZER}
elif optimizer_type is Optimizer:
client_optimizer = torch.optim.Adam(model.parameters())
else:
client_optimizer = _optimizer_callable

if scheduler_type is None:
config_dict['scheduler'] = {'type': WARMUP_LR, 'params': {}}
elif scheduler_type == _LRScheduler:
if isinstance(client_optimizer, Optimizer):
client_scheduler = LambdaLR(client_optimizer, _my_lambda)
else:
# Verify invalid combination is correctly handled
client_scheduler = LambdaLR(torch.optim.Adam(model.parameters()), _my_lambda)
else:
client_scheduler = _lr_scheduler_callable

args = args_from_dict(tmpdir, config_dict)

@distributed_test(world_size=[1])
def _test_client_lr_scheduler(args, model, optimizer, lr_scheduler):
if isinstance(lr_scheduler,
_LRScheduler) and not isinstance(optimizer,
Optimizer):
with pytest.raises(AssertionError):
_, _, _, _ = deepspeed.initialize(args=args,
model=model,
model_parameters=list(model.parameters()),
optimizer=optimizer,
lr_scheduler=lr_scheduler)
else:
_, _, _, ds_lr_scheduler = deepspeed.initialize(args=args,
model=model,
model_parameters=list(model.parameters()),
optimizer=optimizer,
lr_scheduler=lr_scheduler)
if lr_scheduler is None:
assert isinstance(ds_lr_scheduler, WarmupLR)
elif isinstance(lr_scheduler, _LRScheduler):
assert ds_lr_scheduler == lr_scheduler
else:
assert isinstance(ds_lr_scheduler, LambdaLR)

_test_client_lr_scheduler(args=args,
model=model,
optimizer=client_optimizer,
lr_scheduler=client_scheduler)

0 comments on commit 274c375

Please sign in to comment.