From c17dc33c04a65606d770705a2c9d4ae3e0ae5a9b Mon Sep 17 00:00:00 2001 From: Xinyu Lian Date: Tue, 21 Jan 2025 12:48:38 -0600 Subject: [PATCH] Using explicit GPU upcast for ZeRO-Offload (#6962) Following discussion in [PR-6670](https://github.com/microsoft/DeepSpeed/pull/6670), the explict upcast is much more efficient than implicit upcast, this PR is to replace implicit upcast with explict one. The results on 3B model are shown below: | Option | BWD (ms) | Speed up | |------------|-----|------| | Before PR-6670 | 25603.30 | 1x | | After PR-6670 | 1174.31 | 21.8X | | After this PR| 309.2 | 82.8X | --- deepspeed/runtime/zero/stage3.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 9c06567ed100..a5c0c3340019 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -546,15 +546,10 @@ def _setup_for_real_optimizer(self): self.grad_partitions_flat_buffer = get_accelerator().pin_memory(self.grad_partitions_flat_buffer) offset = 0 - max_partition_numel = 0 for param in all_params: self.__param_id_to_grad_partition[param.ds_id] = self.grad_partitions_flat_buffer.narrow( 0, offset, param.partition_numel()) offset += param.partition_numel() - max_partition_numel = max(max_partition_numel, param.partition_numel()) - if self.offload_optimizer: - self.pinned_grad_buffer: Tensor = get_accelerator().pin_memory( - torch.empty(max_partition_numel, device=self.device)) def _link_all_hp_params(self): for p in self.module.parameters(): @@ -1510,13 +1505,9 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L offload_fp32_gradients[i].append(grad_buffer.float()) offload_fp32_offsets[i].append(dest_offset) else: - buffer_numel = grad_buffer.numel() fp32_grad_tensor = self.fp32_partitioned_groups_flat[i].grad.narrow( - 0, dest_offset, buffer_numel) - self.pinned_grad_buffer[:buffer_numel].copy_( - grad_buffer.to(dtype=torch.float32, non_blocking=True)) - get_accelerator().synchronize() - fp32_grad_tensor.copy_(self.pinned_grad_buffer[:buffer_numel], non_blocking=True) + 0, dest_offset, grad_buffer.numel()) + fp32_grad_tensor.copy_(grad_buffer.float()) # free the gradient if not get_accelerator().is_synchronized_device():