diff --git a/tests/losses/mse_loss_test.py b/tests/losses/gmse_loss_test.py similarity index 88% rename from tests/losses/mse_loss_test.py rename to tests/losses/gmse_loss_test.py index 04a39a8..7235ec1 100644 --- a/tests/losses/mse_loss_test.py +++ b/tests/losses/gmse_loss_test.py @@ -1,6 +1,6 @@ import torch import unittest -from autoseg.losses import Weighted_MSELoss +from autoseg.losses import Weighted_GMSELoss class DummyDiscriminator(torch.nn.Module): @@ -12,7 +12,7 @@ class TestWeightedMSELoss(unittest.TestCase): def setUp(self): # Initialize Weighted_MSELoss with default settings discrim = DummyDiscriminator() - self.weighted_mse_loss = Weighted_MSELoss(discrim=discrim) + self.weighted_gmse_loss = Weighted_GMSELoss(discrim=discrim) def test_calc_loss_with_weights(self): # Test _calc_loss method with weights provided @@ -20,7 +20,7 @@ def test_calc_loss_with_weights(self): target = torch.tensor([2.0, 2.0, 2.0], requires_grad=True) weights = torch.tensor([1.0, 0.0, 1.0], requires_grad=False) - loss = self.weighted_mse_loss._calc_loss(prediction, target, weights) + loss = self.weighted_gmse_loss._calc_loss(prediction, target, weights) expected_loss = weights * (prediction - target) ** 2 if len(torch.nonzero(expected_loss)) != 0 and type(weights) == torch.Tensor: @@ -36,7 +36,7 @@ def test_calc_loss_without_weights(self): prediction = torch.tensor([1.0, 2.0, 3.0], requires_grad=True) target = torch.tensor([2.0, 2.0, 2.0], requires_grad=True) - loss = self.weighted_mse_loss._calc_loss(prediction, target) + loss = self.weighted_gmse_loss._calc_loss(prediction, target) expected_loss = torch.mean((prediction - target) ** 2) self.assertTrue(torch.allclose(loss, expected_loss)) @@ -52,7 +52,7 @@ def test_forward_with_gan_loss(self): pred_enhanced = torch.randn(3, requires_grad=True) gt_enhanced = torch.randn(3, requires_grad=True) - loss = self.weighted_mse_loss( + loss = self.weighted_gmse_loss( pred_lsds=pred_lsds, gt_lsds=gt_lsds, lsds_weights=lsds_weights,