Skip to content

Commit

Permalink
feat: add test case for wavelet transform.
Browse files Browse the repository at this point in the history
  • Loading branch information
paquiteau committed Nov 15, 2023
1 parent 869af9a commit 906ddc4
Showing 1 changed file with 20 additions and 1 deletion.
21 changes: 20 additions & 1 deletion modopt/tests/test_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
except ImportError:
SKLEARN_AVAILABLE = False

PYWT_AVAILABLE = True
try:
import pywt
import joblib
except ImportError:
PYWT_AVAILABLE = False

# Basic functions to be used as operators or as dummy functions
func_identity = lambda x_val: x_val
Expand Down Expand Up @@ -156,7 +162,7 @@ def case_linear_identity(self):

return linop, data_op, data_adj_op, res_op, res_adj_op

def case_linear_wavelet(self):
def case_linear_wavelet_convolve(self):
"""Case linear operator wavelet."""
linop = linear.WaveletConvolve(
filters=np.arange(8).reshape(2, 2, 2).astype(float)
Expand All @@ -168,6 +174,19 @@ def case_linear_wavelet(self):

return linop, data_op, data_adj_op, res_op, res_adj_op

@pytest.mark.skipif(not PYWT_AVAILABLE, reason="PyWavelet not available.")
def case_linear_wavelet_transform(self):
linop = linear.WaveletTransform(
wavelet_name="haar",
shape=(8, 8),
level=2,
)
data_op = np.arange(64).reshape(8, 8).astype(float)
res_op, slices, shapes = pywt.ravel_coeffs(pywt.wavedecn(data_op, "haar", level=2))
data_adj_op = linop.op(data_op)
res_adj_op = pywt.waverecn(pywt.unravel_coeffs(data_adj_op, slices, shapes, "wavedecn"), "haar")
return linop, data_op, data_adj_op, res_op, res_adj_op

@parametrize(weights=[[1.0, 1.0], None])
def case_linear_combo(self, weights):
"""Case linear operator combo with weights."""
Expand Down

0 comments on commit 906ddc4

Please sign in to comment.