From c749b05cb71cbf9a15978b228b6a1da152960bfd Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Thu, 22 Aug 2024 07:48:47 +0000 Subject: [PATCH] fix pinning --- deepspeed/runtime/utils.py | 2 +- deepspeed/runtime/zero/stage3.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/utils.py b/deepspeed/runtime/utils.py index a71731bf47b0..871fca403e4d 100755 --- a/deepspeed/runtime/utils.py +++ b/deepspeed/runtime/utils.py @@ -1074,7 +1074,7 @@ def move_key(state, key): if pin_memory: pin_mem_key = f"{key}_pin_memory" if pin_mem_key not in state: - state[pin_mem_key] = torch.empty_like(state[key], device=device).pin_memory() + state[pin_mem_key] = get_accelerator().pin_memory(torch.empty_like(state[key], device=device)) state[pin_mem_key].copy_(state[key], non_blocking=non_blocking) state[key].data = state[pin_mem_key] else: diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index f75043248841..96f8ca65e5ea 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -2826,7 +2826,8 @@ def needs_offload(target): if needs_offload(OffloadStateTypeEnum.lp_params): if pin_memory: if not hasattr(self, "lp_param_contiguous_pin_buffer"): - self.lp_param_contiguous_pin_buffer = torch.empty_like(self.lp_param_buffer, device=device) + self.lp_param_contiguous_pin_buffer = get_accelerator().pin_memory( + torch.empty_like(self.lp_param_buffer, device=device)) self.lp_params_pin_buffers = [ get_accelerator().pin_memory(torch.empty_like(p.ds_tensor, device=device)) for p in self.module.parameters()