Skip to content

Commit

Permalink
update changelog.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Nov 18, 2023
1 parent 49f7a9f commit 5ef7015
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 49 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ Added
- Trainable reconstruction can return intermediate outputs (between pre- and post-processing).
- Auto-download for DRUNet model.
- ``utils.dataset.DiffuserCamMirflickr`` helper class for Mirflickr dataset.
- Option to crop section of image for computing loss when training unrolled.
- Option to learn color filter of RGB mask.
- Trainable mask for Adafruit LCD.
- Utility for capture image.
- Option to freeze/unfreeze/add pre- and post-processor components during training.
- Option to skip unrolled training and just use U-Net.
- Dataset objects for Adafruit LCD: measured CelebA and hardware-in-the-loop.

Changed
~~~~~~~
Expand All @@ -29,6 +36,7 @@ Bugfix
- Support for unrolled reconstruction with grayscale, needed to copy to three channels for LPIPS.
- Fix bad train/test split for DiffuserCamMirflickr in unrolled training.
- Resize utility.
- Aperture, index to dimension conversion.


1.0.5 - (2023-09-05)
Expand Down
2 changes: 1 addition & 1 deletion configs/fine-tune_PSF.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ defaults:

#Trainable Mask
trainable_mask:
mask_type: TrainablePSF #Null or "TrainablePSF"
mask_type: TrainablePSF
initial_value: psf
mask_lr: 1e-3
L1_strength: 1.0 #False or float
Expand Down
15 changes: 12 additions & 3 deletions configs/train_psf_from_scratch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,21 @@ files:

#Trainable Mask
trainable_mask:
mask_type: TrainablePSF #Null or "TrainablePSF"
initial_value: "random"
mask_type: TrainablePSF
initial_value: random

simulation:
grayscale: False

flip: False
scene2mask: 40e-2
mask2sensor: 2e-3
sensor: "rpi_hq"
downsample: 16
object_height: 0.30

training:
crop_preloss: False # crop region for computing loss
batch_size: 8
epoch: 25
eval_batch_size: 16
save_every: 5
2 changes: 1 addition & 1 deletion configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ reconstruction:

