Skip to content

Commit

Permalink
Add support to train adafruit mask.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Oct 10, 2023
1 parent 653ae56 commit 5892fba
Show file tree
Hide file tree
Showing 7 changed files with 352 additions and 80 deletions.
44 changes: 44 additions & 0 deletions configs/train_celeba_digicam.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# python scripts/recon/train_unrolled.py -cn train_celeba_digicam
defaults:
- train_unrolledADMM
- _self_

# Train Dataset
files:
dataset: /scratch/bezzam/celeba_adafruit_random_2mm_20230720_10K
celeba_root: /scratch/bezzam
psf: data/psf/adafruit_random_2mm_20231907.png

# for prepping ground truth data
simulation:
scene2mask: 0.25 # [m]
mask2sensor: 0.002 # [m]
object_height: 0.33 # [m]


reconstruction:
method: unrolled_admm
unrolled_admm:
# Number of iterations
n_iter: 10

pre_process:
network : null # UnetRes or DruNet or null
depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet
post_process:
network : null # UnetRes or DruNet or null
depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet


# see some outputs of classical ADMM before training
test_idx: [0, 1, 2, 3, 4]

#Training
training:
batch_size: 2
epoch: 50
eval_batch_size: 15

# crop: null
crop_preloss: True

64 changes: 64 additions & 0 deletions configs/train_celeba_digicam_mask.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# python scripts/recon/train_unrolled.py -cn train_celeba_digicam_mask
defaults:
- train_celeba_digicam
- _self_

# Train Dataset
files:
dataset: /scratch/bezzam/celeba_adafruit_random_2mm_20230720_10K
celeba_root: /scratch/bezzam
psf: data/psf/adafruit_random_2mm_20231907.png

# for prepping ground truth data
simulation:
scene2mask: 0.25 # [m]
mask2sensor: 0.002 # [m]
object_height: 0.33 # [m]


reconstruction:
method: unrolled_admm
unrolled_admm:
# Number of iterations
n_iter: 10

pre_process:
network : null # UnetRes or DruNet or null
depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet
post_process:
network : null # UnetRes or DruNet or null
depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet

#Training
training:
batch_size: 2
epoch: 50
eval_batch_size: 15

# crop: null
crop_preloss: True

#Trainable Mask
trainable_mask:
mask_type: AdafruitLCD #Null or "TrainablePSF" or "AdafruitLCD"
# "random" (with shape of config.files.psf) or path to npy file
initial_value: data/psf/adafruit_random_pattern_20230719.npy
grayscale: False
mask_lr: 1e-3
L1_strength: False

# only for AdafruitLCD
ap_center: [59, 76]
ap_shape: [19, 26]
rotate: -0.8 # rotation in degrees
slm: adafruit
sensor: rpi_hq
flipud: True
waveprop: True
# to align with measured PSF (so reconstruction also aligned)
vertical_shift: -20 # [px]
horizontal_shift: -100 # [px]
# below are ignored if waveprop=False
scene2mask: 0.3 # [m]
mask2sensor: 0.002 # [m]

12 changes: 7 additions & 5 deletions configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ files:
torch: True
torch_device: 'cuda'

# see some outputs of classical ADMM before training
test_idx: [0, 1, 2, 3, 4]

# test set example to visualize at the end of every epoch
eval_disp_idx: [0, 1, 2, 3, 4]
Expand Down Expand Up @@ -59,7 +61,7 @@ reconstruction:

#Trainable Mask
trainable_mask:
mask_type: Null #Null or "TrainablePSF"
mask_type: Null #Null or "TrainablePSF" or "AdafruitLCD"
# "random" (with shape of config.files.psf) or "psf" (using config.files.psf)
initial_value: psf
grayscale: False
Expand Down Expand Up @@ -106,10 +108,10 @@ training:
skip_NAN: True
slow_start: False #float how much to reduce lr for first epoch

crop: null # crop region for computing loss
# crop:
# vertical: [30, 560]
# horizontal: [275, 710]
crop_preloss: True # crop region for computing loss
crop: null
# vertical: null
# horizontal: null

