Skip to content

Commit

Permalink
Add optional gan loss to MSE
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Nov 20, 2023
1 parent 546457a commit bb54f1d
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/autoseg/losses/mtlsd_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,11 @@ def forward(
aff_loss = self.aff_lambda * self._calc_loss(pred_affs, gt_affs, affs_weights)

# calculate MSE loss for GAN errors
real_scores = self.discriminator(gt_enhanced)
fake_scores = self.discriminator(pred_enhanced)

gan_loss = self.gan_lambda * (torch.mean((real_scores - 1) ** 2) + torch.mean(fake_scores ** 2))
if gt_enhanced is not None and pred_enhanced is not None:
real_scores = self.discriminator(gt_enhanced)
fake_scores = self.discriminator(pred_enhanced)
gan_loss = self.gan_lambda * (torch.mean((real_scores - 1) ** 2) + torch.mean(fake_scores ** 2))
else:
gan_loss: float = 0.

return lsd_loss + aff_loss + gan_loss

0 comments on commit bb54f1d

Please sign in to comment.