Skip to content

Commit

Permalink
Update low_level_optim.py
Browse files Browse the repository at this point in the history
  • Loading branch information
flybird11111 authored Aug 6, 2024
1 parent 7e0c777 commit 6f29436
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions colossalai/zero/low_level/low_level_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,15 +345,15 @@ def _run_reduction(self):
self._update_unpartitoned_grad(bucket_store, grad_in_bucket.values(), flat_grads_per_rank, group_id)
else:
flat_grads_list = list(flat_grads.split(len(flat_grads) // bucket_store.world_size))
recieved_grad = torch.zeros_like(flat_grads_list[0])
received_grad = torch.zeros_like(flat_grads_list[0])
if self._fp8_communication:
reduce_scatter_fp8(
recieved_grad,
received_grad,
flat_grads_list,
group=bucket_store.torch_pg,
)
else:
dist.reduce_scatter(recieved_grad, flat_grads_list, group=bucket_store.torch_pg)
dist.reduce_scatter(received_grad, flat_grads_list, group=bucket_store.torch_pg)

if received_grad.dtype != grad_dtype:
received_grad = received_grad.to(grad_dtype)
Expand Down

0 comments on commit 6f29436

Please sign in to comment.