optimizer:
type: Adam
Expand Down
98 changes: 96 additions & 2 deletions lensless/hardware/trainable_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import abc
import torch
from lensless.utils.image import is_grayscale
from lensless.hardware.slm import get_programmable_mask, get_intensity_psf
from lensless.hardware.sensor import VirtualSensor
from waveprop.devices import slm_dict


class TrainableMask(torch.nn.Module, metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -37,7 +40,7 @@ def __init__(self, initial_mask, optimizer="Adam", lr=1e-3, **kwargs):
"""
super().__init__()
self._mask = torch.nn.Parameter(initial_mask)
self._optimizer = getattr(torch.optim, optimizer)([self._mask], lr=lr, **kwargs)
self._optimizer = getattr(torch.optim, optimizer)([self._mask], lr=lr)
self._counter = 0

@abc.abstractmethod
Expand All @@ -53,7 +56,7 @@ def get_psf(self):
raise NotImplementedError

def update_mask(self):
"""Update the mask parameters. Acoording to externaly updated gradiants."""
"""Update the mask parameters. According to externaly updated gradiants."""
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)
self.project()
Expand Down Expand Up @@ -100,3 +103,94 @@ def get_psf(self):

def project(self):
self._mask.data = torch.clamp(self._mask, 0, 1)


class AdafruitLCD(TrainableMask):
def __init__(
self,
initial_vals,
sensor,
slm,
rotate=None,
flipud=False,
use_waveprop=None,
vertical_shift=None,
horizontal_shift=None,
scene2mask=None,
mask2sensor=None,
downsample=None,
**kwargs
):
"""
Parameters
----------
initial_vals : :py:class:`~torch.Tensor`
Initial mask parameters.
sensor : :py:class:`~lensless.hardware.sensor.VirtualSensor`
Sensor object.
slm_param : :py:class:`~lensless.hardware.slm.SLMParam`
SLM parameters.
rotate : float, optional
Rotation angle in degrees, by default None
flipud : bool, optional
Whether to flip the mask vertically, by default False
"""
super().__init__(initial_vals, **kwargs)

self.slm_param = slm_dict[slm]
self.sensor = VirtualSensor.from_name(sensor, downsample=downsample)
self.rotate = rotate
self.flipud = flipud
self.use_waveprop = use_waveprop
self.scene2mask = scene2mask
self.mask2sensor = mask2sensor
self.vertical_shift = vertical_shift
self.horizontal_shift = horizontal_shift
if downsample is not None and vertical_shift is not None:
self.vertical_shift = vertical_shift // downsample
if downsample is not None and horizontal_shift is not None:
self.horizontal_shift = horizontal_shift // downsample
if self.use_waveprop:
assert self.scene2mask is not None
assert self.mask2sensor is not None

def get_psf(self):

mask = get_programmable_mask(
vals=self._mask,
sensor=self.sensor,
slm_param=self.slm_param,
rotate=self.rotate,
flipud=self.flipud,
)

if self.vertical_shift is not None:
mask = torch.roll(mask, self.vertical_shift, dims=1)

if self.horizontal_shift is not None:
mask = torch.roll(mask, self.horizontal_shift, dims=2)

psf_in = get_intensity_psf(
mask=mask,
sensor=self.sensor,
waveprop=self.use_waveprop,
scene2mask=self.scene2mask,
mask2sensor=self.mask2sensor,
)

# add first dimension (depth)
psf_in = psf_in.unsqueeze(0)

# move channels to last dimension
psf_in = psf_in.permute(0, 2, 3, 1)

# flip mask
psf_in = torch.flip(psf_in, dims=[-3, -2])

# normalize
psf_in = psf_in / psf_in.norm()

return psf_in

def project(self):
self._mask.data = torch.clamp(self._mask, 0, 1)
33 changes: 14 additions & 19 deletions lensless/recon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,20 +373,7 @@ def __init__(
"lpips package is need for LPIPS loss. Install using : pip install lpips"
)

if crop is not None:
datashape = train_dataset[0][0].shape
# create binary mask to multiply with before computing loss
self.mask_crop = torch.zeros(datashape, dtype=torch.bool).to(self.device)

# move channel dimension to third to last
self.mask_crop = self.mask_crop.movedim(-1, -3)

# set values to True in mask
self.mask_crop[
:, :, crop.vertical[0] : crop.vertical[1], crop.horizontal[0] : crop.horizontal[1]
] = True
else:
self.mask_crop = None
self.crop = crop

# optimizer
if optimizer == "Adam":
Expand Down Expand Up @@ -484,7 +471,7 @@ def train_epoch(self, data_loader):

# update psf according to mask
if self.use_mask:
self.recon._set_psf(self.mask.get_psf())
self.recon._set_psf(self.mask.get_psf().to(self.device))

# forward pass
y_pred = self.recon.batch_call(X.to(self.device))
Expand All @@ -503,9 +490,17 @@ def train_epoch(self, data_loader):
y = y.reshape(-1, *y.shape[-3:]).movedim(-1, -3)

# crop
if self.mask_crop is not None:
y_pred = y_pred * self.mask_crop
y = y * self.mask_crop
if self.crop is not None:
y_pred = y_pred[
...,
self.crop["vertical"][0] : self.crop["vertical"][1],
self.crop["horizontal"][0] : self.crop["horizontal"][1],
]
y = y[
...,
self.crop["vertical"][0] : self.crop["vertical"][1],
self.crop["horizontal"][0] : self.crop["horizontal"][1],
]

loss_v = self.Loss(y_pred, y)
if self.lpips:
Expand Down Expand Up @@ -583,7 +578,7 @@ def evaluate(self, mean_loss, save_pt, epoch, disp=None):
batchsize=self.eval_batch_size,
save_idx=disp,
output_dir=output_dir,
mask_crop=self.mask_crop,
crop=self.crop,
)

# update metrics with current metrics
Expand Down
36 changes: 30 additions & 6 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,13 +355,17 @@ def __init__(
psf_path=None,
downsample=1,
flip=True,
vertical_shift=-85,
horizontal_shift=-15,
vertical_shift=None,
horizontal_shift=None,
crop=None,
simulation_config=None,
**kwargs,
):
"""
Some parameters default to work for the ``celeba_adafruit_random_2mm_20230720_10K`` dataset,
namely: flip, vertical_shift, horizontal_shift, crop, simulation_config.
Parameters
----------
celeba_root : str
Expand All @@ -375,11 +379,33 @@ def __init__(
flip : bool, optional
If True, measurements are flipped, by default ``True``. Does not get applied to the original images.
vertical_shift : int, optional
Vertical shift (in pixels) of the lensed images to align, by default 0.
Vertical shift (in pixels) of the lensed images to align.
horizontal_shift : int, optional
Horizontal shift (in pixels) of the lensed images to align, by default 0.
Horizontal shift (in pixels) of the lensed images to align.
crop : dict, optional
Dictionary of crop parameters (vertical: [start, end], horizontal: [start, end]) to select region of interest.
"""

if vertical_shift is None:
# default to (no downsampling) of celeba_adafruit_random_2mm_20230720_10K
vertical_shift = -85
horizontal_shift = -5

if crop is None:
crop = {"vertical": [30, 560], "horizontal": [285, 720]}
self.crop = crop

self.vertical_shift = vertical_shift
self.horizontal_shift = horizontal_shift
if downsample != 1:
self.vertical_shift = int(self.vertical_shift // downsample)
self.horizontal_shift = int(self.horizontal_shift // downsample)

self.crop["vertical"][0] = int(self.crop["vertical"][0] // downsample)
self.crop["vertical"][1] = int(self.crop["vertical"][1] // downsample)
self.crop["horizontal"][0] = int(self.crop["horizontal"][0] // downsample)
self.crop["horizontal"][1] = int(self.crop["horizontal"][1] // downsample)

# download dataset if necessary
if data_dir is None:
data_dir = os.path.join(
Expand Down Expand Up @@ -432,8 +458,6 @@ def __init__(

# load PSF
self.flip_measurement = flip
self.vertical_shift = vertical_shift
self.horizontal_shift = horizontal_shift
psf, background = load_psf(
psf_path,
downsample=downsample * 4, # PSF is 4x the resolution of the images
Expand Down
Loading

0 comments on commit 5892fba

Please sign in to comment.