diff --git a/dacapo/experiments/tasks/losses/affinities_loss.py b/dacapo/experiments/tasks/losses/affinities_loss.py index 1bc9aded5..d731b27db 100644 --- a/dacapo/experiments/tasks/losses/affinities_loss.py +++ b/dacapo/experiments/tasks/losses/affinities_loss.py @@ -79,10 +79,15 @@ def compute(self, prediction, target, weight): weight[:, self.num_affinities :, ...], ) - return ( - torch.nn.BCEWithLogitsLoss(reduction="none")(affs, affs_target) - * affs_weight - ).mean() + self.lsds_to_affs_weight_ratio * ( - torch.nn.MSELoss(reduction="none")(torch.nn.Sigmoid()(aux), aux_target) - * aux_weight - ).mean() + if aux.shape[1] == 0: + return torch.nn.BCEWithLogitsLoss(reduction="none")( + affs, affs_target + ).mean() + else: + return ( + torch.nn.BCEWithLogitsLoss(reduction="none")(affs, affs_target) + * affs_weight + ).mean() + self.lsds_to_affs_weight_ratio * ( + torch.nn.MSELoss(reduction="none")(torch.nn.Sigmoid()(aux), aux_target) + * aux_weight + ).mean()