Skip to content

Commit

Permalink
Clean up training with simulated dataset.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Sep 20, 2023
1 parent 91923f9 commit eb0c84e
Show file tree
Hide file tree
Showing 12 changed files with 224 additions and 104 deletions.
1 change: 0 additions & 1 deletion configs/fine-tune_PSF.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ trainable_mask:
initial_value: psf
mask_lr: 1e-3
L1_strength: 1.0 #False or float
use_mask_in_dataset : False # Work only with simulated dataset

#Training
training:
Expand Down
13 changes: 9 additions & 4 deletions configs/train_psf_from_scratch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@ defaults:
- train_unrolledADMM
- _self_

# Train Dataset
files:
dataset: mnist # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam"
celeba_root: /scratch/bezzam
downsample: 8

#Trainable Mask
trainable_mask:
mask_type: TrainablePSF #Null or "TrainablePSF"
initial_value: "random" # "random" or "DiffuserCam" or "DiffuserCam_gray"
initial_value: "random"

# Train Dataset
files:
dataset: "CelebA" # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam"
simulation:
grayscale: False
8 changes: 4 additions & 4 deletions configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ hydra:
job:
chdir: True # change to output folder

# Train Dataset
# Dataset
files:
dataset: data/DiffuserCam # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam"
celeba_root: null # path to parent directory of CelebA: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
psf: data/psf.tiff
diffusercam_psf: True
eval_dataset: data/DiffuserCam_Test
n_files: null # null to use all for both datasets
n_files: null # null to use all for both train/test
downsample: 2 # factor by which to downsample the PSF, note that for DiffuserCam the PSF has 4x the resolution

torch: True
Expand Down Expand Up @@ -60,7 +60,6 @@ trainable_mask:
grayscale: False
mask_lr: 1e-3
L1_strength: 1.0 #False or float
use_mask_in_dataset : True # Work only with simulated dataset

target: "object_plane" # "original" or "object_plane" or "label"

Expand All @@ -86,6 +85,7 @@ simulation:
# Downsampling for PSF
downsample: 8
# max val in simulated measured (quantized 8 bits)
quantize: False # must be False for differentiability
max_val: 255

#Training
Expand Down
15 changes: 14 additions & 1 deletion lensless/eval/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,20 @@ def benchmark(model, dataset, batchsize=1, metrics=None, **kwargs):
if metric == "ReconstructionError":
metrics_values[metric] += model.reconstruction_error().cpu().item()
else:
metrics_values[metric] += metrics[metric](prediction, lensed).cpu().item()
if "LPIPS" in metric:
if prediction.shape[1] == 1:
# LPIPS needs 3 channels
metrics_values[metric] += (
metrics[metric](
prediction.repeat(1, 3, 1, 1), lensed.repeat(1, 3, 1, 1)
)
.cpu()
.item()
)
else:
metrics_values[metric] += metrics[metric](prediction, lensed).cpu().item()
else:
metrics_values[metric] += metrics[metric](prediction, lensed).cpu().item()

model.reset()

Expand Down
32 changes: 21 additions & 11 deletions lensless/hardware/trainable_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import abc
import torch
from lensless.utils.image import is_grayscale


