Skip to content

Commit

Permalink
Fix MURA.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Dec 6, 2023
1 parent 73bb229 commit 5e1f8cc
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 23 deletions.
6 changes: 3 additions & 3 deletions configs/train_coded_aperture.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ trainable_mask:
optimizer: Adam
mask_lr: 1e-3
L1_strength: False
binary: False
initial_value:
method: MLS
n_bits: 8
# MURA not working...
n_bits: 8 # (2**n_bits-1, 2**n_bits-1)
# method: MURA
# n_bits: 3
# n_bits: 25 # (4*nbits*1, 4*nbits*1)

simulation:
grayscale: False
Expand Down
27 changes: 15 additions & 12 deletions lensless/hardware/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def __init__(
self.torch_device = torch_device

# create mask
self.mask = None
self.create_mask()
self.shape = self.mask.shape

Expand Down Expand Up @@ -216,19 +215,20 @@ def __init__(self, method="MLS", n_bits=8, **kwargs):

# initialize parameters
if self.method.upper() == "MURA":
mask = self.squarepattern(4 * self.n_bits + 1)[1:, 1:]
self.row = self.mask[0, :]
self.col = self.mask[:, 0]
outer = np.outer(self.row, self.col)
assert np.all(outer == mask)
self.mask = self.squarepattern(4 * self.n_bits + 1)
self.row = None
self.col = None
else:
seq = max_len_seq(self.n_bits)[0]
self.row = seq
self.col = seq

if kwargs["is_torch"]:
self.row = torch.from_numpy(self.row).float()
self.col = torch.from_numpy(self.col).float()
if "is_torch" in kwargs and kwargs["is_torch"]:
if self.row is not None and self.col is not None:
self.row = torch.from_numpy(self.row).float()
self.col = torch.from_numpy(self.col).float()
else:
self.mask = torch.from_numpy(self.mask).float()

super().__init__(**kwargs)

Expand All @@ -238,10 +238,13 @@ def create_mask(self):
"""

# outer product
if self.is_torch:
self.mask = torch.outer(self.row, self.col)
if self.row is not None and self.col is not None:
if self.is_torch:
self.mask = torch.outer(self.row, self.col)
else:
self.mask = np.outer(self.row, self.col)
else:
self.mask = np.outer(self.row, self.col)
assert self.mask is not None

# resize to sensor shape
if np.any(self.resolution != self.mask.shape):
Expand Down
29 changes: 21 additions & 8 deletions lensless/hardware/trainable_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,9 +265,17 @@ def __init__(
self._mask = self._mask_obj.mask

# 3) set learnable parameters (should be immediate attributes of the class)
self._row = torch.nn.Parameter(self._mask_obj.row)
self._col = torch.nn.Parameter(self._mask_obj.col)
initial_param = [self._row, self._col]
if self._mask_obj.row is not None:
# seperable
self.separable = True
self._row = torch.nn.Parameter(self._mask_obj.row)
self._col = torch.nn.Parameter(self._mask_obj.col)
initial_param = [self._row, self._col]
else:
# non-seperable
self.separable = False
self._vals = torch.nn.Parameter(self._mask_obj.mask)
initial_param = [self._vals]
self.binary = binary

# 4) set optimizer
Expand All @@ -279,8 +287,13 @@ def get_psf(self):
return self._mask_obj.psf.unsqueeze(0)

def project(self):
self._row.data = torch.clamp(self._row, 0, 1)
self._col.data = torch.clamp(self._col, 0, 1)
if self.binary:
self._row.data = torch.round(self._row)
self._col.data = torch.round(self._col)
if self.separable:
self._row.data = torch.clamp(self._row, 0, 1)
self._col.data = torch.clamp(self._col, 0, 1)
if self.binary:
self._row.data = torch.round(self._row)
self._col.data = torch.round(self._col)
else:
self._vals.data = torch.clamp(self._vals, 0, 1)
if self.binary:
self._vals.data = torch.round(self._vals)
1 change: 1 addition & 0 deletions scripts/recon/train_unrolled.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def prep_trainable_mask(config, psf=None, downsample=None):
distance_sensor=config.simulation.mask2sensor,
optimizer=config.trainable_mask.optimizer,
lr=config.trainable_mask.mask_lr,
binary=config.trainable_mask.binary,
**config.trainable_mask.initial_value,
)

Expand Down

0 comments on commit 5e1f8cc

Please sign in to comment.