Skip to content

Commit

Permalink
Move prep trainable mask into package.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Dec 18, 2023
1 parent 29f0126 commit d2cc322
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 91 deletions.
114 changes: 109 additions & 5 deletions lensless/hardware/trainable_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
# #############################################################################

import abc
import omegaconf
import os
import numpy as np
from hydra.utils import get_original_cwd
import torch
from lensless.utils.image import is_grayscale
from lensless.hardware.slm import get_programmable_mask, get_intensity_psf
from lensless.utils.image import is_grayscale, rgb2gray
from lensless.hardware.slm import full2subpattern, get_programmable_mask, get_intensity_psf
from lensless.hardware.sensor import VirtualSensor
from waveprop.devices import slm_dict
from lensless.hardware.mask import CodedAperture
Expand Down Expand Up @@ -144,7 +148,7 @@ def __init__(
mask2sensor=None,
downsample=None,
min_val=0,
**kwargs
**kwargs,
):
"""
Parameters
Expand Down Expand Up @@ -264,7 +268,7 @@ def __init__(
torch_device="cuda",
optimizer="Adam",
lr=1e-3,
**kwargs
**kwargs,
):
"""
TODO: Distinguish between separable and non-separable.
Expand All @@ -284,7 +288,7 @@ def __init__(
psf_wavelength=[460e-9],
is_torch=True,
torch_device=torch_device,
**kwargs
**kwargs,
)
self._mask = self._mask_obj.mask

Expand Down Expand Up @@ -341,3 +345,103 @@ def project(self):
self._mask_obj.compute_psf()
self._psf = self._mask_obj.psf.unsqueeze(0)
self._psf = self._psf / self._psf.norm()


"""
Utilities to prepare trainable masks.
"""

trainable_mask_dict = {
"AdafruitLCD": AdafruitLCD,
"TrainablePSF": TrainablePSF,
"TrainableCodedAperture": TrainableCodedAperture,
"TrainableHeightVarying": None,
"TrainableMultiLensArray": None,
}


def prep_trainable_mask(config, psf=None, downsample=None):
mask = None
color_filter = None
downsample = config.files.downsample if downsample is None else downsample
if config.trainable_mask.mask_type is not None:

assert config.trainable_mask.mask_type in trainable_mask_dict.keys(), (
f"Trainable mask type {config.trainable_mask.mask_type} not supported. "
f"Supported types are {trainable_mask_dict.keys()}"
)
mask_class = trainable_mask_dict[config.trainable_mask.mask_type]

if isinstance(config.trainable_mask.initial_value, omegaconf.dictconfig.DictConfig):

# from mask config
mask = mask_class(
# mask = TrainableCodedAperture(
sensor_name=config.simulation.sensor,
downsample=downsample,
distance_sensor=config.simulation.mask2sensor,
optimizer=config.trainable_mask.optimizer,
lr=config.trainable_mask.mask_lr,
binary=config.trainable_mask.binary,
torch_device=config.torch_device,
**config.trainable_mask.initial_value,
)

else:

if config.trainable_mask.initial_value == "random":
if psf is not None:
initial_mask = torch.rand_like(psf)
else:
sensor = VirtualSensor.from_name(
config.simulation.sensor, downsample=downsample
)
resolution = sensor.resolution
initial_mask = torch.rand((1, *resolution, 3))
elif config.trainable_mask.initial_value == "psf":
initial_mask = psf.clone()
# if file ending with "npy"
elif config.trainable_mask.initial_value.endswith("npy"):
pattern = np.load(
os.path.join(get_original_cwd(), config.trainable_mask.initial_value)
)

initial_mask = full2subpattern(
pattern=pattern,
shape=config.trainable_mask.ap_shape,
center=config.trainable_mask.ap_center,
slm=config.trainable_mask.slm,
)
initial_mask = torch.from_numpy(initial_mask.astype(np.float32))

# prepare color filter if needed
from waveprop.devices import slm_dict
from waveprop.devices import SLMParam as SLMParam_wp

slm_param = slm_dict[config.trainable_mask.slm]
if (
config.trainable_mask.train_color_filter
and SLMParam_wp.COLOR_FILTER in slm_param.keys()
):
color_filter = slm_param[SLMParam_wp.COLOR_FILTER]
color_filter = torch.from_numpy(color_filter.copy()).to(dtype=torch.float32)

