Skip to content

Commit

Permalink
Add logic for saving best model.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Sep 15, 2023
1 parent fb7b8d4 commit fe2eff5
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 6 deletions.
2 changes: 2 additions & 0 deletions configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ simulation:
training:
batch_size: 8
epoch: 50
metric_for_best_model: LPIPS_Vgg
save_every: null
#In case of instable training
skip_NAN: True
slow_start: False #float how much to reduce lr for first epoch
Expand Down
58 changes: 52 additions & 6 deletions lensless/recon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import json
import math
import numpy as np
import time
from hydra.utils import get_original_cwd
import os
Expand Down Expand Up @@ -250,6 +251,8 @@ def __init__(
slow_start=None,
skip_NAN=False,
algorithm_name="Unknown",
metric_for_best_model=None,
save_every=None,
):
"""
Class to train a reconstruction algorithm. Inspired by Trainer from `HuggingFace <https://huggingface.co/docs/transformers/main_classes/trainer>`__.
Expand Down Expand Up @@ -288,6 +291,10 @@ def __init__(
Whether to skip update if any gradiant are NAN (True) or to throw an error(False), by default False
algorithm_name : str, optional
Algorithm name for logging, by default "Unknown".
metric_for_best_model : str, optional
Metric to use for saving the best model. If None, will default to evaluation loss. Default is None.
save_every : int, optional
Save model every ``save_every`` epochs. If None, just save best model.
"""
self.device = recon._psf.device
Expand Down Expand Up @@ -379,6 +386,15 @@ def learning_rate_function(epoch):
"n_iter": self.recon._n_iter,
"algorithm": algorithm_name,
}
self.metric_for_best_model = metric_for_best_model
if self.metric_for_best_model is not None:
assert self.metric_for_best_model in self.metrics.keys()
if self.metric_for_best_model == "PSNR" or self.metric_for_best_model == "SSIM":
self.best_eval_score = 0
else:
self.best_eval_score = np.inf
self.best_epoch_fn = None
self.save_every = save_every

# Backward hook that detect NAN in the gradient and print the layer weights
if not self.skip_NAN:
Expand Down Expand Up @@ -512,7 +528,16 @@ def evaluate(self, mean_loss, save_pt):
with open(os.path.join(save_pt, "metrics.json"), "w") as f:
json.dump(self.metrics, f)

def on_epoch_end(self, mean_loss, save_pt):
# check best metric
if self.metric_for_best_model is None:
eval_loss = current_metrics["MSE"]
if self.lpips is not None:
eval_loss += self.lpips * current_metrics["LPIPS_Vgg"]
return eval_loss
else:
return current_metrics[self.metric_for_best_model]

def on_epoch_end(self, mean_loss, save_pt, epoch):
"""
Called at the end of each epoch.
Expand All @@ -522,14 +547,34 @@ def on_epoch_end(self, mean_loss, save_pt):
Mean loss of the last epoch.
save_pt : str
Path to save metrics dictionary to. If None, no logging of metrics.
epoch : int
Current epoch.
"""
if save_pt is None:
# Use current directory
save_pt = os.getcwd()

# save model
self.save(path=save_pt, include_optimizer=False)
self.evaluate(mean_loss, save_pt)
# self.save(path=save_pt, include_optimizer=False)
epoch_eval_metric = self.evaluate(mean_loss, save_pt)
new_best = False
if self.metric_for_best_model == "PSNR" or self.metric_for_best_model == "SSIM":
if epoch_eval_metric > self.best_eval_score:
self.best_eval_score = epoch_eval_metric
new_best = True
else:
if epoch_eval_metric < self.best_eval_score:
self.best_eval_score = epoch_eval_metric
new_best = True

if new_best:
if self.best_epoch_fn is not None:
os.remove(os.path.join(save_pt, self.best_epoch_fn))
self.best_epoch_fn = f"recon_BEST_epoch{epoch}.pt"
self.save(path=save_pt, include_optimizer=False, fn=self.best_epoch_fn)

if self.save_every is not None and epoch % self.save_every == 0:
self.save(path=save_pt, include_optimizer=False, fn=f"recon_epoch{epoch}.pt")

def train(self, n_epoch=1, save_pt=None, disp=-1):
"""
Expand All @@ -551,12 +596,13 @@ def train(self, n_epoch=1, save_pt=None, disp=-1):
for epoch in range(n_epoch):
print(f"Epoch {epoch} with learning rate {self.scheduler.get_last_lr()}")
mean_loss = self.train_epoch(self.train_dataloader, disp=disp)
self.on_epoch_end(mean_loss, save_pt)
# offset because of evaluate before loop
self.on_epoch_end(mean_loss, save_pt, epoch + 1)
self.scheduler.step()

print(f"Train time : {time.time() - start_time} s")

def save(self, path="recon", include_optimizer=False):
def save(self, path="recon", include_optimizer=False, fn="recon.pt"):
# create directory if it does not exist
if not os.path.exists(path):
os.makedirs(path)
Expand All @@ -573,4 +619,4 @@ def save(self, path="recon", include_optimizer=False):
if include_optimizer:
torch.save(self.optimizer.state_dict(), os.path.join(path, "optim.pt"))
# save recon
torch.save(self.recon.state_dict(), os.path.join(path, "recon.pt"))
torch.save(self.recon.state_dict(), os.path.join(path, f"{fn}"))
6 changes: 6 additions & 0 deletions scripts/recon/train_unrolled.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,11 @@ def train_unrolled(config):
psf_path=psf_path,
downsample=config.files.downsample,
)
# test set is after 1000
indices = dataset.allowed_idx[dataset.allowed_idx > 1000]
if config.files.n_files is not None:
indices = indices[: config.files.n_files]

train_set = Subset(dataset, indices)
print("Train test size : ", len(train_set))

Expand Down Expand Up @@ -288,6 +292,8 @@ def train_unrolled(config):
slow_start=config.training.slow_start,
skip_NAN=config.training.skip_NAN,
algorithm_name=algorithm_name,
metric_for_best_model=config.training.metric_for_best_model,
save_every=config.training.save_every,
)

trainer.train(n_epoch=config.training.epoch, save_pt=save, disp=disp)
Expand Down

0 comments on commit fe2eff5

Please sign in to comment.