Skip to content

Commit

Permalink
Add more features to unrolled training.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Nov 17, 2023
1 parent 331b8bf commit 2ffa96c
Show file tree
Hide file tree
Showing 13 changed files with 632 additions and 185 deletions.
18 changes: 15 additions & 3 deletions configs/benchmark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,23 @@ admm:
# for DigiCamCelebA
files:
test_size: 0.15
dataset: /scratch/bezzam/celeba_adafruit_random_2mm_20230720_10K
downsample: 1
celeba_root: /scratch/bezzam
psf: data/psf/adafruit_random_2mm_20231907.png

downsample: 1

# dataset: /scratch/bezzam/celeba_adafruit_random_2mm_20230720_10K
# psf: data/psf/adafruit_random_2mm_20231907.png
# vertical_shift: null
# horizontal_shift: null
# crop: null

dataset: /scratch/bezzam/celeba/celeba_adafruit_random_30cm_2mm_20231004_26K
psf: rpi_hq_adafruit_psf_2mm/raw_data_rgb.png
vertical_shift: -117
horizontal_shift: -25
crop:
vertical: [0, 525]
horizontal: [265, 695]

# for prepping ground truth data
#for simulated dataset
Expand Down
4 changes: 4 additions & 0 deletions configs/train_psf_from_scratch.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ trainable_mask:

simulation:
grayscale: False


training:
crop_preloss: False # crop region for computing loss
20 changes: 17 additions & 3 deletions configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ hydra:


seed: 0
start_delay: null

# Dataset
files:
Expand All @@ -16,8 +17,15 @@ files:
downsample: 2 # factor by which to downsample the PSF, note that for DiffuserCam the PSF has 4x the resolution
test_size: 0.15

vertical_shift: null
horizontal_shift: null
crop: null
# vertical: null
# horizontal: null

torch: True
torch_device: 'cuda'
measure: null # if measuring data on-the-fly

# see some outputs of classical ADMM before training
test_idx: [0, 1, 2, 3, 4]
Expand All @@ -37,6 +45,7 @@ save: True
reconstruction:
# Method: unrolled_admm, unrolled_fista
method: unrolled_admm
skip_unrolled: False

# Hyperparameters for each method
unrolled_fista: # for unrolled_fista
Expand All @@ -56,10 +65,17 @@ reconstruction:
network : null # UnetRes or DruNet or null
depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet
nc: null
delay: null # add component after this may epochs
freeze: null
unfreeze: null
post_process:
network : null # UnetRes or DruNet or null
depth : 2 # depth of each up/downsampling layer. Ignore if network is DruNet
nc: null
delay: null # add component after this may epochs
freeze: null
unfreeze: null
train_last_layer: False

#Trainable Mask
trainable_mask:
Expand Down Expand Up @@ -108,11 +124,9 @@ training:
save_every: null
#In case of instable training
skip_NAN: True
clip_grad: 1.0

crop_preloss: True # crop region for computing loss
crop: null
# vertical: null
# horizontal: null

optimizer:
type: Adam
Expand Down
1 change: 1 addition & 0 deletions lensless/eval/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def benchmark(

if save_idx is not None:
batch_idx = np.arange(idx, idx + batchsize)

for i, idx in enumerate(batch_idx):
if idx in save_idx:
prediction_np = prediction.cpu().numpy()[i].squeeze()
Expand Down
2 changes: 2 additions & 0 deletions lensless/recon/drunet/network_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ def __init__(
):
super(UNetRes, self).__init__()

assert len(nc) == 4, "nc's length should be 4."

self.m_head = B.conv(in_nc, nc[0], bias=False, mode="C")

# downsample
Expand Down
109 changes: 83 additions & 26 deletions lensless/recon/trainable_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
n_iter=1,
pre_process=None,
post_process=None,
skip_unrolled=False,
**kwargs,
):
"""
Expand Down Expand Up @@ -79,19 +80,13 @@ def __init__(
psf, dtype=dtype, n_iter=n_iter, **kwargs
)

# pre processing
(
self.pre_process,
self.pre_process_model,
self.pre_process_param,
) = self._prepare_process_block(pre_process)

# post processing
(
self.post_process,
self.post_process_model,
self.post_process_param,
) = self._prepare_process_block(post_process)
self.set_pre_process(pre_process)
self.set_post_process(post_process)
self.skip_unrolled = skip_unrolled
if self.skip_unrolled:
assert (
post_process is not None or pre_process is not None
), "If skip_unrolled is True, pre_process or post_process must be defined."

def _prepare_process_block(self, process):
"""
Expand All @@ -115,13 +110,68 @@ def _prepare_process_block(self, process):
else:
process_function = None
process_model = None

if process_function is not None:
process_param = torch.nn.Parameter(torch.tensor([1.0], device=self._psf.device))
else:
process_param = None

return process_function, process_model, process_param

def set_pre_process(self, pre_process):
(
self.pre_process,
self.pre_process_model,
self.pre_process_param,
) = self._prepare_process_block(pre_process)

def set_post_process(self, post_process):
(
self.post_process,
self.post_process_model,
self.post_process_param,
) = self._prepare_process_block(post_process)

def freeze_pre_process(self):
"""
Method for freezing the pre process block.
"""
if self.pre_process_param is not None:
self.pre_process_param.requires_grad = False
if self.pre_process_model is not None:
for param in self.pre_process_model.parameters():
param.requires_grad = False

