diff --git a/modopt/opt/linear/wavelet.py b/modopt/opt/linear/wavelet.py index c497e59e..cb75ecff 100644 --- a/modopt/opt/linear/wavelet.py +++ b/modopt/opt/linear/wavelet.py @@ -94,9 +94,9 @@ def __init__(self, **kwargs): if compute_backend == "cupy" and ptwt_available: - self.operator = CupyWaveletTransform(wavelet_name, shape, level, mode) + self.operator = CupyWaveletTransform(wavelet=wavelet_name, shape=shape, level=level, mode=mode) elif compute_backend == "numpy" and pywt_available: - self.operator = CPUWaveletTransform(wavelet_name = wavelet_name,shape= shape,nb_scales=level, **kwargs) + self.operator = CPUWaveletTransform(wavelet_name=wavelet_name, shape=shape, nb_scales=level, **kwargs) else: raise ValueError(f"Compute Backend {compute_backend} not available") @@ -417,7 +417,7 @@ def __init__( self.shape = shape self.mode = mode - self.operator = TorchWaveletTransform(shape, wavelet, level, mode) + self.operator = TorchWaveletTransform(shape=shape, wavelet=wavelet, level=level,mode=mode) def op(self, data: cp.array) -> cp.ndarray: """Define the wavelet operator.