Skip to content

Commit

Permalink
fix: provide a coeffs shape property.
Browse files Browse the repository at this point in the history
  • Loading branch information
paquiteau committed Feb 9, 2024
1 parent 62b519a commit 80159e4
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions modopt/opt/linear/wavelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,9 @@ def __init__(self,
self.op = self.operator.op
self.adj_op = self.operator.adj_op

@property
def coeffs_shape(self):
return self.operator.coeffs_shape

class CPUWaveletTransform(LinearParent):
"""
Expand Down Expand Up @@ -283,6 +286,7 @@ def __init__(
self.level = level
self.shape = shape
self.mode = mode
self.coeffs_shape = None # will be set after op.

def op(self, data: torch.Tensor) -> list[torch.Tensor]:
"""Apply the wavelet decomposition on.
Expand All @@ -308,8 +312,6 @@ def op(self, data: torch.Tensor) -> list[torch.Tensor]:
coeffs_ = ptwt.wavedec2(
data_, self.wavelet, level=self.level, mode=self.mode, axes=(-3, -2)
)
self.coeffs_shape = [coeffs_[0].shape]
self.coeffs_shape += [tuple(cc.shape for cc in c) for c in coeffs_]
# flatten list of tuple of tensors to a list of tensors
coeffs = [torch.view_as_complex(coeffs_[0].contiguous())] + [
torch.view_as_complex(cc.contiguous())
Expand Down Expand Up @@ -418,6 +420,7 @@ def __init__(
self.mode = mode

self.operator = TorchWaveletTransform(shape=shape, wavelet=wavelet, level=level,mode=mode)
self.coeffs_shape = None # will be set after op

def op(self, data: cp.array) -> cp.ndarray:
"""Define the wavelet operator.
Expand Down

0 comments on commit 80159e4

Please sign in to comment.