Skip to content

Commit

Permalink
Save eval examples.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Oct 5, 2023
1 parent 315d211 commit 18e46ad
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 32 deletions.
9 changes: 5 additions & 4 deletions configs/train_unrolledADMM.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,19 @@ hydra:
files:
dataset: data/DiffuserCam # Simulated : "mnist", "fashion_mnist", "cifar10", "CelebA". Measure :"DiffuserCam"
celeba_root: null # path to parent directory of CelebA: https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html
psf: data/psf.tiff
psf: data/psf/diffusercam_psf.tiff
diffusercam_psf: True
n_files: null # null to use all for both train/test
downsample: 2 # factor by which to downsample the PSF, note that for DiffuserCam the PSF has 4x the resolution

torch: True
torch_device: 'cuda'


# test set example to visualize at the end of every epoch
eval_disp_idx: [0, 1, 2, 3, 4]

display:
# How many iterations to wait for intermediate plot.
# Set to negative value for no intermediate plots.
disp: 500
# Whether to plot results.
plot: True
# Gamma factor for plotting.
Expand Down
52 changes: 29 additions & 23 deletions lensless/recon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,16 +442,14 @@ def detect_nan(grad):
if param.requires_grad:
param.register_hook(detect_nan)

def train_epoch(self, data_loader, disp=-1):
def train_epoch(self, data_loader):
"""
Train for one epoch.
Parameters
----------
data_loader : :py:class:`torch.utils.data.DataLoader`
Data loader to use for training.
disp : int
Display interval, if -1, no display
Returns
-------
Expand Down Expand Up @@ -481,15 +479,6 @@ def train_epoch(self, data_loader, disp=-1):
y_max = torch.amax(y, dim=(-1, -2, -3), keepdim=True) + eps
y = y / y_max

if i % disp == 1:
img_pred = y_pred[0, 0].cpu().detach().numpy()
img_truth = y[0, 0].cpu().detach().numpy()

plt.imshow(img_pred)
plt.savefig(f"y_pred_{i-1}.png")
plt.imshow(img_truth)
plt.savefig(f"y_{i-1}.png")

self.optimizer.zero_grad(set_to_none=True)
# convert to CHW for loss and remove depth
y_pred = y_pred.reshape(-1, *y_pred.shape[-3:]).movedim(-1, -3)
Expand Down Expand Up @@ -542,7 +531,7 @@ def train_epoch(self, data_loader, disp=-1):

return mean_loss

def evaluate(self, mean_loss, save_pt):
def evaluate(self, mean_loss, save_pt, epoch, disp=None):
"""
Evaluate the reconstruction algorithm on the test dataset.
Expand All @@ -552,11 +541,26 @@ def evaluate(self, mean_loss, save_pt):
Mean loss of the last epoch.
save_pt : str
Path to save metrics dictionary to. If None, no logging of metrics.
disp : list of int, optional
Test set examples to visualize at the end of each epoch, by default None.
"""
if self.test_dataset is None:
return

if disp is not None:
output_dir = os.path.join("eval_recon")
if not os.path.exists(output_dir):
os.mkdir(output_dir)
output_dir = os.path.join(output_dir, str(epoch))

# benchmarking
current_metrics = benchmark(self.recon, self.test_dataset, batchsize=self.eval_batch_size)
current_metrics = benchmark(
self.recon,
self.test_dataset,
batchsize=self.eval_batch_size,
save_idx=disp,
output_dir=output_dir,
)

# update metrics with current metrics
self.metrics["LOSS"].append(mean_loss)
Expand All @@ -566,7 +570,7 @@ def evaluate(self, mean_loss, save_pt):
if save_pt:
# save dictionary metrics to file with json
with open(os.path.join(save_pt, "metrics.json"), "w") as f:
json.dump(self.metrics, f)
json.dump(self.metrics, f, indent=4)

# check best metric
if self.metrics["metric_for_best_model"] is None:
Expand All @@ -579,7 +583,7 @@ def evaluate(self, mean_loss, save_pt):
else:
return current_metrics[self.metrics["metric_for_best_model"]]

def on_epoch_end(self, mean_loss, save_pt, epoch):
def on_epoch_end(self, mean_loss, save_pt, epoch, disp=None):
"""
Called at the end of each epoch.
Expand All @@ -591,14 +595,16 @@ def on_epoch_end(self, mean_loss, save_pt, epoch):
Path to save metrics dictionary to. If None, no logging of metrics.
epoch : int
Current epoch.
disp : list of int, optional
Test set examples to visualize at the end of each epoch, by default None.
"""
if save_pt is None:
# Use current directory
save_pt = os.getcwd()

# save model
# self.save(path=save_pt, include_optimizer=False)
epoch_eval_metric = self.evaluate(mean_loss, save_pt)
epoch_eval_metric = self.evaluate(mean_loss, save_pt, epoch, disp=disp)
new_best = False
if (
self.metrics["metric_for_best_model"] == "PSNR"
Expand All @@ -619,7 +625,7 @@ def on_epoch_end(self, mean_loss, save_pt, epoch):
if self.save_every is not None and epoch % self.save_every == 0:
self.save(path=save_pt, include_optimizer=False, epoch=epoch)

def train(self, n_epoch=1, save_pt=None, disp=-1):
def train(self, n_epoch=1, save_pt=None, disp=None):
"""
Train the reconstruction algorithm.
Expand All @@ -629,21 +635,21 @@ def train(self, n_epoch=1, save_pt=None, disp=-1):
Number of epochs to train for, by default 1
save_pt : str, optional
Path to save metrics dictionary to. If None, use current directory, by default None
disp : int, optional
Display interval, if -1, no display. Default is -1.
disp : list of int, optional
test set examples to visualize at the end of each epoch, by default None.
"""

start_time = time.time()

self.evaluate(-1, save_pt)
self.evaluate(-1, save_pt, epoch=0, disp=disp)
for epoch in range(n_epoch):
if self.logger is not None:
self.logger.info(f"Epoch {epoch} with learning rate {self.scheduler.get_last_lr()}")
else:
print(f"Epoch {epoch} with learning rate {self.scheduler.get_last_lr()}")
mean_loss = self.train_epoch(self.train_dataloader, disp=disp)
mean_loss = self.train_epoch(self.train_dataloader)
# offset because of evaluate before loop
self.on_epoch_end(mean_loss, save_pt, epoch + 1)
self.on_epoch_end(mean_loss, save_pt, epoch + 1, disp=disp)
self.scheduler.step()

if self.logger is not None:
Expand Down
8 changes: 3 additions & 5 deletions scripts/recon/train_unrolled.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,6 @@ def prep_trainable_mask(config, psf, grayscale=False):
@hydra.main(version_base=None, config_path="../../configs", config_name="train_unrolledADMM")
def train_unrolled(config):

disp = config.display.disp
if disp < 0:
disp = None

save = config.save
if save:
save = os.getcwd()
Expand Down Expand Up @@ -268,6 +264,8 @@ def train_unrolled(config):
log.info(f"Train test size : {len(train_set)}")
log.info(f"Test test size : {len(test_set)}")

raise ValueError

start_time = time.time()

# Load pre process model
Expand Down Expand Up @@ -344,7 +342,7 @@ def train_unrolled(config):
logger=log,
)

trainer.train(n_epoch=config.training.epoch, save_pt=save, disp=disp)
trainer.train(n_epoch=config.training.epoch, save_pt=save, disp=config.eval_disp_idx)

log.info(f"Results saved in {save}")

Expand Down

0 comments on commit 18e46ad

Please sign in to comment.