diff --git a/modopt/opt/linear/wavelet.py b/modopt/opt/linear/wavelet.py index 6e22a2b0..5feead66 100644 --- a/modopt/opt/linear/wavelet.py +++ b/modopt/opt/linear/wavelet.py @@ -16,6 +16,14 @@ except ImportError: pywt_available = False +ptwt_available = True +try: + import ptwt + import torch + import cupy as cp +except ImportError: + ptwt_available = False + class WaveletConvolve(LinearParent): """Wavelet Convolution Class. @@ -54,10 +62,56 @@ def __init__(self, filters, method='scipy'): + class WaveletTransform(LinearParent): """ 2D and 3D wavelet transform class. + This is a wrapper around either Pywavelet (CPU) or Pytorch Wavelet (GPU using Pytorch). + + Parameters + ---------- + wavelet_name: str + the wavelet name to be used during the decomposition. + shape: tuple[int,...] + Shape of the input data. The shape should be a tuple of length 2 or 3. + It should not contains coils or batch dimension. + nb_scales: int, default 4 + the number of scales in the decomposition. + mode: str, default "zero" + Boundary Condition mode + compute_backend: str, "numpy" or "cupy", default "numpy" + Backend library to use. "cupy" also requires a working installation of PyTorch and pytorch wavelets. + + **kwargs: extra kwargs for Pywavelet or Pytorch Wavelet + """ + def __init__(self, + wavelet_name, + shape, + level=4, + mode="symmetric", + compute_backend="numpy", + **kwargs): + + if compute_backend == "cupy" and ptwt_available: + self.operator = CupyWaveletTransform(wavelet=wavelet_name, shape=shape, level=level, mode=mode) + elif compute_backend == "numpy" and pywt_available: + self.operator = CPUWaveletTransform(wavelet_name=wavelet_name, shape=shape, level=level, **kwargs) + else: + raise ValueError(f"Compute Backend {compute_backend} not available") + + + self.op = self.operator.op + self.adj_op = self.operator.adj_op + + @property + def coeffs_shape(self): + return self.operator.coeffs_shape + +class CPUWaveletTransform(LinearParent): + """ + 2D and 3D wavelet transform class. + This is a light wrapper around PyWavelet, with multicoil support. Parameters @@ -214,3 +268,210 @@ def _adj_op(self, coeffs): wavelet=self.wavelet, mode=self.mode, ) + + +class TorchWaveletTransform: + """Wavelet transform using pytorch.""" + + wavedec3_keys = ["aad", "ada", "add", "daa", "dad", "dda", "ddd"] + + def __init__( + self, + shape: tuple[int, ...], + wavelet: str, + level: int, + mode: str, + ): + self.wavelet = wavelet + self.level = level + self.shape = shape + self.mode = mode + self.coeffs_shape = None # will be set after op. + + def op(self, data: torch.Tensor) -> list[torch.Tensor]: + """Apply the wavelet decomposition on. + + Parameters + ---------- + data: torch.Tensor + 2D or 3D, real or complex data with last axes matching shape of + the operator. + + Returns + ------- + list[torch.Tensor] + list of tensor each containing the data of a subband. + """ + if data.shape == self.shape: + data = data[None, ...] # add a batch dimension + + if len(self.shape) == 2: + if torch.is_complex(data): + # 2D Complex + data_ = torch.view_as_real(data) + coeffs_ = ptwt.wavedec2( + data_, self.wavelet, level=self.level, mode=self.mode, axes=(-3, -2) + ) + # flatten list of tuple of tensors to a list of tensors + coeffs = [torch.view_as_complex(coeffs_[0].contiguous())] + [ + torch.view_as_complex(cc.contiguous()) + for c in coeffs_[1:] + for cc in c + ] + + return coeffs + # 2D Real + coeffs_ = ptwt.wavedec2( + data, self.wavelet, level=self.level, mode=self.mode, axes=(-2, -1) + ) + return [coeffs_[0]] + [cc for c in coeffs_[1:] for cc in c] + + if torch.is_complex(data): + # 3D Complex + data_ = torch.view_as_real(data) + coeffs_ = ptwt.wavedec3( + data_, + self.wavelet, + level=self.level, + mode=self.mode, + axes=(-4, -3, -2), + ) + # flatten list of tuple of tensors to a list of tensors + coeffs = [torch.view_as_complex(coeffs_[0].contiguous())] + [ + torch.view_as_complex(cc.contiguous()) + for c in coeffs_[1:] + for cc in c.values() + ] + + return coeffs + # 3D Real + coeffs_ = ptwt.wavedec3( + data, self.wavelet, level=self.level, mode=self.mode, axes=(-3, -2, -1) + ) + return [coeffs_[0]] + [cc for c in coeffs_[1:] for cc in c.values()] + + def adj_op(self, coeffs: list[torch.Tensor]) -> torch.Tensor: + """Apply the wavelet recomposition. + + Parameters + ---------- + list[torch.Tensor] + list of tensor each containing the data of a subband. + + Returns + ------- + data: torch.Tensor + 2D or 3D, real or complex data with last axes matching shape of the + operator. + + """ + if len(self.shape) == 2: + if torch.is_complex(coeffs[0]): + ## 2D Complex ## + # list of tensor to list of tuple of tensor + coeffs = [torch.view_as_real(coeffs[0])] + [ + tuple(torch.view_as_real(coeffs[i + k]) for k in range(3)) + for i in range(1, len(coeffs) - 2, 3) + ] + data = ptwt.waverec2(coeffs, wavelet=self.wavelet, axes=(-3, -2)) + return torch.view_as_complex(data.contiguous()) + ## 2D Real ## + coeffs_ = [coeffs[0]] + [ + tuple(coeffs[i + k] for k in range(3)) + for i in range(1, len(coeffs) - 2, 3) + ] + data = ptwt.waverec2(coeffs_, wavelet=self.wavelet, axes=(-2, -1)) + return data + + if torch.is_complex(coeffs[0]): + ## 3D Complex ## + # list of tensor to list of tuple of tensor + coeffs = [torch.view_as_real(coeffs[0])] + [ + { + v: torch.view_as_real(coeffs[i + k]) + for k, v in enumerate(self.wavedec3_keys) + } + for i in range(1, len(coeffs) - 6, 7) + ] + data = ptwt.waverec3(coeffs, wavelet=self.wavelet, axes=(-4, -3, -2)) + return torch.view_as_complex(data.contiguous()) + ## 3D Real ## + coeffs_ = [coeffs[0]] + [ + {v: coeffs[i + k] for k, v in enumerate(self.wavedec3_keys)} + for i in range(1, len(coeffs) - 6, 7) + ] + data = ptwt.waverec3(coeffs_, wavelet=self.wavelet, axes=(-3, -2, -1)) + return data + + +class CupyWaveletTransform(LinearParent): + """Wrapper around torch wavelet transform to be compatible with the Modopt API.""" + + def __init__( + self, + shape: tuple[int, ...], + wavelet: str, + level: int, + mode: str, + ): + self.wavelet = wavelet + self.level = level + self.shape = shape + self.mode = mode + + self.operator = TorchWaveletTransform(shape=shape, wavelet=wavelet, level=level,mode=mode) + self.coeffs_shape = None # will be set after op + + def op(self, data: cp.array) -> cp.ndarray: + """Define the wavelet operator. + + This method returns the input data convolved with the wavelet filter. + + Parameters + ---------- + data: cp.ndarray + input 2D data array. + + Returns + ------- + coeffs: ndarray + the wavelet coefficients. + """ + data_ = torch.as_tensor(data) + tensor_list = self.operator.op(data_) + # flatten the list of tensor to a cupy array + # this requires an on device copy... + self.coeffs_shape = [c.shape for c in tensor_list] + n_tot_coeffs = np.sum([np.prod(s) for s in self.coeffs_shape]) + ret = cp.zeros(n_tot_coeffs, dtype=np.complex64) # FIXME get dtype from torch + start = 0 + for t in tensor_list: + stop = start + np.prod(t.shape) + ret[start:stop] = cp.asarray(t.flatten()) + start = stop + + return ret + + def adj_op(self, data: cp.ndarray) -> cp.ndarray: + """Define the wavelet adjoint operator. + + This method returns the reconstructed image. + + Parameters + ---------- + coeffs: cp.ndarray + the wavelet coefficients. + + Returns + ------- + data: ndarray + the reconstructed data. + """ + start = 0 + tensor_list = [None] * len(self.coeffs_shape) + for i, s in enumerate(self.coeffs_shape): + stop = start + np.prod(s) + tensor_list[i] = torch.as_tensor(data[start:stop].reshape(s), device="cuda") + start = stop + ret_tensor = self.operator.adj_op(tensor_list) + return cp.from_dlpack(ret_tensor) diff --git a/modopt/opt/proximity.py b/modopt/opt/proximity.py index e8492367..fc81a753 100644 --- a/modopt/opt/proximity.py +++ b/modopt/opt/proximity.py @@ -22,6 +22,7 @@ else: import_sklearn = True +from modopt.base.backend import get_array_module from modopt.base.transform import cube2matrix, matrix2cube from modopt.base.types import check_callable from modopt.interface.errors import warn @@ -215,7 +216,10 @@ def _cost_method(self, *args, **kwargs): Sparsity cost component """ - cost_val = np.sum(np.abs(self.weights * self._linear.op(args[0]))) + xp = get_array_module(args[0]) + cost_val = xp.sum(xp.abs(self.weights * self._linear.op(args[0]))) + if isinstance(cost_val, xp.ndarray): + cost_val = cost_val.item() if 'verbose' in kwargs and kwargs['verbose']: print(' - L1 NORM (X):', cost_val) diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py index 4a82e33c..7c30186e 100644 --- a/modopt/tests/test_opt.py +++ b/modopt/tests/test_opt.py @@ -22,6 +22,13 @@ except ImportError: SKLEARN_AVAILABLE = False +PTWT_AVAILABLE = True +try: + import ptwt + import cupy +except ImportError: + PTWT_AVAILABLE = False + PYWT_AVAILABLE = True try: import pywt @@ -174,8 +181,12 @@ def case_linear_wavelet_convolve(self): return linop, data_op, data_adj_op, res_op, res_adj_op - @pytest.mark.skipif(not PYWT_AVAILABLE, reason="PyWavelet not available.") - def case_linear_wavelet_transform(self): + @parametrize( + compute_backend=[ + pytest.param("numpy", marks=pytest.mark.skipif(not PYWT_AVAILABLE, reason="PyWavelet not available.")), + pytest.param("cupy", marks=pytest.mark.skipif(not PTWT_AVAILABLE, reason="Pytorch Wavelet not available.")) + ]) + def case_linear_wavelet_transform(self, compute_backend="numpy"): linop = linear.WaveletTransform( wavelet_name="haar", shape=(8, 8), @@ -298,7 +309,6 @@ class ProxCases: [11.67394789, 12.87497954, 14.07601119], [15.27704284, 16.47807449, 17.67910614], ], - ] ) array233_3 = np.array( [