Skip to content

Commit

Permalink
BF16 optimizer: Improve device utilization by immediate grad update
Browse files Browse the repository at this point in the history
Enabled gradient accumulation in bf16 optimizer which
updates fp32 gradients once the gradient is available.

This improves device utilization on some back-ends, by
parallelizing the workload across engines.

To enable the feature (disabled by default), use a new
config flag "accumulate_grads_via_hooks" under "bf16"
section in Deepspeed config.json (default is false).
  Example:
  "bf16": {
    "enabled": true,
    "accumulate_grads_via_hooks": true
   }
  • Loading branch information
deepcharm committed Dec 18, 2023
1 parent 4d866bd commit 1272438
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 18 deletions.
66 changes: 50 additions & 16 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def __init__(self,
allgather_bucket_size=5000000000,
dp_process_group=None,
timers=None,
grad_acc_dtype=None):
grad_acc_dtype=None,
accumulate_grads_via_hooks=False):
super().__init__()
see_memory_usage('begin bf16_optimizer', force=True)
self.timers = timers
Expand All @@ -49,6 +50,7 @@ def __init__(self,
assert grad_acc_dtype in [torch.float32, torch.bfloat16
], f"BF16Optimizer: Unsupported gradient accumulation data type: {grad_acc_dtype}"
self.grad_acc_dtype = grad_acc_dtype
self.accumulate_grads_via_hooks = accumulate_grads_via_hooks

self.clip_grad = clip_grad
self.norm_type = norm_type
Expand Down Expand Up @@ -162,6 +164,9 @@ def _setup_for_real_optimizer(self):
self.initialize_optimizer_states()
see_memory_usage('end initialize_optimizer', force=True)

if self.accumulate_grads_via_hooks:
self.create_grad_acc_hooks()

# Need optimizer states initialized before linking lp to optimizer state
self._link_all_hp_params()
self._enable_universal_checkpoint()
Expand Down Expand Up @@ -276,27 +281,34 @@ def backward(self, loss, update_hp_grads=True, clear_lp_grads=False, **bwd_kwarg
self.clear_lp_grads()
loss.backward(**bwd_kwargs)

if update_hp_grads:
if not self.accumulate_grads_via_hooks and update_hp_grads:
self.update_hp_grads(clear_lp_grads=clear_lp_grads)

@torch.no_grad()
def update_hp_grads(self, clear_lp_grads=False):
for i, group in enumerate(self.bf16_groups):
for j, lp in enumerate(group):
if lp.grad is None:
continue
def update_hp_grad(self, lp, group_idx, param_idx, clear_lp_grads):
if lp.grad is None:
return

hp_grad = self.fp32_groups_gradients[i][j]
assert hp_grad is not None, \
f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{i}][{j}]'
hp_grad = self.fp32_groups_gradients[group_idx][param_idx]
assert hp_grad is not None, \
f'high precision param has no gradient, lp param_id = {id(lp)} group_info = [{group_idx}][{param_idx}]'

hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape))
lp._hp_grad = hp_grad
self.fp32_groups_has_gradients[i][j] = True
hp_grad.data.add_(lp.grad.data.to(hp_grad.dtype).view(hp_grad.shape))
lp._hp_grad = hp_grad
self.fp32_groups_has_gradients[group_idx][param_idx] = True

# clear gradients
if clear_lp_grads:
lp.grad = None
# clear gradients
if clear_lp_grads:
lp.grad = None

@torch.no_grad()
def update_hp_grads(self, clear_lp_grads=False):
if self.accumulate_grads_via_hooks:
return

for i, group in enumerate(self.bf16_groups):
for j, lp in enumerate(group):
self.update_hp_grad(lp, i, j, clear_lp_grads)

@torch.no_grad()
def get_grads_for_reduction(self):
Expand Down Expand Up @@ -426,6 +438,28 @@ def _load_hp_checkpoint_state(self, checkpoint_dir):
lp.load_hp_checkpoint_state(os.path.join(checkpoint_dir, self.param_names[lp]), tp_rank,
tp_world_size)