#Trainable Mask
trainable_mask:
mask_type: Null #Null or "TrainablePSF" or "AdafruitLCD"
mask_type: null #Null or "TrainablePSF" or "AdafruitLCD"
# "random" (with shape of config.files.psf) or "psf" (using config.files.psf)
initial_value: psf
grayscale: False
Expand Down
18 changes: 0 additions & 18 deletions lensless/hardware/slm.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,6 @@ def get_programmable_mask(

color_filter_idx = i // n_active_slm_pixels[1] % n_color_filter

# if color_filter is not None:
# _rect = np.tile(color_filter[color_filter_idx][0][:, np.newaxis, np.newaxis], (1, _height_pixel, _width_pixel))
# else:
# _rect = np.ones((1, _height_pixel, _width_pixel))

# if use_torch:
# _rect = torch.tensor(_rect).to(slm_vals_flat)

# import pudb; pudb.set_trace()

mask_val = slm_vals_flat[i] * color_filter[color_filter_idx][0]
if isinstance(mask_val, np.ndarray):
mask_val = mask_val[:, np.newaxis, np.newaxis]
Expand All @@ -199,14 +189,6 @@ def get_programmable_mask(
_center_top_left_pixel[1] : _center_top_left_pixel[1] + _width_pixel,
] = mask_val

# mask[
# :,
# _center_top_left_pixel[0] : _center_top_left_pixel[0] + _height_pixel,
# _center_top_left_pixel[1] : _center_top_left_pixel[1] + _width_pixel,
# ] = (
# slm_vals_flat[i] * _rect
# )

# # quantize mask
# if use_torch:
# mask = mask / torch.max(mask)
Expand Down
2 changes: 1 addition & 1 deletion lensless/hardware/trainable_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, initial_mask, optimizer="Adam", lr=1e-3, **kwargs):
"""
super().__init__()
self._mask = torch.nn.Parameter(initial_mask)
self._optimizer = getattr(torch.optim, optimizer)([self._mask], lr=lr)
self._optimizer = getattr(torch.optim, optimizer)([self._mask], lr=lr, **kwargs)
self.train_mask_vals = True
self._counter = 0

Expand Down
58 changes: 33 additions & 25 deletions scripts/recon/train_unrolled.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from torch.utils.data import Subset
import lensless.hardware.trainable_mask
from lensless.hardware.slm import full2subpattern
from lensless.hardware.sensor import VirtualSensor
from lensless.recon.utils import create_process_network
from lensless.utils.image import rgb2gray, is_grayscale
from lensless.utils.simulation import FarFieldSimulator
Expand All @@ -72,23 +73,31 @@ def simulate_dataset(config, generator=None):
else:
device = "cpu"

# prepare PSF
psf_fp = os.path.join(get_original_cwd(), config.files.psf)
psf, _ = load_psf(
psf_fp,
downsample=config.files.downsample,
return_float=True,
return_bg=True,
bg_pix=(0, 15),
)
if config.files.diffusercam_psf:
transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]])
psf = transform_BRG2RGB(torch.from_numpy(psf))
# -- prepare PSF
psf = None
if config.trainable_mask.mask_type is None or config.trainable_mask.initial_value == "psf":
psf_fp = os.path.join(get_original_cwd(), config.files.psf)
psf, _ = load_psf(
psf_fp,
downsample=config.files.downsample,
return_float=True,
return_bg=True,
bg_pix=(0, 15),
)
if config.files.diffusercam_psf:
transform_BRG2RGB = transforms.Lambda(lambda x: x[..., [2, 1, 0]])
psf = transform_BRG2RGB(torch.from_numpy(psf))

# drop depth dimension
psf = psf.to(device)

# drop depth dimension
psf = psf.to(device)
else:
# training mask / PSF
# mask = prep_trainable_mask(config, psf, downsample=config.files.downsample)
mask = prep_trainable_mask(config, psf, downsample=config.simulation.downsample)
psf = mask.get_psf().to(device)

# load dataset
# -- load dataset
pre_transform = None
transforms_list = [transforms.ToTensor()]
data_path = os.path.join(get_original_cwd(), "data")
Expand Down Expand Up @@ -117,9 +126,6 @@ def simulate_dataset(config, generator=None):
assert os.path.isdir(
data_path
), f"Data path {data_path} does not exist. Make sure you download the CelebA dataset and provide the parent directory as 'config.files.celeba_root'. Download link: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html"

# add rotate by 90 degrees to transform list
# pre_transform = transforms.RandomRotation(degrees=(-90, -90))
transform = transforms.Compose(transforms_list)
if config.files.n_files is None:
train_ds = datasets.CelebA(
Expand All @@ -143,11 +149,6 @@ def simulate_dataset(config, generator=None):
if config.simulation.grayscale and not is_grayscale(psf):
psf = rgb2gray(psf)

# prepare mask
# mask = prep_trainable_mask(config, psf, grayscale=config.simulation.grayscale)
mask = prep_trainable_mask(config, psf, downsample=config.files.downsample)
psf = mask.get_psf().to(device)

# check if gpu is available
device_conv = config.torch_device
if device_conv == "cuda" and torch.cuda.is_available():
Expand All @@ -162,6 +163,8 @@ def simulate_dataset(config, generator=None):
**config.simulation,
)

# import pudb; pudb.set_trace()

# create Pytorch dataset and dataloader
crop = config.files.crop.copy() if config.files.crop is not None else None
if mask is None:
Expand Down Expand Up @@ -264,14 +267,19 @@ def simulate_dataset(config, generator=None):
return train_ds_prop, test_ds_prop, mask


def prep_trainable_mask(config, psf, downsample=None):
def prep_trainable_mask(config, psf=None, downsample=None):
mask = None
color_filter = None
if config.trainable_mask.mask_type is not None:
mask_class = getattr(lensless.hardware.trainable_mask, config.trainable_mask.mask_type)

if config.trainable_mask.initial_value == "random":
initial_mask = torch.rand_like(psf)
if psf is not None:
initial_mask = torch.rand_like(psf)
else:
sensor = VirtualSensor.from_name(config.simulation.sensor, downsample=downsample)
resolution = sensor.resolution
initial_mask = torch.rand((1, *resolution, 3))
elif config.trainable_mask.initial_value == "psf":
initial_mask = psf.clone()
# if file ending with "npy"
Expand Down

0 comments on commit 5ef7015

Please sign in to comment.