Skip to content

Commit

Permalink
elimitate duplicated buffer for lp param
Browse files Browse the repository at this point in the history
  • Loading branch information
tohtana committed Sep 3, 2024
1 parent 2a4733e commit e9a499e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 24 deletions.
36 changes: 13 additions & 23 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from deepspeed.runtime.zero.config import ZeroStageEnum
from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum
from deepspeed.runtime.zero.parameter_offload import DeepSpeedZeRoOffload
from deepspeed.runtime.zero.utils import apply_to_tensors_only
from deepspeed.runtime.zero.utils import apply_to_tensors_only, get_mapping_to_flat_buffer
from deepspeed.ops.adam import DeepSpeedCPUAdam
from deepspeed.runtime.swap_tensor.partitioned_param_swapper import PartitionedParamStatus
from deepspeed.runtime.swap_tensor.optimizer_utils import OptimizerSwapper
Expand Down Expand Up @@ -566,21 +566,15 @@ def defragment(tensors: List[Tensor]) -> Tensor:
cpu_buffer = torch.empty(sum(p.numel() for p in tensors),
dtype=get_only_unique_item(t.dtype for t in tensors),
device="cpu")
tensor_infos: List[Tuple[Tensor, int, int]] = []
tensor_infos: List[Tuple[Tensor, int, int]] = get_mapping_to_flat_buffer(tensors)
orig_device = get_only_unique_item(t.device for t in tensors)

offset = 0
for tensor in tensors:
tensor_numel = tensor.numel()
for tensor, offset, tensor_numel in tensor_infos:
# move the tensor from device memory to host memory
cpu_buffer.narrow(0, offset, tensor_numel).copy_(tensor)
tensor.data = torch.empty(0, dtype=tensor.dtype, device=tensor.device)

# record some data so we can restore the device tensor later
tensor_infos.append((tensor, offset, tensor_numel))

offset += tensor_numel

gc.collect()
get_accelerator().empty_cache()

Expand Down Expand Up @@ -2828,20 +2822,14 @@ def needs_offload(target):
if not hasattr(self, "lp_param_contiguous_pin_buffer"):
self.lp_param_contiguous_pin_buffer = get_accelerator().pin_memory(
torch.empty_like(self.lp_param_buffer, device=device))
self.lp_params_pin_buffers = [
get_accelerator().pin_memory(torch.empty_like(p.ds_tensor, device=device))
for p in self.module.parameters()
]
self.lp_param_contiguous_pin_buffer.copy_(self.lp_param_buffer, non_blocking=non_blocking)
self.lp_param_buffer.data = self.lp_param_contiguous_pin_buffer

for p, buf in zip(self.module.parameters(), self.lp_params_pin_buffers):
buf.copy_(p.ds_tensor.data, non_blocking=non_blocking)
p.ds_tensor.data = buf
cpu_buffer = self.lp_param_contiguous_pin_buffer
else:
self.lp_param_buffer.data = self.lp_param_buffer.to(device, non_blocking=non_blocking)
for p in self.module.parameters():
p.ds_tensor.data = p.ds_tensor.data.to(device, non_blocking=non_blocking)
cpu_buffer = self.lp_param_buffer

for tensor, offset, tensor_numel in get_mapping_to_flat_buffer(self.module.parameters()):
tensor.data = cpu_buffer.narrow(0, offset, tensor_numel)

self.fp16_partitioned_groups_flat.clear()
self.offloaded_states.add(OffloadStateTypeEnum.lp_params)
Expand Down Expand Up @@ -2895,11 +2883,13 @@ def reload_states(self, non_blocking: bool = False):

# LP Param
if OffloadStateTypeEnum.lp_params in self.offloaded_states:
self.lp_param_buffer.data = self.lp_param_buffer.data.to(device, non_blocking=non_blocking)
cpu_buffer = self.lp_param_contiguous_pin_buffer if hasattr(
self, "lp_param_contiguous_pin_buffer") else self.lp_param_buffer
self.lp_param_buffer.data = cpu_buffer.data.to(device, non_blocking=non_blocking)
self._set_fp16_partitioned_groups_flat()

for p in self.module.parameters():
p.ds_tensor.data = p.ds_tensor.data.to(device, non_blocking=non_blocking)
for tensor, offset, tensor_numel in get_mapping_to_flat_buffer(self.module.parameters()):
tensor.data = self.lp_param_buffer.narrow(0, offset, tensor_numel)
self.offloaded_states.remove(OffloadStateTypeEnum.lp_params)

# LP grad
Expand Down
15 changes: 14 additions & 1 deletion deepspeed/runtime/zero/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# DeepSpeed Team

import os
from typing import List
from typing import List, Tuple

import torch
from deepspeed import comm as dist
Expand Down Expand Up @@ -160,3 +160,16 @@ def apply_to_tensors_only(function, value, warning_msg_fn=None):
logger.warning(warning_msg_fn(value))
warned = True
return value


def get_mapping_to_flat_buffer(tensors: List[torch.Tensor]) -> List[Tuple[torch.Tensor, int, int]]:
tensor_infos: List[Tuple[torch.Tensor, int, int]] = []

offset = 0
for tensor in tensors:
tensor_numel = tensor.numel()
# record some data so we can restore the device tensor later
tensor_infos.append((tensor, offset, tensor_numel))
offset += tensor_numel

return tensor_infos

0 comments on commit e9a499e

Please sign in to comment.