Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DeepSpeedZeroOptimizer: refactor bit16 flattening to support more accelerators #4833

Merged
merged 11 commits into from
Jan 11, 2024
50 changes: 33 additions & 17 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,6 @@ def get_alignment_padding(tensor_list, alignment):
return (alignment - remainder) if remainder else remainder


def move_to_cpu(tensor_list):
for tensor in tensor_list:
tensor.data = tensor.data.cpu()


def print_rank_msg(msg):
print(f"rank {dist.get_rank()} - {msg}")

Expand Down Expand Up @@ -294,6 +289,7 @@ def __init__(self,

self.round_robin_bit16_groups = []
self.round_robin_bit16_indices = []
self.round_robin_bit16_meta = []

# Use different parallel to do all_to_all_reduce related things
# padding on each partition for alignment purposes
Expand All @@ -316,7 +312,14 @@ def __init__(self,

see_memory_usage(f"Before moving param group {i} to CPU")
# move all the parameters to cpu to free up GPU space for creating flat buffer
move_to_cpu(self.bit16_groups[i])

# Create temp CPU param copies, free accelerator tensors
orig_group_numel = 0
for param in self.bit16_groups[i]:
param.cpu_data = param.data.cpu()
param.data = torch.empty(1).to(param.device)
orig_group_numel += param.numel()

empty_cache()
see_memory_usage(f"After moving param group {i} to CPU", force=False)

Expand All @@ -334,18 +337,31 @@ def __init__(self,
self.round_robin_bit16_groups.append(round_robin_tensors)
self.round_robin_bit16_indices.append(round_robin_indices)

# create flat buffer in CPU and move to GPU
self.bit16_groups_flat.append(
self.flatten_dense_tensors_aligned(
self.round_robin_bit16_groups[i],
self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_dp_process_group[i])).to(
get_accelerator().current_device_name()))
# Create meta tensors list, ordered according to round_robin_tensors
meta_tensors = []
for param in round_robin_tensors:
meta_tensors.append(torch.zeros_like(param.cpu_data, device="meta"))
self.round_robin_bit16_meta.append(meta_tensors)

# create flat buffer in CPU
flattened_buffer = self.flatten_dense_tensors_aligned(
self.round_robin_bit16_groups[i],
self.nccl_start_alignment_factor * dist.get_world_size(group=self.real_dp_process_group[i]),
use_cpu_data=True)

# free temp CPU params
for param in self.bit16_groups[i]:
del param.cpu_data

# Move CPU flat tensor to the accelerator memory.
self.bit16_groups_flat.append(flattened_buffer.to(get_accelerator().current_device_name()))
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
del flattened_buffer

see_memory_usage(f"After flattening and moving param group {i} to GPU", force=False)

# Record padding required for alignment
if partition_id == dist.get_world_size(group=self.real_dp_process_group[i]) - 1:
padding = self.bit16_groups_flat[i].numel() - sum(
[t.numel() for t in self.round_robin_bit16_groups[i]])
padding = self.bit16_groups_flat[i].numel() - orig_group_numel
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
else:
padding = 0
self.groups_padding.append(padding)
Expand Down Expand Up @@ -596,8 +612,7 @@ def _configure_moe_settings(self):
assert self.ep_process_group is not None, "Expert parallel group should be configured with MoE"

def _update_model_bit16_weights(self, group_index):
updated_params = self.unflatten(self.bit16_groups_flat[group_index],
self.round_robin_bit16_groups[group_index])
updated_params = self.unflatten(self.bit16_groups_flat[group_index], self.round_robin_bit16_meta[group_index])
for p, q in zip(self.round_robin_bit16_groups[group_index], updated_params):
p.data = q.data

Expand Down Expand Up @@ -887,7 +902,8 @@ def report_ipg_memory_usage(self, tag, param_elems):
)

# create a flat tensor aligned at the alignment boundary
def flatten_dense_tensors_aligned(self, tensor_list, alignment):
def flatten_dense_tensors_aligned(self, tensor_list, alignment, use_cpu_data=False):
tensor_list = [param.cpu_data for param in tensor_list] if use_cpu_data else tensor_list
return self.flatten(align_dense_tensors(tensor_list, alignment))

############### Independent Partition Gradient ########################
Expand Down
Loading