From 1f5561795aae8ca9eb4fd03de25e5f31b3e7e42e Mon Sep 17 00:00:00 2001 From: torzdf <36920800+torzdf@users.noreply.github.com> Date: Wed, 4 May 2022 09:56:48 +0100 Subject: [PATCH] update losses test --- tests/lib/model/losses_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/lib/model/losses_test.py b/tests/lib/model/losses_test.py index 1ff2fde597..7c83b8755e 100644 --- a/tests/lib/model/losses_test.py +++ b/tests/lib/model/losses_test.py @@ -43,9 +43,9 @@ def test_loss_output(loss_func, output_shape): _LWPARAMS = [losses.GeneralizedLoss(), losses.GradientLoss(), losses.GMSDLoss(), losses.LInfNorm(), k_losses.mean_absolute_error, k_losses.mean_squared_error, - k_losses.logcosh, losses.DSSIMObjective()] + k_losses.logcosh, losses.DSSIMObjective(), losses.MSSSIMLoss()] _LWIDS = ["GeneralizedLoss", "GradientLoss", "GMSDLoss", "LInfNorm", "mae", "mse", "logcosh", - "DSSIMObjective"] + "DSSIMObjective", "MS-SSIM"] _LWIDS = [f"{loss}[{get_backend().upper()}]" for loss in _LWIDS] @@ -55,6 +55,8 @@ def test_loss_wrapper(loss_func): if get_backend() == "amd": if isinstance(loss_func, losses.GMSDLoss): pytest.skip("GMSD Loss is not currently compatible with PlaidML") + if isinstance(loss_func, losses.MSSSIMLoss): + pytest.skip("MS-SSIM Loss is not currently compatible with PlaidML") if hasattr(loss_func, "__name__") and loss_func.__name__ == "logcosh": pytest.skip("LogCosh Loss is not currently compatible with PlaidML") y_a = K.variable(np.random.random((2, 16, 16, 4)))