Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

uniform deepspeed overflow check #5424

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions deepspeed/runtime/bf16_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
from packaging import version as pkg_version
from deepspeed.git_version_info import version
from deepspeed.runtime.utils import (get_global_norm_of_tensors, clip_tensors_by_global_norm, DummyOptim,
align_dense_tensors, all_gather_dp_groups, is_model_parallel_parameter,
see_memory_usage, graph_process, get_norm_with_moe_layers)
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, groups
CheckOverflow, align_dense_tensors, all_gather_dp_groups,
is_model_parallel_parameter, see_memory_usage, graph_process,
get_norm_with_moe_layers)
from deepspeed.utils import link_hp_params, lazy_init_hp_params_optimizer_state, fragment_address, groups, logger
from deepspeed.moe.utils import is_moe_param, is_moe_param_group
from deepspeed.utils.bwc import bwc_tensor_model_parallel_rank
from deepspeed.checkpoint import enable_universal_checkpoint
Expand All @@ -31,6 +32,7 @@ class BF16_Optimizer(ZeROOptimizer):
def __init__(self,
init_optimizer,
param_names,
deepspeed=None,
mpu=None,
clip_grad=0.0,
norm_type=2,
Expand Down Expand Up @@ -92,6 +94,10 @@ def __init__(self,
if self.using_real_optimizer:
self._setup_for_real_optimizer()

# Overflow check init
self.overflow = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should self.overflow be a class member since it seems to be only used once?

self.overflow_checker = CheckOverflow(self.bf16_groups, mpu=self.mpu, deepspeed=deepspeed)

see_memory_usage('end bf16_optimizer', force=True)

def _configure_moe_settings(self):
Expand Down Expand Up @@ -280,6 +286,16 @@ def step(self, closure=None):
if closure is not None:
raise NotImplementedError(f'{self.__class__} does not support closure.')

self.overflow = self.overflow_checker.check()
if self.overflow:
logger.warning(f"all_groups_norm Overflow in BF16_Optimizer. Skipping step.")
see_memory_usage('After overflow before clearing gradients')
self.clear_hp_grads()
see_memory_usage('After overflow after clearing gradients')
#TODO: add timer
#TODO: save ckpt, then crash
return

non_expert_grads_for_norm, expert_grads_for_norm = self.get_grads_for_norm()
non_expert_groups_norm = get_global_norm_of_tensors(input_tensors=non_expert_grads_for_norm,
mpu=self.mpu,
Expand Down
2 changes: 2 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,6 +1473,7 @@ def _configure_bf16_optimizer(self, optimizer):
timers = self.timers if self.wall_clock_breakdown() else NoopTimer()
optimizer = BF16_Optimizer(optimizer,
self.param_names,
deepspeed=self,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't pass self into submodules.

mpu=self.mpu,
clip_grad=clip_grad,
allgather_bucket_size=self.zero_allgather_bucket_size(),
Expand Down Expand Up @@ -1532,6 +1533,7 @@ def _configure_zero_optimizer(self, optimizer):
overlap_comm=overlap_comm,
offload_optimizer_config=self.zero_offload_optimizer(),
mpu=self.mpu,
deepspeed=self,
postscale_gradients=self.postscale_gradients(),
gradient_predivide_factor=self.gradient_predivide_factor(),
gradient_accumulation_steps=self.gradient_accumulation_steps(),
Expand Down
10 changes: 6 additions & 4 deletions deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,13 @@ def get_norm_with_moe_layers_fast(all_groups_norm, group):
class CheckOverflow(object):
'''Checks for overflow in gradient across parallel process'''

def __init__(self, param_groups=None, mpu=None, zero_reduce_scatter=False, deepspeed=None):
def __init__(self, param_groups=None, mpu=None, zero_reduce_scatter=False, deepspeed=None, partition_grads=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing deepseed engine into a submodule is not a good design and create all sorts of cyclic reference issues. It is better to pass the specific attributes that are needed, such as enable_backward_allreduce

self.mpu = mpu
self.params = [] if param_groups else None
self.zero_reduce_scatter = zero_reduce_scatter
self.deepspeed = deepspeed
self.has_moe_params = False
self.partition_grads = partition_grads
if param_groups:
for group in param_groups:
for param in group:
Expand Down Expand Up @@ -234,7 +235,7 @@ def check(self, param_groups=None):
# `params` is a list / generator of torch.Variable
def has_overflow_serial(self, params):
for i, p in enumerate(params):
if p.grad is not None and self._has_inf_or_nan(p.grad.data, i):
if p.grad is not None and self._has_inf_or_nan(p.grad.data):
return True
return False

Expand All @@ -261,15 +262,16 @@ def has_overflow(self, params, has_moe_params=None):
not using_pipeline and self.deepspeed.enable_backward_allreduce is False):
dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_data_parallel_group())
dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.mpu.get_model_parallel_group())
elif self.deepspeed is not None and self.deepspeed.enable_backward_allreduce is False:
elif self.deepspeed is not None and (self.deepspeed.enable_backward_allreduce is False
or self.partition_grads is True):
dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=dist.get_world_group())

overflow = overflow_gpu[0].item()
return bool(overflow)

# `x` is a torch.Tensor
@staticmethod
def _has_inf_or_nan(x, i):
def _has_inf_or_nan(x):
try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as x
Expand Down
71 changes: 16 additions & 55 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from deepspeed.runtime.base_optimizer import ZeROOptimizer
from deepspeed.runtime.fp16.loss_scaler import CreateLossScaler
from deepspeed.runtime.utils import (empty_cache, see_memory_usage, inf, is_model_parallel_parameter,
align_dense_tensors, all_gather_dp_groups)
align_dense_tensors, all_gather_dp_groups, CheckOverflow)
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum
from deepspeed.ops.adam import DeepSpeedCPUAdam
Expand Down Expand Up @@ -124,6 +124,7 @@ def __init__(self,
overlap_comm=False,
offload_optimizer_config=None,
mpu=None,
deepspeed=None,
clip_grad=0.0,
gradient_accumulation_dtype=torch.float32,
communication_data_type=torch.float16,
Expand Down Expand Up @@ -218,7 +219,6 @@ def __init__(self,
self.model_parallel_world_size = mpu.get_model_parallel_world_size()
self.model_parallel_rank = bwc_tensor_model_parallel_rank(mpu)

self.overflow = False
self.clip_grad = clip_grad
self.communication_data_type = communication_data_type
self.gradient_predivide_factor = gradient_predivide_factor
Expand Down Expand Up @@ -551,6 +551,12 @@ def __init__(self,
self._enable_universal_checkpoint()
self._param_slice_mappings = self._create_param_mapping()

self.overflow = False
self.overflow_checker = CheckOverflow(self.bit16_groups,
mpu=self.mpu,
deepspeed=deepspeed,
partition_grads=self.partition_gradients)

def destroy(self):
for hook in self._grad_acc_hooks:
hook.remove()
Expand Down Expand Up @@ -1166,7 +1172,7 @@ def get_grad_position(self, group_id, tensor_list, first_offset, partition_size)

def update_overflow_tracker_for_param_grad(self, param):
grad_accum = self.get_param_gradient_attribute(param)
if grad_accum is not None and self._has_inf_or_nan(grad_accum.data):
if grad_accum is not None and self.overflow_checker._has_inf_or_nan(grad_accum.data):
self.local_overflow = True

def _get_offload_gradient_dict(self):
Expand Down Expand Up @@ -1821,7 +1827,7 @@ def step(self, closure=None):
see_memory_usage(f"In step before checking overflow")

# First compute norm for all group so we know if there is overflow
if self.dtype == torch.float16:
if self.dtype == torch.float16 or self.dtype == torch.bfloat16:
self.check_overflow()

prev_scale = self.loss_scale
Expand Down Expand Up @@ -1973,56 +1979,11 @@ def unscale_and_clip_grads(self, grad_groups_flat, total_norm):
else:
grad.data.mul_(1. / combined_scale)

def _check_overflow(self, partition_gradients=True):
self.overflow = self.has_overflow(partition_gradients)

# `params` is a list / generator of torch.Variable
def has_overflow_serial(self, params):
invalid_grad_count = torch.zeros([1], dtype=torch.float, device=get_accelerator().current_device_name())
for p in params:
if p.grad is not None:
invalid_grad_count += self._has_inf_or_nan(p.grad)
return invalid_grad_count.bool()

def has_overflow_partitioned_grads_serial(self):
invalid_grad_count = torch.zeros([1], dtype=torch.float, device=get_accelerator().current_device_name())
for i in range(len(self.bit16_groups)):
for j, grad in enumerate(self.averaged_gradients[i]):
if grad is not None:
invalid_grad_count += self._has_inf_or_nan(grad)
return invalid_grad_count.bool()

def has_overflow(self, partition_gradients=True):
if partition_gradients:
overflow = self.local_overflow if self.cpu_offload else self.has_overflow_partitioned_grads_serial()
overflow_gpu = get_accelerator().ByteTensor([overflow]) if self.cpu_offload else overflow.byte().to(
get_accelerator().current_device_name())
'''This will capture overflow across all data parallel and expert parallel process
Since expert parallel process are a subset of data parallel process'''
dist.all_reduce(overflow_gpu, op=dist.ReduceOp.MAX, group=self.dp_process_group)

def _check_overflow(self):
if self.cpu_offload:
self.overflow = self.local_overflow
else:
params = []
for group in self.bit16_groups:
for param in group:
params.append(param)
overflow_gpu = self.has_overflow_serial(params).byte().to(get_accelerator().current_device_name())

# Since each model parallel GPU carries only part of the model,
# make sure overflow flag is synced across all the model parallel GPUs
self._model_parallel_all_reduce(tensor=overflow_gpu, op=dist.ReduceOp.MAX)

overflow = overflow_gpu[0].item()
return bool(overflow)

# `x` is a torch.Tensor
@staticmethod
def _has_inf_or_nan(x, j=None):
float_x = x.float()
nan = float_x.isnan()
inf = float_x.isinf()
inf_or_nan = nan.logical_or(inf)
return inf_or_nan.float().max()
self.overflow = self.overflow_checker.check()

def backward(self, loss, retain_graph=False):
"""
Expand Down Expand Up @@ -2059,8 +2020,8 @@ def backward(self, loss, retain_graph=False):
if self.use_grad_accum_attribute:
self.fill_grad_accum_attribute()

def check_overflow(self, partition_gradients=True):
self._check_overflow(partition_gradients)
def check_overflow(self):
self._check_overflow()

def _update_scale(self, has_overflow=False):
self.loss_scaler.update_scale(has_overflow)
Expand Down
Loading