From 5892fba0db9b24d195aed4f55609fc9a8fb37732 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Tue, 10 Oct 2023 15:57:08 +0200 Subject: [PATCH] Add support to train adafruit mask. --- configs/train_celeba_digicam.yaml | 44 ++++++++ configs/train_celeba_digicam_mask.yaml | 64 +++++++++++ configs/train_unrolledADMM.yaml | 12 +- lensless/hardware/trainable_mask.py | 98 ++++++++++++++++- lensless/recon/utils.py | 33 +++--- lensless/utils/dataset.py | 36 +++++- scripts/recon/train_unrolled.py | 145 +++++++++++++++++-------- 7 files changed, 352 insertions(+), 80 deletions(-) create mode 100644 configs/train_celeba_digicam.yaml create mode 100644 configs/train_celeba_digicam_mask.yaml diff --git a/configs/train_celeba_digicam.yaml b/configs/train_celeba_digicam.yaml new file mode 100644 index 00000000..91604010 --- /dev/null +++ b/configs/train_celeba_digicam.yaml @@ -0,0 +1,44 @@ +# python scripts/recon/train_unrolled.py -cn train_celeba_digicam +defaults: + - train_unrolledADMM + - _self_ + +# Train Dataset +files: + dataset: /scratch/bezzam/celeba_adafruit_random_2mm_20230720_10K + celeba_root: /scratch/bezzam + psf: data/psf/adafruit_random_2mm_20231907.png + +# for prepping ground truth data +simulation: + scene2mask: 0.25 # [m] + mask2sensor: 0.002 # [m] + object_height: 0.33 # [m] + + +reconstruction: + method: unrolled_admm + unrolled_admm: + # Number of iterations + n_iter: 10 + + pre_process: + network : null # UnetRes or DruNet or null + depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet + post_process: + network : null # UnetRes or DruNet or null + depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet + + +# see some outputs of classical ADMM before training +test_idx: [0, 1, 2, 3, 4] + +#Training +training: + batch_size: 2 + epoch: 50 + eval_batch_size: 15 + + # crop: null + crop_preloss: True + diff --git a/configs/train_celeba_digicam_mask.yaml b/configs/train_celeba_digicam_mask.yaml new file mode 100644 index 00000000..310ac058 --- /dev/null +++ b/configs/train_celeba_digicam_mask.yaml @@ -0,0 +1,64 @@ +# python scripts/recon/train_unrolled.py -cn train_celeba_digicam_mask +defaults: + - train_celeba_digicam + - _self_ + +# Train Dataset +files: + dataset: /scratch/bezzam/celeba_adafruit_random_2mm_20230720_10K + celeba_root: /scratch/bezzam + psf: data/psf/adafruit_random_2mm_20231907.png + +# for prepping ground truth data +simulation: + scene2mask: 0.25 # [m] + mask2sensor: 0.002 # [m] + object_height: 0.33 # [m] + + +reconstruction: + method: unrolled_admm + unrolled_admm: + # Number of iterations + n_iter: 10 + + pre_process: + network : null # UnetRes or DruNet or null + depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet + post_process: + network : null # UnetRes or DruNet or null + depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet + +#Training +training: + batch_size: 2 + epoch: 50 + eval_batch_size: 15 + + # crop: null + crop_preloss: True + +#Trainable Mask +trainable_mask: + mask_type: AdafruitLCD #Null or "TrainablePSF" or "AdafruitLCD" + # "random" (with shape of config.files.psf) or path to npy file + initial_value: data/psf/adafruit_random_pattern_20230719.npy + grayscale: False + mask_lr: 1e-3 + L1_strength: False + + # only for AdafruitLCD + ap_center: [59, 76] + ap_shape: [19, 26] + rotate: -0.8 # rotation in degrees + slm: adafruit + sensor: rpi_hq + flipud: True + waveprop: True + # to align with measured PSF (so reconstruction also aligned) + vertical_shift: -20 # [px] + horizontal_shift: -100 # [px] + # below are ignored if waveprop=False + scene2mask: 0.3 # [m] + mask2sensor: 0.002 # [m] + \ No newline at end of file diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index df6552e2..745343ce 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -19,6 +19,8 @@ files: torch: True torch_device: 'cuda' +# see some outputs of classical ADMM before training +test_idx: [0, 1, 2, 3, 4] # test set example to visualize at the end of every epoch eval_disp_idx: [0, 1, 2, 3, 4] @@ -59,7 +61,7 @@ reconstruction: #Trainable Mask trainable_mask: - mask_type: Null #Null or "TrainablePSF" + mask_type: Null #Null or "TrainablePSF" or "AdafruitLCD" # "random" (with shape of config.files.psf) or "psf" (using config.files.psf) initial_value: psf grayscale: False @@ -106,10 +108,10 @@ training: skip_NAN: True slow_start: False #float how much to reduce lr for first epoch - crop: null # crop region for computing loss - # crop: - # vertical: [30, 560] - # horizontal: [275, 710] + crop_preloss: True # crop region for computing loss + crop: null + # vertical: null + # horizontal: null optimizer: type: Adam diff --git a/lensless/hardware/trainable_mask.py b/lensless/hardware/trainable_mask.py index 9bc70bc8..451670a5 100644 --- a/lensless/hardware/trainable_mask.py +++ b/lensless/hardware/trainable_mask.py @@ -9,6 +9,9 @@ import abc import torch from lensless.utils.image import is_grayscale +from lensless.hardware.slm import get_programmable_mask, get_intensity_psf +from lensless.hardware.sensor import VirtualSensor +from waveprop.devices import slm_dict class TrainableMask(torch.nn.Module, metaclass=abc.ABCMeta): @@ -37,7 +40,7 @@ def __init__(self, initial_mask, optimizer="Adam", lr=1e-3, **kwargs): """ super().__init__() self._mask = torch.nn.Parameter(initial_mask) - self._optimizer = getattr(torch.optim, optimizer)([self._mask], lr=lr, **kwargs) + self._optimizer = getattr(torch.optim, optimizer)([self._mask], lr=lr) self._counter = 0 @abc.abstractmethod @@ -53,7 +56,7 @@ def get_psf(self): raise NotImplementedError def update_mask(self): - """Update the mask parameters. Acoording to externaly updated gradiants.""" + """Update the mask parameters. According to externaly updated gradiants.""" self._optimizer.step() self._optimizer.zero_grad(set_to_none=True) self.project() @@ -100,3 +103,94 @@ def get_psf(self): def project(self): self._mask.data = torch.clamp(self._mask, 0, 1) + + +class AdafruitLCD(TrainableMask): + def __init__( + self, + initial_vals, + sensor, + slm, + rotate=None, + flipud=False, + use_waveprop=None, + vertical_shift=None, + horizontal_shift=None, + scene2mask=None, + mask2sensor=None, + downsample=None, + **kwargs + ): + """ + Parameters + ---------- + initial_vals : :py:class:`~torch.Tensor` + Initial mask parameters. + sensor : :py:class:`~lensless.hardware.sensor.VirtualSensor` + Sensor object. + slm_param : :py:class:`~lensless.hardware.slm.SLMParam` + SLM parameters. + rotate : float, optional + Rotation angle in degrees, by default None + flipud : bool, optional + Whether to flip the mask vertically, by default False + """ + super().__init__(initial_vals, **kwargs) + + self.slm_param = slm_dict[slm] + self.sensor = VirtualSensor.from_name(sensor, downsample=downsample) + self.rotate = rotate + self.flipud = flipud + self.use_waveprop = use_waveprop + self.scene2mask = scene2mask + self.mask2sensor = mask2sensor + self.vertical_shift = vertical_shift + self.horizontal_shift = horizontal_shift + if downsample is not None and vertical_shift is not None: + self.vertical_shift = vertical_shift // downsample + if downsample is not None and horizontal_shift is not None: + self.horizontal_shift = horizontal_shift // downsample + if self.use_waveprop: + assert self.scene2mask is not None + assert self.mask2sensor is not None + + def get_psf(self): + + mask = get_programmable_mask( + vals=self._mask, + sensor=self.sensor, + slm_param=self.slm_param, + rotate=self.rotate, + flipud=self.flipud, + ) + + if self.vertical_shift is not None: + mask = torch.roll(mask, self.vertical_shift, dims=1) + + if self.horizontal_shift is not None: + mask = torch.roll(mask, self.horizontal_shift, dims=2) + + psf_in = get_intensity_psf( + mask=mask, + sensor=self.sensor, + waveprop=self.use_waveprop, + scene2mask=self.scene2mask, + mask2sensor=self.mask2sensor, + ) + + # add first dimension (depth) + psf_in = psf_in.unsqueeze(0) + + # move channels to last dimension + psf_in = psf_in.permute(0, 2, 3, 1) + + # flip mask + psf_in = torch.flip(psf_in, dims=[-3, -2]) + + # normalize + psf_in = psf_in / psf_in.norm() + + return psf_in + + def project(self): + self._mask.data = torch.clamp(self._mask, 0, 1) diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index dca1cd03..439fe9d9 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -373,20 +373,7 @@ def __init__( "lpips package is need for LPIPS loss. Install using : pip install lpips" ) - if crop is not None: - datashape = train_dataset[0][0].shape - # create binary mask to multiply with before computing loss - self.mask_crop = torch.zeros(datashape, dtype=torch.bool).to(self.device) - - # move channel dimension to third to last - self.mask_crop = self.mask_crop.movedim(-1, -3) - - # set values to True in mask - self.mask_crop[ - :, :, crop.vertical[0] : crop.vertical[1], crop.horizontal[0] : crop.horizontal[1] - ] = True - else: - self.mask_crop = None + self.crop = crop # optimizer if optimizer == "Adam": @@ -484,7 +471,7 @@ def train_epoch(self, data_loader): # update psf according to mask if self.use_mask: - self.recon._set_psf(self.mask.get_psf()) + self.recon._set_psf(self.mask.get_psf().to(self.device)) # forward pass y_pred = self.recon.batch_call(X.to(self.device)) @@ -503,9 +490,17 @@ def train_epoch(self, data_loader): y = y.reshape(-1, *y.shape[-3:]).movedim(-1, -3) # crop - if self.mask_crop is not None: - y_pred = y_pred * self.mask_crop - y = y * self.mask_crop + if self.crop is not None: + y_pred = y_pred[ + ..., + self.crop["vertical"][0] : self.crop["vertical"][1], + self.crop["horizontal"][0] : self.crop["horizontal"][1], + ] + y = y[ + ..., + self.crop["vertical"][0] : self.crop["vertical"][1], + self.crop["horizontal"][0] : self.crop["horizontal"][1], + ] loss_v = self.Loss(y_pred, y) if self.lpips: @@ -583,7 +578,7 @@ def evaluate(self, mean_loss, save_pt, epoch, disp=None): batchsize=self.eval_batch_size, save_idx=disp, output_dir=output_dir, - mask_crop=self.mask_crop, + crop=self.crop, ) # update metrics with current metrics diff --git a/lensless/utils/dataset.py b/lensless/utils/dataset.py index c794ebb2..2aeec33a 100644 --- a/lensless/utils/dataset.py +++ b/lensless/utils/dataset.py @@ -355,13 +355,17 @@ def __init__( psf_path=None, downsample=1, flip=True, - vertical_shift=-85, - horizontal_shift=-15, + vertical_shift=None, + horizontal_shift=None, + crop=None, simulation_config=None, **kwargs, ): """ + Some parameters default to work for the ``celeba_adafruit_random_2mm_20230720_10K`` dataset, + namely: flip, vertical_shift, horizontal_shift, crop, simulation_config. + Parameters ---------- celeba_root : str @@ -375,11 +379,33 @@ def __init__( flip : bool, optional If True, measurements are flipped, by default ``True``. Does not get applied to the original images. vertical_shift : int, optional - Vertical shift (in pixels) of the lensed images to align, by default 0. + Vertical shift (in pixels) of the lensed images to align. horizontal_shift : int, optional - Horizontal shift (in pixels) of the lensed images to align, by default 0. + Horizontal shift (in pixels) of the lensed images to align. + crop : dict, optional + Dictionary of crop parameters (vertical: [start, end], horizontal: [start, end]) to select region of interest. """ + if vertical_shift is None: + # default to (no downsampling) of celeba_adafruit_random_2mm_20230720_10K + vertical_shift = -85 + horizontal_shift = -5 + + if crop is None: + crop = {"vertical": [30, 560], "horizontal": [285, 720]} + self.crop = crop + + self.vertical_shift = vertical_shift + self.horizontal_shift = horizontal_shift + if downsample != 1: + self.vertical_shift = int(self.vertical_shift // downsample) + self.horizontal_shift = int(self.horizontal_shift // downsample) + + self.crop["vertical"][0] = int(self.crop["vertical"][0] // downsample) + self.crop["vertical"][1] = int(self.crop["vertical"][1] // downsample) + self.crop["horizontal"][0] = int(self.crop["horizontal"][0] // downsample) + self.crop["horizontal"][1] = int(self.crop["horizontal"][1] // downsample) + # download dataset if necessary if data_dir is None: data_dir = os.path.join( @@ -432,8 +458,6 @@ def __init__( # load PSF self.flip_measurement = flip - self.vertical_shift = vertical_shift - self.horizontal_shift = horizontal_shift psf, background = load_psf( psf_path, downsample=downsample * 4, # PSF is 4x the resolution of the images diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index e8d05a31..dbeacbdc 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -56,6 +56,7 @@ from lensless.utils.io import load_psf from lensless.utils.io import save_image from lensless.utils.plot import plot_image +from lensless import ADMM import matplotlib.pyplot as plt # A logger for this file @@ -177,7 +178,7 @@ def simulate_dataset(config): return train_ds_prop, test_ds_prop, mask -def prep_trainable_mask(config, psf, grayscale=False): +def prep_trainable_mask(config, psf, downsample=None): mask = None if config.trainable_mask.mask_type is not None: mask_class = getattr(lensless.hardware.trainable_mask, config.trainable_mask.mask_type) @@ -186,6 +187,28 @@ def prep_trainable_mask(config, psf, grayscale=False): initial_mask = torch.rand_like(psf) elif config.trainable_mask.initial_value == "psf": initial_mask = psf.clone() + # if file ending with "npy" + elif config.trainable_mask.initial_value.endswith("npy"): + pattern = np.load(os.path.join(get_original_cwd(), config.trainable_mask.initial_value)) + + # -- apply aperture + ap_center = np.array(config.trainable_mask.ap_center) + ap_shape = np.array(config.trainable_mask.ap_shape) + # -- extract aperture region + idx_1 = ap_center[0] - ap_shape[0] // 2 + idx_2 = ap_center[1] - ap_shape[1] // 2 + + initial_mask = pattern[ + :, + idx_1 : idx_1 + ap_shape[0], + idx_2 : idx_2 + ap_shape[1], + ] + initial_mask = initial_mask / 255.0 + if config.trainable_mask.slm == "adafruit": + # flatten color channel along rows + initial_mask = initial_mask.reshape((-1, initial_mask.shape[-1]), order="F") + + initial_mask = torch.from_numpy(initial_mask.astype(np.float32)) else: raise ValueError( f"Initial PSF value {config.trainable_mask.initial_value} not supported" @@ -195,7 +218,7 @@ def prep_trainable_mask(config, psf, grayscale=False): initial_mask = rgb2gray(initial_mask) mask = mask_class( - initial_mask, optimizer="Adam", lr=config.trainable_mask.mask_lr, grayscale=grayscale + initial_mask, optimizer="Adam", downsample=downsample, **config.trainable_mask ) return mask @@ -206,10 +229,9 @@ def train_unrolled(config): # set seed seed = config.seed - if seed is not None: - torch.manual_seed(seed) - np.random.seed(seed) - generator = torch.Generator().manual_seed(seed) + torch.manual_seed(seed) + np.random.seed(seed) + generator = torch.Generator().manual_seed(seed) save = config.save if save: @@ -268,48 +290,14 @@ def train_unrolled(config): celeba_root=config.files.celeba_root, psf_path=os.path.join(get_original_cwd(), config.files.psf), downsample=config.files.downsample, - vertical_shift=config.files.vertical_shift, - horizontal_shift=config.files.horizontal_shift, + # vertical_shift=config.files.vertical_shift, + # horizontal_shift=config.files.horizontal_shift, simulation_config=config.simulation, + crop=config.training.crop, ) dataset.psf = dataset.psf.to(device) - psf = dataset.psf log.info(f"Data shape : {dataset[0][0].shape}") - # reconstruct lensless with ADMM - lensless, lensed = dataset[0] - from lensless import ADMM - - recon = ADMM(psf) - recon.set_data(lensless.to(psf.device)) - print("Reconstructing lensless image with ADMM...") - start_time = time.time() - res = recon.apply(disp_iter=None, plot=False, n_iter=10) - print(f"Processing time : {time.time() - start_time} s") - res_np = res[0].cpu().numpy() - res_np = res_np / res_np.max() - save_image(res_np, "lensless_recon.png") - lensed_np = lensed[0].cpu().numpy() - save_image(lensed_np, "lensed.png") - lensless_np = lensless[0].cpu().numpy() - save_image(lensless_np, "lensless_raw.png") - - # -- plot lensed and res on top of each other - if config.training.crop is not None: - res_np = res_np[ - config.training.crop.vertical[0] : config.training.crop.vertical[1], - config.training.crop.horizontal[0] : config.training.crop.horizontal[1], - ] - lensed_np = lensed_np[ - config.training.crop.vertical[0] : config.training.crop.vertical[1], - config.training.crop.horizontal[0] : config.training.crop.horizontal[1], - ] - log.info(f"Cropped shape : {res_np.shape}") - plt.figure() - plt.imshow(lensed_np, alpha=0.5) - plt.imshow(res_np, alpha=0.7) - plt.savefig("overlay_lensed_recon.png") - # train-test split train_size = int((1 - config.files.test_size) * len(dataset)) test_size = len(dataset) - train_size @@ -321,17 +309,78 @@ def train_unrolled(config): test_set = Subset(test_set, np.arange(config.files.n_files)) # -- if learning mask - mask = prep_trainable_mask(config, dataset.psf) + downsample = config.files.downsample * 4 # measured files are 4x downsampled + mask = prep_trainable_mask(config, dataset.psf, downsample=downsample) + if mask is not None: # plot initial PSF - psf_np = mask.get_psf().detach().cpu().numpy()[0, ...] - if config.trainable_mask.grayscale: - psf_np = psf_np[:, :, -1] + with torch.no_grad(): + psf_np = mask.get_psf().detach().cpu().numpy()[0, ...] + if config.trainable_mask.grayscale: + psf_np = psf_np[:, :, -1] save_image(psf_np, os.path.join(save, "psf_initial.png")) plot_image(psf_np, gamma=config.display.gamma) plt.savefig(os.path.join(save, "psf_initial_plot.png")) + # save original PSF as well + psf_meas = dataset.psf.detach().cpu().numpy()[0, ...] + plot_image(psf_meas, gamma=config.display.gamma) + plt.savefig(os.path.join(save, "psf_meas_plot.png")) + + with torch.no_grad(): + psf = mask.get_psf().to(dataset.psf) + + else: + + psf = dataset.psf + + # print info about PSF + log.info(f"PSF shape : {psf.shape}") + log.info(f"PSF min : {psf.min()}") + log.info(f"PSF max : {psf.max()}") + log.info(f"PSF dtype : {psf.dtype}") + log.info(f"PSF norm : {psf.norm()}") + + # reconstruct lensless with ADMM + if config.test_idx is not None: + + log.info("Reconstruction a few images with ADMM...") + + for i, _idx in enumerate(config.test_idx): + + lensless, lensed = dataset[_idx] + recon = ADMM(psf) + recon.set_data(lensless.to(psf.device)) + start_time = time.time() + res = recon.apply(disp_iter=None, plot=False, n_iter=10) + res_np = res[0].cpu().numpy() + res_np = res_np / res_np.max() + save_image(res_np, f"lensless_recon_{_idx}.png") + lensed_np = lensed[0].cpu().numpy() + save_image(lensed_np, f"lensed_{_idx}.png") + lensless_np = lensless[0].cpu().numpy() + save_image(lensless_np, f"lensless_raw_{_idx}.png") + + # -- plot lensed and res on top of each other + if config.training.crop_preloss: + crop = dataset.crop + + res_np = res_np[ + crop["vertical"][0] : crop["vertical"][1], + crop["horizontal"][0] : crop["horizontal"][1], + ] + lensed_np = lensed_np[ + crop["vertical"][0] : crop["vertical"][1], + crop["horizontal"][0] : crop["horizontal"][1], + ] + if i == 0: + log.info(f"Cropped shape : {res_np.shape}") + plt.figure() + plt.imshow(lensed_np, alpha=0.4) + plt.imshow(res_np, alpha=0.7) + plt.savefig(f"overlay_lensed_recon_{_idx}.png") + else: train_set, test_set, mask = simulate_dataset(config) @@ -418,7 +467,7 @@ def train_unrolled(config): save_every=config.training.save_every, gamma=config.display.gamma, logger=log, - crop=config.training.crop, + crop=dataset.crop if config.training.crop_preloss else None, ) trainer.train(n_epoch=config.training.epoch, save_pt=save, disp=config.eval_disp_idx)