Skip to content

Commit

Permalink
update losses test
Browse files Browse the repository at this point in the history
  • Loading branch information
torzdf committed May 4, 2022
1 parent 5569abb commit 1f55617
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tests/lib/model/losses_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def test_loss_output(loss_func, output_shape):

_LWPARAMS = [losses.GeneralizedLoss(), losses.GradientLoss(), losses.GMSDLoss(),
losses.LInfNorm(), k_losses.mean_absolute_error, k_losses.mean_squared_error,
k_losses.logcosh, losses.DSSIMObjective()]
k_losses.logcosh, losses.DSSIMObjective(), losses.MSSSIMLoss()]
_LWIDS = ["GeneralizedLoss", "GradientLoss", "GMSDLoss", "LInfNorm", "mae", "mse", "logcosh",
"DSSIMObjective"]
"DSSIMObjective", "MS-SSIM"]
_LWIDS = [f"{loss}[{get_backend().upper()}]" for loss in _LWIDS]


Expand All @@ -55,6 +55,8 @@ def test_loss_wrapper(loss_func):
if get_backend() == "amd":
if isinstance(loss_func, losses.GMSDLoss):
pytest.skip("GMSD Loss is not currently compatible with PlaidML")
if isinstance(loss_func, losses.MSSSIMLoss):
pytest.skip("MS-SSIM Loss is not currently compatible with PlaidML")
if hasattr(loss_func, "__name__") and loss_func.__name__ == "logcosh":
pytest.skip("LogCosh Loss is not currently compatible with PlaidML")
y_a = K.variable(np.random.random((2, 16, 16, 4)))
Expand Down

0 comments on commit 1f55617

Please sign in to comment.