Skip to content

Commit

Permalink
Log reconstruction images.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Apr 29, 2024
1 parent 0bb90d5 commit 3baab47
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 3 deletions.
12 changes: 11 additions & 1 deletion lensless/eval/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from tqdm import tqdm
import os
import numpy as np
import wandb

try:
import torch
Expand All @@ -37,6 +38,9 @@ def benchmark(
unrolled_output_factor=False,
return_average=True,
snr=None,
use_wandb=False,
label=None,
epoch=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -179,7 +183,13 @@ def benchmark(
prediction_np = prediction.cpu().numpy()[i]
# switch to [H, W, C] for saving
prediction_np = np.moveaxis(prediction_np, 0, -1)
save_image(prediction_np, fp=os.path.join(output_dir, f"{_batch_idx}.png"))
fp = os.path.join(output_dir, f"{_batch_idx}.png")
save_image(prediction_np, fp=fp)

if use_wandb:
assert epoch is not None, "epoch must be provided for wandb logging"
log_key = f"{_batch_idx}_{label}" if label is not None else f"{_batch_idx}"
wandb.log({log_key: wandb.Image(fp)}, step=epoch)

# normalization
prediction_max = torch.amax(prediction, dim=(-1, -2, -3), keepdim=True)
Expand Down
7 changes: 6 additions & 1 deletion lensless/recon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,8 @@ def evaluate(self, mean_loss, epoch, disp=None):
output_dir=output_dir,
crop=self.crop,
unrolled_output_factor=self.unrolled_output_factor,
use_wandb=self.use_wandb,
epoch=epoch,
)

# update metrics with current metrics
Expand Down Expand Up @@ -860,6 +862,9 @@ def evaluate(self, mean_loss, epoch, disp=None):
output_dir=output_dir,
crop=self.crop,
unrolled_output_factor=self.unrolled_output_factor,
use_wandb=self.use_wandb,
label=eval_set,
epoch=epoch,
)

# add metrics to dictionary
Expand Down Expand Up @@ -944,7 +949,7 @@ def train(self, n_epoch=1, save_pt=None, disp=None):

start_time = time.time()

self.evaluate(-1, epoch=0, disp=disp)
self.evaluate(mean_loss=1, epoch=0, disp=disp)
for epoch in range(n_epoch):

# add extra components (if specified)
Expand Down
4 changes: 3 additions & 1 deletion scripts/recon/train_unrolled.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,15 @@ def train_unrolled(config):
use_cuda = False
if "cuda" in config.torch_device and torch.cuda.is_available():
# if config.torch_device == "cuda" and torch.cuda.is_available():
log.info("Using GPU for training.")
log.info(f"Using GPU for training. Main device : {config.torch_device}")
device = config.torch_device
use_cuda = True
else:
log.info("Using CPU for training.")
device = "cpu"
device_ids = config.device_ids
if device_ids is not None:
log.info(f"Using multiple GPUs : {device_ids}")

# load dataset and create dataloader
train_set = None
Expand Down

0 comments on commit 3baab47

Please sign in to comment.