Skip to content

Commit

Permalink
Accept relative paths for reference background. (#143)
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam authored Aug 21, 2024
1 parent 85812fd commit b439e08
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 16 deletions.
4 changes: 2 additions & 2 deletions configs/upload_tapecam_mirflickr_ambient.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ defaults:
- _self_

repo_id: "Lensless/TapeCam-Mirflickr-Ambient"
n_files: null
n_files: 16000
test_size: 0.15
# -- to match TapeCam without ambient light
split: 100 # "first: first `nfiles*test_size` for test, `int`: test_size*split for test (interleaved) as if multimask with this many masks

lensless:
dir: data/100_samples
dir: /dev/shm/tape_15k_ambient/all_measured_20240805-143922
ambient: True
ext: ".png"

Expand Down
23 changes: 17 additions & 6 deletions lensless/recon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -875,15 +875,26 @@ def train_epoch(self, data_loader):
# get batch
flip_lr = None
flip_ud = None
if self.train_random_flip:
X, y, psfs, flip_lr, flip_ud = batch
psfs = psfs.to(self.device)
elif self.train_multimask:
X, y, psfs = batch
X = batch[0].to(self.device)
y = batch[1].to(self.device)
if self.train_multimask or self.train_random_flip:
psfs = batch[2]
psfs = psfs.to(self.device)
else:
X, y = batch
psfs = None
if self.train_random_flip:
flip_lr = batch[3]
flip_ud = batch[4]

# if self.train_random_flip:
# X, y, psfs, flip_lr, flip_ud = batch
# psfs = psfs.to(self.device)
# elif self.train_multimask:
# X, y, psfs = batch
# psfs = psfs.to(self.device)
# else:
# X, y = batch
# psfs = None

random_rotate = False
if self.random_rotate:
Expand Down
16 changes: 8 additions & 8 deletions scripts/recon/train_learning_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import wandb
import logging
import hydra
from hydra.utils import get_original_cwd
from hydra.utils import get_original_cwd, to_absolute_path
import os
import numpy as np
import time
Expand Down Expand Up @@ -227,7 +227,7 @@ def train_learned(config):
display_res=config.files.image_res,
alignment=config.alignment,
bg_snr_range=config.files.background_snr_range, # TODO check if correct
bg_fp=config.files.background_fp,
bg_fp=to_absolute_path(config.files.background_fp),
)

else:
Expand All @@ -251,7 +251,7 @@ def train_learned(config):
simulate_lensless=config.files.simulate_lensless,
random_flip=config.files.random_flip,
bg_snr_range=config.files.background_snr_range,
bg_fp=config.files.background_fp,
bg_fp=to_absolute_path(config.files.background_fp),
)

test_set = HFDataset(
Expand All @@ -271,7 +271,7 @@ def train_learned(config):
n_files=config.files.n_files,
simulation_config=config.simulation,
bg_snr_range=config.files.background_snr_range,
bg_fp=config.files.background_fp,
bg_fp=to_absolute_path(config.files.background_fp),
force_rgb=config.files.force_rgb,
simulate_lensless=False, # in general evaluate on measured (set to False)
)
Expand Down Expand Up @@ -341,6 +341,7 @@ def train_learned(config):
return_items = test_set[_idx]
lensless = return_items[0]
lensed = return_items[1]

if test_set.bg_sim is not None:
background = return_items[-1]
if test_set.multimask or test_set.random_flip:
Expand Down Expand Up @@ -379,6 +380,9 @@ def train_learned(config):
if config.files.random_rotate or config.files.random_shifts:
save_image(psf_recon[0].cpu().numpy(), f"psf_{_idx}.png")

save_image(lensed[0].cpu().numpy(), f"lensed_{_idx}.png")
save_image(lensless[0].cpu().numpy(), f"lensless_raw_{_idx}.png")

# Reconstruct and plot image
reconstruct_save(
_idx,
Expand All @@ -395,7 +399,6 @@ def train_learned(config):
rotate_angle,
shift,
)
save_image(lensed[0].cpu().numpy(), f"lensed_{_idx}.png")
if test_set.bg_sim is not None:
# Reconstruct and plot background subtracted image
reconstruct_save(
Expand Down Expand Up @@ -665,9 +668,6 @@ def reconstruct_save(
res_np = res_np / res_np.max()
lensed_np = lensed[0].cpu().numpy()

lensless_np = lensless[0].cpu().numpy()
save_image(lensless_np, f"lensless_raw_{_idx}.png")

# -- plot lensed and res on top of each other
cropped = False
if hasattr(test_set, "alignment"):
Expand Down

0 comments on commit b439e08

Please sign in to comment.