From 73bb22926f2e1657d13c968f7323c486050b959a Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Wed, 6 Dec 2023 16:49:46 +0100 Subject: [PATCH] Improve trainable mask API. --- configs/train_coded_aperture.yaml | 7 ++-- lensless/hardware/trainable_mask.py | 62 ++++++++++++++--------------- lensless/recon/utils.py | 33 +++++++-------- scripts/recon/train_unrolled.py | 9 ++++- 4 files changed, 58 insertions(+), 53 deletions(-) diff --git a/configs/train_coded_aperture.yaml b/configs/train_coded_aperture.yaml index 4219910d..5d8aa6d4 100644 --- a/configs/train_coded_aperture.yaml +++ b/configs/train_coded_aperture.yaml @@ -7,13 +7,14 @@ defaults: files: dataset: mnist # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" celeba_root: /scratch/bezzam - downsample: 8 + downsample: 16 # TODO use simulation instead? #Trainable Mask trainable_mask: mask_type: TrainableCodedAperture optimizer: Adam mask_lr: 1e-3 + L1_strength: False initial_value: method: MLS n_bits: 8 @@ -32,7 +33,7 @@ simulation: training: crop_preloss: False # crop region for computing loss - batch_size: 8 + batch_size: 4 epoch: 25 eval_batch_size: 16 - save_every: 5 + save_every: 1 diff --git a/lensless/hardware/trainable_mask.py b/lensless/hardware/trainable_mask.py index fd6c28cd..b4328f45 100644 --- a/lensless/hardware/trainable_mask.py +++ b/lensless/hardware/trainable_mask.py @@ -42,12 +42,13 @@ def __init__(self, optimizer="Adam", lr=1e-3, **kwargs): # # self._param = initial_param # self._optimizer = getattr(torch.optim, optimizer)(self._param, lr=lr) # self._counter = 0 - self.optimizer = optimizer - self.lr = lr + self._optimizer = optimizer + self._lr = lr + self._counter = 0 def _set_optimizer(self, param): """Set the optimizer for the mask parameters.""" - self._optimizer = getattr(torch.optim, self.optimizer)(param, lr=self.lr) + self._optimizer = getattr(torch.optim, self._optimizer)(param, lr=self._lr) @abc.abstractmethod def get_psf(self): @@ -68,10 +69,6 @@ def update_mask(self): self.project() self._counter += 1 - def get_vals(self): - """Get the mask parameters.""" - return self._param - @abc.abstractmethod def project(self): """Abstract method for projecting the mask parameters to a valid space (should be a subspace of [0,1]).""" @@ -162,16 +159,16 @@ def __init__( super().__init__(optimizer, lr, **kwargs) self.train_mask_vals = train_mask_vals if train_mask_vals: - self._mask = torch.nn.Parameter(initial_vals) + self._vals = torch.nn.Parameter(initial_vals) else: - self._mask = initial_vals + self._vals = initial_vals if color_filter is not None: - self.color_filter = torch.nn.Parameter(color_filter) + self._color_filter = torch.nn.Parameter(color_filter) if train_mask_vals: - initial_param = [self._mask, self.color_filter] + initial_param = [self._vals, self._color_filter] else: - initial_param = [self.color_filter] + initial_param = [self._color_filter] else: assert ( train_mask_vals @@ -202,12 +199,12 @@ def __init__( def get_psf(self): mask = get_programmable_mask( - vals=self._mask, + vals=self._vals, sensor=self.sensor, slm_param=self.slm_param, rotate=self.rotate, flipud=self.flipud, - color_filter=self.color_filter, + color_filter=self._color_filter, ) if self.vertical_shift is not None: @@ -240,11 +237,11 @@ def get_psf(self): def project(self): if self.train_mask_vals: - self._mask.data = torch.clamp(self._mask, self.min_val, 1) - if self.color_filter is not None: - self.color_filter.data = torch.clamp(self.color_filter, 0, 1) + self._vals.data = torch.clamp(self._vals, self.min_val, 1) + if self._color_filter is not None: + self._color_filter.data = torch.clamp(self._color_filter, 0, 1) # normalize each row to 1 - self.color_filter.data = self.color_filter / self.color_filter.sum( + self._color_filter.data = self._color_filter / self._color_filter.sum( dim=[1, 2] ).unsqueeze(-1).unsqueeze(-1) @@ -257,32 +254,33 @@ def __init__( TODO: Distinguish between separable and non-separable. """ + # 1) call base constructor so parameters can be set super().__init__(optimizer, lr, **kwargs) + # 2) initialize mask assert "distance_sensor" in kwargs, "Distance to sensor must be specified" assert "method" in kwargs, "Method must be specified." assert "n_bits" in kwargs, "Number of bits must be specified." + self._mask_obj = CodedAperture.from_sensor(sensor_name, downsample, is_torch=True, **kwargs) + self._mask = self._mask_obj.mask - # initialize mask - self._mask = CodedAperture.from_sensor(sensor_name, downsample, is_torch=True, **kwargs) - - # set learnable parameters (should be immediate attributes of the class) - self._row = torch.nn.Parameter(self._mask.row) - self._col = torch.nn.Parameter(self._mask.col) + # 3) set learnable parameters (should be immediate attributes of the class) + self._row = torch.nn.Parameter(self._mask_obj.row) + self._col = torch.nn.Parameter(self._mask_obj.col) initial_param = [self._row, self._col] self.binary = binary - # set optimizer + # 4) set optimizer self._set_optimizer(initial_param) def get_psf(self): - self._mask.create_mask() - self._mask.compute_psf() - return self._mask.psf.unsqueeze(0) + self._mask_obj.create_mask() + self._mask_obj.compute_psf() + return self._mask_obj.psf.unsqueeze(0) def project(self): - self.row.data = torch.clamp(self.row, 0, 1) - self.col = torch.clamp(self.col, 0, 1) + self._row.data = torch.clamp(self._row, 0, 1) + self._col.data = torch.clamp(self._col, 0, 1) if self.binary: - self.row.data = torch.round(self.row) - self.col.data = torch.round(self.col) + self._row.data = torch.round(self._row) + self._col.data = torch.round(self._col) diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 53f23a1b..f249f2b1 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -579,7 +579,9 @@ def train_epoch(self, data_loader): self.Loss_lpips(2 * y_pred - 1, 2 * y - 1) ) if self.use_mask and self.l1_mask: - loss_v = loss_v + self.l1_mask * torch.mean(torch.abs(self.mask._mask)) + for p in self.mask.parameters(): + if p.requires_grad: + loss_v = loss_v + self.l1_mask * torch.mean(torch.abs(p)) loss_v.backward() if self.clip_grad_norm is not None: @@ -659,7 +661,10 @@ def evaluate(self, mean_loss, save_pt, epoch, disp=None): if self.lpips is not None: eval_loss += self.lpips * current_metrics["LPIPS_Vgg"] if self.use_mask and self.l1_mask: - eval_loss += self.l1_mask * np.mean(np.abs(self.mask._mask.cpu().detach().numpy())) + for p in self.mask.parameters(): + if p.requires_grad: + eval_loss += self.l1_mask * np.mean(np.abs(p.cpu().detach().numpy())) + # eval_loss += self.l1_mask * np.mean(np.abs(self.mask._mask.cpu().detach().numpy())) return eval_loss else: return current_metrics[self.metrics["metric_for_best_model"]] @@ -771,23 +776,18 @@ def save(self, epoch, path="recon", include_optimizer=False): # create directory if it does not exist if not os.path.exists(path): os.makedirs(path) - # save mask + + # save mask parameters if self.use_mask: - # torch.save(self.mask._mask, os.path.join(path, f"mask_epoch{epoch}.pt")) - # save mask as numpy array - if self.mask.train_mask_vals: - np.save( - os.path.join(path, f"mask_epoch{epoch}.npy"), - self.mask._mask.cpu().detach().numpy(), - ) + for name, param in self.mask.named_parameters(): - if self.mask.color_filter is not None: - # save save numpy array - np.save( - os.path.join(path, f"mask_color_filter_epoch{epoch}.npy"), - self.mask.color_filter.cpu().detach().numpy(), - ) + # save as numpy array + if param.requires_grad: + np.save( + os.path.join(path, f"mask{name}_epoch{epoch}.npy"), + param.cpu().detach().numpy(), + ) torch.save( self.mask._optimizer.state_dict(), os.path.join(path, f"mask_optim_epoch{epoch}.pt") @@ -802,5 +802,6 @@ def save(self, epoch, path="recon", include_optimizer=False): # save optimizer if include_optimizer: torch.save(self.optimizer.state_dict(), os.path.join(path, f"optim_epoch{epoch}.pt")) + # save recon torch.save(self.recon.state_dict(), os.path.join(path, f"recon_epoch{epoch}")) diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index 8a326a7a..c9193674 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -108,6 +108,13 @@ def simulate_dataset(config, generator=None): transform = transforms.Compose(transforms_list) train_ds = datasets.MNIST(root=data_path, train=True, download=True, transform=transform) test_ds = datasets.MNIST(root=data_path, train=False, download=True, transform=transform) + + if config.files.n_files is not None: + train_size = int((1 - config.files.test_size) * config.files.n_files) + test_size = config.files.n_files - train_size + train_ds = Subset(train_ds, np.arange(train_size)) + test_ds = Subset(test_ds, np.arange(test_size)) + elif config.files.dataset == "fashion_mnist": transform = transforms.Compose(transforms_list) train_ds = datasets.FashionMNIST( @@ -274,8 +281,6 @@ def prep_trainable_mask(config, psf=None, downsample=None): if isinstance(config.trainable_mask.initial_value, omegaconf.dictconfig.DictConfig): - from lensless.hardware.trainable_mask import TrainableCodedAperture - # from mask config mask = mask_class( # mask = TrainableCodedAperture(