diff --git a/configs/train_coded_aperture.yaml b/configs/train_coded_aperture.yaml index 5d8aa6d4..b1e5e30a 100644 --- a/configs/train_coded_aperture.yaml +++ b/configs/train_coded_aperture.yaml @@ -15,12 +15,12 @@ trainable_mask: optimizer: Adam mask_lr: 1e-3 L1_strength: False + binary: False initial_value: method: MLS - n_bits: 8 - # MURA not working... + n_bits: 8 # (2**n_bits-1, 2**n_bits-1) # method: MURA - # n_bits: 3 + # n_bits: 25 # (4*nbits*1, 4*nbits*1) simulation: grayscale: False diff --git a/lensless/hardware/mask.py b/lensless/hardware/mask.py index 4ffa7b36..7c34d627 100644 --- a/lensless/hardware/mask.py +++ b/lensless/hardware/mask.py @@ -109,7 +109,6 @@ def __init__( self.torch_device = torch_device # create mask - self.mask = None self.create_mask() self.shape = self.mask.shape @@ -216,19 +215,20 @@ def __init__(self, method="MLS", n_bits=8, **kwargs): # initialize parameters if self.method.upper() == "MURA": - mask = self.squarepattern(4 * self.n_bits + 1)[1:, 1:] - self.row = self.mask[0, :] - self.col = self.mask[:, 0] - outer = np.outer(self.row, self.col) - assert np.all(outer == mask) + self.mask = self.squarepattern(4 * self.n_bits + 1) + self.row = None + self.col = None else: seq = max_len_seq(self.n_bits)[0] self.row = seq self.col = seq - if kwargs["is_torch"]: - self.row = torch.from_numpy(self.row).float() - self.col = torch.from_numpy(self.col).float() + if "is_torch" in kwargs and kwargs["is_torch"]: + if self.row is not None and self.col is not None: + self.row = torch.from_numpy(self.row).float() + self.col = torch.from_numpy(self.col).float() + else: + self.mask = torch.from_numpy(self.mask).float() super().__init__(**kwargs) @@ -238,10 +238,13 @@ def create_mask(self): """ # outer product - if self.is_torch: - self.mask = torch.outer(self.row, self.col) + if self.row is not None and self.col is not None: + if self.is_torch: + self.mask = torch.outer(self.row, self.col) + else: + self.mask = np.outer(self.row, self.col) else: - self.mask = np.outer(self.row, self.col) + assert self.mask is not None # resize to sensor shape if np.any(self.resolution != self.mask.shape): diff --git a/lensless/hardware/trainable_mask.py b/lensless/hardware/trainable_mask.py index b4328f45..eef57933 100644 --- a/lensless/hardware/trainable_mask.py +++ b/lensless/hardware/trainable_mask.py @@ -265,9 +265,17 @@ def __init__( self._mask = self._mask_obj.mask # 3) set learnable parameters (should be immediate attributes of the class) - self._row = torch.nn.Parameter(self._mask_obj.row) - self._col = torch.nn.Parameter(self._mask_obj.col) - initial_param = [self._row, self._col] + if self._mask_obj.row is not None: + # seperable + self.separable = True + self._row = torch.nn.Parameter(self._mask_obj.row) + self._col = torch.nn.Parameter(self._mask_obj.col) + initial_param = [self._row, self._col] + else: + # non-seperable + self.separable = False + self._vals = torch.nn.Parameter(self._mask_obj.mask) + initial_param = [self._vals] self.binary = binary # 4) set optimizer @@ -279,8 +287,13 @@ def get_psf(self): return self._mask_obj.psf.unsqueeze(0) def project(self): - self._row.data = torch.clamp(self._row, 0, 1) - self._col.data = torch.clamp(self._col, 0, 1) - if self.binary: - self._row.data = torch.round(self._row) - self._col.data = torch.round(self._col) + if self.separable: + self._row.data = torch.clamp(self._row, 0, 1) + self._col.data = torch.clamp(self._col, 0, 1) + if self.binary: + self._row.data = torch.round(self._row) + self._col.data = torch.round(self._col) + else: + self._vals.data = torch.clamp(self._vals, 0, 1) + if self.binary: + self._vals.data = torch.round(self._vals) diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index c9193674..132464f0 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -289,6 +289,7 @@ def prep_trainable_mask(config, psf=None, downsample=None): distance_sensor=config.simulation.mask2sensor, optimizer=config.trainable_mask.optimizer, lr=config.trainable_mask.mask_lr, + binary=config.trainable_mask.binary, **config.trainable_mask.initial_value, )