Skip to content

Commit

Permalink
pipe engine _aggregate_total_loss: more efficient loss concatenation (#…
Browse files Browse the repository at this point in the history
…4327)

* _aggregate_total_loss: more efficient loss concatenation

optimize _aggregate_total_loss function in order to remove dependancy
of copying from device to host and back to device.
This reduce the runtime on the host.

* Fixing the if/else block on which the optimization should take place

---------

Co-authored-by: Olatunji Ruwase <[email protected]>
  • Loading branch information
nelyahu and tjruwase authored Oct 16, 2023
1 parent 12aedac commit 28ed9bd
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,7 @@ def _aggregate_total_loss(self):
agg_loss /= self.dp_world_size

assert self.global_rank in self.grid.pp_group
losses = torch.Tensor([self.dp_group_loss, agg_loss]).to(self.device)
losses = torch.stack([self.dp_group_loss, agg_loss])
if self.is_pipe_parallel:
dist.broadcast(tensor=losses, src=self.global_rank, group=self.mpu.get_pipe_parallel_group())
else:
Expand Down

0 comments on commit 28ed9bd

Please sign in to comment.