diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py index 0e45ffb8..dace6d18 100644 --- a/modopt/tests/test_opt.py +++ b/modopt/tests/test_opt.py @@ -22,6 +22,12 @@ except ImportError: SKLEARN_AVAILABLE = False +PYWT_AVAILABLE = True +try: + import pywt + import joblib +except ImportError: + PYWT_AVAILABLE = False # Basic functions to be used as operators or as dummy functions func_identity = lambda x_val: x_val @@ -156,7 +162,7 @@ def case_linear_identity(self): return linop, data_op, data_adj_op, res_op, res_adj_op - def case_linear_wavelet(self): + def case_linear_wavelet_convolve(self): """Case linear operator wavelet.""" linop = linear.WaveletConvolve( filters=np.arange(8).reshape(2, 2, 2).astype(float) @@ -168,6 +174,19 @@ def case_linear_wavelet(self): return linop, data_op, data_adj_op, res_op, res_adj_op + @pytest.mark.skipif(not PYWT_AVAILABLE, reason="PyWavelet not available.") + def case_linear_wavelet_transform(self): + linop = linear.WaveletTransform( + wavelet_name="haar", + shape=(8, 8), + level=2, + ) + data_op = np.arange(64).reshape(8, 8).astype(float) + res_op, slices, shapes = pywt.ravel_coeffs(pywt.wavedecn(data_op, "haar", level=2)) + data_adj_op = linop.op(data_op) + res_adj_op = pywt.waverecn(pywt.unravel_coeffs(data_adj_op, slices, shapes, "wavedecn"), "haar") + return linop, data_op, data_adj_op, res_op, res_adj_op + @parametrize(weights=[[1.0, 1.0], None]) def case_linear_combo(self, weights): """Case linear operator combo with weights."""