From 6d1521900f508e9dadf5d19e39fbba091246bf95 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Fri, 12 Jan 2024 14:41:01 +0100 Subject: [PATCH] Add option to add unrolled output to loss. --- configs/train_unrolledADMM.yaml | 3 +- lensless/eval/benchmark.py | 54 +++++++++++++++++++++++- lensless/recon/recon.py | 1 + lensless/recon/trainable_recon.py | 23 ++++++++-- lensless/recon/utils.py | 70 +++++++++++++++++++++++++++++-- scripts/recon/train_unrolled.py | 3 ++ 6 files changed, 145 insertions(+), 9 deletions(-) diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index f7602f01..3ead25e4 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -139,4 +139,5 @@ optimizer: loss: 'l2' # set lpips to false to deactivate. Otherwise, give the weigth for the loss (the main loss l2/l1 always having a weigth of 1) -lpips: 1.0 \ No newline at end of file +lpips: 1.0 +unrolled_output_factor: False # whether to account for unrolled output in loss (there must post-processor) \ No newline at end of file diff --git a/lensless/eval/benchmark.py b/lensless/eval/benchmark.py index 8abd254e..46ae1cac 100644 --- a/lensless/eval/benchmark.py +++ b/lensless/eval/benchmark.py @@ -33,6 +33,7 @@ def benchmark( crop=None, save_idx=None, output_dir=None, + unrolled_output_factor=False, **kwargs, ): """ @@ -98,11 +99,17 @@ def benchmark( with torch.no_grad(): if batchsize == 1: model.set_data(lensless) - prediction = model.apply(plot=False, save=False, **kwargs) + prediction = model.apply( + plot=False, save=False, output_intermediate=unrolled_output_factor, **kwargs + ) else: prediction = model.batch_call(lensless, **kwargs) + if unrolled_output_factor: + unrolled_out = prediction[-1] + prediction = prediction[0] + # Convert to [N*D, C, H, W] for torchmetrics prediction = prediction.reshape(-1, *prediction.shape[-3:]).movedim(-1, -3) lensed = lensed.reshape(-1, *lensed.shape[-3:]).movedim(-1, -3) @@ -137,6 +144,7 @@ def benchmark( print("Warning: prediction is zero") lensed_max = torch.amax(lensed, dim=(1, 2, 3), keepdim=True) lensed = lensed / lensed_max + # compute metrics for metric in metrics: if metric == "ReconstructionError": @@ -157,6 +165,50 @@ def benchmark( else: metrics_values[metric] += metrics[metric](prediction, lensed).cpu().item() + # compute metrics for unrolled output + if unrolled_output_factor: + + # -- convert to CHW and remove depth + unrolled_out = unrolled_out.reshape(-1, *unrolled_out.shape[-3:]).movedim(-1, -3) + + # -- extraction region of interest + if crop is not None: + unrolled_out = unrolled_out[ + ..., + crop["vertical"][0] : crop["vertical"][1], + crop["horizontal"][0] : crop["horizontal"][1], + ] + + # -- normalization + unrolled_out_max = torch.amax(unrolled_out, dim=(-1, -2, -3), keepdim=True) + if torch.all(unrolled_out_max != 0): + unrolled_out = unrolled_out / unrolled_out_max + + # -- compute metrics + for metric in metrics: + if metric == "ReconstructionError": + # only have this for final output + continue + else: + if "LPIPS" in metric: + if unrolled_out.shape[1] == 1: + # LPIPS needs 3 channels + metrics_values[metric] += ( + metrics[metric]( + unrolled_out.repeat(1, 3, 1, 1), lensed.repeat(1, 3, 1, 1) + ) + .cpu() + .item() + ) + else: + metrics_values[metric + "_unrolled"] += ( + metrics[metric](unrolled_out, lensed).cpu().item() + ) + else: + metrics_values[metric + "_unrolled"] += ( + metrics[metric](unrolled_out, lensed).cpu().item() + ) + model.reset() idx += batchsize diff --git a/lensless/recon/recon.py b/lensless/recon/recon.py index 444e3b0a..5ef3eb44 100644 --- a/lensless/recon/recon.py +++ b/lensless/recon/recon.py @@ -462,6 +462,7 @@ def apply( gamma=None, ax=None, reset=True, + **kwargs, ): """ Method for performing iterative reconstruction. Note that `set_data` diff --git a/lensless/recon/trainable_recon.py b/lensless/recon/trainable_recon.py index 52343c19..4f1af904 100644 --- a/lensless/recon/trainable_recon.py +++ b/lensless/recon/trainable_recon.py @@ -50,6 +50,7 @@ def __init__( pre_process=None, post_process=None, skip_unrolled=False, + return_unrolled_output=False, **kwargs, ): """ @@ -74,6 +75,10 @@ def __init__( post_process : :py:class:`function` or :py:class:`~torch.nn.Module`, optional If :py:class:`function` : Function to apply to the image estimate after the whole algorithm. Its input most be (image to process, noise_level), where noise_level is a learnable parameter. If it include aditional learnable parameters, they will not be added to the parameter list of the algorithm. To allow for traning, the function must be autograd compatible. If :py:class:`~torch.nn.Module` : A DruNet compatible network to apply to the image estimate after the whole algorithm. See ``utils.image.apply_denoiser`` for more details. The network will be included as a submodule of the algorithm and its parameters will be added to the parameter list of the algorithm. If this isn't intended behavior, set requires_grad=False. + skip_unrolled : bool, optional + Whether to skip the unrolled algorithm and only apply the pre- or post-processor block (e.g. to just use a U-Net for reconstruction). + return_unrolled_output : bool, optional + Whether to return the output of the unrolled algorithm if also using a post-processor block. """ assert isinstance(psf, torch.Tensor), "PSF must be a torch.Tensor" super(TrainableReconstructionAlgorithm, self).__init__( @@ -83,6 +88,11 @@ def __init__( self.set_pre_process(pre_process) self.set_post_process(post_process) self.skip_unrolled = skip_unrolled + self.return_unrolled_output = return_unrolled_output + if self.return_unrolled_output: + assert ( + post_process is not None + ), "If return_unrolled_output is True, post_process must be defined." if self.skip_unrolled: assert ( post_process is not None or pre_process is not None @@ -197,17 +207,24 @@ def batch_call(self, batch): self.reset(batch_size=batch_size) + # unrolled algorithm if not self.skip_unrolled: for i in range(self._n_iter): self._update(i) image_est = self._form_image() - else: image_est = self._data + # post process data if self.post_process is not None: - image_est = self.post_process(image_est, self.post_process_param) - return image_est + final_est = self.post_process(image_est, self.post_process_param) + else: + final_est = image_est + + if self.return_unrolled_output: + return final_est, image_est + else: + return final_est def apply( self, diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 53f23a1b..d376d951 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -265,6 +265,7 @@ def __init__( logger=None, crop=None, clip_grad=1.0, + unrolled_output_factor=False, # for adding components during training pre_process=None, pre_process_delay=None, @@ -322,6 +323,8 @@ def __init__( Logger to use for logging. If None, just print to terminal. Default is None. crop : dict, optional Crop to apply to images before computing loss (by applying a mask). If None, no crop is applied. Default is None. + unrolled_output_factor : float, optional + How much of the unrolled loss to add to the total loss. If False, no unrolled loss is added. Default is False. Only applicable if a post-processor is used. pre_process : :py:class:`torch.nn.Module`, optional Pre process component to add during training. Default is None. pre_process_delay : int, optional @@ -409,7 +412,7 @@ def __init__( else: raise ValueError(f"Unsuported loss : {loss}") - # Lpips loss + # -- Lpips loss if lpips: try: import lpips @@ -422,6 +425,15 @@ def __init__( self.crop = crop + # -- adding unrolled loss + self.unrolled_output_factor = unrolled_output_factor + if self.unrolled_output_factor: + assert self.unrolled_output_factor > 0 + assert self.post_process is not None + assert self.post_process_delay is not None + assert self.post_process_unfreeze is not None + assert self.post_process_freeze is not None + # optimizer self.clip_grad_norm = clip_grad self.optimizer_config = optimizer @@ -429,6 +441,7 @@ def __init__( self.metrics = { "LOSS": [], # train loss + "LOSS_TEST": [], # test loss "MSE": [], "MAE": [], "LPIPS_Vgg": [], @@ -539,6 +552,10 @@ def train_epoch(self, data_loader): # forward pass y_pred = self.recon.batch_call(X.to(self.device)) + if self.unrolled_output_factor: + unrolled_out = y_pred[1] + y_pred = y_pred[0] + # normalizing each output eps = 1e-12 y_pred_max = torch.amax(y_pred, dim=(-1, -2, -3), keepdim=True) + eps @@ -553,7 +570,7 @@ def train_epoch(self, data_loader): y_pred = y_pred.reshape(-1, *y_pred.shape[-3:]).movedim(-1, -3) y = y.reshape(-1, *y.shape[-3:]).movedim(-1, -3) - # crop + # extraction region of interest for loss if self.crop is not None: y_pred = y_pred[ ..., @@ -567,6 +584,8 @@ def train_epoch(self, data_loader): ] loss_v = self.Loss(y_pred, y) + + # add LPIPS loss if self.lpips: if y_pred.shape[1] == 1: @@ -580,6 +599,41 @@ def train_epoch(self, data_loader): ) if self.use_mask and self.l1_mask: loss_v = loss_v + self.l1_mask * torch.mean(torch.abs(self.mask._mask)) + + if self.unrolled_output_factor: + # -- normalize + unrolled_out_max = torch.amax(unrolled_out, dim=(-1, -2, -3), keepdim=True) + eps + unrolled_out = unrolled_out / unrolled_out_max + + # -- convert to CHW for loss and remove depth + unrolled_out = unrolled_out.reshape(-1, *unrolled_out.shape[-3:]).movedim(-1, -3) + + # -- extraction region of interest for loss + if self.crop is not None: + unrolled_out = unrolled_out[ + ..., + self.crop["vertical"][0] : self.crop["vertical"][1], + self.crop["horizontal"][0] : self.crop["horizontal"][1], + ] + + # -- compute unrolled output loss + loss_unrolled = self.Loss(unrolled_out, y) + + # -- add LPIPS loss + if self.lpips: + if unrolled_out.shape[1] == 1: + # if only one channel, repeat for LPIPS + unrolled_out = unrolled_out.repeat(1, 3, 1, 1) + + # value for LPIPS needs to be in range [-1, 1] + loss_unrolled = loss_unrolled + self.lpips * torch.mean( + self.Loss_lpips(2 * unrolled_out - 1, 2 * y - 1) + ) + + # -- add unrolled loss to total loss + loss_v = loss_v + self.unrolled_output_factor * loss_unrolled + + # backward pass loss_v.backward() if self.clip_grad_norm is not None: @@ -641,6 +695,7 @@ def evaluate(self, mean_loss, save_pt, epoch, disp=None): save_idx=disp, output_dir=output_dir, crop=self.crop, + unrolled_output_factor=self.unrolled_output_factor, ) # update metrics with current metrics @@ -660,9 +715,16 @@ def evaluate(self, mean_loss, save_pt, epoch, disp=None): eval_loss += self.lpips * current_metrics["LPIPS_Vgg"] if self.use_mask and self.l1_mask: eval_loss += self.l1_mask * np.mean(np.abs(self.mask._mask.cpu().detach().numpy())) - return eval_loss + if self.unrolled_output_factor: + unrolled_loss = current_metrics["MSE_unrolled"] + if self.lpips is not None: + unrolled_loss += self.lpips * current_metrics["LPIPS_Vgg_unrolled"] + eval_loss += self.unrolled_output_factor * unrolled_loss else: - return current_metrics[self.metrics["metric_for_best_model"]] + eval_loss = current_metrics[self.metrics["metric_for_best_model"]] + + self.metrics["LOSS_TEST"].append(eval_loss) + return eval_loss def on_epoch_end(self, mean_loss, save_pt, epoch, disp=None): """ diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index eaace9a8..8b4b2867 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -547,6 +547,7 @@ def train_unrolled(config): pre_process=pre_process if pre_proc_delay is None else None, post_process=post_process if post_proc_delay is None else None, skip_unrolled=config.reconstruction.skip_unrolled, + return_unrolled_output=True if config.unrolled_output_factor > 0 else False, ).to(device) elif config.reconstruction.method == "unrolled_admm": recon = UnrolledADMM( @@ -559,6 +560,7 @@ def train_unrolled(config): pre_process=pre_process if pre_proc_delay is None else None, post_process=post_process if post_proc_delay is None else None, skip_unrolled=config.reconstruction.skip_unrolled, + return_unrolled_output=True if config.unrolled_output_factor > 0 else False, ).to(device) else: raise ValueError(f"{config.reconstruction.method} is not a supported algorithm") @@ -606,6 +608,7 @@ def train_unrolled(config): post_process_freeze=config.reconstruction.post_process.freeze, post_process_unfreeze=config.reconstruction.post_process.unfreeze, clip_grad=config.training.clip_grad, + unrolled_output_factor=config.unrolled_output_factor, ) trainer.train(n_epoch=config.training.epoch, save_pt=save, disp=config.eval_disp_idx)