From 91923f9798ef12c40f174c27e4001110563afde2 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Tue, 19 Sep 2023 13:17:05 +0000 Subject: [PATCH] Clean up fine-tuning PSF. --- configs/fine-tune_PSF.yaml | 3 +- configs/train_unrolledADMM.yaml | 4 +- lensless/hardware/trainable_mask.py | 9 ++- scripts/recon/train_unrolled.py | 85 +++++++++++------------------ 4 files changed, 44 insertions(+), 57 deletions(-) diff --git a/configs/fine-tune_PSF.yaml b/configs/fine-tune_PSF.yaml index 6047f618..9576a111 100644 --- a/configs/fine-tune_PSF.yaml +++ b/configs/fine-tune_PSF.yaml @@ -6,14 +6,13 @@ defaults: #Trainable Mask trainable_mask: mask_type: TrainablePSF #Null or "TrainablePSF" - initial_value: "DiffuserCam" # "random" or "DiffuserCam" or "DiffuserCam_gray" + initial_value: psf mask_lr: 1e-3 L1_strength: 1.0 #False or float use_mask_in_dataset : False # Work only with simulated dataset #Training training: - epoch: 50 save_every: 5 display: diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index ba46d203..786141e3 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -55,7 +55,9 @@ reconstruction: #Trainable Mask trainable_mask: mask_type: Null #Null or "TrainablePSF" - initial_value: "DiffuserCam_gray" # "random" or "DiffuserCam" or "DiffuserCam_gray" + # "random" (with shape of config.files.psf) or "psf" (using config.files.psf) + initial_value: psf + grayscale: False mask_lr: 1e-3 L1_strength: 1.0 #False or float use_mask_in_dataset : True # Work only with simulated dataset diff --git a/lensless/hardware/trainable_mask.py b/lensless/hardware/trainable_mask.py index 7c48d7ea..6517f52a 100644 --- a/lensless/hardware/trainable_mask.py +++ b/lensless/hardware/trainable_mask.py @@ -3,6 +3,7 @@ # ================== # Authors : # Yohann PERRON [yohann.perron@gmail.com] +# Eric BEZZAM [ebezzam@gmail.com] # ############################################################################# import abc @@ -76,12 +77,16 @@ class TrainablePSF(TrainableMask): def __init__(self, initial_mask, optimizer="Adam", lr=1e-3, is_rgb=True, **kwargs): super().__init__(initial_mask, optimizer, lr, **kwargs) self._is_rgb = is_rgb + if is_rgb: + assert initial_mask.shape[-1] == 3, "RGB mask should have 3 channels" + else: + assert initial_mask.shape[-1] == 1, "Monochrome mask should have 1 channel" def get_psf(self): if self._is_rgb: - return self._mask.expand(-1, -1, -1, 3) - else: return self._mask + else: + return self._mask.expand(-1, -1, -1, 3) def project(self): self._mask.data = torch.clamp(self._mask, 0, 1) diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index 4075562f..d2610ede 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -153,6 +153,33 @@ def simulate_dataset(config): return ds_prop, mask +def prep_trainable_mask(config, dataset): + mask = None + if config.trainable_mask.mask_type is not None: + mask_class = getattr(lensless.hardware.trainable_mask, config.trainable_mask.mask_type) + + if config.trainable_mask.initial_value == "random": + initial_mask = torch.rand_like(dataset.psf) + elif config.trainable_mask.initial_value == "psf": + initial_mask = dataset.psf.clone() + else: + raise ValueError( + f"Initial PSF value {config.trainable_mask.initial_value} not supported" + ) + + if config.trainable_mask.grayscale: + initial_mask = rgb2gray(initial_mask) + + mask = mask_class( + initial_mask, + optimizer="Adam", + lr=config.trainable_mask.mask_lr, + is_rgb=not config.trainable_mask.grayscale, + ) + + return mask + + @hydra.main(version_base=None, config_path="../../configs", config_name="train_unrolledADMM") def train_unrolled(config): @@ -171,37 +198,6 @@ def train_unrolled(config): print("Using CPU for training.") device = "cpu" - # # benchmarking dataset: - # eval_path = os.path.join(get_original_cwd(), config.files.eval_dataset) - # benchmark_dataset = DiffuserCamTestDataset( - # data_dir=eval_path, downsample=config.files.downsample, n_files=config.files.n_files - # ) - - # diffusercam_psf = benchmark_dataset.psf.to(device) - # # background = benchmark_dataset.background - - # # convert psf from BGR to RGB - # diffusercam_psf = diffusercam_psf[..., [2, 1, 0]] - - # # create mask - # if config.trainable_mask.mask_type is not None: - # mask_class = getattr(lensless.hardware.trainable_mask, config.trainable_mask.mask_type) - # if config.trainable_mask.initial_value == "random": - # mask = mask_class( - # torch.rand_like(diffusercam_psf), optimizer="Adam", lr=config.trainable_mask.mask_lr - # ) - # elif config.trainable_mask.initial_value == "DiffuserCam": - # mask = mask_class(diffusercam_psf, optimizer="Adam", lr=config.trainable_mask.mask_lr) - # elif config.trainable_mask.initial_value == "DiffuserCam_gray": - # mask = mask_class( - # diffusercam_psf[:, :, :, 0, None], - # optimizer="Adam", - # lr=config.trainable_mask.mask_lr, - # is_rgb=not config.simulation.grayscale, - # ) - # else: - # mask = None - # load dataset and create dataloader train_set = None test_set = None @@ -229,33 +225,19 @@ def train_unrolled(config): print("Test test size : ", len(test_set)) # -- if learning mask - mask = None - if config.trainable_mask.mask_type is not None: - mask_class = getattr(lensless.hardware.trainable_mask, config.trainable_mask.mask_type) - - if config.trainable_mask.initial_value == "random": - mask = mask_class( - torch.rand_like(dataset.psf), optimizer="Adam", lr=config.trainable_mask.mask_lr - ) - # TODO : change to PSF - elif config.trainable_mask.initial_value == "DiffuserCam": - mask = mask_class(dataset.psf, optimizer="Adam", lr=config.trainable_mask.mask_lr) - elif config.trainable_mask.initial_value == "DiffuserCam_gray": - # TODO convert to grayscale - mask = mask_class( - dataset.psf[:, :, :, 0, None], - optimizer="Adam", - lr=config.trainable_mask.mask_lr, - is_rgb=not config.simulation.grayscale, - ) - + mask = prep_trainable_mask(config, dataset) + if mask is not None: # plot initial PSF psf_np = mask.get_psf().detach().cpu().numpy()[0, ...] + if config.trainable_mask.grayscale: + psf_np = psf_np[:, :, -1] + save_image(psf_np, os.path.join(save, "psf_initial.png")) plot_image(psf_np, gamma=config.display.gamma) plt.savefig(os.path.join(save, "psf_initial_plot.png")) else: + # Use a simulated dataset if config.trainable_mask.use_mask_in_dataset: train_set, mask = simulate_dataset(config) @@ -263,7 +245,6 @@ def train_unrolled(config): print("Trainable Mask will be used in the test dataset") test_set = None else: - # TODO handlge case where finetuning PSF train_set, mask = simulate_dataset(config) start_time = time.time()