diff --git a/spanet/network/jet_reconstruction/jet_reconstruction_training.py b/spanet/network/jet_reconstruction/jet_reconstruction_training.py index b5d6ea5..bc12d6c 100644 --- a/spanet/network/jet_reconstruction/jet_reconstruction_training.py +++ b/spanet/network/jet_reconstruction/jet_reconstruction_training.py @@ -238,7 +238,7 @@ def training_step(self, batch: Batch, batch_nb: int) -> Dict[str, Tensor]: # Take the weighted average of the symmetric loss terms. masks = masks.unsqueeze(1) - symmetric_losses = (weights * symmetric_losses).sum(-1) / masks.sum(-1) + symmetric_losses = (weights * symmetric_losses).sum(-1) / torch.clamp(masks.sum(-1), 1, None) assignment_loss, detection_loss = torch.unbind(symmetric_losses, 1) # ===================================================================================================