Skip to content

Commit

Permalink
feature: Add background estimator
Browse files Browse the repository at this point in the history
Added a new background preprocessor to estimate the background/ambient noise. This block can be simply added in front of the existing pre-inv-post pipeline
  • Loading branch information
StefanPetersTM committed Dec 17, 2024
1 parent 99f9ce2 commit 18e8e66
Show file tree
Hide file tree
Showing 4 changed files with 662 additions and 9 deletions.
31 changes: 31 additions & 0 deletions configs/train_background_estimator.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# python scripts/recon/background_estimator.py -cn configs/train_background_estimator

defaults:
- train_mirflickr_multilens_ambient
- _self_

wandb_project: null
device_ids: [1]
torch_device: cuda:1

# Dataset
files:
dataset: Lensless/TapeCam-Mirflickr-Ambient-100 # 100 examples
image_res: [600, 600]
per_pixel_color_shift: True
per_pixel_color_shift_range: [ 0.8, 1.2 ]

alignment:
# when there is no downsampling
top_left: [85, 185] # height, width
height: 178

optimizer:
type: AdamW
cosine_decay_warmup: True
final_lr: 2e-5

training:
lr: 7e-5
batch_size: 4
num_epochs: 4
18 changes: 18 additions & 0 deletions lensless/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1422,6 +1422,8 @@ def __init__(
cache_dir=None,
single_channel_psf=False,
random_flip=False,
per_pixel_color_shift=False,
per_pixel_color_shift_range=None,
bg_snr_range=None,
bg_fp=None,
**kwargs,
Expand Down Expand Up @@ -1461,6 +1463,10 @@ def __init__(
If multimask dataset, save the simulated PSFs.
random_flip : bool, optional
If True, randomly flip the lensless images vertically and horizonally with equal probability. By default, no flipping.
per_pixel_color_shift: bool, optional
If True: randomly shift the color of each pixel in the lensless image. By default, no color shift.
per_pixel_color_shift_range: list, optional
Range of possible color shifts for each pixel in the lensless image. Used in conjunction with 'per_pixel_color_shift'.
simulation_config : dict, optional
Simulation parameters for PSF if using a mask pattern.
bg_snr_range : list, optional
Expand Down Expand Up @@ -1488,6 +1494,8 @@ def __init__(

# augmentation
self.random_flip = random_flip
self.per_pixel_color_shift = per_pixel_color_shift
self.per_pixel_color_shift_range = per_pixel_color_shift_range

# deduce downsampling factor from the first image
data_0 = self.dataset[0]
Expand Down Expand Up @@ -1826,6 +1834,16 @@ def __getitem__(self, idx):
psf_aug = torch.flip(psf_aug, dims=(-3,))
background = torch.flip(background, dims=(-3,))

if self.per_pixel_color_shift:
color_filter = torch.empty(1, 1, 1, lensless.shape[-1], device=lensless.device).uniform_(*self.per_pixel_color_shift_range)
lensless = lensless * color_filter
lensed = lensed * color_filter

# Uncomment to visualize the effect of color shift
#save_image(background.squeeze().cpu().numpy(), f"background_pre{idx}.png")
background = background * color_filter
#save_image(background.squeeze().cpu().numpy(), f"background_post{idx}.png")

return_items = [lensless, lensed]
if self.multimask:
if self.return_mask_label:
Expand Down
Loading

0 comments on commit 18e8e66

Please sign in to comment.