diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index 3871be0d..12e7ddc7 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -7,7 +7,7 @@ hydra: files: dataset: data/DiffuserCam # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam" celeba_root: null # path to parent directory of CelebA: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html - psf: data/psf.tiff + psf: data/psf/diffusercam_psf.tiff diffusercam_psf: True n_files: null # null to use all for both train/test downsample: 2 # factor by which to downsample the PSF, note that for DiffuserCam the PSF has 4x the resolution @@ -15,10 +15,11 @@ files: torch: True torch_device: 'cuda' + +# test set example to visualize at the end of every epoch +eval_disp_idx: [0, 1, 2, 3, 4] + display: - # How many iterations to wait for intermediate plot. - # Set to negative value for no intermediate plots. - disp: 500 # Whether to plot results. plot: True # Gamma factor for plotting. diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 2ca758c6..cc4ca700 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -442,7 +442,7 @@ def detect_nan(grad): if param.requires_grad: param.register_hook(detect_nan) - def train_epoch(self, data_loader, disp=-1): + def train_epoch(self, data_loader): """ Train for one epoch. @@ -450,8 +450,6 @@ def train_epoch(self, data_loader, disp=-1): ---------- data_loader : :py:class:`torch.utils.data.DataLoader` Data loader to use for training. - disp : int - Display interval, if -1, no display Returns ------- @@ -481,15 +479,6 @@ def train_epoch(self, data_loader, disp=-1): y_max = torch.amax(y, dim=(-1, -2, -3), keepdim=True) + eps y = y / y_max - if i % disp == 1: - img_pred = y_pred[0, 0].cpu().detach().numpy() - img_truth = y[0, 0].cpu().detach().numpy() - - plt.imshow(img_pred) - plt.savefig(f"y_pred_{i-1}.png") - plt.imshow(img_truth) - plt.savefig(f"y_{i-1}.png") - self.optimizer.zero_grad(set_to_none=True) # convert to CHW for loss and remove depth y_pred = y_pred.reshape(-1, *y_pred.shape[-3:]).movedim(-1, -3) @@ -542,7 +531,7 @@ def train_epoch(self, data_loader, disp=-1): return mean_loss - def evaluate(self, mean_loss, save_pt): + def evaluate(self, mean_loss, save_pt, epoch, disp=None): """ Evaluate the reconstruction algorithm on the test dataset. @@ -552,11 +541,26 @@ def evaluate(self, mean_loss, save_pt): Mean loss of the last epoch. save_pt : str Path to save metrics dictionary to. If None, no logging of metrics. + disp : list of int, optional + Test set examples to visualize at the end of each epoch, by default None. """ if self.test_dataset is None: return + + if disp is not None: + output_dir = os.path.join("eval_recon") + if not os.path.exists(output_dir): + os.mkdir(output_dir) + output_dir = os.path.join(output_dir, str(epoch)) + # benchmarking - current_metrics = benchmark(self.recon, self.test_dataset, batchsize=self.eval_batch_size) + current_metrics = benchmark( + self.recon, + self.test_dataset, + batchsize=self.eval_batch_size, + save_idx=disp, + output_dir=output_dir, + ) # update metrics with current metrics self.metrics["LOSS"].append(mean_loss) @@ -566,7 +570,7 @@ def evaluate(self, mean_loss, save_pt): if save_pt: # save dictionary metrics to file with json with open(os.path.join(save_pt, "metrics.json"), "w") as f: - json.dump(self.metrics, f) + json.dump(self.metrics, f, indent=4) # check best metric if self.metrics["metric_for_best_model"] is None: @@ -579,7 +583,7 @@ def evaluate(self, mean_loss, save_pt): else: return current_metrics[self.metrics["metric_for_best_model"]] - def on_epoch_end(self, mean_loss, save_pt, epoch): + def on_epoch_end(self, mean_loss, save_pt, epoch, disp=None): """ Called at the end of each epoch. @@ -591,6 +595,8 @@ def on_epoch_end(self, mean_loss, save_pt, epoch): Path to save metrics dictionary to. If None, no logging of metrics. epoch : int Current epoch. + disp : list of int, optional + Test set examples to visualize at the end of each epoch, by default None. """ if save_pt is None: # Use current directory @@ -598,7 +604,7 @@ def on_epoch_end(self, mean_loss, save_pt, epoch): # save model # self.save(path=save_pt, include_optimizer=False) - epoch_eval_metric = self.evaluate(mean_loss, save_pt) + epoch_eval_metric = self.evaluate(mean_loss, save_pt, epoch, disp=disp) new_best = False if ( self.metrics["metric_for_best_model"] == "PSNR" @@ -619,7 +625,7 @@ def on_epoch_end(self, mean_loss, save_pt, epoch): if self.save_every is not None and epoch % self.save_every == 0: self.save(path=save_pt, include_optimizer=False, epoch=epoch) - def train(self, n_epoch=1, save_pt=None, disp=-1): + def train(self, n_epoch=1, save_pt=None, disp=None): """ Train the reconstruction algorithm. @@ -629,21 +635,21 @@ def train(self, n_epoch=1, save_pt=None, disp=-1): Number of epochs to train for, by default 1 save_pt : str, optional Path to save metrics dictionary to. If None, use current directory, by default None - disp : int, optional - Display interval, if -1, no display. Default is -1. + disp : list of int, optional + test set examples to visualize at the end of each epoch, by default None. """ start_time = time.time() - self.evaluate(-1, save_pt) + self.evaluate(-1, save_pt, epoch=0, disp=disp) for epoch in range(n_epoch): if self.logger is not None: self.logger.info(f"Epoch {epoch} with learning rate {self.scheduler.get_last_lr()}") else: print(f"Epoch {epoch} with learning rate {self.scheduler.get_last_lr()}") - mean_loss = self.train_epoch(self.train_dataloader, disp=disp) + mean_loss = self.train_epoch(self.train_dataloader) # offset because of evaluate before loop - self.on_epoch_end(mean_loss, save_pt, epoch + 1) + self.on_epoch_end(mean_loss, save_pt, epoch + 1, disp=disp) self.scheduler.step() if self.logger is not None: diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index c9be1ee4..735ebf9b 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -203,10 +203,6 @@ def prep_trainable_mask(config, psf, grayscale=False): @hydra.main(version_base=None, config_path="../../configs", config_name="train_unrolledADMM") def train_unrolled(config): - disp = config.display.disp - if disp < 0: - disp = None - save = config.save if save: save = os.getcwd() @@ -268,6 +264,8 @@ def train_unrolled(config): log.info(f"Train test size : {len(train_set)}") log.info(f"Test test size : {len(test_set)}") + raise ValueError + start_time = time.time() # Load pre process model @@ -344,7 +342,7 @@ def train_unrolled(config): logger=log, ) - trainer.train(n_epoch=config.training.epoch, save_pt=save, disp=disp) + trainer.train(n_epoch=config.training.epoch, save_pt=save, disp=config.eval_disp_idx) log.info(f"Results saved in {save}")