From 6cb792008212ac86607e6e08183bc35e45d36815 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Fri, 7 Jul 2023 12:52:33 +0800 Subject: [PATCH] [checkpointio] unsharded optimizer checkpoint for Gemini using all_gather --- colossalai/nn/optimizer/cpu_adam.py | 3 - colossalai/nn/optimizer/fused_adam.py | 3 - colossalai/nn/optimizer/hybrid_adam.py | 3 - colossalai/testing/comparison.py | 3 +- colossalai/zero/gemini/gemini_optimizer.py | 146 ++++++++---------- .../test_gemini_checkpoint_io.py | 5 +- .../test_gemini_torch_compability.py | 25 +-- 7 files changed, 84 insertions(+), 104 deletions(-) diff --git a/colossalai/nn/optimizer/cpu_adam.py b/colossalai/nn/optimizer/cpu_adam.py index 8e7652d64d1b..3a6d37103398 100644 --- a/colossalai/nn/optimizer/cpu_adam.py +++ b/colossalai/nn/optimizer/cpu_adam.py @@ -176,6 +176,3 @@ def step(self, closure=None, div_scale: float = -1): raise RuntimeError self._post_step() return loss - - def get_state_names(self): - return ['step', 'exp_avg', 'exp_avg_sq'] diff --git a/colossalai/nn/optimizer/fused_adam.py b/colossalai/nn/optimizer/fused_adam.py index 7c645f955b6a..82a6250f1fd1 100644 --- a/colossalai/nn/optimizer/fused_adam.py +++ b/colossalai/nn/optimizer/fused_adam.py @@ -147,6 +147,3 @@ def step(self, closure=None, grads=None, output_params=None, scale=None, grad_no group['weight_decay'], div_scale) return loss - - def get_state_names(self): - return ['step', 'exp_avg', 'exp_avg_sq'] diff --git a/colossalai/nn/optimizer/hybrid_adam.py b/colossalai/nn/optimizer/hybrid_adam.py index 7f982a8e9b7b..84903ac36832 100644 --- a/colossalai/nn/optimizer/hybrid_adam.py +++ b/colossalai/nn/optimizer/hybrid_adam.py @@ -152,6 +152,3 @@ def step(self, closure=None, div_scale: float = -1): bias_correction, group['weight_decay'], div_scale) self._post_step() return loss - - def get_state_names(self): - return ['step', 'exp_avg', 'exp_avg_sq'] diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index 2d9d77891f91..8d9ec8ab5f35 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -39,7 +39,8 @@ def assert_equal_in_group(tensor: Tensor, process_group: ProcessGroup = None): def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True): - assert len(list(d1.keys())) == len(list(d2.keys())) + assert len(list(d1.keys())) == len(list(d2.keys())), \ + f"Number of keys unequal: {len(list(d1.keys()))} vs {len(list(d2.keys()))}" for k, v1 in d1.items(): assert k in d2 v2 = d2[k] diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 18f0571bf990..99aff6f1c527 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -1,8 +1,8 @@ # this code is inspired by the DeepSpeed library and implemented with our own design from scratch import copy +import gc import math import warnings -from collections import abc as container_abcs from typing import Any, Dict, Set, Tuple import torch @@ -348,10 +348,11 @@ def get_offsets(self, param_id: int) -> tuple: chunk_offset(int): Offset of parameter inside the chunk. shard_offset(int): Offset of its optimizer state shard relative to the whole optimizer state. - param_size(int): Length of parameter shard owned by current process. + shard_size(int): Length of parameter shard owned by current process. ''' - assert param_id in self.id_to_fake_params + if param_id not in self.id_to_fake_params: + return -1, -1, -1 fake_param = self.id_to_fake_params[param_id] chunk = self.param_to_chunk32[fake_param].paired_chunk param = self.id_to_real_params[param_id] @@ -360,15 +361,16 @@ def get_offsets(self, param_id: int) -> tuple: begin_in_chunk, end_in_chunk = self.param_to_range[fake_param] chunk_offset = begin_in_chunk shard_offset = begin_in_chunk + chunk.shard_begin - param_info.offset - param_size = end_in_chunk - begin_in_chunk + shard_size = end_in_chunk - begin_in_chunk assert chunk_offset >= 0 and shard_offset >= 0 - return chunk_offset, shard_offset, param_size + return chunk_offset, shard_offset, shard_size - def collect_states(self, param_id: int) -> dict: + def collect_states(self, param_id: int, only_rank_0: bool = True) -> dict: """ Args: param_id (int): id of the parameter whose state is to be gathered at master rank. + only_rank_0(bool): if True, states will be collected only on master rank, otherwise collected on every rank. Returns: collected_states(dict): the gathered optimzier state of parameter with given id @@ -379,21 +381,44 @@ def collect_states(self, param_id: int) -> dict: # Get param & chunk & process group. param = self.id_to_real_params[param_id] + fake_param = self.id_to_fake_params.get(param_id, None) chunk = self.chunk_manager.get_chunk(param) process_group = chunk.torch_pg rank = dist.get_rank(process_group) master_rank = 0 - state_names = self.optim.get_state_names() - collected_states = {} + # Fetch names of states through all_gather. + local_state_names = None + if fake_param is not None: + local_state_names = list(self.optim.state[fake_param].keys()) + gathered_state_names = [None for _ in range(dist.get_world_size(process_group))] + dist.barrier() + dist.all_gather_object(gathered_state_names, local_state_names) + state_names = None + for names in gathered_state_names: + if names is not None: + # Assume different devices share the same set of state names if they have. + state_names = copy.deepcopy(names) + break + + # Directly return if this parameter doesn't have optimizer states. + # e.g. parameter freezed/layer dropped + if state_names is None: + return collected_states + + # Boolean variable is_collector indicates that whether the current rank + # needs to gather the whole optimizer states. + # Only master rank is collector when only_rank_0 is True. + # Every rank is collector when only_rank_0 is False. + is_collector = (rank == master_rank) or (not only_rank_0) + # If the chunk is kept gathered, # the parameteres are treated the same as that of those in strict DDP during training. # So states can be directly fetched from current device. if chunk.keep_gathered: assert param_id in self.id_to_fake_params - fake_param = self.id_to_fake_params[param_id] - if rank == master_rank: + if is_collector: states = self.optim.state[fake_param] for state_name in state_names: if state_name == 'step': @@ -408,42 +433,8 @@ def collect_states(self, param_id: int) -> dict: # Check whether the param with given id is managed by current process. own_param = param_id in self.id_to_fake_params - # Compute position information (offsets) of state shard. - # If current process doesn't control this param, position message should be [rank, -1, -1] - # else it should be [rank, start_offset, end_offset] - state_shard_range = torch.tensor([rank, -1, -1], dtype=torch.int, requires_grad=False).cuda() - if own_param: - _, shard_offset, param_size = self.get_offsets(param_id) - state_shard_range[1] = shard_offset - state_shard_range[2] = shard_offset + param_size - - # Ranks other than master send position messages to master. - # The master rank should receive a dict mapping rank number to state shard offsets, in the form of: - # {rank_0: (0, x_1), rank_1: (x_1, x_2), rank2: (x_2, x_3), ... rank_{n-1}: (x_{n-1}, param.numels()) } - dist.barrier(process_group) - range_info = dict() - if rank == master_rank: - - # Record if master owns this param. - if state_shard_range[-1] >= 0: - range_info[master_rank] = (state_shard_range[-2].item(), state_shard_range[-1].item()) - - container_tensor = torch.zeros_like(state_shard_range, dtype=torch.int, requires_grad=False).cuda() - for src in range(0, dist.get_world_size(process_group)): - if src == master_rank: - continue - dist.recv(container_tensor, src, process_group) - # Only keeps valid position messages. - if container_tensor[-1] >= 0: - range_info[container_tensor[0].item()] = (container_tensor[-2].item(), container_tensor[-1].item()) - assert len(list(range_info.keys())) > 0 - - else: - dist.send(state_shard_range, master_rank, process_group) - dist.barrier(process_group) - - # Master get prepared for state collecting. - if rank == master_rank: + # Collector gets prepared for state collecting. + if is_collector: for state_name in state_names: if state_name == 'step': # To keep aligned with pytorch, state 'step' is stored as a pytorch tensor with type float32. @@ -452,35 +443,33 @@ def collect_states(self, param_id: int) -> dict: collected_states[state_name] = torch.zeros(param.numel(), dtype=torch.float32, requires_grad=False).cpu() - # Master starts to receive state tensors from other ranks - # after all the position messages of state shards has been collected. - compacted_states = self.pack_optimizer_states_to_tensor(param_id) - if rank == master_rank: - for src, shard_range in range_info.items(): + # Materials for gathering, including compacted state tensors, and the offset of shard inside each state. + compacted_states = self.pack_optimizer_states_to_tensor(param_id, state_names) if own_param else None + _, shard_offset, shard_size = self.get_offsets(param_id) - # If src is master, directly load from compacted_states tensor. - if src == master_rank: - self.load_from_compacted_states(compacted_states, collected_states, shard_range) - continue + # Collectors gather state shards through all_gathering. + gathered_state_shards = [None for _ in range(dist.get_world_size(process_group))] - # If src is not master, receive states from other ranks. - shard_size = shard_range[1] - shard_range[0] - num_states = len(state_names) - container_size = 1 + (num_states - 1) * shard_size if 'step' in state_names \ - else num_states * shard_size - container_tensor = torch.zeros(container_size, dtype=torch.float32, requires_grad=False).cuda() - dist.recv(container_tensor, src, process_group) - self.load_from_compacted_states(container_tensor, collected_states, shard_range) - del container_tensor - else: - if own_param: - dist.send(compacted_states, master_rank, process_group) + dist.barrier() + dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size]) - del compacted_states - dist.barrier(process_group) + if is_collector: + for state_shard in gathered_state_shards: + compacted_states = state_shard[0] + shard_offset = state_shard[1] + shard_size = state_shard[2] + if compacted_states is None: + continue + self.load_from_compacted_states(compacted_states, collected_states, state_names, shard_offset, + shard_size) + + # Clean gathered states + for state_shard in gathered_state_shards: + del state_shard[0] + gc.collect() # Reshape tensors - if rank == master_rank: + if is_collector: for state_name, state_tensor in collected_states.items(): if state_tensor.numel() == param.numel(): collected_states[state_name] = torch.reshape(state_tensor, param.shape) @@ -489,6 +478,7 @@ def collect_states(self, param_id: int) -> dict: def pack_optimizer_states_to_tensor(self, param_id: int, + state_names: list, device: torch.device = torch.device('cuda'), dtype: torch.dtype = torch.float32) -> torch.Tensor: ''' @@ -502,7 +492,7 @@ def pack_optimizer_states_to_tensor(self, states = self.optim.state[fake_param] shard_size = param_range[1] - param_range[0] compacted_size = 0 - for name in self.optim.get_state_names(): + for name in state_names: if name == 'step': compacted_size += 1 else: @@ -526,16 +516,16 @@ def pack_optimizer_states_to_tensor(self, return compacted_states - def load_from_compacted_states(self, compacted_states: torch.Tensor, collected_states: dict, shard_range: tuple): + def load_from_compacted_states(self, compacted_states: torch.Tensor, collected_states: dict, state_names: list, + shard_start: int, shard_size: int): ''' Given a tensor carrying compacted optimizer states, update these states to collected_states. ''' - shard_start, shard_end = shard_range - shard_size = shard_end - shard_start + shard_end = shard_start + shard_size next_state_offset = 0 - for state_name in self.optim.get_state_names(): + for state_name in state_names: if state_name == 'step': collected_states['step'].data = torch.tensor(compacted_states[next_state_offset].item(), dtype=torch.float32, @@ -564,10 +554,6 @@ def state_dict(self, only_rank_0: bool = True) -> dict: Warning: This method will gather and return the whole optimizer state_dict, so it should be called only when memory resources are abundant. """ - - if not only_rank_0: - raise ValueError("False 'only_rank_0' in self.state_dict() method is currently not supported") - state_dict = {} state_dict['param_groups'] = copy.deepcopy(self.param_groups_backup) @@ -589,7 +575,7 @@ def state_dict(self, only_rank_0: bool = True) -> dict: state_dict['state'] = dict() for param_id in self.id_to_real_params.keys(): dist.barrier() - state_dict['state'][param_id] = self.collect_states(param_id=param_id) + state_dict['state'][param_id] = self.collect_states(param_id=param_id, only_rank_0=only_rank_0) return state_dict def load_param_groups(self, saved_param_groups: list): diff --git a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py index aab6d5bf7912..0235ff2e2c81 100644 --- a/tests/test_checkpoint_io/test_gemini_checkpoint_io.py +++ b/tests/test_checkpoint_io/test_gemini_checkpoint_io.py @@ -58,7 +58,7 @@ def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: b def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_shard: int): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() - plugin = GeminiPlugin(placement_policy=placement_policy) + plugin = GeminiPlugin(placement_policy=placement_policy, precision="fp16", initial_scale=(2**14)) booster = Booster(plugin=plugin) model = model_fn() @@ -91,7 +91,8 @@ def exam_state_dict(placement_policy, shard: bool, model_name: str, size_per_sha new_model.unwrap().state_dict(only_rank_0=False), False) booster.load_optimizer(new_optimizer, optimizer_ckpt_path) - check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False) + check_state_dict_equal(optimizer.unwrap().state_dict(only_rank_0=False), + new_optimizer.unwrap().state_dict(only_rank_0=False), False) # Check the new model/optimizer can successfully run. data = data_gen_fn() diff --git a/tests/test_checkpoint_io/test_gemini_torch_compability.py b/tests/test_checkpoint_io/test_gemini_torch_compability.py index dfd08c000e40..b34e3e3a1310 100644 --- a/tests/test_checkpoint_io/test_gemini_torch_compability.py +++ b/tests/test_checkpoint_io/test_gemini_torch_compability.py @@ -25,7 +25,7 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str): (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) criterion = lambda x: x.mean() - plugin = GeminiPlugin() + plugin = GeminiPlugin(precision="fp16", initial_scale=(2**14)) booster = Booster(plugin=plugin) model = model_fn() @@ -65,12 +65,7 @@ def exam_torch_load_from_gemini(shard: bool, model_name: str): new_model.state_dict(), False) new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) - old_state_dict = optimizer.state_dict() - new_state_dict = new_optimizer.state_dict() - - # The complete optimizer state_dict of GeminiPlugin is only collected on rank 0 - if dist.get_rank() == 0: - check_state_dict_equal(old_state_dict, new_state_dict, False) + check_state_dict_equal(optimizer.unwrap().state_dict(only_rank_0=False), new_optimizer.state_dict(), False) # Check the new model/optimizer can successfully run. data = data_gen_fn() @@ -134,12 +129,18 @@ def exam_gemini_load_from_torch(shard: bool, model_name: str): model.state_dict(), False) new_booster.load_optimizer(new_optimizer, optimizer_ckpt_path) - new_state_dict = new_optimizer.state_dict() old_state_dict = optimizer.state_dict() - - # The complete optimizer state_dict of GeminiPlugin is only collected on rank 0 - if dist.get_rank() == 0: - check_state_dict_equal(new_state_dict, old_state_dict, False) + new_state_dict = new_optimizer.unwrap().state_dict(only_rank_0=False) + + # Comparison of param_groups needs special care here, + # since not all hyperparameters in Adam are used by HybridAdam + hyperparameters_to_examine = ['params', 'lr', 'betas', 'eps', 'weight_decay'] + for old_group, new_group in zip(old_state_dict['param_groups'], new_state_dict['param_groups']): + for k in hyperparameters_to_examine: + assert k in old_group and k in new_group, \ + f"Old group's keys: {list(old_group.keys())}, New group's keys: {list(new_group.keys())}" + assert old_group[k] == new_group[k] + check_state_dict_equal(old_state_dict['state'], new_state_dict['state'], False) # Check the new model/optimizer can successfully run. data = data_gen_fn()