Skip to content

Commit

Permalink
Add option to add unrolled output to loss.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Jan 12, 2024
1 parent 07f533f commit 6d15219
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 9 deletions.
3 changes: 2 additions & 1 deletion configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,5 @@ optimizer:

loss: 'l2'
# set lpips to false to deactivate. Otherwise, give the weigth for the loss (the main loss l2/l1 always having a weigth of 1)
lpips: 1.0
lpips: 1.0
unrolled_output_factor: False # whether to account for unrolled output in loss (there must post-processor)
54 changes: 53 additions & 1 deletion lensless/eval/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def benchmark(
crop=None,
save_idx=None,
output_dir=None,
unrolled_output_factor=False,
**kwargs,
):
"""
Expand Down Expand Up @@ -98,11 +99,17 @@ def benchmark(
with torch.no_grad():
if batchsize == 1:
model.set_data(lensless)
prediction = model.apply(plot=False, save=False, **kwargs)
prediction = model.apply(
plot=False, save=False, output_intermediate=unrolled_output_factor, **kwargs
)

else:
prediction = model.batch_call(lensless, **kwargs)

if unrolled_output_factor:
unrolled_out = prediction[-1]
prediction = prediction[0]

# Convert to [N*D, C, H, W] for torchmetrics
prediction = prediction.reshape(-1, *prediction.shape[-3:]).movedim(-1, -3)
lensed = lensed.reshape(-1, *lensed.shape[-3:]).movedim(-1, -3)
Expand Down Expand Up @@ -137,6 +144,7 @@ def benchmark(
print("Warning: prediction is zero")
lensed_max = torch.amax(lensed, dim=(1, 2, 3), keepdim=True)
lensed = lensed / lensed_max

# compute metrics
for metric in metrics:
if metric == "ReconstructionError":
Expand All @@ -157,6 +165,50 @@ def benchmark(
else:
metrics_values[metric] += metrics[metric](prediction, lensed).cpu().item()

# compute metrics for unrolled output
if unrolled_output_factor:

# -- convert to CHW and remove depth
unrolled_out = unrolled_out.reshape(-1, *unrolled_out.shape[-3:]).movedim(-1, -3)

# -- extraction region of interest
if crop is not None:
unrolled_out = unrolled_out[
...,
crop["vertical"][0] : crop["vertical"][1],
crop["horizontal"][0] : crop["horizontal"][1],
]

# -- normalization
unrolled_out_max = torch.amax(unrolled_out, dim=(-1, -2, -3), keepdim=True)
if torch.all(unrolled_out_max != 0):
unrolled_out = unrolled_out / unrolled_out_max

# -- compute metrics
for metric in metrics:
if metric == "ReconstructionError":
# only have this for final output
continue
else:
if "LPIPS" in metric:
if unrolled_out.shape[1] == 1:
# LPIPS needs 3 channels
metrics_values[metric] += (
metrics[metric](
unrolled_out.repeat(1, 3, 1, 1), lensed.repeat(1, 3, 1, 1)
)
.cpu()
.item()
)
else:
metrics_values[metric + "_unrolled"] += (
metrics[metric](unrolled_out, lensed).cpu().item()
)
else:
metrics_values[metric + "_unrolled"] += (
metrics[metric](unrolled_out, lensed).cpu().item()
)

model.reset()
idx += batchsize

Expand Down
1 change: 1 addition & 0 deletions lensless/recon/recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ def apply(
gamma=None,
ax=None,
reset=True,
**kwargs,
):
"""
Method for performing iterative reconstruction. Note that `set_data`
Expand Down
23 changes: 20 additions & 3 deletions lensless/recon/trainable_recon.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
pre_process=None,
post_process=None,
skip_unrolled=False,
return_unrolled_output=False,
**kwargs,
):
"""
Expand All @@ -74,6 +75,10 @@ def __init__(
post_process : :py:class:`function` or :py:class:`~torch.nn.Module`, optional
If :py:class:`function` : Function to apply to the image estimate after the whole algorithm. Its input most be (image to process, noise_level), where noise_level is a learnable parameter. If it include aditional learnable parameters, they will not be added to the parameter list of the algorithm. To allow for traning, the function must be autograd compatible.
If :py:class:`~torch.nn.Module` : A DruNet compatible network to apply to the image estimate after the whole algorithm. See ``utils.image.apply_denoiser`` for more details. The network will be included as a submodule of the algorithm and its parameters will be added to the parameter list of the algorithm. If this isn't intended behavior, set requires_grad=False.
skip_unrolled : bool, optional
Whether to skip the unrolled algorithm and only apply the pre- or post-processor block (e.g. to just use a U-Net for reconstruction).
return_unrolled_output : bool, optional
Whether to return the output of the unrolled algorithm if also using a post-processor block.
"""
assert isinstance(psf, torch.Tensor), "PSF must be a torch.Tensor"
super(TrainableReconstructionAlgorithm, self).__init__(
Expand All @@ -83,6 +88,11 @@ def __init__(
self.set_pre_process(pre_process)
self.set_post_process(post_process)
self.skip_unrolled = skip_unrolled
self.return_unrolled_output = return_unrolled_output
if self.return_unrolled_output:
assert (
post_process is not None
), "If return_unrolled_output is True, post_process must be defined."
if self.skip_unrolled:
assert (
post_process is not None or pre_process is not None
Expand Down Expand Up @@ -197,17 +207,24 @@ def batch_call(self, batch):

self.reset(batch_size=batch_size)

# unrolled algorithm
if not self.skip_unrolled:
for i in range(self._n_iter):
self._update(i)
image_est = self._form_image()

else:
image_est = self._data

# post process data
if self.post_process is not None:
image_est = self.post_process(image_est, self.post_process_param)
return image_est
final_est = self.post_process(image_est, self.post_process_param)
else:
final_est = image_est

if self.return_unrolled_output:
return final_est, image_est
else:
return final_est

def apply(
self,
Expand Down
70 changes: 66 additions & 4 deletions lensless/recon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def __init__(
logger=None,
crop=None,
clip_grad=1.0,
unrolled_output_factor=False,
# for adding components during training
pre_process=None,
pre_process_delay=None,
Expand Down Expand Up @@ -322,6 +323,8 @@ def __init__(
Logger to use for logging. If None, just print to terminal. Default is None.
crop : dict, optional
Crop to apply to images before computing loss (by applying a mask). If None, no crop is applied. Default is None.
unrolled_output_factor : float, optional
How much of the unrolled loss to add to the total loss. If False, no unrolled loss is added. Default is False. Only applicable if a post-processor is used.
pre_process : :py:class:`torch.nn.Module`, optional
Pre process component to add during training. Default is None.
pre_process_delay : int, optional
Expand Down Expand Up @@ -409,7 +412,7 @@ def __init__(
else:
raise ValueError(f"Unsuported loss : {loss}")

# Lpips loss
# -- Lpips loss
if lpips:
try:
import lpips
Expand All @@ -422,13 +425,23 @@ def __init__(

self.crop = crop

# -- adding unrolled loss
self.unrolled_output_factor = unrolled_output_factor
if self.unrolled_output_factor:
assert self.unrolled_output_factor > 0
assert self.post_process is not None
assert self.post_process_delay is not None
assert self.post_process_unfreeze is not None
assert self.post_process_freeze is not None

# optimizer
self.clip_grad_norm = clip_grad
self.optimizer_config = optimizer
self.set_optimizer()

self.metrics = {
"LOSS": [], # train loss
"LOSS_TEST": [], # test loss
"MSE": [],
"MAE": [],
"LPIPS_Vgg": [],
Expand Down Expand Up @@ -539,6 +552,10 @@ def train_epoch(self, data_loader):

# forward pass
y_pred = self.recon.batch_call(X.to(self.device))
if self.unrolled_output_factor:
unrolled_out = y_pred[1]
y_pred = y_pred[0]

# normalizing each output
eps = 1e-12
y_pred_max = torch.amax(y_pred, dim=(-1, -2, -3), keepdim=True) + eps
Expand All @@ -553,7 +570,7 @@ def train_epoch(self, data_loader):
y_pred = y_pred.reshape(-1, *y_pred.shape[-3:]).movedim(-1, -3)
y = y.reshape(-1, *y.shape[-3:]).movedim(-1, -3)

# crop
# extraction region of interest for loss
if self.crop is not None:
y_pred = y_pred[
...,
Expand All @@ -567,6 +584,8 @@ def train_epoch(self, data_loader):
]

loss_v = self.Loss(y_pred, y)

# add LPIPS loss
if self.lpips:

if y_pred.shape[1] == 1:
Expand All @@ -580,6 +599,41 @@ def train_epoch(self, data_loader):
)
if self.use_mask and self.l1_mask:
loss_v = loss_v + self.l1_mask * torch.mean(torch.abs(self.mask._mask))

if self.unrolled_output_factor:
# -- normalize
unrolled_out_max = torch.amax(unrolled_out, dim=(-1, -2, -3), keepdim=True) + eps
unrolled_out = unrolled_out / unrolled_out_max

# -- convert to CHW for loss and remove depth
unrolled_out = unrolled_out.reshape(-1, *unrolled_out.shape[-3:]).movedim(-1, -3)

# -- extraction region of interest for loss
if self.crop is not None:
unrolled_out = unrolled_out[
...,
self.crop["vertical"][0] : self.crop["vertical"][1],
self.crop["horizontal"][0] : self.crop["horizontal"][1],
]

# -- compute unrolled output loss
loss_unrolled = self.Loss(unrolled_out, y)

# -- add LPIPS loss
if self.lpips:
if unrolled_out.shape[1] == 1:
# if only one channel, repeat for LPIPS
unrolled_out = unrolled_out.repeat(1, 3, 1, 1)

# value for LPIPS needs to be in range [-1, 1]
loss_unrolled = loss_unrolled + self.lpips * torch.mean(
self.Loss_lpips(2 * unrolled_out - 1, 2 * y - 1)
)

# -- add unrolled loss to total loss
loss_v = loss_v + self.unrolled_output_factor * loss_unrolled

# backward pass
loss_v.backward()

if self.clip_grad_norm is not None:
Expand Down Expand Up @@ -641,6 +695,7 @@ def evaluate(self, mean_loss, save_pt, epoch, disp=None):
save_idx=disp,
output_dir=output_dir,
crop=self.crop,
unrolled_output_factor=self.unrolled_output_factor,
)

# update metrics with current metrics
Expand All @@ -660,9 +715,16 @@ def evaluate(self, mean_loss, save_pt, epoch, disp=None):
eval_loss += self.lpips * current_metrics["LPIPS_Vgg"]
if self.use_mask and self.l1_mask:
eval_loss += self.l1_mask * np.mean(np.abs(self.mask._mask.cpu().detach().numpy()))
return eval_loss
if self.unrolled_output_factor:
unrolled_loss = current_metrics["MSE_unrolled"]
if self.lpips is not None:
unrolled_loss += self.lpips * current_metrics["LPIPS_Vgg_unrolled"]
eval_loss += self.unrolled_output_factor * unrolled_loss
else:
return current_metrics[self.metrics["metric_for_best_model"]]
eval_loss = current_metrics[self.metrics["metric_for_best_model"]]

self.metrics["LOSS_TEST"].append(eval_loss)
return eval_loss

def on_epoch_end(self, mean_loss, save_pt, epoch, disp=None):
"""
Expand Down
3 changes: 3 additions & 0 deletions scripts/recon/train_unrolled.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,7 @@ def train_unrolled(config):
pre_process=pre_process if pre_proc_delay is None else None,
post_process=post_process if post_proc_delay is None else None,
skip_unrolled=config.reconstruction.skip_unrolled,
return_unrolled_output=True if config.unrolled_output_factor > 0 else False,
).to(device)
elif config.reconstruction.method == "unrolled_admm":
recon = UnrolledADMM(
Expand All @@ -559,6 +560,7 @@ def train_unrolled(config):
pre_process=pre_process if pre_proc_delay is None else None,
post_process=post_process if post_proc_delay is None else None,
skip_unrolled=config.reconstruction.skip_unrolled,
return_unrolled_output=True if config.unrolled_output_factor > 0 else False,
).to(device)
else:
raise ValueError(f"{config.reconstruction.method} is not a supported algorithm")
Expand Down Expand Up @@ -606,6 +608,7 @@ def train_unrolled(config):
post_process_freeze=config.reconstruction.post_process.freeze,
post_process_unfreeze=config.reconstruction.post_process.unfreeze,
clip_grad=config.training.clip_grad,
unrolled_output_factor=config.unrolled_output_factor,
)

trainer.train(n_epoch=config.training.epoch, save_pt=save, disp=config.eval_disp_idx)
Expand Down

0 comments on commit 6d15219

Please sign in to comment.