Skip to content

Commit

Permalink
Fix loss test imports
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Dec 14, 2023
1 parent 647b75d commit b024882
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions tests/losses/mse_loss_test.py → tests/losses/gmse_loss_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import unittest
from autoseg.losses import Weighted_MSELoss
from autoseg.losses import Weighted_GMSELoss


class DummyDiscriminator(torch.nn.Module):
Expand All @@ -12,15 +12,15 @@ class TestWeightedMSELoss(unittest.TestCase):
def setUp(self):
# Initialize Weighted_MSELoss with default settings
discrim = DummyDiscriminator()
self.weighted_mse_loss = Weighted_MSELoss(discrim=discrim)
self.weighted_gmse_loss = Weighted_GMSELoss(discrim=discrim)

def test_calc_loss_with_weights(self):
# Test _calc_loss method with weights provided
prediction = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
target = torch.tensor([2.0, 2.0, 2.0], requires_grad=True)
weights = torch.tensor([1.0, 0.0, 1.0], requires_grad=False)

loss = self.weighted_mse_loss._calc_loss(prediction, target, weights)
loss = self.weighted_gmse_loss._calc_loss(prediction, target, weights)

expected_loss = weights * (prediction - target) ** 2
if len(torch.nonzero(expected_loss)) != 0 and type(weights) == torch.Tensor:
Expand All @@ -36,7 +36,7 @@ def test_calc_loss_without_weights(self):
prediction = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
target = torch.tensor([2.0, 2.0, 2.0], requires_grad=True)

loss = self.weighted_mse_loss._calc_loss(prediction, target)
loss = self.weighted_gmse_loss._calc_loss(prediction, target)

expected_loss = torch.mean((prediction - target) ** 2)
self.assertTrue(torch.allclose(loss, expected_loss))
Expand All @@ -52,7 +52,7 @@ def test_forward_with_gan_loss(self):
pred_enhanced = torch.randn(3, requires_grad=True)
gt_enhanced = torch.randn(3, requires_grad=True)

loss = self.weighted_mse_loss(
loss = self.weighted_gmse_loss(
pred_lsds=pred_lsds,
gt_lsds=gt_lsds,
lsds_weights=lsds_weights,
Expand Down

0 comments on commit b024882

Please sign in to comment.