diff --git a/configs/upload_tapecam_mirflickr_ambient.yaml b/configs/upload_tapecam_mirflickr_ambient.yaml index 0d62238a..f1196ca3 100644 --- a/configs/upload_tapecam_mirflickr_ambient.yaml +++ b/configs/upload_tapecam_mirflickr_ambient.yaml @@ -4,13 +4,13 @@ defaults: - _self_ repo_id: "Lensless/TapeCam-Mirflickr-Ambient" -n_files: null +n_files: 16000 test_size: 0.15 # -- to match TapeCam without ambient light split: 100 # "first: first `nfiles*test_size` for test, `int`: test_size*split for test (interleaved) as if multimask with this many masks lensless: - dir: data/100_samples + dir: /dev/shm/tape_15k_ambient/all_measured_20240805-143922 ambient: True ext: ".png" diff --git a/lensless/recon/utils.py b/lensless/recon/utils.py index 65c6e238..9c75877f 100644 --- a/lensless/recon/utils.py +++ b/lensless/recon/utils.py @@ -875,15 +875,26 @@ def train_epoch(self, data_loader): # get batch flip_lr = None flip_ud = None - if self.train_random_flip: - X, y, psfs, flip_lr, flip_ud = batch - psfs = psfs.to(self.device) - elif self.train_multimask: - X, y, psfs = batch + X = batch[0].to(self.device) + y = batch[1].to(self.device) + if self.train_multimask or self.train_random_flip: + psfs = batch[2] psfs = psfs.to(self.device) else: - X, y = batch psfs = None + if self.train_random_flip: + flip_lr = batch[3] + flip_ud = batch[4] + + # if self.train_random_flip: + # X, y, psfs, flip_lr, flip_ud = batch + # psfs = psfs.to(self.device) + # elif self.train_multimask: + # X, y, psfs = batch + # psfs = psfs.to(self.device) + # else: + # X, y = batch + # psfs = None random_rotate = False if self.random_rotate: diff --git a/scripts/recon/train_learning_based.py b/scripts/recon/train_learning_based.py index 77e6e7d0..6a048d13 100644 --- a/scripts/recon/train_learning_based.py +++ b/scripts/recon/train_learning_based.py @@ -31,7 +31,7 @@ import wandb import logging import hydra -from hydra.utils import get_original_cwd +from hydra.utils import get_original_cwd, to_absolute_path import os import numpy as np import time @@ -227,7 +227,7 @@ def train_learned(config): display_res=config.files.image_res, alignment=config.alignment, bg_snr_range=config.files.background_snr_range, # TODO check if correct - bg_fp=config.files.background_fp, + bg_fp=to_absolute_path(config.files.background_fp), ) else: @@ -251,7 +251,7 @@ def train_learned(config): simulate_lensless=config.files.simulate_lensless, random_flip=config.files.random_flip, bg_snr_range=config.files.background_snr_range, - bg_fp=config.files.background_fp, + bg_fp=to_absolute_path(config.files.background_fp), ) test_set = HFDataset( @@ -271,7 +271,7 @@ def train_learned(config): n_files=config.files.n_files, simulation_config=config.simulation, bg_snr_range=config.files.background_snr_range, - bg_fp=config.files.background_fp, + bg_fp=to_absolute_path(config.files.background_fp), force_rgb=config.files.force_rgb, simulate_lensless=False, # in general evaluate on measured (set to False) ) @@ -341,6 +341,7 @@ def train_learned(config): return_items = test_set[_idx] lensless = return_items[0] lensed = return_items[1] + if test_set.bg_sim is not None: background = return_items[-1] if test_set.multimask or test_set.random_flip: @@ -379,6 +380,9 @@ def train_learned(config): if config.files.random_rotate or config.files.random_shifts: save_image(psf_recon[0].cpu().numpy(), f"psf_{_idx}.png") + save_image(lensed[0].cpu().numpy(), f"lensed_{_idx}.png") + save_image(lensless[0].cpu().numpy(), f"lensless_raw_{_idx}.png") + # Reconstruct and plot image reconstruct_save( _idx, @@ -395,7 +399,6 @@ def train_learned(config): rotate_angle, shift, ) - save_image(lensed[0].cpu().numpy(), f"lensed_{_idx}.png") if test_set.bg_sim is not None: # Reconstruct and plot background subtracted image reconstruct_save( @@ -665,9 +668,6 @@ def reconstruct_save( res_np = res_np / res_np.max() lensed_np = lensed[0].cpu().numpy() - lensless_np = lensless[0].cpu().numpy() - save_image(lensless_np, f"lensless_raw_{_idx}.png") - # -- plot lensed and res on top of each other cropped = False if hasattr(test_set, "alignment"):