def freeze_post_process(self):
"""
Method for freezing the post process block.
"""
if self.post_process_param is not None:
self.post_process_param.requires_grad = False
if self.post_process_model is not None:
for param in self.post_process_model.parameters():
param.requires_grad = False

def unfreeze_pre_process(self):
"""
Method for unfreezing the pre process block.
"""
if self.pre_process_param is not None:
self.pre_process_param.requires_grad = True
if self.pre_process_model is not None:
for param in self.pre_process_model.parameters():
param.requires_grad = True

def unfreeze_post_process(self):
"""
Method for unfreezing the post process block.
"""
if self.post_process_param is not None:
self.post_process_param.requires_grad = True
if self.post_process_model is not None:
for param in self.post_process_model.parameters():
param.requires_grad = True

def batch_call(self, batch):
"""
Method for performing iterative reconstruction on a batch of images.
Expand All @@ -147,10 +197,14 @@ def batch_call(self, batch):

self.reset(batch_size=batch_size)

for i in range(self._n_iter):
self._update(i)
if not self.skip_unrolled:
for i in range(self._n_iter):
self._update(i)
image_est = self._form_image()

else:
image_est = self._data

image_est = self._form_image()
if self.post_process is not None:
image_est = self.post_process(image_est, self.post_process_param)
return image_est
Expand Down Expand Up @@ -207,16 +261,19 @@ def apply(
if output_intermediate:
pre_processed_image = self._data[0, ...].clone()

im = super(TrainableReconstructionAlgorithm, self).apply(
n_iter=self._n_iter,
disp_iter=disp_iter,
plot_pause=plot_pause,
plot=plot,
save=save,
gamma=gamma,
ax=ax,
reset=reset,
)
if not self.skip_unrolled:
im = super(TrainableReconstructionAlgorithm, self).apply(
n_iter=self._n_iter,
disp_iter=disp_iter,
plot_pause=plot_pause,
plot=plot,
save=save,
gamma=gamma,
ax=ax,
reset=reset,
)
else:
im = self._data

# remove plot if returned
if plot:
Expand Down
22 changes: 18 additions & 4 deletions lensless/recon/unrolled_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,24 @@ def __init__(
psf, n_iter=n_iter, dtype=dtype, pad=pad, norm=norm, reset=False, **kwargs
)

self._mu1_p = torch.nn.Parameter(torch.ones(self._n_iter, device=self._psf.device) * mu1)
self._mu2_p = torch.nn.Parameter(torch.ones(self._n_iter, device=self._psf.device) * mu2)
self._mu3_p = torch.nn.Parameter(torch.ones(self._n_iter, device=self._psf.device) * mu3)
self._tau_p = torch.nn.Parameter(torch.ones(self._n_iter, device=self._psf.device) * tau)
if not self.skip_unrolled:
self._mu1_p = torch.nn.Parameter(
torch.ones(self._n_iter, device=self._psf.device) * mu1
)
self._mu2_p = torch.nn.Parameter(
torch.ones(self._n_iter, device=self._psf.device) * mu2
)
self._mu3_p = torch.nn.Parameter(
torch.ones(self._n_iter, device=self._psf.device) * mu3
)
self._tau_p = torch.nn.Parameter(
torch.ones(self._n_iter, device=self._psf.device) * tau
)
else:
self._mu1_p = torch.ones(self._n_iter, device=self._psf.device) * mu1
self._mu2_p = torch.ones(self._n_iter, device=self._psf.device) * mu2
self._mu3_p = torch.ones(self._n_iter, device=self._psf.device) * mu3
self._tau_p = torch.ones(self._n_iter, device=self._psf.device) * tau

# set prior
if psi is None:
Expand Down
15 changes: 10 additions & 5 deletions lensless/recon/unrolled_fista.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,22 @@ def __init__(self, psf, n_iter=5, dtype=None, proj=non_neg, learn_tk=True, tk=1,
# learnable step size initialize as < 2 / lipschitz
Hadj_flat = self._convolver._Hadj.reshape(-1, self._psf_shape[3])
H_flat = self._convolver._H.reshape(-1, self._psf_shape[3])
self._alpha_p = torch.nn.Parameter(
torch.ones(self._n_iter, self._psf_shape[3]).to(psf.device)
* (1.8 / torch.max(torch.abs(Hadj_flat * H_flat), axis=0).values)
)
if not self.skip_unrolled:
self._alpha_p = torch.nn.Parameter(
torch.ones(self._n_iter, self._psf_shape[3]).to(psf.device)
* (1.8 / torch.max(torch.abs(Hadj_flat * H_flat), axis=0).values)
)
else:
self._alpha_p = torch.ones(self._n_iter, self._psf_shape[3]).to(psf.device) * (
1.8 / torch.max(torch.abs(Hadj_flat * H_flat), axis=0).values
)

# set tk, can be learnt if learn_tk=True
self._tk_p = [tk]
for i in range(self._n_iter):
self._tk_p.append((1 + np.sqrt(1 + 4 * self._tk_p[i] ** 2)) / 2)
self._tk_p = torch.Tensor(self._tk_p)
if learn_tk:
if learn_tk and not self.skip_unrolled:
self._tk_p = torch.nn.Parameter(self._tk_p).to(psf.device)

self.reset()
Expand Down
Loading

0 comments on commit 2ffa96c

Please sign in to comment.