Skip to content

Commit

Permalink
Using explicit GPU upcast for ZeRO-Offload (#6962)
Browse files Browse the repository at this point in the history
Following discussion in
[PR-6670](#6670), the explict
upcast is much more efficient than implicit upcast, this PR is to
replace implicit upcast with explict one.

The results on 3B model are shown below:

| Option | BWD (ms) | Speed up |
|------------|-----|------|
| Before PR-6670 | 25603.30 | 1x |
| After PR-6670 | 1174.31 | 21.8X |
| After this PR| 309.2 | 82.8X |
  • Loading branch information
xylian86 authored Jan 21, 2025
1 parent 8d1bc0a commit c17dc33
Showing 1 changed file with 2 additions and 11 deletions.
13 changes: 2 additions & 11 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,15 +546,10 @@ def _setup_for_real_optimizer(self):
self.grad_partitions_flat_buffer = get_accelerator().pin_memory(self.grad_partitions_flat_buffer)

offset = 0
max_partition_numel = 0
for param in all_params:
self.__param_id_to_grad_partition[param.ds_id] = self.grad_partitions_flat_buffer.narrow(
0, offset, param.partition_numel())
offset += param.partition_numel()
max_partition_numel = max(max_partition_numel, param.partition_numel())
if self.offload_optimizer:
self.pinned_grad_buffer: Tensor = get_accelerator().pin_memory(
torch.empty(max_partition_numel, device=self.device))

def _link_all_hp_params(self):
for p in self.module.parameters():
Expand Down Expand Up @@ -1510,13 +1505,9 @@ def partition_grads(self, params_to_release: List[Parameter], grad_partitions: L
offload_fp32_gradients[i].append(grad_buffer.float())
offload_fp32_offsets[i].append(dest_offset)
else:
buffer_numel = grad_buffer.numel()
fp32_grad_tensor = self.fp32_partitioned_groups_flat[i].grad.narrow(
0, dest_offset, buffer_numel)
self.pinned_grad_buffer[:buffer_numel].copy_(
grad_buffer.to(dtype=torch.float32, non_blocking=True))
get_accelerator().synchronize()
fp32_grad_tensor.copy_(self.pinned_grad_buffer[:buffer_numel], non_blocking=True)
0, dest_offset, grad_buffer.numel())
fp32_grad_tensor.copy_(grad_buffer.float())

# free the gradient
if not get_accelerator().is_synchronized_device():
Expand Down

0 comments on commit c17dc33

Please sign in to comment.