Skip to content

Commit

Permalink
Improve trainable mask API.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Dec 6, 2023
1 parent cd76a3c commit 73bb229
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 53 deletions.
7 changes: 4 additions & 3 deletions configs/train_coded_aperture.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ defaults:
files:
dataset: mnist # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam"
celeba_root: /scratch/bezzam
downsample: 8
downsample: 16 # TODO use simulation instead?

#Trainable Mask
trainable_mask:
mask_type: TrainableCodedAperture
optimizer: Adam
mask_lr: 1e-3
L1_strength: False
initial_value:
method: MLS
n_bits: 8
Expand All @@ -32,7 +33,7 @@ simulation:

training:
crop_preloss: False # crop region for computing loss
batch_size: 8
batch_size: 4
epoch: 25
eval_batch_size: 16
save_every: 5
save_every: 1
62 changes: 30 additions & 32 deletions lensless/hardware/trainable_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,13 @@ def __init__(self, optimizer="Adam", lr=1e-3, **kwargs):
# # self._param = initial_param
# self._optimizer = getattr(torch.optim, optimizer)(self._param, lr=lr)
# self._counter = 0
self.optimizer = optimizer
self.lr = lr
self._optimizer = optimizer
self._lr = lr
self._counter = 0

def _set_optimizer(self, param):
"""Set the optimizer for the mask parameters."""
self._optimizer = getattr(torch.optim, self.optimizer)(param, lr=self.lr)
self._optimizer = getattr(torch.optim, self._optimizer)(param, lr=self._lr)

@abc.abstractmethod
def get_psf(self):
Expand All @@ -68,10 +69,6 @@ def update_mask(self):
self.project()
self._counter += 1

def get_vals(self):
"""Get the mask parameters."""
return self._param

