From eb03070a08cc670896fa392ddd8d49e383e1abfc Mon Sep 17 00:00:00 2001 From: William Patton Date: Wed, 16 Oct 2024 09:22:57 -0700 Subject: [PATCH] avoid adding auxiliary loss if not predicting affs since this results in nan loss --- .../tasks/losses/affinities_loss.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/dacapo/experiments/tasks/losses/affinities_loss.py b/dacapo/experiments/tasks/losses/affinities_loss.py index 1bc9aded5..d731b27db 100644 --- a/dacapo/experiments/tasks/losses/affinities_loss.py +++ b/dacapo/experiments/tasks/losses/affinities_loss.py @@ -79,10 +79,15 @@ def compute(self, prediction, target, weight): weight[:, self.num_affinities :, ...], ) - return ( - torch.nn.BCEWithLogitsLoss(reduction="none")(affs, affs_target) - * affs_weight - ).mean() + self.lsds_to_affs_weight_ratio * ( - torch.nn.MSELoss(reduction="none")(torch.nn.Sigmoid()(aux), aux_target) - * aux_weight - ).mean() + if aux.shape[1] == 0: + return torch.nn.BCEWithLogitsLoss(reduction="none")( + affs, affs_target + ).mean() + else: + return ( + torch.nn.BCEWithLogitsLoss(reduction="none")(affs, affs_target) + * affs_weight + ).mean() + self.lsds_to_affs_weight_ratio * ( + torch.nn.MSELoss(reduction="none")(torch.nn.Sigmoid()(aux), aux_target) + * aux_weight + ).mean()