Skip to content

Commit

Permalink
avoid adding auxiliary loss if not predicting affs since this results…
Browse files Browse the repository at this point in the history
… in nan loss
  • Loading branch information
pattonw committed Oct 16, 2024
1 parent 1369fa3 commit eb03070
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions dacapo/experiments/tasks/losses/affinities_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,15 @@ def compute(self, prediction, target, weight):
weight[:, self.num_affinities :, ...],
)

return (
torch.nn.BCEWithLogitsLoss(reduction="none")(affs, affs_target)
* affs_weight
).mean() + self.lsds_to_affs_weight_ratio * (
torch.nn.MSELoss(reduction="none")(torch.nn.Sigmoid()(aux), aux_target)
* aux_weight
).mean()
if aux.shape[1] == 0:
return torch.nn.BCEWithLogitsLoss(reduction="none")(
affs, affs_target
).mean()
else:
return (
torch.nn.BCEWithLogitsLoss(reduction="none")(affs, affs_target)
* affs_weight
).mean() + self.lsds_to_affs_weight_ratio * (
torch.nn.MSELoss(reduction="none")(torch.nn.Sigmoid()(aux), aux_target)
* aux_weight
).mean()

0 comments on commit eb03070

Please sign in to comment.