Skip to content

Commit

Permalink
Add trainable inverse.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Jan 29, 2024
1 parent 798d1d8 commit 96087a6
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 2 deletions.
5 changes: 5 additions & 0 deletions configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ reconstruction:
# Method: unrolled_admm, unrolled_fista
method: unrolled_admm
skip_unrolled: False
init_processors: null # model name
init_pre: True # if `init_processors`, set pre-procesor is available
init_post: True # if `init_processors`, set post-procesor is available

# Hyperparameters for each method
unrolled_fista: # for unrolled_fista
Expand All @@ -62,6 +65,8 @@ reconstruction:
mu2: 1e-4
mu3: 1e-4
tau: 2e-4
trainable_inv:
K: 1e-4
pre_process:
network : null # UnetRes or DruNet or null
depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet
Expand Down
1 change: 1 addition & 0 deletions lensless/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .recon.trainable_recon import TrainableReconstructionAlgorithm
from .recon.unrolled_admm import UnrolledADMM
from .recon.unrolled_fista import UnrolledFISTA
from .recon.trainable_inversion import TrainableInversion
except Exception:
pass

Expand Down
84 changes: 83 additions & 1 deletion lensless/hardware/trainable_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@

import abc
import torch
import numpy as np
from lensless.utils.image import is_grayscale
from lensless.hardware.slm import get_programmable_mask, get_intensity_psf
from lensless.hardware.sensor import VirtualSensor
from waveprop.devices import slm_dict
from lensless.hardware.slm import full2subpattern
from lensless.utils.image import rgb2gray


