From fe2eff5dd5ce3aa1867693a2cd6086d7ef096818 Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Fri, 15 Sep 2023 09:38:58 +0000 Subject: [PATCH] Add logic for saving best model. --- configs/train_unrolledADMM.yaml | 2 ++ lensless/recon/utils.py | 58 +++++++++++++++++++++++++++++---- scripts/recon/train_unrolled.py | 6 ++++ 3 files changed, 60 insertions(+), 6 deletions(-) diff --git a/configs/train_unrolledADMM.yaml b/configs/train_unrolledADMM.yaml index 65115aed..d19cee5d 100644 --- a/configs/train_unrolledADMM.yaml +++ b/configs/train_unrolledADMM.yaml @@ -90,6 +90,8 @@ simulation: training: batch_size: 8 epoch: 50 + metric_for_best_model: LPIPS_Vgg + save_every: null #In case of instable training skip_NAN: True slow_start: False #float how much to reduce lr for first epoch diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index b82b62b7..91966194 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -9,6 +9,7 @@ import json import math +import numpy as np import time from hydra.utils import get_original_cwd import os @@ -250,6 +251,8 @@ def __init__( slow_start=None, skip_NAN=False, algorithm_name="Unknown", + metric_for_best_model=None, + save_every=None, ): """ Class to train a reconstruction algorithm. Inspired by Trainer from `HuggingFace `__. @@ -288,6 +291,10 @@ def __init__( Whether to skip update if any gradiant are NAN (True) or to throw an error(False), by default False algorithm_name : str, optional Algorithm name for logging, by default "Unknown". + metric_for_best_model : str, optional + Metric to use for saving the best model. If None, will default to evaluation loss. Default is None. + save_every : int, optional + Save model every ``save_every`` epochs. If None, just save best model. """ self.device = recon._psf.device @@ -379,6 +386,15 @@ def learning_rate_function(epoch): "n_iter": self.recon._n_iter, "algorithm": algorithm_name, } + self.metric_for_best_model = metric_for_best_model + if self.metric_for_best_model is not None: + assert self.metric_for_best_model in self.metrics.keys() + if self.metric_for_best_model == "PSNR" or self.metric_for_best_model == "SSIM": + self.best_eval_score = 0 + else: + self.best_eval_score = np.inf + self.best_epoch_fn = None + self.save_every = save_every # Backward hook that detect NAN in the gradient and print the layer weights if not self.skip_NAN: @@ -512,7 +528,16 @@ def evaluate(self, mean_loss, save_pt): with open(os.path.join(save_pt, "metrics.json"), "w") as f: json.dump(self.metrics, f) - def on_epoch_end(self, mean_loss, save_pt): + # check best metric + if self.metric_for_best_model is None: + eval_loss = current_metrics["MSE"] + if self.lpips is not None: + eval_loss += self.lpips * current_metrics["LPIPS_Vgg"] + return eval_loss + else: + return current_metrics[self.metric_for_best_model] + + def on_epoch_end(self, mean_loss, save_pt, epoch): """ Called at the end of each epoch. @@ -522,14 +547,34 @@ def on_epoch_end(self, mean_loss, save_pt): Mean loss of the last epoch. save_pt : str Path to save metrics dictionary to. If None, no logging of metrics. + epoch : int + Current epoch. """ if save_pt is None: # Use current directory save_pt = os.getcwd() # save model - self.save(path=save_pt, include_optimizer=False) - self.evaluate(mean_loss, save_pt) + # self.save(path=save_pt, include_optimizer=False) + epoch_eval_metric = self.evaluate(mean_loss, save_pt) + new_best = False + if self.metric_for_best_model == "PSNR" or self.metric_for_best_model == "SSIM": + if epoch_eval_metric > self.best_eval_score: + self.best_eval_score = epoch_eval_metric + new_best = True + else: + if epoch_eval_metric < self.best_eval_score: + self.best_eval_score = epoch_eval_metric + new_best = True + + if new_best: + if self.best_epoch_fn is not None: + os.remove(os.path.join(save_pt, self.best_epoch_fn)) + self.best_epoch_fn = f"recon_BEST_epoch{epoch}.pt" + self.save(path=save_pt, include_optimizer=False, fn=self.best_epoch_fn) + + if self.save_every is not None and epoch % self.save_every == 0: + self.save(path=save_pt, include_optimizer=False, fn=f"recon_epoch{epoch}.pt") def train(self, n_epoch=1, save_pt=None, disp=-1): """ @@ -551,12 +596,13 @@ def train(self, n_epoch=1, save_pt=None, disp=-1): for epoch in range(n_epoch): print(f"Epoch {epoch} with learning rate {self.scheduler.get_last_lr()}") mean_loss = self.train_epoch(self.train_dataloader, disp=disp) - self.on_epoch_end(mean_loss, save_pt) + # offset because of evaluate before loop + self.on_epoch_end(mean_loss, save_pt, epoch + 1) self.scheduler.step() print(f"Train time : {time.time() - start_time} s") - def save(self, path="recon", include_optimizer=False): + def save(self, path="recon", include_optimizer=False, fn="recon.pt"): # create directory if it does not exist if not os.path.exists(path): os.makedirs(path) @@ -573,4 +619,4 @@ def save(self, path="recon", include_optimizer=False): if include_optimizer: torch.save(self.optimizer.state_dict(), os.path.join(path, "optim.pt")) # save recon - torch.save(self.recon.state_dict(), os.path.join(path, "recon.pt")) + torch.save(self.recon.state_dict(), os.path.join(path, f"{fn}")) diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index add6b8b8..7e07faaa 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -258,7 +258,11 @@ def train_unrolled(config): psf_path=psf_path, downsample=config.files.downsample, ) + # test set is after 1000 indices = dataset.allowed_idx[dataset.allowed_idx > 1000] + if config.files.n_files is not None: + indices = indices[: config.files.n_files] + train_set = Subset(dataset, indices) print("Train test size : ", len(train_set)) @@ -288,6 +292,8 @@ def train_unrolled(config): slow_start=config.training.slow_start, skip_NAN=config.training.skip_NAN, algorithm_name=algorithm_name, + metric_for_best_model=config.training.metric_for_best_model, + save_every=config.training.save_every, ) trainer.train(n_epoch=config.training.epoch, save_pt=save, disp=disp)