class TrainableMask(torch.nn.Module, metaclass=abc.ABCMeta):
Expand Down Expand Up @@ -70,23 +71,32 @@ class TrainablePSF(TrainableMask):
Parameters
----------
is_rgb : bool, optional
Whether the mask is RGB or not, by default True.
grayscale : bool, optional
Whether mask should be returned as grayscale when calling :py:class:`~lensless.hardware.trainable_mask.TrainableMask.get_psf`.
Otherwise PSF will be returned as RGB. By default False.
"""

def __init__(self, initial_mask, optimizer="Adam", lr=1e-3, is_rgb=True, **kwargs):
def __init__(self, initial_mask, optimizer="Adam", lr=1e-3, grayscale=False, **kwargs):
super().__init__(initial_mask, optimizer, lr, **kwargs)
self._is_rgb = is_rgb
if is_rgb:
assert initial_mask.shape[-1] == 3, "RGB mask should have 3 channels"
else:
assert initial_mask.shape[-1] == 1, "Monochrome mask should have 1 channel"
assert (
len(initial_mask.shape) == 4
), "Mask must be of shape (depth, height, width, channels)"
self.grayscale = grayscale
self._is_grayscale = is_grayscale(initial_mask)
if grayscale:
assert self._is_grayscale, "Mask must be grayscale"

def get_psf(self):
if self._is_rgb:
return self._mask
if self._is_grayscale:
if self.grayscale:
# simulation in grayscale
return self._mask
else:
# replicate to 3 channels
return self._mask.expand(-1, -1, -1, 3)
else:
return self._mask.expand(-1, -1, -1, 3)
# assume RGB
return self._mask

def project(self):
self._mask.data = torch.clamp(self._mask, 0, 1)
15 changes: 14 additions & 1 deletion lensless/recon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def __init__(
recon,
train_dataset,
test_dataset,
test_size=0.15,
mask=None,
batch_size=4,
loss="l2",
Expand Down Expand Up @@ -274,6 +275,8 @@ def __init__(
Dataset to use for training.
test_dataset : :py:class:`torch.utils.data.Dataset`
Dataset to use for testing.
test_size : float, optional
If test_dataset is None, fraction of the train dataset to use for testing, by default 0.15.
mask : TrainableMask, optional
Trainable mask to use for training. If none, training with fix psf, by default None.
batch_size : int, optional
Expand Down Expand Up @@ -307,13 +310,16 @@ def __init__(

self.recon = recon

assert train_dataset is not None
if test_dataset is None:
assert test_size < 1.0 and test_size > 0.0
# split train dataset
train_size = int(0.9 * len(train_dataset))
train_size = int((1 - test_size) * len(train_dataset))
test_size = len(train_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(
train_dataset, [train_size, test_size]
)
print(f"Train size : {train_size}, Test size : {test_size}")

self.train_dataloader = torch.utils.data.DataLoader(
dataset=train_dataset,
Expand Down Expand Up @@ -475,6 +481,12 @@ def train_epoch(self, data_loader, disp=-1):

loss_v = self.Loss(y_pred, y)
if self.lpips:

if y_pred.shape[1] == 1:
# if only one channel, repeat for LPIPS
y_pred = y_pred.repeat(1, 3, 1, 1)
y = y.repeat(1, 3, 1, 1)

# value for LPIPS needs to be in range [-1, 1]
loss_v = loss_v + self.lpips * torch.mean(
self.Loss_lpips(2 * y_pred - 1, 2 * y - 1)
Expand Down Expand Up @@ -623,6 +635,7 @@ def save(self, epoch, path="recon", include_optimizer=False):
)

psf_np = self.mask.get_psf().detach().cpu().numpy()[0, ...]
psf_np = psf_np.squeeze() # remove (potential) singleton color channel
save_image(psf_np, os.path.join(path, f"psf_epoch{epoch}.png"))
plot_image(psf_np, gamma=self.gamma)
plt.savefig(os.path.join(path, f"psf_epoch{epoch}_plot.png"))
Expand Down
19 changes: 15 additions & 4 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def __init__(
dataset_is_CHW : bool, optional
If True, the input dataset is expected to output images with shape [C, H, W], by default ``False``.
flip : bool, optional
If True, images are flipped beffore the simulation, by default ``False``..
If True, images are flipped beffore the simulation, by default ``False``.
"""

