From 7e2028ecdf3dbf6a52e6e9141418ca268305676f Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Wed, 5 Jul 2023 09:51:56 +0800 Subject: [PATCH] [checkpointio] unsharded optimizer checkpoint for Gemini plugin --- colossalai/booster/plugin/gemini_plugin.py | 78 ++-- .../checkpoint_io/checkpoint_io_base.py | 2 + .../checkpoint_io/general_checkpoint_io.py | 14 +- colossalai/checkpoint_io/utils.py | 22 +- colossalai/interface/optimizer.py | 6 + colossalai/nn/optimizer/cpu_adam.py | 3 + colossalai/nn/optimizer/fused_adam.py | 3 + colossalai/nn/optimizer/hybrid_adam.py | 3 + colossalai/testing/comparison.py | 47 ++- colossalai/zero/gemini/gemini_optimizer.py | 353 +++++++++++++++++- .../test_gemini_checkpoint_io.py | 155 ++++++-- 11 files changed, 607 insertions(+), 79 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 1173589fcd49..eb7eaff6e728 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -33,44 +33,40 @@ def __init__(self) -> None: super().__init__() self.coordinator = DistCoordinator() - def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True): - """ - Load model from checkpoint with automatic unwrapping. - """ - # the model should be unwrapped in self.load_model via ModelWrapper.unwrap - return super().load_unsharded_model(model, checkpoint, strict=strict) - def save_unsharded_model(self, model: GeminiDDP, checkpoint: str, gather_dtensor: bool, use_safetensors: bool): """ - Save model to checkpoint but only on master process. + Save sharded model to checkpoint but only on master process. + The model should be unwrapped in self.load_model via ModelWrapper.unwrap. + As there is communication when getting state dict, this must be called on all processes. """ - # the model should be unwrapped in self.load_model via ModelWrapper.unwrap - # as there is communication when get state dict, this must be called on all processes state_dict = model.state_dict(only_rank_0=True) if self.coordinator.is_master(): save_state_dict(state_dict, checkpoint, use_safetensors) - def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): + def load_unsharded_model(self, model: GeminiDDP, checkpoint: str, strict: bool = True): """ - Save optimizer to checkpoint but only on master process. + Load model from checkpoint with automatic unwrapping. + The model should be unwrapped in self.load_model via ModelWrapper.unwrap. """ - # TODO(ver217): optimizer state dict is sharded - warnings.warn('GeminiPlugin does not support save full optimizer checkpoint now. Save it on every process.') - checkpoint = f'{checkpoint}.rank{self.coordinator.rank}' - super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor) - - def load_optimizer(self, optimizer: Optimizer, checkpoint: str): - warnings.warn( - 'GeminiPlugin can only load optimizer checkpoint saved by itself with the same number of processes.') - checkpoint = f'{checkpoint}.rank{self.coordinator.rank}' - super().load_optimizer(optimizer, checkpoint) + super().load_unsharded_model(model, checkpoint, strict=strict) - def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + def save_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str, gather_dtensor: bool): """ - Save model to checkpoint but only on master process. + Save unsharded optimizer state dict to checkpoint. + After calling optimizer.state_dict(), the complete optimizer states will be collected on master rank. + As there is communication when getting state dict, this must be called on all processes. + The saving process will only be executed by master rank. """ + state_dict = optimizer.state_dict() if self.coordinator.is_master(): - super().save_lr_scheduler(lr_scheduler, checkpoint) + save_state_dict(state_dict, checkpoint, use_safetensors=False) + + def load_unsharded_optimizer(self, optimizer: Optimizer, checkpoint: str): + """ + Loading unsharded optimizer from checkpoint file. + For each process, only loading optimizer states of parameters it controls. + """ + super().load_unsharded_optimizer(optimizer, checkpoint) def save_sharded_model(self, model: GeminiDDP, @@ -82,6 +78,12 @@ def save_sharded_model(self, """ Save sharded model """ + if os.path.isfile(checkpoint_path): + logging.error(f"Provided path ({checkpoint_path}) should be a directory, not a file") + return + + Path(checkpoint_path).mkdir(parents=True, exist_ok=True) + state_dict_shard = model.state_dict_shard(max_shard_size=max_shard_size, only_rank_0=True, dtype=torch.float32) weights_name, save_index_file = get_model_base_filenames(prefix, use_safetensors) total_size = 0 @@ -117,6 +119,30 @@ def load_sharded_model(self, """ return super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module=False) + def save_sharded_optimizer(self, optimizer: Optimizer, checkpoint: Path, gather_dtensor: bool, prefix: str, + size_per_shard: int): + """ + Save sharded optimizer state dict to checkpoint folder. + As there is communication when getting state dict, this must be called on all processes. + """ + Path(checkpoint).mkdir(parents=True, exist_ok=True) + super().save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard) + + def load_sharded_optimizer(self, optimizer: Optimizer, checkpoint_index_file: Path, prefix: str): + """ + Loading sharded optimizer from checkpoint folder, with index file given. + For each process, only loading optimizer states of parameters it controls. + """ + # TODO(Baizhou): To be implemented. + pass + + def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str): + """ + Save model to checkpoint but only on master process. + """ + if self.coordinator.is_master(): + super().save_lr_scheduler(lr_scheduler, checkpoint) + class GeminiModel(ModelWrapper): @@ -219,7 +245,7 @@ def __init__( min_chunk_size_m: float = 32, memstats: Optional[MemStats] = None, gpu_margin_mem_ratio: float = 0.0, - initial_scale: float = 2**32, + initial_scale: float = 2**16, min_scale: float = 1, growth_factor: float = 2, backoff_factor: float = 0.5, diff --git a/colossalai/checkpoint_io/checkpoint_io_base.py b/colossalai/checkpoint_io/checkpoint_io_base.py index 8ff9d87c288e..baff24e1cb25 100644 --- a/colossalai/checkpoint_io/checkpoint_io_base.py +++ b/colossalai/checkpoint_io/checkpoint_io_base.py @@ -152,6 +152,7 @@ def load_optimizer(self, optimizer: Optimizer, checkpoint: str, prefix: str = No names to compose the keys in state_dict. Defaults to None. size_per_shard (int, optional): Maximum size of checkpoint shard file in MB. This is useful only when ``shard=True``. Defaults to 1024. """ + index_file_exists, index_file_path = has_index_file(checkpoint) if Path(checkpoint).is_dir() and not index_file_exists: @@ -186,6 +187,7 @@ def save_optimizer(self, prefix (str): prefix for the optimizer checkpoint when shard = True. Default: None. size_per_shard (int): size per shard in MB. Default: 1024. This value is only used when shard is set to True. """ + if shard: self.save_sharded_optimizer(optimizer, checkpoint, gather_dtensor, prefix, size_per_shard) else: diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 26cafcada2c5..e1d9066948dd 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -28,6 +28,7 @@ shard_model_checkpoint, shard_optimizer_checkpoint, sharded_optimizer_loading_epilogue, + unwrap_optimizer, ) __all__ = ['GeneralCheckpointIO'] @@ -59,7 +60,7 @@ def load_sharded_optimizer(self, optimizer: Optimizer, index_file_path: str, pre # If optimizer is wrapped, unwrap it. if isinstance(optimizer, OptimizerWrapper): - optimizer = optimizer.optim + optimizer = unwrap_optimizer(optimizer) # Read checkpoint index file. ckpt_index_file = CheckpointIndexFile.from_file(index_file_path) @@ -96,6 +97,11 @@ def save_sharded_optimizer( - A group file (pytorch_optim_group.bin) recording information of param_groups - Multiple files (pytorch_optim-000XX.bin) that store state tensors of optimizer in a sharding way """ + + # If optimizer is wrapped, unwrap it. + if isinstance(optimizer, OptimizerWrapper): + optimizer = unwrap_optimizer(optimizer) + if os.path.isfile(checkpoint): logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") return @@ -121,9 +127,8 @@ def save_sharded_optimizer( shard, current_size = shard_pair shard_file = get_shard_filename(states_name, idx) total_size = total_size + current_size - for param_id in shard.keys(): - index_file.append_weight_map(str(param_id), shard_file) - + for key in shard.keys(): + index_file.append_weight_map(key, shard_file) checkpoint_file_path = os.path.join(checkpoint, shard_file) save_state_dict(shard, checkpoint_file_path, use_safetensors=False) @@ -177,7 +182,6 @@ def save_sharded_model(self, total_size = total_size + shard_pair[1] for key in shard.keys(): index_file.append_weight_map(key, shard_file) - checkpoint_file_path = os.path.join(checkpoint_path, shard_file) save_state_dict(shard, checkpoint_file_path, use_safetensors) diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 485577b9650c..f8756b1fed1c 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -10,6 +10,8 @@ import torch.nn as nn from torch.optim import Optimizer +from colossalai.interface import OptimizerWrapper +from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.tensor.d_tensor import is_distributed_tensor SAFE_WEIGHTS_NAME = "model.safetensors" @@ -88,6 +90,17 @@ def is_safetensor_checkpoint(checkpoint_file_path: str) -> bool: # ====================================== # Helper functions for saving shard file # ====================================== +def unwrap_optimizer(optimizer: OptimizerWrapper): + ''' + Unwrap a wrapped optimizer. + This method should be used before saving/loading it to/from sharded checkpoints. + ''' + unwrapped_optim = optimizer.optim + if isinstance(unwrapped_optim, ColossalaiOptimizer): + unwrapped_optim = unwrapped_optim.optim + return unwrapped_optim + + def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) -> Iterator[Tuple[OrderedDict, int]]: """ Splits a model state dictionary in sub-checkpoints so that the final size of each sub-checkpoint does not exceed a @@ -103,7 +116,7 @@ def shard_model_checkpoint(state_dict: torch.Tensor, max_shard_size: int = 1024) weight_size = calculate_tensor_size(weight) # If this weight is going to tip up over the maximal size, we split. - if current_block_size + weight_size > max_shard_size: + if current_block_size + weight_size > max_shard_size and current_block_size > 0: ret_block = current_block ret_block_size = current_block_size current_block = {} @@ -140,9 +153,10 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> isDTensor = False for state_tensor in state.values(): - # When state_tensor is None (e.g., a SGD optimizer with momentum set to 0), + # When state_tensor is not of Tensor class, + # e.g., a SGD optimizer with momentum set to 0 can have None as state # The calculation of tensor size should be skipped to avoid error. - if state_tensor is None: + if not isinstance(state_tensor, torch.Tensor): continue # If the states are stored as DTensors, mark isDTensor as true. @@ -152,7 +166,7 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) -> if not isDTensor: - if current_block_size + state_size > max_shard_size: + if current_block_size + state_size > max_shard_size and current_block_size > 0: ret_block = current_block ret_block_size = current_block_size current_block = {} diff --git a/colossalai/interface/optimizer.py b/colossalai/interface/optimizer.py index dd9acab17584..0eaf2e1ef8ba 100644 --- a/colossalai/interface/optimizer.py +++ b/colossalai/interface/optimizer.py @@ -119,3 +119,9 @@ def unscale_grad(self): """ raise NotImplementedError( "The method unscale_grad is only available for optimizers with mixed precision training") + + def unwrap(self): + """ + Unwrap the optimizer for checkpoint saving/loading. + """ + return self.optim diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 3a6d37103398..8e7652d64d1b 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -176,3 +176,6 @@ def step(self, closure=None, div_scale: float = -1): raise RuntimeError self._post_step() return loss + + def get_state_names(self): + return ['step', 'exp_avg', 'exp_avg_sq'] diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py index 82a6250f1fd1..7c645f955b6a 100644 --- a/colossalai/nn/optimizer/fused_adam.py +++ b/colossalai/nn/optimizer/fused_adam.py @@ -147,3 +147,6 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no group['weight_decay'], div_scale) return loss + + def get_state_names(self): + return ['step', 'exp_avg', 'exp_avg_sq'] diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index 84903ac36832..7f982a8e9b7b 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -152,3 +152,6 @@ def step(self, closure=None, div_scale: float = -1): bias_correction, group['weight_decay'], div_scale) self._post_step() return loss + + def get_state_names(self): + return ['step', 'exp_avg', 'exp_avg_sq'] diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index 5cbfb936b144..123fc38b941c 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -16,7 +16,12 @@ def assert_not_equal(a: Tensor, b: Tensor): def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1e-3): - assert_close(a, b, rtol=rtol, atol=atol) + assert_close(a, + b, + rtol=rtol, + atol=atol, + msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \ + dtype: {a.dtype} vs {b.dtype}") def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): @@ -33,25 +38,35 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True): - for k, v in d1.items(): - if isinstance(v, dict): - check_state_dict_equal(v, d2[k]) - elif isinstance(v, list): - for i in range(len(v)): - if isinstance(v[i], torch.Tensor): + assert len(list(d1.keys())) == len(list(d2.keys())) + for k, v1 in d1.items(): + assert k in d2 + v2 = d2[k] + if isinstance(v1, dict): + assert isinstance(v2, dict) + check_state_dict_equal(v1, v2, ignore_device) + elif isinstance(v1, list): + assert isinstance(v2, list) + for v1_i, v2_i in zip(v1, v2): + if isinstance(v1_i, torch.Tensor): + assert isinstance(v2_i, torch.Tensor) if not ignore_device: - v[i] = v[i].to("cpu") - d2[k][i] = d2[k][i].to("cpu") - assert torch.equal(v[i], d2[k][i]) + v1_i = v1_i.to("cpu") + v2_i = v2_i.to("cpu") + assert_close_loose(v1_i, v2_i) + elif isinstance(v1_i, dict): + assert isinstance(v2_i, dict) + check_state_dict_equal(v1_i, v2_i, ignore_device) else: - assert v[i] == d2[k][i] - elif isinstance(v, torch.Tensor): + assert v1_i == v2_i, f"{v1_i} not equals to {v2_i}" + elif isinstance(v1, torch.Tensor): + assert isinstance(v2, torch.Tensor) if not ignore_device: - v = v.to("cpu") - d2[k] = d2[k].to("cpu") - assert torch.equal(v, d2[k]) + v1 = v1.to("cpu") + v2 = v2.to("cpu") + assert_close_loose(v1, v2) else: - assert v == d2[k] + assert v1 == v2, f"{v1} not equals to {v2}" def assert_hf_output_close(out1: Any, diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 267deb1e8699..b019938fceee 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -1,6 +1,8 @@ # this code is inspired by the DeepSpeed library and implemented with our own design from scratch +import copy import math import warnings +from collections import abc as container_abcs from typing import Any, Dict, Set, Tuple import torch @@ -101,6 +103,11 @@ def __init__(self, self.clipping_flag = clipping_norm > 0.0 self.max_norm = clipping_norm self.verbose = verbose + self.param_groups_backup = list() + + # Mapping from integer id to real/fake param tensor, used for checkpointing. + self.id_to_real_params: Dict[int, Parameter] = dict() + self.id_to_fake_params: Dict[int, Parameter] = dict() if self.clipping_flag: assert norm_type == 2.0, "ZeroOptimizer only supports L2 norm now" @@ -301,25 +308,365 @@ def get_range_pair(local_chunk: Chunk, local_param: Parameter): end = min(local_chunk.shard_size, param_info.end - local_chunk.shard_begin) return begin, end + param_id = -1 for group in self.optim.param_groups: fake_params_list = list() - + group_backup = {k: v for k, v in group.items() if k != 'params'} + group_ids = [] for param in group['params']: + + # Record the mapping of id to current param. + param_id += 1 + self.id_to_real_params[param_id] = param + group_ids.append(param_id) + + # If current param is controlled by current process, add it to fake_param. if is_ddp_ignored(param): continue chunk16 = self.chunk_manager.get_chunk(param) range_pair = get_range_pair(chunk16, param) if range_pair[0] >= range_pair[1]: continue - grad_device = self.module.grads_device[param] fake_param = torch.nn.Parameter(torch.empty([0], device=grad_device)) self.param_to_chunk32[fake_param] = chunk16.paired_chunk self.param_to_range[fake_param] = range_pair - + self.id_to_fake_params[param_id] = fake_param fake_params_list.append(fake_param) + # Update self.optim.param_groups as well as backup group. group['params'] = fake_params_list + group_backup['params'] = group_ids + self.param_groups_backup.append(group_backup) + + def get_offsets(self, param_id: int) -> tuple: + # Args: param_id(int): The id of parameter. + # + # Returns: chunk_offset(int):Offset of parameter inside the chunk. + # shard_offset(int): Offset of its optimizer state shard + # relative to the whole optimizer state. + # param_size(int): Length of parameter shard owned by current process. + + assert param_id in self.id_to_fake_params + fake_param = self.id_to_fake_params[param_id] + chunk = self.param_to_chunk32[fake_param].paired_chunk + param = self.id_to_real_params[param_id] + param_info = chunk.tensors_info[param] + + param_range = self.param_to_range[fake_param] + begin_in_chunk, end_in_chunk = param_range + chunk_offset = begin_in_chunk + shard_offset = begin_in_chunk + chunk.shard_begin - param_info.offset + param_size = end_in_chunk - begin_in_chunk + assert chunk_offset >= 0 and shard_offset >= 0 + + return chunk_offset, shard_offset, param_size + + def collect_states(self, param_id: int) -> dict: + """ + Args: + param_id (int): id of the parameter whose state is to be gathered at master rank. + + Returns: + collected_states(dict): the gathered optimzier state of parameter with given id + if this method is called by master rank, otherwise an empty dict. + + This method can work only when called by all processes simultaneously. + """ + + # Get param & chunk & process group. + param = self.id_to_real_params[param_id] + chunk = self.chunk_manager.get_chunk(param) + process_group = chunk.torch_pg + rank = dist.get_rank(process_group) + master_rank = 0 + state_names = self.optim.get_state_names() + + collected_states = {} + + # If the chunk is kept gathered, + # the parameteres are treated the same as that of those in strict DDP during training. + # So states can be directly fetched from current device. + if chunk.keep_gathered: + assert param_id in self.id_to_fake_params + fake_param = self.id_to_fake_params[param_id] + if rank == master_rank: + states = self.optim.state[fake_param] + for state_name in state_names: + if state_name == 'step': + # To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32. + collected_states[state_name] = torch.tensor(states['step'], + dtype=torch.float32, + requires_grad=False).cpu() + else: + collected_states[state_name] = states[state_name].detach().clone().to(torch.float32).cpu() + return collected_states + + # Check whether the param with given id is managed by current process. + own_param = param_id in self.id_to_fake_params + + # Compute position information (offsets) of state shard. + # If current process doesn't control this param, position message should be [rank, -1, -1] + # else it should be [rank, start_offset, end_offset] + state_shard_range = torch.tensor([rank, -1, -1], dtype=torch.int, requires_grad=False).cuda() + if own_param: + _, shard_offset, param_size = self.get_offsets(param_id) + state_shard_range[1] = shard_offset + state_shard_range[2] = shard_offset + param_size + + # Ranks other than master send position messages to master. + # The master rank should receive a dict mapping rank number to state shard offsets, in the form of: + # {rank_0: (0, x_1), rank_1: (x_1, x_2), rank2: (x_2, x_3), ... rank_{n-1}: (x_{n-1}, param.numels()) } + dist.barrier(process_group) + range_info = dict() + if rank == master_rank: + + # Record if master owns this param. + if state_shard_range[-1] >= 0: + range_info[master_rank] = (state_shard_range[-2].item(), state_shard_range[-1].item()) + + container_tensor = torch.zeros_like(state_shard_range, dtype=torch.int, requires_grad=False).cuda() + for src in range(0, dist.get_world_size(process_group)): + if src == master_rank: + continue + dist.recv(container_tensor, src, process_group) + # Only keeps valid position messages. + if container_tensor[-1] >= 0: + range_info[container_tensor[0].item()] = (container_tensor[-2].item(), container_tensor[-1].item()) + assert len(list(range_info.keys())) > 0 + + else: + dist.send(state_shard_range, master_rank, process_group) + dist.barrier(process_group) + + # Master get prepared for state collecting. + if rank == master_rank: + for state_name in state_names: + if state_name == 'step': + # To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32. + collected_states[state_name] = torch.tensor(0.0, dtype=torch.float32, requires_grad=False).cpu() + else: + collected_states[state_name] = torch.zeros(param.numel(), dtype=torch.float32, + requires_grad=False).cpu() + + # Master starts to receive state tensors from other ranks + # after all the position messages of state shards has been collected. + compacted_states = self.pack_optimizer_states_to_tensor(param_id) + if rank == master_rank: + for src, shard_range in range_info.items(): + + # If src is master, directly load from compacted_states tensor. + if src == master_rank: + self.load_from_compacted_states(compacted_states, collected_states, shard_range) + continue + + # If src is not master, receive states from other ranks. + shard_size = shard_range[1] - shard_range[0] + num_states = len(state_names) + container_size = 1 + (num_states - 1) * shard_size if 'step' in state_names \ + else num_states * shard_size + container_tensor = torch.zeros(container_size, dtype=torch.float32, requires_grad=False).cuda() + dist.recv(container_tensor, src, process_group) + self.load_from_compacted_states(container_tensor, collected_states, shard_range) + del container_tensor + else: + if own_param: + dist.send(compacted_states, master_rank, process_group) + + del compacted_states + dist.barrier(process_group) + + # Reshape tensors + if rank == master_rank: + for state_name, state_tensor in collected_states.items(): + if state_name == 'step': + continue + assert state_tensor.numel() == param.numel() + collected_states[state_name] = torch.reshape(state_tensor, param.shape) + + return collected_states + + def pack_optimizer_states_to_tensor(self, + param_id: int, + device: torch.device = torch.device('cuda'), + dtype: torch.dtype = torch.float32) -> torch.Tensor: + ''' + With param id given, pack its optimizer states into a compact tensor and return. + ''' + if param_id not in self.id_to_fake_params: + return None + + fake_param = self.id_to_fake_params[param_id] + param_range = self.param_to_range[fake_param] + states = self.optim.state[fake_param] + shard_size = param_range[1] - param_range[0] + compacted_size = 0 + for name in self.optim.get_state_names(): + if name == 'step': + compacted_size += 1 + else: + compacted_size += shard_size + compacted_states = torch.zeros(compacted_size, dtype=dtype, device=device, requires_grad=False) + + next_state_offset = 0 + for state_name, state_tensor in states.items(): + # State 'step' needs special operation. + if state_name == 'step': + if isinstance(state_tensor, torch.Tensor): + compacted_states[next_state_offset] = state_tensor[0].item() + else: + assert isinstance(state_tensor, int) + compacted_states[next_state_offset] = state_tensor + next_state_offset += 1 + else: + assert state_tensor.numel() == shard_size + compacted_states[next_state_offset:next_state_offset + shard_size].copy_(state_tensor) + next_state_offset += shard_size + + return compacted_states + + def load_from_compacted_states(self, compacted_states: torch.Tensor, collected_states: dict, shard_range: tuple): + ''' + Given a tensor carrying compacted optimizer states, + update these states to collected_states. + ''' + shard_start, shard_end = shard_range + shard_size = shard_end - shard_start + next_state_offset = 0 + + for state_name in self.optim.get_state_names(): + if state_name == 'step': + collected_states['step'].data = torch.tensor(compacted_states[next_state_offset].item(), + dtype=torch.float32, + requires_grad=False).cpu() + next_state_offset += 1 + else: + target_segment = collected_states[state_name][shard_start:shard_end] + target_segment.copy_(compacted_states[next_state_offset:next_state_offset + shard_size]) + next_state_offset += shard_size + + def state_dict(self, only_rank_0: bool = True) -> dict: + """ + Args: + only_rank_0 (bool): a boolean value indicating whether the state_dict is collected + only on rank 0, dafault to True. + + Returns: + The complete state of the optimizer as a :class:`dict`. + It contains two entries: + + * state - a dict holding current optimization state. Its content + differs between optimizer classes. + * param_groups - a list containing all parameter groups where each + parameter group is a dict. + + Warning: This method will gather and return the whole optimizer state_dict, + so it should be called only when memory resources are abundant. + """ + + if not only_rank_0: + raise ValueError("False 'only_rank_0' in self.state_dict() method is currently not supported") + + state_dict = {} + state_dict['param_groups'] = copy.deepcopy(self.param_groups_backup) + + torch_special_hyperparameters = { + 'amsgrad': False, + 'maximize': False, + 'foreach': None, + 'capturable': False, + 'differentiable': False, + 'fused': False + } + + for group in state_dict['param_groups']: + for k, v in torch_special_hyperparameters.items(): + if k not in group: + group[k] = v + + # Collect optimizer states. + state_dict['state'] = dict() + for param_id in self.id_to_real_params.keys(): + dist.barrier() + state_dict['state'][param_id] = self.collect_states(param_id=param_id) + return state_dict + + def load_param_groups(self, saved_param_groups: list): + """ + Load saved_param_groups into + self.param_groups and self.param_groups_backup + """ + self.param_groups_backup = copy.deepcopy(saved_param_groups) + + # discard the older param_groups + self.optim.param_groups = [] + + for group in saved_param_groups: + fake_params_list = list() + updated_group = {k: v for k, v in group.items() if k != 'params'} + for param_id in group['params']: + if param_id not in self.id_to_fake_params: + continue + fake_param = self.id_to_fake_params[param_id] + fake_params_list.append(fake_param) + updated_group['params'] = fake_params_list + self.optim.param_groups.append(updated_group) + + def load_single_param_states(self, param_id: int, saved_states: dict): + """ + Load saved optimizer states into parameter with given id. + """ + + def cast(param, state_range, value, key=None): + """ + Make a copy of the needed segment of value and cast it to device of param. + """ + assert isinstance(value, torch.Tensor) + ret_val = value + if (key == "step"): + assert value.numel() == 1 + ret_val = int(value.item()) + else: + state_start, state_end = state_range + ret_val = torch.zeros(state_end - state_start, + dtype=torch.float32, + device=param.device, + requires_grad=False) + ret_val.copy_(value.flatten()[state_start:state_end]) + return ret_val + + assert param_id in self.id_to_fake_params + fake_param = self.id_to_fake_params[param_id] + _, state_offset, param_size = self.get_offsets(param_id) + state_range = (state_offset, state_offset + param_size) + + # Copy states assigned to param (and cast tensors to appropriate types). + updated_states = dict() + for k, v in saved_states.items(): + updated_states[k] = cast(fake_param, state_range, v, k) + del v # clean loaded states + self.optim.state[fake_param].update(updated_states) + + def load_state_dict(self, state_dict: dict): + """Loads optimizer state from whole optimizer state_dict. + During loading, filter out the part of states not considered by current process. + + Args: + state_dict (dict): optimizer state. Should be an object returned + from a call to :meth:`state_dict`. + """ + assert 'param_groups' in state_dict + self.load_param_groups(state_dict['param_groups']) + + state = state_dict['state'] + + for param_id, param_states in state.items(): + if param_id in self.id_to_fake_params: + self.load_single_param_states(param_id, param_states) + + # Epilogue for pytorch optimizer. + self.optim._hook_for_profile() # To support multiprocessing pickle/unpickle. + self.optim.defaults.setdefault('differentiable', False) class GeminiAdamOptimizer(ZeroOptimizer): diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index 602cf468c944..6e9cbdaffa23 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -3,17 +3,14 @@ import pytest import torch import torch.distributed as dist +from torch.optim import Adam from utils import shared_tempdir import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin -from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO +from colossalai.booster.plugin import GeminiPlugin, TorchDDPPlugin from colossalai.nn.optimizer import HybridAdam from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn -from colossalai.zero import ZeroDDP -from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration -from colossalai.zero.gemini.gemini_mgr import GeminiManager from tests.kit.model_zoo import model_zoo @@ -29,30 +26,29 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b pretrained_path = os.path.join(tempdir, 'pretrained') bert_model.config.save_pretrained(save_directory=pretrained_path) - # TODO(ver217): use boost api - config_dict, *_ = search_chunk_configuration(bert_model, search_range_m=1, search_interval=100) - chunk_manager = ChunkManager(config_dict) - gemini_manager = GeminiManager(placement_policy, chunk_manager) - bert_model = ZeroDDP(bert_model, gemini_manager) - bert_model.train() - - ckpt_io = GeminiCheckpointIO() + plugin = GeminiPlugin(placement_policy=placement_policy) + booster = Booster(plugin=plugin) + bert_model, _, _, _, _ = booster.boost(bert_model) model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 - ckpt_io.save_model(bert_model, (pretrained_path), + + booster.save_model(bert_model, + pretrained_path, True, True, '', (model_size / 3), use_safetensors=use_safetensors) dist.barrier() + new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path) - check_state_dict_equal(bert_model.state_dict(only_rank_0=False, dtype=torch.float32), + check_state_dict_equal(bert_model.unwrap().state_dict(only_rank_0=False, dtype=torch.float32), new_bert_model.state_dict(), False) @parameterize('placement_policy', ['cuda', 'cpu']) -@parameterize('shard', [True, False]) +@parameterize('shard', [False]) @parameterize('model_name', ['transformers_gpt']) -def exam_state_dict(placement_policy, shard: bool, model_name: str): +@parameterize('size_per_shard', [32]) +def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_shard: int): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() plugin = GeminiPlugin(placement_policy=placement_policy) @@ -78,18 +74,125 @@ def exam_state_dict(placement_policy, shard: bool, model_name: str): with shared_tempdir() as tempdir: model_ckpt_path = f"{tempdir}/model" optimizer_ckpt_path = f"{tempdir}/optimizer" - booster.save_model(model, model_ckpt_path) - if not shard: - # TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint - booster.save_optimizer(optimizer, optimizer_ckpt_path) + booster.save_model(model, model_ckpt_path, shard=shard, size_per_shard=size_per_shard) + + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard, size_per_shard=size_per_shard) dist.barrier() booster.load_model(new_model, model_ckpt_path) check_state_dict_equal(model.unwrap().state_dict(only_rank_0=False), new_model.unwrap().state_dict(only_rank_0=False), False) - if not shard: - booster.load_optimizer(new_optimizer, optimizer_ckpt_path) - check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) + + booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) + + +@parameterize('shard', [False]) +@parameterize('model_name', ['transformers_gpt']) +def exam_torch_load_from_gemini(shard: bool, model_name: str): + + (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + criterion = lambda x: x.mean() + plugin = GeminiPlugin() + booster = Booster(plugin=plugin) + + model = model_fn() + optimizer = HybridAdam(model.parameters(), lr=0.001) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + data = data_gen_fn() + data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + + booster.backward(loss, optimizer) + optimizer.step() + + with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" + optimizer_ckpt_path = f"{tempdir}/optimizer" + + booster.save_model(model, model_ckpt_path, shard=shard) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard) + dist.barrier() + + new_model = model_fn() + new_optimizer = Adam(new_model.parameters(), lr=0.001) + new_plugin = TorchDDPPlugin() + new_booster = Booster(plugin=new_plugin) + new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion) + + # Loading HybridAdam states to torch.Adam + new_booster.load_model(new_model, model_ckpt_path, strict=True) + + # Add prefix to get aligned with pytorch parameter names. + check_state_dict_equal( + model.unwrap().state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32), + new_model.state_dict(), False) + + new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + old_state_dict = optimizer.state_dict() + new_state_dict = new_optimizer.state_dict() + + # The complete optimizer state_dict of GeminiPlugin is only collected on rank 0 + if dist.get_rank() == 0: + check_state_dict_equal(old_state_dict, new_state_dict, False) + + +@parameterize('shard', [False]) +@parameterize('model_name', ['transformers_gpt']) +def exam_gemini_load_from_torch(shard: bool, model_name: str): + + (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) + criterion = lambda x: x.mean() + plugin = TorchDDPPlugin() + booster = Booster(plugin=plugin) + + model = model_fn() + optimizer = Adam(model.parameters(), lr=0.001) + model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion) + + data = data_gen_fn() + data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()} + output = model(**data) + output = output_transform_fn(output) + output_key = list(output.keys())[0] + loss = criterion(output[output_key]) + + booster.backward(loss, optimizer) + optimizer.step() + + with shared_tempdir() as tempdir: + model_ckpt_path = f"{tempdir}/model" + optimizer_ckpt_path = f"{tempdir}/optimizer" + + booster.save_model(model, model_ckpt_path, shard=shard) + booster.save_optimizer(optimizer, optimizer_ckpt_path, shard=shard) + dist.barrier() + + new_model = model_fn() + new_optimizer = HybridAdam(new_model.parameters(), lr=0.001) + new_plugin = GeminiPlugin() + new_booster = Booster(plugin=new_plugin) + new_model, new_optimizer, criterion, _, _ = new_booster.boost(new_model, new_optimizer, criterion) + + # Loading torch.Adam states to HybridAdam + new_booster.load_model(new_model, model_ckpt_path, strict=True) + + # Add prefix to get aligned with pytorch parameter names. + check_state_dict_equal( + new_model.unwrap().state_dict(only_rank_0=False, prefix='module.module.', dtype=torch.float32), + model.state_dict(), False) + + new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) + new_state_dict = new_optimizer.state_dict() + old_state_dict = optimizer.state_dict() + + # The complete optimizer state_dict of GeminiPlugin is only collected on rank 0 + if dist.get_rank() == 0: + check_state_dict_equal(new_state_dict, old_state_dict, False) def run_dist(rank, world_size, port): @@ -97,10 +200,12 @@ def run_dist(rank, world_size, port): colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') exam_state_dict() exam_state_dict_with_origin() + exam_torch_load_from_gemini() + exam_gemini_load_from_torch() @pytest.mark.dist -@pytest.mark.parametrize('world_size', [2]) +@pytest.mark.parametrize('world_size', [1, 2]) @rerun_if_address_is_in_use() def test_gemini_ckpIO(world_size): spawn(run_dist, world_size)