From 4dc7d10f7ba30bc046162eabbe3e50af892afaf8 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Sun, 4 Feb 2024 21:22:01 +0100 Subject: [PATCH 1/6] feat: add support for cupy in SparseThreshold. Ideally we want to have such support everywhere. --- modopt/opt/proximity.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/modopt/opt/proximity.py b/modopt/opt/proximity.py index e8492367..fc81a753 100644 --- a/modopt/opt/proximity.py +++ b/modopt/opt/proximity.py @@ -22,6 +22,7 @@ else: import_sklearn = True +from modopt.base.backend import get_array_module from modopt.base.transform import cube2matrix, matrix2cube from modopt.base.types import check_callable from modopt.interface.errors import warn @@ -215,7 +216,10 @@ def _cost_method(self, *args, **kwargs): Sparsity cost component """ - cost_val = np.sum(np.abs(self.weights * self._linear.op(args[0]))) + xp = get_array_module(args[0]) + cost_val = xp.sum(xp.abs(self.weights * self._linear.op(args[0]))) + if isinstance(cost_val, xp.ndarray): + cost_val = cost_val.item() if 'verbose' in kwargs and kwargs['verbose']: print(' - L1 NORM (X):', cost_val) From 1961acc6346da0b15b1b6ed9544195d235ec5b6c Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Mon, 5 Feb 2024 09:54:48 +0100 Subject: [PATCH 2/6] feat: add cupy wavelet transform. --- modopt/opt/linear/wavelet.py | 189 +++++++++++++++++++++++++++++++++++ 1 file changed, 189 insertions(+) diff --git a/modopt/opt/linear/wavelet.py b/modopt/opt/linear/wavelet.py index 6e22a2b0..6d72db0d 100644 --- a/modopt/opt/linear/wavelet.py +++ b/modopt/opt/linear/wavelet.py @@ -16,6 +16,14 @@ except ImportError: pywt_available = False +ptwt_available = True +try: + import ptwt + import torch + import cupy as cp +except: + ptwt_available = False + class WaveletConvolve(LinearParent): """Wavelet Convolution Class. @@ -214,3 +222,184 @@ def _adj_op(self, coeffs): wavelet=self.wavelet, mode=self.mode, ) + + +class TorchWaveletTransform: + """Wavelet transform using pytorch.""" + + wavedec3_keys = ["aad", "ada", "add", "daa", "dad", "dda", "ddd"] + + def __init__( + self, + shape: tuple[int, ...], + wavelet: str, + level: int, + mode: str, + ): + self.wavelet = wavelet + self.level = level + self.shape = shape + self.mode = mode + + def op(self, data: torch.Tensor) -> list[torch.Tensor]: + """Apply the wavelet decomposition on. + + Parameters + ---------- + data: torch.Tensor + 2D or 3D, real or complex data with last axes matching shape of + the operator. + + Returns + ------- + list[torch.Tensor] + list of tensor each containing the data of a subband. + """ + if data.shape == self.shape: + data = data[None, ...] # add a batch dimension + + if len(self.shape) == 2: + if torch.is_complex(data): + # 2D Complex + data_ = torch.view_as_real(data) + coeffs_ = ptwt.wavedec2( + data_, self.wavelet, level=self.level, mode=self.mode, axes=(-3, -2) + ) + self.coeffs_shape = [coeffs_[0].shape] + self.coeffs_shape += [tuple(cc.shape for cc in c) for c in coeffs_] + # flatten list of tuple of tensors to a list of tensors + coeffs = [torch.view_as_complex(coeffs_[0].contiguous())] + [ + torch.view_as_complex(cc.contiguous()) + for c in coeffs_[1:] + for cc in c + ] + + return coeffs + # 2D Real + coeffs_ = ptwt.wavedec2( + data, self.wavelet, level=self.level, mode=self.mode, axes=(-2, -1) + ) + return [coeffs_[0]] + [cc for c in coeffs_[1:] for cc in c] + + if torch.is_complex(data): + # 3D Complex + data_ = torch.view_as_real(data) + coeffs_ = ptwt.wavedec3( + data_, + self.wavelet, + level=self.level, + mode=self.mode, + axes=(-4, -3, -2), + ) + # flatten list of tuple of tensors to a list of tensors + coeffs = [torch.view_as_complex(coeffs_[0].contiguous())] + [ + torch.view_as_complex(cc.contiguous()) + for c in coeffs_[1:] + for cc in c.values() + ] + + return coeffs + # 3D Real + coeffs_ = ptwt.wavedec3( + data, self.wavelet, level=self.level, mode=self.mode, axes=(-3, -2, -1) + ) + return [coeffs_[0]] + [cc for c in coeffs_[1:] for cc in c.values()] + + def adj_op(self, coeffs: list[torch.Tensor]) -> torch.Tensor: + """Apply the wavelet recomposition. + + Parameters + ---------- + list[torch.Tensor] + list of tensor each containing the data of a subband. + + Returns + ------- + data: torch.Tensor + 2D or 3D, real or complex data with last axes matching shape of the + operator. + + """ + if len(self.shape) == 2: + if torch.is_complex(coeffs[0]): + ## 2D Complex ## + # list of tensor to list of tuple of tensor + coeffs = [torch.view_as_real(coeffs[0])] + [ + tuple(torch.view_as_real(coeffs[i + k]) for k in range(3)) + for i in range(1, len(coeffs) - 2, 3) + ] + data = ptwt.waverec2(coeffs, wavelet=self.wavelet, axes=(-3, -2)) + return torch.view_as_complex(data.contiguous()) + ## 2D Real ## + coeffs_ = [coeffs[0]] + [ + tuple(coeffs[i + k] for k in range(3)) + for i in range(1, len(coeffs) - 2, 3) + ] + data = ptwt.waverec2(coeffs_, wavelet=self.wavelet, axes=(-2, -1)) + return data + + if torch.is_complex(coeffs[0]): + ## 3D Complex ## + # list of tensor to list of tuple of tensor + coeffs = [torch.view_as_real(coeffs[0])] + [ + { + v: torch.view_as_real(coeffs[i + k]) + for k, v in enumerate(self.wavedec3_keys) + } + for i in range(1, len(coeffs) - 6, 7) + ] + data = ptwt.waverec3(coeffs, wavelet=self.wavelet, axes=(-4, -3, -2)) + return torch.view_as_complex(data.contiguous()) + ## 3D Real ## + coeffs_ = [coeffs[0]] + [ + {v: coeffs[i + k] for k, v in enumerate(self.wavedec3_keys)} + for i in range(1, len(coeffs) - 6, 7) + ] + data = ptwt.waverec3(coeffs_, wavelet=self.wavelet, axes=(-3, -2, -1)) + return data + + +class CupyWaveletTransform: + """Wrapper around torch wavelet transform to be compatible with the Modopt API.""" + + def __init__( + self, + shape: tuple[int, ...], + wavelet: str, + level: int, + mode: str, + ): + self.wavelet = wavelet + self.level = level + self.shape = shape + self.mode = mode + + self.operator = TorchWaveletTransform(shape, wavelet, level, mode) + + def op(self, data: cp.array) -> cp.ndarray: + """Apply Forward Wavelet transform on cupy array.""" + data_ = torch.as_tensor(data) + tensor_list = self.operator.op(data_) + # flatten the list of tensor to a cupy array + # this requires an on device copy... + self.coeffs_shape = [c.shape for c in tensor_list] + n_tot_coeffs = np.sum([np.prod(s) for s in self.coeffs_shape]) + ret = cp.zeros(n_tot_coeffs, dtype=np.complex64) # FIXME get dtype from torch + start = 0 + for t in tensor_list: + stop = start + np.prod(t.shape) + ret[start:stop] = cp.asarray(t.flatten()) + start = stop + + return ret + + def adj_op(self, data: cp.ndarray) -> cp.ndarray: + """Apply Adjoint Wavelet transform on cupy array.""" + start = 0 + tensor_list = [None] * len(self.coeffs_shape) + for i, s in enumerate(self.coeffs_shape): + stop = start + np.prod(s) + tensor_list[i] = torch.as_tensor(data[start:stop].reshape(s), device="cuda") + start = stop + ret_tensor = self.operator.adj_op(tensor_list) + return cp.from_dlpack(ret_tensor) From fc36139eb02fa3d777971b1101eaf8c1a6b937de Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Mon, 5 Feb 2024 10:30:33 +0100 Subject: [PATCH 3/6] feat: use compute_backend to dispatch. --- modopt/opt/linear/wavelet.py | 77 ++++++++++++++++++++++++++++++++++-- modopt/tests/test_opt.py | 16 ++++++-- 2 files changed, 86 insertions(+), 7 deletions(-) diff --git a/modopt/opt/linear/wavelet.py b/modopt/opt/linear/wavelet.py index 6d72db0d..c497e59e 100644 --- a/modopt/opt/linear/wavelet.py +++ b/modopt/opt/linear/wavelet.py @@ -21,7 +21,7 @@ import ptwt import torch import cupy as cp -except: +except ImportError: ptwt_available = False @@ -62,10 +62,53 @@ def __init__(self, filters, method='scipy'): + class WaveletTransform(LinearParent): """ 2D and 3D wavelet transform class. + This is a wrapper around either Pywavelet (CPU) or Pytorch Wavelet (GPU using Pytorch). + + Parameters + ---------- + wavelet_name: str + the wavelet name to be used during the decomposition. + shape: tuple[int,...] + Shape of the input data. The shape should be a tuple of length 2 or 3. + It should not contains coils or batch dimension. + nb_scales: int, default 4 + the number of scales in the decomposition. + mode: str, default "zero" + Boundary Condition mode + compute_backend: str, "numpy" or "cupy", default "numpy" + Backend library to use. "cupy" also requires a working installation of PyTorch and pytorch wavelets. + + **kwargs: extra kwargs for Pywavelet or Pytorch Wavelet + """ + def __init__(self, + wavelet_name, + shape, + level=4, + mode="symmetric", + compute_backend="numpy", + **kwargs): + + if compute_backend == "cupy" and ptwt_available: + self.operator = CupyWaveletTransform(wavelet_name, shape, level, mode) + elif compute_backend == "numpy" and pywt_available: + self.operator = CPUWaveletTransform(wavelet_name = wavelet_name,shape= shape,nb_scales=level, **kwargs) + else: + raise ValueError(f"Compute Backend {compute_backend} not available") + + + self.op = self.operator.op + self.adj_op = self.operator.adj_op + + +class CPUWaveletTransform(LinearParent): + """ + 2D and 3D wavelet transform class. + This is a light wrapper around PyWavelet, with multicoil support. Parameters @@ -359,7 +402,7 @@ def adj_op(self, coeffs: list[torch.Tensor]) -> torch.Tensor: return data -class CupyWaveletTransform: +class CupyWaveletTransform(LinearParent): """Wrapper around torch wavelet transform to be compatible with the Modopt API.""" def __init__( @@ -377,7 +420,20 @@ def __init__( self.operator = TorchWaveletTransform(shape, wavelet, level, mode) def op(self, data: cp.array) -> cp.ndarray: - """Apply Forward Wavelet transform on cupy array.""" + """Define the wavelet operator. + + This method returns the input data convolved with the wavelet filter. + + Parameters + ---------- + data: cp.ndarray + input 2D data array. + + Returns + ------- + coeffs: ndarray + the wavelet coefficients. + """ data_ = torch.as_tensor(data) tensor_list = self.operator.op(data_) # flatten the list of tensor to a cupy array @@ -394,7 +450,20 @@ def op(self, data: cp.array) -> cp.ndarray: return ret def adj_op(self, data: cp.ndarray) -> cp.ndarray: - """Apply Adjoint Wavelet transform on cupy array.""" + """Define the wavelet adjoint operator. + + This method returns the reconstructed image. + + Parameters + ---------- + coeffs: cp.ndarray + the wavelet coefficients. + + Returns + ------- + data: ndarray + the reconstructed data. + """ start = 0 tensor_list = [None] * len(self.coeffs_shape) for i, s in enumerate(self.coeffs_shape): diff --git a/modopt/tests/test_opt.py b/modopt/tests/test_opt.py index 4a82e33c..7c30186e 100644 --- a/modopt/tests/test_opt.py +++ b/modopt/tests/test_opt.py @@ -22,6 +22,13 @@ except ImportError: SKLEARN_AVAILABLE = False +PTWT_AVAILABLE = True +try: + import ptwt + import cupy +except ImportError: + PTWT_AVAILABLE = False + PYWT_AVAILABLE = True try: import pywt @@ -174,8 +181,12 @@ def case_linear_wavelet_convolve(self): return linop, data_op, data_adj_op, res_op, res_adj_op - @pytest.mark.skipif(not PYWT_AVAILABLE, reason="PyWavelet not available.") - def case_linear_wavelet_transform(self): + @parametrize( + compute_backend=[ + pytest.param("numpy", marks=pytest.mark.skipif(not PYWT_AVAILABLE, reason="PyWavelet not available.")), + pytest.param("cupy", marks=pytest.mark.skipif(not PTWT_AVAILABLE, reason="Pytorch Wavelet not available.")) + ]) + def case_linear_wavelet_transform(self, compute_backend="numpy"): linop = linear.WaveletTransform( wavelet_name="haar", shape=(8, 8), @@ -298,7 +309,6 @@ class ProxCases: [11.67394789, 12.87497954, 14.07601119], [15.27704284, 16.47807449, 17.67910614], ], - ] ) array233_3 = np.array( [ From 62b519a009c2bea4cc35e1e6f036e1e12e77e863 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Mon, 5 Feb 2024 11:04:37 +0100 Subject: [PATCH 4/6] fix: pass parameters by name --- modopt/opt/linear/wavelet.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modopt/opt/linear/wavelet.py b/modopt/opt/linear/wavelet.py index c497e59e..cb75ecff 100644 --- a/modopt/opt/linear/wavelet.py +++ b/modopt/opt/linear/wavelet.py @@ -94,9 +94,9 @@ def __init__(self, **kwargs): if compute_backend == "cupy" and ptwt_available: - self.operator = CupyWaveletTransform(wavelet_name, shape, level, mode) + self.operator = CupyWaveletTransform(wavelet=wavelet_name, shape=shape, level=level, mode=mode) elif compute_backend == "numpy" and pywt_available: - self.operator = CPUWaveletTransform(wavelet_name = wavelet_name,shape= shape,nb_scales=level, **kwargs) + self.operator = CPUWaveletTransform(wavelet_name=wavelet_name, shape=shape, nb_scales=level, **kwargs) else: raise ValueError(f"Compute Backend {compute_backend} not available") @@ -417,7 +417,7 @@ def __init__( self.shape = shape self.mode = mode - self.operator = TorchWaveletTransform(shape, wavelet, level, mode) + self.operator = TorchWaveletTransform(shape=shape, wavelet=wavelet, level=level,mode=mode) def op(self, data: cp.array) -> cp.ndarray: """Define the wavelet operator. From 80159e4855c4fdb72893c30e49af412cd5e383cb Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Fri, 9 Feb 2024 18:12:08 +0100 Subject: [PATCH 5/6] fix: provide a coeffs shape property. --- modopt/opt/linear/wavelet.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/modopt/opt/linear/wavelet.py b/modopt/opt/linear/wavelet.py index cb75ecff..3fa2d60d 100644 --- a/modopt/opt/linear/wavelet.py +++ b/modopt/opt/linear/wavelet.py @@ -104,6 +104,9 @@ def __init__(self, self.op = self.operator.op self.adj_op = self.operator.adj_op + @property + def coeffs_shape(self): + return self.operator.coeffs_shape class CPUWaveletTransform(LinearParent): """ @@ -283,6 +286,7 @@ def __init__( self.level = level self.shape = shape self.mode = mode + self.coeffs_shape = None # will be set after op. def op(self, data: torch.Tensor) -> list[torch.Tensor]: """Apply the wavelet decomposition on. @@ -308,8 +312,6 @@ def op(self, data: torch.Tensor) -> list[torch.Tensor]: coeffs_ = ptwt.wavedec2( data_, self.wavelet, level=self.level, mode=self.mode, axes=(-3, -2) ) - self.coeffs_shape = [coeffs_[0].shape] - self.coeffs_shape += [tuple(cc.shape for cc in c) for c in coeffs_] # flatten list of tuple of tensors to a list of tensors coeffs = [torch.view_as_complex(coeffs_[0].contiguous())] + [ torch.view_as_complex(cc.contiguous()) @@ -418,6 +420,7 @@ def __init__( self.mode = mode self.operator = TorchWaveletTransform(shape=shape, wavelet=wavelet, level=level,mode=mode) + self.coeffs_shape = None # will be set after op def op(self, data: cp.array) -> cp.ndarray: """Define the wavelet operator. From bf9367470a204a820d9d9913947b9201e71601f6 Mon Sep 17 00:00:00 2001 From: Pierre-antoine Comby Date: Tue, 13 Feb 2024 18:28:56 +0100 Subject: [PATCH 6/6] fix: update name. --- modopt/opt/linear/wavelet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modopt/opt/linear/wavelet.py b/modopt/opt/linear/wavelet.py index 3fa2d60d..5feead66 100644 --- a/modopt/opt/linear/wavelet.py +++ b/modopt/opt/linear/wavelet.py @@ -96,7 +96,7 @@ def __init__(self, if compute_backend == "cupy" and ptwt_available: self.operator = CupyWaveletTransform(wavelet=wavelet_name, shape=shape, level=level, mode=mode) elif compute_backend == "numpy" and pywt_available: - self.operator = CPUWaveletTransform(wavelet_name=wavelet_name, shape=shape, nb_scales=level, **kwargs) + self.operator = CPUWaveletTransform(wavelet_name=wavelet_name, shape=shape, level=level, **kwargs) else: raise ValueError(f"Compute Backend {compute_backend} not available")