def accumulate_hp_grads_and_remove_lp(self, lp_param, group_idx, param_idx):
assert self.accumulate_grads_via_hooks
self.update_hp_grad(lp_param, group_idx, param_idx, clear_lp_grads=False)

def create_grad_acc_hooks(self):
self.grad_accs = []
for i, param_group in enumerate(self.bf16_groups):
for j, param in enumerate(param_group):
if param.requires_grad:

def wrapper(param, i, j):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]

def accumulate_hp_grads_and_remove_lp(*notneeded):
self.accumulate_hp_grads_and_remove_lp(param, i, j)

grad_acc.register_hook(accumulate_hp_grads_and_remove_lp)
self.grad_accs.append(grad_acc)

wrapper(param, i, j)


def _get_padded_tensor(src_tensor, size):
if src_tensor.numel() >= size:
Expand Down
8 changes: 8 additions & 0 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,13 @@ def get_bfloat16_enabled(param_dict):
return False


def get_bfloat16_accumulate_grads_via_hooks(param_dict):
for key in [BFLOAT16, BFLOAT16_OLD]:
if key in param_dict.keys():
return get_scalar_param(param_dict[key], BFLOAT16_GRAD_ACC_VIA_HOOKS, BFLOAT16_GRAD_ACC_VIA_HOOKS_DEFAULT)
return False


def get_fp16_master_weights_and_grads_enabled(param_dict):
if get_fp16_enabled(param_dict):
return get_scalar_param(param_dict[FP16], FP16_MASTER_WEIGHTS_AND_GRADS, FP16_MASTER_WEIGHTS_AND_GRADS_DEFAULT)
Expand Down Expand Up @@ -813,6 +820,7 @@ def _initialize_params(self, param_dict):
self.fp16_enabled = get_fp16_enabled(param_dict)
self.fp16_auto_cast = get_fp16_auto_cast(param_dict)
self.bfloat16_enabled = get_bfloat16_enabled(param_dict)
self.bfloat16_accumulate_grads_via_hooks = get_bfloat16_accumulate_grads_via_hooks(param_dict)
assert not (self.fp16_enabled
and self.bfloat16_enabled), 'bfloat16 and fp16 modes cannot be simultaneously enabled'
self.fp16_master_weights_and_gradients = get_fp16_master_weights_and_grads_enabled(param_dict)
Expand Down
4 changes: 4 additions & 0 deletions deepspeed/runtime/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@
BFLOAT16_ENABLED = "enabled"
BFLOAT16_ENABLED_DEFAULT = False

# BFLOAT16 optimizer gradient accumulation via hooks
BFLOAT16_GRAD_ACC_VIA_HOOKS = "accumulate_grads_via_hooks"
BFLOAT16_GRAD_ACC_VIA_HOOKS_DEFAULT = False

#########################################
# FP16 support
#########################################
Expand Down
3 changes: 2 additions & 1 deletion deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1451,7 +1451,8 @@ def _configure_bf16_optimizer(self, optimizer):
allgather_bucket_size=self.zero_allgather_bucket_size(),
dp_process_group=self.seq_data_parallel_group,
timers=timers,
grad_acc_dtype=self.get_data_types()[1])
grad_acc_dtype=self.get_data_types()[1],
accumulate_grads_via_hooks=self._config.bfloat16_accumulate_grads_via_hooks)

return optimizer

Expand Down
3 changes: 2 additions & 1 deletion deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -766,7 +766,8 @@ def _exec_backward_pass(self, buffer_id):

if self.using_bf16_optimizer and not self.is_last_stage():
# manually call because we don't call optimizer.backward()
self.optimizer.update_hp_grads(clear_lp_grads=False)
if not self._config.bfloat16_accumulate_grads_via_hooks:
self.optimizer.update_hp_grads(clear_lp_grads=False)

# Free up the memory from the output of forward()
self.pipe_buffers['output_tensors'][buffer_id] = None
Expand Down

0 comments on commit 1272438

Please sign in to comment.