diff --git a/src/autoseg/losses/mtlsd_losses.py b/src/autoseg/losses/mtlsd_losses.py index f1523a4..949cc41 100644 --- a/src/autoseg/losses/mtlsd_losses.py +++ b/src/autoseg/losses/mtlsd_losses.py @@ -40,9 +40,11 @@ def forward( aff_loss = self.aff_lambda * self._calc_loss(pred_affs, gt_affs, affs_weights) # calculate MSE loss for GAN errors - real_scores = self.discriminator(gt_enhanced) - fake_scores = self.discriminator(pred_enhanced) - - gan_loss = self.gan_lambda * (torch.mean((real_scores - 1) ** 2) + torch.mean(fake_scores ** 2)) + if gt_enhanced is not None and pred_enhanced is not None: + real_scores = self.discriminator(gt_enhanced) + fake_scores = self.discriminator(pred_enhanced) + gan_loss = self.gan_lambda * (torch.mean((real_scores - 1) ** 2) + torch.mean(fake_scores ** 2)) + else: + gan_loss: float = 0. return lsd_loss + aff_loss + gan_loss \ No newline at end of file