Skip to content

Commit

Permalink
Average only valid part of the ipg buffer.
Browse files Browse the repository at this point in the history
When contiguous gradients is used ipg buffer may not be fully utilized.
Call average_tensor only for the slice with valid gradints

Change-Id: I760559d52c2f91e15cd6cd0b48e534ec2352802a
  • Loading branch information
BacharL committed Mar 13, 2024
1 parent d9e12d3 commit dcf6282
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1360,7 +1360,7 @@ def reduce_ipg_grads(self):
self.average_tensor(extra_large_grad_reduc.view(-1))
self.extra_large_param_to_reduce = None
else:
self.average_tensor(self.ipg_buffer[self.ipg_index])
self.average_tensor(self.ipg_buffer[self.ipg_index].narrow(0, 0, self.elements_in_ipg_bucket))
else:
self.buffered_reduce_fallback(None,
self.grads_in_ipg_bucket,
Expand Down

0 comments on commit dcf6282

Please sign in to comment.