diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 2eec7cbc96a1..c24bc16cdbfe 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -75,11 +75,6 @@ def get_alignment_padding(tensor_list, alignment): return (alignment - remainder) if remainder else remainder -def move_to_cpu(tensor_list): - for tensor in tensor_list: - tensor.data = tensor.data.cpu() - - def print_rank_msg(msg): print(f"rank {dist.get_rank()} - {msg}") @@ -294,6 +289,7 @@ def __init__(self, self.round_robin_bit16_groups = [] self.round_robin_bit16_indices = [] + self.round_robin_bit16_meta = [] # Use different parallel to do all_to_all_reduce related things # padding on each partition for alignment purposes @@ -316,7 +312,14 @@ def __init__(self, see_memory_usage(f"Before moving param group {i} to CPU") # move all the parameters to cpu to free up GPU space for creating flat buffer - move_to_cpu(self.bit16_groups[i]) + + # Create temp CPU param copies, free accelerator tensors + orig_group_numel = 0 + for param in self.bit16_groups[i]: + orig_group_numel += param.numel() + param.cpu_data = param.data.cpu() + param.data = torch.empty(1).to(param.device) + empty_cache() see_memory_usage(f"After moving param group {i} to CPU", force=False) @@ -334,18 +337,31 @@ def __init__(self, self.round_robin_bit16_groups.append(round_robin_tensors) self.round_robin_bit16_indices.append(round_robin_indices) - # create flat buffer in CPU and move to GPU - self.bit16_groups_flat.append( - self.flatten_dense_tensors_aligned( - self.round_robin_bit16_groups[i], - self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_dp_process_group[i])).to( - get_accelerator().current_device_name())) + # Create meta tensors list, ordered according to round_robin_tensors + meta_tensors = [] + for param in round_robin_tensors: + meta_tensors.append(torch.zeros_like(param.cpu_data, device="meta")) + self.round_robin_bit16_meta.append(meta_tensors) + + # create flat buffer in CPU + flattened_buffer = self.flatten_dense_tensors_aligned( + self.round_robin_bit16_groups[i], + self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_dp_process_group[i]), + use_cpu_data=True) + + # free temp CPU params + for param in self.bit16_groups[i]: + del param.cpu_data + + # Move CPU flat tensor to the accelerator memory. + self.bit16_groups_flat.append(flattened_buffer.to(get_accelerator().current_device_name())) + del flattened_buffer + see_memory_usage(f"After flattening and moving param group {i} to GPU", force=False) # Record padding required for alignment if partition_id == dist.get_world_size(group=self.real_dp_process_group[i]) - 1: - padding = self.bit16_groups_flat[i].numel() - sum( - [t.numel() for t in self.round_robin_bit16_groups[i]]) + padding = self.bit16_groups_flat[i].numel() - orig_group_numel else: padding = 0 self.groups_padding.append(padding) @@ -596,8 +612,7 @@ def _configure_moe_settings(self): assert self.ep_process_group is not None, "Expert parallel group should be configured with MoE" def _update_model_bit16_weights(self, group_index): - updated_params = self.unflatten(self.bit16_groups_flat[group_index], - self.round_robin_bit16_groups[group_index]) + updated_params = self.unflatten(self.bit16_groups_flat[group_index], self.round_robin_bit16_meta[group_index]) for p, q in zip(self.round_robin_bit16_groups[group_index], updated_params): p.data = q.data @@ -887,7 +902,8 @@ def report_ipg_memory_usage(self, tag, param_elems): ) # create a flat tensor aligned at the alignment boundary - def flatten_dense_tensors_aligned(self, tensor_list, alignment): + def flatten_dense_tensors_aligned(self, tensor_list, alignment, use_cpu_data=False): + tensor_list = [param.cpu_data for param in tensor_list] if use_cpu_data else tensor_list return self.flatten(align_dense_tensors(tensor_list, alignment)) ############### Independent Partition Gradient ########################