Skip to content

Commit

Permalink
Add support for multimask training.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Feb 28, 2024
1 parent b62e1ff commit 223f6a2
Show file tree
Hide file tree
Showing 10 changed files with 267 additions and 27 deletions.
1 change: 1 addition & 0 deletions configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ start_delay: null
# Dataset
files:
dataset: /scratch/bezzam/DiffuserCam_mirflickr/dataset # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam"
huggingface_dataset: null
celeba_root: null # path to parent directory of CelebA: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
psf: data/psf/diffusercam_psf.tiff
diffusercam_psf: True
Expand Down
23 changes: 23 additions & 0 deletions configs/train_unrolled_multimask.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# python scripts/recon/train_unrolled.py -cn train_unrolled_multimask
defaults:
- train_unrolledADMM
- _self_

# Dataset
files:
dataset: bezzam/DigiCam-Mirflickr-MultiMask-1K
huggingface_dataset: True
downsample: 1.6
image_res: [900, 1200] # used during measurement
rotate: True # if measurement is upside-down

alignment:
# when there is no downsampling
topright: [80, 100] # height, width
height: 200

training:
batch_size: 4
epoch: 25
eval_batch_size: 10

30 changes: 27 additions & 3 deletions lensless/eval/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def benchmark(
batchsize=1,
metrics=None,
crop=None,
alignment=False,
save_idx=None,
output_dir=None,
unrolled_output_factor=False,
Expand All @@ -58,6 +59,8 @@ def benchmark(
Directory to save the predictions, by default save in working directory if save_idx is provided.
crop : dict, optional
Dictionary of crop parameters (vertical: [start, end], horizontal: [start, end]), by default None (no crop).
alignment : dict, optional
Similar to crop. Dictionary of alignment parameters (topright: [height, width], height: pix). Expects ``recon_width`` in ``dataset``. By default None (no alignment).
unrolled_output_factor : bool, optional
If True, compute metrics for unrolled output, by default False.
return_average : bool, optional
Expand All @@ -73,6 +76,11 @@ def benchmark(
assert isinstance(model._psf, torch.Tensor), "model need to be constructed with torch support"
device = model._psf.device

if hasattr(dataset, "psfs"):
multipsf_dataset = True
else:
multipsf_dataset = False

if output_dir is None:
output_dir = os.getcwd()
else:
Expand Down Expand Up @@ -107,7 +115,14 @@ def benchmark(
model.reset()
idx = 0
with torch.no_grad():
for lensless, lensed in tqdm(dataloader):
for batch in tqdm(dataloader):
if multipsf_dataset:
lensless, lensed, psfs = batch
psfs = psfs.to(device)
else:
lensless, lensed = batch
psfs = None

lensless = lensless.to(device)
lensed = lensed.to(device)

Expand All @@ -118,13 +133,15 @@ def benchmark(

# compute predictions
if batchsize == 1:
# TODO : handle multipsf
assert not multipsf_dataset
model.set_data(lensless)
prediction = model.apply(
plot=False, save=False, output_intermediate=unrolled_output_factor, **kwargs
)

else:
prediction = model.batch_call(lensless, **kwargs)
prediction = model.batch_call(lensless, psfs, **kwargs)

if unrolled_output_factor:
unrolled_out = prediction[-1]
Expand All @@ -134,7 +151,14 @@ def benchmark(
prediction = prediction.reshape(-1, *prediction.shape[-3:]).movedim(-1, -3)
lensed = lensed.reshape(-1, *lensed.shape[-3:]).movedim(-1, -3)

if crop is not None:
if alignment is not None:
prediction = prediction[
...,
alignment["topright"][0] : alignment["topright"][0] + alignment["height"],
alignment["topright"][1] : alignment["topright"][1] + alignment["width"],
]
# expected that lensed is also reshaped accordingly
elif crop is not None:
prediction = prediction[
...,
crop["vertical"][0] : crop["vertical"][1],
Expand Down
6 changes: 4 additions & 2 deletions lensless/recon/recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def __init__(
else:
raise ValueError(f"Unsupported dtype : {self._dtype}")

self._convolver_param = {"dtype": dtype, "pad": pad, **kwargs}
self._convolver = RealFFTConvolve2D(psf, dtype=dtype, pad=pad, **kwargs)
self._padded_shape = self._convolver._padded_shape

Expand Down Expand Up @@ -445,8 +446,9 @@ def _set_psf(self, psf):
psf : :py:class:`~numpy.ndarray` or :py:class:`~torch.Tensor`
PSF to set.
"""
assert len(psf.shape) == 4, "PSF must be 4D: (depth, height, width, channels)."
assert psf.shape[3] == 3 or psf.shape[3] == 1, "PSF must either be rgb (3) or grayscale (1)"
assert (
psf.shape[-1] == 3 or psf.shape[-1] == 1
), "PSF must either be rgb (3) or grayscale (1)"
assert self._psf.shape == psf.shape, "new PSF must have same shape as old PSF"
assert isinstance(psf, type(self._psf)), "new PSF must have same type as old PSF"

Expand Down
10 changes: 6 additions & 4 deletions lensless/recon/rfft_convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,12 @@ def __init__(self, psf, dtype=None, pad=True, norm="ortho", **kwargs):

# prepare shapes for reconstruction

assert len(psf.shape) == 4, "Expected 4D PSF of shape (depth, width, height, channels)"
self._use_3d = psf.shape[0] != 1
self._is_rgb = psf.shape[3] == 3
assert self._is_rgb or psf.shape[3] == 1
assert (
len(psf.shape) >= 4
), "Expected 4D PSF of shape ([batch], depth, width, height, channels)"
self._use_3d = psf.shape[-4] != 1
self._is_rgb = psf.shape[-1] == 3
assert self._is_rgb or psf.shape[-1] == 1

# save normalization
self.norm = norm
Expand Down
10 changes: 9 additions & 1 deletion lensless/recon/trainable_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from matplotlib import pyplot as plt
from lensless.recon.recon import ReconstructionAlgorithm
from lensless.utils.plot import plot_image
from lensless.recon.rfft_convolve import RealFFTConvolve2D

try:
import torch
Expand Down Expand Up @@ -191,7 +192,7 @@ def unfreeze_post_process(self):
for param in self.post_process_model.parameters():
param.requires_grad = True

def batch_call(self, batch):
def batch_call(self, batch, psfs=None):
"""
Method for performing iterative reconstruction on a batch of images.
This implementation is a properly vectorized implementation of FISTA.
Expand All @@ -200,6 +201,8 @@ def batch_call(self, batch):
----------
batch : :py:class:`~torch.Tensor` of shape (batch, depth, channels, height, width)
The lensless images to reconstruct.
psfs : :py:class:`~torch.Tensor` of shape (batch, depth, channels, height, width)
The lensless images to reconstruct.
Returns
-------
Expand All @@ -209,6 +212,11 @@ def batch_call(self, batch):
self._data = batch
assert len(self._data.shape) == 5, "batch must be of shape (N, D, C, H, W)"
batch_size = batch.shape[0]
if psfs is not None:
# assert same shape
assert psfs.shape == batch.shape, "psfs must have the same shape as batch"
# -- update convolver
self._convolver = RealFFTConvolve2D(psfs.to(self._psf.device), **self._convolver_param)

# pre process data
if self.pre_process is not None:
Expand Down
16 changes: 8 additions & 8 deletions lensless/recon/unrolled_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,18 +159,18 @@ def reset(self, batch_size=1):
self._eta = torch.zeros_like(self._U)
self._rho = torch.zeros_like(self._X)

# precompute_R_divmat
# precompute_R_divmat [iter, batch, depth, height, width, channels]
self._R_divmat = 1.0 / (
self._mu1[:, None, None, None, None]
* (torch.abs(self._convolver._Hadj * self._convolver._H))
+ self._mu2[:, None, None, None, None] * torch.abs(self._PsiTPsi)
+ self._mu3[:, None, None, None, None]
self._mu1[:, None, None, None, None, None]
* (torch.abs(self._convolver._Hadj * self._convolver._H))[None, ...]
+ self._mu2[:, None, None, None, None, None] * torch.abs(self._PsiTPsi)
+ self._mu3[:, None, None, None, None, None]
).type(self._complex_dtype)

# precompute_X_divmat
# precompute_X_divmat [iter, batch, depth, height, width, channels]
self._X_divmat = 1.0 / (
self._convolver._pad(torch.ones_like(self._psf[None, ...]))
+ self._mu1[:, None, None, None, None]
self._convolver._pad(torch.ones_like(self._convolver._psf))[None, ...]
+ self._mu1[:, None, None, None, None, None]
)

def _U_update(self, iter):
Expand Down
31 changes: 28 additions & 3 deletions lensless/recon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def __init__(
gamma=None,
logger=None,
crop=None,
alignment=None,
clip_grad=1.0,
unrolled_output_factor=False,
# for adding components during training
Expand Down Expand Up @@ -417,6 +418,10 @@ def __init__(
)
self.print(f"Train size : {train_size}, Test size : {test_size}")

if hasattr(train_dataset, "psfs"):
self.multipsf_dataset = True
else:
self.multipsf_dataset = False
self.train_dataloader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=batch_size,
Expand Down Expand Up @@ -470,6 +475,7 @@ def __init__(
)

self.crop = crop
self.alignment = alignment

# -- adding unrolled loss
self.unrolled_output_factor = unrolled_output_factor
Expand Down Expand Up @@ -590,7 +596,16 @@ def train_epoch(self, data_loader):
mean_loss = 0.0
i = 1.0
pbar = tqdm(data_loader)
for X, y in pbar:
for batch in pbar:

# get batch
if self.multipsf_dataset:
X, y, psfs = batch
psfs = psfs.to(self.device)
else:
X, y = batch
psfs = None

# send to device
X = X.to(self.device)
y = y.to(self.device)
Expand All @@ -600,7 +615,7 @@ def train_epoch(self, data_loader):
self.recon._set_psf(self.mask.get_psf().to(self.device))

# forward pass
y_pred = self.recon.batch_call(X.to(self.device))
y_pred = self.recon.batch_call(X, psfs=psfs)
if self.unrolled_output_factor:
unrolled_out = y_pred[1]
y_pred = y_pred[0]
Expand All @@ -619,7 +634,16 @@ def train_epoch(self, data_loader):
y = y.reshape(-1, *y.shape[-3:]).movedim(-1, -3)

# extraction region of interest for loss
if self.crop is not None:
if self.alignment is not None:
y_pred = y_pred[
...,
self.alignment["topright"][0] : self.alignment["topright"][0]
+ self.alignment["height"],
self.alignment["topright"][1] : self.alignment["topright"][1]
+ self.alignment["width"],
]
# expected that lensed is also reshaped accordingly
elif self.crop is not None:
y_pred = y_pred[
...,
self.crop["vertical"][0] : self.crop["vertical"][1],
Expand Down Expand Up @@ -771,6 +795,7 @@ def evaluate(self, mean_loss, save_pt, epoch, disp=None):
save_idx=disp,
output_dir=output_dir,
crop=self.crop,
alignment=self.alignment,
unrolled_output_factor=self.unrolled_output_factor,
)

Expand Down
Loading

0 comments on commit 223f6a2

Please sign in to comment.