class TrainableMask(torch.nn.Module, metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -129,7 +132,7 @@ def __init__(
mask2sensor=None,
downsample=None,
min_val=0,
**kwargs
**kwargs,
):
"""
Parameters
Expand Down Expand Up @@ -230,3 +233,82 @@ def project(self):
self.color_filter.data = self.color_filter / self.color_filter.sum(
dim=[1, 2]
).unsqueeze(-1).unsqueeze(-1)


"""
Utility functions to help prepare trainable masks.
"""

mask_type_to_class = {
"TrainablePSF": TrainablePSF,
"AdafruitLCD": AdafruitLCD,
}


def prep_trainable_mask(config, psf=None, downsample=None):

mask = None
color_filter = None
downsample = config["files"]["downsample"] if downsample is None else downsample
if config["trainable_mask"]["mask_type"] is not None:
mask_class = mask_type_to_class[config["trainable_mask"]["mask_type"]]

if config["trainable_mask"]["initial_value"] == "random":
if psf is not None:
initial_mask = torch.rand_like(psf)
else:
sensor = VirtualSensor.from_name(
config["simulation"]["sensor"], downsample=downsample
)
resolution = sensor.resolution
initial_mask = torch.rand((1, *resolution, 3))

elif config["trainable_mask"]["initial_value"] == "psf":
initial_mask = psf.clone()

# if file ending with "npy"
elif config["trainable_mask"]["initial_value"].endswith("npy"):

pattern = np.load(config["trainable_mask"]["initial_value"])

initial_mask = full2subpattern(
pattern=pattern,
shape=config["trainable_mask"]["ap_shape"],
center=config["trainable_mask"]["ap_center"],
slm=config["trainable_mask"]["slm"],
)
initial_mask = torch.from_numpy(initial_mask.astype(np.float32))

# prepare color filter if needed
from waveprop.devices import slm_dict
from waveprop.devices import SLMParam as SLMParam_wp

slm_param = slm_dict[config["trainable_mask"]["slm"]]
if (
config["trainable_mask"]["train_color_filter"]
and SLMParam_wp.COLOR_FILTER in slm_param.keys()
):
color_filter = slm_param[SLMParam_wp.COLOR_FILTER]
color_filter = torch.from_numpy(color_filter.copy()).to(dtype=torch.float32)

# add small random values
color_filter = color_filter + 0.1 * torch.rand_like(color_filter)

else:
raise ValueError(
f"Initial PSF value {config['trainable_mask']['initial_value']} not supported"
)

# convert to grayscale if needed
if config["trainable_mask"]["grayscale"] and not is_grayscale(initial_mask):
initial_mask = rgb2gray(initial_mask)

mask = mask_class(
initial_mask,
optimizer="Adam",
downsample=downsample,
color_filter=color_filter,
**config["trainable_mask"],
)

return mask
50 changes: 50 additions & 0 deletions lensless/recon/trainable_inversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# #############################################################################
# trainable_inversion.py
# =================
# Authors :
# Eric BEZZAM [[email protected]]
# #############################################################################

from lensless.recon.trainable_recon import TrainableReconstructionAlgorithm


class TrainableInversion(TrainableReconstructionAlgorithm):
""" """

def __init__(self, psf, dtype=None, K=1e-4, **kwargs):
"""
Constructor for trainable inversion component as proposed in
the FlatNet work: https://siddiquesalman.github.io/flatnet/
Parameters
----------
psf : :py:class:`~torch.Tensor`
Point spread function (PSF) that models forward propagation.
Must be of shape (depth, height, width, channels) even if
depth = 1 and channels = 1. You can use :py:func:`~lensless.io.load_psf`
to load a PSF from a file such that it is in the correct format.
dtype : float32 or float64
Data type to use for optimization.
K : float
Regularization parameter.
"""

super(TrainableInversion, self).__init__(psf, n_iter=1, dtype=dtype, reset=False, **kwargs)
self._convolver._Hadj = self._convolver._Hadj / (self._convolver._H.norm() ** 2 + K)

self.reset()

def _form_image(self):
self._image_est[self._image_est < 0] = 0
return self._image_est

def _set_psf(self, psf):
return super()._set_psf(psf)

def reset(self, batch_size=1):
# no state variables
return

def _update(self, iter):
self._image_est = self._convolver.deconvolve(self._data)
38 changes: 37 additions & 1 deletion scripts/recon/train_unrolled.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import os
import numpy as np
import time
from lensless import UnrolledFISTA, UnrolledADMM
from lensless import UnrolledFISTA, UnrolledADMM, TrainableInversion
from lensless.utils.dataset import (
DiffuserCamMirflickr,
SimulatedFarFieldDataset,
Expand Down Expand Up @@ -540,6 +540,34 @@ def train_unrolled(config):
param.requires_grad = False
# print(name, param.requires_grad, param.numel())

# initialize pre- and post processor with another model
if config.reconstruction.init_processors is not None:
from lensless.recon.model_dict import load_model, model_dict

model_orig = load_model(
model_dict["diffusercam"]["mirflickr"][config.reconstruction.init_processors],
psf=psf,
device=device,
)

# -- replace pre-process
if config.reconstruction.init_pre:
params1 = model_orig.pre_process_model.named_parameters()
params2 = pre_process.named_parameters()
dict_params2 = dict(params2)
for name1, param1 in params1:
if name1 in dict_params2:
dict_params2[name1].data.copy_(param1.data)

# -- replace post-process
if config.reconstruction.init_post:
params1_post = model_orig.post_process_model.named_parameters()
params2_post = post_process.named_parameters()
dict_params2_post = dict(params2_post)
for name1, param1 in params1_post:
if name1 in dict_params2_post:
dict_params2_post[name1].data.copy_(param1.data)

# create reconstruction algorithm
if config.reconstruction.method == "unrolled_fista":
recon = UnrolledFISTA(
Expand All @@ -566,6 +594,14 @@ def train_unrolled(config):
skip_unrolled=config.reconstruction.skip_unrolled,
return_unrolled_output=True if config.unrolled_output_factor > 0 else False,
).to(device)
elif config.reconstruction.method == "trainable_inv":
recon = TrainableInversion(
psf,
K=config.reconstruction.trainable_inv.K,
pre_process=pre_process if pre_proc_delay is None else None,
post_process=post_process if post_proc_delay is None else None,
return_unrolled_output=True if config.unrolled_output_factor > 0 else False,
).to(device)
else:
raise ValueError(f"{config.reconstruction.method} is not a supported algorithm")

Expand Down

0 comments on commit 96087a6

Please sign in to comment.