From e9a499ef7ebd11c9d80f89fa2346daf1d7759999 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Tue, 3 Sep 2024 23:09:58 +0000 Subject: [PATCH] elimitate duplicated buffer for lp param --- deepspeed/runtime/zero/stage3.py | 36 ++++++++++++-------------------- deepspeed/runtime/zero/utils.py | 15 ++++++++++++- 2 files changed, 27 insertions(+), 24 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 45e311efa564..d84ac16331f8 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -23,7 +23,7 @@ from deepspeed.runtime.zero.config import ZeroStageEnum from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload -from deepspeed.runtime.zero.utils import apply_to_tensors_only +from deepspeed.runtime.zero.utils import apply_to_tensors_only, get_mapping_to_flat_buffer from deepspeed.ops.adam import DeepSpeedCPUAdam from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus from deepspeed.runtime.swap_tensor.optimizer_utils import OptimizerSwapper @@ -566,21 +566,15 @@ def defragment(tensors: List[Tensor]) -> Tensor: cpu_buffer = torch.empty(sum(p.numel() for p in tensors), dtype=get_only_unique_item(t.dtype for t in tensors), device="cpu") - tensor_infos: List[Tuple[Tensor, int, int]] = [] + tensor_infos: List[Tuple[Tensor, int, int]] = get_mapping_to_flat_buffer(tensors) orig_device = get_only_unique_item(t.device for t in tensors) offset = 0 - for tensor in tensors: - tensor_numel = tensor.numel() + for tensor, offset, tensor_numel in tensor_infos: # move the tensor from device memory to host memory cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor) tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device) - # record some data so we can restore the device tensor later - tensor_infos.append((tensor, offset, tensor_numel)) - - offset += tensor_numel - gc.collect() get_accelerator().empty_cache() @@ -2828,20 +2822,14 @@ def needs_offload(target): if not hasattr(self, "lp_param_contiguous_pin_buffer"): self.lp_param_contiguous_pin_buffer = get_accelerator().pin_memory( torch.empty_like(self.lp_param_buffer, device=device)) - self.lp_params_pin_buffers = [ - get_accelerator().pin_memory(torch.empty_like(p.ds_tensor, device=device)) - for p in self.module.parameters() - ] self.lp_param_contiguous_pin_buffer.copy_(self.lp_param_buffer, non_blocking=non_blocking) - self.lp_param_buffer.data = self.lp_param_contiguous_pin_buffer - - for p, buf in zip(self.module.parameters(), self.lp_params_pin_buffers): - buf.copy_(p.ds_tensor.data, non_blocking=non_blocking) - p.ds_tensor.data = buf + cpu_buffer = self.lp_param_contiguous_pin_buffer else: self.lp_param_buffer.data = self.lp_param_buffer.to(device, non_blocking=non_blocking) - for p in self.module.parameters(): - p.ds_tensor.data = p.ds_tensor.data.to(device, non_blocking=non_blocking) + cpu_buffer = self.lp_param_buffer + + for tensor, offset, tensor_numel in get_mapping_to_flat_buffer(self.module.parameters()): + tensor.data = cpu_buffer.narrow(0, offset, tensor_numel) self.fp16_partitioned_groups_flat.clear() self.offloaded_states.add(OffloadStateTypeEnum.lp_params) @@ -2895,11 +2883,13 @@ def reload_states(self, non_blocking: bool = False): # LP Param if OffloadStateTypeEnum.lp_params in self.offloaded_states: - self.lp_param_buffer.data = self.lp_param_buffer.data.to(device, non_blocking=non_blocking) + cpu_buffer = self.lp_param_contiguous_pin_buffer if hasattr( + self, "lp_param_contiguous_pin_buffer") else self.lp_param_buffer + self.lp_param_buffer.data = cpu_buffer.data.to(device, non_blocking=non_blocking) self._set_fp16_partitioned_groups_flat() - for p in self.module.parameters(): - p.ds_tensor.data = p.ds_tensor.data.to(device, non_blocking=non_blocking) + for tensor, offset, tensor_numel in get_mapping_to_flat_buffer(self.module.parameters()): + tensor.data = self.lp_param_buffer.narrow(0, offset, tensor_numel) self.offloaded_states.remove(OffloadStateTypeEnum.lp_params) # LP grad diff --git a/deepspeed/runtime/zero/utils.py b/deepspeed/runtime/zero/utils.py index 8f913d065934..2d1cf17962d8 100755 --- a/deepspeed/runtime/zero/utils.py +++ b/deepspeed/runtime/zero/utils.py @@ -4,7 +4,7 @@ # DeepSpeed Team import os -from typing import List +from typing import List, Tuple import torch from deepspeed import comm as dist @@ -160,3 +160,16 @@ def apply_to_tensors_only(function, value, warning_msg_fn=None): logger.warning(warning_msg_fn(value)) warned = True return value + + +def get_mapping_to_flat_buffer(tensors: List[torch.Tensor]) -> List[Tuple[torch.Tensor, int, int]]: + tensor_infos: List[Tuple[torch.Tensor, int, int]] = [] + + offset = 0 + for tensor in tensors: + tensor_numel = tensor.numel() + # record some data so we can restore the device tensor later + tensor_infos.append((tensor, offset, tensor_numel)) + offset += tensor_numel + + return tensor_infos