Skip to content

Commit

Permalink
Fix for stage3 when setting different communication data type (micros…
Browse files Browse the repository at this point in the history
…oft#4540)

Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
2 people authored and amaurya committed Feb 17, 2024
1 parent 385448e commit 5925526
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,7 +1217,7 @@ def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None:
@instrument_w_nvtx
def __avg_scatter_contiguous_grads(self, buffer_to_reduce: Tensor) -> List[Tensor]:
dtype = buffer_to_reduce.dtype
if self.communication_data_type == self.dtype:
if self.communication_data_type != dtype:
buffer_to_reduce = buffer_to_reduce.to(self.communication_data_type)
if self.postscale_gradients and self.gradient_predivide_factor != 1.0:
buffer_to_reduce = buffer_to_reduce.div_(self.gradient_predivide_factor)
Expand Down

0 comments on commit 5925526

Please sign in to comment.