diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index 4caa4e2f8377..2274cc52cdb9 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -1217,7 +1217,7 @@ def __reduce_and_partition_ipg_grads(self, safe_mode: bool = False) -> None: @instrument_w_nvtx def __avg_scatter_contiguous_grads(self, buffer_to_reduce: Tensor) -> List[Tensor]: dtype = buffer_to_reduce.dtype - if self.communication_data_type == self.dtype: + if self.communication_data_type != dtype: buffer_to_reduce = buffer_to_reduce.to(self.communication_data_type) if self.postscale_gradients and self.gradient_predivide_factor != 1.0: buffer_to_reduce = buffer_to_reduce.div_(self.gradient_predivide_factor)