diff --git a/lensless/hardware/trainable_mask.py b/lensless/hardware/trainable_mask.py index 897efac4..4618acf3 100644 --- a/lensless/hardware/trainable_mask.py +++ b/lensless/hardware/trainable_mask.py @@ -7,9 +7,13 @@ # ############################################################################# import abc +import omegaconf +import os +import numpy as np +from hydra.utils import get_original_cwd import torch -from lensless.utils.image import is_grayscale -from lensless.hardware.slm import get_programmable_mask, get_intensity_psf +from lensless.utils.image import is_grayscale, rgb2gray +from lensless.hardware.slm import full2subpattern, get_programmable_mask, get_intensity_psf from lensless.hardware.sensor import VirtualSensor from waveprop.devices import slm_dict from lensless.hardware.mask import CodedAperture @@ -144,7 +148,7 @@ def __init__( mask2sensor=None, downsample=None, min_val=0, - **kwargs + **kwargs, ): """ Parameters @@ -264,7 +268,7 @@ def __init__( torch_device="cuda", optimizer="Adam", lr=1e-3, - **kwargs + **kwargs, ): """ TODO: Distinguish between separable and non-separable. @@ -284,7 +288,7 @@ def __init__( psf_wavelength=[460e-9], is_torch=True, torch_device=torch_device, - **kwargs + **kwargs, ) self._mask = self._mask_obj.mask @@ -341,3 +345,103 @@ def project(self): self._mask_obj.compute_psf() self._psf = self._mask_obj.psf.unsqueeze(0) self._psf = self._psf / self._psf.norm() + + +""" +Utilities to prepare trainable masks. +""" + +trainable_mask_dict = { + "AdafruitLCD": AdafruitLCD, + "TrainablePSF": TrainablePSF, + "TrainableCodedAperture": TrainableCodedAperture, + "TrainableHeightVarying": None, + "TrainableMultiLensArray": None, +} + + +def prep_trainable_mask(config, psf=None, downsample=None): + mask = None + color_filter = None + downsample = config.files.downsample if downsample is None else downsample + if config.trainable_mask.mask_type is not None: + + assert config.trainable_mask.mask_type in trainable_mask_dict.keys(), ( + f"Trainable mask type {config.trainable_mask.mask_type} not supported. " + f"Supported types are {trainable_mask_dict.keys()}" + ) + mask_class = trainable_mask_dict[config.trainable_mask.mask_type] + + if isinstance(config.trainable_mask.initial_value, omegaconf.dictconfig.DictConfig): + + # from mask config + mask = mask_class( + # mask = TrainableCodedAperture( + sensor_name=config.simulation.sensor, + downsample=downsample, + distance_sensor=config.simulation.mask2sensor, + optimizer=config.trainable_mask.optimizer, + lr=config.trainable_mask.mask_lr, + binary=config.trainable_mask.binary, + torch_device=config.torch_device, + **config.trainable_mask.initial_value, + ) + + else: + + if config.trainable_mask.initial_value == "random": + if psf is not None: + initial_mask = torch.rand_like(psf) + else: + sensor = VirtualSensor.from_name( + config.simulation.sensor, downsample=downsample + ) + resolution = sensor.resolution + initial_mask = torch.rand((1, *resolution, 3)) + elif config.trainable_mask.initial_value == "psf": + initial_mask = psf.clone() + # if file ending with "npy" + elif config.trainable_mask.initial_value.endswith("npy"): + pattern = np.load( + os.path.join(get_original_cwd(), config.trainable_mask.initial_value) + ) + + initial_mask = full2subpattern( + pattern=pattern, + shape=config.trainable_mask.ap_shape, + center=config.trainable_mask.ap_center, + slm=config.trainable_mask.slm, + ) + initial_mask = torch.from_numpy(initial_mask.astype(np.float32)) + + # prepare color filter if needed + from waveprop.devices import slm_dict + from waveprop.devices import SLMParam as SLMParam_wp + + slm_param = slm_dict[config.trainable_mask.slm] + if ( + config.trainable_mask.train_color_filter + and SLMParam_wp.COLOR_FILTER in slm_param.keys() + ): + color_filter = slm_param[SLMParam_wp.COLOR_FILTER] + color_filter = torch.from_numpy(color_filter.copy()).to(dtype=torch.float32) + + # add small random values + color_filter = color_filter + 0.1 * torch.rand_like(color_filter) + + else: + raise ValueError( + f"Initial PSF value {config.trainable_mask.initial_value} not supported" + ) + + if config.trainable_mask.grayscale and not is_grayscale(initial_mask): + initial_mask = rgb2gray(initial_mask) + + mask = mask_class( + initial_mask, + downsample=downsample, + color_filter=color_filter, + **config.trainable_mask, + ) + + return mask diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index 3cc1803c..4dc617d0 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -33,13 +33,13 @@ """ import logging -import omegaconf import hydra from hydra.utils import get_original_cwd import os import numpy as np import time from lensless import UnrolledFISTA, UnrolledADMM +from lensless.hardware.trainable_mask import prep_trainable_mask from lensless.utils.dataset import ( DiffuserCamMirflickr, SimulatedFarFieldDataset, @@ -48,9 +48,6 @@ HITLDatasetTrainableMask, ) from torch.utils.data import Subset -import lensless.hardware.trainable_mask -from lensless.hardware.slm import full2subpattern -from lensless.hardware.sensor import VirtualSensor from lensless.recon.utils import create_process_network from lensless.utils.image import rgb2gray, is_grayscale from lensless.utils.simulation import FarFieldSimulator @@ -277,88 +274,6 @@ def simulate_dataset(config, generator=None): return train_ds_prop, test_ds_prop, mask -def prep_trainable_mask(config, psf=None, downsample=None): - mask = None - color_filter = None - downsample = config.files.downsample if downsample is None else downsample - if config.trainable_mask.mask_type is not None: - mask_class = getattr(lensless.hardware.trainable_mask, config.trainable_mask.mask_type) - - if isinstance(config.trainable_mask.initial_value, omegaconf.dictconfig.DictConfig): - - # from mask config - mask = mask_class( - # mask = TrainableCodedAperture( - sensor_name=config.simulation.sensor, - downsample=downsample, - distance_sensor=config.simulation.mask2sensor, - optimizer=config.trainable_mask.optimizer, - lr=config.trainable_mask.mask_lr, - binary=config.trainable_mask.binary, - torch_device=config.torch_device, - **config.trainable_mask.initial_value, - ) - - else: - - if config.trainable_mask.initial_value == "random": - if psf is not None: - initial_mask = torch.rand_like(psf) - else: - sensor = VirtualSensor.from_name( - config.simulation.sensor, downsample=downsample - ) - resolution = sensor.resolution - initial_mask = torch.rand((1, *resolution, 3)) - elif config.trainable_mask.initial_value == "psf": - initial_mask = psf.clone() - # if file ending with "npy" - elif config.trainable_mask.initial_value.endswith("npy"): - pattern = np.load( - os.path.join(get_original_cwd(), config.trainable_mask.initial_value) - ) - - initial_mask = full2subpattern( - pattern=pattern, - shape=config.trainable_mask.ap_shape, - center=config.trainable_mask.ap_center, - slm=config.trainable_mask.slm, - ) - initial_mask = torch.from_numpy(initial_mask.astype(np.float32)) - - # prepare color filter if needed - from waveprop.devices import slm_dict - from waveprop.devices import SLMParam as SLMParam_wp - - slm_param = slm_dict[config.trainable_mask.slm] - if ( - config.trainable_mask.train_color_filter - and SLMParam_wp.COLOR_FILTER in slm_param.keys() - ): - color_filter = slm_param[SLMParam_wp.COLOR_FILTER] - color_filter = torch.from_numpy(color_filter.copy()).to(dtype=torch.float32) - - # add small random values - color_filter = color_filter + 0.1 * torch.rand_like(color_filter) - - else: - raise ValueError( - f"Initial PSF value {config.trainable_mask.initial_value} not supported" - ) - - if config.trainable_mask.grayscale and not is_grayscale(initial_mask): - initial_mask = rgb2gray(initial_mask) - - mask = mask_class( - initial_mask, - downsample=downsample, - color_filter=color_filter, - **config.trainable_mask, - ) - - return mask - - @hydra.main(version_base=None, config_path="../../configs", config_name="train_unrolledADMM") def train_unrolled(config):