diff --git a/modopt/opt/linear/wavelet.py b/modopt/opt/linear/wavelet.py index cb75ecff..3fa2d60d 100644 --- a/modopt/opt/linear/wavelet.py +++ b/modopt/opt/linear/wavelet.py @@ -104,6 +104,9 @@ def __init__(self, self.op = self.operator.op self.adj_op = self.operator.adj_op + @property + def coeffs_shape(self): + return self.operator.coeffs_shape class CPUWaveletTransform(LinearParent): """ @@ -283,6 +286,7 @@ def __init__( self.level = level self.shape = shape self.mode = mode + self.coeffs_shape = None # will be set after op. def op(self, data: torch.Tensor) -> list[torch.Tensor]: """Apply the wavelet decomposition on. @@ -308,8 +312,6 @@ def op(self, data: torch.Tensor) -> list[torch.Tensor]: coeffs_ = ptwt.wavedec2( data_, self.wavelet, level=self.level, mode=self.mode, axes=(-3, -2) ) - self.coeffs_shape = [coeffs_[0].shape] - self.coeffs_shape += [tuple(cc.shape for cc in c) for c in coeffs_] # flatten list of tuple of tensors to a list of tensors coeffs = [torch.view_as_complex(coeffs_[0].contiguous())] + [ torch.view_as_complex(cc.contiguous()) @@ -418,6 +420,7 @@ def __init__( self.mode = mode self.operator = TorchWaveletTransform(shape=shape, wavelet=wavelet, level=level,mode=mode) + self.coeffs_shape = None # will be set after op def op(self, data: cp.array) -> cp.ndarray: """Define the wavelet operator.