Skip to content

Commit

Permalink
[gemini]remove registered gradients hooks (hpcaitech#5696)
Browse files Browse the repository at this point in the history
* fix gemini

fix gemini

* fix

fix
  • Loading branch information
flybird11111 authored May 9, 2024
1 parent 2229778 commit d4c5ef4
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 46 deletions.
11 changes: 10 additions & 1 deletion colossalai/zero/gemini/chunk/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@ class ChunkManager:
init_device (torch.device): optional, the device on which the chunk is initialized. The default is None.
"""

def __init__(self, chunk_configuration, init_device: Optional[torch.device] = None) -> None:
def __init__(
self,
chunk_configuration,
init_device: Optional[torch.device] = None,
reuse_fp16_chunk: bool = True,
) -> None:
self.device = init_device or get_accelerator().get_current_device()
self.dp_degree_chunk_size_dict: Dict[int, int] = dict()
self.kwargs_config = chunk_configuration
Expand All @@ -33,6 +38,10 @@ def __init__(self, chunk_configuration, init_device: Optional[torch.device] = No
self.accessed_chunks: Set[Chunk] = set()
self.accessed_mem: int = 0
self.total_mem: Dict[str, int] = {"cpu": 0, "cuda": 0}
self.reuse_fp16_chunk = reuse_fp16_chunk
# Whether model is accumulating gradients,
self.accumulating_grads = False
self.overflow_counter = 0

def register_tensor(
self,
Expand Down
7 changes: 6 additions & 1 deletion colossalai/zero/gemini/chunk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def init_chunk_manager(
model: nn.Module,
init_device: Optional[torch.device] = None,
hidden_dim: Optional[int] = None,
reuse_fp16_chunk: bool = True,
verbose: bool = False,
**kwargs,
) -> ChunkManager:
Expand Down Expand Up @@ -50,5 +51,9 @@ def init_chunk_manager(
)
dist.barrier()

chunk_manager = ChunkManager(config_dict, init_device)
chunk_manager = ChunkManager(
config_dict,
init_device,
reuse_fp16_chunk=reuse_fp16_chunk,
)
return chunk_manager
105 changes: 69 additions & 36 deletions colossalai/zero/gemini/gemini_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,14 @@ def __init__(
verbose: bool = False,
) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
self.enable_gradient_accumulation = enable_gradient_accumulation
if chunk_config_dict is not None:
self.chunk_manager = ChunkManager(chunk_config_dict, chunk_init_device)
self.chunk_manager = ChunkManager(
chunk_config_dict,
chunk_init_device,
reuse_fp16_chunk=reuse_fp16_chunk,
)
else:
# some ugly hotfix for the compatibility with Lightning
if search_range_m is None:
Expand All @@ -112,6 +118,7 @@ def __init__(
min_chunk_size_m=min_chunk_size_m,
strict_ddp_flag=strict_ddp_mode,
process_group=zero_group,
reuse_fp16_chunk=reuse_fp16_chunk,
verbose=verbose,
)
self.gemini_manager = GeminiManager(
Expand All @@ -128,7 +135,6 @@ def __init__(
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
self.fp32_params: List[torch.Tensor] = list()
self.fp16_params: List[ColoParameter] = list()
self.overflow_counter = 0
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
self.param2name: Dict[nn.Parameter, str] = dict()
self.name2param: Dict[str, nn.Parameter] = dict()
Expand All @@ -137,14 +143,8 @@ def __init__(
self.zero_group = zero_group or _get_default_group()
self.extra_dp_group = extra_dp_group

self.reuse_fp16_chunk = master_weights
self.master_weights = master_weights

self.enable_gradient_accumulation = enable_gradient_accumulation
if self.enable_gradient_accumulation:
self.reuse_fp16_chunk = False
self.accumulating_grads = False # Whether model is accumulating gradients

self._logger = get_dist_logger()

if self.gemini_manager._premade_memstats_:
Expand Down Expand Up @@ -178,7 +178,29 @@ def __init__(
if is_ddp_ignored(p):
continue
if p.requires_grad:
p.register_hook(partial(self.grad_handle, p))
p._grad_handle = p.register_hook(
partial(
GeminiDDP.grad_handle,
chunk_manager=self.chunk_manager,
param2name=self.param2name,
grads_device=self.grads_device,
master_weights=self.master_weights,
enable_gradient_accumulation=self.enable_gradient_accumulation,
p=p,
)
)

def remove_hooks(self):
for p in self.module.parameters():
if is_ddp_ignored(p):
continue
if p.requires_grad:
assert hasattr(p, "_grad_handle")
p._grad_handle.remove()
delattr(p, "_grad_handle")

def __del__(self):
self.remove_hooks()

def parameters(self, recurse: bool = True):
return self.module.parameters(recurse)
Expand Down Expand Up @@ -324,8 +346,8 @@ def _post_backward(self):
f"{error_str}",
)
self._setup_grads_ptr()
if self.enable_gradient_accumulation and not self.accumulating_grads:
self.accumulating_grads = True # Turn on the state of gradient accumulation.
if self.enable_gradient_accumulation and not self.chunk_manager.accumulating_grads:
self.chunk_manager.accumulating_grads = True # Turn on the state of gradient accumulation.
self._logger.debug(
f"comp cuda demand time: {self.gemini_manager._comp_cuda_demand_time}, layout time: {self.gemini_manager._layout_time}, evict time: {self.gemini_manager._evict_time}, CPU->CUDA vol: {self.gemini_manager._h2d_volume}B, CUDA->CPU vol: {self.gemini_manager._d2h_volume}"
)
Expand All @@ -340,25 +362,34 @@ def backward(self, loss: torch.Tensor):
def backward_by_grad(self, tensor, grad):
raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.")

def grad_handle(self, p, grad):
@staticmethod
def grad_handle(
grad,
chunk_manager: ChunkManager,
param2name: Dict,
grads_device: Dict,
master_weights: bool,
enable_gradient_accumulation: bool,
p: nn.Parameter,
):
setattr(p, "_gemini_reduced", True)
empty_grad = torch.empty_like(grad)
free_storage(empty_grad)
with torch._C.DisableTorchFunction():
chunk = self.chunk_manager.get_chunk(p)
chunk = chunk_manager.get_chunk(p)
if chunk.tensors_info[p].state != TensorState.HOLD_AFTER_BWD:
raise RuntimeError(
f"Parameter `{self.param2name[p]}` failed at the gradient reduction. "
f"Parameter `{param2name[p]}` failed at the gradient reduction. "
"Some unsupported torch function is operated upon this parameter."
)
grad_chunk = chunk
if not self.reuse_fp16_chunk:
if not self.accumulating_grads:
grad_chunk = self.chunk_manager.init_grad_chunk(chunk)
if not chunk_manager.reuse_fp16_chunk:
if not chunk_manager.accumulating_grads:
grad_chunk = chunk_manager.init_grad_chunk(chunk)
else:
assert chunk.grad_chunk is not None
if chunk.grad_chunk not in self.chunk_manager.accessed_chunks:
grad_chunk = self.chunk_manager.rearrange_accumulated_grad_chunk(chunk)
if chunk.grad_chunk not in chunk_manager.accessed_chunks:
grad_chunk = chunk_manager.rearrange_accumulated_grad_chunk(chunk)
else:
grad_chunk = chunk.grad_chunk
chunk.grad_chunk.l2_norm = None
Expand All @@ -371,33 +402,33 @@ def grad_handle(self, p, grad):
chunk.tensor_trans_state(p, TensorState.HOLD)

grad_chunk.tensor_trans_state(p, TensorState.READY_FOR_REDUCE)
if not self.accumulating_grads:
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=self.reuse_fp16_chunk)
if not chunk_manager.accumulating_grads:
grad_chunk.copy_tensor_to_chunk_slice(p, grad, update_ptr=chunk_manager.reuse_fp16_chunk)
else:
grad_chunk.add_tensor_to_chunk_slice(p, grad)
reduced = self.chunk_manager.reduce_chunk(grad_chunk)
reduced = chunk_manager.reduce_chunk(grad_chunk)
if reduced:
if not self.reuse_fp16_chunk:
if not chunk_manager.reuse_fp16_chunk:
if chunk.keep_gathered:
self.chunk_manager.fake_release_chunk(chunk)
chunk_manager.fake_release_chunk(chunk)
else:
self.chunk_manager.release_chunk(chunk)
chunk_manager.release_chunk(chunk)
if grad_chunk.is_gathered:
grad_chunk.cuda_global_chunk.div_(chunk.pg_size)
if self.extra_dp_group is not None:
if chunk.extra_dp_group is not None:
grad_chunk.cuda_global_chunk.div_(chunk.extra_dp_size)
else:
grad_chunk.cuda_shard.div_(chunk.pg_size)
if self.extra_dp_group is not None:
if chunk.extra_dp_group is not None:
grad_chunk.cuda_shard.div_(chunk.extra_dp_size)
# check overflow elements
self.overflow_counter += grad_chunk.has_inf_or_nan
chunk_manager.overflow_counter += grad_chunk.has_inf_or_nan
# record l2 norm for gradient clipping. flag is bound to fp16 chunk
if chunk.l2_norm_flag:
grad_chunk.set_l2_norm()
self.chunk_manager.move_chunk(grad_chunk, self.grads_device[p], force_copy=True)
if not (self.master_weights) or (self.enable_gradient_accumulation):
self.chunk_manager.move_chunk(chunk, self.grads_device[p], force_copy=True)
chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True)
if not (master_weights) or (enable_gradient_accumulation):
chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True)
return empty_grad

def zero_grad(self, set_to_none: bool = False) -> None:
Expand Down Expand Up @@ -513,11 +544,11 @@ def _save_to_state_dict(self, destination, prefix, keep_vars, only_rank_0=True):

# get copies of fp32 parameters in CPU
# as memory of fp16_params may be reused by grad, it's not reliable, we should use fp32_params and convert to fp16
params = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params
params = self.fp32_params if self.chunk_manager.reuse_fp16_chunk else self.fp16_params
param_to_save_data = self._get_param_to_save_data(params, only_rank_0)
# get the mapping between copies and fp16 parameters
p_mapping = dict()
if self.reuse_fp16_chunk:
if self.chunk_manager.reuse_fp16_chunk:
for p, fp32_p in zip(self.fp16_params, self.fp32_params):
name = self.param2name[p]
assert fp32_p in param_to_save_data, "Parameter '{}' is neglected in the chunk list".format(name)
Expand Down Expand Up @@ -713,7 +744,7 @@ def load_parameter(chunk_slice, data):
name = self.param2name[p]
fp32_to_name[fp32_p] = name

params_to_load = self.fp32_params if self.reuse_fp16_chunk else self.fp16_params
params_to_load = self.fp32_params if self.chunk_manager.reuse_fp16_chunk else self.fp16_params
chunk_list = self.chunk_manager.get_chunks(params_to_load)
for chunk in chunk_list:
temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision)
Expand All @@ -728,7 +759,9 @@ def load_parameter(chunk_slice, data):
shard_fn = tensor.shard_fn
gather_fn = tensor.gather_fn

parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor]
parameter_name = (
fp32_to_name[tensor] if self.chunk_manager.reuse_fp16_chunk else self.param2name[tensor]
)
parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end]
load(
parameter_name,
Expand Down Expand Up @@ -900,7 +933,7 @@ def state_dict_shard(
gathered_param = param if keep_vars else param.detach()
else:
# as memory of fp16 param may be reused, we should use fp32 param and then convert to fp16
param_to_save = fp16_to_fp32[param] if self.reuse_fp16_chunk else param
param_to_save = fp16_to_fp32[param] if self.chunk_manager.reuse_fp16_chunk else param
if param_to_save not in gathered_param_buffer:
chunk = self.chunk_manager.get_chunk(param_to_save)
gathered_param_buffer.update(self._get_chunk_to_save_data(chunk, only_rank_0))
Expand Down
14 changes: 7 additions & 7 deletions colossalai/zero/gemini/gemini_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ def __init__(
self.module = module

def check_local_overflow(self) -> bool:
return self.module.overflow_counter > 0
return self.module.chunk_manager.overflow_counter > 0

def pre_zero_grad(self) -> None:
self.module.overflow_counter = 0
self.module.chunk_manager.overflow_counter = 0


class GeminiOptimizer(OptimizerWrapper):
Expand Down Expand Up @@ -202,7 +202,7 @@ def _set_grad_ptr(self):
chunk16 = self.param_to_chunk16[fake_param]
begin, end = self.param_to_range[fake_param]

grad_chunk16 = chunk16 if self.module.reuse_fp16_chunk else chunk16.grad_chunk
grad_chunk16 = chunk16 if self.module.chunk_manager.reuse_fp16_chunk else chunk16.grad_chunk
fake_param.data = grad_chunk16.payload[begin:end]
fake_param.grad = fake_param.data

Expand All @@ -221,14 +221,14 @@ def _update_fp16_params(self):

def _clear_global_norm(self) -> None:
for c16 in self.chunk16_set:
grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk
grad_chunk = c16 if self.module.chunk_manager.reuse_fp16_chunk else c16.grad_chunk
grad_chunk.l2_norm = None

def _calc_global_norm(self) -> float:
norm_sqr: float = 0.0
group_to_norm = dict()
for c16 in self.chunk16_set:
grad_chunk = c16 if self.module.reuse_fp16_chunk else c16.grad_chunk
grad_chunk = c16 if self.module.chunk_manager.reuse_fp16_chunk else c16.grad_chunk
assert grad_chunk.l2_norm is not None

if grad_chunk.is_gathered:
Expand Down Expand Up @@ -275,7 +275,7 @@ def step(self, *args, **kwargs):
self._logger.info(f"Found overflow. Skip step")
self._clear_global_norm() # clear recorded norm
self.zero_grad() # reset all gradients
if self.module.reuse_fp16_chunk:
if self.module.chunk_manager.reuse_fp16_chunk:
self._update_fp16_params()
return

Expand All @@ -288,7 +288,7 @@ def step(self, *args, **kwargs):
self.zero_grad()
if self.module.master_weights:
self._update_fp16_params()
self.module.accumulating_grads = False
self.module.chunk_manager.accumulating_grads = False
return ret

def clip_grad_norm(self, model: torch.nn.Module, max_norm: float, norm_type: float = 2.0):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_zero/test_gemini/test_fwd_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
chunk_manager = model.chunk_manager
param_list = [p for p in model.parameters()]
chunk_list = chunk_manager.get_chunks(param_list)
if not model.reuse_fp16_chunk:
if not model.chunk_manager.reuse_fp16_chunk:
chunk_list = [chunk.grad_chunk for chunk in chunk_list]
for chunk in chunk_list:
chunk_manager.access_chunk(chunk)
Expand Down

0 comments on commit d4c5ef4

Please sign in to comment.