From 28ed9bdc6e62b2d3d359e61613045549f889cdec Mon Sep 17 00:00:00 2001 From: Nadav Elyahu <88962733+nelyahu@users.noreply.github.com> Date: Mon, 16 Oct 2023 15:25:07 +0300 Subject: [PATCH] pipe engine _aggregate_total_loss: more efficient loss concatenation (#4327) * _aggregate_total_loss: more efficient loss concatenation optimize _aggregate_total_loss function in order to remove dependancy of copying from device to host and back to device. This reduce the runtime on the host. * Fixing the if/else block on which the optimization should take place --------- Co-authored-by: Olatunji Ruwase --- deepspeed/runtime/pipe/engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 0c7c9f7a1090d..dd1fd3dff5df2 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -549,7 +549,7 @@ def _aggregate_total_loss(self): agg_loss /= self.dp_world_size assert self.global_rank in self.grid.pp_group - losses = torch.Tensor([self.dp_group_loss, agg_loss]).to(self.device) + losses = torch.stack([self.dp_group_loss, agg_loss]) if self.is_pipe_parallel: dist.broadcast(tensor=losses, src=self.global_rank, group=self.mpu.get_pipe_parallel_group()) else: