Skip to content

Commit

Permalink
Merged main into dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanPetersTM committed Aug 8, 2024
1 parent 741837b commit 28fa3bf
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 121 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Added

- Option to pass background image to ``utils.io.load_data``.
- Option to set image resolution with ``hardware.utils.display`` function.
- Option to do background removal in ``util.dataset``
- Option to add simulated background in ``util.dataset``
- Auxiliary of reconstructing output from pre-processor (not working).
- Option to set focal range for MultiLensArray.
- Optional to remove deadspace modelling for programmable mask.
Expand Down
49 changes: 4 additions & 45 deletions configs/train_tapecam_simulated_background.yaml
Original file line number Diff line number Diff line change
@@ -1,54 +1,13 @@
# python scripts/recon/train_learning_based.py -cn train_mirflickr_tape
defaults:
- train_unrolledADMM
- train_mirflickr_tape
- _self_

torch_device: 'cuda:0'
#device_ids: [0, 1, 2, 3]
eval_disp_idx: [ 1, 2, 4, 5, 9 ]

wandb_project:
device_ids:


# Dataset
files:
dataset: bezzam/TapeCam-Mirflickr-25K
huggingface_dataset: True
huggingface_psf: psf.png
downsample: 1
# TODO: these parameters should be in the dataset?
image_res: [ 900, 1200 ] # used during measurement
rotate: False # if measurement is upside-down
save_psf: True

background_fp: ""
background_snr_range: [ -10, 10 ]

# TODO: these parameters should be in the dataset?
alignment:
# when there is no downsampling
top_left: [ 45, 95 ] # height, width
height: 250

training:
batch_size: 4
epoch: 25
eval_batch_size: 4

reconstruction:
method: unrolled_admm
unrolled_admm:
# Number of iterations
n_iter: 5
# Hyperparameters
mu1: 1e-4
mu2: 1e-4
mu3: 1e-4
tau: 2e-4
pre_process:
network: UnetRes # UnetRes or DruNet or null
depth: 4 # depth of each up/downsampling layer. Ignore if network is DruNet
nc: [ 32,64,116,128 ]
post_process:
network: UnetRes # UnetRes or DruNet or null
depth: 4 # depth of each up/downsampling layer. Ignore if network is DruNet
nc: [ 32,64,116,128 ]
background_snr_range: [0,0]
29 changes: 7 additions & 22 deletions lensless/eval/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,31 +121,16 @@ def benchmark(

flip_lr = None
flip_ud = None
if dataset.random_flip:
lensless, lensed, psfs, flip_lr, flip_ud = batch
psfs = psfs.to(device)
elif dataset.multimask:
lensless, lensed, psfs = batch
lensless = batch[0].to(device)
lensed = batch[1].to(device)
if dataset.multimask or dataset.random_flip:
psfs = batch[2]
psfs = psfs.to(device)
if dataset.bg is not None:
lensless, lensed, background = batch
else:
lensless, lensed = batch
psfs = None

# if hasattr(dataset, "multimask"):
# if dataset.multimask:
# lensless, lensed, psfs = batch
# psfs = psfs.to(device)
# else:
# lensless, lensed = batch
# psfs = None
# else:
# lensless, lensed = batch
# psfs = None

lensless = lensless.to(device)
lensed = lensed.to(device)
if dataset.random_flip:
flip_lr = batch[3]
flip_ud = batch[4]

# add shot noise
if snr is not None:
Expand Down
59 changes: 34 additions & 25 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,9 +1288,9 @@ def __init__(
cache_dir=None,
single_channel_psf=False,
random_flip=False,
bg_snr_range=None,
bg_fp=None,
**kwargs,
bg_snr_range=None,
bg_fp=None,
**kwargs,
):
"""
Wrapper for lensless datasets on Hugging Face.
Expand Down Expand Up @@ -1514,7 +1514,7 @@ def __init__(
)
self.simulator = simulator

if bg_fp is not None and os.path.isfile(bg_fp):
if bg_fp is not None:
assert (
bg_snr_range is not None
), "Since a background path was provided, the SNR range should not be empty"
Expand All @@ -1524,13 +1524,13 @@ def __init__(
return_float=True,
flip=rotate,
)
self.bg = torch.from_numpy(bg)
self.bg_sim = torch.from_numpy(bg)
# Used for background noise addition
self.bg_snr_range = bg_snr_range
# Precomputing for efficiency (used in the SNR computations)
self.background_var = torch.var(self.bg.flatten())
self.background_var = torch.var(self.bg_sim.flatten())
else:
self.bg = None
self.bg_sim = None
self.bg_snr_range = None
self.background_var = None

Expand Down Expand Up @@ -1649,32 +1649,41 @@ def __getitem__(self, idx):
lensed = torch.flip(lensed, dims=(-3,))
psf_aug = torch.flip(psf_aug, dims=(-3,))

# return corresponding PSF
return_items = [lensless, lensed]
if self.multimask:
if self.return_mask_label:
return lensless, lensed, mask_label
return_items.append(mask_label)
else:
if not self.random_flip:
return lensless, lensed, self.psf[mask_label]
else:
return lensless, lensed, psf_aug, flip_lr, flip_ud
return_items.append(self.psf[mask_label])
if self.random_flip:
return_items.append(flip_lr)
return_items.append(flip_ud)
else:
if self.random_flip:
return lensless, lensed, psf_aug, flip_lr, flip_ud
# Add background to achieve desired SNR
if self.bg is not None:
sig_var = torch.var(lensless.flatten())
target_snr = np.random.uniform(self.bg_snr_range[0], self.bg_snr_range[1])
alpha = torch.sqrt(sig_var / self.background_var / (10**target_snr / 10))
return_items.append(psf_aug)
return_items.append(flip_lr)
return_items.append(flip_ud)

scaled_bg = alpha * self.bg
# Add background to achieve desired SNR
if self.bg_sim is not None:
sig_var = torch.var(lensless.flatten())
target_snr = np.random.uniform(self.bg_snr_range[0], self.bg_snr_range[1])
alpha = torch.sqrt(sig_var / self.background_var / (10**target_snr / 10))

# Add background noise to the target image
image_with_bg = lensless + scaled_bg
scaled_bg = alpha * self.bg_sim

return image_with_bg, lensed, scaled_bg
else:
return lensless, lensed
# Add background noise to the target image
image_with_bg = lensless + scaled_bg

return image_with_bg, lensed, scaled_bg
else:
return lensless, lensed

# add simulated background to get image_with_bg and scaled_bg
return_items[0] = image_with_bg
return_items[0].append(scaled_bg)

return return_items

def extract_roi(
self,
Expand Down
77 changes: 49 additions & 28 deletions scripts/recon/train_learning_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def train_learned(config):
display_res=config.files.image_res,
alignment=config.alignment,
bg_snr_range=config.files.background_snr_range, # TODO check if correct
bg=config.files.background_fp,
bg_fp=config.files.background_fp,
)

else:
Expand All @@ -251,7 +251,7 @@ def train_learned(config):
simulate_lensless=config.files.simulate_lensless,
random_flip=config.files.random_flip,
bg_snr_range=config.files.background_snr_range,
bg=config.files.background_fp,
bg_fp=config.files.background_fp,
)

test_set = HFDataset(
Expand All @@ -271,7 +271,7 @@ def train_learned(config):
n_files=config.files.n_files,
simulation_config=config.simulation,
bg_snr_range=config.files.background_snr_range,
bg=config.files.background_fp,
bg_fp=config.files.background_fp,
force_rgb=config.files.force_rgb,
simulate_lensless=False, # in general evaluate on measured (set to False)
)
Expand Down Expand Up @@ -338,17 +338,19 @@ def train_learned(config):

flip_lr = None
flip_ud = None
if test_set.random_flip:
lensless, lensed, psf_recon, flip_lr, flip_ud = test_set[_idx]
psf_recon = psf_recon.to(device)
elif test_set.multimask:
lensless, lensed, psf_recon = test_set[_idx]
return_items = test_set[_idx]
lensless = return_items[0]
lensed = return_items[1]
if test_set.bg_sim is not None:
background = return_items[-1]
if test_set.multimask or test_set.random_flip:
psf_recon = return_items[2]
psf_recon = psf_recon.to(device)
elif test_set.bg is not None:
lensless, lensed, bg = test_set[_idx]
else:
lensless, lensed = test_set[_idx]
psf_recon = psf.clone()
if test_set.random_flip:
flip_lr = return_items[3]
flip_ud = return_items[4]

rotate_angle = False
if config.files.random_rotate:
Expand All @@ -375,24 +377,41 @@ def train_learned(config):
shift = tuple(shift)

if config.files.random_rotate or config.files.random_shifts:

save_image(psf_recon[0].cpu().numpy(), f"psf_{_idx}.png")

# Reconstruct and plot image
reconstruct_save(_idx, config, crop, i, lensed, lensless, psf, test_set, "")
reconstruct_save(
_idx,
config,
crop,
i,
lensed,
lensless,
psf,
test_set,
"",
flip_lr,
flip_ud,
rotate_angle,
shift,
)
save_image(lensed[0].cpu().numpy(), f"lensed_{_idx}.png")
if test_set.bg != None:
if test_set.bg_sim is not None:
# Reconstruct and plot background subtracted image
reconstruct_save(
_idx,
config,
crop,
i,
lensed,
(lensless - bg),
(lensless - background),
psf,
test_set,
"subtraction_",
flip_lr,
flip_ud,
rotate_angle,
shift,
)
log.info(f"Train test size : {len(train_set)}")
log.info(f"Test test size : {len(test_set)}")
Expand All @@ -416,9 +435,11 @@ def train_learned(config):
nc=config.reconstruction.post_process.nc,
device=device,
device_ids=device_ids,
concatenate_compensation=config.reconstruction.compensation[-1]
if config.reconstruction.compensation is not None
else False,
concatenate_compensation=(
config.reconstruction.compensation[-1]
if config.reconstruction.compensation is not None
else False
),
)
post_proc_delay = config.reconstruction.post_process.delay

Expand Down Expand Up @@ -499,9 +520,9 @@ def train_learned(config):
pre_process=pre_process if pre_proc_delay is None else None,
post_process=post_process if post_proc_delay is None else None,
skip_unrolled=config.reconstruction.skip_unrolled,
return_intermediate=True
if config.unrolled_output_factor > 0 or config.pre_proc_aux > 0
else False,
return_intermediate=(
True if config.unrolled_output_factor > 0 or config.pre_proc_aux > 0 else False
),
compensation=config.reconstruction.compensation,
compensation_residual=config.reconstruction.compensation_residual,
)
Expand All @@ -516,9 +537,9 @@ def train_learned(config):
pre_process=pre_process if pre_proc_delay is None else None,
post_process=post_process if post_proc_delay is None else None,
skip_unrolled=config.reconstruction.skip_unrolled,
return_intermediate=True
if config.unrolled_output_factor > 0 or config.pre_proc_aux > 0
else False,
return_intermediate=(
True if config.unrolled_output_factor > 0 or config.pre_proc_aux > 0 else False
),
compensation=config.reconstruction.compensation,
compensation_residual=config.reconstruction.compensation_residual,
)
Expand All @@ -529,9 +550,9 @@ def train_learned(config):
K=config.reconstruction.trainable_inv.K,
pre_process=pre_process if pre_proc_delay is None else None,
post_process=post_process if post_proc_delay is None else None,
return_intermediate=True
if config.unrolled_output_factor > 0 or config.pre_proc_aux > 0
else False,
return_intermediate=(
True if config.unrolled_output_factor > 0 or config.pre_proc_aux > 0 else False
),
)
elif config.reconstruction.method == "multi_wiener":

Expand Down Expand Up @@ -688,7 +709,7 @@ def reconstruct_save(
if cropped and i == 0:
log.info(f"Cropped shape : {res_np.shape}")

save_image(res_np, f"lensless_recon_{_idx}.png")
save_image(res_np, f"lensless_recon_{fp}{_idx}.png")

plt.figure()
plt.imshow(lensed_np, alpha=0.4)
Expand Down

0 comments on commit 28fa3bf

Please sign in to comment.