# add small random values
color_filter = color_filter + 0.1 * torch.rand_like(color_filter)

else:
raise ValueError(
f"Initial PSF value {config.trainable_mask.initial_value} not supported"
)

if config.trainable_mask.grayscale and not is_grayscale(initial_mask):
initial_mask = rgb2gray(initial_mask)

mask = mask_class(
initial_mask,
downsample=downsample,
color_filter=color_filter,
**config.trainable_mask,
)

return mask
87 changes: 1 addition & 86 deletions scripts/recon/train_unrolled.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@
"""

import logging
import omegaconf
import hydra
from hydra.utils import get_original_cwd
import os
import numpy as np
import time
from lensless import UnrolledFISTA, UnrolledADMM
from lensless.hardware.trainable_mask import prep_trainable_mask
from lensless.utils.dataset import (
DiffuserCamMirflickr,
SimulatedFarFieldDataset,
Expand All @@ -48,9 +48,6 @@
HITLDatasetTrainableMask,
)
from torch.utils.data import Subset
import lensless.hardware.trainable_mask
from lensless.hardware.slm import full2subpattern
from lensless.hardware.sensor import VirtualSensor
from lensless.recon.utils import create_process_network
from lensless.utils.image import rgb2gray, is_grayscale
from lensless.utils.simulation import FarFieldSimulator
Expand Down Expand Up @@ -277,88 +274,6 @@ def simulate_dataset(config, generator=None):
return train_ds_prop, test_ds_prop, mask


def prep_trainable_mask(config, psf=None, downsample=None):
mask = None
color_filter = None
downsample = config.files.downsample if downsample is None else downsample
if config.trainable_mask.mask_type is not None:
mask_class = getattr(lensless.hardware.trainable_mask, config.trainable_mask.mask_type)

if isinstance(config.trainable_mask.initial_value, omegaconf.dictconfig.DictConfig):

# from mask config
mask = mask_class(
# mask = TrainableCodedAperture(
sensor_name=config.simulation.sensor,
downsample=downsample,
distance_sensor=config.simulation.mask2sensor,
optimizer=config.trainable_mask.optimizer,
lr=config.trainable_mask.mask_lr,
binary=config.trainable_mask.binary,
torch_device=config.torch_device,
**config.trainable_mask.initial_value,
)

else:

if config.trainable_mask.initial_value == "random":
if psf is not None:
initial_mask = torch.rand_like(psf)
else:
sensor = VirtualSensor.from_name(
config.simulation.sensor, downsample=downsample
)
resolution = sensor.resolution
initial_mask = torch.rand((1, *resolution, 3))
elif config.trainable_mask.initial_value == "psf":
initial_mask = psf.clone()
# if file ending with "npy"
elif config.trainable_mask.initial_value.endswith("npy"):
pattern = np.load(
os.path.join(get_original_cwd(), config.trainable_mask.initial_value)
)

initial_mask = full2subpattern(
pattern=pattern,
shape=config.trainable_mask.ap_shape,
center=config.trainable_mask.ap_center,
slm=config.trainable_mask.slm,
)
initial_mask = torch.from_numpy(initial_mask.astype(np.float32))

# prepare color filter if needed
from waveprop.devices import slm_dict
from waveprop.devices import SLMParam as SLMParam_wp

slm_param = slm_dict[config.trainable_mask.slm]
if (
config.trainable_mask.train_color_filter
and SLMParam_wp.COLOR_FILTER in slm_param.keys()
):
color_filter = slm_param[SLMParam_wp.COLOR_FILTER]
color_filter = torch.from_numpy(color_filter.copy()).to(dtype=torch.float32)

# add small random values
color_filter = color_filter + 0.1 * torch.rand_like(color_filter)

else:
raise ValueError(
f"Initial PSF value {config.trainable_mask.initial_value} not supported"
)

if config.trainable_mask.grayscale and not is_grayscale(initial_mask):
initial_mask = rgb2gray(initial_mask)

mask = mask_class(
initial_mask,
downsample=downsample,
color_filter=color_filter,
**config.trainable_mask,
)

return mask


@hydra.main(version_base=None, config_path="../../configs", config_name="train_unrolledADMM")
def train_unrolled(config):

Expand Down

0 comments on commit d2cc322

Please sign in to comment.