# we do the flipping before the simualtion
Expand All @@ -192,6 +192,10 @@ def __init__(
assert simulator.fft_shape is not None, "Simulator should have a psf"
self.sim = simulator

@property
def psf(self):
return self.sim.get_psf()

def get_image(self, index):
return self.dataset[index]

Expand All @@ -206,7 +210,14 @@ def _get_images_pair(self, index):
if self._pre_transform is not None:
img = self._pre_transform(img)

lensless, lensed = self.sim.propagate(img, return_object_plane=True)
lensless, lensed = self.sim.propagate_image(img, return_object_plane=True)

if lensed.shape[-1] == 1 and lensless.shape[-1] == 3:
# copy to 3 channels
lensed = lensed.repeat(1, 1, 3)
assert (
lensed.shape[-1] == lensless.shape[-1]
), "Lensed and lensless should have same number of channels"

return lensless, lensed

Expand Down Expand Up @@ -319,7 +330,7 @@ def _get_images_pair(self, idx):

# project original image to lensed space
with torch.no_grad():
lensed = self.sim.propagate()
lensed = self.sim.propagate_image()

return lensless, lensed

Expand Down Expand Up @@ -581,7 +592,7 @@ def __init__(
def _get_images_pair(self, index):
# update psf
psf = self._mask.get_psf()
self.sim.set_psf(psf)
self.sim.set_point_spread_function(psf)

# return simulated images
return super()._get_images_pair(index)
17 changes: 17 additions & 0 deletions lensless/utils/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,23 @@ def resize(img, factor=None, shape=None, interpolation=cv2.INTER_CUBIC):
return np.clip(resized, min_val, max_val)


def is_grayscale(img):
"""
Check if image is RGB. Assuming image is of shape ([depth,] height, width, color).
Parameters
----------
img : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor`
Image array.
Returns
-------
bool
Whether image is RGB.
"""
return img.shape[-1] == 1


def rgb2gray(rgb, weights=None, keepchanneldim=True):
"""
Convert RGB array to grayscale.
Expand Down
58 changes: 44 additions & 14 deletions lensless/utils/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
# Eric BEZZAM [[email protected]]
# #############################################################################

import numpy as np
from waveprop.simulation import FarFieldSimulator as FarFieldSimulator_wp
import torch


class FarFieldSimulator(FarFieldSimulator_wp):
Expand All @@ -34,7 +34,7 @@ def __init__(
"""
Parameters
----------
psf : np.ndarray, optional.
psf : np.ndarray or torch.Tensor, optional.
Point spread function. If not provided, return image at object plane.
object_height : float or (float, float)
Height of object in meters. Or range of values to randomly sample from.
Expand All @@ -58,9 +58,15 @@ def __init__(
Whether to quantize image, by default True.
"""

if psf is not None:
# convert HWC to CHW
psf = psf.squeeze().movedim(-1, 0)
assert len(psf.shape) == 4, "PSF must be of shape (depth, height, width, channels)"

if torch.is_tensor(psf):
# drop depth dimension, and convert HWC to CHW
psf = psf[0].movedim(-1, 0)
assert psf.shape[0] == 1 or psf.shape[0] == 3, "PSF must have 1 or 3 channels"
else:
psf = psf[0]
assert psf.shape[-1] == 1 or psf.shape[-1] == 3, "PSF must have 1 or 3 channels"

super().__init__(
object_height,
Expand All @@ -78,6 +84,13 @@ def __init__(
**kwargs
)

if self.is_torch:
assert self.psf.shape[0] == 1 or self.psf.shape[0] == 3, "PSF must have 1 or 3 channels"
else:
assert (
self.psf.shape[-1] == 1 or self.psf.shape[-1] == 3
), "PSF must have 1 or 3 channels"

# save all the parameters in a dict
self.params = {
"object_height": object_height,
Expand All @@ -94,7 +107,15 @@ def __init__(
}
self.params.update(kwargs)

def set_psf(self, psf):
def get_psf(self):
if self.is_torch:
# convert CHW to HWC
return self.psf.movedim(0, -1).unsqueeze(0)
else:
return self.psf[None, ...]

# needs different name from parent class
def set_point_spread_function(self, psf):
"""
Set point spread function.
Expand All @@ -103,19 +124,32 @@ def set_psf(self, psf):
psf : np.ndarray or torch.Tensor
Point spread function.
"""
psf = psf.squeeze().movedim(-1, 0)
assert len(psf.shape) == 4, "PSF must be of shape (depth, height, width, channels)"

if torch.is_tensor(psf):
# convert HWC to CHW
psf = psf[0].movedim(-1, 0)
assert psf.shape[0] == 1 or psf.shape[0] == 3, "PSF must have 1 or 3 channels"
else:
psf = psf[0]
assert psf.shape[-1] == 1 or psf.shape[-1] == 3, "PSF must have 1 or 3 channels"

return super().set_psf(psf)

def propagate(self, obj, return_object_plane=False):
def propagate_image(self, obj, return_object_plane=False):
"""
Parameters
----------
obj : np.ndarray or torch.Tensor
Single image to propagate at format HWC.
Single image to propagate of format HWC.
return_object_plane : bool, optional
Whether to return object plane, by default False.
"""

assert obj.shape[-1] == 1 or obj.shape[-1] == 3, "Image must have 1 or 3 channels"

if self.is_torch:
# channel in first dimension as expected by waveprop for pytorch
obj = obj.moveaxis(-1, 0)
res = super().propagate(obj, return_object_plane)
if isinstance(res, tuple):
Expand All @@ -124,10 +158,6 @@ def propagate(self, obj, return_object_plane=False):
res = res.moveaxis(-3, -1)
return res
else:
obj = np.moveaxis(obj, -1, 0)
# TODO: not tested, but normally don't need to move dimensions for numpy
res = super().propagate(obj, return_object_plane)
if isinstance(res, tuple):
res = np.moveaxis(res[0], -3, -1), np.moveaxis(res[1], -3, -1)
else:
res = np.moveaxis(res, -3, -1)
return res
2 changes: 1 addition & 1 deletion mask_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
sympy>=1.11.1
perlin_numpy @ git+https://github.com/pvigier/perlin-numpy.git@5e26837db14042e51166eb6cad4c0df2c1907016
waveprop>=0.0.7
waveprop>=0.0.8
Loading

0 comments on commit eb0c84e

Please sign in to comment.