@abc.abstractmethod
def project(self):
"""Abstract method for projecting the mask parameters to a valid space (should be a subspace of [0,1])."""
Expand Down Expand Up @@ -162,16 +159,16 @@ def __init__(
super().__init__(optimizer, lr, **kwargs)
self.train_mask_vals = train_mask_vals
if train_mask_vals:
self._mask = torch.nn.Parameter(initial_vals)
self._vals = torch.nn.Parameter(initial_vals)
else:
self._mask = initial_vals
self._vals = initial_vals

if color_filter is not None:
self.color_filter = torch.nn.Parameter(color_filter)
self._color_filter = torch.nn.Parameter(color_filter)
if train_mask_vals:
initial_param = [self._mask, self.color_filter]
initial_param = [self._vals, self._color_filter]
else:
initial_param = [self.color_filter]
initial_param = [self._color_filter]
else:
assert (
train_mask_vals
Expand Down Expand Up @@ -202,12 +199,12 @@ def __init__(
def get_psf(self):

mask = get_programmable_mask(
vals=self._mask,
vals=self._vals,
sensor=self.sensor,
slm_param=self.slm_param,
rotate=self.rotate,
flipud=self.flipud,
color_filter=self.color_filter,
color_filter=self._color_filter,
)

if self.vertical_shift is not None:
Expand Down Expand Up @@ -240,11 +237,11 @@ def get_psf(self):

def project(self):
if self.train_mask_vals:
self._mask.data = torch.clamp(self._mask, self.min_val, 1)
if self.color_filter is not None:
self.color_filter.data = torch.clamp(self.color_filter, 0, 1)
self._vals.data = torch.clamp(self._vals, self.min_val, 1)
if self._color_filter is not None:
self._color_filter.data = torch.clamp(self._color_filter, 0, 1)
# normalize each row to 1
self.color_filter.data = self.color_filter / self.color_filter.sum(
self._color_filter.data = self._color_filter / self._color_filter.sum(
dim=[1, 2]
).unsqueeze(-1).unsqueeze(-1)

Expand All @@ -257,32 +254,33 @@ def __init__(
TODO: Distinguish between separable and non-separable.
"""

# 1) call base constructor so parameters can be set
super().__init__(optimizer, lr, **kwargs)

# 2) initialize mask
assert "distance_sensor" in kwargs, "Distance to sensor must be specified"
assert "method" in kwargs, "Method must be specified."
assert "n_bits" in kwargs, "Number of bits must be specified."
self._mask_obj = CodedAperture.from_sensor(sensor_name, downsample, is_torch=True, **kwargs)
self._mask = self._mask_obj.mask

# initialize mask
self._mask = CodedAperture.from_sensor(sensor_name, downsample, is_torch=True, **kwargs)

# set learnable parameters (should be immediate attributes of the class)
self._row = torch.nn.Parameter(self._mask.row)
self._col = torch.nn.Parameter(self._mask.col)
# 3) set learnable parameters (should be immediate attributes of the class)
self._row = torch.nn.Parameter(self._mask_obj.row)
self._col = torch.nn.Parameter(self._mask_obj.col)
initial_param = [self._row, self._col]
self.binary = binary

# set optimizer
# 4) set optimizer
self._set_optimizer(initial_param)

def get_psf(self):
self._mask.create_mask()
self._mask.compute_psf()
return self._mask.psf.unsqueeze(0)
self._mask_obj.create_mask()
self._mask_obj.compute_psf()
return self._mask_obj.psf.unsqueeze(0)

def project(self):
self.row.data = torch.clamp(self.row, 0, 1)
self.col = torch.clamp(self.col, 0, 1)
self._row.data = torch.clamp(self._row, 0, 1)
self._col.data = torch.clamp(self._col, 0, 1)
if self.binary:
self.row.data = torch.round(self.row)
self.col.data = torch.round(self.col)
self._row.data = torch.round(self._row)
self._col.data = torch.round(self._col)
33 changes: 17 additions & 16 deletions lensless/recon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,9 @@ def train_epoch(self, data_loader):
self.Loss_lpips(2 * y_pred - 1, 2 * y - 1)
)
if self.use_mask and self.l1_mask:
loss_v = loss_v + self.l1_mask * torch.mean(torch.abs(self.mask._mask))
for p in self.mask.parameters():
if p.requires_grad:
loss_v = loss_v + self.l1_mask * torch.mean(torch.abs(p))
loss_v.backward()

if self.clip_grad_norm is not None:
Expand Down Expand Up @@ -659,7 +661,10 @@ def evaluate(self, mean_loss, save_pt, epoch, disp=None):
if self.lpips is not None:
eval_loss += self.lpips * current_metrics["LPIPS_Vgg"]
if self.use_mask and self.l1_mask:
eval_loss += self.l1_mask * np.mean(np.abs(self.mask._mask.cpu().detach().numpy()))
for p in self.mask.parameters():
if p.requires_grad:
eval_loss += self.l1_mask * np.mean(np.abs(p.cpu().detach().numpy()))
# eval_loss += self.l1_mask * np.mean(np.abs(self.mask._mask.cpu().detach().numpy()))
return eval_loss
else:
return current_metrics[self.metrics["metric_for_best_model"]]
Expand Down Expand Up @@ -771,23 +776,18 @@ def save(self, epoch, path="recon", include_optimizer=False):
# create directory if it does not exist
if not os.path.exists(path):
os.makedirs(path)
# save mask

# save mask parameters
if self.use_mask:
# torch.save(self.mask._mask, os.path.join(path, f"mask_epoch{epoch}.pt"))

# save mask as numpy array
if self.mask.train_mask_vals:
np.save(
os.path.join(path, f"mask_epoch{epoch}.npy"),
self.mask._mask.cpu().detach().numpy(),
)
for name, param in self.mask.named_parameters():

if self.mask.color_filter is not None:
# save save numpy array
np.save(
os.path.join(path, f"mask_color_filter_epoch{epoch}.npy"),
self.mask.color_filter.cpu().detach().numpy(),
)
# save as numpy array
if param.requires_grad:
np.save(
os.path.join(path, f"mask{name}_epoch{epoch}.npy"),
param.cpu().detach().numpy(),
)

torch.save(
self.mask._optimizer.state_dict(), os.path.join(path, f"mask_optim_epoch{epoch}.pt")
Expand All @@ -802,5 +802,6 @@ def save(self, epoch, path="recon", include_optimizer=False):
# save optimizer
if include_optimizer:
torch.save(self.optimizer.state_dict(), os.path.join(path, f"optim_epoch{epoch}.pt"))

# save recon
torch.save(self.recon.state_dict(), os.path.join(path, f"recon_epoch{epoch}"))
9 changes: 7 additions & 2 deletions scripts/recon/train_unrolled.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,13 @@ def simulate_dataset(config, generator=None):
transform = transforms.Compose(transforms_list)
train_ds = datasets.MNIST(root=data_path, train=True, download=True, transform=transform)
test_ds = datasets.MNIST(root=data_path, train=False, download=True, transform=transform)

if config.files.n_files is not None:
train_size = int((1 - config.files.test_size) * config.files.n_files)
test_size = config.files.n_files - train_size
train_ds = Subset(train_ds, np.arange(train_size))
test_ds = Subset(test_ds, np.arange(test_size))

elif config.files.dataset == "fashion_mnist":
transform = transforms.Compose(transforms_list)
train_ds = datasets.FashionMNIST(
Expand Down Expand Up @@ -274,8 +281,6 @@ def prep_trainable_mask(config, psf=None, downsample=None):

if isinstance(config.trainable_mask.initial_value, omegaconf.dictconfig.DictConfig):

from lensless.hardware.trainable_mask import TrainableCodedAperture

# from mask config
mask = mask_class(
# mask = TrainableCodedAperture(
Expand Down

0 comments on commit 73bb229

Please sign in to comment.